config.json CHANGED
@@ -3,8 +3,12 @@
3
  "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
  "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
  "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
- "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM"
 
7
  },
 
 
 
8
  "attention_probs_dropout_prob": 0.1,
9
  "bos_token_id": 0,
10
  "eos_token_id": 2,
 
3
  "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
  "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
  "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
+ "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM",
7
+ "AutoModelForSequenceClassification":"modeling_xlm_roberta.XLMRobertaForSequenceClassification"
8
  },
9
+ "architectures": [
10
+ "XLMRobertaModel"
11
+ ],
12
  "attention_probs_dropout_prob": 0.1,
13
  "bos_token_id": 0,
14
  "eos_token_id": 2,
configuration_xlm_roberta.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import PretrainedConfig
 
2
 
3
  class XLMRobertaFlashConfig(PretrainedConfig):
4
  def __init__(
@@ -21,10 +22,16 @@ class XLMRobertaFlashConfig(PretrainedConfig):
21
  position_embedding_type="absolute",
22
  use_cache=True,
23
  classifier_dropout=None,
 
 
 
 
 
24
  **kwargs,
25
  ):
26
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
27
 
 
28
  self.vocab_size = vocab_size
29
  self.hidden_size = hidden_size
30
  self.num_hidden_layers = num_hidden_layers
@@ -39,4 +46,12 @@ class XLMRobertaFlashConfig(PretrainedConfig):
39
  self.layer_norm_eps = layer_norm_eps
40
  self.position_embedding_type = position_embedding_type
41
  self.use_cache = use_cache
42
- self.classifier_dropout = classifier_dropout
 
 
 
 
 
 
 
 
 
1
  from transformers import PretrainedConfig
2
+ import torch
3
 
4
  class XLMRobertaFlashConfig(PretrainedConfig):
5
  def __init__(
 
22
  position_embedding_type="absolute",
23
  use_cache=True,
24
  classifier_dropout=None,
25
+ num_loras=1,
26
+ load_trained_adapters=False,
27
+ use_flash_attn=True,
28
+ torch_dtype=None,
29
+ emb_pooler=None,
30
  **kwargs,
31
  ):
32
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
33
 
34
+
35
  self.vocab_size = vocab_size
36
  self.hidden_size = hidden_size
37
  self.num_hidden_layers = num_hidden_layers
 
46
  self.layer_norm_eps = layer_norm_eps
47
  self.position_embedding_type = position_embedding_type
48
  self.use_cache = use_cache
49
+ self.classifier_dropout = classifier_dropout
50
+ self.num_loras = num_loras
51
+ self.load_trained_adapters = load_trained_adapters
52
+ self.use_flash_attn = use_flash_attn
53
+ self.emb_pooler = emb_pooler
54
+ if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
55
+ self.torch_dtype = getattr(torch, torch_dtype)
56
+ else:
57
+ self.torch_dtype = torch_dtype
convert_roberta_weights_to_flash.py CHANGED
@@ -1,10 +1,11 @@
1
  import re
2
  from collections import OrderedDict
3
  from transformers import PretrainedConfig
4
- from transformers import XLMRobertaForMaskedLM
5
 
6
  from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
7
- from .modeling_xlm_roberta import XLMRobertaForMaskedLM as BertModel
 
8
  import torch
9
 
10
  import click
@@ -137,14 +138,23 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
137
 
138
  @click.command()
139
  @click.option('--model_name', default='FacebookAI/xlm-roberta-base', help='model name')
 
 
140
  @click.option('--output', default='converted_roberta_weights.bin', help='model name')
141
- def main(model_name, output):
142
- roberta_model = XLMRobertaForMaskedLM.from_pretrained(model_name)
 
 
 
 
143
  config = BertConfig.from_dict(roberta_model.config.to_dict())
144
  state_dict = roberta_model.state_dict()
145
  new_state_dict = remap_state_dict(state_dict, config)
146
-
147
- flash_model = BertModel(config)
 
 
 
148
 
149
  for k, v in flash_model.state_dict().items():
150
  if k not in new_state_dict:
 
1
  import re
2
  from collections import OrderedDict
3
  from transformers import PretrainedConfig
4
+ from transformers import XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification
5
 
6
  from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
7
+ from .modeling_xlm_roberta import XLMRobertaForMaskedLM as FlashXLMRobertaForMaskedLM
8
+ from .modeling_xlm_roberta import XLMRobertaForSequenceClassification as FlashXLMRobertaForSequenceClassification
9
  import torch
10
 
11
  import click
 
138
 
139
  @click.command()
140
  @click.option('--model_name', default='FacebookAI/xlm-roberta-base', help='model name')
141
+ @click.option('--revision', default='main', help='revision')
142
+ @click.option('--task', default='masked_lm', help='task')
143
  @click.option('--output', default='converted_roberta_weights.bin', help='model name')
144
+ def main(model_name, revision, task, output):
145
+
146
+ if task == 'masked_lm':
147
+ roberta_model = XLMRobertaForMaskedLM.from_pretrained(model_name, revision=revision)
148
+ elif task == 'sequence_classification':
149
+ roberta_model = XLMRobertaForSequenceClassification.from_pretrained(model_name, revision=revision,num_labels=1)
150
  config = BertConfig.from_dict(roberta_model.config.to_dict())
151
  state_dict = roberta_model.state_dict()
152
  new_state_dict = remap_state_dict(state_dict, config)
153
+
154
+ if task == 'masked_lm':
155
+ flash_model = FlashXLMRobertaForMaskedLM(config)
156
+ elif task == 'sequence_classification':
157
+ flash_model = FlashXLMRobertaForSequenceClassification(config)
158
 
159
  for k, v in flash_model.state_dict().items():
160
  if k not in new_state_dict:
mha.py CHANGED
@@ -10,8 +10,6 @@ import torch
10
  import torch.nn as nn
11
  from einops import rearrange, repeat
12
 
13
- from flash_attn.utils.distributed import get_dim_for_local_rank
14
-
15
  try:
16
  from flash_attn import (
17
  flash_attn_kvpacked_func,
 
10
  import torch.nn as nn
11
  from einops import rearrange, repeat
12
 
 
 
13
  try:
14
  from flash_attn import (
15
  flash_attn_kvpacked_func,
modeling_lora.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from functools import partial
4
+ from typing import Iterator, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.utils.parametrize as parametrize
8
+ from torch import nn
9
+ from torch.nn import Parameter
10
+ from transformers import PretrainedConfig
11
+
12
+ from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaPreTrainedModel, XLMRobertaFlashConfig
13
+
14
+
15
+ def initialized_weights(
16
+ shape: Tuple[int], num_adaptions: int, init: str = "kaiming"
17
+ ) -> torch.Tensor:
18
+ weight_data = []
19
+ for _ in range(num_adaptions):
20
+ new_adaption = torch.zeros(shape)
21
+ if init == "kaiming":
22
+ nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
23
+ elif init == "normal":
24
+ nn.init.normal_(new_adaption)
25
+ else:
26
+ raise NotImplementedError
27
+ weight_data.append(new_adaption)
28
+ return torch.stack(weight_data, dim=0)
29
+
30
+
31
+ class LoRAParametrization(nn.Module):
32
+ """
33
+ This LoRA implementation was inspired by https://github.com/cccntu/minLoRA
34
+ The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
35
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software
36
+ and associated documentation files (the "Software"), to deal in the Software without restriction,
37
+ including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
38
+ and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
39
+ subject to the following conditions:
40
+ The above copyright notice and this permission notice shall be included in all copies or substantial
41
+ portions of the Software.
42
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
43
+ LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
44
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
45
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
46
+ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
47
+ """
48
+ def __init__(
49
+ self,
50
+ fan_in: int,
51
+ fan_out: int,
52
+ layer_type: str = "linear",
53
+ num_adaptions: int = 1,
54
+ rank: int = 4,
55
+ lora_dropout_p: float = 0.0,
56
+ lora_alpha: float = 1,
57
+ ):
58
+ super().__init__()
59
+ # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
60
+ # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
61
+ fan_in_fan_out = layer_type == "embedding"
62
+ self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
63
+
64
+ if layer_type == "linear":
65
+ self.lora_A = nn.Parameter(
66
+ initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
67
+ )
68
+ self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
69
+ elif layer_type == "embedding":
70
+ self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
71
+ self.lora_B = nn.Parameter(
72
+ initialized_weights(
73
+ (rank, fan_out), num_adaptions=num_adaptions, init="normal"
74
+ )
75
+ )
76
+ else:
77
+ raise NotImplementedError
78
+
79
+ self.lora_alpha, self.rank = lora_alpha, rank
80
+ self.scaling = lora_alpha / rank
81
+ self.lora_dropout = (
82
+ nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
83
+ )
84
+ self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
85
+ self.register_buffer(
86
+ "lora_dropout_mask",
87
+ torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
88
+ persistent=False,
89
+ )
90
+ self.forward_fn = lambda x: x
91
+ self.current_task = None
92
+
93
+ def _dropout(self, A):
94
+ # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
95
+ return A * self.lora_dropout(self.lora_dropout_mask)
96
+
97
+ def lora_forward(self, X):
98
+ assert self.current_task is not None
99
+ return (
100
+ X
101
+ + torch.matmul(
102
+ *self.swap(
103
+ (
104
+ self.lora_B[self.current_task],
105
+ self.dropout_fn(self.lora_A[self.current_task]),
106
+ )
107
+ )
108
+ ).view(X.shape)
109
+ * self.scaling
110
+ )
111
+
112
+ def forward(self, X):
113
+ return self.forward_fn(X)
114
+
115
+ @property
116
+ def current_task(self):
117
+ return self._current_task
118
+
119
+ @current_task.setter
120
+ def current_task(self, task: Union[None, int]):
121
+ self._current_task = task
122
+ if task is None:
123
+ self.forward_fn = lambda x: x
124
+ else:
125
+ self.forward_fn = self.lora_forward
126
+
127
+ @classmethod
128
+ def from_linear(
129
+ cls,
130
+ layer: nn.Module,
131
+ num_adaptions: int = 1,
132
+ rank: int = 4,
133
+ lora_dropout_p: float = 0.0,
134
+ lora_alpha: int = 1,
135
+ ):
136
+ assert isinstance(layer, nn.Linear)
137
+ fan_out, fan_in = layer.weight.shape
138
+ return cls(
139
+ fan_in,
140
+ fan_out,
141
+ num_adaptions=num_adaptions,
142
+ layer_type="linear",
143
+ rank=rank,
144
+ lora_dropout_p=lora_dropout_p,
145
+ lora_alpha=lora_alpha,
146
+ )
147
+
148
+ @classmethod
149
+ def from_embedding(
150
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
151
+ ):
152
+ assert isinstance(layer, nn.Embedding)
153
+ fan_in, fan_out = layer.weight.shape
154
+ return cls(
155
+ fan_in,
156
+ fan_out,
157
+ num_adaptions=num_adaptions,
158
+ layer_type="embedding",
159
+ rank=rank,
160
+ lora_dropout_p=lora_dropout_p,
161
+ lora_alpha=lora_alpha,
162
+ )
163
+
164
+ @classmethod
165
+ def add_to_layer(
166
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
167
+ ):
168
+ if isinstance(layer, nn.Linear):
169
+ parametrize.register_parametrization(
170
+ layer,
171
+ "weight",
172
+ cls.from_linear(
173
+ layer,
174
+ num_adaptions=num_adaptions,
175
+ rank=rank,
176
+ lora_dropout_p=lora_dropout_p,
177
+ lora_alpha=lora_alpha,
178
+ ),
179
+ )
180
+ elif isinstance(layer, nn.Embedding):
181
+ parametrize.register_parametrization(
182
+ layer,
183
+ "weight",
184
+ cls.from_embedding(
185
+ layer,
186
+ num_adaptions=num_adaptions,
187
+ rank=rank,
188
+ lora_dropout_p=lora_dropout_p,
189
+ lora_alpha=lora_alpha,
190
+ ),
191
+ )
192
+
193
+ @staticmethod
194
+ def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
195
+ if isinstance(layer, LoRAParametrization):
196
+ layer.current_task = task_idx
197
+
198
+ @staticmethod
199
+ def merge_lora_into_layer(layer: nn.Module):
200
+ if hasattr(layer, "parametrizations"):
201
+ for attr_name in layer.parametrizations.keys():
202
+ parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
203
+
204
+
205
+ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
206
+ def __init__(self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None, add_pooling_layer=True):
207
+ super().__init__(config)
208
+
209
+ if roberta is None:
210
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=add_pooling_layer)
211
+ else:
212
+ self.roberta = roberta
213
+
214
+ self._is_merged = False
215
+ self._num_adaptions = config.num_loras
216
+ self._register_lora(self._num_adaptions)
217
+
218
+ self.main_params_trainable = False
219
+ self._task_idx = None
220
+ # By default, we select the first LoRA
221
+ self.current_task = 0
222
+
223
+ @property
224
+ def main_params_trainable(self):
225
+ return self._main_params_trainable
226
+
227
+ @main_params_trainable.setter
228
+ def main_params_trainable(self, val: bool):
229
+ """Whether the main parameters (i.e. those that are not LoRA) should be trainable.
230
+ This method sets the `requires_grad_` attribute of the main weights
231
+ and controls which parameters are returned in `self.parameters()`.
232
+ :param val: Whether or not to make the parameters trainable.
233
+ :return: None
234
+ """
235
+ self._main_params_trainable = val
236
+ for name, param in super().named_parameters():
237
+ if "lora" not in name:
238
+ param.requires_grad_(val)
239
+
240
+ def merge_lora(self):
241
+ """Merges currently selected LoRA into main weights."""
242
+ if self._is_merged:
243
+ raise Exception('LoRA has already been merged, cannot merge again')
244
+ self._is_merged = True
245
+ self.apply(LoRAParametrization.merge_lora_into_layer)
246
+
247
+ @classmethod
248
+ def from_pretrained(
249
+ cls,
250
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
251
+ *model_args,
252
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
253
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
254
+ ignore_mismatched_sizes: bool = False,
255
+ force_download: bool = False,
256
+ local_files_only: bool = False,
257
+ token: Optional[Union[str, bool]] = None,
258
+ revision: str = "main",
259
+ use_safetensors: bool = None,
260
+ **kwargs,
261
+ ):
262
+ config = XLMRobertaFlashConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
263
+ if config.load_trained_adapters:
264
+ return super().from_pretrained(
265
+ pretrained_model_name_or_path,
266
+ *model_args,
267
+ **kwargs
268
+ )
269
+ else:
270
+ roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
271
+ return cls(config, roberta=roberta)
272
+
273
+ def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
274
+ self.apply(
275
+ partial(
276
+ LoRAParametrization.add_to_layer,
277
+ num_adaptions=num_adaptions,
278
+ rank=rank,
279
+ lora_dropout_p=lora_dropout_p,
280
+ lora_alpha=lora_alpha,
281
+ )
282
+ )
283
+
284
+ @property
285
+ def current_task(self):
286
+ """ Which LoRA is currently selected
287
+ :return: Integer or None (when LoRA is disabled)
288
+ """
289
+ return self._task_idx
290
+
291
+ @current_task.setter
292
+ def current_task(self, task_idx: Union[None, int]):
293
+ """Set the LoRA that is to be used.
294
+ The LoRA is specified by `task_idx`, which may be an integer >= 0,
295
+ indexing the available LoRAs. If it is None, no LoRA is used.
296
+ :param task_idx: Which LoRA to use
297
+ :return:
298
+ """
299
+ if self._is_merged:
300
+ raise Exception('LoRA has been merged, cannot select new task')
301
+ assert task_idx is None or 0 <= task_idx < self._num_adaptions
302
+ if self._task_idx != task_idx:
303
+ # In this case, we need to update the LoRAs everywhere
304
+ self._task_idx = task_idx
305
+ self.apply(
306
+ partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
307
+ )
308
+
309
+ def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
310
+ if current_task is None or current_task >= 0:
311
+ self.current_task = current_task
312
+ return self.roberta(*args, **kwargs)
313
+
314
+ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
315
+ for _, param in self.named_parameters(recurse=recurse):
316
+ yield param
317
+
318
+ def named_parameters(
319
+ self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
320
+ ) -> Iterator[Tuple[str, Parameter]]:
321
+ for name, param in super().named_parameters(
322
+ prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
323
+ ):
324
+ if "lora" in name or self.main_params_trainable:
325
+ yield name, param
modeling_xlm_roberta.py CHANGED
@@ -1,6 +1,5 @@
1
  # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
  # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
-
4
  # Copyright (c) 2022, Tri Dao.
5
  # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
6
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
@@ -8,20 +7,23 @@
8
 
9
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
10
 
 
11
  import logging
12
  import re
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
 
16
 
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
  import torch.utils.checkpoint
 
21
  from einops import rearrange
22
  from transformers import PretrainedConfig
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.modeling_outputs import MaskedLMOutput
25
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
26
 
27
  from transformers.models.bert.modeling_bert import (
@@ -29,7 +31,7 @@ from transformers.models.bert.modeling_bert import (
29
  BertForPreTrainingOutput,
30
  )
31
 
32
- from typing import Optional, Tuple, Union
33
 
34
  from .xlm_padding import (
35
  index_first_axis,
@@ -61,12 +63,30 @@ try:
61
  except ImportError:
62
  CrossEntropyLoss = None
63
 
 
 
 
 
 
64
 
65
  logger = logging.getLogger(__name__)
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
69
- use_flash_attn = getattr(config, "use_flash_attn", False)
70
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
71
  rotary_kwargs = {}
72
  if config.position_embedding_type == "rotary":
@@ -169,7 +189,7 @@ def _init_weights(module, initializer_range=0.02):
169
  class XLMRobertaEncoder(nn.Module):
170
  def __init__(self, config: XLMRobertaFlashConfig):
171
  super().__init__()
172
- self.use_flash_attn = getattr(config, "use_flash_attn", False)
173
  self.layers = nn.ModuleList(
174
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
175
  )
@@ -376,6 +396,17 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
376
  if isinstance(module, XLMRobertaEncoder):
377
  module.gradient_checkpointing = value
378
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
381
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
@@ -409,6 +440,169 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
409
 
410
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  def forward(
413
  self,
414
  input_ids,
@@ -946,3 +1140,117 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
946
  )
947
 
948
  return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
  # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
 
3
  # Copyright (c) 2022, Tri Dao.
4
  # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
 
7
 
8
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
 
10
+ import importlib.util
11
  import logging
12
  import re
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
16
+ import numpy as np
17
 
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
  from transformers import PretrainedConfig
25
  from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
 
29
  from transformers.models.bert.modeling_bert import (
 
31
  BertForPreTrainingOutput,
32
  )
33
 
34
+ from typing import List, Optional, Tuple, Union
35
 
36
  from .xlm_padding import (
37
  index_first_axis,
 
63
  except ImportError:
64
  CrossEntropyLoss = None
65
 
66
+ try:
67
+ from tqdm.autonotebook import trange
68
+ except ImportError:
69
+ trange = None
70
+
71
 
72
  logger = logging.getLogger(__name__)
73
 
74
 
75
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
76
+ if not getattr(config, "use_flash_attn", False):
77
+ return False
78
+ if not torch.cuda.is_available():
79
+ return False
80
+ if importlib.util.find_spec("flash_attn") is None:
81
+ logger.warning(
82
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
83
+ )
84
+ return False
85
+ return True
86
+
87
+
88
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
89
+ use_flash_attn = get_use_flash_attn(config)
90
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
 
189
  class XLMRobertaEncoder(nn.Module):
190
  def __init__(self, config: XLMRobertaFlashConfig):
191
  super().__init__()
192
+ self.use_flash_attn = get_use_flash_attn(config)
193
  self.layers = nn.ModuleList(
194
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
195
  )
 
396
  if isinstance(module, XLMRobertaEncoder):
397
  module.gradient_checkpointing = value
398
 
399
+ @classmethod
400
+ def from_pretrained(
401
+ cls,
402
+ *args,
403
+ **kwargs,
404
+ ):
405
+ if not 'torch_dtype' in kwargs:
406
+ kwargs['torch_dtype'] = 'auto'
407
+ return super().from_pretrained(*args, **kwargs)
408
+
409
+
410
 
411
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
412
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
 
440
 
441
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
442
 
443
+
444
+ @torch.inference_mode()
445
+ def encode(
446
+ self: 'XLMRobertaModel',
447
+ sentences: Union[str, List[str]],
448
+ batch_size: int = 32,
449
+ show_progress_bar: Optional[bool] = None,
450
+ output_value: str = 'sentence_embedding',
451
+ convert_to_numpy: bool = True,
452
+ convert_to_tensor: bool = False,
453
+ device: Optional[torch.device] = None,
454
+ normalize_embeddings: bool = False,
455
+ **tokenizer_kwargs,
456
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
457
+ """
458
+ Computes sentence embeddings
459
+ Args:
460
+ sentences(`str` or `List[str]`):
461
+ Sentence or sentences to be encoded
462
+ batch_size(`int`, *optional*, defaults to 32):
463
+ Batch size for the computation
464
+ show_progress_bar(`bool`, *optional*, defaults to None):
465
+ Show a progress bar when encoding sentences.
466
+ If set to None, progress bar is only shown when
467
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
468
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
469
+ Default sentence_embedding, to get sentence embeddings.
470
+ Can be set to token_embeddings to get wordpiece token embeddings.
471
+ Set to None, to get all output values
472
+ convert_to_numpy(`bool`, *optional*, defaults to True):
473
+ If true, the output is a list of numpy vectors.
474
+ Else, it is a list of pytorch tensors.
475
+ convert_to_tensor(`bool`, *optional*, defaults to False):
476
+ If true, you get one large tensor as return.
477
+ Overwrites any setting from convert_to_numpy
478
+ device(`torch.device`, *optional*, defaults to None):
479
+ Which torch.device to use for the computation
480
+ normalize_embeddings(`bool`, *optional*, defaults to False):
481
+ If set to true, returned vectors will have length 1. In that case, the
482
+ faster dot-product (util.dot_score) instead of cosine similarity can
483
+ be used.
484
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
485
+ Keyword arguments for the tokenizer
486
+ Returns:
487
+ By default, a list of tensors is returned.
488
+ If convert_to_tensor, a stacked tensor is returned.
489
+ If convert_to_numpy, a numpy matrix is returned.
490
+ """
491
+ from transformers import AutoTokenizer
492
+
493
+ self.tokenizer = AutoTokenizer.from_pretrained(
494
+ self.name_or_path, trust_remote_code=True
495
+ )
496
+
497
+ is_training = self.training
498
+ self.eval()
499
+
500
+ if show_progress_bar is None:
501
+ show_progress_bar = (
502
+ logger.getEffectiveLevel() == logging.INFO
503
+ or logger.getEffectiveLevel() == logging.DEBUG
504
+ )
505
+
506
+ if convert_to_tensor:
507
+ convert_to_numpy = False
508
+
509
+ if output_value != 'sentence_embedding':
510
+ convert_to_tensor = False
511
+ convert_to_numpy = False
512
+
513
+ input_was_string = False
514
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
515
+ sentences = [sentences]
516
+ input_was_string = True
517
+
518
+ if device is not None:
519
+ self.to(device)
520
+
521
+ permutation = np.argsort([-len(i) for i in sentences])
522
+ inverse_permutation = np.argsort(permutation)
523
+ sentences = [sentences[idx] for idx in permutation]
524
+
525
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
526
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
527
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
528
+ )
529
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
530
+
531
+ all_embeddings = []
532
+
533
+ if trange is not None:
534
+ range_iter = trange(
535
+ 0,
536
+ len(sentences),
537
+ batch_size,
538
+ desc="Encoding",
539
+ disable=not show_progress_bar,
540
+ )
541
+ else:
542
+ range_iter = range(0, len(sentences), batch_size)
543
+
544
+ for i in range_iter:
545
+ encoded_input = self.tokenizer(
546
+ sentences[i : i + batch_size],
547
+ return_tensors='pt',
548
+ **tokenizer_kwargs,
549
+ ).to(self.device)
550
+ token_embs = self.forward(**encoded_input)[0]
551
+
552
+ # Accumulate in fp32 to avoid overflow
553
+ token_embs = token_embs.float()
554
+
555
+ if output_value == 'token_embeddings':
556
+ raise NotImplementedError
557
+ elif output_value is None:
558
+ raise NotImplementedError
559
+ else:
560
+ if self.config.emb_pooler == 'cls':
561
+ embeddings = self.cls_pooling(
562
+ token_embs, encoded_input['attention_mask']
563
+ )
564
+ else:
565
+ embeddings = self.mean_pooling(
566
+ token_embs, encoded_input['attention_mask']
567
+ )
568
+
569
+ if normalize_embeddings:
570
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
571
+
572
+ if convert_to_numpy:
573
+ embeddings = embeddings.cpu()
574
+ all_embeddings.extend(embeddings)
575
+
576
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
577
+
578
+ if convert_to_tensor:
579
+ all_embeddings = torch.stack(all_embeddings)
580
+ elif convert_to_numpy:
581
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
582
+
583
+ if input_was_string:
584
+ all_embeddings = all_embeddings[0]
585
+
586
+ self.train(is_training)
587
+ return all_embeddings
588
+
589
+ def mean_pooling(
590
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
591
+ ):
592
+ input_mask_expanded = (
593
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
594
+ )
595
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
596
+ input_mask_expanded.sum(1), min=1e-9
597
+ )
598
+
599
+
600
+ def cls_pooling(
601
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
602
+ ):
603
+ return token_embeddings[:,0]
604
+
605
+
606
  def forward(
607
  self,
608
  input_ids,
 
1140
  )
1141
 
1142
  return state_dict
1143
+
1144
+
1145
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
1146
+ class XLMRobertaClassificationHead(nn.Module):
1147
+ """Head for sentence-level classification tasks."""
1148
+
1149
+ def __init__(self, config):
1150
+ super().__init__()
1151
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1152
+ classifier_dropout = (
1153
+ config.classifier_dropout
1154
+ if config.classifier_dropout is not None
1155
+ else config.hidden_dropout_prob
1156
+ )
1157
+ self.dropout = nn.Dropout(classifier_dropout)
1158
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1159
+
1160
+ def forward(self, features, **kwargs):
1161
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1162
+ x = self.dropout(x)
1163
+ x = self.dense(x)
1164
+ x = torch.tanh(x)
1165
+ x = self.dropout(x)
1166
+ x = self.out_proj(x)
1167
+ return x
1168
+
1169
+
1170
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
1171
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1172
+ def __init__(self, config):
1173
+ super().__init__(config)
1174
+ self.num_labels = config.num_labels
1175
+ self.config = config
1176
+
1177
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
1178
+ self.classifier = XLMRobertaClassificationHead(config)
1179
+
1180
+ # Initialize weights and apply final processing
1181
+ self.post_init()
1182
+
1183
+ def forward(
1184
+ self,
1185
+ input_ids: Optional[torch.LongTensor] = None,
1186
+ attention_mask: Optional[torch.FloatTensor] = None,
1187
+ token_type_ids: Optional[torch.LongTensor] = None,
1188
+ position_ids: Optional[torch.LongTensor] = None,
1189
+ head_mask: Optional[torch.FloatTensor] = None,
1190
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1191
+ labels: Optional[torch.LongTensor] = None,
1192
+ output_attentions: Optional[bool] = None,
1193
+ output_hidden_states: Optional[bool] = None,
1194
+ return_dict: Optional[bool] = None,
1195
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1196
+ r"""
1197
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1198
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1199
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1200
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1201
+ """
1202
+ return_dict = (
1203
+ return_dict if return_dict is not None else self.config.use_return_dict
1204
+ )
1205
+
1206
+ outputs = self.roberta(
1207
+ input_ids,
1208
+ attention_mask=attention_mask,
1209
+ token_type_ids=token_type_ids,
1210
+ position_ids=position_ids,
1211
+ head_mask=head_mask,
1212
+ inputs_embeds=inputs_embeds,
1213
+ output_attentions=output_attentions,
1214
+ output_hidden_states=output_hidden_states,
1215
+ return_dict=return_dict,
1216
+ )
1217
+ sequence_output = outputs[0]
1218
+ logits = self.classifier(sequence_output)
1219
+
1220
+ loss = None
1221
+ if labels is not None:
1222
+ # move labels to correct device to enable model parallelism
1223
+ labels = labels.to(logits.device)
1224
+ if self.config.problem_type is None:
1225
+ if self.num_labels == 1:
1226
+ self.config.problem_type = "regression"
1227
+ elif self.num_labels > 1 and (
1228
+ labels.dtype == torch.long or labels.dtype == torch.int
1229
+ ):
1230
+ self.config.problem_type = "single_label_classification"
1231
+ else:
1232
+ self.config.problem_type = "multi_label_classification"
1233
+
1234
+ if self.config.problem_type == "regression":
1235
+ loss_fct = MSELoss()
1236
+ if self.num_labels == 1:
1237
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1238
+ else:
1239
+ loss = loss_fct(logits, labels)
1240
+ elif self.config.problem_type == "single_label_classification":
1241
+ loss_fct = CrossEntropyLoss()
1242
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1243
+ elif self.config.problem_type == "multi_label_classification":
1244
+ loss_fct = BCEWithLogitsLoss()
1245
+ loss = loss_fct(logits, labels)
1246
+
1247
+ if not return_dict:
1248
+ output = (logits,) + outputs[2:]
1249
+ return ((loss,) + output) if loss is not None else output
1250
+
1251
+ return SequenceClassifierOutput(
1252
+ loss=loss,
1253
+ logits=logits,
1254
+ hidden_states=outputs.hidden_states,
1255
+ attentions=outputs.attentions,
1256
+ )
modeling_xlm_roberta_for_glue.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
+ from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
+
8
+ from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
9
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
10
+
11
+
12
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
13
+ def __init__(self, config: XLMRobertaFlashConfig):
14
+ super().__init__(config)
15
+ self.num_labels = config.num_labels
16
+ self.config = config
17
+
18
+ self.roberta = XLMRobertaModel(config)
19
+ classifier_dropout = (
20
+ config.classifier_dropout
21
+ if config.classifier_dropout is not None
22
+ else config.hidden_dropout_prob
23
+ )
24
+ self.dropout = nn.Dropout(classifier_dropout)
25
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
26
+
27
+ # Initialize weights and apply final processing
28
+ self.post_init()
29
+
30
+
31
+ def forward(
32
+ self,
33
+ input_ids: Optional[torch.Tensor] = None,
34
+ attention_mask: Optional[torch.Tensor] = None,
35
+ token_type_ids: Optional[torch.Tensor] = None,
36
+ position_ids: Optional[torch.Tensor] = None,
37
+ head_mask: Optional[torch.Tensor] = None,
38
+ inputs_embeds: Optional[torch.Tensor] = None,
39
+ labels: Optional[torch.Tensor] = None,
40
+ output_attentions: Optional[bool] = None,
41
+ output_hidden_states: Optional[bool] = None,
42
+ return_dict: Optional[bool] = None,
43
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
44
+ r"""
45
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
46
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
47
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
48
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
49
+ """
50
+ return_dict = (
51
+ return_dict if return_dict is not None else self.config.use_return_dict
52
+ )
53
+
54
+ assert head_mask is None
55
+ assert inputs_embeds is None
56
+ assert output_attentions is None
57
+ assert output_hidden_states is None
58
+ assert return_dict
59
+ outputs = self.roberta(
60
+ input_ids,
61
+ attention_mask=attention_mask,
62
+ token_type_ids=token_type_ids,
63
+ position_ids=position_ids,
64
+ head_mask=head_mask,
65
+ inputs_embeds=inputs_embeds,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=return_dict,
69
+ )
70
+
71
+ pooled_output = outputs[1]
72
+
73
+ pooled_output = self.dropout(pooled_output)
74
+ logits = self.classifier(pooled_output)
75
+
76
+ loss = None
77
+ if labels is not None:
78
+ if self.config.problem_type is None:
79
+ if self.num_labels == 1:
80
+ self.config.problem_type = "regression"
81
+ elif self.num_labels > 1 and (
82
+ labels.dtype == torch.long or labels.dtype == torch.int
83
+ ):
84
+ self.config.problem_type = "single_label_classification"
85
+ else:
86
+ self.config.problem_type = "multi_label_classification"
87
+
88
+ if self.config.problem_type == "regression":
89
+ loss_fct = MSELoss()
90
+ if self.num_labels == 1:
91
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
92
+ else:
93
+ loss = loss_fct(logits, labels)
94
+ elif self.config.problem_type == "single_label_classification":
95
+ loss_fct = CrossEntropyLoss()
96
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
97
+ elif self.config.problem_type == "multi_label_classification":
98
+ loss_fct = BCEWithLogitsLoss()
99
+ loss = loss_fct(logits, labels)
100
+ if not return_dict:
101
+ output = (logits,) + outputs[2:]
102
+ return ((loss,) + output) if loss is not None else output
103
+
104
+ return SequenceClassifierOutput(
105
+ loss=loss,
106
+ logits=logits,
107
+ hidden_states=outputs.hidden_states,
108
+ attentions=outputs.attentions,
109
+ )