configuration_xlm_roberta.py CHANGED
@@ -5,6 +5,9 @@ from transformers import PretrainedConfig
5
 
6
 
7
  class XLMRobertaFlashConfig(PretrainedConfig):
 
 
 
8
  def __init__(
9
  self,
10
  vocab_size: int = 250002,
@@ -25,9 +28,10 @@ class XLMRobertaFlashConfig(PretrainedConfig):
25
  position_embedding_type: str = "rotary",
26
  rotary_emb_base: float = 10000.0,
27
  use_cache: bool = True,
 
28
  classifier_dropout: Optional[float] = None,
29
  lora_adaptations: Optional[List[str]] = None,
30
- lora_prompts: Optional[Dict[str, str]] = None,
31
  lora_rank: int = 4,
32
  lora_dropout_p: float = 0.0,
33
  lora_alpha: int = 1,
@@ -62,6 +66,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
62
  position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
63
  rotary_emb_base (float): Base for rotary embeddings.
64
  use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
 
65
  classifier_dropout (Optional[float]): The dropout ratio for the classification head.
66
  lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
67
  lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
@@ -100,10 +105,11 @@ class XLMRobertaFlashConfig(PretrainedConfig):
100
  self.position_embedding_type = position_embedding_type
101
  self.rotary_emb_base = rotary_emb_base
102
  self.use_cache = use_cache
 
103
  self.classifier_dropout = classifier_dropout
104
  self.load_trained_adapters = load_trained_adapters
105
  self.lora_adaptations = lora_adaptations
106
- self.lora_prompts = lora_prompts
107
  self.lora_rank = lora_rank
108
  self.lora_dropout_p = lora_dropout_p
109
  self.lora_alpha = lora_alpha
 
5
 
6
 
7
  class XLMRobertaFlashConfig(PretrainedConfig):
8
+
9
+ model_type = "xlm-roberta"
10
+
11
  def __init__(
12
  self,
13
  vocab_size: int = 250002,
 
28
  position_embedding_type: str = "rotary",
29
  rotary_emb_base: float = 10000.0,
30
  use_cache: bool = True,
31
+ use_reentrant: bool = False,
32
  classifier_dropout: Optional[float] = None,
33
  lora_adaptations: Optional[List[str]] = None,
34
+ task_instructions: Optional[Dict[str, str]] = None,
35
  lora_rank: int = 4,
36
  lora_dropout_p: float = 0.0,
37
  lora_alpha: int = 1,
 
66
  position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
67
  rotary_emb_base (float): Base for rotary embeddings.
68
  use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
69
+ use_reentrant (bool): Whether or not the model should enable the 'use_reentrant' flag in gradient checkpointing.
70
  classifier_dropout (Optional[float]): The dropout ratio for the classification head.
71
  lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
72
  lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
 
105
  self.position_embedding_type = position_embedding_type
106
  self.rotary_emb_base = rotary_emb_base
107
  self.use_cache = use_cache
108
+ self.use_reentrant = use_reentrant
109
  self.classifier_dropout = classifier_dropout
110
  self.load_trained_adapters = load_trained_adapters
111
  self.lora_adaptations = lora_adaptations
112
+ self.task_instructions = task_instructions
113
  self.lora_rank = lora_rank
114
  self.lora_dropout_p = lora_dropout_p
115
  self.lora_alpha = lora_alpha
mha.py CHANGED
@@ -463,6 +463,7 @@ class MHA(nn.Module):
463
  scale_base=rotary_emb_scale_base,
464
  interleaved=rotary_emb_interleaved,
465
  device=device,
 
466
  )
467
 
468
  if fused_bias_fc and FusedDense is None:
 
463
  scale_base=rotary_emb_scale_base,
464
  interleaved=rotary_emb_interleaved,
465
  device=device,
466
+ use_flash_attn=use_flash_attn,
467
  )
468
 
469
  if fused_bias_fc and FusedDense is None:
modeling_lora.py CHANGED
@@ -11,6 +11,7 @@ from torch.nn import Parameter
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
 
14
  from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
15
  XLMRobertaPreTrainedModel)
16
 
@@ -164,7 +165,6 @@ class LoRAParametrization(nn.Module):
164
  ):
165
  """
166
  Registering LoRA adapters to all embedding and linear layers.
167
-
168
  Additionally, we implement a custom forward function for LoRA parametrization.
169
  This function modifies the layer's forward pass to optionally use task-specific
170
  parameters. When a `task_id` is provided, it employs a LoRA parametrization
@@ -241,6 +241,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
241
  """
242
  A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
243
  """
 
244
  def __init__(
245
  self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
246
  ):
@@ -258,15 +259,17 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
258
  raise ValueError(
259
  f"`lora_adaptations` must be a list and contain at least one element"
260
  )
261
- self._lora_prompts = config.lora_prompts
262
  if (
263
- not isinstance(self._lora_prompts, dict)
264
- or len(self._lora_prompts) != len(self._lora_adaptations)
265
- or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
 
 
266
  ):
267
  raise ValueError(
268
- f"`lora_prompts` must be a dict and contain the same number of elements "
269
- f"as `lora_adaptations` with all keys in `lora_prompts` present in `lora_adaptations`."
270
  )
271
  self._adaptation_map = {
272
  name: idx for idx, name in enumerate(self._lora_adaptations)
@@ -322,16 +325,13 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
322
  use_safetensors: bool = None,
323
  **kwargs,
324
  ):
325
- config = XLMRobertaFlashConfig.from_pretrained(
326
- pretrained_model_name_or_path, *model_args, **kwargs
327
- )
328
- if config.load_trained_adapters: # checkpoint already contains LoRA adapters
329
  return super().from_pretrained(
330
- pretrained_model_name_or_path, *model_args, **kwargs
331
  )
332
- else: # initializing new adapters
333
  roberta = XLMRobertaModel.from_pretrained(
334
- pretrained_model_name_or_path, *model_args, **kwargs
335
  )
336
  return cls(config, roberta=roberta)
337
 
@@ -372,7 +372,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
372
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
373
  """
374
  Computes sentence embeddings.
375
-
376
  sentences(`str` or `List[str]`):
377
  Sentence or sentences to be encoded
378
  task_type(`str`, *optional*, defaults to `None`):
@@ -393,6 +392,10 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
393
  adapter_mask = torch.full(
394
  (num_examples,), task_id, dtype=torch.int32, device=self.device
395
  )
 
 
 
 
396
  return self.roberta.encode(
397
  sentences, *args, adapter_mask=adapter_mask, **kwargs
398
  )
 
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
+ from .rotary import RotaryEmbedding
15
  from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
16
  XLMRobertaPreTrainedModel)
17
 
 
165
  ):
166
  """
167
  Registering LoRA adapters to all embedding and linear layers.
 
168
  Additionally, we implement a custom forward function for LoRA parametrization.
169
  This function modifies the layer's forward pass to optionally use task-specific
170
  parameters. When a `task_id` is provided, it employs a LoRA parametrization
 
241
  """
242
  A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
243
  """
244
+
245
  def __init__(
246
  self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
247
  ):
 
259
  raise ValueError(
260
  f"`lora_adaptations` must be a list and contain at least one element"
261
  )
262
+ self._task_instructions = config.task_instructions
263
  if (
264
+ not isinstance(self._task_instructions, dict)
265
+ or len(self._task_instructions) != len(self._lora_adaptations)
266
+ or not all(
267
+ [v in self._lora_adaptations for v in self._task_instructions.keys()]
268
+ )
269
  ):
270
  raise ValueError(
271
+ f"`task_instructions` must be a dict and contain the same number of elements "
272
+ f"as `lora_adaptations` with all keys in `task_instructions` present in `lora_adaptations`."
273
  )
274
  self._adaptation_map = {
275
  name: idx for idx, name in enumerate(self._lora_adaptations)
 
325
  use_safetensors: bool = None,
326
  **kwargs,
327
  ):
328
+ if config.load_trained_adapters: # checkpoint already contains LoRA adapters
 
 
 
329
  return super().from_pretrained(
330
+ pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
331
  )
332
+ else: # initializing new adapters
333
  roberta = XLMRobertaModel.from_pretrained(
334
+ pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
335
  )
336
  return cls(config, roberta=roberta)
337
 
 
372
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
373
  """
374
  Computes sentence embeddings.
 
375
  sentences(`str` or `List[str]`):
376
  Sentence or sentences to be encoded
377
  task_type(`str`, *optional*, defaults to `None`):
 
392
  adapter_mask = torch.full(
393
  (num_examples,), task_id, dtype=torch.int32, device=self.device
394
  )
395
+ if isinstance(sentences, str):
396
+ sentences = self._task_instructions[task_type] + sentences
397
+ else:
398
+ sentences = [self._task_instructions[task_type] + sentence for sentence in sentences]
399
  return self.roberta.encode(
400
  sentences, *args, adapter_mask=adapter_mask, **kwargs
401
  )
modeling_xlm_roberta.py CHANGED
@@ -30,6 +30,7 @@ from transformers.models.bert.modeling_bert import (
30
  from transformers.models.xlm_roberta.modeling_xlm_roberta import \
31
  XLMRobertaLMHead
32
 
 
33
  from .block import Block
34
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
35
  from .embedding import XLMRobertaEmbeddings
@@ -63,9 +64,7 @@ logger = logging.getLogger(__name__)
63
 
64
 
65
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
66
- if not getattr(config, "use_flash_attn", False):
67
- return False
68
- if not torch.cuda.is_available():
69
  return False
70
  if importlib.util.find_spec("flash_attn") is None:
71
  logger.warning(
@@ -181,6 +180,7 @@ class XLMRobertaEncoder(nn.Module):
181
  def __init__(self, config: XLMRobertaFlashConfig):
182
  super().__init__()
183
  self.use_flash_attn = get_use_flash_attn(config)
 
184
  self.layers = nn.ModuleList(
185
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
186
  )
@@ -210,7 +210,7 @@ class XLMRobertaEncoder(nn.Module):
210
  hidden_states = torch.utils.checkpoint.checkpoint(
211
  layer,
212
  hidden_states,
213
- use_reentrant=False,
214
  mixer_kwargs=mixer_kwargs,
215
  )
216
  else:
@@ -234,7 +234,7 @@ class XLMRobertaEncoder(nn.Module):
234
  hidden_states = torch.utils.checkpoint.checkpoint(
235
  layer,
236
  hidden_states,
237
- use_reentrant=False,
238
  mixer_kwargs=mixer_kwargs,
239
  )
240
  else:
@@ -246,7 +246,7 @@ class XLMRobertaEncoder(nn.Module):
246
  hidden_states = torch.utils.checkpoint.checkpoint(
247
  layer,
248
  hidden_states,
249
- use_reentrant=False,
250
  mixer_kwargs=mixer_kwargs,
251
  )
252
  else:
@@ -284,7 +284,7 @@ class XLMRobertaEncoder(nn.Module):
284
  torch.utils.checkpoint.checkpoint(
285
  self.layers[-1],
286
  hidden_states_subset,
287
- use_reentrant=False,
288
  mixer_kwargs=mixer_kwargs,
289
  )
290
  else:
 
30
  from transformers.models.xlm_roberta.modeling_xlm_roberta import \
31
  XLMRobertaLMHead
32
 
33
+ from .rotary import RotaryEmbedding
34
  from .block import Block
35
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
36
  from .embedding import XLMRobertaEmbeddings
 
64
 
65
 
66
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
67
+ if not getattr(config, "use_flash_attn", False) or not torch.cuda.is_available():
 
 
68
  return False
69
  if importlib.util.find_spec("flash_attn") is None:
70
  logger.warning(
 
180
  def __init__(self, config: XLMRobertaFlashConfig):
181
  super().__init__()
182
  self.use_flash_attn = get_use_flash_attn(config)
183
+ self.use_reentrant = config.use_reentrant
184
  self.layers = nn.ModuleList(
185
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
186
  )
 
210
  hidden_states = torch.utils.checkpoint.checkpoint(
211
  layer,
212
  hidden_states,
213
+ use_reentrant=self.use_reentrant,
214
  mixer_kwargs=mixer_kwargs,
215
  )
216
  else:
 
234
  hidden_states = torch.utils.checkpoint.checkpoint(
235
  layer,
236
  hidden_states,
237
+ use_reentrant=self.use_reentrant,
238
  mixer_kwargs=mixer_kwargs,
239
  )
240
  else:
 
246
  hidden_states = torch.utils.checkpoint.checkpoint(
247
  layer,
248
  hidden_states,
249
+ use_reentrant=self.use_reentrant,
250
  mixer_kwargs=mixer_kwargs,
251
  )
252
  else:
 
284
  torch.utils.checkpoint.checkpoint(
285
  self.layers[-1],
286
  hidden_states_subset,
287
+ use_reentrant=self.use_reentrant,
288
  mixer_kwargs=mixer_kwargs,
289
  )
290
  else:
rotary.py CHANGED
@@ -4,7 +4,6 @@
4
 
5
  # Copyright (c) 2023, Tri Dao.
6
 
7
- import math
8
  from typing import Optional, Tuple, Union
9
 
10
  import torch
@@ -16,7 +15,10 @@ if torch.cuda.is_available():
16
  except ImportError:
17
 
18
  def apply_rotary(*args, **kwargs):
19
- raise RuntimeError("RoPE requires flash-attention to be installed")
 
 
 
20
 
21
 
22
  def rotate_half(x, interleaved=False):
@@ -169,12 +171,13 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
169
  seqlen_offsets: Union[int, torch.Tensor] = 0,
170
  cu_seqlens: Optional[torch.Tensor] = None,
171
  max_seqlen: Optional[int] = None,
 
172
  ):
173
  # batch, seqlen, three, nheads, headdim = qkv.shape
174
  assert qkv.shape[-3] == 3
175
  if cos_k is None and sin_k is None and qkv.is_contiguous():
176
 
177
- if torch.cuda.is_available():
178
  # Call 1 kernel instead of 2 kernels
179
  # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
180
  # dimensions, we get the same tensor
@@ -288,7 +291,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
288
  cu_seqlens=cu_seqlens,
289
  max_seqlen=ctx.max_seqlen,
290
  )
291
- return dqkv, None, None, None, None, None, None, None, None
292
 
293
 
294
  def apply_rotary_emb_qkv_(
@@ -301,6 +304,7 @@ def apply_rotary_emb_qkv_(
301
  seqlen_offsets: Union[int, torch.Tensor] = 0,
302
  cu_seqlens: Optional[torch.Tensor] = None,
303
  max_seqlen: Optional[int] = None,
 
304
  ):
305
  """
306
  Arguments:
@@ -321,7 +325,7 @@ def apply_rotary_emb_qkv_(
321
  Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
322
  """
323
  return ApplyRotaryEmbQKV_.apply(
324
- qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
325
  )
326
 
327
 
@@ -443,6 +447,7 @@ class RotaryEmbedding(torch.nn.Module):
443
  scale_base=None,
444
  pos_idx_in_fp32=True,
445
  device=None,
 
446
  ):
447
  """
448
  interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
@@ -462,6 +467,7 @@ class RotaryEmbedding(torch.nn.Module):
462
  self.dim = dim
463
  self._base = float(base)
464
  self.pos_idx_in_fp32 = pos_idx_in_fp32
 
465
  # Generate and save the inverse frequency buffer (non trainable)
466
  inv_freq = self._compute_inv_freq(device)
467
  self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -588,6 +594,7 @@ class RotaryEmbedding(torch.nn.Module):
588
  seqlen_offsets=seqlen_offset,
589
  cu_seqlens=cu_seqlens,
590
  max_seqlen=max_seqlen,
 
591
  )
592
  else:
593
  return apply_rotary_emb_qkv_(
@@ -600,6 +607,7 @@ class RotaryEmbedding(torch.nn.Module):
600
  seqlen_offsets=seqlen_offset,
601
  cu_seqlens=cu_seqlens,
602
  max_seqlen=max_seqlen,
 
603
  )
604
  else:
605
  q = qkv
 
4
 
5
  # Copyright (c) 2023, Tri Dao.
6
 
 
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
 
15
  except ImportError:
16
 
17
  def apply_rotary(*args, **kwargs):
18
+ raise RuntimeError(
19
+ "FlashAttention is not installed. To proceed with training, please install FlashAttention. "
20
+ "For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
21
+ )
22
 
23
 
24
  def rotate_half(x, interleaved=False):
 
171
  seqlen_offsets: Union[int, torch.Tensor] = 0,
172
  cu_seqlens: Optional[torch.Tensor] = None,
173
  max_seqlen: Optional[int] = None,
174
+ use_flash_attn: bool = True,
175
  ):
176
  # batch, seqlen, three, nheads, headdim = qkv.shape
177
  assert qkv.shape[-3] == 3
178
  if cos_k is None and sin_k is None and qkv.is_contiguous():
179
 
180
+ if use_flash_attn:
181
  # Call 1 kernel instead of 2 kernels
182
  # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
183
  # dimensions, we get the same tensor
 
291
  cu_seqlens=cu_seqlens,
292
  max_seqlen=ctx.max_seqlen,
293
  )
294
+ return dqkv, None, None, None, None, None, None, None, None, None
295
 
296
 
297
  def apply_rotary_emb_qkv_(
 
304
  seqlen_offsets: Union[int, torch.Tensor] = 0,
305
  cu_seqlens: Optional[torch.Tensor] = None,
306
  max_seqlen: Optional[int] = None,
307
+ use_flash_attn=True,
308
  ):
309
  """
310
  Arguments:
 
325
  Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
326
  """
327
  return ApplyRotaryEmbQKV_.apply(
328
+ qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
329
  )
330
 
331
 
 
447
  scale_base=None,
448
  pos_idx_in_fp32=True,
449
  device=None,
450
+ use_flash_attn=True,
451
  ):
452
  """
453
  interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
 
467
  self.dim = dim
468
  self._base = float(base)
469
  self.pos_idx_in_fp32 = pos_idx_in_fp32
470
+ self.use_flash_attn = use_flash_attn
471
  # Generate and save the inverse frequency buffer (non trainable)
472
  inv_freq = self._compute_inv_freq(device)
473
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
594
  seqlen_offsets=seqlen_offset,
595
  cu_seqlens=cu_seqlens,
596
  max_seqlen=max_seqlen,
597
+ use_flash_attn=self.use_flash_attn,
598
  )
599
  else:
600
  return apply_rotary_emb_qkv_(
 
607
  seqlen_offsets=seqlen_offset,
608
  cu_seqlens=cu_seqlens,
609
  max_seqlen=max_seqlen,
610
+ use_flash_attn=self.use_flash_attn,
611
  )
612
  else:
613
  q = qkv