|
import gradio as gr |
|
from openai import OpenAI |
|
from huggingface_hub import InferenceClient |
|
from tenacity import retry, wait_random_exponential, stop_after_attempt |
|
|
|
OPENAI_KEY = os.getenv("OPENAI_KEY") |
|
client = OpenAI(api_key=OPEN_AI_KEY) |
|
|
|
def get_current_weather(location, unit="celsius"): |
|
"""Get the current weather in a given location""" |
|
if "taipei" in location.lower(): |
|
return json.dumps({"location": "Taipei", "temperature": "10", "unit": unit}) |
|
else: |
|
return json.dumps({"location": location, "temperature": "unknown"}) |
|
|
|
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3)) |
|
def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL): |
|
try: |
|
response = client.chat.completions.create( |
|
model=model, |
|
messages=messages, |
|
tools=tools, |
|
tool_choice=tool_choice, |
|
) |
|
return response |
|
except Exception as e: |
|
print("Unable to generate ChatCompletion response") |
|
print(f"Exception: {e}") |
|
return e |
|
|
|
tools = [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "get_current_weather", |
|
"description": "Get the current weather", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"location": { |
|
"type": "string", |
|
"description": "The city and state, e.g. San Francisco, CA", |
|
}, |
|
"unit": { |
|
"type": "string", |
|
"enum": ["celsius", "fahrenheit"], |
|
"description": "The temperature unit to use. Infer this from the users location.", |
|
}, |
|
}, |
|
"required": ["location", "unit"], |
|
}, |
|
} |
|
} |
|
] |
|
|
|
|
|
def respond( |
|
message, |
|
history: list[tuple[str, str]], |
|
system_message, |
|
): |
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
for val in history: |
|
if val[0]: |
|
messages.append({"role": "user", "content": val[0]}) |
|
if val[1]: |
|
messages.append({"role": "assistant", "content": val[1]}) |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
response = chat_completion_request(messages, tools=tools, tool_choice='auto') |
|
|
|
response_message = response.choices[0].message |
|
tool_calls = response_message.tool_calls |
|
if tool_calls: |
|
available_functions = { |
|
"get_current_weather": get_current_weather, |
|
} |
|
messages.append(response_message) |
|
for tool_call in tool_calls: |
|
function_name = tool_call.function.name |
|
function_to_call = available_functions[function_name] |
|
function_args = json.loads(tool_call.function.arguments) |
|
function_response = function_to_call( |
|
location=function_args.get("location"), |
|
unit=function_args.get("unit"), |
|
) |
|
messages.append( |
|
{ |
|
"tool_call_id": tool_call.id, |
|
"role": "tool", |
|
"name": function_name, |
|
"content": function_response, |
|
} |
|
) |
|
second_response = chat_completion_request(messages) |
|
print(second_response) |
|
return second_response |
|
|
|
""" |
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
""" |
|
demo = gr.ChatInterface( |
|
respond, |
|
additional_inputs=[ |
|
gr.Textbox(value="", label="System message"), |
|
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Top-p (nucleus sampling)", |
|
), |
|
], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |