File size: 1,539 Bytes
2fd3831 7a75a15 b1dd47e 7a75a15 30f253f 7a75a15 2fd3831 30f253f 7a75a15 30f253f 7a75a15 b1dd47e 7a75a15 30f253f 7a75a15 30f253f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine
from supported_models import get_supported_model_names
SUPPORTED_MODEL_NAMES = get_supported_model_names()
def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
"""
Creates a pipeline with the given model name and group size.
:param model_name: The name of the model to use.
:param group_size: The size of the groups to use.
:return: A pipeline with the given model name and group size.
"""
return GroupedSamplingPipeLine(
model_name=model_name,
group_size=group_size,
end_of_sentence_stop=True,
temp=0.5,
top_p=0.6,
)
@st.cache
def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
"""
Called when the user submits the form.
:param model_name: The name of the model to use.
:param output_length: The size of the groups to use.
:param prompt: The prompt to use.
:return: The output of the model.
"""
if model_name not in SUPPORTED_MODEL_NAMES:
raise ValueError(f"The selected model {model_name} is not supported."
f"Supported models are all the models in:"
f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch")
pipeline = create_pipeline(
model_name=model_name,
group_size=output_length,
)
return pipeline(
prompt_s=prompt,
max_new_tokens=output_length,
)["generated_text"]
|