yonikremer commited on
Commit
30f253f
1 Parent(s): c7b7b1d

changed pipeline's parameters

Browse files
Files changed (1) hide show
  1. hanlde_form_submit.py +10 -6
hanlde_form_submit.py CHANGED
@@ -1,7 +1,6 @@
1
  import streamlit as st
2
  from grouped_sampling import GroupedSamplingPipeLine
3
 
4
-
5
  from supported_models import get_supported_model_names
6
 
7
 
@@ -19,15 +18,17 @@ def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
19
  model_name=model_name,
20
  group_size=group_size,
21
  end_of_sentence_stop=True,
 
 
22
  )
23
 
24
 
25
  @st.cache
26
- def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
27
  """
28
  Called when the user submits the form.
29
  :param model_name: The name of the model to use.
30
- :param group_size: The size of the groups to use.
31
  :param prompt: The prompt to use.
32
  :return: The output of the model.
33
  """
@@ -36,7 +37,10 @@ def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
36
  f"Supported models are all the models in:"
37
  f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch")
38
  pipeline = create_pipeline(
39
- model_name,
40
- group_size,
41
  )
42
- return pipeline(prompt)["generated_text"]
 
 
 
 
1
  import streamlit as st
2
  from grouped_sampling import GroupedSamplingPipeLine
3
 
 
4
  from supported_models import get_supported_model_names
5
 
6
 
 
18
  model_name=model_name,
19
  group_size=group_size,
20
  end_of_sentence_stop=True,
21
+ temp=0.5,
22
+ top_p=0.6,
23
  )
24
 
25
 
26
  @st.cache
27
+ def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
28
  """
29
  Called when the user submits the form.
30
  :param model_name: The name of the model to use.
31
+ :param output_length: The size of the groups to use.
32
  :param prompt: The prompt to use.
33
  :return: The output of the model.
34
  """
 
37
  f"Supported models are all the models in:"
38
  f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch")
39
  pipeline = create_pipeline(
40
+ model_name=model_name,
41
+ group_size=output_length,
42
  )
43
+ return pipeline(
44
+ prompt_s=prompt,
45
+ max_new_tokens=output_length,
46
+ )["generated_text"]