indiejoseph commited on
Commit
039d7bc
1 Parent(s): f9c87b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -24,7 +24,7 @@ if not torch.cuda.is_available():
24
 
25
  if torch.cuda.is_available():
26
  model_id = "hon9kon9ize/CantoneseLLMChat-preview20240326"
27
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
28
  model = torch.compile(model)
29
  tokenizer = AutoTokenizer.from_pretrained(model_id)
30
  tokenizer.use_default_system_prompt = False
@@ -48,7 +48,9 @@ def generate(
48
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
49
  conversation.append({"role": "user", "content": message})
50
 
51
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
52
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
54
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
@@ -60,7 +62,6 @@ def generate(
60
  top_p=top_p,
61
  top_k=top_k,
62
  temperature=temperature,
63
- num_beams=1,
64
  repetition_penalty=repetition_penalty
65
  )
66
 
 
24
 
25
  if torch.cuda.is_available():
26
  model_id = "hon9kon9ize/CantoneseLLMChat-preview20240326"
27
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16,)
28
  model = torch.compile(model)
29
  tokenizer = AutoTokenizer.from_pretrained(model_id)
30
  tokenizer.use_default_system_prompt = False
 
48
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
49
  conversation.append({"role": "user", "content": message})
50
 
51
+ printf(chat_history)
52
+
53
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors='pt')
54
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
55
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
56
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
62
  top_p=top_p,
63
  top_k=top_k,
64
  temperature=temperature,
 
65
  repetition_penalty=repetition_penalty
66
  )
67