Spaces:
Paused
Paused
File size: 2,923 Bytes
4f6613a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from dataclasses import dataclass
import loralib as lora
@dataclass
class LoraConfig:
r: int
lora_alpha: float
lora_dropout: float = 0.0
def setup_lora(model, lora_config):
# Replace the embedding layer with a LoRA layer
model.embeddings = lora.Embedding(
num_embeddings=model.embeddings.num_embeddings,
embedding_dim=model.embeddings.embedding_dim,
padding_idx=model.embeddings.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
model.codebook_embeddings = lora.Embedding(
num_embeddings=model.codebook_embeddings.num_embeddings,
embedding_dim=model.codebook_embeddings.embedding_dim,
padding_idx=model.codebook_embeddings.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
# Replace output layer with a LoRA layer
linears = [(model, "output")]
# Replace all linear layers with LoRA layers
for layer in model.layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
if hasattr(model, "fast_layers"):
model.fast_embeddings = lora.Embedding(
num_embeddings=model.fast_embeddings.num_embeddings,
embedding_dim=model.fast_embeddings.embedding_dim,
padding_idx=model.fast_embeddings.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
# Dual-AR model
linears.append((model, "fast_output"))
for layer in model.fast_layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
for module, layer in linears:
updated_linear = lora.Linear(
in_features=getattr(module, layer).in_features,
out_features=getattr(module, layer).out_features,
bias=getattr(module, layer).bias,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
setattr(module, layer, updated_linear)
# Mark only the LoRA layers as trainable
lora.mark_only_lora_as_trainable(model, bias="none")
def get_merged_state_dict(model):
# This line will merge the state dict of the model and the LoRA parameters
model.eval()
# Then we need to remove the LoRA parameters from the state dict
state_dict = model.state_dict()
for name in list(state_dict.keys()):
if "lora" in name:
state_dict.pop(name)
return state_dict
|