Jackmin108 commited on
Commit
c35a42b
1 Parent(s): 7af97e7

fix: when sentences is one

Browse files

Signed-off-by: Meow <[email protected]>

Files changed (1) hide show
  1. modeling_lora.py +18 -9
modeling_lora.py CHANGED
@@ -11,8 +11,11 @@ from torch.nn import Parameter
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
- from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
15
- XLMRobertaPreTrainedModel)
 
 
 
16
 
17
 
18
  def initialized_weights(
@@ -241,6 +244,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
241
  """
242
  A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
243
  """
 
244
  def __init__(
245
  self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
246
  ):
@@ -262,7 +266,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
262
  if (
263
  not isinstance(self._task_instructions, dict)
264
  or len(self._task_instructions) != len(self._lora_adaptations)
265
- or not all([v in self._lora_adaptations for v in self._task_instructions.keys()])
 
 
266
  ):
267
  raise ValueError(
268
  f"`task_instructions` must be a dict and contain the same number of elements "
@@ -325,11 +331,11 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
325
  config = XLMRobertaFlashConfig.from_pretrained(
326
  pretrained_model_name_or_path, *model_args, **kwargs
327
  )
328
- if config.load_trained_adapters: # checkpoint already contains LoRA adapters
329
  return super().from_pretrained(
330
  pretrained_model_name_or_path, *model_args, **kwargs
331
  )
332
- else: # initializing new adapters
333
  roberta = XLMRobertaModel.from_pretrained(
334
  pretrained_model_name_or_path, *model_args, **kwargs
335
  )
@@ -387,14 +393,17 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
387
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
388
  )
389
  adapter_mask = None
 
390
  if task_type:
391
  task_id = self._adaptation_map[task_type]
392
- num_examples = 1 if isinstance(sentences, str) else len(sentences)
393
  adapter_mask = torch.full(
394
- (num_examples,), task_id, dtype=torch.int32, device=self.device
395
  )
396
- if task_type in ['query', 'passage']:
397
- sentences = [self._task_instructions[task_type] + ' ' + sentence for sentence in sentences]
 
 
 
398
  return self.roberta.encode(
399
  sentences, *args, adapter_mask=adapter_mask, **kwargs
400
  )
 
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
+ from .modeling_xlm_roberta import (
15
+ XLMRobertaFlashConfig,
16
+ XLMRobertaModel,
17
+ XLMRobertaPreTrainedModel,
18
+ )
19
 
20
 
21
  def initialized_weights(
 
244
  """
245
  A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
246
  """
247
+
248
  def __init__(
249
  self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
250
  ):
 
266
  if (
267
  not isinstance(self._task_instructions, dict)
268
  or len(self._task_instructions) != len(self._lora_adaptations)
269
+ or not all(
270
+ [v in self._lora_adaptations for v in self._task_instructions.keys()]
271
+ )
272
  ):
273
  raise ValueError(
274
  f"`task_instructions` must be a dict and contain the same number of elements "
 
331
  config = XLMRobertaFlashConfig.from_pretrained(
332
  pretrained_model_name_or_path, *model_args, **kwargs
333
  )
334
+ if config.load_trained_adapters: # checkpoint already contains LoRA adapters
335
  return super().from_pretrained(
336
  pretrained_model_name_or_path, *model_args, **kwargs
337
  )
338
+ else: # initializing new adapters
339
  roberta = XLMRobertaModel.from_pretrained(
340
  pretrained_model_name_or_path, *model_args, **kwargs
341
  )
 
393
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
394
  )
395
  adapter_mask = None
396
+ sentences = list(sentences) if isinstance(sentences, str) else sentences
397
  if task_type:
398
  task_id = self._adaptation_map[task_type]
 
399
  adapter_mask = torch.full(
400
+ (len(sentences),), task_id, dtype=torch.int32, device=self.device
401
  )
402
+ if task_type in ["query", "passage"]:
403
+ sentences = [
404
+ self._task_instructions[task_type] + " " + sentence
405
+ for sentence in sentences
406
+ ]
407
  return self.roberta.encode(
408
  sentences, *args, adapter_mask=adapter_mask, **kwargs
409
  )