|
|
|
""" |
|
@Time : 2023/3/13 16:18 |
|
@Auth : mxzh |
|
@DESC:使用gradio chatbot组件部署 openai接口, 流式返回 |
|
""" |
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation.utils import GenerationConfig |
|
|
|
|
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", use_fast=False, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True) |
|
model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-7B-Chat") |
|
return model, tokenizer |
|
|
|
|
|
def chat(content): |
|
messages = [] |
|
messages.append({"role": "user", "content": content}) |
|
response = model.chat(tokenizer, messages) |
|
return response |
|
|
|
|
|
if __name__ == "__main__": |
|
model, tokenizer = load_model() |
|
iface = gr.Interface(fn=chat, inputs="text", outputs="text") |
|
iface.launch() |
|
|
|
|