File size: 1,971 Bytes
3a9aacf
 
2fd3831
df273ff
c9089bd
 
 
 
 
 
 
 
 
dfa084c
 
 
 
c9089bd
 
a671856
c9089bd
d102e03
 
c9089bd
7a75a15
 
0499581
df273ff
0499581
 
 
7a75a15
 
df273ff
30f253f
7a75a15
 
d73a8e9
 
 
 
 
7a75a15
d73a8e9
22e2fd1
 
16ec708
3a9aacf
e63724c
c9089bd
 
 
 
3a9aacf
 
bf8a943
16ec708
e63724c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from time import time

import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine


def generate_text(
        pipeline: GroupedSamplingPipeLine,
        prompt: str,
        output_length: int,
) -> str:
    """
    Generates text using the given pipeline.
    :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
    :param prompt: The prompt to use. str.
    :param output_length: The size of the text to generate in tokens. int > 0.
    :return: The generated text. str.
    """
    return pipeline(
        prompt_s=prompt,
        max_new_tokens=output_length,
        return_text=True,
        return_full_text=False,
    )["generated_text"]


def on_form_submit(
        pipeline: GroupedSamplingPipeLine,
        output_length: int,
        prompt: str,
) -> str:
    """
    Called when the user submits the form.
    :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
    :param output_length: The size of the groups to use.
    :param prompt: The prompt to use.
    :return: The output of the model.
    :raises ValueError: If the model name is not supported, the output length is <= 0,
     the prompt is empty or longer than
     16384 characters, or the output length is not an integer.
     TypeError: If the output length is not an integer or the prompt is not a string.
     RuntimeError: If the model is not found.
    """
    if len(prompt) == 0:
        raise ValueError("The prompt must not be empty.")
    st.write("Generating text...")
    print("Generating text...")
    generation_start_time = time()
    generated_text = generate_text(
        pipeline=pipeline,
        prompt=prompt,
        output_length=output_length,
    )
    generation_end_time = time()
    generation_time = generation_end_time - generation_start_time
    st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
    print(f"Finished generating text in {generation_time:,.2f} seconds.")
    return generated_text