|
import gradio as gr |
|
from pathlib import Path |
|
import torch |
|
from tsai_gpt.generate_for_app import generate_for_app |
|
|
|
pythia_model = "checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth" |
|
def generate_text(prompt): |
|
generated_text = generate_for_app(prompt, num_samples=1, max_new_tokens=200, temperature=0.9, checkpoint_dir=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/")) |
|
|
|
|
|
return generated_text |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Example of text generation with our pythia 160M model based on the RedPajama sample data: |
|
|
|
|
|
The model checkpoint is the 'checkpoints/meta-llama/Llama-2-7b-chat-hf' dir. The hyper params used are the exact same emitted by the training main.ipynb notebook. The loss is less than 3.5; we can see syntactically correct but semantically meaningless sentences. |
|
|
|
Keep in mind the output is limited to 250 tokens so the inference runs within reasonable time (10s) on CPU. (Huggingface free tier) |
|
|
|
GPU inference can output much much longer sequences. |
|
Click on the "Generate text" button to see the generated text. |
|
""") |
|
|
|
generate_button = gr.Button("Generate text!") |
|
input=gr.Textbox(label="Enter your prompt here") |
|
output = gr.Textbox(label="Text generated by Pythia 160M trained model") |
|
generate_button.click(fn=generate_text, inputs=input, outputs=output, api_name='text generation sample') |
|
|
|
demo.launch() |