QuintW's picture
Upload 1350 files
3f9c56c
raw
history blame
5.27 kB
import bisect
import functools
import logging
import numbers
import os
import signal
import sys
import traceback
import warnings
import torch
from pytorch_lightning import seed_everything
LOGGER = logging.getLogger(__name__)
def check_and_warn_input_range(tensor, min_value, max_value, name):
actual_min = tensor.min()
actual_max = tensor.max()
if actual_min < min_value or actual_max > max_value:
warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
for k, v in cur_dict.items():
target_key = prefix + k
target[target_key] = target.get(target_key, default) + v
def average_dicts(dict_list):
result = {}
norm = 1e-3
for dct in dict_list:
sum_dict_with_prefix(result, dct, '')
norm += 1
for k in list(result):
result[k] /= norm
return result
def add_prefix_to_keys(dct, prefix):
return {prefix + k: v for k, v in dct.items()}
def set_requires_grad(module, value):
for param in module.parameters():
param.requires_grad = value
def flatten_dict(dct):
result = {}
for k, v in dct.items():
if isinstance(k, tuple):
k = '_'.join(k)
if isinstance(v, dict):
for sub_k, sub_v in flatten_dict(v).items():
result[f'{k}_{sub_k}'] = sub_v
else:
result[k] = v
return result
class LinearRamp:
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
self.start_value = start_value
self.end_value = end_value
self.start_iter = start_iter
self.end_iter = end_iter
def __call__(self, i):
if i < self.start_iter:
return self.start_value
if i >= self.end_iter:
return self.end_value
part = (i - self.start_iter) / (self.end_iter - self.start_iter)
return self.start_value * (1 - part) + self.end_value * part
class LadderRamp:
def __init__(self, start_iters, values):
self.start_iters = start_iters
self.values = values
assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
def __call__(self, i):
segment_i = bisect.bisect_right(self.start_iters, i)
return self.values[segment_i]
def get_ramp(kind='ladder', **kwargs):
if kind == 'linear':
return LinearRamp(**kwargs)
if kind == 'ladder':
return LadderRamp(**kwargs)
raise ValueError(f'Unexpected ramp kind: {kind}')
def print_traceback_handler(sig, frame):
LOGGER.warning(f'Received signal {sig}')
bt = ''.join(traceback.format_stack())
LOGGER.warning(f'Requested stack trace:\n{bt}')
def register_debug_signal_handlers(sig=None, handler=print_traceback_handler):
LOGGER.warning(f'Setting signal {sig} handler {handler}')
signal.signal(sig, handler)
def handle_deterministic_config(config):
seed = dict(config).get('seed', None)
if seed is None:
return False
seed_everything(seed)
return True
def get_shape(t):
if torch.is_tensor(t):
return tuple(t.shape)
elif isinstance(t, dict):
return {n: get_shape(q) for n, q in t.items()}
elif isinstance(t, (list, tuple)):
return [get_shape(q) for q in t]
elif isinstance(t, numbers.Number):
return type(t)
else:
raise ValueError('unexpected type {}'.format(type(t)))
def get_has_ddp_rank():
master_port = os.environ.get('MASTER_PORT', None)
node_rank = os.environ.get('NODE_RANK', None)
local_rank = os.environ.get('LOCAL_RANK', None)
world_size = os.environ.get('WORLD_SIZE', None)
has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
return has_rank
def handle_ddp_subprocess():
def main_decorator(main_func):
@functools.wraps(main_func)
def new_main(*args, **kwargs):
# Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
has_parent = parent_cwd is not None
has_rank = get_has_ddp_rank()
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
if has_parent:
# we are in the worker
sys.argv.extend([
f'hydra.run.dir={parent_cwd}',
# 'hydra/hydra_logging=disabled',
# 'hydra/job_logging=disabled'
])
# do nothing if this is a top-level process
# TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
main_func(*args, **kwargs)
return new_main
return main_decorator
def handle_ddp_parent_process():
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
has_parent = parent_cwd is not None
has_rank = get_has_ddp_rank()
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
if parent_cwd is None:
os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
return has_parent