Files changed (9) hide show
  1. README.md +0 -33
  2. config.json +0 -40
  3. configuration_bert.py +6 -6
  4. convert_v2_weights.py +0 -151
  5. mha.py +0 -4
  6. mlp.py +0 -47
  7. modeling_bert.py +27 -145
  8. modeling_lora.py +15 -75
  9. tokenizer.py +88 -0
README.md DELETED
@@ -1,33 +0,0 @@
1
- # BERT with Flash-Attention
2
- ### Installing dependencies
3
- To run the model on GPU, you need to install Flash Attention.
4
- You may either install from pypi (which may not work with fused-dense), or from source.
5
- To install from source, clone the GitHub repository:
6
- ```console
7
- git clone [email protected]:Dao-AILab/flash-attention.git
8
- ```
9
- The code provided here should work with commit `43950dd`.
10
- Change to the cloned repo and install:
11
- ```console
12
- cd flash-attention && python setup.py install
13
- ```
14
- This will compile the flash-attention kernel, which will take some time.
15
-
16
- If you would like to use fused MLPs (e.g. to use activation checkpointing),
17
- you may install fused-dense also from source:
18
- ```console
19
- cd csrc/fused_dense_lib && python setup.py install
20
- ```
21
-
22
-
23
- ### Configuration
24
- The config adds some new parameters:
25
- - `use_flash_attn`: If `True`, always use flash attention. If `None`, use flash attention when GPU is available. If `False`, never use flash attention (works on CPU).
26
- - `window_size`: Size (left and right) of the local attention window. If `(-1, -1)`, use global attention
27
- - `dense_seq_output`: If true, we only need to pass the hidden states for the masked out token (around 15%) to the classifier heads. I set this to true for pretraining.
28
- - `fused_mlp`: Whether to use fused-dense. Useful to reduce VRAM in combination with activation checkpointing
29
- - `mlp_checkpoint_lvl`: One of `{0, 1, 2}`. Increasing this increases the amount of activation checkpointing within the MLP. Keep this at 0 for pretraining and use gradient accumulation instead. For embedding training, increase this as much as needed.
30
- - `last_layer_subset`: If true, we only need the compute the last layer for a subset of tokens. I left this to false.
31
- - `use_qk_norm`: Whether or not to use QK-normalization
32
- - `num_loras`: Number of LoRAs to use when initializing a `BertLoRA` model. Has no effect on other models.
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json DELETED
@@ -1,40 +0,0 @@
1
- {
2
- "_name_or_path": "jinaai/jina-bert-flash-implementation",
3
- "auto_map": {
4
- "AutoConfig": "jinaai/jina-bert-flash-implementation--configuration_bert.JinaBertConfig",
5
- "AutoModel": "jinaai/jina-bert-flash-implementation--modeling_bert.BertModel",
6
- "AutoModelForPreTraining": "jinaai/jina-bert-flash-implementation--modeling_bert.BertForPreTraining",
7
- "AutoModelForMaskedLM": "jinaai/jina-bert-flash-implementation--modeling_bert.BertForPreTraining"
8
- },
9
- "attention_probs_dropout_prob": 0.1,
10
- "classifier_dropout": null,
11
- "dense_seq_output": false,
12
- "emb_pooler": null,
13
- "fused_bias_fc": false,
14
- "fused_dropout_add_ln": false,
15
- "hidden_act": "gelu",
16
- "hidden_dropout_prob": 0.1,
17
- "hidden_size": 768,
18
- "initializer_range": 0.02,
19
- "intermediate_size": 3072,
20
- "last_layer_subset": false,
21
- "layer_norm_eps": 1e-12,
22
- "mlp_checkpoint_lvl": 0,
23
- "mlp_type": "glu",
24
- "model_type": "bert",
25
- "num_attention_heads": 12,
26
- "num_hidden_layers": 12,
27
- "num_loras": 5,
28
- "pad_token_id": 0,
29
- "pad_vocab_size_multiple": 1,
30
- "torch_dtype": "float16",
31
- "transformers_version": "4.39.3",
32
- "type_vocab_size": 2,
33
- "use_flash_attn": true,
34
- "use_qk_norm": false,
35
- "vocab_size": 30528,
36
- "window_size": [
37
- -1,
38
- -1
39
- ]
40
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configuration_bert.py CHANGED
@@ -75,24 +75,24 @@ class JinaBertConfig(PretrainedConfig):
75
  pad_token_id=0,
76
  window_size=(-1, -1),
77
  dense_seq_output=False,
78
- mlp_type='mlp',
79
  mlp_checkpoint_lvl=0,
80
  last_layer_subset=False,
81
  fused_dropout_add_ln=False,
82
  fused_bias_fc=False,
83
  pad_vocab_size_multiple=1,
 
84
  use_flash_attn=True,
85
  use_qk_norm=True,
86
  emb_pooler=None,
87
  classifier_dropout=None,
88
- num_loras=5,
89
  **kwargs,
90
  ):
91
  assert 'position_embedding_type' not in kwargs
92
  assert 'max_position_embeddings' not in kwargs
93
  super().__init__(pad_token_id=pad_token_id, **kwargs)
94
 
95
- if mlp_type == 'fused_mlp' and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
96
  raise ValueError('Fused MLP only supports approximate gelu')
97
 
98
  self.vocab_size = vocab_size
@@ -108,14 +108,14 @@ class JinaBertConfig(PretrainedConfig):
108
  self.layer_norm_eps = layer_norm_eps
109
  self.window_size = window_size
110
  self.dense_seq_output = dense_seq_output
111
- self.mlp_type= mlp_type
112
  self.mlp_checkpoint_lvl = mlp_checkpoint_lvl
113
  self.last_layer_subset = last_layer_subset
114
  self.fused_dropout_add_ln = fused_dropout_add_ln
115
  self.fused_bias_fc = fused_bias_fc
116
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
 
117
  self.use_flash_attn = use_flash_attn
118
  self.use_qk_norm = use_qk_norm
119
  self.emb_pooler = emb_pooler
120
- self.classifier_dropout = classifier_dropout
121
- self.num_loras = num_loras
 
75
  pad_token_id=0,
76
  window_size=(-1, -1),
77
  dense_seq_output=False,
78
+ fused_mlp=False,
79
  mlp_checkpoint_lvl=0,
80
  last_layer_subset=False,
81
  fused_dropout_add_ln=False,
82
  fused_bias_fc=False,
83
  pad_vocab_size_multiple=1,
84
+ num_tasks=0,
85
  use_flash_attn=True,
86
  use_qk_norm=True,
87
  emb_pooler=None,
88
  classifier_dropout=None,
 
89
  **kwargs,
90
  ):
91
  assert 'position_embedding_type' not in kwargs
92
  assert 'max_position_embeddings' not in kwargs
93
  super().__init__(pad_token_id=pad_token_id, **kwargs)
94
 
95
+ if fused_mlp and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
96
  raise ValueError('Fused MLP only supports approximate gelu')
97
 
98
  self.vocab_size = vocab_size
 
108
  self.layer_norm_eps = layer_norm_eps
109
  self.window_size = window_size
110
  self.dense_seq_output = dense_seq_output
111
+ self.fused_mlp = fused_mlp
112
  self.mlp_checkpoint_lvl = mlp_checkpoint_lvl
113
  self.last_layer_subset = last_layer_subset
114
  self.fused_dropout_add_ln = fused_dropout_add_ln
115
  self.fused_bias_fc = fused_bias_fc
116
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
117
+ self.num_tasks = num_tasks
118
  self.use_flash_attn = use_flash_attn
119
  self.use_qk_norm = use_qk_norm
120
  self.emb_pooler = emb_pooler
121
+ self.classifier_dropout = classifier_dropout
 
convert_v2_weights.py DELETED
@@ -1,151 +0,0 @@
1
- import re
2
- from collections import OrderedDict
3
- from transformers import AutoModel, AutoTokenizer
4
- from .configuration_bert import JinaBertConfig
5
- import torch
6
- from .modeling_bert import BertModel
7
-
8
- def remap_state_dict(state_dict, config: JinaBertConfig):
9
- """
10
- Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
11
- """
12
-
13
- # LayerNorm
14
- def key_mapping_ln_gamma_beta(key):
15
- key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
16
- key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
17
- return key
18
-
19
- state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
20
-
21
- # Layers
22
- def key_mapping_layers(key):
23
- return re.sub(r"^encoder.layer.", "encoder.layers.", key)
24
-
25
- state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
26
-
27
- # LayerNorm
28
- def key_mapping_ln(key):
29
- key = re.sub(r"^embeddings.LayerNorm.", "emb_ln.", key)
30
- key = re.sub(
31
- r"^encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
32
- r"encoder.layers.\1.norm1.\2",
33
- key,
34
- )
35
- key = re.sub(
36
- r"^encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
37
- r"encoder.layers.\1.norm2.\2",
38
- key,
39
- )
40
- key = re.sub(
41
- r"^cls.predictions.transform.LayerNorm.(weight|bias)",
42
- r"cls.predictions.transform.layer_norm.\1",
43
- key,
44
- )
45
- return key
46
-
47
- state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
48
-
49
- # MLP
50
- def key_mapping_mlp(key):
51
- key = re.sub(
52
- r"^encoder.layers.(\d+).intermediate.dense.(weight|bias)",
53
- r"encoder.layers.\1.mlp.fc1.\2",
54
- key,
55
- )
56
- key = re.sub(
57
- r"^encoder.layers.(\d+).output.dense.(weight|bias)",
58
- r"encoder.layers.\1.mlp.fc2.\2",
59
- key,
60
- )
61
- return key
62
-
63
- state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
64
-
65
- # Attention
66
- last_layer_subset = getattr(config, "last_layer_subset", False)
67
- for d in range(config.num_hidden_layers):
68
- Wq = state_dict.pop(f"encoder.layers.{d}.attention.self.query.weight")
69
- Wk = state_dict.pop(f"encoder.layers.{d}.attention.self.key.weight")
70
- Wv = state_dict.pop(f"encoder.layers.{d}.attention.self.value.weight")
71
- bq = state_dict.pop(f"encoder.layers.{d}.attention.self.query.bias")
72
- bk = state_dict.pop(f"encoder.layers.{d}.attention.self.key.bias")
73
- bv = state_dict.pop(f"encoder.layers.{d}.attention.self.value.bias")
74
- if not (last_layer_subset and d == config.num_hidden_layers - 1):
75
- state_dict[f"encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
76
- [Wq, Wk, Wv], dim=0
77
- )
78
- state_dict[f"encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
79
- else:
80
- state_dict[f"encoder.layers.{d}.mixer.Wq.weight"] = Wq
81
- state_dict[f"encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
82
- state_dict[f"encoder.layers.{d}.mixer.Wq.bias"] = bq
83
- state_dict[f"encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
84
-
85
- def key_mapping_attn(key):
86
- return re.sub(
87
- r"^encoder.layers.(\d+).attention.output.dense.(weight|bias)",
88
- r"encoder.layers.\1.mixer.out_proj.\2",
89
- key,
90
- )
91
-
92
- state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
93
-
94
- def key_mapping_decoder_bias(key):
95
- return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
96
-
97
- state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
98
-
99
- # Word embedding
100
- pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
101
- if pad_vocab_size_multiple > 1:
102
- word_embeddings = state_dict["embeddings.word_embeddings.weight"]
103
- state_dict["embeddings.word_embeddings.weight"] = F.pad(
104
- word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
105
- )
106
- decoder_weight = state_dict["cls.predictions.decoder.weight"]
107
- state_dict["cls.predictions.decoder.weight"] = F.pad(
108
- decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
109
- )
110
- # If the vocab was padded, we want to set the decoder bias for those padded indices to be
111
- # strongly negative (i.e. the decoder shouldn't predict those indices).
112
- # TD [2022-05-09]: I don't think it affects the MLPerf training.
113
- decoder_bias = state_dict["cls.predictions.decoder.bias"]
114
- state_dict["cls.predictions.decoder.bias"] = F.pad(
115
- decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
116
- )
117
-
118
- # LayerNorm
119
- def key_mapping_layernorm(key):
120
- return re.sub(r'^encoder.layers.(\d+).mlp.layernorm.(weight|bias)', r"encoder.layers.\1.norm2.\2", key)
121
-
122
- state_dict = OrderedDict((key_mapping_layernorm(k), v) for k, v in state_dict.items())
123
-
124
- return state_dict
125
-
126
-
127
- v2_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
128
- config = JinaBertConfig(vocab_size=30528, use_qk_norm=False, mlp_type='glu', hidden_act='gelu')
129
- state_dict = v2_model.state_dict()
130
- new_state_dict = remap_state_dict(state_dict, config)
131
- flash_model = BertModel(config)
132
- flash_model.load_state_dict(new_state_dict)
133
-
134
-
135
- torch.save(new_state_dict, 'converted_weights.bin')
136
- print(config.to_json_string())
137
-
138
-
139
- """
140
- tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
141
- inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
142
- v2_model.eval()
143
- flash_model.eval()
144
- v2_model = v2_model.to('cuda', torch.float16)
145
- flash_model = flash_model.to('cuda', torch.float16)
146
- output_v2 = v2_model(**inp)
147
- output_flash = flash_model(**inp)
148
- x = output_v2.last_hidden_state
149
- y = output_flash.last_hidden_state
150
- print(torch.abs(x - y))
151
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mha.py CHANGED
@@ -514,10 +514,6 @@ class MHA(nn.Module):
514
  alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
515
  else:
516
  alibi_slopes = None
517
-
518
- if isinstance(window_size, list):
519
- window_size = tuple(window_size)
520
-
521
  if window_size != (-1, -1):
522
  assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
523
 
 
514
  alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
515
  else:
516
  alibi_slopes = None
 
 
 
 
517
  if window_size != (-1, -1):
518
  assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
519
 
mlp.py CHANGED
@@ -27,53 +27,6 @@ except ImportError:
27
  FusedMLP, ParallelFusedMLP = None, None
28
 
29
 
30
- class GLUMLP(nn.Module):
31
- def __init__(
32
- self,
33
- in_features,
34
- hidden_features,
35
- activation,
36
- use_flash_attn,
37
- return_residual=False,
38
- hidden_dropout_prob=0.1
39
- ):
40
- super().__init__()
41
- self.hidden_features = hidden_features
42
- self.gated_layers = nn.Linear(
43
- in_features, hidden_features * 2, bias=False
44
- )
45
- if activation == 'relu':
46
- self.act = nn.ReLU()
47
- elif activation == 'gelu':
48
- self.act = nn.GELU()
49
- else:
50
- raise ValueError(
51
- f"activation {activation} not supported"
52
- )
53
- self.wo = nn.Linear(hidden_features, in_features)
54
- self.dropout = nn.Dropout(hidden_dropout_prob)
55
- self.return_residual = return_residual
56
- self.use_flash_attn = use_flash_attn
57
- #self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
58
-
59
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
60
- residual_connection = hidden_states
61
- # compute the activation
62
- hidden_states = self.gated_layers(hidden_states)
63
- if self.use_flash_attn:
64
- gated = hidden_states[:, : self.hidden_features]
65
- non_gated = hidden_states[:, self.hidden_features :]
66
- else:
67
- gated = hidden_states[:, :, : self.hidden_features]
68
- non_gated = hidden_states[:, :, self.hidden_features :]
69
- hidden_states = self.act(gated) * non_gated
70
- hidden_states = self.dropout(hidden_states)
71
- # multiply by the second matrix
72
- hidden_states = self.wo(hidden_states)
73
- # add the residual connection and post-LN
74
- # hidden_states = self.layernorm(hidden_states + residual_connection)
75
- return hidden_states if not self.return_residual else (hidden_states, residual_connection)
76
-
77
  class Mlp(nn.Module):
78
  def __init__(
79
  self,
 
27
  FusedMLP, ParallelFusedMLP = None, None
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class Mlp(nn.Module):
31
  def __init__(
32
  self,
modeling_bert.py CHANGED
@@ -39,7 +39,7 @@ from .bert_padding import (
39
  from .block import Block
40
  from .embedding import BertEmbeddings
41
  from .mha import MHA
42
- from .mlp import FusedMLP, Mlp, GLUMLP
43
 
44
  try:
45
  from flash_attn.ops.fused_dense import FusedDense
@@ -81,23 +81,19 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
81
  return_residual=return_residual,
82
  use_alibi=True,
83
  window_size=window_size,
84
- qk_norm=use_qk_norm,
85
- checkpointing=False,
86
  )
87
  return mixer_cls
88
 
89
 
90
  def create_mlp_cls(config, layer_idx=None, return_residual=False):
91
  inner_dim = config.intermediate_size
92
- mlp_type = config.mlp_type
93
- assert mlp_type in ('mlp', 'fused_mlp', 'glu')
94
- if mlp_type == 'fused_mlp':
95
  assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
96
  "fused_mlp only " "supports approximate gelu"
97
  )
98
- if mlp_type == 'glu':
99
- assert config.hidden_act in ('relu', 'gelu')
100
- if mlp_type == 'mlp':
101
  approximate = (
102
  "tanh"
103
  if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
@@ -109,16 +105,7 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
109
  activation=partial(F.gelu, approximate=approximate),
110
  return_residual=return_residual,
111
  )
112
- elif mlp_type == 'glu':
113
- mlp_cls = partial(
114
- GLUMLP,
115
- hidden_features=inner_dim,
116
- activation=config.hidden_act,
117
- use_flash_attn=config.use_flash_attn,
118
- hidden_dropout_prob=config.hidden_dropout_prob,
119
- return_residual=return_residual,
120
- )
121
- elif mlp_type == 'fused_mlp':
122
  if FusedMLP is None:
123
  raise ImportError("fused_dense is not installed")
124
  mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
@@ -132,8 +119,6 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
132
  checkpoint_lvl=mlp_checkpoint_lvl,
133
  return_residual=return_residual,
134
  )
135
- else:
136
- raise NotImplementedError
137
  return mlp_cls
138
 
139
 
@@ -167,7 +152,7 @@ def _init_weights(module, initializer_range=0.02):
167
  nn.init.normal_(module.weight, std=initializer_range)
168
  if module.bias is not None:
169
  nn.init.zeros_(module.bias)
170
- elif isinstance(module, nn.Embedding):
171
  nn.init.normal_(module.weight, std=initializer_range)
172
  if module.padding_idx is not None:
173
  nn.init.zeros_(module.weight[module.padding_idx])
@@ -189,6 +174,8 @@ class BertEncoder(nn.Module):
189
  @gradient_checkpointing.setter
190
  def gradient_checkpointing(self, value):
191
  self._grad_checkpointing = value
 
 
192
 
193
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
194
  """If subset_mask is not None, we only want output for the subset of the sequence.
@@ -200,15 +187,7 @@ class BertEncoder(nn.Module):
200
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
201
  )
202
  for layer in self.layers:
203
- if self._grad_checkpointing:
204
- hidden_states = torch.utils.checkpoint.checkpoint(
205
- layer,
206
- hidden_states,
207
- use_reentrant=False,
208
- mixer_kwargs=mixer_kwargs
209
- )
210
- else:
211
- hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
212
  if subset_mask is not None:
213
  hidden_states = hidden_states[subset_mask]
214
  else:
@@ -219,27 +198,11 @@ class BertEncoder(nn.Module):
219
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
220
  if subset_mask is None:
221
  for layer in self.layers:
222
- if self._grad_checkpointing:
223
- hidden_states = torch.utils.checkpoint.checkpoint(
224
- layer,
225
- hidden_states,
226
- use_reentrant=False,
227
- mixer_kwargs=mixer_kwargs
228
- )
229
- else:
230
- hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
231
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
232
  else:
233
  for layer in self.layers[:-1]:
234
- if self._grad_checkpointing:
235
- hidden_states = torch.utils.checkpoint.checkpoint(
236
- layer,
237
- hidden_states,
238
- use_reentrant=False,
239
- mixer_kwargs=mixer_kwargs
240
- )
241
- else:
242
- hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
243
  if key_padding_mask is not None:
244
  subset_idx = torch.nonzero(
245
  subset_mask[key_padding_mask], as_tuple=False
@@ -265,15 +228,7 @@ class BertEncoder(nn.Module):
265
  "cu_seqlens_k": cu_seqlens,
266
  "max_seqlen_k": max_seqlen_in_batch,
267
  }
268
- if self._grad_checkpointing:
269
- torch.utils.checkpoint.checkpoint(
270
- self.layers[-1],
271
- hidden_states_subset,
272
- use_reentrant=False,
273
- mixer_kwargs=mixer_kwargs
274
- )
275
- else:
276
- hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
277
  return hidden_states
278
 
279
 
@@ -396,16 +351,24 @@ class BertModel(BertPreTrainedModel):
396
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
397
  self.encoder = BertEncoder(config)
398
  self.pooler = BertPooler(config) if add_pooling_layer else None
 
399
 
400
  self.emb_pooler = config.emb_pooler
401
  self._name_or_path = config._name_or_path
402
  if self.emb_pooler is not None:
403
  from transformers import AutoTokenizer
404
 
405
- self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True)
406
  else:
407
  self.tokenizer = None
408
 
 
 
 
 
 
 
 
409
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
410
 
411
  def forward(
@@ -413,9 +376,9 @@ class BertModel(BertPreTrainedModel):
413
  input_ids,
414
  position_ids=None,
415
  token_type_ids=None,
 
416
  attention_mask=None,
417
  masked_tokens_mask=None,
418
- return_dict=True,
419
  ):
420
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
421
  we only want the output for the masked tokens. This means that we only compute the last
@@ -425,6 +388,8 @@ class BertModel(BertPreTrainedModel):
425
  hidden_states = self.embeddings(
426
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
427
  )
 
 
428
 
429
  # TD [2022-12:18]: Don't need to force residual in fp32
430
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -464,9 +429,6 @@ class BertModel(BertPreTrainedModel):
464
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
465
  pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
466
 
467
- if not return_dict:
468
- return (sequence_output, pooled_output)
469
-
470
  return BaseModelOutputWithPoolingAndCrossAttentions(
471
  last_hidden_state=sequence_output,
472
  pooler_output=pooled_output,
@@ -522,7 +484,7 @@ class BertModel(BertPreTrainedModel):
522
  self.emb_pooler = 'mean'
523
  from transformers import AutoTokenizer
524
 
525
- self.tokenizer = AutoTokenizer.from_pretrained(self._name_or_path, trust_remote_code=True)
526
  if self.emb_pooler != 'mean':
527
  raise NotImplementedError
528
 
@@ -723,84 +685,4 @@ class BertForPreTraining(BertPreTrainedModel):
723
  loss=total_loss,
724
  prediction_logits=prediction_scores,
725
  seq_relationship_logits=seq_relationship_score,
726
- )
727
-
728
-
729
- class BertForMaskedLM(BertPreTrainedModel):
730
- def __init__(self, config: JinaBertConfig):
731
- super().__init__(config)
732
- # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
733
- # (around 15%) to the classifier heads.
734
- self.dense_seq_output = getattr(config, "dense_seq_output", False)
735
- # If last_layer_subset, we only need the compute the last layer for a subset of tokens
736
- # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
737
- self.last_layer_subset = getattr(config, "last_layer_subset", False)
738
- if self.last_layer_subset:
739
- assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
740
- use_xentropy = getattr(config, "use_xentropy", False)
741
- if use_xentropy and CrossEntropyLoss is None:
742
- raise ImportError("xentropy_cuda is not installed")
743
- loss_cls = (
744
- nn.CrossEntropyLoss
745
- if not use_xentropy
746
- else partial(CrossEntropyLoss, inplace_backward=True)
747
- )
748
-
749
- self.bert = BertModel(config)
750
- self.cls = BertPreTrainingHeads(config)
751
- self.mlm_loss = loss_cls(ignore_index=0)
752
-
753
- # Initialize weights and apply final processing
754
- self.apply(partial(_init_weights, initializer_range=config.initializer_range))
755
- self.tie_weights()
756
-
757
- def tie_weights(self):
758
- self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
759
-
760
- def get_input_embeddings(self):
761
- return self.bert.embeddings.word_embeddings
762
-
763
- def forward(
764
- self,
765
- input_ids,
766
- position_ids=None,
767
- token_type_ids=None,
768
- attention_mask=None,
769
- labels=None
770
- ):
771
- masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
772
- outputs = self.bert(
773
- input_ids,
774
- position_ids=position_ids,
775
- token_type_ids=token_type_ids,
776
- attention_mask=attention_mask.bool() if attention_mask is not None else None,
777
- masked_tokens_mask=masked_tokens_mask,
778
- )
779
- sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
780
- if self.dense_seq_output and labels is not None:
781
- masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
782
- if not self.last_layer_subset:
783
- sequence_output = index_first_axis(
784
- rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
785
- )
786
- prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
787
-
788
- if (
789
- self.dense_seq_output and labels is not None
790
- ): # prediction_scores are already flattened
791
- masked_lm_loss = self.mlm_loss(
792
- prediction_scores, labels.flatten()[masked_token_idx]
793
- ).float()
794
- elif labels is not None:
795
- masked_lm_loss = self.mlm_loss(
796
- rearrange(prediction_scores, "... v -> (...) v"),
797
- rearrange(labels, "... -> (...)"),
798
- ).float()
799
- else:
800
- raise ValueError('MLM labels must not be None')
801
-
802
- return BertForPreTrainingOutput(
803
- loss=masked_lm_loss,
804
- prediction_logits=prediction_scores,
805
- seq_relationship_logits=seq_relationship_score,
806
- )
 
39
  from .block import Block
40
  from .embedding import BertEmbeddings
41
  from .mha import MHA
42
+ from .mlp import FusedMLP, Mlp
43
 
44
  try:
45
  from flash_attn.ops.fused_dense import FusedDense
 
81
  return_residual=return_residual,
82
  use_alibi=True,
83
  window_size=window_size,
84
+ qk_norm=use_qk_norm
 
85
  )
86
  return mixer_cls
87
 
88
 
89
  def create_mlp_cls(config, layer_idx=None, return_residual=False):
90
  inner_dim = config.intermediate_size
91
+ fused_mlp = getattr(config, "fused_mlp", False)
92
+ if fused_mlp:
 
93
  assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
94
  "fused_mlp only " "supports approximate gelu"
95
  )
96
+ if not fused_mlp:
 
 
97
  approximate = (
98
  "tanh"
99
  if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
 
105
  activation=partial(F.gelu, approximate=approximate),
106
  return_residual=return_residual,
107
  )
108
+ else:
 
 
 
 
 
 
 
 
 
109
  if FusedMLP is None:
110
  raise ImportError("fused_dense is not installed")
111
  mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
 
119
  checkpoint_lvl=mlp_checkpoint_lvl,
120
  return_residual=return_residual,
121
  )
 
 
122
  return mlp_cls
123
 
124
 
 
152
  nn.init.normal_(module.weight, std=initializer_range)
153
  if module.bias is not None:
154
  nn.init.zeros_(module.bias)
155
+ elif isinstance(module, nn.Embedding) and not getattr(module, "skip_init", False):
156
  nn.init.normal_(module.weight, std=initializer_range)
157
  if module.padding_idx is not None:
158
  nn.init.zeros_(module.weight[module.padding_idx])
 
174
  @gradient_checkpointing.setter
175
  def gradient_checkpointing(self, value):
176
  self._grad_checkpointing = value
177
+ for block in self.layers:
178
+ block.mixer.checkpointing = value
179
 
180
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
181
  """If subset_mask is not None, we only want output for the subset of the sequence.
 
187
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
188
  )
189
  for layer in self.layers:
190
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
191
  if subset_mask is not None:
192
  hidden_states = hidden_states[subset_mask]
193
  else:
 
198
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
199
  if subset_mask is None:
200
  for layer in self.layers:
201
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
202
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
203
  else:
204
  for layer in self.layers[:-1]:
205
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
206
  if key_padding_mask is not None:
207
  subset_idx = torch.nonzero(
208
  subset_mask[key_padding_mask], as_tuple=False
 
228
  "cu_seqlens_k": cu_seqlens,
229
  "max_seqlen_k": max_seqlen_in_batch,
230
  }
231
+ hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
232
  return hidden_states
233
 
234
 
 
351
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
352
  self.encoder = BertEncoder(config)
353
  self.pooler = BertPooler(config) if add_pooling_layer else None
354
+ self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
355
 
356
  self.emb_pooler = config.emb_pooler
357
  self._name_or_path = config._name_or_path
358
  if self.emb_pooler is not None:
359
  from transformers import AutoTokenizer
360
 
361
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
362
  else:
363
  self.tokenizer = None
364
 
365
+ # We now initialize the task embeddings to 0; We do not use task types during
366
+ # pretraining. When we start using task types during embedding training,
367
+ # we want the model to behave exactly as in pretraining (i.e. task types
368
+ # have no effect).
369
+ nn.init.zeros_(self.task_type_embeddings.weight)
370
+ self.task_type_embeddings.skip_init = True
371
+ # The following code should skip the embeddings layer
372
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
373
 
374
  def forward(
 
376
  input_ids,
377
  position_ids=None,
378
  token_type_ids=None,
379
+ task_type_ids=None,
380
  attention_mask=None,
381
  masked_tokens_mask=None,
 
382
  ):
383
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
384
  we only want the output for the masked tokens. This means that we only compute the last
 
388
  hidden_states = self.embeddings(
389
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
390
  )
391
+ if task_type_ids is not None:
392
+ hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
393
 
394
  # TD [2022-12:18]: Don't need to force residual in fp32
395
  # BERT puts embedding LayerNorm before embedding dropout.
 
429
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
430
  pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
431
 
 
 
 
432
  return BaseModelOutputWithPoolingAndCrossAttentions(
433
  last_hidden_state=sequence_output,
434
  pooler_output=pooled_output,
 
484
  self.emb_pooler = 'mean'
485
  from transformers import AutoTokenizer
486
 
487
+ self.tokenizer = AutoTokenizer.from_pretrained(self._name_or_path)
488
  if self.emb_pooler != 'mean':
489
  raise NotImplementedError
490
 
 
685
  loss=total_loss,
686
  prediction_logits=prediction_scores,
687
  seq_relationship_logits=seq_relationship_score,
688
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_lora.py CHANGED
@@ -65,8 +65,6 @@ class LoRAParametrization(nn.Module):
65
  fan_in_fan_out = layer_type == "embedding"
66
  self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
67
 
68
- # For the officially "correct" LoRA initialization, check here: https://github.com/microsoft/LoRA
69
- # TODO: Ensure that the initialization here is correct
70
  if layer_type == "linear":
71
  self.lora_A = nn.Parameter(
72
  initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
@@ -196,64 +194,30 @@ class LoRAParametrization(nn.Module):
196
  ),
197
  )
198
 
199
- @staticmethod
200
- def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
201
  if isinstance(layer, LoRAParametrization):
202
  layer.current_task = task_idx
203
 
204
- @staticmethod
205
- def merge_lora_into_layer(layer: nn.Module):
206
- if hasattr(layer, "parametrizations"):
207
- for attr_name in layer.parametrizations.keys():
208
- parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
209
-
210
 
211
  class BertLoRA(BertPreTrainedModel):
212
- def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
213
  super().__init__(config)
214
  if bert is None:
215
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
216
  else:
217
  self.bert = bert
218
- self._is_merged = False
219
- self._num_adaptions = config.num_loras
220
- self._register_lora(self._num_adaptions)
221
- self.main_params_trainable = False
222
- self._task_idx = None
223
- # By default, we select the first LoRA
224
- self.current_task = 0
225
-
226
- @property
227
- def main_params_trainable(self):
228
- return self._main_params_trainable
229
-
230
- @main_params_trainable.setter
231
- def main_params_trainable(self, val: bool):
232
- """Whether the main parameters (i.e. those that are not LoRA) should be trainable.
233
-
234
- This method sets the `requires_grad_` attribute of the main weights
235
- and controls which parameters are returned in `self.parameters()`.
236
-
237
- :param val: Whether or not to make the parameters trainable.
238
- :return: None
239
- """
240
- self._main_params_trainable = val
241
  for name, param in super().named_parameters():
242
  if "lora" not in name:
243
- param.requires_grad_(val)
 
244
 
245
  @classmethod
246
- def from_bert(cls, *args, **kwargs):
247
  bert = BertModel.from_pretrained(*args, **kwargs)
248
  config = JinaBertConfig.from_pretrained(*args, **kwargs)
249
- return cls(config, bert=bert)
250
-
251
- def merge_lora(self):
252
- """Merges currently selected LoRA into main weights."""
253
- if self._is_merged:
254
- raise Exception('LoRA has already been merged, cannot merge again')
255
- self._is_merged = True
256
- self.apply(LoRAParametrization.merge_lora_into_layer)
257
 
258
  @classmethod
259
  def from_pretrained(
@@ -270,13 +234,7 @@ class BertLoRA(BertPreTrainedModel):
270
  use_safetensors: bool = None,
271
  **kwargs,
272
  ):
273
- """
274
- TODO: choose between from_bert and super().from_pretrained
275
-
276
- We want to be able to load both a pretrained BertModel, and a trained
277
- BertLoRA via this method. To this end, we need to check which of these
278
- models we are expected to load.
279
- """
280
  return cls.from_bert(pretrained_model_name_or_path)
281
 
282
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
@@ -292,34 +250,16 @@ class BertLoRA(BertPreTrainedModel):
292
 
293
  @property
294
  def current_task(self):
295
- """ Which LoRA is currently selected
296
- :return: Integer or None (when LoRA is disabled)
297
- """
298
  return self._task_idx
299
 
300
  @current_task.setter
301
  def current_task(self, task_idx: Union[None, int]):
302
- """Set the LoRA that is to be used.
303
-
304
- The LoRA is specified by `task_idx`, which may be an integer >= 0,
305
- indexing the available LoRAs. If it is None, no LoRA is used.
306
-
307
- :param task_idx: Which LoRA to use
308
- :return:
309
- """
310
- if self._is_merged:
311
- raise Exception('LoRA has been merged, cannot select new task')
312
- assert task_idx is None or 0 <= task_idx < self._num_adaptions
313
- if self._task_idx != task_idx:
314
- # In this case, we need to update the LoRAs everywhere
315
- self._task_idx = task_idx
316
- self.apply(
317
- partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
318
- )
319
 
320
- def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
321
- if current_task is None or current_task >= 0:
322
- self.current_task = current_task
323
  return self.bert(*args, **kwargs)
324
 
325
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
@@ -332,5 +272,5 @@ class BertLoRA(BertPreTrainedModel):
332
  for name, param in super().named_parameters(
333
  prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
334
  ):
335
- if "lora" in name or self.main_params_trainable:
336
  yield name, param
 
65
  fan_in_fan_out = layer_type == "embedding"
66
  self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
67
 
 
 
68
  if layer_type == "linear":
69
  self.lora_A = nn.Parameter(
70
  initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
 
194
  ),
195
  )
196
 
197
+ @classmethod
198
+ def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
199
  if isinstance(layer, LoRAParametrization):
200
  layer.current_task = task_idx
201
 
 
 
 
 
 
 
202
 
203
  class BertLoRA(BertPreTrainedModel):
204
+ def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True, num_adaptions=1):
205
  super().__init__(config)
206
  if bert is None:
207
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
208
  else:
209
  self.bert = bert
210
+ self._register_lora(num_adaptions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  for name, param in super().named_parameters():
212
  if "lora" not in name:
213
+ param.requires_grad_(False)
214
+ self.current_task = 0
215
 
216
  @classmethod
217
+ def from_bert(cls, *args, num_adaptions=1, **kwargs):
218
  bert = BertModel.from_pretrained(*args, **kwargs)
219
  config = JinaBertConfig.from_pretrained(*args, **kwargs)
220
+ return cls(config, bert=bert, num_adaptions=num_adaptions)
 
 
 
 
 
 
 
221
 
222
  @classmethod
223
  def from_pretrained(
 
234
  use_safetensors: bool = None,
235
  **kwargs,
236
  ):
237
+ # TODO: choose between from_bert and super().from_pretrained
 
 
 
 
 
 
238
  return cls.from_bert(pretrained_model_name_or_path)
239
 
240
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
 
250
 
251
  @property
252
  def current_task(self):
 
 
 
253
  return self._task_idx
254
 
255
  @current_task.setter
256
  def current_task(self, task_idx: Union[None, int]):
257
+ self._task_idx = task_idx
258
+ self.apply(
259
+ partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
260
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
+ def forward(self, *args, **kwargs):
 
 
263
  return self.bert(*args, **kwargs)
264
 
265
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 
272
  for name, param in super().named_parameters(
273
  prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
274
  ):
275
+ if "lora" in name:
276
  yield name, param
tokenizer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
4
+ import warnings
5
+
6
+
7
+ def get_tokenizer(parent_class):
8
+ class TokenizerClass(parent_class):
9
+ def __init__(self, *args, **kwargs):
10
+ """
11
+ This class dynamically extends a given tokenizer class from the HF
12
+ Transformers library (RobertaTokenizer or RobertaTokenizerFast).
13
+ The task_type_ids are used to pass instruction information to the model.
14
+ A task_type should either be an integer or a sequence of integers with the same
15
+ length as the batch size.
16
+ """
17
+ super().__init__(*args, **kwargs)
18
+
19
+ def __call__(self, *args, task_type=None, **kwargs):
20
+ batch_encoding = super().__call__(*args, **kwargs)
21
+ if task_type is not None:
22
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
23
+ return batch_encoding
24
+
25
+ def _batch_encode_plus(self, *args, task_type=None, **kwargs):
26
+ batch_encoding = super()._batch_encode_plus(*args, **kwargs)
27
+ if task_type is not None:
28
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
29
+ return batch_encoding
30
+
31
+ def _encode_plus(self, *args, task_type=None, **kwargs):
32
+ batch_encoding = super()._encode_plus(*args, **kwargs)
33
+ if task_type is not None:
34
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
35
+ return batch_encoding
36
+
37
+ @classmethod
38
+ def _add_task_type_ids(cls, batch_encoding, task_type, tensor_type):
39
+ return BatchEncoding(
40
+ {
41
+ 'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
42
+ **batch_encoding,
43
+ },
44
+ tensor_type=tensor_type,
45
+ )
46
+
47
+ @staticmethod
48
+ def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
49
+
50
+ def apply_task_type(m, x):
51
+ x = torch.tensor(x)
52
+ assert (
53
+ len(x.shape) == 0 or x.shape[0] == m.shape[0]
54
+ ), 'The shape of task_type does not match the size of the batch.'
55
+ return m * x if len(x.shape) == 0 else m * x[:, None]
56
+
57
+ if isinstance(batch_encoding['input_ids'], torch.Tensor):
58
+ shape = batch_encoding['input_ids'].shape
59
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
60
+ else:
61
+ try:
62
+ shape = torch.tensor(batch_encoding['input_ids']).shape
63
+ except:
64
+ raise ValueError(
65
+ "Unable to create tensor, you should probably "
66
+ "activate truncation and/or padding with "
67
+ "'padding=True' 'truncation=True' to have batched "
68
+ "tensors with the same length."
69
+ )
70
+ if isinstance(batch_encoding['input_ids'], list):
71
+ return (
72
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
73
+ ).tolist()
74
+ elif isinstance(batch_encoding['input_ids'], np.array):
75
+ return (
76
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
77
+ ).numpy()
78
+ else:
79
+ warnings.warn(
80
+ 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
81
+ )
82
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
83
+
84
+ return TokenizerClass
85
+
86
+
87
+ JinaTokenizer = get_tokenizer(RobertaTokenizer)
88
+ JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)