Spaces:
Runtime error
Runtime error
Jonatan Asketorp
commited on
Commit
•
b3f2672
1
Parent(s):
457d9a8
Add onnx filter
Browse files- background_task.py +15 -4
background_task.py
CHANGED
@@ -191,7 +191,17 @@ def match(model1, model2):
|
|
191 |
print(f"Match {model1_id} against {model2_id} ended.")
|
192 |
|
193 |
|
194 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
"""
|
196 |
Get the list of models from the hub and the ELO file.
|
197 |
|
@@ -200,14 +210,15 @@ def get_models_list(filter_bad_models) -> list:
|
|
200 |
models = []
|
201 |
models_ids = []
|
202 |
data = pd.read_csv(os.path.join(DATASET_REPO_URL, "resolve", "main", ELO_FILENAME))
|
203 |
-
|
|
|
204 |
for i, row in data.iterrows():
|
205 |
model_id = row["author"] + "/" + row["model"]
|
206 |
-
if model_id in filter_bad_models:
|
207 |
continue
|
208 |
models.append(Model(row["author"], row["model"], row["elo"], row["games_played"]))
|
209 |
models_ids.append(model_id)
|
210 |
-
for model in
|
211 |
if model.modelId in filter_bad_models:
|
212 |
continue
|
213 |
author, name = model.modelId.split("/")[0], model.modelId.split("/")[1]
|
|
|
191 |
print(f"Match {model1_id} against {model2_id} ended.")
|
192 |
|
193 |
|
194 |
+
def check_for_onnx_file(model_info: ModelInfo) -> bool:
|
195 |
+
"""
|
196 |
+
Checks if the model contains a `.onnx` file.
|
197 |
+
"""
|
198 |
+
for repo_file in model_info.siblings:
|
199 |
+
if repo_file.rfilename.endswith(".onnx"):
|
200 |
+
return True
|
201 |
+
return False
|
202 |
+
|
203 |
+
|
204 |
+
def get_models_list(filter_bad_models):
|
205 |
"""
|
206 |
Get the list of models from the hub and the ELO file.
|
207 |
|
|
|
210 |
models = []
|
211 |
models_ids = []
|
212 |
data = pd.read_csv(os.path.join(DATASET_REPO_URL, "resolve", "main", ELO_FILENAME))
|
213 |
+
models_with_onnx_on_hub = filter(check_for_onnx_file, api.list_models(filter=["reinforcement-learning", "ml-agents", "ML-Agents-SoccerTwos"]))
|
214 |
+
model_ids_with_onnx = {model.modelId for model in models_with_onnx_on_hub}
|
215 |
for i, row in data.iterrows():
|
216 |
model_id = row["author"] + "/" + row["model"]
|
217 |
+
if model_id in filter_bad_models or model_id not in model_ids_with_onnx:
|
218 |
continue
|
219 |
models.append(Model(row["author"], row["model"], row["elo"], row["games_played"]))
|
220 |
models_ids.append(model_id)
|
221 |
+
for model in models_with_onnx_on_hub:
|
222 |
if model.modelId in filter_bad_models:
|
223 |
continue
|
224 |
author, name = model.modelId.split("/")[0], model.modelId.split("/")[1]
|