Spaces:
Paused
Paused
import bisect | |
import functools | |
import logging | |
import numbers | |
import os | |
import signal | |
import sys | |
import traceback | |
import warnings | |
import torch | |
from pytorch_lightning import seed_everything | |
LOGGER = logging.getLogger(__name__) | |
def check_and_warn_input_range(tensor, min_value, max_value, name): | |
actual_min = tensor.min() | |
actual_max = tensor.max() | |
if actual_min < min_value or actual_max > max_value: | |
warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}") | |
def sum_dict_with_prefix(target, cur_dict, prefix, default=0): | |
for k, v in cur_dict.items(): | |
target_key = prefix + k | |
target[target_key] = target.get(target_key, default) + v | |
def average_dicts(dict_list): | |
result = {} | |
norm = 1e-3 | |
for dct in dict_list: | |
sum_dict_with_prefix(result, dct, '') | |
norm += 1 | |
for k in list(result): | |
result[k] /= norm | |
return result | |
def add_prefix_to_keys(dct, prefix): | |
return {prefix + k: v for k, v in dct.items()} | |
def set_requires_grad(module, value): | |
for param in module.parameters(): | |
param.requires_grad = value | |
def flatten_dict(dct): | |
result = {} | |
for k, v in dct.items(): | |
if isinstance(k, tuple): | |
k = '_'.join(k) | |
if isinstance(v, dict): | |
for sub_k, sub_v in flatten_dict(v).items(): | |
result[f'{k}_{sub_k}'] = sub_v | |
else: | |
result[k] = v | |
return result | |
class LinearRamp: | |
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): | |
self.start_value = start_value | |
self.end_value = end_value | |
self.start_iter = start_iter | |
self.end_iter = end_iter | |
def __call__(self, i): | |
if i < self.start_iter: | |
return self.start_value | |
if i >= self.end_iter: | |
return self.end_value | |
part = (i - self.start_iter) / (self.end_iter - self.start_iter) | |
return self.start_value * (1 - part) + self.end_value * part | |
class LadderRamp: | |
def __init__(self, start_iters, values): | |
self.start_iters = start_iters | |
self.values = values | |
assert len(values) == len(start_iters) + 1, (len(values), len(start_iters)) | |
def __call__(self, i): | |
segment_i = bisect.bisect_right(self.start_iters, i) | |
return self.values[segment_i] | |
def get_ramp(kind='ladder', **kwargs): | |
if kind == 'linear': | |
return LinearRamp(**kwargs) | |
if kind == 'ladder': | |
return LadderRamp(**kwargs) | |
raise ValueError(f'Unexpected ramp kind: {kind}') | |
def print_traceback_handler(sig, frame): | |
LOGGER.warning(f'Received signal {sig}') | |
bt = ''.join(traceback.format_stack()) | |
LOGGER.warning(f'Requested stack trace:\n{bt}') | |
def register_debug_signal_handlers(sig=None, handler=print_traceback_handler): | |
LOGGER.warning(f'Setting signal {sig} handler {handler}') | |
signal.signal(sig, handler) | |
def handle_deterministic_config(config): | |
seed = dict(config).get('seed', None) | |
if seed is None: | |
return False | |
seed_everything(seed) | |
return True | |
def get_shape(t): | |
if torch.is_tensor(t): | |
return tuple(t.shape) | |
elif isinstance(t, dict): | |
return {n: get_shape(q) for n, q in t.items()} | |
elif isinstance(t, (list, tuple)): | |
return [get_shape(q) for q in t] | |
elif isinstance(t, numbers.Number): | |
return type(t) | |
else: | |
raise ValueError('unexpected type {}'.format(type(t))) | |
def get_has_ddp_rank(): | |
master_port = os.environ.get('MASTER_PORT', None) | |
node_rank = os.environ.get('NODE_RANK', None) | |
local_rank = os.environ.get('LOCAL_RANK', None) | |
world_size = os.environ.get('WORLD_SIZE', None) | |
has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None | |
return has_rank | |
def handle_ddp_subprocess(): | |
def main_decorator(main_func): | |
def new_main(*args, **kwargs): | |
# Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE | |
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) | |
has_parent = parent_cwd is not None | |
has_rank = get_has_ddp_rank() | |
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' | |
if has_parent: | |
# we are in the worker | |
sys.argv.extend([ | |
f'hydra.run.dir={parent_cwd}', | |
# 'hydra/hydra_logging=disabled', | |
# 'hydra/job_logging=disabled' | |
]) | |
# do nothing if this is a top-level process | |
# TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization | |
main_func(*args, **kwargs) | |
return new_main | |
return main_decorator | |
def handle_ddp_parent_process(): | |
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) | |
has_parent = parent_cwd is not None | |
has_rank = get_has_ddp_rank() | |
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' | |
if parent_cwd is None: | |
os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd() | |
return has_parent | |