indiejoseph commited on
Commit
35f8f29
1 Parent(s): 9021fd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -14
app.py CHANGED
@@ -5,7 +5,7 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 4096
11
  DEFAULT_MAX_NEW_TOKENS = 2048
@@ -39,7 +39,7 @@ def generate(
39
  top_p: float = 0.9,
40
  top_k: int = 50,
41
  repetition_penalty: float = 1.2,
42
- ) -> Iterator[str]:
43
  conversation = []
44
  if system_prompt:
45
  conversation.append({"role": "system", "content": system_prompt})
@@ -52,26 +52,20 @@ def generate(
52
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
53
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
54
  input_ids = input_ids.to(model.device)
55
-
56
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
57
- generate_kwargs = dict(
58
- {"input_ids": input_ids},
59
- streamer=streamer,
60
  max_new_tokens=max_new_tokens,
61
  do_sample=True,
62
  top_p=top_p,
63
  top_k=top_k,
64
  temperature=temperature,
65
  num_beams=1,
66
- repetition_penalty=repetition_penalty,
67
  )
68
- t = Thread(target=model.generate, kwargs=generate_kwargs)
69
- t.start()
70
 
71
- outputs = []
72
- for text in streamer:
73
- outputs.append(text)
74
- yield "".join(outputs)
75
 
76
 
77
  chat_interface = gr.ChatInterface(
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  MAX_MAX_NEW_TOKENS = 4096
11
  DEFAULT_MAX_NEW_TOKENS = 2048
 
39
  top_p: float = 0.9,
40
  top_k: int = 50,
41
  repetition_penalty: float = 1.2,
42
+ ) -> str:
43
  conversation = []
44
  if system_prompt:
45
  conversation.append({"role": "system", "content": system_prompt})
 
52
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
53
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
54
  input_ids = input_ids.to(model.device)
55
+ output_ids = model.generate(
56
+ input_ids,
 
 
 
57
  max_new_tokens=max_new_tokens,
58
  do_sample=True,
59
  top_p=top_p,
60
  top_k=top_k,
61
  temperature=temperature,
62
  num_beams=1,
63
+ repetition_penalty=repetition_penalty
64
  )
 
 
65
 
66
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
67
+ return response
68
+
 
69
 
70
 
71
  chat_interface = gr.ChatInterface(