Walmart-the-bag commited on
Commit
5a3e826
1 Parent(s): 2239d70

faster inference :)

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -5,10 +5,17 @@ from transformers import StoppingCriteria, StoppingCriteriaList
5
  import torch
6
  import spaces
7
  import os
 
8
 
 
 
 
 
 
 
9
  model_name = "microsoft/Phi-3-medium-128k-instruct"
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, trust_remote_code=True)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
 
14
  class StopOnTokens(StoppingCriteria):
@@ -19,7 +26,7 @@ class StopOnTokens(StoppingCriteria):
19
  return True
20
  return False
21
  model.to('cuda')
22
- @spaces.GPU()
23
  def predict(message, history, temperature, max_tokens, top_p, top_k):
24
  history_transformer_format = history + [[message, ""]]
25
  stop = StopOnTokens()
 
5
  import torch
6
  import spaces
7
  import os
8
+ import subprocess
9
 
10
+ # Install flash attention
11
+ subprocess.run(
12
+ "pip install flash-attn --no-build-isolation",
13
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
14
+ shell=True,
15
+ )
16
  model_name = "microsoft/Phi-3-medium-128k-instruct"
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', _attn_implementation="flash_attention_2", torch_dtype=torch.float16, trust_remote_code=True)
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
21
  class StopOnTokens(StoppingCriteria):
 
26
  return True
27
  return False
28
  model.to('cuda')
29
+ @spaces.GPU(queue=False)
30
  def predict(message, history, temperature, max_tokens, top_p, top_k):
31
  history_transformer_format = history + [[message, ""]]
32
  stop = StopOnTokens()