codys12 commited on
Commit
99ab088
1 Parent(s): 4e4fd76
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -7,7 +7,7 @@ from typing import Iterator
7
  import gradio as gr
8
  import spaces
9
  import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
  DESCRIPTION = "# Mistral-7B"
13
 
@@ -21,8 +21,9 @@ MAX_INPUT_TOKEN_LENGTH = 4096
21
  if torch.cuda.is_available():
22
  model_id = "codys12/MergeLlama-7b"
23
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map=0, cache_dir="/data")
24
- model.cuda()
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
26
 
27
 
28
  @spaces.GPU
@@ -50,20 +51,25 @@ def generate(
50
  input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
51
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
52
 
53
- input_ids = tokenizer(current_input, return_tensors="pt").to("cuda")
54
-
55
- # Generate
56
- output_ids = model.generate(input_ids=input_ids,
57
- max_new_tokens=max_new_tokens,
58
- do_sample=True,
59
- top_p=top_p,
60
- top_k=top_k,
61
- temperature=temperature,
62
- repetition_penalty=repetition_penalty)
 
 
 
 
63
 
64
- # Stream output
65
- for id in output_ids.tolist()[0]:
66
- yield tokenizer.decode(id)
 
67
 
68
 
69
  chat_interface = gr.ChatInterface(
 
7
  import gradio as gr
8
  import spaces
9
  import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
11
 
12
  DESCRIPTION = "# Mistral-7B"
13
 
 
21
  if torch.cuda.is_available():
22
  model_id = "codys12/MergeLlama-7b"
23
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map=0, cache_dir="/data")
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+ tokenizer.padding_side = "right"
27
 
28
 
29
  @spaces.GPU
 
51
  input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
52
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
53
 
54
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_special_tokens=True)
55
+ generate_kwargs = dict(
56
+ {"input_ids": input_ids},
57
+ streamer=streamer,
58
+ max_new_tokens=max_new_tokens,
59
+ do_sample=True,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
+ temperature=temperature,
63
+ num_beams=1,
64
+ repetition_penalty=repetition_penalty,
65
+ )
66
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
67
+ t.start()
68
 
69
+ outputs = []
70
+ for text in streamer:
71
+ outputs.append(text)
72
+ yield "".join(outputs)
73
 
74
 
75
  chat_interface = gr.ChatInterface(