Spaces:
Runtime error
Runtime error
import sys | |
import warnings | |
from bisect import bisect_right | |
import torch | |
import torch.nn as nn | |
from torch.optim import lr_scheduler | |
import threestudio | |
def get_scheduler(name): | |
if hasattr(lr_scheduler, name): | |
return getattr(lr_scheduler, name) | |
else: | |
raise NotImplementedError | |
def getattr_recursive(m, attr): | |
for name in attr.split("."): | |
m = getattr(m, name) | |
return m | |
def get_parameters(model, name): | |
module = getattr_recursive(model, name) | |
if isinstance(module, nn.Module): | |
return module.parameters() | |
elif isinstance(module, nn.Parameter): | |
return module | |
return [] | |
def parse_optimizer(config, model): | |
if hasattr(config, "params"): | |
params = [ | |
{"params": get_parameters(model, name), "name": name, **args} | |
for name, args in config.params.items() | |
] | |
threestudio.debug(f"Specify optimizer params: {config.params}") | |
else: | |
params = model.parameters() | |
if config.name in ["FusedAdam"]: | |
import apex | |
optim = getattr(apex.optimizers, config.name)(params, **config.args) | |
elif config.name in ["Adan"]: | |
from threestudio.systems import optimizers | |
optim = getattr(optimizers, config.name)(params, **config.args) | |
else: | |
optim = getattr(torch.optim, config.name)(params, **config.args) | |
return optim | |
def parse_scheduler(config, optimizer): | |
interval = config.get("interval", "epoch") | |
assert interval in ["epoch", "step"] | |
if config.name == "SequentialLR": | |
scheduler = { | |
"scheduler": lr_scheduler.SequentialLR( | |
optimizer, | |
[ | |
parse_scheduler(conf, optimizer)["scheduler"] | |
for conf in config.schedulers | |
], | |
milestones=config.milestones, | |
), | |
"interval": interval, | |
} | |
elif config.name == "ChainedScheduler": | |
scheduler = { | |
"scheduler": lr_scheduler.ChainedScheduler( | |
[ | |
parse_scheduler(conf, optimizer)["scheduler"] | |
for conf in config.schedulers | |
] | |
), | |
"interval": interval, | |
} | |
else: | |
scheduler = { | |
"scheduler": get_scheduler(config.name)(optimizer, **config.args), | |
"interval": interval, | |
} | |
return scheduler | |