Christina Theodoris
commited on
Commit
•
eeba323
1
Parent(s):
f75f5ac
update examples for predict_eval and handle roc for 2 cell classes
Browse files- examples/cell_classification.ipynb +1 -2
- geneformer/classifier.py +52 -34
- geneformer/evaluation_utils.py +1 -1
examples/cell_classification.ipynb
CHANGED
@@ -266,8 +266,7 @@
|
|
266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
267 |
" output_directory=output_dir,\n",
|
268 |
" output_prefix=output_prefix,\n",
|
269 |
-
" split_id_dict=train_valid_id_split_dict
|
270 |
-
" predict=True)"
|
271 |
]
|
272 |
},
|
273 |
{
|
|
|
266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
267 |
" output_directory=output_dir,\n",
|
268 |
" output_prefix=output_prefix,\n",
|
269 |
+
" split_id_dict=train_valid_id_split_dict)"
|
|
|
270 |
]
|
271 |
},
|
272 |
{
|
geneformer/classifier.py
CHANGED
@@ -30,7 +30,7 @@ Geneformer classifier.
|
|
30 |
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
31 |
... output_directory="path/to/output_directory",
|
32 |
... output_prefix="output_prefix",
|
33 |
-
...
|
34 |
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
|
35 |
... output_directory="path/to/output_directory",
|
36 |
... output_prefix="output_prefix",
|
@@ -308,7 +308,7 @@ class Classifier:
|
|
308 |
output_directory,
|
309 |
output_prefix,
|
310 |
split_id_dict=None,
|
311 |
-
test_size=
|
312 |
attr_to_split=None,
|
313 |
attr_to_balance=None,
|
314 |
max_trials=100,
|
@@ -417,27 +417,48 @@ class Classifier:
|
|
417 |
data_dict["test"].save_to_disk(test_data_output_path)
|
418 |
elif (test_size is not None) and (self.classifier == "cell"):
|
419 |
if 1 > test_size > 0:
|
420 |
-
|
421 |
-
data
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
).with_suffix(".dataset")
|
439 |
-
|
440 |
-
|
441 |
else:
|
442 |
data_output_path = (
|
443 |
Path(output_directory) / f"{output_prefix}_labeled"
|
@@ -1012,7 +1033,7 @@ class Classifier:
|
|
1012 |
model = pu.load_model(model_type, num_classes, model_directory, "eval")
|
1013 |
|
1014 |
# evaluate the model
|
1015 |
-
|
1016 |
model,
|
1017 |
num_classes,
|
1018 |
id_class_dict,
|
@@ -1023,24 +1044,21 @@ class Classifier:
|
|
1023 |
)
|
1024 |
|
1025 |
all_conf_mat_df = pd.DataFrame(
|
1026 |
-
|
1027 |
columns=id_class_dict.values(),
|
1028 |
index=id_class_dict.values(),
|
1029 |
)
|
1030 |
all_metrics = {
|
1031 |
"conf_matrix": all_conf_mat_df,
|
1032 |
-
"macro_f1":
|
1033 |
-
"acc":
|
1034 |
}
|
1035 |
all_roc_metrics = None # roc metrics not reported for multiclass
|
|
|
1036 |
if num_classes == 2:
|
1037 |
mean_fpr = np.linspace(0, 1, 100)
|
1038 |
-
|
1039 |
-
all_roc_auc =
|
1040 |
-
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
|
1041 |
-
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
|
1042 |
-
all_tpr, all_roc_auc, all_tpr_wt
|
1043 |
-
)
|
1044 |
all_roc_metrics = {
|
1045 |
"mean_tpr": mean_tpr,
|
1046 |
"mean_fpr": mean_fpr,
|
@@ -1137,7 +1155,7 @@ class Classifier:
|
|
1137 |
|
1138 |
predictions_file : path
|
1139 |
| Path of model predictions output to plot
|
1140 |
-
| (saved output from self.validate if
|
1141 |
| (or saved output from self.evaluate_saved_model)
|
1142 |
id_class_dict_file : Path
|
1143 |
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
@@ -1173,7 +1191,7 @@ class Classifier:
|
|
1173 |
predictions_logits = np.array(predictions["predictions"])
|
1174 |
true_ids = predictions["label_ids"]
|
1175 |
else:
|
1176 |
-
# format is output from self.validate if
|
1177 |
predictions_logits = predictions.predictions
|
1178 |
true_ids = predictions.label_ids
|
1179 |
|
|
|
30 |
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
31 |
... output_directory="path/to/output_directory",
|
32 |
... output_prefix="output_prefix",
|
33 |
+
... predict_eval=True)
|
34 |
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
|
35 |
... output_directory="path/to/output_directory",
|
36 |
... output_prefix="output_prefix",
|
|
|
308 |
output_directory,
|
309 |
output_prefix,
|
310 |
split_id_dict=None,
|
311 |
+
test_size=None,
|
312 |
attr_to_split=None,
|
313 |
attr_to_balance=None,
|
314 |
max_trials=100,
|
|
|
417 |
data_dict["test"].save_to_disk(test_data_output_path)
|
418 |
elif (test_size is not None) and (self.classifier == "cell"):
|
419 |
if 1 > test_size > 0:
|
420 |
+
if attr_to_split is None:
|
421 |
+
data_dict = data.train_test_split(
|
422 |
+
test_size=test_size,
|
423 |
+
stratify_by_column=self.stratify_splits_col,
|
424 |
+
seed=42,
|
425 |
+
)
|
426 |
+
train_data_output_path = (
|
427 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
428 |
+
).with_suffix(".dataset")
|
429 |
+
test_data_output_path = (
|
430 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
431 |
+
).with_suffix(".dataset")
|
432 |
+
data_dict["train"].save_to_disk(train_data_output_path)
|
433 |
+
data_dict["test"].save_to_disk(test_data_output_path)
|
434 |
+
else:
|
435 |
+
data_dict, balance_df = cu.balance_attr_splits(
|
436 |
+
data,
|
437 |
+
attr_to_split,
|
438 |
+
attr_to_balance,
|
439 |
+
test_size,
|
440 |
+
max_trials,
|
441 |
+
pval_threshold,
|
442 |
+
self.cell_state_dict["state_key"],
|
443 |
+
self.nproc,
|
444 |
+
)
|
445 |
+
balance_df.to_csv(
|
446 |
+
f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
|
447 |
+
)
|
448 |
+
train_data_output_path = (
|
449 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
450 |
+
).with_suffix(".dataset")
|
451 |
+
test_data_output_path = (
|
452 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
453 |
+
).with_suffix(".dataset")
|
454 |
+
data_dict["train"].save_to_disk(train_data_output_path)
|
455 |
+
data_dict["test"].save_to_disk(test_data_output_path)
|
456 |
+
else:
|
457 |
+
data_output_path = (
|
458 |
+
Path(output_directory) / f"{output_prefix}_labeled"
|
459 |
).with_suffix(".dataset")
|
460 |
+
data.save_to_disk(data_output_path)
|
461 |
+
print(data_output_path)
|
462 |
else:
|
463 |
data_output_path = (
|
464 |
Path(output_directory) / f"{output_prefix}_labeled"
|
|
|
1033 |
model = pu.load_model(model_type, num_classes, model_directory, "eval")
|
1034 |
|
1035 |
# evaluate the model
|
1036 |
+
result = self.evaluate_model(
|
1037 |
model,
|
1038 |
num_classes,
|
1039 |
id_class_dict,
|
|
|
1044 |
)
|
1045 |
|
1046 |
all_conf_mat_df = pd.DataFrame(
|
1047 |
+
result["conf_mat"],
|
1048 |
columns=id_class_dict.values(),
|
1049 |
index=id_class_dict.values(),
|
1050 |
)
|
1051 |
all_metrics = {
|
1052 |
"conf_matrix": all_conf_mat_df,
|
1053 |
+
"macro_f1": result["macro_f1"],
|
1054 |
+
"acc": result["acc"],
|
1055 |
}
|
1056 |
all_roc_metrics = None # roc metrics not reported for multiclass
|
1057 |
+
|
1058 |
if num_classes == 2:
|
1059 |
mean_fpr = np.linspace(0, 1, 100)
|
1060 |
+
mean_tpr = result["roc_metrics"]["interp_tpr"]
|
1061 |
+
all_roc_auc = result["roc_metrics"]["auc"]
|
|
|
|
|
|
|
|
|
1062 |
all_roc_metrics = {
|
1063 |
"mean_tpr": mean_tpr,
|
1064 |
"mean_fpr": mean_fpr,
|
|
|
1155 |
|
1156 |
predictions_file : path
|
1157 |
| Path of model predictions output to plot
|
1158 |
+
| (saved output from self.validate if predict_eval=True)
|
1159 |
| (or saved output from self.evaluate_saved_model)
|
1160 |
id_class_dict_file : Path
|
1161 |
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
|
|
1191 |
predictions_logits = np.array(predictions["predictions"])
|
1192 |
true_ids = predictions["label_ids"]
|
1193 |
else:
|
1194 |
+
# format is output from self.validate if predict_eval=True
|
1195 |
predictions_logits = predictions.predictions
|
1196 |
true_ids = predictions.label_ids
|
1197 |
|
geneformer/evaluation_utils.py
CHANGED
@@ -201,10 +201,10 @@ def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix
|
|
201 |
plt.ylabel("True Positive Rate")
|
202 |
plt.title(title)
|
203 |
plt.legend(loc="lower right")
|
204 |
-
plt.show()
|
205 |
|
206 |
output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
|
207 |
plt.savefig(output_file, bbox_inches="tight")
|
|
|
208 |
|
209 |
|
210 |
# plot confusion matrix
|
|
|
201 |
plt.ylabel("True Positive Rate")
|
202 |
plt.title(title)
|
203 |
plt.legend(loc="lower right")
|
|
|
204 |
|
205 |
output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
|
206 |
plt.savefig(output_file, bbox_inches="tight")
|
207 |
+
plt.show()
|
208 |
|
209 |
|
210 |
# plot confusion matrix
|