Spaces:
Runtime error
Runtime error
File size: 7,317 Bytes
b334e29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from annotator.uniformer.mmcv.fileio import FileClient
from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook
@HOOKS.register_module()
class CheckpointHook(Hook):
"""Save checkpoints periodically.
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The root directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir``
and the last level directory of ``runner.work_dir``.
`Changed in version 1.3.16.`
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
save_last (bool, optional): Whether to force the last checkpoint to be
saved regardless of interval. Default: True.
sync_buffer (bool, optional): Whether to synchronize buffers in
different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
.. warning::
Before v1.3.16, the ``out_dir`` argument indicates the path where the
checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
root directory and the final path to save checkpoint is the
concatenation of ``out_dir`` and the last level directory of
``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
and the value of ``runner.work_dir`` is "/path/of/B", then the final
path will be "/path/of/A/B".
"""
def __init__(self,
interval=-1,
by_epoch=True,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
save_last=True,
sync_buffer=False,
file_client_args=None,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args
def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.'))
# disable the create_symlink option because some file backends do not
# allow to create a symlink
if 'create_symlink' in self.args:
if self.args[
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
('create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}'))
else:
self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner):
if not self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 2. reach the last epoch of training
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
@master_only
def _save_checkpoint(self, runner):
"""Save the current checkpoint and delete unwanted checkpoint."""
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
if runner.meta is not None:
if self.by_epoch:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename)
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
name = 'epoch_{}.pth'
current_ckpt = runner.epoch + 1
else:
name = 'iter_{}.pth'
current_ckpt = runner.iter + 1
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = self.file_client.join_path(
self.out_dir, filename_tmpl.format(_step))
if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path)
else:
break
def after_train_iter(self, runner):
if self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
|