yonikremer commited on
Commit
7924ca5
1 Parent(s): 45e9ce6

added a supported model check

Browse files
Files changed (2) hide show
  1. hanlde_form_submit.py +11 -0
  2. supported_models.py +17 -2
hanlde_form_submit.py CHANGED
@@ -6,6 +6,8 @@ from grouped_sampling import GroupedSamplingPipeLine
6
 
7
  from download_repo import download_repository
8
  from prompt_engeneering import rewrite_prompt
 
 
9
 
10
 
11
  def is_downloaded(model_name: str) -> bool:
@@ -93,6 +95,15 @@ def on_form_submit(
93
  TypeError: If the output length is not an integer or the prompt is not a string.
94
  RuntimeError: If the model is not found.
95
  """
 
 
 
 
 
 
 
 
 
96
  if len(prompt) == 0:
97
  raise ValueError(f"The prompt must not be empty.")
98
  st.write(f"Loading model: {model_name}...")
 
6
 
7
  from download_repo import download_repository
8
  from prompt_engeneering import rewrite_prompt
9
+ from supported_models import is_supported, SUPPORTED_MODEL_NAME_PAGES_FORMAT, BLACKLISTED_MODEL_NAMES, \
10
+ BLACKLISTED_ORGANIZATIONS
11
 
12
 
13
  def is_downloaded(model_name: str) -> bool:
 
95
  TypeError: If the output length is not an integer or the prompt is not a string.
96
  RuntimeError: If the model is not found.
97
  """
98
+ if not is_supported(model_name, 1, 1):
99
+ raise ValueError(
100
+ f"The model: {model_name} is not supported."
101
+ f"The supported models are the models from {SUPPORTED_MODEL_NAME_PAGES_FORMAT}"
102
+ f" that satisfy the following conditions:\n"
103
+ f"1. The model has at least one like and one download.\n"
104
+ f"2. The model is not one of: {BLACKLISTED_MODEL_NAMES}.\n"
105
+ f"3. The model was not created any of those organizations: {BLACKLISTED_ORGANIZATIONS}.\n"
106
+ )
107
  if len(prompt) == 0:
108
  raise ValueError(f"The prompt must not be empty.")
109
  st.write(f"Loading model: {model_name}...")
supported_models.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Generator, Set, Union, List, Optional
2
 
3
  import requests
@@ -137,14 +138,28 @@ def generate_supported_model_names(
137
  )
138
 
139
 
 
140
  def get_supported_model_names(
141
  min_number_of_downloads: int = DEFAULT_MIN_NUMBER_OF_DOWNLOADS,
142
  min_number_of_likes: int = DEFAULT_MIN_NUMBER_OF_LIKES,
143
  ) -> Set[str]:
144
- return set(generate_supported_model_names(
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  min_number_of_downloads=min_number_of_downloads,
146
  min_number_of_likes=min_number_of_likes,
147
- ))
148
 
149
 
150
  if __name__ == "__main__":
 
1
+ from functools import lru_cache
2
  from typing import Generator, Set, Union, List, Optional
3
 
4
  import requests
 
138
  )
139
 
140
 
141
+ @lru_cache
142
  def get_supported_model_names(
143
  min_number_of_downloads: int = DEFAULT_MIN_NUMBER_OF_DOWNLOADS,
144
  min_number_of_likes: int = DEFAULT_MIN_NUMBER_OF_LIKES,
145
  ) -> Set[str]:
146
+ return set(
147
+ generate_supported_model_names(
148
+ min_number_of_downloads=min_number_of_downloads,
149
+ min_number_of_likes=min_number_of_likes,
150
+ )
151
+ )
152
+
153
+
154
+ def is_supported(
155
+ model_name: str,
156
+ min_number_of_downloads: int = DEFAULT_MIN_NUMBER_OF_DOWNLOADS,
157
+ min_number_of_likes: int = DEFAULT_MIN_NUMBER_OF_LIKES,
158
+ ) -> bool:
159
+ return model_name in get_supported_model_names(
160
  min_number_of_downloads=min_number_of_downloads,
161
  min_number_of_likes=min_number_of_likes,
162
+ )
163
 
164
 
165
  if __name__ == "__main__":