zyliu commited on
Commit
01657a2
1 Parent(s): 7c0c777

update gradio_web_server.py and model_worker.py

Browse files
Files changed (2) hide show
  1. gradio_web_server.py +6 -1
  2. model_worker.py +28 -17
gradio_web_server.py CHANGED
@@ -818,7 +818,7 @@ if __name__ == "__main__":
818
  parser = argparse.ArgumentParser()
819
  parser.add_argument("--host", type=str, default="0.0.0.0")
820
  parser.add_argument("--port", type=int, default=11000)
821
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
822
  parser.add_argument("--concurrency-count", type=int, default=10)
823
  parser.add_argument(
824
  "--model-list-mode", type=str, default="once", choices=["once", "reload"]
@@ -829,6 +829,11 @@ if __name__ == "__main__":
829
  parser.add_argument("--embed", action="store_true")
830
  args = parser.parse_args()
831
  logger.info(f"args: {args}")
 
 
 
 
 
832
 
833
  models = get_model_list()
834
 
 
818
  parser = argparse.ArgumentParser()
819
  parser.add_argument("--host", type=str, default="0.0.0.0")
820
  parser.add_argument("--port", type=int, default=11000)
821
+ parser.add_argument("--controller-url", type=str, default=None)
822
  parser.add_argument("--concurrency-count", type=int, default=10)
823
  parser.add_argument(
824
  "--model-list-mode", type=str, default="once", choices=["once", "reload"]
 
829
  parser.add_argument("--embed", action="store_true")
830
  args = parser.parse_args()
831
  logger.info(f"args: {args}")
832
+ if not args.controller_url:
833
+ args.controller_url = os.environ.get("CONTROLLER_URL", None)
834
+
835
+ if not args.controller_url:
836
+ raise ValueError("controller-url is required.")
837
 
838
  models = get_model_list()
839
 
model_worker.py CHANGED
@@ -160,6 +160,25 @@ def split_model(model_name):
160
  return device_map
161
 
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  class ModelWorker:
164
  def __init__(
165
  self,
@@ -325,8 +344,6 @@ class ModelWorker:
325
  "queue_length": self.get_queue_length(),
326
  }
327
 
328
- # @torch.inference_mode()
329
- @spaces.GPU(duration=120)
330
  def generate_stream(self, params):
331
  system_message = params["prompt"][0]["content"]
332
  send_messages = params["prompt"][1:]
@@ -428,20 +445,14 @@ class ModelWorker:
428
  streamer=streamer,
429
  )
430
  logger.info(f"Generation config: {generation_config}")
431
-
432
- with torch.no_grad():
433
- thread = Thread(
434
- target=self.model.chat,
435
- kwargs=dict(
436
- tokenizer=self.tokenizer,
437
- pixel_values=pixel_values,
438
- question=question,
439
- history=history,
440
- return_history=False,
441
- generation_config=generation_config,
442
- ),
443
- )
444
- thread.start()
445
 
446
  generated_text = ""
447
  for new_text in streamer:
@@ -541,4 +552,4 @@ if __name__ == "__main__":
541
  args.load_8bit,
542
  args.device,
543
  )
544
- uvicorn.run(app, host=args.host, port=args.port, log_level="info", workers=1)
 
160
  return device_map
161
 
162
 
163
+ @spaces.GPU(duration=120)
164
+ def multi_thread_infer(
165
+ model, tokenizer, pixel_values, question, history, generation_config
166
+ ):
167
+ with torch.no_grad():
168
+ thread = Thread(
169
+ target=model.chat,
170
+ kwargs=dict(
171
+ tokenizer=tokenizer,
172
+ pixel_values=pixel_values,
173
+ question=question,
174
+ history=history,
175
+ return_history=False,
176
+ generation_config=generation_config,
177
+ ),
178
+ )
179
+ thread.start()
180
+
181
+
182
  class ModelWorker:
183
  def __init__(
184
  self,
 
344
  "queue_length": self.get_queue_length(),
345
  }
346
 
 
 
347
  def generate_stream(self, params):
348
  system_message = params["prompt"][0]["content"]
349
  send_messages = params["prompt"][1:]
 
445
  streamer=streamer,
446
  )
447
  logger.info(f"Generation config: {generation_config}")
448
+ multi_thread_infer(
449
+ self.model,
450
+ self.tokenizer,
451
+ pixel_values,
452
+ question,
453
+ history,
454
+ generation_config,
455
+ )
 
 
 
 
 
 
456
 
457
  generated_text = ""
458
  for new_text in streamer:
 
552
  args.load_8bit,
553
  args.device,
554
  )
555
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")