Files changed (8) hide show
  1. block.py +1 -1
  2. configuration_xlm_roberta.py +2 -0
  3. embedding.py +4 -3
  4. mha.py +11 -3
  5. mlp.py +4 -3
  6. modeling_lora.py +61 -79
  7. modeling_xlm_roberta.py +18 -20
  8. rotary.py +44 -21
block.py CHANGED
@@ -233,7 +233,7 @@ class Block(nn.Module):
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
- mlp_out = self.mlp(hidden_states)
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'))
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
configuration_xlm_roberta.py CHANGED
@@ -23,6 +23,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
 
26
  lora_rank=4,
27
  lora_dropout_p=0.0,
28
  lora_alpha=1,
@@ -55,6 +56,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
55
  self.classifier_dropout = classifier_dropout
56
  self.load_trained_adapters = load_trained_adapters
57
  self.lora_adaptations = lora_adaptations
 
58
  self.lora_rank = lora_rank
59
  self.lora_dropout_p = lora_dropout_p
60
  self.lora_alpha = lora_alpha
 
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
26
+ lora_prompts=None,
27
  lora_rank=4,
28
  lora_dropout_p=0.0,
29
  lora_alpha=1,
 
56
  self.classifier_dropout = classifier_dropout
57
  self.load_trained_adapters = load_trained_adapters
58
  self.lora_adaptations = lora_adaptations
59
+ self.lora_prompts = lora_prompts
60
  self.lora_rank = lora_rank
61
  self.lora_dropout_p = lora_dropout_p
62
  self.lora_alpha = lora_alpha
embedding.py CHANGED
@@ -40,14 +40,15 @@ class XLMRobertaEmbeddings(nn.Module):
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
- def forward(self, input_ids, position_ids=None, token_type_ids=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- embeddings = self.word_embeddings(input_ids)
 
51
  if self.max_position_embeddings > 0:
52
  if position_ids is None:
53
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
@@ -57,6 +58,6 @@ class XLMRobertaEmbeddings(nn.Module):
57
  if self.type_vocab_size > 0:
58
  if token_type_ids is None:
59
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
  embeddings = embeddings + token_type_embeddings
62
  return embeddings
 
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
51
+ embeddings = self.word_embeddings(input_ids, **lora_kwargs)
52
  if self.max_position_embeddings > 0:
53
  if position_ids is None:
54
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
 
58
  if self.type_vocab_size > 0:
59
  if token_type_ids is None:
60
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
61
+ token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
62
  embeddings = embeddings + token_type_embeddings
63
  return embeddings
mha.py CHANGED
@@ -450,6 +450,7 @@ class MHA(nn.Module):
450
 
451
  if fused_bias_fc and FusedDense is None:
452
  raise ImportError("fused_dense is not installed")
 
453
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
454
  linear_resid_cls = (
455
  LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
@@ -589,6 +590,7 @@ class MHA(nn.Module):
589
  max_seqlen=None,
590
  mixer_subset=None,
591
  inference_params=None,
 
592
  **kwargs,
593
  ):
594
  """
@@ -643,10 +645,14 @@ class MHA(nn.Module):
643
  batch, seqlen = x.shape[:2]
644
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
645
  assert x_kv is None and mixer_subset is None
 
646
  if not self.return_residual:
647
- qkv = self.Wqkv(x)
648
  else:
649
- qkv, x = self.Wqkv(x)
 
 
 
650
  if self.dwconv:
651
  qkv = rearrange(
652
  self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
@@ -731,5 +737,7 @@ class MHA(nn.Module):
731
  context = self._update_kvcache_attention(q, kv, inference_params)
732
  else:
733
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
734
- out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
 
 
735
  return out if not self.return_residual else (out, x)
 
450
 
451
  if fused_bias_fc and FusedDense is None:
452
  raise ImportError("fused_dense is not installed")
453
+
454
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
455
  linear_resid_cls = (
456
  LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
 
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
+ task_type=None,
594
  **kwargs,
595
  ):
596
  """
 
645
  batch, seqlen = x.shape[:2]
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
648
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
649
  if not self.return_residual:
650
+ qkv = self.Wqkv(x, **lora_kwargs)
651
  else:
652
+ if lora_kwargs:
653
+ lora_kwargs['residual'] = True
654
+ qkv, x = self.Wqkv(x, **lora_kwargs)
655
+
656
  if self.dwconv:
657
  qkv = rearrange(
658
  self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
737
  context = self._update_kvcache_attention(q, kv, inference_params)
738
  else:
739
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
740
+
741
+ lora_kwargs.pop('residual', None)
742
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), **lora_kwargs)
743
  return out if not self.return_residual else (out, x)
mlp.py CHANGED
@@ -47,10 +47,11 @@ class Mlp(nn.Module):
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
- def forward(self, x):
51
- y = self.fc1(x)
 
52
  y = self.activation(y)
53
- y = self.fc2(y)
54
  return y if not self.return_residual else (y, x)
55
 
56
 
 
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
+ def forward(self, x, task_type=None):
51
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
52
+ y = self.fc1(x, **lora_kwargs)
53
  y = self.activation(y)
54
+ y = self.fc2(y, **lora_kwargs)
55
  return y if not self.return_residual else (y, x)
56
 
57
 
modeling_lora.py CHANGED
@@ -9,14 +9,12 @@ import torch
9
  import torch.nn.utils.parametrize as parametrize
10
  from torch import nn
11
  from torch.nn import Parameter
 
12
  from transformers import PretrainedConfig
13
 
14
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
15
 
16
 
17
- LORA_NO_UPDATE = '__lora_no_update__'
18
-
19
-
20
  def initialized_weights(
21
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
22
  ) -> torch.Tensor:
@@ -91,22 +89,19 @@ class LoRAParametrization(nn.Module):
91
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
92
  persistent=False,
93
  )
94
- self.forward_fn = lambda x: x
95
- self.current_task = None
96
 
97
  def _dropout(self, A):
98
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
99
  return A * self.lora_dropout(self.lora_dropout_mask)
100
 
101
- def lora_forward(self, X):
102
- assert self.current_task is not None
103
  return (
104
  X
105
  + torch.matmul(
106
  *self.swap(
107
  (
108
- self.lora_B[self.current_task],
109
- self.dropout_fn(self.lora_A[self.current_task]),
110
  )
111
  )
112
  ).view(X.shape)
@@ -114,19 +109,7 @@ class LoRAParametrization(nn.Module):
114
  )
115
 
116
  def forward(self, X):
117
- return self.forward_fn(X)
118
-
119
- @property
120
- def current_task(self):
121
- return self._current_task
122
-
123
- @current_task.setter
124
- def current_task(self, task: Union[None, int]):
125
- self._current_task = task
126
- if task is None:
127
- self.forward_fn = lambda x: x
128
- else:
129
- self.forward_fn = self.lora_forward
130
 
131
  @classmethod
132
  def from_linear(
@@ -178,6 +161,7 @@ class LoRAParametrization(nn.Module):
178
  rank: int,
179
  dropout_p: float,
180
  alpha: float,
 
181
  ):
182
  if isinstance(layer, nn.Linear):
183
  parametrize.register_parametrization(
@@ -191,6 +175,22 @@ class LoRAParametrization(nn.Module):
191
  alpha=alpha,
192
  ),
193
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  elif isinstance(layer, nn.Embedding):
195
  parametrize.register_parametrization(
196
  layer,
@@ -204,10 +204,20 @@ class LoRAParametrization(nn.Module):
204
  ),
205
  )
206
 
207
- @staticmethod
208
- def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
209
- if isinstance(layer, LoRAParametrization):
210
- layer.current_task = task_idx
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
@@ -231,6 +241,16 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
231
  raise ValueError(
232
  f'`lora_adaptations` must be a list and contain at least one element'
233
  )
 
 
 
 
 
 
 
 
 
 
234
  self._adaptation_map = {
235
  name: idx for idx, name in enumerate(self._lora_adaptations)
236
  }
@@ -244,9 +264,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
244
  alpha=self._alpha,
245
  )
246
  self.main_params_trainable = config.lora_main_params_trainable
247
- self._task_idx = None
248
- # By default, disable LoRA until it's specified which adapter/task to use
249
- self.current_task = None
250
 
251
  @property
252
  def main_params_trainable(self):
@@ -300,42 +318,11 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
300
  rank=rank,
301
  dropout_p=dropout_p,
302
  alpha=alpha,
 
303
  )
304
  )
305
 
306
- @property
307
- def current_task(self):
308
- """Which LoRA is currently selected
309
- :return: Integer or None (when LoRA is disabled)
310
- """
311
- return self._task_idx
312
-
313
- @current_task.setter
314
- def current_task(self, task_name: Union[None, str]):
315
- """Set the LoRA that is to be used.
316
- The LoRA is specified by `task_idx`, which may be an integer >= 0,
317
- indexing the available LoRAs. If it is None, no LoRA is used.
318
- :param task_name: Which LoRA to use
319
- :return:
320
- """
321
- if task_name and task_name not in self._lora_adaptations:
322
- raise ValueError(
323
- f"Unsupported task '{task_name}'. "
324
- f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
325
- f"Alternatively, set `task` to `None` if you want to disable LoRA."
326
- )
327
- task_idx = self._adaptation_map[task_name] if task_name else None
328
- if self._task_idx != task_idx:
329
- # In this case, we need to update the LoRAs everywhere
330
- self._task_idx = task_idx
331
- self.apply(
332
- partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
333
- )
334
-
335
- def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
336
- if task != LORA_NO_UPDATE:
337
- self.current_task = task
338
-
339
  return self.roberta(*args, **kwargs)
340
 
341
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
@@ -355,27 +342,22 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
355
  def encode(
356
  self,
357
  *args,
358
- task: Union[str, None] = LORA_NO_UPDATE,
359
  **kwargs,
360
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
361
  """
362
  Computes sentence embeddings
363
 
364
- task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
365
- Specifies the task for which the encoding is intended. This parameter controls the
366
- use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
367
- to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
368
- existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
369
- adapters are disabled, and the model reverts to its original, general-purpose weights.
370
- If `task` is set to a specific LoRA adaptation, that adaptation is activated.
371
  """
372
- if task != LORA_NO_UPDATE:
373
- if not task:
374
- warnings.warn(
375
- f"Task-specific embeddings are disabled. To enable, specify the `task` "
376
- f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
377
- category=UserWarning,
378
- )
379
- self.current_task = task
380
 
381
- return self.roberta.encode(*args, **kwargs)
 
9
  import torch.nn.utils.parametrize as parametrize
10
  from torch import nn
11
  from torch.nn import Parameter
12
+ from torch.nn import functional as F
13
  from transformers import PretrainedConfig
14
 
15
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
16
 
17
 
 
 
 
18
  def initialized_weights(
19
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
20
  ) -> torch.Tensor:
 
89
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
90
  persistent=False,
91
  )
 
 
92
 
93
  def _dropout(self, A):
94
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
95
  return A * self.lora_dropout(self.lora_dropout_mask)
96
 
97
+ def lora_forward(self, X, current_task):
 
98
  return (
99
  X
100
  + torch.matmul(
101
  *self.swap(
102
  (
103
+ self.lora_B[current_task],
104
+ self.dropout_fn(self.lora_A[current_task]),
105
  )
106
  )
107
  ).view(X.shape)
 
109
  )
110
 
111
  def forward(self, X):
112
+ return X
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  @classmethod
115
  def from_linear(
 
161
  rank: int,
162
  dropout_p: float,
163
  alpha: float,
164
+ adaptation_map: dict,
165
  ):
166
  if isinstance(layer, nn.Linear):
167
  parametrize.register_parametrization(
 
175
  alpha=alpha,
176
  ),
177
  )
178
+
179
+ def new_forward(self, input, task_type, residual=False):
180
+ task_idx = adaptation_map[task_type] if task_type else None
181
+ if task_idx is not None:
182
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
183
+ else:
184
+ weights = self.weight
185
+
186
+ out = F.linear(input, weights, self.bias)
187
+
188
+ if residual:
189
+ return out, input
190
+ return out
191
+
192
+ layer.forward = new_forward.__get__(layer, layer.__class__)
193
+
194
  elif isinstance(layer, nn.Embedding):
195
  parametrize.register_parametrization(
196
  layer,
 
204
  ),
205
  )
206
 
207
+ def new_forward(self, input, task_type):
208
+ task_idx = adaptation_map[task_type] if task_type else None
209
+ if task_idx is not None:
210
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
211
+ else:
212
+ weights = self.weight
213
+
214
+ out = F.embedding(
215
+ input, weights, self.padding_idx, self.max_norm,
216
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
217
+
218
+ return out
219
+
220
+ layer.forward = new_forward.__get__(layer, layer.__class__)
221
 
222
 
223
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
 
241
  raise ValueError(
242
  f'`lora_adaptations` must be a list and contain at least one element'
243
  )
244
+ self._lora_prompts = config.lora_prompts
245
+ if (
246
+ not isinstance(self._lora_prompts, dict)
247
+ or len(self._lora_prompts) != len(self._lora_adaptations)
248
+ or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
249
+ ):
250
+ raise ValueError(
251
+ f'`lora_prompts` must be a dict and contain the same number of elements '
252
+ f'as `lora_adaptations` with all keys in `lora_prompts` present in `lora_adaptations`.'
253
+ )
254
  self._adaptation_map = {
255
  name: idx for idx, name in enumerate(self._lora_adaptations)
256
  }
 
264
  alpha=self._alpha,
265
  )
266
  self.main_params_trainable = config.lora_main_params_trainable
267
+
 
 
268
 
269
  @property
270
  def main_params_trainable(self):
 
318
  rank=rank,
319
  dropout_p=dropout_p,
320
  alpha=alpha,
321
+ adaptation_map=self._adaptation_map,
322
  )
323
  )
324
 
325
+ def forward(self, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  return self.roberta(*args, **kwargs)
327
 
328
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 
342
  def encode(
343
  self,
344
  *args,
345
+ task_type: Optional[str] = None,
346
  **kwargs,
347
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
348
  """
349
  Computes sentence embeddings
350
 
351
+ task_type(`str`, *optional*, defaults to `None`):
352
+ Specifies the task for which the encoding is intended. If `task_type` is not provide,
353
+ all LoRA adapters are disabled, and the model reverts to its original,
354
+ general-purpose weights.
 
 
 
355
  """
356
+ if task_type and task_type not in self._lora_adaptations:
357
+ raise ValueError(
358
+ f"Unsupported task '{task_type}'. "
359
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
360
+ f"Alternatively, don't pass the `task_type` argument to disable LoRA."
361
+ )
 
 
362
 
363
+ return self.roberta.encode(*args, task_type=task_type, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -21,7 +21,7 @@ import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
- from transformers import PretrainedConfig
25
  from transformers.modeling_utils import PreTrainedModel
26
  from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
@@ -204,7 +204,7 @@ class XLMRobertaEncoder(nn.Module):
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
- def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
@@ -215,6 +215,7 @@ class XLMRobertaEncoder(nn.Module):
215
  if key_padding_mask is not None
216
  else None
217
  )
 
218
  for layer in self.layers:
219
  if self._grad_checkpointing:
220
  hidden_states = torch.utils.checkpoint.checkpoint(
@@ -232,7 +233,7 @@ class XLMRobertaEncoder(nn.Module):
232
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
233
  hidden_states, key_padding_mask
234
  )
235
- mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
236
  if subset_mask is None:
237
  for layer in self.layers:
238
  if self._grad_checkpointing:
@@ -309,11 +310,13 @@ class XLMRobertaPooler(nn.Module):
309
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
310
  self.activation = nn.Tanh()
311
 
312
- def forward(self, hidden_states, pool=True):
313
  # We "pool" the model by simply taking the hidden state corresponding
314
  # to the first token.
 
 
315
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
316
- pooled_output = self.dense(first_token_tensor)
317
  pooled_output = self.activation(pooled_output)
318
  return pooled_output
319
 
@@ -440,7 +443,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
440
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
441
 
442
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
443
-
444
 
445
  @torch.inference_mode()
446
  def encode(
@@ -454,6 +457,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
454
  device: Optional[torch.device] = None,
455
  normalize_embeddings: bool = False,
456
  truncate_dim: Optional[int] = None,
 
457
  **tokenizer_kwargs,
458
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
459
  """
@@ -492,12 +496,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
492
  If convert_to_tensor, a stacked tensor is returned.
493
  If convert_to_numpy, a numpy matrix is returned.
494
  """
495
- from transformers import AutoTokenizer
496
-
497
- self.tokenizer = AutoTokenizer.from_pretrained(
498
- self.name_or_path, trust_remote_code=True
499
- )
500
-
501
  is_training = self.training
502
  self.eval()
503
 
@@ -544,14 +542,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
544
  )
545
  else:
546
  range_iter = range(0, len(sentences), batch_size)
547
-
548
  for i in range_iter:
549
  encoded_input = self.tokenizer(
550
  sentences[i : i + batch_size],
551
  return_tensors='pt',
552
  **tokenizer_kwargs,
553
  ).to(self.device)
554
- token_embs = self.forward(**encoded_input)[0]
555
 
556
  # Accumulate in fp32 to avoid overflow
557
  token_embs = token_embs.float()
@@ -639,7 +637,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
639
  layer output for these tokens.
640
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
641
  """
642
-
643
  if kwargs:
644
  for key, value in kwargs.items():
645
  if value is not None:
@@ -653,7 +651,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
653
  )
654
 
655
  hidden_states = self.embeddings(
656
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids
657
  )
658
  # TD [2022-12:18]: Don't need to force residual in fp32
659
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -677,12 +675,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
677
  subset_mask = None
678
 
679
  sequence_output = self.encoder(
680
- hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
681
  )
682
 
683
  if masked_tokens_mask is None:
684
  pooled_output = (
685
- self.pooler(sequence_output) if self.pooler is not None else None
686
  )
687
  else:
688
  # TD [2022-03-01]: the indexing here is very tricky.
@@ -696,7 +694,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
696
  pool_input = sequence_output[first_col_mask[subset_mask]]
697
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
698
  pooled_output = (
699
- self.pooler(pool_input, pool=False) if self.pooler is not None else None
700
  )
701
 
702
  if not return_dict:
@@ -1278,4 +1276,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1278
  logits=logits,
1279
  hidden_states=outputs.hidden_states,
1280
  attentions=outputs.attentions,
1281
- )
 
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
+ from transformers import PretrainedConfig, AutoTokenizer
25
  from transformers.modeling_utils import PreTrainedModel
26
  from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
 
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
 
215
  if key_padding_mask is not None
216
  else None
217
  )
218
+ mixer_kwargs['task_type'] = task_type
219
  for layer in self.layers:
220
  if self._grad_checkpointing:
221
  hidden_states = torch.utils.checkpoint.checkpoint(
 
233
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
234
  hidden_states, key_padding_mask
235
  )
236
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type}
237
  if subset_mask is None:
238
  for layer in self.layers:
239
  if self._grad_checkpointing:
 
310
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
311
  self.activation = nn.Tanh()
312
 
313
+ def forward(self, hidden_states, pool=True, task_type=None):
314
  # We "pool" the model by simply taking the hidden state corresponding
315
  # to the first token.
316
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
317
+
318
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
319
+ pooled_output = self.dense(first_token_tensor, **lora_kwargs)
320
  pooled_output = self.activation(pooled_output)
321
  return pooled_output
322
 
 
443
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
444
 
445
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
446
+ self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
447
 
448
  @torch.inference_mode()
449
  def encode(
 
457
  device: Optional[torch.device] = None,
458
  normalize_embeddings: bool = False,
459
  truncate_dim: Optional[int] = None,
460
+ task_type: Optional[str] = None,
461
  **tokenizer_kwargs,
462
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
463
  """
 
496
  If convert_to_tensor, a stacked tensor is returned.
497
  If convert_to_numpy, a numpy matrix is returned.
498
  """
 
 
 
 
 
 
499
  is_training = self.training
500
  self.eval()
501
 
 
542
  )
543
  else:
544
  range_iter = range(0, len(sentences), batch_size)
545
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
546
  for i in range_iter:
547
  encoded_input = self.tokenizer(
548
  sentences[i : i + batch_size],
549
  return_tensors='pt',
550
  **tokenizer_kwargs,
551
  ).to(self.device)
552
+ token_embs = self.forward(**encoded_input, **lora_kwargs)[0]
553
 
554
  # Accumulate in fp32 to avoid overflow
555
  token_embs = token_embs.float()
 
637
  layer output for these tokens.
638
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
639
  """
640
+ task_type = kwargs.pop('task_type', None)
641
  if kwargs:
642
  for key, value in kwargs.items():
643
  if value is not None:
 
651
  )
652
 
653
  hidden_states = self.embeddings(
654
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type
655
  )
656
  # TD [2022-12:18]: Don't need to force residual in fp32
657
  # BERT puts embedding LayerNorm before embedding dropout.
 
675
  subset_mask = None
676
 
677
  sequence_output = self.encoder(
678
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type
679
  )
680
 
681
  if masked_tokens_mask is None:
682
  pooled_output = (
683
+ self.pooler(sequence_output, task_type=task_type) if self.pooler is not None else None
684
  )
685
  else:
686
  # TD [2022-03-01]: the indexing here is very tricky.
 
694
  pool_input = sequence_output[first_col_mask[subset_mask]]
695
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
696
  pooled_output = (
697
+ self.pooler(pool_input, pool=False, task_type=task_type) if self.pooler is not None else None
698
  )
699
 
700
  if not return_dict:
 
1276
  logits=logits,
1277
  hidden_states=outputs.hidden_states,
1278
  attentions=outputs.attentions,
1279
+ )
rotary.py CHANGED
@@ -6,11 +6,13 @@ from typing import Optional, Tuple, Union
6
 
7
  import torch
8
  from einops import rearrange, repeat
9
- try:
10
- from flash_attn.ops.triton.rotary import apply_rotary
11
- except ImportError:
12
- def apply_rotary(*args, **kwargs):
13
- raise RuntimeError('RoPE requires flash-attention to be installed')
 
 
14
 
15
 
16
  def rotate_half(x, interleaved=False):
@@ -29,6 +31,10 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
29
  """
30
  ro_dim = cos.shape[-1] * 2
31
  assert ro_dim <= x.shape[-1]
 
 
 
 
32
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
33
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
34
  return torch.cat(
@@ -60,6 +66,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
60
  interleaved=interleaved,
61
  inplace=inplace,
62
  )
 
63
  if isinstance(seqlen_offsets, int):
64
  ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
65
  ctx.seqlen_offsets = seqlen_offsets
@@ -82,6 +89,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
82
  # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
83
  if not ctx.interleaved and not ctx.inplace:
84
  do = do.clone()
 
85
  dx = apply_rotary(
86
  do,
87
  cos,
@@ -150,21 +158,37 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
150
  # batch, seqlen, three, nheads, headdim = qkv.shape
151
  assert qkv.shape[-3] == 3
152
  if cos_k is None and sin_k is None and qkv.is_contiguous():
153
- # Call 1 kernel instead of 2 kernels
154
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
155
- # dimensions, we get the same tensor
156
- qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
157
- # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
158
- apply_rotary(
159
- qk,
160
- cos,
161
- sin,
162
- seqlen_offsets=seqlen_offsets,
163
- interleaved=interleaved,
164
- inplace=True,
165
- cu_seqlens=cu_seqlens,
166
- max_seqlen=max_seqlen,
167
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  else:
169
  cos_k = cos if cos_k is None else cos_k
170
  sin_k = sin if sin_k is None else sin_k
@@ -228,7 +252,6 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
228
  sin_k = sin if sin_k is None else sin_k
229
  dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
230
  apply_rotary(
231
-
232
  dq,
233
  cos,
234
  sin,
 
6
 
7
  import torch
8
  from einops import rearrange, repeat
9
+
10
+ if torch.cuda.is_available():
11
+ try:
12
+ from flash_attn.ops.triton.rotary import apply_rotary
13
+ except ImportError:
14
+ def apply_rotary(*args, **kwargs):
15
+ raise RuntimeError('RoPE requires flash-attention to be installed')
16
 
17
 
18
  def rotate_half(x, interleaved=False):
 
31
  """
32
  ro_dim = cos.shape[-1] * 2
33
  assert ro_dim <= x.shape[-1]
34
+ cos, sin = (
35
+ cos[:x.shape[1]],
36
+ sin[:x.shape[1]],
37
+ )
38
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
39
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
40
  return torch.cat(
 
66
  interleaved=interleaved,
67
  inplace=inplace,
68
  )
69
+
70
  if isinstance(seqlen_offsets, int):
71
  ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
72
  ctx.seqlen_offsets = seqlen_offsets
 
89
  # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
90
  if not ctx.interleaved and not ctx.inplace:
91
  do = do.clone()
92
+
93
  dx = apply_rotary(
94
  do,
95
  cos,
 
158
  # batch, seqlen, three, nheads, headdim = qkv.shape
159
  assert qkv.shape[-3] == 3
160
  if cos_k is None and sin_k is None and qkv.is_contiguous():
161
+
162
+ if torch.cuda.is_available():
163
+ # Call 1 kernel instead of 2 kernels
164
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
165
+ # dimensions, we get the same tensor
166
+ qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
167
+ # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
168
+ apply_rotary(
169
+ qk,
170
+ cos,
171
+ sin,
172
+ seqlen_offsets=seqlen_offsets,
173
+ interleaved=interleaved,
174
+ inplace=True,
175
+ cu_seqlens=cu_seqlens,
176
+ max_seqlen=max_seqlen,
177
+ )
178
+ else:
179
+ q_rot = apply_rotary_emb_torch(
180
+ qkv[:, :, 0],
181
+ cos,
182
+ sin,
183
+ interleaved=interleaved,
184
+ )
185
+ k_rot = apply_rotary_emb_torch(
186
+ qkv[:, :, 1],
187
+ cos,
188
+ sin,
189
+ interleaved=interleaved,
190
+ )
191
+ qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
192
  else:
193
  cos_k = cos if cos_k is None else cos_k
194
  sin_k = sin if sin_k is None else sin_k
 
252
  sin_k = sin if sin_k is None else sin_k
253
  dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
254
  apply_rotary(
 
255
  dq,
256
  cos,
257
  sin,