Walmart-the-bag
commited on
Commit
•
5a3e826
1
Parent(s):
2239d70
faster inference :)
Browse files
app.py
CHANGED
@@ -5,10 +5,17 @@ from transformers import StoppingCriteria, StoppingCriteriaList
|
|
5 |
import torch
|
6 |
import spaces
|
7 |
import os
|
|
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
model_name = "microsoft/Phi-3-medium-128k-instruct"
|
10 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, trust_remote_code=True)
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
13 |
|
14 |
class StopOnTokens(StoppingCriteria):
|
@@ -19,7 +26,7 @@ class StopOnTokens(StoppingCriteria):
|
|
19 |
return True
|
20 |
return False
|
21 |
model.to('cuda')
|
22 |
-
@spaces.GPU()
|
23 |
def predict(message, history, temperature, max_tokens, top_p, top_k):
|
24 |
history_transformer_format = history + [[message, ""]]
|
25 |
stop = StopOnTokens()
|
|
|
5 |
import torch
|
6 |
import spaces
|
7 |
import os
|
8 |
+
import subprocess
|
9 |
|
10 |
+
# Install flash attention
|
11 |
+
subprocess.run(
|
12 |
+
"pip install flash-attn --no-build-isolation",
|
13 |
+
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
14 |
+
shell=True,
|
15 |
+
)
|
16 |
model_name = "microsoft/Phi-3-medium-128k-instruct"
|
17 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', _attn_implementation="flash_attention_2", torch_dtype=torch.float16, trust_remote_code=True)
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
20 |
|
21 |
class StopOnTokens(StoppingCriteria):
|
|
|
26 |
return True
|
27 |
return False
|
28 |
model.to('cuda')
|
29 |
+
@spaces.GPU(queue=False)
|
30 |
def predict(message, history, temperature, max_tokens, top_p, top_k):
|
31 |
history_transformer_format = history + [[message, ""]]
|
32 |
stop = StopOnTokens()
|