artificialguybr commited on
Commit
81c24b6
1 Parent(s): 466a76b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -4,10 +4,11 @@ import mdtex2html
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from transformers.generation import GenerationConfig
 
7
 
8
  # Initialize model and tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
10
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", device_map="auto", trust_remote_code=True).eval()
11
  model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
12
 
13
  # Postprocess function
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from transformers.generation import GenerationConfig
7
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
8
 
9
  # Initialize model and tokenizer
10
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
11
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", device_map="auto", trust_remote_code=True, use_flash_attn=True).eval()
12
  model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
13
 
14
  # Postprocess function