Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import warnings | |
from annotator.uniformer.mmcv.fileio import FileClient | |
from ..dist_utils import allreduce_params, master_only | |
from .hook import HOOKS, Hook | |
class CheckpointHook(Hook): | |
"""Save checkpoints periodically. | |
Args: | |
interval (int): The saving period. If ``by_epoch=True``, interval | |
indicates epochs, otherwise it indicates iterations. | |
Default: -1, which means "never". | |
by_epoch (bool): Saving checkpoints by epoch or by iteration. | |
Default: True. | |
save_optimizer (bool): Whether to save optimizer state_dict in the | |
checkpoint. It is usually used for resuming experiments. | |
Default: True. | |
out_dir (str, optional): The root directory to save checkpoints. If not | |
specified, ``runner.work_dir`` will be used by default. If | |
specified, the ``out_dir`` will be the concatenation of ``out_dir`` | |
and the last level directory of ``runner.work_dir``. | |
`Changed in version 1.3.16.` | |
max_keep_ckpts (int, optional): The maximum checkpoints to keep. | |
In some cases we want only the latest few checkpoints and would | |
like to delete old ones to save the disk space. | |
Default: -1, which means unlimited. | |
save_last (bool, optional): Whether to force the last checkpoint to be | |
saved regardless of interval. Default: True. | |
sync_buffer (bool, optional): Whether to synchronize buffers in | |
different gpus. Default: False. | |
file_client_args (dict, optional): Arguments to instantiate a | |
FileClient. See :class:`mmcv.fileio.FileClient` for details. | |
Default: None. | |
`New in version 1.3.16.` | |
.. warning:: | |
Before v1.3.16, the ``out_dir`` argument indicates the path where the | |
checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the | |
root directory and the final path to save checkpoint is the | |
concatenation of ``out_dir`` and the last level directory of | |
``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A" | |
and the value of ``runner.work_dir`` is "/path/of/B", then the final | |
path will be "/path/of/A/B". | |
""" | |
def __init__(self, | |
interval=-1, | |
by_epoch=True, | |
save_optimizer=True, | |
out_dir=None, | |
max_keep_ckpts=-1, | |
save_last=True, | |
sync_buffer=False, | |
file_client_args=None, | |
**kwargs): | |
self.interval = interval | |
self.by_epoch = by_epoch | |
self.save_optimizer = save_optimizer | |
self.out_dir = out_dir | |
self.max_keep_ckpts = max_keep_ckpts | |
self.save_last = save_last | |
self.args = kwargs | |
self.sync_buffer = sync_buffer | |
self.file_client_args = file_client_args | |
def before_run(self, runner): | |
if not self.out_dir: | |
self.out_dir = runner.work_dir | |
self.file_client = FileClient.infer_client(self.file_client_args, | |
self.out_dir) | |
# if `self.out_dir` is not equal to `runner.work_dir`, it means that | |
# `self.out_dir` is set so the final `self.out_dir` is the | |
# concatenation of `self.out_dir` and the last level directory of | |
# `runner.work_dir` | |
if self.out_dir != runner.work_dir: | |
basename = osp.basename(runner.work_dir.rstrip(osp.sep)) | |
self.out_dir = self.file_client.join_path(self.out_dir, basename) | |
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by ' | |
f'{self.file_client.name}.')) | |
# disable the create_symlink option because some file backends do not | |
# allow to create a symlink | |
if 'create_symlink' in self.args: | |
if self.args[ | |
'create_symlink'] and not self.file_client.allow_symlink: | |
self.args['create_symlink'] = False | |
warnings.warn( | |
('create_symlink is set as True by the user but is changed' | |
'to be False because creating symbolic link is not ' | |
f'allowed in {self.file_client.name}')) | |
else: | |
self.args['create_symlink'] = self.file_client.allow_symlink | |
def after_train_epoch(self, runner): | |
if not self.by_epoch: | |
return | |
# save checkpoint for following cases: | |
# 1. every ``self.interval`` epochs | |
# 2. reach the last epoch of training | |
if self.every_n_epochs( | |
runner, self.interval) or (self.save_last | |
and self.is_last_epoch(runner)): | |
runner.logger.info( | |
f'Saving checkpoint at {runner.epoch + 1} epochs') | |
if self.sync_buffer: | |
allreduce_params(runner.model.buffers()) | |
self._save_checkpoint(runner) | |
def _save_checkpoint(self, runner): | |
"""Save the current checkpoint and delete unwanted checkpoint.""" | |
runner.save_checkpoint( | |
self.out_dir, save_optimizer=self.save_optimizer, **self.args) | |
if runner.meta is not None: | |
if self.by_epoch: | |
cur_ckpt_filename = self.args.get( | |
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) | |
else: | |
cur_ckpt_filename = self.args.get( | |
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) | |
runner.meta.setdefault('hook_msgs', dict()) | |
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( | |
self.out_dir, cur_ckpt_filename) | |
# remove other checkpoints | |
if self.max_keep_ckpts > 0: | |
if self.by_epoch: | |
name = 'epoch_{}.pth' | |
current_ckpt = runner.epoch + 1 | |
else: | |
name = 'iter_{}.pth' | |
current_ckpt = runner.iter + 1 | |
redundant_ckpts = range( | |
current_ckpt - self.max_keep_ckpts * self.interval, 0, | |
-self.interval) | |
filename_tmpl = self.args.get('filename_tmpl', name) | |
for _step in redundant_ckpts: | |
ckpt_path = self.file_client.join_path( | |
self.out_dir, filename_tmpl.format(_step)) | |
if self.file_client.isfile(ckpt_path): | |
self.file_client.remove(ckpt_path) | |
else: | |
break | |
def after_train_iter(self, runner): | |
if self.by_epoch: | |
return | |
# save checkpoint for following cases: | |
# 1. every ``self.interval`` iterations | |
# 2. reach the last iteration of training | |
if self.every_n_iters( | |
runner, self.interval) or (self.save_last | |
and self.is_last_iter(runner)): | |
runner.logger.info( | |
f'Saving checkpoint at {runner.iter + 1} iterations') | |
if self.sync_buffer: | |
allreduce_params(runner.model.buffers()) | |
self._save_checkpoint(runner) | |