wangrongsheng commited on
Commit
046f527
1 Parent(s): bb2872c

support two models

Browse files
Files changed (1) hide show
  1. app.py +74 -34
app.py CHANGED
@@ -36,10 +36,16 @@ if torch.cuda.is_available():
36
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  tokenizer.use_default_system_prompt = False
 
 
 
 
 
39
 
40
 
41
  @spaces.GPU
42
  def generate(
 
43
  message: str,
44
  chat_history: list[tuple[str, str]],
45
  system_prompt: str,
@@ -49,43 +55,78 @@ def generate(
49
  top_k: int = 50,
50
  repetition_penalty: float = 1.2,
51
  ) -> Iterator[str]:
52
- conversation = []
53
- if system_prompt:
54
- conversation.append({"role": "system", "content": system_prompt})
55
- for user, assistant in chat_history:
56
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
57
- conversation.append({"role": "user", "content": message})
58
-
59
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
60
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
63
- input_ids = input_ids.to(model.device)
64
-
65
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
66
- generate_kwargs = dict(
67
- {"input_ids": input_ids},
68
- streamer=streamer,
69
- max_new_tokens=max_new_tokens,
70
- do_sample=True,
71
- top_p=top_p,
72
- top_k=top_k,
73
- temperature=temperature,
74
- num_beams=1,
75
- repetition_penalty=repetition_penalty,
76
- )
77
- t = Thread(target=model.generate, kwargs=generate_kwargs)
78
- t.start()
79
-
80
- outputs = []
81
- for text in streamer:
82
- outputs.append(text)
83
- yield "".join(outputs)
84
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  chat_interface = gr.ChatInterface(
87
  fn=generate,
88
  additional_inputs=[
 
89
  gr.Textbox(label="System prompt", lines=6),
90
  gr.Slider(
91
  label="Max new tokens",
@@ -129,7 +170,6 @@ chat_interface = gr.ChatInterface(
129
  ["Can you explain briefly to me what is the Python programming language?"],
130
  ["Explain the plot of Cinderella in a sentence."],
131
  ["How many hours does it take a man to eat a Helicopter?"],
132
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
133
  ],
134
  )
135
 
 
36
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  tokenizer.use_default_system_prompt = False
39
+
40
+ model_id_zh = "FarReelAILab/Machine_Mindset_zh_INTJ"
41
+ model_zh = AutoModelForCausalLM.from_pretrained(model_id_zh, torch_dtype=torch.float16, device_map="auto")
42
+ tokenizer_zh = AutoTokenizer.from_pretrained(model_id_zh)
43
+ tokenizer_zh.use_default_system_prompt = False
44
 
45
 
46
  @spaces.GPU
47
  def generate(
48
+ select_model: str,
49
  message: str,
50
  chat_history: list[tuple[str, str]],
51
  system_prompt: str,
 
55
  top_k: int = 50,
56
  repetition_penalty: float = 1.2,
57
  ) -> Iterator[str]:
58
+ if select_model=="INTJ-en"
59
+ conversation = []
60
+ if system_prompt:
61
+ conversation.append({"role": "system", "content": system_prompt})
62
+ for user, assistant in chat_history:
63
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
64
+ conversation.append({"role": "user", "content": message})
65
+
66
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
67
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
68
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
69
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
70
+ input_ids = input_ids.to(model.device)
71
+
72
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
73
+ generate_kwargs = dict(
74
+ {"input_ids": input_ids},
75
+ streamer=streamer,
76
+ max_new_tokens=max_new_tokens,
77
+ do_sample=True,
78
+ top_p=top_p,
79
+ top_k=top_k,
80
+ temperature=temperature,
81
+ num_beams=1,
82
+ repetition_penalty=repetition_penalty,
83
+ )
84
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
85
+ t.start()
86
+
87
+ outputs = []
88
+ for text in streamer:
89
+ outputs.append(text)
90
+ yield "".join(outputs)
91
+
92
+ if select_model=="INTJ-zh"
93
+ conversation = []
94
+ if system_prompt:
95
+ conversation.append({"role": "system", "content": system_prompt})
96
+ for user, assistant in chat_history:
97
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
98
+ conversation.append({"role": "user", "content": message})
99
+
100
+ input_ids = tokenizer_zh.apply_chat_template(conversation, return_tensors="pt")
101
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
102
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
103
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
104
+ input_ids = input_ids.to(model_zh.device)
105
+
106
+ streamer = TextIteratorStreamer(tokenizer_zh, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
107
+ generate_kwargs = dict(
108
+ {"input_ids": input_ids},
109
+ streamer=streamer,
110
+ max_new_tokens=max_new_tokens,
111
+ do_sample=True,
112
+ top_p=top_p,
113
+ top_k=top_k,
114
+ temperature=temperature,
115
+ num_beams=1,
116
+ repetition_penalty=repetition_penalty,
117
+ )
118
+ t = Thread(target=model_zh.generate, kwargs=generate_kwargs)
119
+ t.start()
120
+
121
+ outputs = []
122
+ for text in streamer:
123
+ outputs.append(text)
124
+ yield "".join(outputs)
125
 
126
  chat_interface = gr.ChatInterface(
127
  fn=generate,
128
  additional_inputs=[
129
+ gr.Dropdown(choices=["INTJ-en", "INTJ-zh"], value="INTJ-en", label="Select Model")
130
  gr.Textbox(label="System prompt", lines=6),
131
  gr.Slider(
132
  label="Max new tokens",
 
170
  ["Can you explain briefly to me what is the Python programming language?"],
171
  ["Explain the plot of Cinderella in a sentence."],
172
  ["How many hours does it take a man to eat a Helicopter?"],
 
173
  ],
174
  )
175