Tonic commited on
Commit
a3cff49
โ€ข
1 Parent(s): ee5b6dd

improve interface and ZeroGPU logic

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -7,6 +7,8 @@ import spaces
7
  model_path = "nvidia/Mistral-NeMo-Minitron-8B-Instruct"
8
  tokenizer = AutoTokenizer.from_pretrained(model_path)
9
  model = AutoModelForCausalLM.from_pretrained(model_path)
 
 
10
 
11
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
12
 
@@ -36,15 +38,16 @@ def generate_response(message, history, system_message, max_tokens, temperature,
36
  if use_pipeline:
37
  response = pipe(full_prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True)[0]['generated_text']
38
  else:
39
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
40
-
41
  with torch.no_grad():
42
  output_ids = model.generate(
43
  inputs.input_ids,
44
  max_new_tokens=max_tokens,
45
  temperature=temperature,
46
  top_p=top_p,
47
- do_sample=True
 
48
  )
49
 
50
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
7
  model_path = "nvidia/Mistral-NeMo-Minitron-8B-Instruct"
8
  tokenizer = AutoTokenizer.from_pretrained(model_path)
9
  model = AutoModelForCausalLM.from_pretrained(model_path)
10
+ if tokenizer.pad_token_id is None:
11
+ tokenizer.pad_token_id = tokenizer.eos_token_id
12
 
13
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
14
 
 
38
  if use_pipeline:
39
  response = pipe(full_prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True)[0]['generated_text']
40
  else:
41
+ inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
42
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
43
  with torch.no_grad():
44
  output_ids = model.generate(
45
  inputs.input_ids,
46
  max_new_tokens=max_tokens,
47
  temperature=temperature,
48
  top_p=top_p,
49
+ do_sample=True,
50
+ attention_mask=inputs['attention_mask']
51
  )
52
 
53
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)