from time import time import streamlit as st from grouped_sampling import GroupedSamplingPipeLine def generate_text( pipeline: GroupedSamplingPipeLine, prompt: str, output_length: int, ) -> str: """ Generates text using the given pipeline. :param pipeline: The pipeline to use. GroupedSamplingPipeLine. :param prompt: The prompt to use. str. :param output_length: The size of the text to generate in tokens. int > 0. :return: The generated text. str. """ return pipeline( prompt_s=prompt, max_new_tokens=output_length, return_text=True, return_full_text=False, )["generated_text"] def on_form_submit( pipeline: GroupedSamplingPipeLine, output_length: int, prompt: str, ) -> str: """ Called when the user submits the form. :param pipeline: The pipeline to use. GroupedSamplingPipeLine. :param output_length: The size of the groups to use. :param prompt: The prompt to use. :return: The output of the model. :raises ValueError: If the model name is not supported, the output length is <= 0, the prompt is empty or longer than 16384 characters, or the output length is not an integer. TypeError: If the output length is not an integer or the prompt is not a string. RuntimeError: If the model is not found. """ if len(prompt) == 0: raise ValueError("The prompt must not be empty.") st.write("Generating text...") print("Generating text...") generation_start_time = time() generated_text = generate_text( pipeline=pipeline, prompt=prompt, output_length=output_length, ) generation_end_time = time() generation_time = generation_end_time - generation_start_time st.write(f"Finished generating text in {generation_time:,.2f} seconds.") print(f"Finished generating text in {generation_time:,.2f} seconds.") return generated_text