dictionaries from parent dir

#405
Files changed (1) hide show
  1. geneformer/mtl/collators.py +13 -10
geneformer/mtl/collators.py CHANGED
@@ -1,18 +1,24 @@
1
  # imports
2
  import torch
 
3
  from ..collator_for_classification import DataCollatorForGeneClassification
4
- from . import TOKEN_DICTIONARY # import the token dictionary from the mtl module's init
5
 
6
- """
7
- Geneformer collator for multi-task cell classification.
8
- """
9
 
10
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
11
  class_type = "cell"
12
 
 
 
 
 
 
13
  def __init__(self, *args, **kwargs) -> None:
14
- # Use the loaded token dictionary from the mtl module's init
15
- super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
 
 
16
 
17
  def _prepare_batch(self, features):
18
  # Process inputs as usual
@@ -29,7 +35,6 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
29
  if "label" in features[0]:
30
  # Initialize labels dictionary for all tasks
31
  labels = {task: [] for task in features[0]["label"].keys()}
32
-
33
  # Populate labels for each task
34
  for feature in features:
35
  for task, label in feature["label"].items():
@@ -57,7 +62,6 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
57
 
58
  def __call__(self, features):
59
  batch = self._prepare_batch(features)
60
-
61
  for k, v in batch.items():
62
  if torch.is_tensor(v):
63
  batch[k] = v.clone().detach()
@@ -69,5 +73,4 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
69
  }
70
  else:
71
  batch[k] = torch.tensor(v, dtype=torch.int64)
72
-
73
- return batch
 
1
  # imports
2
  import torch
3
+ import pickle
4
  from ..collator_for_classification import DataCollatorForGeneClassification
5
+ from .. import TOKEN_DICTIONARY_FILE
6
 
7
+ """Geneformer collator for multi-task cell classification."""
 
 
8
 
9
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
10
  class_type = "cell"
11
 
12
+ @staticmethod
13
+ def load_token_dictionary():
14
+ with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
15
+ return pickle.load(f)
16
+
17
  def __init__(self, *args, **kwargs) -> None:
18
+ # Load the token dictionary
19
+ token_dictionary = self.load_token_dictionary()
20
+ # Use the loaded token dictionary
21
+ super().__init__(token_dictionary=token_dictionary, *args, **kwargs)
22
 
23
  def _prepare_batch(self, features):
24
  # Process inputs as usual
 
35
  if "label" in features[0]:
36
  # Initialize labels dictionary for all tasks
37
  labels = {task: [] for task in features[0]["label"].keys()}
 
38
  # Populate labels for each task
39
  for feature in features:
40
  for task, label in feature["label"].items():
 
62
 
63
  def __call__(self, features):
64
  batch = self._prepare_batch(features)
 
65
  for k, v in batch.items():
66
  if torch.is_tensor(v):
67
  batch[k] = v.clone().detach()
 
73
  }
74
  else:
75
  batch[k] = torch.tensor(v, dtype=torch.int64)
76
+ return batch