zyliu commited on
Commit
389bff0
1 Parent(s): 46446a7

update model_worker.py

Browse files
Files changed (1) hide show
  1. model_worker.py +13 -9
model_worker.py CHANGED
@@ -183,8 +183,8 @@ class ModelWorker:
183
  else:
184
  self.model_name = model_name
185
 
 
186
  logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
187
-
188
  tokenizer = AutoTokenizer.from_pretrained(
189
  model_path, trust_remote_code=True, use_fast=False
190
  )
@@ -225,6 +225,18 @@ class ModelWorker:
225
  )
226
  self.heart_beat_thread.start()
227
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  def reload_model(self):
229
  del self.model
230
  torch.cuda.empty_cache()
@@ -311,15 +323,7 @@ class ModelWorker:
311
  @spaces.GPU
312
  @torch.inference_mode()
313
  def generate_stream(self, params):
314
- try:
315
- import flash_attn
316
- except ImportError:
317
 
318
- def install_flash_attn():
319
- os.system("pip install flash-attn==2.5.9.post1")
320
-
321
- install_flash_attn()
322
- # import flash_attn
323
  system_message = params["prompt"][0]["content"]
324
  send_messages = params["prompt"][1:]
325
  max_input_tiles = params["max_input_tiles"]
 
183
  else:
184
  self.model_name = model_name
185
 
186
+ self.import_flash_attn()
187
  logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
 
188
  tokenizer = AutoTokenizer.from_pretrained(
189
  model_path, trust_remote_code=True, use_fast=False
190
  )
 
225
  )
226
  self.heart_beat_thread.start()
227
 
228
+ @spaces.GPU
229
+ def import_flash_attn(self):
230
+ try:
231
+ import flash_attn
232
+ except ImportError:
233
+
234
+ def install_flash_attn():
235
+ os.system("pip install flash-attn==2.5.9.post1")
236
+
237
+ install_flash_attn()
238
+ # import flash_attn
239
+
240
  def reload_model(self):
241
  del self.model
242
  torch.cuda.empty_cache()
 
323
  @spaces.GPU
324
  @torch.inference_mode()
325
  def generate_stream(self, params):
 
 
 
326
 
 
 
 
 
 
327
  system_message = params["prompt"][0]["content"]
328
  send_messages = params["prompt"][1:]
329
  max_input_tiles = params["max_input_tiles"]