Spaces:
Runtime error
Runtime error
import timeit | |
import numpy as np | |
import os | |
import os.path as osp | |
import shutil | |
import copy | |
import torch | |
import torch.nn as nn | |
import torch.distributed as dist | |
from .cfg_holder import cfg_unique_holder as cfguh | |
from . import sync | |
def print_log(*console_info): | |
grank, lrank, _ = sync.get_rank('all') | |
if lrank!=0: | |
return | |
console_info = [str(i) for i in console_info] | |
console_info = ' '.join(console_info) | |
print(console_info) | |
if grank!=0: | |
return | |
log_file = None | |
try: | |
log_file = cfguh().cfg.train.log_file | |
except: | |
try: | |
log_file = cfguh().cfg.eval.log_file | |
except: | |
return | |
if log_file is not None: | |
with open(log_file, 'a') as f: | |
f.write(console_info + '\n') | |
class distributed_log_manager(object): | |
def __init__(self): | |
self.sum = {} | |
self.cnt = {} | |
self.time_check = timeit.default_timer() | |
cfgt = cfguh().cfg.train | |
self.ddp = sync.is_ddp() | |
self.grank, self.lrank, _ = sync.get_rank('all') | |
self.gwsize = sync.get_world_size('global') | |
use_tensorboard = cfgt.get('log_tensorboard', False) and (self.grank==0) | |
self.tb = None | |
if use_tensorboard: | |
import tensorboardX | |
monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard') | |
self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir)) | |
def accumulate(self, n, **data): | |
if n < 0: | |
raise ValueError | |
for itemn, di in data.items(): | |
if itemn in self.sum: | |
self.sum[itemn] += di * n | |
self.cnt[itemn] += n | |
else: | |
self.sum[itemn] = di * n | |
self.cnt[itemn] = n | |
def get_mean_value_dict(self): | |
value_gather = [ | |
self.sum[itemn]/self.cnt[itemn] \ | |
for itemn in sorted(self.sum.keys()) ] | |
value_gather_tensor = torch.FloatTensor(value_gather).to(self.lrank) | |
if self.ddp: | |
dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM) | |
value_gather_tensor /= self.gwsize | |
mean = {} | |
for idx, itemn in enumerate(sorted(self.sum.keys())): | |
mean[itemn] = value_gather_tensor[idx].item() | |
return mean | |
def tensorboard_log(self, step, data, mode='train', **extra): | |
if self.tb is None: | |
return | |
if mode == 'train': | |
self.tb.add_scalar('other/epochn', extra['epochn'], step) | |
if ('lr' in extra) and (extra['lr'] is not None): | |
self.tb.add_scalar('other/lr', extra['lr'], step) | |
for itemn, di in data.items(): | |
if itemn.find('loss') == 0: | |
self.tb.add_scalar('loss/'+itemn, di, step) | |
elif itemn == 'Loss': | |
self.tb.add_scalar('Loss', di, step) | |
else: | |
self.tb.add_scalar('other/'+itemn, di, step) | |
elif mode == 'eval': | |
if isinstance(data, dict): | |
for itemn, di in data.items(): | |
self.tb.add_scalar('eval/'+itemn, di, step) | |
else: | |
self.tb.add_scalar('eval', data, step) | |
return | |
def train_summary(self, itern, epochn, samplen, lr, tbstep=None): | |
console_info = [ | |
'Iter:{}'.format(itern), | |
'Epoch:{}'.format(epochn), | |
'Sample:{}'.format(samplen),] | |
if lr is not None: | |
console_info += ['LR:{:.4E}'.format(lr)] | |
mean = self.get_mean_value_dict() | |
tbstep = itern if tbstep is None else tbstep | |
self.tensorboard_log( | |
tbstep, mean, mode='train', | |
itern=itern, epochn=epochn, lr=lr) | |
loss = mean.pop('Loss') | |
mean_info = ['Loss:{:.4f}'.format(loss)] + [ | |
'{}:{:.4f}'.format(itemn, mean[itemn]) \ | |
for itemn in sorted(mean.keys()) \ | |
if itemn.find('loss') == 0 | |
] | |
console_info += mean_info | |
console_info.append('Time:{:.2f}s'.format( | |
timeit.default_timer() - self.time_check)) | |
return ' , '.join(console_info) | |
def clear(self): | |
self.sum = {} | |
self.cnt = {} | |
self.time_check = timeit.default_timer() | |
def tensorboard_close(self): | |
if self.tb is not None: | |
self.tb.close() | |
# ----- also include some small utils ----- | |
def torch_to_numpy(*argv): | |
if len(argv) > 1: | |
data = list(argv) | |
else: | |
data = argv[0] | |
if isinstance(data, torch.Tensor): | |
return data.to('cpu').detach().numpy() | |
elif isinstance(data, (list, tuple)): | |
out = [] | |
for di in data: | |
out.append(torch_to_numpy(di)) | |
return out | |
elif isinstance(data, dict): | |
out = {} | |
for ni, di in data.items(): | |
out[ni] = torch_to_numpy(di) | |
return out | |
else: | |
return data | |