lvwerra's picture
lvwerra HF staff
Update app.py
199759c
raw
history blame
5.54 kB
import json
import os
import shutil
import gradio as gr
from huggingface_hub import Repository
from text_generation import Client
from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = os.environ.get("API_URL")
FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
FIM_INDICATOR = "<FILL_HERE>"
FORMATS = """## Model formats
### Prefixes
Any combination of the three:
```
<reponame>REPONAME<filename>FILENAME<gh_stars>STARS\nCode<eos>
```
Stars be: 0, 1-10, 10-100, 100-1000, 1000+
### Commits
```
<commit_before>code<commit_msg>text<commit_after>code<|endoftext|>
```
### Jupyter structure
```
<start_jupyter><jupyter_text>text<jupyter_code>code<jupyter_output>output<jupyter_text>
```
### Fill-in-the-middle
```
code before<FILL_HERE>code after
```
"""
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)
client = Client(
API_URL,
#headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
def generate(prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
fim_mode = False
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
if FIM_INDICATOR in prompt:
fim_mode = True
try:
prefix, suffix = prompt.split("<FILL-HERE>")
except:
ValueError("Only one <FILL-HERE> allowed in prompt!")
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
stream = client.generate_stream(prompt, **generate_kwargs)
if fim_mode:
output = prefix
else:
output = prompt
for response in stream:
output += response.token.text
yield output
if fim_mode:
output += suffix
return output
examples = [
"def hello_world():",
"def fibonacci(n):",
"class TransformerDecoder(nn.Module):",
"class ComplexNumbers:"
]
def process_example(args):
for x in generate(args):
pass
return x
css = ".generating {visibility: hidden}" + share_btn_css
with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
with gr.Column():
gr.Markdown(
"""\
# BigCode - Playground
_Note:_ this is an internal playground - please do not share. The deployment can also change and thus the space not work as we continue development.\
"""
)
with gr.Row():
with gr.Column(scale=3):
instruction = gr.Textbox(placeholder="Enter your prompt here", label="Prompt", elem_id="q-input")
submit = gr.Button("Generate", variant="primary")
output = gr.Code(elem_id="q-output")
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=True)
loading_icon = gr.HTML(loading_icon_html, visible=True)
share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=False,
fn=process_example,
outputs=[output],
)
gr.Markdown(FORMATS)
with gr.Column(scale=1):
temperature = gr.Slider(
label="Temperature",
value=0.2,
minimum=0.0,
maximum=2.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=4,
interactive=True,
info="The maximum numbers of new tokens",
)
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty], outputs=[output])
instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty], outputs=[output])
share_button.click(None, [], [], _js=share_js)
demo.queue(concurrency_count=16).launch(debug=True)