|
import json |
|
import torch |
|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation.utils import GenerationConfig |
|
|
|
|
|
st.set_page_config(page_title="Baichuan-13B-Chat") |
|
st.title("Baichuan-13B-Chat") |
|
|
|
@st.cache_resource |
|
def init_model(): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"baichuan-inc/Baichuan-13B-Chat", |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
model.generation_config = GenerationConfig.from_pretrained( |
|
"baichuan-inc/Baichuan-13B-Chat" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"baichuan-inc/Baichuan-13B-Chat", |
|
use_fast=False, |
|
trust_remote_code=True |
|
) |
|
model = model.quantize(8).cuda() |
|
return model, tokenizer |
|
|
|
|
|
def clear_chat_history(): |
|
del st.session_state.messages |
|
|
|
|
|
def init_chat_history(): |
|
with st.chat_message("assistant", avatar='🤖'): |
|
st.markdown("Greetings! I am the BaiChuan large language model, delighted to assist you.🥰") |
|
|
|
if "messages" in st.session_state: |
|
for message in st.session_state.messages: |
|
avatar = '🧑💻' if message["role"] == "user" else '🤖' |
|
with st.chat_message(message["role"], avatar=avatar): |
|
st.markdown(message["content"]) |
|
else: |
|
st.session_state.messages = [] |
|
|
|
return st.session_state.messages |
|
|
|
|
|
def main(): |
|
model, tokenizer = init_model() |
|
messages = init_chat_history() |
|
|
|
if prompt := st.chat_input("Shift + Enter for a new line, Enter to send"): |
|
with st.chat_message("user", avatar='🧑💻'): |
|
st.markdown(prompt) |
|
messages.append({"role": "user", "content": prompt}) |
|
print(f"[user] {prompt}", flush=True) |
|
with st.chat_message("assistant", avatar='🤖'): |
|
placeholder = st.empty() |
|
for response in model.chat(tokenizer, messages, stream=True): |
|
placeholder.markdown(response) |
|
if torch.backends.mps.is_available(): |
|
torch.mps.empty_cache() |
|
messages.append({"role": "assistant", "content": response}) |
|
print(json.dumps(messages, ensure_ascii=False), flush=True) |
|
|
|
st.button("Reset Chat", on_click=clear_chat_history) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|