demo / app.py
iojw's picture
Add max tokens
d51b0f1
raw
history blame
2.34 kB
import re
import gradio as gr
from routellm.controller import Controller
TEMPERATURE = 0.8
THRESHOLD = 0.11593
ROUTER = "mf"
client = Controller(
routers=["mf"],
strong_model="gpt-4-1106-preview",
weak_model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1",
)
def predict(message, history, threshold, temperature):
# Convert chat history to OpenAI format
history_openai_format = [
{"role": "system", "content": "You are a helpful AI assistant."}
]
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append(
{
"role": "assistant",
# Remove model name from response
"content": re.sub(r"^\*\*\[.*?\]\*\*\s*", "", assistant),
}
)
history_openai_format.append({"role": "user", "content": message})
# Create a chat completion request and send it to the API server
stream = client.chat.completions.create(
model=f"router-{ROUTER}-{threshold}", # Model name to use
messages=history_openai_format, # Chat history
temperature=temperature, # Temperature for text generation
stream=True, # Stream response
max_tokens=512
)
print(stream)
# Read and return generated text from response stream
partial_message = ""
for i, chunk in enumerate(stream):
print(chunk)
if i == 0:
model_prefix = f"**[{chunk.model}]**\n"
yield model_prefix
partial_message += model_prefix
partial_message += chunk.choices[0].delta.content or ""
yield partial_message
# Create and launch a chat interface with Gradio
demo = gr.ChatInterface(
predict,
additional_inputs=[
gr.Slider(label="Threshold", minimum=0, maximum=1, value=THRESHOLD, step=0.01),
gr.Slider(
label="Temperature", minimum=0, maximum=1, value=TEMPERATURE, step=0.1
),
],
title="RouteLLM",
description="This is a demo of our matrix factorization router, calibrated so that approximately 50% of calls (those that are harder) are routed to GPT-4, with remaining calls routed to Mixtral 8x7B.\n\nCheck out https://github.com/lm-sys/RouteLLM for details!",
)
if __name__ == "__main__":
demo.launch()