Spaces:
Running
Running
File size: 1,625 Bytes
a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
def build_optimizer(model, config):
name = config.TRAINER.OPTIMIZER
lr = config.TRAINER.TRUE_LR
if name == "adam":
return torch.optim.Adam(
model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY
)
elif name == "adamw":
return torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY
)
else:
raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
def build_scheduler(config, optimizer):
"""
Returns:
scheduler (dict):{
'scheduler': lr_scheduler,
'interval': 'step', # or 'epoch'
'monitor': 'val_f1', (optional)
'frequency': x, (optional)
}
"""
scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL}
name = config.TRAINER.SCHEDULER
if name == "MultiStepLR":
scheduler.update(
{
"scheduler": MultiStepLR(
optimizer,
config.TRAINER.MSLR_MILESTONES,
gamma=config.TRAINER.MSLR_GAMMA,
)
}
)
elif name == "CosineAnnealing":
scheduler.update(
{"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}
)
elif name == "ExponentialLR":
scheduler.update(
{"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}
)
else:
raise NotImplementedError()
return scheduler
|