# Copyright (c) OpenMMLab. All rights reserved. import copy from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS from mmcv.utils import Registry, build_from_cfg OPTIMIZER_BUILDERS = Registry( 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) def build_optimizer_constructor(cfg): constructor_type = cfg.get('type') if constructor_type in OPTIMIZER_BUILDERS: return build_from_cfg(cfg, OPTIMIZER_BUILDERS) elif constructor_type in MMCV_OPTIMIZER_BUILDERS: return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) else: raise KeyError(f'{constructor_type} is not registered ' 'in the optimizer builder registry.') def build_optimizer(model, cfg): optimizer_cfg = copy.deepcopy(cfg) constructor_type = optimizer_cfg.pop('constructor', 'DefaultOptimizerConstructor') paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) optim_constructor = build_optimizer_constructor( dict( type=constructor_type, optimizer_cfg=optimizer_cfg, paramwise_cfg=paramwise_cfg)) optimizer = optim_constructor(model) return optimizer