Nicholas Meisburger commited on
Commit
c7d63a8
1 Parent(s): bb3a7f9

add temperature and beam width sliders

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -8,14 +8,15 @@ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
8
  model = bolt.GenerativeModel.load("./generative.model")
9
 
10
 
11
- def generate(prompt):
12
  prompt = tokenizer.encode(prompt)
13
 
14
  stream = model.streaming_generation(
15
  input_tokens=prompt,
16
  prediction_chunk_size=2,
17
  max_predictions=80,
18
- beam_width=10,
 
19
  )
20
 
21
  for res in stream:
@@ -23,14 +24,21 @@ def generate(prompt):
23
 
24
 
25
  with gr.Blocks() as demo:
 
26
  output = gr.TextArea(label="Output")
27
- prompt = gr.Textbox(
28
- label="Prompt",
 
 
 
 
 
29
  )
30
- prompt.submit(generate, inputs=[prompt], outputs=[output])
 
31
 
32
  btn = gr.Button(value="Generate")
33
- btn.click(generate, inputs=[prompt], outputs=[output])
34
 
35
  gr.ClearButton(components=[prompt, output])
36
 
 
8
  model = bolt.GenerativeModel.load("./generative.model")
9
 
10
 
11
+ def generate(prompt, beam_width, temperature):
12
  prompt = tokenizer.encode(prompt)
13
 
14
  stream = model.streaming_generation(
15
  input_tokens=prompt,
16
  prediction_chunk_size=2,
17
  max_predictions=80,
18
+ beam_width=beam_width,
19
+ temperature=temperature if temperature > 0 else None,
20
  )
21
 
22
  for res in stream:
 
24
 
25
 
26
  with gr.Blocks() as demo:
27
+ prompt = gr.Textbox(label="Prompt", autofocus=True)
28
  output = gr.TextArea(label="Output")
29
+ beam_width = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Beam Width")
30
+ temperature = gr.Slider(
31
+ minimum=0,
32
+ maximum=3,
33
+ step=0.1,
34
+ value=1.2,
35
+ label="Temperature (0 means temperature isn't used)",
36
  )
37
+
38
+ prompt.submit(generate, inputs=[prompt, beam_width, temperature], outputs=[output])
39
 
40
  btn = gr.Button(value="Generate")
41
+ btn.click(generate, inputs=[prompt, beam_width, temperature], outputs=[output])
42
 
43
  gr.ClearButton(components=[prompt, output])
44