mxzh1995 commited on
Commit
cc377e0
1 Parent(s): 7aa80b2

Add application file

Browse files
Files changed (1) hide show
  1. chat.py +32 -0
chat.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @Time : 2023/3/13 16:18
4
+ @Auth : mxzh
5
+ @DESC:使用gradio chatbot组件部署 openai接口, 流式返回
6
+ """
7
+
8
+ import gradio as gr
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from transformers.generation.utils import GenerationConfig
12
+
13
+
14
+ def load_model():
15
+ tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", use_fast=False, trust_remote_code=True)
16
+ model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
17
+ model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-7B-Chat")
18
+ return model, tokenizer
19
+
20
+
21
+ def chat(content):
22
+ messages = []
23
+ messages.append({"role": "user", "content": content})
24
+ response = model.chat(tokenizer, messages)
25
+ return response
26
+
27
+
28
+ if __name__ == "__main__":
29
+ model, tokenizer = load_model()
30
+ iface = gr.Interface(fn=chat, inputs="text", outputs="text")
31
+ iface.launch()
32
+