Commit
•
7c4a80c
1
Parent(s):
98c3cd2
lora bugfix (#16)
Browse files- fix: lora bug (4c504d33aca884a998533b089cb905b597a82467)
Co-authored-by: Jack Min Ong <[email protected]>
- modeling_lora.py +13 -8
modeling_lora.py
CHANGED
@@ -11,7 +11,7 @@ from torch import nn
|
|
11 |
from torch.nn import Parameter
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
-
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
|
15 |
|
16 |
|
17 |
LORA_NO_UPDATE = '__lora_no_update__'
|
@@ -210,13 +210,19 @@ class LoRAParametrization(nn.Module):
|
|
210 |
layer.current_task = task_idx
|
211 |
|
212 |
|
213 |
-
class XLMRobertaLoRA(
|
214 |
def __init__(
|
215 |
self,
|
216 |
config: XLMRobertaFlashConfig,
|
|
|
217 |
):
|
218 |
super().__init__(config)
|
219 |
|
|
|
|
|
|
|
|
|
|
|
220 |
self._lora_adaptations = config.lora_adaptations
|
221 |
if (
|
222 |
not isinstance(self._lora_adaptations, list)
|
@@ -231,7 +237,6 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
231 |
self._rank = config.lora_rank
|
232 |
self._dropout_p = config.lora_dropout_p
|
233 |
self._alpha = config.lora_alpha
|
234 |
-
|
235 |
self._register_lora(
|
236 |
num_adaptations=len(self._lora_adaptations),
|
237 |
rank=self._rank,
|
@@ -284,9 +289,8 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
284 |
pretrained_model_name_or_path, *model_args, **kwargs
|
285 |
)
|
286 |
else:
|
287 |
-
|
288 |
-
|
289 |
-
return cls(config)
|
290 |
|
291 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
292 |
self.apply(
|
@@ -331,7 +335,8 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
331 |
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
332 |
if task != LORA_NO_UPDATE:
|
333 |
self.current_task = task
|
334 |
-
|
|
|
335 |
|
336 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
337 |
for _, param in self.named_parameters(recurse=recurse):
|
@@ -373,4 +378,4 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
373 |
)
|
374 |
self.current_task = task
|
375 |
|
376 |
-
return
|
|
|
11 |
from torch.nn import Parameter
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
+
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
|
15 |
|
16 |
|
17 |
LORA_NO_UPDATE = '__lora_no_update__'
|
|
|
210 |
layer.current_task = task_idx
|
211 |
|
212 |
|
213 |
+
class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
214 |
def __init__(
|
215 |
self,
|
216 |
config: XLMRobertaFlashConfig,
|
217 |
+
roberta: Optional[XLMRobertaModel] = None
|
218 |
):
|
219 |
super().__init__(config)
|
220 |
|
221 |
+
if roberta is None:
|
222 |
+
self.roberta = XLMRobertaModel(config)
|
223 |
+
else:
|
224 |
+
self.roberta = roberta
|
225 |
+
|
226 |
self._lora_adaptations = config.lora_adaptations
|
227 |
if (
|
228 |
not isinstance(self._lora_adaptations, list)
|
|
|
237 |
self._rank = config.lora_rank
|
238 |
self._dropout_p = config.lora_dropout_p
|
239 |
self._alpha = config.lora_alpha
|
|
|
240 |
self._register_lora(
|
241 |
num_adaptations=len(self._lora_adaptations),
|
242 |
rank=self._rank,
|
|
|
289 |
pretrained_model_name_or_path, *model_args, **kwargs
|
290 |
)
|
291 |
else:
|
292 |
+
roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
293 |
+
return cls(config, roberta=roberta)
|
|
|
294 |
|
295 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
296 |
self.apply(
|
|
|
335 |
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
336 |
if task != LORA_NO_UPDATE:
|
337 |
self.current_task = task
|
338 |
+
|
339 |
+
return self.roberta(*args, **kwargs)
|
340 |
|
341 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
342 |
for _, param in self.named_parameters(recurse=recurse):
|
|
|
378 |
)
|
379 |
self.current_task = task
|
380 |
|
381 |
+
return self.roberta.encode(*args, **kwargs)
|