File size: 1,971 Bytes
3a9aacf 2fd3831 df273ff c9089bd dfa084c c9089bd a671856 c9089bd d102e03 c9089bd 7a75a15 0499581 df273ff 0499581 7a75a15 df273ff 30f253f 7a75a15 d73a8e9 7a75a15 d73a8e9 22e2fd1 16ec708 3a9aacf e63724c c9089bd 3a9aacf bf8a943 16ec708 e63724c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from time import time
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine
def generate_text(
pipeline: GroupedSamplingPipeLine,
prompt: str,
output_length: int,
) -> str:
"""
Generates text using the given pipeline.
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
:param prompt: The prompt to use. str.
:param output_length: The size of the text to generate in tokens. int > 0.
:return: The generated text. str.
"""
return pipeline(
prompt_s=prompt,
max_new_tokens=output_length,
return_text=True,
return_full_text=False,
)["generated_text"]
def on_form_submit(
pipeline: GroupedSamplingPipeLine,
output_length: int,
prompt: str,
) -> str:
"""
Called when the user submits the form.
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
:param output_length: The size of the groups to use.
:param prompt: The prompt to use.
:return: The output of the model.
:raises ValueError: If the model name is not supported, the output length is <= 0,
the prompt is empty or longer than
16384 characters, or the output length is not an integer.
TypeError: If the output length is not an integer or the prompt is not a string.
RuntimeError: If the model is not found.
"""
if len(prompt) == 0:
raise ValueError("The prompt must not be empty.")
st.write("Generating text...")
print("Generating text...")
generation_start_time = time()
generated_text = generate_text(
pipeline=pipeline,
prompt=prompt,
output_length=output_length,
)
generation_end_time = time()
generation_time = generation_end_time - generation_start_time
st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
print(f"Finished generating text in {generation_time:,.2f} seconds.")
return generated_text
|