|
|
|
|
|
|
|
|
|
|
|
|
|
r""" |
|
Low Ranking Adaptation for LLMs scheme. |
|
|
|
โโโโโโโโโโโโโโโโโโโโโ |
|
โ h โ |
|
โโโโโโโโโโโโโโโโโโโโโ |
|
โฒ |
|
| |
|
+ |
|
/ \ |
|
โโโโโโโโโโโโโโโโโโโ โญโโโโโโโโโโโโโโโโฎ Matrix initialization: |
|
โ โ \ B / B = 0 |
|
โ pretrained โ \ r*d / A = N(0, sigma^2) |
|
โ weights โ โฐโโโโโโโโโโฏ |
|
โ โ | r | r - rank |
|
โ W e R^(d*d) โ | โโโโโโโถ | |
|
โ โ โญโโโโโโโโโโฎ |
|
โโโโโโโโโโโโโโโโโโโ / A \ |
|
โฒ / d*r \ |
|
\ โฐโโโโโโโโโโโโโโโโฏ |
|
\ โฒ |
|
\ / |
|
\ / |
|
โโโโโโโโโโโโโโโโโโโโโ |
|
โ x โ |
|
โโโโโโโโโโโโโโโโโโโโโ |
|
|
|
With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d, |
|
we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates |
|
for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of |
|
course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen |
|
pretrained weights and thus fine-tune the model. |
|
|
|
The goal of this approach is to move weight updates into a separate matrix which is decomposed with |
|
two matrices of a lower rank. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import math |
|
from typing import Dict, List |
|
|
|
import lit_llama.model as llama |
|
|
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
|
|
|
|
class LoRALayer(): |
|
def __init__( |
|
self, |
|
r: int, |
|
lora_alpha: int, |
|
lora_dropout: float, |
|
merge_weights: bool, |
|
): |
|
"""Store LoRA specific attributes in a class. |
|
|
|
Args: |
|
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
lora_alpha: alpha is needed for scaling updates as alpha/r |
|
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
merge_weights: whether we want to merge pretrained weights and LoRA weight updates. This is useful if one wants to use |
|
fine-tuned model as a standalone one (without storing LoRA weights separately) plus it helps to reduce |
|
overhead during inference. |
|
""" |
|
self.r = r |
|
self.lora_alpha = lora_alpha |
|
|
|
if lora_dropout > 0.: |
|
self.lora_dropout = nn.Dropout(p=lora_dropout) |
|
else: |
|
self.lora_dropout = lambda x: x |
|
|
|
self.merged = False |
|
self.merge_weights = merge_weights |
|
|
|
|
|
class MergedLinear(nn.Linear, LoRALayer): |
|
|
|
def __init__( |
|
self, |
|
|
|
in_features: int, |
|
out_features: int, |
|
|
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0., |
|
enable_lora: List[bool] = [False], |
|
fan_in_fan_out: bool = False, |
|
merge_weights: bool = True, |
|
**kwargs |
|
): |
|
"""LoRA wrapper around linear class that is used for calculation of q, k and v matrices. |
|
|
|
This class has three weight matrices: |
|
1. Pretrained weights are stored as `self.weight` (because of the nn.Linear inheritance) |
|
2. LoRA A matrix as `self.lora_A` |
|
3. LoRA B matrix as `self.lora_B` |
|
Only LoRA's A and B matrices are updated, pretrained weights stay frozen. |
|
|
|
Args: |
|
in_features: number of input features of the pretrained weights |
|
out_features: number of output features of the pretrained weights |
|
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
lora_alpha: alpha is needed for scaling updates as alpha/r |
|
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we |
|
don't want to apply LoRA for all three (query, key and value) we can set it as False. For example if we want |
|
to apply LoRA only to `query` and `value` but keep `key` without weight updates we should pass `[True, |
|
False, True]` |
|
fan_in_fan_out: set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses |
|
`Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True` |
|
https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#LL53C9-L53C112 |
|
merge_weights: whether we want to merge pretrained weights and LoRA weight updates. This is useful if one wants to use |
|
fine-tuned model as a standalone one (without storing LoRA weight separately) plus it helps to reduce |
|
overhead during inference. |
|
""" |
|
nn.Linear.__init__(self, in_features, out_features, **kwargs) |
|
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
|
merge_weights=merge_weights) |
|
assert out_features % len(enable_lora) == 0, \ |
|
'The length of enable_lora must divide out_features' |
|
self.enable_lora = enable_lora |
|
self.fan_in_fan_out = fan_in_fan_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if r > 0 and any(enable_lora): |
|
self.lora_A = nn.Parameter( |
|
self.weight.new_zeros((r * sum(enable_lora), in_features))) |
|
self.lora_B = nn.Parameter( |
|
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.scaling = self.lora_alpha / self.r |
|
|
|
|
|
self.weight.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lora_ind = self.weight.new_zeros( |
|
(out_features, ), dtype=torch.bool |
|
).view(len(enable_lora), -1) |
|
self.lora_ind[enable_lora, :] = True |
|
self.lora_ind = self.lora_ind.view(-1) |
|
self.reset_parameters() |
|
if fan_in_fan_out: |
|
self.weight.data = self.weight.data.T |
|
|
|
def reset_parameters(self): |
|
"""Reset all the weights, even including pretrained ones.""" |
|
nn.Linear.reset_parameters(self) |
|
if hasattr(self, 'lora_A'): |
|
|
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
nn.init.zeros_(self.lora_B) |
|
|
|
def zero_pad(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Properly pad weight updates with zeros. |
|
|
|
If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, |
|
then the weights update should be: |
|
|
|
[[ฮW,ฮW,ฮW, ..., 0,0,0, ..., ฮW,ฮW,ฮW,], |
|
[....................................], |
|
[ฮW,ฮW,ฮW, ..., 0,0,0, ..., ฮW,ฮW,ฮW,]] |
|
โ โ โ |
|
________________________________________ |
|
| query | key | value | |
|
---------------------------------------- |
|
|
|
Args: |
|
x: tensor with weights update that will be padded with zeros if necessary |
|
|
|
Returns: |
|
A tensor with weight updates and zeros for deselected q, k or v |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.transpose(0, 1) |
|
result = x.new_zeros((*x.shape[:-1], self.out_features)) |
|
result = result.view(-1, self.out_features) |
|
result[:, self.lora_ind] = x.reshape( |
|
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) |
|
) |
|
return result.view((*x.shape[:-1], self.out_features)).transpose(0, 1) |
|
|
|
def train(self, mode: bool = True): |
|
"""Set the module into train or eval mode if `mode` is True of False respectively. |
|
|
|
For train mode (train(True)) if weights are merged we need to subtract weights updates (LoRA_A @ LoRA_B) from |
|
pretrained weights so we can continue training LoRA's matrices A and B and keep pretrained weights frozen. |
|
|
|
For eval mode (train(False)) if weights are not merged we need to add weight updates to pretrained weights in |
|
order to reduce computational overhead during inference. |
|
|
|
Args: |
|
mode: if True the module will be set into train mode (affects Dropout and BatchNorm), if False - eval mode. |
|
|
|
""" |
|
def T(w): |
|
return w.T if self.fan_in_fan_out else w |
|
|
|
|
|
nn.Linear.train(self, mode) |
|
|
|
|
|
|
|
should = self.merged if mode else not self.merged |
|
|
|
|
|
|
|
|
|
|
|
if self.merge_weights and should: |
|
if self.r > 0 and any(self.enable_lora): |
|
delta_w = F.conv1d( |
|
self.lora_A.data.unsqueeze(0), |
|
self.lora_B.data.unsqueeze(-1), |
|
groups=sum(self.enable_lora) |
|
).squeeze(0) |
|
|
|
sign = -1 if mode else 1 |
|
self.weight.data += sign * self.zero_pad(T(delta_w * self.scaling)) |
|
self.merged = not mode |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Do the forward pass. |
|
|
|
If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication. |
|
If not, then multiply pretrained weights with input, apply LoRA on input and do summation. |
|
|
|
Args: |
|
x: input tensor of shape (batch_size, context_length, embedding_size) |
|
|
|
Returns: |
|
Output tensor of shape (batch_size, context_length, 3 * embedding_size) |
|
""" |
|
def T(w): |
|
return w.T if self.fan_in_fan_out else w |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.merged: |
|
return F.linear(x, T(self.weight), bias=self.bias) |
|
else: |
|
|
|
result = F.linear(x, T(self.weight), bias=self.bias) |
|
if self.r > 0: |
|
after_A = F.linear(self.lora_dropout(x), self.lora_A) |
|
|
|
|
|
|
|
|
|
|
|
after_B = F.conv1d( |
|
after_A.transpose(-2, -1), |
|
self.lora_B.unsqueeze(-1), |
|
groups=sum(self.enable_lora) |
|
).transpose(-2, -1) |
|
result += self.zero_pad(after_B) * self.scaling |
|
return result |
|
|
|
|
|
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: |
|
"""Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights. |
|
|
|
Args: |
|
model: model with LoRA layers |
|
bias: |
|
``"none"``: all bias weights will be frozen, |
|
``"lora_only"``: only bias weight for LoRA layers will be unfrozen, |
|
``"all"``: all bias weights will be unfrozen. |
|
|
|
Raises: |
|
NotImplementedError: if `bias` not in ["none", "lora_only", "all"] |
|
""" |
|
|
|
for n, p in model.named_parameters(): |
|
if 'lora_' not in n: |
|
p.requires_grad = False |
|
|
|
|
|
if bias == 'none': |
|
return |
|
elif bias == 'all': |
|
for n, p in model.named_parameters(): |
|
if 'bias' in n: |
|
p.requires_grad = True |
|
elif bias == 'lora_only': |
|
for m in model.modules(): |
|
if isinstance(m, LoRALayer) and \ |
|
hasattr(m, 'bias') and \ |
|
m.bias is not None: |
|
m.bias.requires_grad = True |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: |
|
"""Return state_dict with weights of LoRA's A and B matrices and with biases depending on the `bias` value. |
|
|
|
Args: |
|
model: model with LoRA layers |
|
bias: |
|
``"none"``: state dict will not store bias weights, |
|
``"lora_only"``: state dict will store bias weights only from LoRA layers, |
|
``"all"``: state dict will store all bias weights. |
|
|
|
Returns: |
|
Weights and biases of LoRA layers |
|
|
|
Raises: |
|
NotImplementedError: if `bias` not in ["none", "lora_only", "all"] |
|
""" |
|
my_state_dict = model.state_dict() |
|
if bias == 'none': |
|
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} |
|
elif bias == 'all': |
|
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k} |
|
elif bias == 'lora_only': |
|
to_return = {} |
|
for k in my_state_dict: |
|
if 'lora_' in k: |
|
to_return[k] = my_state_dict[k] |
|
bias_name = k.split('lora_')[0]+'bias' |
|
if bias_name in my_state_dict: |
|
to_return[bias_name] = my_state_dict[bias_name] |
|
return to_return |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
@dataclass |
|
class LoRAConfig: |
|
r: float = 0.0 |
|
alpha: float = 1.0 |
|
dropout: float = 0.0 |
|
|
|
|
|
class CausalSelfAttention(llama.CausalSelfAttention): |
|
lora_config = None |
|
|
|
def __init__(self, config: llama.LLaMAConfig) -> None: |
|
"""Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for |
|
parameter-efficient fine-tuning. |
|
|
|
*Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for |
|
query, key and value for each head) we can do this in a single pass with a single weight matrix. |
|
|
|
Args: |
|
config: |
|
``"block_size"``: size of the context of the model, |
|
``"vocab_size"``: number of unique tokens, |
|
``"padded_vocab_size"``: padded size of the vocabulary to the nearest multiple of 64 (leads to a greater performance), |
|
``"n_layer"``: number of transformer blocks (self-attention + MLP), |
|
``"n_head"``: number of heads in multi-head attention mechanism, |
|
``"n_embd"``: size of the embedding: vector representation of each token. |
|
""" |
|
|
|
|
|
nn.Module.__init__(self) |
|
assert config.n_embd % config.n_head == 0 |
|
|
|
|
|
self.c_attn = MergedLinear( |
|
in_features=config.n_embd, |
|
out_features=3 * config.n_embd, |
|
r=self.lora_config.r, |
|
lora_alpha=self.lora_config.alpha, |
|
lora_dropout=self.lora_config.dropout, |
|
enable_lora=[True, False, True], |
|
fan_in_fan_out = False, |
|
merge_weights=True, |
|
bias=False) |
|
|
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
|
self.n_head = config.n_head |
|
self.n_embd = config.n_embd |
|
self.block_size = config.block_size |
|
self.rope_cache = None |
|
|
|
|
|
@contextmanager |
|
def lora(r, alpha, dropout, enabled: bool = True): |
|
"""Apply context manager under which you can instantiate the model with LoRA. |
|
|
|
In a nutshell the code inside this function forces to use LoRA variant of causal self-attention |
|
instead of the original one (without LoRA). |
|
|
|
Args: |
|
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
alpha: alpha is needed for scaling updates as alpha/r |
|
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
enabled: enables/disables LoRA |
|
""" |
|
if not enabled: |
|
yield |
|
return |
|
|
|
CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout) |
|
|
|
|
|
causal_self_attention = llama.CausalSelfAttention |
|
llama.CausalSelfAttention = CausalSelfAttention |
|
yield |
|
|
|
llama.CausalSelfAttention = causal_self_attention |
|
|
|
CausalSelfAttention.lora_config = None |
|
|