ctheodoris commited on
Commit
3fe35ba
1 Parent(s): 2e64874

update refs to get_model_emb_dims

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +2 -2
geneformer/emb_extractor.py CHANGED
@@ -622,13 +622,13 @@ class EmbExtractor:
622
 
623
  if self.exact_summary_stat == "exact_mean":
624
  embs = embs.mean(dim=0)
625
- emb_dims = pu.get_model_embedding_dimensions(model)
626
  embs_df = pd.DataFrame(
627
  embs_df[0:emb_dims-1].mean(axis="rows"), columns=[self.exact_summary_stat]
628
  ).T
629
  elif self.exact_summary_stat == "exact_median":
630
  embs = torch.median(embs, dim=0)[0]
631
- emb_dims = pu.get_model_embedding_dimensions(model)
632
  embs_df = pd.DataFrame(
633
  embs_df[0:emb_dims-1].median(axis="rows"), columns=[self.exact_summary_stat]
634
  ).T
 
622
 
623
  if self.exact_summary_stat == "exact_mean":
624
  embs = embs.mean(dim=0)
625
+ emb_dims = pu.get_model_emb_dims(model)
626
  embs_df = pd.DataFrame(
627
  embs_df[0:emb_dims-1].mean(axis="rows"), columns=[self.exact_summary_stat]
628
  ).T
629
  elif self.exact_summary_stat == "exact_median":
630
  embs = torch.median(embs, dim=0)[0]
631
+ emb_dims = pu.get_model_emb_dims(model)
632
  embs_df = pd.DataFrame(
633
  embs_df[0:emb_dims-1].median(axis="rows"), columns=[self.exact_summary_stat]
634
  ).T