Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import logging | |
import os.path as osp | |
import warnings | |
from abc import ABCMeta, abstractmethod | |
import torch | |
from torch.optim import Optimizer | |
import annotator.mmpkg.mmcv as mmcv | |
from ..parallel import is_module_wrapper | |
from .checkpoint import load_checkpoint | |
from .dist_utils import get_dist_info | |
from .hooks import HOOKS, Hook | |
from .log_buffer import LogBuffer | |
from .priority import Priority, get_priority | |
from .utils import get_time_str | |
class BaseRunner(metaclass=ABCMeta): | |
"""The base class of Runner, a training helper for PyTorch. | |
All subclasses should implement the following APIs: | |
- ``run()`` | |
- ``train()`` | |
- ``val()`` | |
- ``save_checkpoint()`` | |
Args: | |
model (:obj:`torch.nn.Module`): The model to be run. | |
batch_processor (callable): A callable method that process a data | |
batch. The interface of this method should be | |
`batch_processor(model, data, train_mode) -> dict` | |
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an | |
optimizer (in most cases) or a dict of optimizers (in models that | |
requires more than one optimizer, e.g., GAN). | |
work_dir (str, optional): The working directory to save checkpoints | |
and logs. Defaults to None. | |
logger (:obj:`logging.Logger`): Logger used during training. | |
Defaults to None. (The default value is just for backward | |
compatibility) | |
meta (dict | None): A dict records some import information such as | |
environment info and seed, which will be logged in logger hook. | |
Defaults to None. | |
max_epochs (int, optional): Total training epochs. | |
max_iters (int, optional): Total training iterations. | |
""" | |
def __init__(self, | |
model, | |
batch_processor=None, | |
optimizer=None, | |
work_dir=None, | |
logger=None, | |
meta=None, | |
max_iters=None, | |
max_epochs=None): | |
if batch_processor is not None: | |
if not callable(batch_processor): | |
raise TypeError('batch_processor must be callable, ' | |
f'but got {type(batch_processor)}') | |
warnings.warn('batch_processor is deprecated, please implement ' | |
'train_step() and val_step() in the model instead.') | |
# raise an error is `batch_processor` is not None and | |
# `model.train_step()` exists. | |
if is_module_wrapper(model): | |
_model = model.module | |
else: | |
_model = model | |
if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'): | |
raise RuntimeError( | |
'batch_processor and model.train_step()/model.val_step() ' | |
'cannot be both available.') | |
else: | |
assert hasattr(model, 'train_step') | |
# check the type of `optimizer` | |
if isinstance(optimizer, dict): | |
for name, optim in optimizer.items(): | |
if not isinstance(optim, Optimizer): | |
raise TypeError( | |
f'optimizer must be a dict of torch.optim.Optimizers, ' | |
f'but optimizer["{name}"] is a {type(optim)}') | |
elif not isinstance(optimizer, Optimizer) and optimizer is not None: | |
raise TypeError( | |
f'optimizer must be a torch.optim.Optimizer object ' | |
f'or dict or None, but got {type(optimizer)}') | |
# check the type of `logger` | |
if not isinstance(logger, logging.Logger): | |
raise TypeError(f'logger must be a logging.Logger object, ' | |
f'but got {type(logger)}') | |
# check the type of `meta` | |
if meta is not None and not isinstance(meta, dict): | |
raise TypeError( | |
f'meta must be a dict or None, but got {type(meta)}') | |
self.model = model | |
self.batch_processor = batch_processor | |
self.optimizer = optimizer | |
self.logger = logger | |
self.meta = meta | |
# create work_dir | |
if mmcv.is_str(work_dir): | |
self.work_dir = osp.abspath(work_dir) | |
mmcv.mkdir_or_exist(self.work_dir) | |
elif work_dir is None: | |
self.work_dir = None | |
else: | |
raise TypeError('"work_dir" must be a str or None') | |
# get model name from the model class | |
if hasattr(self.model, 'module'): | |
self._model_name = self.model.module.__class__.__name__ | |
else: | |
self._model_name = self.model.__class__.__name__ | |
self._rank, self._world_size = get_dist_info() | |
self.timestamp = get_time_str() | |
self.mode = None | |
self._hooks = [] | |
self._epoch = 0 | |
self._iter = 0 | |
self._inner_iter = 0 | |
if max_epochs is not None and max_iters is not None: | |
raise ValueError( | |
'Only one of `max_epochs` or `max_iters` can be set.') | |
self._max_epochs = max_epochs | |
self._max_iters = max_iters | |
# TODO: Redesign LogBuffer, it is not flexible and elegant enough | |
self.log_buffer = LogBuffer() | |
def model_name(self): | |
"""str: Name of the model, usually the module class name.""" | |
return self._model_name | |
def rank(self): | |
"""int: Rank of current process. (distributed training)""" | |
return self._rank | |
def world_size(self): | |
"""int: Number of processes participating in the job. | |
(distributed training)""" | |
return self._world_size | |
def hooks(self): | |
"""list[:obj:`Hook`]: A list of registered hooks.""" | |
return self._hooks | |
def epoch(self): | |
"""int: Current epoch.""" | |
return self._epoch | |
def iter(self): | |
"""int: Current iteration.""" | |
return self._iter | |
def inner_iter(self): | |
"""int: Iteration in an epoch.""" | |
return self._inner_iter | |
def max_epochs(self): | |
"""int: Maximum training epochs.""" | |
return self._max_epochs | |
def max_iters(self): | |
"""int: Maximum training iterations.""" | |
return self._max_iters | |
def train(self): | |
pass | |
def val(self): | |
pass | |
def run(self, data_loaders, workflow, **kwargs): | |
pass | |
def save_checkpoint(self, | |
out_dir, | |
filename_tmpl, | |
save_optimizer=True, | |
meta=None, | |
create_symlink=True): | |
pass | |
def current_lr(self): | |
"""Get current learning rates. | |
Returns: | |
list[float] | dict[str, list[float]]: Current learning rates of all | |
param groups. If the runner has a dict of optimizers, this | |
method will return a dict. | |
""" | |
if isinstance(self.optimizer, torch.optim.Optimizer): | |
lr = [group['lr'] for group in self.optimizer.param_groups] | |
elif isinstance(self.optimizer, dict): | |
lr = dict() | |
for name, optim in self.optimizer.items(): | |
lr[name] = [group['lr'] for group in optim.param_groups] | |
else: | |
raise RuntimeError( | |
'lr is not applicable because optimizer does not exist.') | |
return lr | |
def current_momentum(self): | |
"""Get current momentums. | |
Returns: | |
list[float] | dict[str, list[float]]: Current momentums of all | |
param groups. If the runner has a dict of optimizers, this | |
method will return a dict. | |
""" | |
def _get_momentum(optimizer): | |
momentums = [] | |
for group in optimizer.param_groups: | |
if 'momentum' in group.keys(): | |
momentums.append(group['momentum']) | |
elif 'betas' in group.keys(): | |
momentums.append(group['betas'][0]) | |
else: | |
momentums.append(0) | |
return momentums | |
if self.optimizer is None: | |
raise RuntimeError( | |
'momentum is not applicable because optimizer does not exist.') | |
elif isinstance(self.optimizer, torch.optim.Optimizer): | |
momentums = _get_momentum(self.optimizer) | |
elif isinstance(self.optimizer, dict): | |
momentums = dict() | |
for name, optim in self.optimizer.items(): | |
momentums[name] = _get_momentum(optim) | |
return momentums | |
def register_hook(self, hook, priority='NORMAL'): | |
"""Register a hook into the hook list. | |
The hook will be inserted into a priority queue, with the specified | |
priority (See :class:`Priority` for details of priorities). | |
For hooks with the same priority, they will be triggered in the same | |
order as they are registered. | |
Args: | |
hook (:obj:`Hook`): The hook to be registered. | |
priority (int or str or :obj:`Priority`): Hook priority. | |
Lower value means higher priority. | |
""" | |
assert isinstance(hook, Hook) | |
if hasattr(hook, 'priority'): | |
raise ValueError('"priority" is a reserved attribute for hooks') | |
priority = get_priority(priority) | |
hook.priority = priority | |
# insert the hook to a sorted list | |
inserted = False | |
for i in range(len(self._hooks) - 1, -1, -1): | |
if priority >= self._hooks[i].priority: | |
self._hooks.insert(i + 1, hook) | |
inserted = True | |
break | |
if not inserted: | |
self._hooks.insert(0, hook) | |
def register_hook_from_cfg(self, hook_cfg): | |
"""Register a hook from its cfg. | |
Args: | |
hook_cfg (dict): Hook config. It should have at least keys 'type' | |
and 'priority' indicating its type and priority. | |
Notes: | |
The specific hook class to register should not use 'type' and | |
'priority' arguments during initialization. | |
""" | |
hook_cfg = hook_cfg.copy() | |
priority = hook_cfg.pop('priority', 'NORMAL') | |
hook = mmcv.build_from_cfg(hook_cfg, HOOKS) | |
self.register_hook(hook, priority=priority) | |
def call_hook(self, fn_name): | |
"""Call all hooks. | |
Args: | |
fn_name (str): The function name in each hook to be called, such as | |
"before_train_epoch". | |
""" | |
for hook in self._hooks: | |
getattr(hook, fn_name)(self) | |
def get_hook_info(self): | |
# Get hooks info in each stage | |
stage_hook_map = {stage: [] for stage in Hook.stages} | |
for hook in self.hooks: | |
try: | |
priority = Priority(hook.priority).name | |
except ValueError: | |
priority = hook.priority | |
classname = hook.__class__.__name__ | |
hook_info = f'({priority:<12}) {classname:<35}' | |
for trigger_stage in hook.get_triggered_stages(): | |
stage_hook_map[trigger_stage].append(hook_info) | |
stage_hook_infos = [] | |
for stage in Hook.stages: | |
hook_infos = stage_hook_map[stage] | |
if len(hook_infos) > 0: | |
info = f'{stage}:\n' | |
info += '\n'.join(hook_infos) | |
info += '\n -------------------- ' | |
stage_hook_infos.append(info) | |
return '\n'.join(stage_hook_infos) | |
def load_checkpoint(self, | |
filename, | |
map_location='cpu', | |
strict=False, | |
revise_keys=[(r'^module.', '')]): | |
return load_checkpoint( | |
self.model, | |
filename, | |
map_location, | |
strict, | |
self.logger, | |
revise_keys=revise_keys) | |
def resume(self, | |
checkpoint, | |
resume_optimizer=True, | |
map_location='default'): | |
if map_location == 'default': | |
if torch.cuda.is_available(): | |
device_id = torch.cuda.current_device() | |
checkpoint = self.load_checkpoint( | |
checkpoint, | |
map_location=lambda storage, loc: storage.cuda(device_id)) | |
else: | |
checkpoint = self.load_checkpoint(checkpoint) | |
else: | |
checkpoint = self.load_checkpoint( | |
checkpoint, map_location=map_location) | |
self._epoch = checkpoint['meta']['epoch'] | |
self._iter = checkpoint['meta']['iter'] | |
if self.meta is None: | |
self.meta = {} | |
self.meta.setdefault('hook_msgs', {}) | |
# load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages | |
self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {})) | |
# Re-calculate the number of iterations when resuming | |
# models with different number of GPUs | |
if 'config' in checkpoint['meta']: | |
config = mmcv.Config.fromstring( | |
checkpoint['meta']['config'], file_format='.py') | |
previous_gpu_ids = config.get('gpu_ids', None) | |
if previous_gpu_ids and len(previous_gpu_ids) > 0 and len( | |
previous_gpu_ids) != self.world_size: | |
self._iter = int(self._iter * len(previous_gpu_ids) / | |
self.world_size) | |
self.logger.info('the iteration number is changed due to ' | |
'change of GPU number') | |
# resume meta information meta | |
self.meta = checkpoint['meta'] | |
if 'optimizer' in checkpoint and resume_optimizer: | |
if isinstance(self.optimizer, Optimizer): | |
self.optimizer.load_state_dict(checkpoint['optimizer']) | |
elif isinstance(self.optimizer, dict): | |
for k in self.optimizer.keys(): | |
self.optimizer[k].load_state_dict( | |
checkpoint['optimizer'][k]) | |
else: | |
raise TypeError( | |
'Optimizer should be dict or torch.optim.Optimizer ' | |
f'but got {type(self.optimizer)}') | |
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) | |
def register_lr_hook(self, lr_config): | |
if lr_config is None: | |
return | |
elif isinstance(lr_config, dict): | |
assert 'policy' in lr_config | |
policy_type = lr_config.pop('policy') | |
# If the type of policy is all in lower case, e.g., 'cyclic', | |
# then its first letter will be capitalized, e.g., to be 'Cyclic'. | |
# This is for the convenient usage of Lr updater. | |
# Since this is not applicable for ` | |
# CosineAnnealingLrUpdater`, | |
# the string will not be changed if it contains capital letters. | |
if policy_type == policy_type.lower(): | |
policy_type = policy_type.title() | |
hook_type = policy_type + 'LrUpdaterHook' | |
lr_config['type'] = hook_type | |
hook = mmcv.build_from_cfg(lr_config, HOOKS) | |
else: | |
hook = lr_config | |
self.register_hook(hook, priority='VERY_HIGH') | |
def register_momentum_hook(self, momentum_config): | |
if momentum_config is None: | |
return | |
if isinstance(momentum_config, dict): | |
assert 'policy' in momentum_config | |
policy_type = momentum_config.pop('policy') | |
# If the type of policy is all in lower case, e.g., 'cyclic', | |
# then its first letter will be capitalized, e.g., to be 'Cyclic'. | |
# This is for the convenient usage of momentum updater. | |
# Since this is not applicable for | |
# `CosineAnnealingMomentumUpdater`, | |
# the string will not be changed if it contains capital letters. | |
if policy_type == policy_type.lower(): | |
policy_type = policy_type.title() | |
hook_type = policy_type + 'MomentumUpdaterHook' | |
momentum_config['type'] = hook_type | |
hook = mmcv.build_from_cfg(momentum_config, HOOKS) | |
else: | |
hook = momentum_config | |
self.register_hook(hook, priority='HIGH') | |
def register_optimizer_hook(self, optimizer_config): | |
if optimizer_config is None: | |
return | |
if isinstance(optimizer_config, dict): | |
optimizer_config.setdefault('type', 'OptimizerHook') | |
hook = mmcv.build_from_cfg(optimizer_config, HOOKS) | |
else: | |
hook = optimizer_config | |
self.register_hook(hook, priority='ABOVE_NORMAL') | |
def register_checkpoint_hook(self, checkpoint_config): | |
if checkpoint_config is None: | |
return | |
if isinstance(checkpoint_config, dict): | |
checkpoint_config.setdefault('type', 'CheckpointHook') | |
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS) | |
else: | |
hook = checkpoint_config | |
self.register_hook(hook, priority='NORMAL') | |
def register_logger_hooks(self, log_config): | |
if log_config is None: | |
return | |
log_interval = log_config['interval'] | |
for info in log_config['hooks']: | |
logger_hook = mmcv.build_from_cfg( | |
info, HOOKS, default_args=dict(interval=log_interval)) | |
self.register_hook(logger_hook, priority='VERY_LOW') | |
def register_timer_hook(self, timer_config): | |
if timer_config is None: | |
return | |
if isinstance(timer_config, dict): | |
timer_config_ = copy.deepcopy(timer_config) | |
hook = mmcv.build_from_cfg(timer_config_, HOOKS) | |
else: | |
hook = timer_config | |
self.register_hook(hook, priority='LOW') | |
def register_custom_hooks(self, custom_config): | |
if custom_config is None: | |
return | |
if not isinstance(custom_config, list): | |
custom_config = [custom_config] | |
for item in custom_config: | |
if isinstance(item, dict): | |
self.register_hook_from_cfg(item) | |
else: | |
self.register_hook(item, priority='NORMAL') | |
def register_profiler_hook(self, profiler_config): | |
if profiler_config is None: | |
return | |
if isinstance(profiler_config, dict): | |
profiler_config.setdefault('type', 'ProfilerHook') | |
hook = mmcv.build_from_cfg(profiler_config, HOOKS) | |
else: | |
hook = profiler_config | |
self.register_hook(hook) | |
def register_training_hooks(self, | |
lr_config, | |
optimizer_config=None, | |
checkpoint_config=None, | |
log_config=None, | |
momentum_config=None, | |
timer_config=dict(type='IterTimerHook'), | |
custom_hooks_config=None): | |
"""Register default and custom hooks for training. | |
Default and custom hooks include: | |
+----------------------+-------------------------+ | |
| Hooks | Priority | | |
+======================+=========================+ | |
| LrUpdaterHook | VERY_HIGH (10) | | |
+----------------------+-------------------------+ | |
| MomentumUpdaterHook | HIGH (30) | | |
+----------------------+-------------------------+ | |
| OptimizerStepperHook | ABOVE_NORMAL (40) | | |
+----------------------+-------------------------+ | |
| CheckpointSaverHook | NORMAL (50) | | |
+----------------------+-------------------------+ | |
| IterTimerHook | LOW (70) | | |
+----------------------+-------------------------+ | |
| LoggerHook(s) | VERY_LOW (90) | | |
+----------------------+-------------------------+ | |
| CustomHook(s) | defaults to NORMAL (50) | | |
+----------------------+-------------------------+ | |
If custom hooks have same priority with default hooks, custom hooks | |
will be triggered after default hooks. | |
""" | |
self.register_lr_hook(lr_config) | |
self.register_momentum_hook(momentum_config) | |
self.register_optimizer_hook(optimizer_config) | |
self.register_checkpoint_hook(checkpoint_config) | |
self.register_timer_hook(timer_config) | |
self.register_logger_hooks(log_config) | |
self.register_custom_hooks(custom_hooks_config) | |