Christina Theodoris
commited on
Commit
•
0d675a3
1
Parent(s):
316d817
add load model for train and fix validate anchor gene error
Browse files
geneformer/in_silico_perturber.py
CHANGED
@@ -263,7 +263,7 @@ class InSilicoPerturber:
|
|
263 |
"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
|
264 |
)
|
265 |
raise
|
266 |
-
if (self.combos > 0) and (self.
|
267 |
logger.error(
|
268 |
"Combination perturbation without anchor gene is currently under development. "
|
269 |
"Currently, must provide anchor gene for combination perturbation."
|
@@ -416,7 +416,9 @@ class InSilicoPerturber:
|
|
416 |
)
|
417 |
|
418 |
### load model and define parameters ###
|
419 |
-
model = pu.load_model(
|
|
|
|
|
420 |
self.max_len = pu.get_model_input_size(model)
|
421 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
422 |
|
|
|
263 |
"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
|
264 |
)
|
265 |
raise
|
266 |
+
if (self.combos > 0) and (self.anchor_gene is None):
|
267 |
logger.error(
|
268 |
"Combination perturbation without anchor gene is currently under development. "
|
269 |
"Currently, must provide anchor gene for combination perturbation."
|
|
|
416 |
)
|
417 |
|
418 |
### load model and define parameters ###
|
419 |
+
model = pu.load_model(
|
420 |
+
self.model_type, self.num_classes, model_directory, mode="eval"
|
421 |
+
)
|
422 |
self.max_len = pu.get_model_input_size(model)
|
423 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
424 |
|
geneformer/perturber_utils.py
CHANGED
@@ -108,28 +108,36 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
|
|
108 |
|
109 |
|
110 |
# load model to GPU
|
111 |
-
def load_model(model_type, num_classes, model_directory):
|
|
|
|
|
|
|
|
|
|
|
112 |
if model_type == "Pretrained":
|
113 |
model = BertForMaskedLM.from_pretrained(
|
114 |
-
model_directory,
|
|
|
|
|
115 |
)
|
116 |
elif model_type == "GeneClassifier":
|
117 |
model = BertForTokenClassification.from_pretrained(
|
118 |
model_directory,
|
119 |
num_labels=num_classes,
|
120 |
-
output_hidden_states=
|
121 |
output_attentions=False,
|
122 |
)
|
123 |
elif model_type == "CellClassifier":
|
124 |
model = BertForSequenceClassification.from_pretrained(
|
125 |
model_directory,
|
126 |
num_labels=num_classes,
|
127 |
-
output_hidden_states=
|
128 |
output_attentions=False,
|
129 |
)
|
130 |
-
# put the model in eval mode for fwd pass
|
131 |
-
|
132 |
-
|
|
|
133 |
return model
|
134 |
|
135 |
|
|
|
108 |
|
109 |
|
110 |
# load model to GPU
|
111 |
+
def load_model(model_type, num_classes, model_directory, mode):
|
112 |
+
if mode == "eval":
|
113 |
+
output_hidden_states = True
|
114 |
+
elif mode == "train":
|
115 |
+
output_hidden_states = False
|
116 |
+
|
117 |
if model_type == "Pretrained":
|
118 |
model = BertForMaskedLM.from_pretrained(
|
119 |
+
model_directory,
|
120 |
+
output_hidden_states=output_hidden_states,
|
121 |
+
output_attentions=False,
|
122 |
)
|
123 |
elif model_type == "GeneClassifier":
|
124 |
model = BertForTokenClassification.from_pretrained(
|
125 |
model_directory,
|
126 |
num_labels=num_classes,
|
127 |
+
output_hidden_states=output_hidden_states,
|
128 |
output_attentions=False,
|
129 |
)
|
130 |
elif model_type == "CellClassifier":
|
131 |
model = BertForSequenceClassification.from_pretrained(
|
132 |
model_directory,
|
133 |
num_labels=num_classes,
|
134 |
+
output_hidden_states=output_hidden_states,
|
135 |
output_attentions=False,
|
136 |
)
|
137 |
+
# if eval mode, put the model in eval mode for fwd pass
|
138 |
+
if mode == "eval":
|
139 |
+
model.eval()
|
140 |
+
model = model.to("cuda")
|
141 |
return model
|
142 |
|
143 |
|