File size: 2,302 Bytes
1d0039a c111a7e 1d0039a c111a7e 1d0039a c111a7e 1d0039a c111a7e 1d0039a c111a7e 1d0039a f33abcd c111a7e 1d0039a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import re
import gradio as gr
from routellm.controller import Controller
TEMPERATURE = 0.8
THRESHOLD = 0.11593
ROUTER = "mf"
client = Controller(
routers=["mf"],
strong_model="gpt-4-1106-preview",
weak_model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1",
)
def predict(message, history, threshold, temperature):
# Convert chat history to OpenAI format
history_openai_format = [
{"role": "system", "content": "You are a helpful AI assistant."}
]
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append(
{
"role": "assistant",
# Remove model name from response
"content": re.sub(r"^\[.*?\]\s*", "", assistant),
}
)
history_openai_format.append({"role": "user", "content": message})
# Create a chat completion request and send it to the API server
stream = client.chat.completions.create(
model=f"router-{ROUTER}-{threshold}", # Model name to use
messages=history_openai_format, # Chat history
temperature=temperature, # Temperature for text generation
stream=True, # Stream response
)
print(stream)
# Read and return generated text from response stream
partial_message = ""
for i, chunk in enumerate(stream):
print(chunk)
if i == 0:
model_prefix = f"[{chunk.model}]\n"
yield model_prefix
partial_message += model_prefix
partial_message += chunk.choices[0].delta.content or ""
yield partial_message
# Create and launch a chat interface with Gradio
demo = gr.ChatInterface(
predict,
additional_inputs=[
gr.Slider(label="Threshold", minimum=0, maximum=1, value=THRESHOLD, step=0.01),
gr.Slider(
label="Temperature", minimum=0, maximum=1, value=TEMPERATURE, step=0.1
),
],
title="RouteLLM",
description="This is a demo of our matrix factorization router, calibrated so that approximately 50% of calls (those that are harder) are routed to GPT-4, with remaining calls routed to Mixtral 8x7B.\n\nCheck out https://github.com/lm-sys/RouteLLM for details!",
)
if __name__ == "__main__":
demo.launch()
|