import streamlit as st import torch import time from threading import Thread from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer ) # App title st.set_page_config(page_title="πŸ˜Άβ€πŸŒ«οΈ FuseChat Model") root_path = "FuseAI" model_name = "FuseAI/FuseChat-7B-v2.0" @st.cache_resource def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained( f"{root_path}/{model_name}", trust_remote_code=True, ) if tokenizer.pad_token_id is None: if tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id else: tokenizer.pad_token_id = 0 model = AutoModelForCausalLM.from_pretrained( f"{root_path}/{model_name}", device_map="auto", load_in_4bit=True, torch_dtype=torch.bfloat16, trust_remote_code=True, ) model.eval() return model, tokenizer with st.sidebar: st.title('πŸ˜Άβ€πŸŒ«οΈ FuseChat-v2.0') st.write('This chatbot is created using FuseChat, a model developed by FuseAI') temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.1, step=0.01) top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01) top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=50, step=1) repetition_penalty = st.sidebar.slider('repetition penalty', min_value=1., max_value=2., value=1.2, step=0.05) max_length = st.sidebar.slider('max new tokens', min_value=32, max_value=2000, value=512, step=8) with st.spinner('loading model..'): model, tokenizer = load_model(model_name) # Store LLM generated responses if "messages" not in st.session_state.keys(): st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] # Display or clear chat messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) def set_query(query): st.session_state.messages.append({"role": "user", "content": query}) # Create a list of candidate questions candidate_questions = ["Can you tell me a joke?", "Write a quicksort code in Python.", "Write a poem about love in Shakespearean tone."] # Display the chat interface with a list of clickable question buttons for question in candidate_questions: st.sidebar.button(label=question, on_click=set_query, args=[question]) def clear_chat_history(): st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] st.sidebar.button('Clear Chat History', on_click=clear_chat_history) def generate_fusechat_response(): # string_dialogue = "You are a helpful and harmless assistant." string_dialogue = "" for dict_message in st.session_state.messages: if dict_message["role"] == "user": string_dialogue += "GPT4 Correct User: " + dict_message["content"] + "<|end_of_turn|>" else: string_dialogue += "GPT4 Correct Assistant: " + dict_message["content"] + "<|end_of_turn|>" input_ids = tokenizer(f"{string_dialogue}GPT4 Correct Assistant: ", return_tensors="pt").input_ids streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_length, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) return "".join(outputs) # User-provided prompt if prompt := st.chat_input("Hello there! How are you doing?"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.write(prompt) # Generate a new response if last message is not from assistant if st.session_state.messages[-1]["role"] != "assistant": with st.chat_message("assistant"): with st.spinner("Thinking..."): response = generate_fusechat_response() placeholder = st.empty() full_response = '' for item in response: full_response += item time.sleep(0.05) placeholder.markdown(full_response + "β–Œ") placeholder.markdown(full_response) message = {"role": "assistant", "content": full_response} st.session_state.messages.append(message)