File size: 4,658 Bytes
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777dac6
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e5383
7261d63
 
 
 
 
 
 
 
 
777dac6
7261d63
 
 
 
 
39e5383
7261d63
 
13ea389
7261d63
 
 
 
 
 
 
 
 
 
cf8bf4d
 
 
 
 
 
 
 
7261d63
 
 
 
 
 
13ea389
 
7261d63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import streamlit as st
import torch
import time
from threading import Thread
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer
)


# App title
st.set_page_config(page_title="😶‍🌫️ FuseChat Model")

root_path = "FuseAI"
model_name = "FuseAI/FuseChat-7B-v2.0"

@st.cache_resource
def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        f"{root_path}/{model_name}",
        trust_remote_code=True,
    )

    if tokenizer.pad_token_id is None:
            if tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.pad_token_id = 0

    model = AutoModelForCausalLM.from_pretrained(
        f"{root_path}/{model_name}",
        device_map="auto",
        load_in_4bit=True,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )

    model.eval()
    return model, tokenizer


with st.sidebar:
    st.title('😶‍🌫️ FuseChat-v2.0')
    st.write('This chatbot is created using FuseChat, a model developed by FuseAI')
    temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.1, step=0.01)
    top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
    top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=50, step=1)
    repetition_penalty = st.sidebar.slider('repetition penalty', min_value=1., max_value=2., value=1.2, step=0.05)
    max_length = st.sidebar.slider('max new tokens', min_value=32, max_value=2000, value=512, step=8)

with st.spinner('loading model..'):
    model, tokenizer = load_model(model_name)

# Store LLM generated responses
if "messages" not in st.session_state.keys():
    st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]

# Display or clear chat messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])

def set_query(query):
    st.session_state.messages.append({"role": "user", "content": query})
# Create a list of candidate questions
candidate_questions = ["Can you tell me a joke?", "Write a quicksort code in Python.", "Write a poem about love in Shakespearean tone."]
# Display the chat interface with a list of clickable question buttons
for question in candidate_questions:
    st.sidebar.button(label=question, on_click=set_query, args=[question])

def clear_chat_history():
    st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
st.sidebar.button('Clear Chat History', on_click=clear_chat_history)


def generate_fusechat_response():
    # string_dialogue = "You are a helpful and harmless assistant."
    string_dialogue = ""
    for dict_message in st.session_state.messages:
        if dict_message["role"] == "user":
            string_dialogue += "GPT4 Correct User: " + dict_message["content"] + "<|end_of_turn|>"
        else:
            string_dialogue += "GPT4 Correct Assistant: " + dict_message["content"] + "<|end_of_turn|>"

    input_ids = tokenizer(f"{string_dialogue}GPT4 Correct Assistant: ", return_tensors="pt").input_ids
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_length,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
    return "".join(outputs)

# User-provided prompt
if prompt := st.chat_input("Hello there! How are you doing?"):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.write(prompt)

# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "assistant":
    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            response = generate_fusechat_response()
            placeholder = st.empty()
            full_response = ''
            for item in response:
                full_response += item
                time.sleep(0.05)
                placeholder.markdown(full_response + "▌")
            placeholder.markdown(full_response)
    message = {"role": "assistant", "content": full_response}
    st.session_state.messages.append(message)