|
import re |
|
|
|
import gradio as gr |
|
from routellm.controller import Controller |
|
|
|
TEMPERATURE = 0.8 |
|
THRESHOLD = 0.11593 |
|
ROUTER = "mf" |
|
|
|
client = Controller( |
|
routers=["mf"], |
|
strong_model="gpt-4-1106-preview", |
|
weak_model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
) |
|
|
|
|
|
def predict(message, history, threshold, temperature): |
|
|
|
history_openai_format = [ |
|
{"role": "system", "content": "You are a helpful AI assistant."} |
|
] |
|
for human, assistant in history: |
|
history_openai_format.append({"role": "user", "content": human}) |
|
history_openai_format.append( |
|
{ |
|
"role": "assistant", |
|
|
|
"content": re.sub(r"^\[.*?\]\s*", "", assistant), |
|
} |
|
) |
|
history_openai_format.append({"role": "user", "content": message}) |
|
|
|
|
|
stream = client.chat.completions.create( |
|
model=f"router-{ROUTER}-{threshold}", |
|
messages=history_openai_format, |
|
temperature=temperature, |
|
stream=True, |
|
) |
|
print(stream) |
|
|
|
|
|
partial_message = "" |
|
for i, chunk in enumerate(stream): |
|
print(chunk) |
|
if i == 0: |
|
model_prefix = f"[{chunk.model}]\n" |
|
yield model_prefix |
|
partial_message += model_prefix |
|
partial_message += chunk.choices[0].delta.content or "" |
|
yield partial_message |
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
predict, |
|
additional_inputs=[ |
|
gr.Slider(label="Threshold", minimum=0, maximum=1, value=THRESHOLD, step=0.01), |
|
gr.Slider( |
|
label="Temperature", minimum=0, maximum=1, value=TEMPERATURE, step=0.1 |
|
), |
|
], |
|
title="RouteLLM", |
|
description="This is a demo of our matrix factorization router, calibrated so that approximately 50% of calls (those that are harder) are routed to GPT-4, with remaining calls routed to Mixtral 8x7B.\n\nCheck out https://github.com/lm-sys/RouteLLM for details!", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|