aresnow commited on
Commit
371c2f5
1 Parent(s): d6220ca

limit model size

Browse files
Files changed (1) hide show
  1. app.py +469 -7
app.py CHANGED
@@ -1,10 +1,472 @@
1
- from xinference.deploy.local import main
2
- from xoscar.utils import get_next_port
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- address = f"0.0.0.0:{get_next_port()}"
5
 
6
- main(
7
- address=address,
8
- host="0.0.0.0",
9
- port=get_next_port(),
 
 
 
 
 
 
 
10
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import asyncio
15
+ import os
16
+ import urllib.request
17
+ import uuid
18
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
19
 
20
+ import gradio as gr
21
 
22
+ from xinference.locale.utils import Locale
23
+ from xinference.model import MODEL_FAMILIES, ModelSpec
24
+ from xinference.core.api import SyncSupervisorAPI
25
+
26
+ if TYPE_CHECKING:
27
+ from xinference.types import ChatCompletionChunk, ChatCompletionMessage
28
+
29
+ MODEL_TO_FAMILIES = dict(
30
+ (model_family.model_name, model_family)
31
+ for model_family in MODEL_FAMILIES
32
+ if model_family.model_name != "baichuan"
33
  )
34
+
35
+
36
+ class GradioApp:
37
+ def __init__(
38
+ self,
39
+ supervisor_address: str,
40
+ gladiator_num: int = 2,
41
+ max_model_num: int = 2,
42
+ use_launched_model: bool = False,
43
+ ):
44
+ self._api = SyncSupervisorAPI(supervisor_address)
45
+ self._gladiator_num = gladiator_num
46
+ self._max_model_num = max_model_num
47
+ self._use_launched_model = use_launched_model
48
+ self._locale = Locale()
49
+
50
+ def _create_model(
51
+ self,
52
+ model_name: str,
53
+ model_size_in_billions: Optional[int] = None,
54
+ model_format: Optional[str] = None,
55
+ quantization: Optional[str] = None,
56
+ ):
57
+ model_uid = str(uuid.uuid1())
58
+ models = self._api.list_models()
59
+ if len(models) >= self._max_model_num:
60
+ self._api.terminate_model(models[0][0])
61
+ return self._api.launch_model(
62
+ model_uid, model_name, model_size_in_billions, model_format, quantization
63
+ )
64
+
65
+ async def generate(
66
+ self,
67
+ model: str,
68
+ message: str,
69
+ chat: List[List[str]],
70
+ max_token: int,
71
+ temperature: float,
72
+ top_p: float,
73
+ window_size: int,
74
+ show_finish_reason: bool,
75
+ ):
76
+ if not message:
77
+ yield message, chat
78
+ else:
79
+ try:
80
+ model_ref = self._api.get_model(model)
81
+ except KeyError:
82
+ raise gr.Error(self._locale(f"Please create model first"))
83
+
84
+ history: "List[ChatCompletionMessage]" = []
85
+ for c in chat:
86
+ history.append({"role": "user", "content": c[0]})
87
+
88
+ out = c[1]
89
+ finish_reason_idx = out.find(f"[{self._locale('stop reason')}: ")
90
+ if finish_reason_idx != -1:
91
+ out = out[:finish_reason_idx]
92
+ history.append({"role": "assistant", "content": out})
93
+
94
+ if window_size != 0:
95
+ history = history[-(window_size // 2) :]
96
+
97
+ # chatglm only support even number of conversation history.
98
+ if len(history) % 2 != 0:
99
+ history = history[1:]
100
+
101
+ generate_config = dict(
102
+ max_tokens=max_token,
103
+ temperature=temperature,
104
+ top_p=top_p,
105
+ stream=True,
106
+ )
107
+ chat += [[message, ""]]
108
+ chat_generator = await model_ref.chat(
109
+ message,
110
+ chat_history=history,
111
+ generate_config=generate_config,
112
+ )
113
+
114
+ chunk: Optional["ChatCompletionChunk"] = None
115
+ async for chunk in chat_generator:
116
+ assert chunk is not None
117
+ delta = chunk["choices"][0]["delta"]
118
+ if "content" not in delta:
119
+ continue
120
+ else:
121
+ chat[-1][1] += delta["content"]
122
+ yield "", chat
123
+ if show_finish_reason and chunk is not None:
124
+ chat[-1][
125
+ 1
126
+ ] += f"[{self._locale('stop reason')}: {chunk['choices'][0]['finish_reason']}]"
127
+ yield "", chat
128
+
129
+ def _build_chatbot(self, model_uid: str, model_name: str):
130
+ with gr.Accordion(self._locale("Parameters"), open=False):
131
+ max_token = gr.Slider(
132
+ 128,
133
+ 1024,
134
+ value=256,
135
+ step=1,
136
+ label=self._locale("Max tokens"),
137
+ info=self._locale("The maximum number of tokens to generate."),
138
+ )
139
+ temperature = gr.Slider(
140
+ 0.2,
141
+ 1,
142
+ value=0.8,
143
+ step=0.01,
144
+ label=self._locale("Temperature"),
145
+ info=self._locale("The temperature to use for sampling."),
146
+ )
147
+ top_p = gr.Slider(
148
+ 0.2,
149
+ 1,
150
+ value=0.95,
151
+ step=0.01,
152
+ label=self._locale("Top P"),
153
+ info=self._locale("The top-p value to use for sampling."),
154
+ )
155
+ window_size = gr.Slider(
156
+ 0,
157
+ 50,
158
+ value=10,
159
+ step=1,
160
+ label=self._locale("Window size"),
161
+ info=self._locale("Window size of chat history."),
162
+ )
163
+ show_finish_reason = gr.Checkbox(
164
+ label=f"{self._locale('Show stop reason')}"
165
+ )
166
+ chat = gr.Chatbot(label=model_name)
167
+ text = gr.Textbox(visible=False)
168
+ model_uid = gr.Textbox(model_uid, visible=False)
169
+ text.change(
170
+ self.generate,
171
+ [
172
+ model_uid,
173
+ text,
174
+ chat,
175
+ max_token,
176
+ temperature,
177
+ top_p,
178
+ window_size,
179
+ show_finish_reason,
180
+ ],
181
+ [text, chat],
182
+ )
183
+ return (
184
+ text,
185
+ chat,
186
+ max_token,
187
+ temperature,
188
+ top_p,
189
+ show_finish_reason,
190
+ window_size,
191
+ model_uid,
192
+ )
193
+
194
+ def _build_chat_column(self):
195
+ with gr.Column():
196
+ with gr.Row():
197
+ model_name = gr.Dropdown(
198
+ choices=list(MODEL_TO_FAMILIES.keys()),
199
+ label=self._locale("model name"),
200
+ scale=2,
201
+ )
202
+ model_format = gr.Dropdown(
203
+ choices=[],
204
+ interactive=False,
205
+ label=self._locale("model format"),
206
+ scale=2,
207
+ )
208
+ model_size_in_billions = gr.Dropdown(
209
+ choices=[],
210
+ interactive=False,
211
+ label=self._locale("model size in billions"),
212
+ scale=1,
213
+ )
214
+ quantization = gr.Dropdown(
215
+ choices=[],
216
+ interactive=False,
217
+ label=self._locale("quantization"),
218
+ scale=1,
219
+ )
220
+ create_model = gr.Button(value=self._locale("create"))
221
+
222
+ def select_model_name(model_name: str):
223
+ if model_name:
224
+ model_family = MODEL_TO_FAMILIES[model_name]
225
+ formats = [model_family.model_format]
226
+ model_sizes_in_billions = [
227
+ str(b) for b in model_family.model_sizes_in_billions
228
+ ]
229
+ quantizations = model_family.quantizations
230
+ return (
231
+ gr.Dropdown.update(
232
+ choices=formats,
233
+ interactive=True,
234
+ value=model_family.model_format,
235
+ ),
236
+ gr.Dropdown.update(
237
+ choices=model_sizes_in_billions[:1],
238
+ interactive=True,
239
+ value=model_sizes_in_billions[0],
240
+ ),
241
+ gr.Dropdown.update(
242
+ choices=quantizations,
243
+ interactive=True,
244
+ value=quantizations[0],
245
+ ),
246
+ )
247
+ else:
248
+ return (
249
+ gr.Dropdown.update(),
250
+ gr.Dropdown.update(),
251
+ gr.Dropdown.update(),
252
+ )
253
+
254
+ model_name.change(
255
+ select_model_name,
256
+ inputs=[model_name],
257
+ outputs=[model_format, model_size_in_billions, quantization],
258
+ )
259
+
260
+ components = self._build_chatbot("", "")
261
+ model_text = components[0]
262
+ chat, model_uid = components[1], components[-1]
263
+
264
+ def select_model(
265
+ _model_name: str,
266
+ _model_format: str,
267
+ _model_size_in_billions: str,
268
+ _quantization: str,
269
+ progress=gr.Progress(),
270
+ ):
271
+ model_family = MODEL_TO_FAMILIES[_model_name]
272
+ cache_path, meta_path = model_family.generate_cache_path(
273
+ int(_model_size_in_billions), _quantization
274
+ )
275
+ if not (os.path.exists(cache_path) and os.path.exists(meta_path)):
276
+ if os.path.exists(cache_path):
277
+ os.remove(cache_path)
278
+ url = model_family.url_generator(
279
+ int(_model_size_in_billions), _quantization
280
+ )
281
+ full_name = (
282
+ f"{str(model_family)}-{_model_size_in_billions}b-{_quantization}"
283
+ )
284
+ try:
285
+ urllib.request.urlretrieve(
286
+ url,
287
+ cache_path,
288
+ reporthook=lambda block_num, block_size, total_size: progress(
289
+ block_num * block_size / total_size,
290
+ desc=self._locale("Downloading"),
291
+ ),
292
+ )
293
+ # write a meta file to record if download finished
294
+ with open(meta_path, "w") as f:
295
+ f.write(full_name)
296
+ except:
297
+ if os.path.exists(cache_path):
298
+ os.remove(cache_path)
299
+
300
+ model_uid = self._create_model(
301
+ _model_name, int(_model_size_in_billions), _model_format, _quantization
302
+ )
303
+ return gr.Chatbot.update(
304
+ label="-".join(
305
+ [_model_name, _model_size_in_billions, _model_format, _quantization]
306
+ ),
307
+ value=[],
308
+ ), gr.Textbox.update(value=model_uid)
309
+
310
+ def clear_chat(
311
+ _model_name: str,
312
+ _model_format: str,
313
+ _model_size_in_billions: str,
314
+ _quantization: str,
315
+ ):
316
+ full_name = "-".join(
317
+ [_model_name, _model_size_in_billions, _model_format, _quantization]
318
+ )
319
+ return str(uuid.uuid4()), gr.Chatbot.update(
320
+ label=full_name,
321
+ value=[],
322
+ )
323
+
324
+ invisible_text = gr.Textbox(visible=False)
325
+ create_model.click(
326
+ clear_chat,
327
+ inputs=[model_name, model_format, model_size_in_billions, quantization],
328
+ outputs=[invisible_text, chat],
329
+ )
330
+
331
+ invisible_text.change(
332
+ select_model,
333
+ inputs=[model_name, model_format, model_size_in_billions, quantization],
334
+ outputs=[chat, model_uid],
335
+ postprocess=False,
336
+ )
337
+ return chat, model_text
338
+
339
+ def _build_arena(self):
340
+ with gr.Box():
341
+ with gr.Row():
342
+ chat_and_text = [
343
+ self._build_chat_column() for _ in range(self._gladiator_num)
344
+ ]
345
+ chats = [c[0] for c in chat_and_text]
346
+ texts = [c[1] for c in chat_and_text]
347
+
348
+ msg = gr.Textbox(label=self._locale("Input"))
349
+
350
+ def update_message(text_in: str):
351
+ return "", text_in, text_in
352
+
353
+ msg.submit(update_message, inputs=[msg], outputs=[msg] + texts)
354
+
355
+ gr.ClearButton(components=[msg] + chats + texts)
356
+
357
+ def _build_single(self):
358
+ chat, model_text = self._build_chat_column()
359
+
360
+ msg = gr.Textbox(label=self._locale("Input"))
361
+
362
+ def update_message(text_in: str):
363
+ return "", text_in
364
+
365
+ msg.submit(update_message, inputs=[msg], outputs=[msg, model_text])
366
+ gr.ClearButton(components=[chat, msg, model_text])
367
+
368
+ def _build_single_with_launched(
369
+ self, models: List[Tuple[str, ModelSpec]], default_index: int
370
+ ):
371
+ uid_to_model_spec: Dict[str, ModelSpec] = dict((m[0], m[1]) for m in models)
372
+ choices = [
373
+ "-".join(
374
+ [
375
+ s.model_name,
376
+ str(s.model_size_in_billions),
377
+ s.model_format,
378
+ s.quantization,
379
+ ]
380
+ )
381
+ for s in uid_to_model_spec.values()
382
+ ]
383
+ choice_to_uid = dict(zip(choices, uid_to_model_spec.keys()))
384
+ model_selection = gr.Dropdown(
385
+ label=self._locale("select model"),
386
+ choices=choices,
387
+ value=choices[default_index],
388
+ )
389
+ components = self._build_chatbot(
390
+ models[default_index][0], choices[default_index]
391
+ )
392
+ model_text = components[0]
393
+ model_uid = components[-1]
394
+ chat = components[1]
395
+
396
+ def select_model(model_name):
397
+ uid = choice_to_uid[model_name]
398
+ return gr.Chatbot.update(label=model_name), uid
399
+
400
+ model_selection.change(
401
+ select_model, inputs=[model_selection], outputs=[chat, model_uid]
402
+ )
403
+ return chat, model_text
404
+
405
+ def _build_arena_with_launched(self, models: List[Tuple[str, ModelSpec]]):
406
+ with gr.Box():
407
+ with gr.Row():
408
+ chat_and_text = [
409
+ self._build_single_with_launched(models, i)
410
+ for i in range(self._gladiator_num)
411
+ ]
412
+ chats = [c[0] for c in chat_and_text]
413
+ texts = [c[1] for c in chat_and_text]
414
+
415
+ msg = gr.Textbox(label=self._locale("Input"))
416
+
417
+ def update_message(text_in: str):
418
+ return "", text_in, text_in
419
+
420
+ msg.submit(update_message, inputs=[msg], outputs=[msg] + texts)
421
+
422
+ gr.ClearButton(components=[msg] + chats + texts)
423
+
424
+ def build(self):
425
+ if self._use_launched_model:
426
+ models = self._api.list_models()
427
+ with gr.Blocks() as blocks:
428
+ with gr.Tab(self._locale("Chat")):
429
+ chat, model_text = self._build_single_with_launched(models, 0)
430
+ msg = gr.Textbox(label=self._locale("Input"))
431
+
432
+ def update_message(text_in: str):
433
+ return "", text_in
434
+
435
+ msg.submit(update_message, inputs=[msg], outputs=[msg, model_text])
436
+ gr.ClearButton(components=[chat, msg, model_text])
437
+ if len(models) > 2:
438
+ with gr.Tab(self._locale("Arena")):
439
+ self._build_arena_with_launched(models)
440
+ else:
441
+ with gr.Blocks() as blocks:
442
+ with gr.Tab(self._locale("Chat")):
443
+ self._build_single()
444
+ with gr.Tab(self._locale("Arena")):
445
+ self._build_arena()
446
+ blocks.queue(concurrency_count=40)
447
+ return blocks
448
+
449
+
450
+ async def launch_xinference():
451
+ import xoscar as xo
452
+ from xinference.core.service import SupervisorActor
453
+
454
+ pool = await xo.create_actor_pool(address="0.0.0.0", n_process=0)
455
+ await xo.create_actor(
456
+ SupervisorActor, address=pool.external_address, uid=SupervisorActor.uid()
457
+ )
458
+ gradio_block = GradioApp(pool.external_address).build()
459
+ gradio_block.launch()
460
+
461
+
462
+ if __name__ == "__main__":
463
+ loop = asyncio.get_event_loop()
464
+ task = loop.create_task(launch_xinference())
465
+
466
+ try:
467
+ loop.run_until_complete(task)
468
+ except KeyboardInterrupt:
469
+ task.cancel()
470
+ loop.run_until_complete(task)
471
+ # avoid displaying exception-unhandled warnings
472
+ task.exception()