yonikremer commited on
Commit
7a75a15
1 Parent(s): a95851c

moved functions to new file

Browse files
Files changed (2) hide show
  1. app.py +7 -24
  2. hanlde_form_submit.py +30 -0
app.py CHANGED
@@ -1,41 +1,24 @@
1
  """
2
  The Streamlit app for the project demo.
3
- In the demo, the user can write a prompt and the model will generate a response using the grouped sampling algorithm.
 
4
  """
5
 
6
  import streamlit as st
7
- from grouped_sampling import GroupedSamplingPipeLine
8
 
9
- available_models_list = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
10
 
11
 
12
- def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
13
- """
14
- Creates a pipeline with the given model name and group size.
15
- :param model_name: The name of the model to use.
16
- :param group_size: The size of the groups to use.
17
- :return: A pipeline with the given model name and group size.
18
- """
19
- return GroupedSamplingPipeLine(model_name=model_name, group_size=group_size)
20
-
21
-
22
- def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
23
- """
24
- Called when the user submits the form.
25
- :param model_name: The name of the model to use.
26
- :param group_size: The size of the groups to use.
27
- :param prompt: The prompt to use.
28
- :return: The output of the model.
29
- """
30
- pipeline = create_pipeline(model_name, group_size)
31
- return pipeline(prompt)["generated_text"]
32
 
33
 
34
  with st.form("request_form"):
35
  selected_model_name: str = st.text_input(
36
  label="Model name",
37
  value="gpt2",
38
- help=f"The name of the model to use. Must be a model from this list: {available_models_list}"
 
 
39
  )
40
 
41
  output_length: int = st.number_input(
 
1
  """
2
  The Streamlit app for the project demo.
3
+ In the demo, the user can write a prompt
4
+ and the model will generate a response using the grouped sampling algorithm.
5
  """
6
 
7
  import streamlit as st
 
8
 
9
+ from hanlde_form_submit import on_form_submit
10
 
11
 
12
+ AVAILABLE_MODEL_NAMES = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  with st.form("request_form"):
16
  selected_model_name: str = st.text_input(
17
  label="Model name",
18
  value="gpt2",
19
+ help=f"The name of the model to use."
20
+ f" Must be a model from this list:"
21
+ f" {AVAILABLE_MODEL_NAMES}"
22
  )
23
 
24
  output_length: int = st.number_input(
hanlde_form_submit.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from grouped_sampling import GroupedSamplingPipeLine
2
+
3
+
4
+ def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
5
+ """
6
+ Creates a pipeline with the given model name and group size.
7
+ :param model_name: The name of the model to use.
8
+ :param group_size: The size of the groups to use.
9
+ :return: A pipeline with the given model name and group size.
10
+ """
11
+ return GroupedSamplingPipeLine(
12
+ model_name=model_name,
13
+ group_size=group_size,
14
+ end_of_sentence_stop=True,
15
+ )
16
+
17
+
18
+ def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
19
+ """
20
+ Called when the user submits the form.
21
+ :param model_name: The name of the model to use.
22
+ :param group_size: The size of the groups to use.
23
+ :param prompt: The prompt to use.
24
+ :return: The output of the model.
25
+ """
26
+ pipeline = create_pipeline(
27
+ model_name,
28
+ group_size,
29
+ )
30
+ return pipeline(prompt)["generated_text"]