jupyterjazz commited on
Commit
b2590e9
1 Parent(s): 6343db7

feat: multi cls tokenizer

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (1) hide show
  1. tokenizer.py +53 -32
tokenizer.py CHANGED
@@ -1,48 +1,69 @@
1
- import torch
2
- import numpy as np
3
- from transformers import RobertaTokenizer, BatchEncoding
4
  import warnings
5
 
 
 
 
 
6
 
7
  class JinaTokenizer(RobertaTokenizer):
8
- def __init__(self, *args, task_type_vocab_size=6, **kwargs):
 
 
9
  super().__init__(*args, **kwargs)
10
  self.task_type_vocab_size = task_type_vocab_size
 
11
 
12
  def __call__(self, *args, task_type=None, **kwargs):
13
- batch_encoding = super().__call__(*args, **kwargs)
14
- batch_encoding = BatchEncoding(
15
- {
16
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
17
- **batch_encoding,
18
- },
19
- tensor_type=kwargs.get('return_tensors'),
 
 
20
  )
21
- return batch_encoding
22
 
23
- def _batch_encode_plus(self, *args, task_type=None, **kwargs):
24
- batch_encoding = super()._batch_encode_plus(*args, **kwargs)
25
- if task_type is not None:
26
- batch_encoding = BatchEncoding(
27
- {
28
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
29
- **batch_encoding,
30
- },
31
- tensor_type=kwargs.get('return_tensors'),
32
  )
33
- return batch_encoding
 
 
34
 
35
- def _encode_plus(self, *args, task_type=None, **kwargs):
36
- batch_encoding = super()._encode_plus(*args, **kwargs)
37
  if task_type is not None:
38
- batch_encoding = BatchEncoding(
39
- {
40
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
41
- **batch_encoding,
42
- },
43
- tensor_type=kwargs.get('return_tensors'),
44
- )
45
- return batch_encoding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  @staticmethod
48
  def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):
 
 
 
 
1
  import warnings
2
 
3
+ import numpy as np
4
+ import torch
5
+ from transformers import BatchEncoding, RobertaTokenizer
6
+
7
 
8
  class JinaTokenizer(RobertaTokenizer):
9
+ def __init__(
10
+ self, *args, task_type_vocab_size=6, cls_token_interval=None, **kwargs
11
+ ):
12
  super().__init__(*args, **kwargs)
13
  self.task_type_vocab_size = task_type_vocab_size
14
+ self.cls_token_interval = cls_token_interval
15
 
16
  def __call__(self, *args, task_type=None, **kwargs):
17
+ kwargs['task_type'] = task_type
18
+ return super().__call__(*args, **kwargs)
19
+
20
+ def _encode_plus(self, *args, **kwargs):
21
+ return self._process_encoding(super()._encode_plus(*args, **kwargs), **kwargs)
22
+
23
+ def _batch_encode_plus(self, *args, **kwargs):
24
+ return self._process_encoding(
25
+ super()._batch_encode_plus(*args, **kwargs), **kwargs
26
  )
 
27
 
28
+ def _process_encoding(self, batch_encoding, **kwargs):
29
+ task_type = kwargs.get('task_type')
30
+ if self.cls_token_interval is not None:
31
+ modified_input_ids, modified_attention_mask = self._insert_cls_tokens(
32
+ batch_encoding
 
 
 
 
33
  )
34
+ batch_encoding['input_ids'] = modified_input_ids
35
+ if 'attention_mask' in batch_encoding:
36
+ batch_encoding['attention_mask'] = modified_attention_mask
37
 
 
 
38
  if task_type is not None:
39
+ task_type_ids = self._get_task_type_ids(batch_encoding, task_type)
40
+ batch_encoding['task_type_ids'] = task_type_ids
41
+
42
+ return BatchEncoding(batch_encoding, tensor_type=kwargs.get('return_tensors'))
43
+
44
+ def _insert_cls_tokens(self, batch_encoding):
45
+ new_input_ids = []
46
+ new_attention_masks = []
47
+ sequences = batch_encoding['input_ids'].tolist()
48
+
49
+ for sequence in sequences:
50
+ modified_sequence = [sequence[0]]
51
+
52
+ for i in range(1, len(sequence), self.cls_token_interval):
53
+ chunk = sequence[i : i + self.cls_token_interval]
54
+ modified_sequence.extend(chunk)
55
+
56
+ if i + self.cls_token_interval < len(sequence):
57
+ modified_sequence.append(self.cls_token_id)
58
+
59
+ attention_mask = [1 for _ in range(len(modified_sequence))]
60
+ new_input_ids.append(modified_sequence)
61
+ new_attention_masks.append(attention_mask)
62
+
63
+ new_input_ids = torch.tensor(new_input_ids, dtype=torch.long)
64
+ new_attention_masks = torch.tensor(new_attention_masks, dtype=torch.long)
65
+
66
+ return new_input_ids, new_attention_masks
67
 
68
  @staticmethod
69
  def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):