yonikremer's picture
set max output length to 1024
6768469
raw
history blame
1.56 kB
"""
The Streamlit app for the project demo.
In the demo, the user can write a prompt
and the model will generate a response using the grouped sampling algorithm.
"""
import streamlit as st
from torch.cuda import CudaError
from available_models import AVAILABLE_MODELS
from hanlde_form_submit import on_form_submit
st.title("A Single Usage is All You Need - Demo")
with st.form("request_form"):
selected_model_name: str = st.selectbox(
label="choose a model",
options=AVAILABLE_MODELS,
help="opt-iml-max-30b generates better texts but is slower",
)
output_length: int = st.number_input(
label="the length of the output (in tokens)",
min_value=1,
max_value=1024,
value=5,
)
submitted_prompt: str = st.text_area(
label="prompt",
value="""
Keywords: cat, look, mouse
What is a sentence that includes all these keywords?
Answer:""",
max_chars=1024,
)
submitted: bool = st.form_submit_button(
label="generate text",
disabled=False,
)
if submitted:
try:
output = on_form_submit(
selected_model_name,
output_length,
submitted_prompt,
)
except CudaError as e:
st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
except (ValueError, TypeError, RuntimeError) as e:
st.error(e)
else:
st.write(f"Generated text: {output}")