PoTaTo721's picture
Upload Fish-Agent Demo
4f6613a
raw
history blame
2.92 kB
from dataclasses import dataclass
import loralib as lora
@dataclass
class LoraConfig:
r: int
lora_alpha: float
lora_dropout: float = 0.0
def setup_lora(model, lora_config):
# Replace the embedding layer with a LoRA layer
model.embeddings = lora.Embedding(
num_embeddings=model.embeddings.num_embeddings,
embedding_dim=model.embeddings.embedding_dim,
padding_idx=model.embeddings.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
model.codebook_embeddings = lora.Embedding(
num_embeddings=model.codebook_embeddings.num_embeddings,
embedding_dim=model.codebook_embeddings.embedding_dim,
padding_idx=model.codebook_embeddings.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
# Replace output layer with a LoRA layer
linears = [(model, "output")]
# Replace all linear layers with LoRA layers
for layer in model.layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
if hasattr(model, "fast_layers"):
model.fast_embeddings = lora.Embedding(
num_embeddings=model.fast_embeddings.num_embeddings,
embedding_dim=model.fast_embeddings.embedding_dim,
padding_idx=model.fast_embeddings.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
# Dual-AR model
linears.append((model, "fast_output"))
for layer in model.fast_layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
for module, layer in linears:
updated_linear = lora.Linear(
in_features=getattr(module, layer).in_features,
out_features=getattr(module, layer).out_features,
bias=getattr(module, layer).bias,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
setattr(module, layer, updated_linear)
# Mark only the LoRA layers as trainable
lora.mark_only_lora_as_trainable(model, bias="none")
def get_merged_state_dict(model):
# This line will merge the state dict of the model and the LoRA parameters
model.eval()
# Then we need to remove the LoRA parameters from the state dict
state_dict = model.state_dict()
for name in list(state_dict.keys()):
if "lora" in name:
state_dict.pop(name)
return state_dict