chat1 / chat.py
mxzh1995's picture
Add application file
cc377e0
raw
history blame contribute delete
No virus
1.03 kB
# -*- coding: utf-8 -*-
"""
@Time : 2023/3/13 16:18
@Auth : mxzh
@DESC:使用gradio chatbot组件部署 openai接口, 流式返回
"""
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
def load_model():
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-7B-Chat")
return model, tokenizer
def chat(content):
messages = []
messages.append({"role": "user", "content": content})
response = model.chat(tokenizer, messages)
return response
if __name__ == "__main__":
model, tokenizer = load_model()
iface = gr.Interface(fn=chat, inputs="text", outputs="text")
iface.launch()