import re import gradio as gr from routellm.controller import Controller TEMPERATURE = 0.8 THRESHOLD = 0.11593 ROUTER = "mf" client = Controller( routers=["mf"], strong_model="gpt-4-1106-preview", weak_model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1", ) def predict(message, history, threshold, temperature): # Convert chat history to OpenAI format history_openai_format = [ {"role": "system", "content": "You are a helpful AI assistant."} ] for human, assistant in history: history_openai_format.append({"role": "user", "content": human}) history_openai_format.append( { "role": "assistant", # Remove model name from response "content": re.sub(r"^\*\*\[.*?\]\*\*\s*", "", assistant), } ) history_openai_format.append({"role": "user", "content": message}) # Create a chat completion request and send it to the API server stream = client.chat.completions.create( model=f"router-{ROUTER}-{threshold}", # Model name to use messages=history_openai_format, # Chat history temperature=temperature, # Temperature for text generation stream=True, # Stream response max_tokens=512 ) print(stream) # Read and return generated text from response stream partial_message = "" for i, chunk in enumerate(stream): print(chunk) if i == 0: if chunk.model == "mistralai/Mixtral-8x7B-Instruct-v0.1": model_name = "Mixtral-8x7B-Instruct-v0.1" else: model_name = chunk.model model_prefix = f"**[{model_name}]**\n" yield model_prefix partial_message += model_prefix partial_message += chunk.choices[0].delta.content or "" yield partial_message # Create and launch a chat interface with Gradio demo = gr.ChatInterface( predict, additional_inputs=[ gr.Slider(label="Threshold", minimum=0, maximum=1, value=THRESHOLD, step=0.01), gr.Slider( label="Temperature", minimum=0, maximum=1, value=TEMPERATURE, step=0.1 ), ], title="RouteLLM", fill_height=True, description="This is a demo of our matrix factorization router, calibrated so that approximately 50% of calls (those that are harder) are routed to GPT-4, with remaining calls routed to Mixtral 8x7B.\n\nCheck out https://github.com/lm-sys/RouteLLM for details!", ) if __name__ == "__main__": demo.launch()