Christina Theodoris commited on
Commit
bb217cf
1 Parent(s): 3d06203

Add filtering for start state cells prior to in silico perturbation when modeling cell states

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +10 -0
geneformer/in_silico_perturber.py CHANGED
@@ -447,6 +447,9 @@ class InSilicoPerturber:
447
  if type(attr_value) not in {list, dict}:
448
  if attr_value in valid_options:
449
  continue
 
 
 
450
  valid_type = False
451
  for option in valid_options:
452
  if (option in [int,list,dict]) and isinstance(attr_value, option):
@@ -555,6 +558,13 @@ class InSilicoPerturber:
555
  self.gene_token_dict,
556
  self.forward_batch_size,
557
  self.nproc)
 
 
 
 
 
 
 
558
  self.in_silico_perturb(model,
559
  filtered_input_data,
560
  layer_to_quant,
 
447
  if type(attr_value) not in {list, dict}:
448
  if attr_value in valid_options:
449
  continue
450
+ if attr_name in ["anchor_gene"]:
451
+ if type(attr_name) in {str}:
452
+ continue
453
  valid_type = False
454
  for option in valid_options:
455
  if (option in [int,list,dict]) and isinstance(attr_value, option):
 
558
  self.gene_token_dict,
559
  self.forward_batch_size,
560
  self.nproc)
561
+ # filter for start state cells
562
+ start_state = list(self.cell_states_to_model.values())[0][0][0]
563
+ def filter_for_origin(example):
564
+ return example[list(self.cell_states_to_model.keys())[0]] in [start_state]
565
+
566
+ filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
567
+
568
  self.in_silico_perturb(model,
569
  filtered_input_data,
570
  layer_to_quant,