Spaces:
Paused
Paused
import json | |
import torch | |
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers.generation.utils import GenerationConfig | |
st.set_page_config(page_title="Baichuan-13B-Chat") | |
st.title("Baichuan-13B-Chat") | |
def init_model(): | |
model = AutoModelForCausalLM.from_pretrained( | |
"baichuan-inc/Baichuan-13B-Chat", | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
model.generation_config = GenerationConfig.from_pretrained( | |
"baichuan-inc/Baichuan-13B-Chat" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"baichuan-inc/Baichuan-13B-Chat", | |
use_fast=False, | |
trust_remote_code=True | |
) | |
return model, tokenizer | |
def clear_chat_history(): | |
del st.session_state.messages | |
def init_chat_history(): | |
with st.chat_message("assistant", avatar='π€'): | |
st.markdown("Greetings! I am the BaiChuan large language model, delighted to assist you.π₯°") | |
if "messages" in st.session_state: | |
for message in st.session_state.messages: | |
avatar = 'π§βπ»' if message["role"] == "user" else 'π€' | |
with st.chat_message(message["role"], avatar=avatar): | |
st.markdown(message["content"]) | |
else: | |
st.session_state.messages = [] | |
return st.session_state.messages | |
def main(): | |
model, tokenizer = init_model() | |
messages = init_chat_history() | |
if prompt := st.chat_input("Shift + Enter for a new line, Enter to send"): | |
with st.chat_message("user", avatar='π§βπ»'): | |
st.markdown(prompt) | |
messages.append({"role": "user", "content": prompt}) | |
print(f"[user] {prompt}", flush=True) | |
with st.chat_message("assistant", avatar='π€'): | |
placeholder = st.empty() | |
for response in model.chat(tokenizer, messages, stream=True): | |
placeholder.markdown(response) | |
if torch.backends.mps.is_available(): | |
torch.mps.empty_cache() | |
messages.append({"role": "assistant", "content": response}) | |
print(json.dumps(messages, ensure_ascii=False), flush=True) | |
st.button("Reset Chat", on_click=clear_chat_history) | |
if __name__ == "__main__": | |
main() | |