zyliu commited on
Commit
46446a7
1 Parent(s): 498ea76

update app.py and model_worker.py

Browse files
Files changed (2) hide show
  1. app.py +0 -12
  2. model_worker.py +10 -0
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import fire
3
  import subprocess
4
  import os
@@ -7,17 +6,6 @@ import signal
7
  import subprocess
8
  import atexit
9
 
10
- try:
11
- import flash_attn
12
- except ImportError:
13
-
14
- @spaces.GPU
15
- def install_flash_attn():
16
- os.system("pip install flash-attn==2.5.9.post1")
17
-
18
- # install_flash_attn()
19
- # import flash_attn
20
-
21
 
22
  def kill_processes_by_cmd_substring(cmd_substring):
23
  # execute `ps -ef` and obtain its output
 
 
1
  import fire
2
  import subprocess
3
  import os
 
6
  import subprocess
7
  import atexit
8
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def kill_processes_by_cmd_substring(cmd_substring):
11
  # execute `ps -ef` and obtain its output
model_worker.py CHANGED
@@ -8,6 +8,7 @@
8
  A model worker executes the model.
9
  """
10
  import spaces
 
11
  import argparse
12
  import asyncio
13
 
@@ -310,6 +311,15 @@ class ModelWorker:
310
  @spaces.GPU
311
  @torch.inference_mode()
312
  def generate_stream(self, params):
 
 
 
 
 
 
 
 
 
313
  system_message = params["prompt"][0]["content"]
314
  send_messages = params["prompt"][1:]
315
  max_input_tiles = params["max_input_tiles"]
 
8
  A model worker executes the model.
9
  """
10
  import spaces
11
+ import os
12
  import argparse
13
  import asyncio
14
 
 
311
  @spaces.GPU
312
  @torch.inference_mode()
313
  def generate_stream(self, params):
314
+ try:
315
+ import flash_attn
316
+ except ImportError:
317
+
318
+ def install_flash_attn():
319
+ os.system("pip install flash-attn==2.5.9.post1")
320
+
321
+ install_flash_attn()
322
+ # import flash_attn
323
  system_message = params["prompt"][0]["content"]
324
  send_messages = params["prompt"][1:]
325
  max_input_tiles = params["max_input_tiles"]