Prompt-Free-Diffusion / lib /log_service.py
3v324v23's picture
code pushed
515f781
import timeit
import numpy as np
import os
import os.path as osp
import shutil
import copy
import torch
import torch.nn as nn
import torch.distributed as dist
from .cfg_holder import cfg_unique_holder as cfguh
from . import sync
def print_log(*console_info):
grank, lrank, _ = sync.get_rank('all')
if lrank!=0:
return
console_info = [str(i) for i in console_info]
console_info = ' '.join(console_info)
print(console_info)
if grank!=0:
return
log_file = None
try:
log_file = cfguh().cfg.train.log_file
except:
try:
log_file = cfguh().cfg.eval.log_file
except:
return
if log_file is not None:
with open(log_file, 'a') as f:
f.write(console_info + '\n')
class distributed_log_manager(object):
def __init__(self):
self.sum = {}
self.cnt = {}
self.time_check = timeit.default_timer()
cfgt = cfguh().cfg.train
self.ddp = sync.is_ddp()
self.grank, self.lrank, _ = sync.get_rank('all')
self.gwsize = sync.get_world_size('global')
use_tensorboard = cfgt.get('log_tensorboard', False) and (self.grank==0)
self.tb = None
if use_tensorboard:
import tensorboardX
monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))
def accumulate(self, n, **data):
if n < 0:
raise ValueError
for itemn, di in data.items():
if itemn in self.sum:
self.sum[itemn] += di * n
self.cnt[itemn] += n
else:
self.sum[itemn] = di * n
self.cnt[itemn] = n
def get_mean_value_dict(self):
value_gather = [
self.sum[itemn]/self.cnt[itemn] \
for itemn in sorted(self.sum.keys()) ]
value_gather_tensor = torch.FloatTensor(value_gather).to(self.lrank)
if self.ddp:
dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
value_gather_tensor /= self.gwsize
mean = {}
for idx, itemn in enumerate(sorted(self.sum.keys())):
mean[itemn] = value_gather_tensor[idx].item()
return mean
def tensorboard_log(self, step, data, mode='train', **extra):
if self.tb is None:
return
if mode == 'train':
self.tb.add_scalar('other/epochn', extra['epochn'], step)
if ('lr' in extra) and (extra['lr'] is not None):
self.tb.add_scalar('other/lr', extra['lr'], step)
for itemn, di in data.items():
if itemn.find('loss') == 0:
self.tb.add_scalar('loss/'+itemn, di, step)
elif itemn == 'Loss':
self.tb.add_scalar('Loss', di, step)
else:
self.tb.add_scalar('other/'+itemn, di, step)
elif mode == 'eval':
if isinstance(data, dict):
for itemn, di in data.items():
self.tb.add_scalar('eval/'+itemn, di, step)
else:
self.tb.add_scalar('eval', data, step)
return
def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
console_info = [
'Iter:{}'.format(itern),
'Epoch:{}'.format(epochn),
'Sample:{}'.format(samplen),]
if lr is not None:
console_info += ['LR:{:.4E}'.format(lr)]
mean = self.get_mean_value_dict()
tbstep = itern if tbstep is None else tbstep
self.tensorboard_log(
tbstep, mean, mode='train',
itern=itern, epochn=epochn, lr=lr)
loss = mean.pop('Loss')
mean_info = ['Loss:{:.4f}'.format(loss)] + [
'{}:{:.4f}'.format(itemn, mean[itemn]) \
for itemn in sorted(mean.keys()) \
if itemn.find('loss') == 0
]
console_info += mean_info
console_info.append('Time:{:.2f}s'.format(
timeit.default_timer() - self.time_check))
return ' , '.join(console_info)
def clear(self):
self.sum = {}
self.cnt = {}
self.time_check = timeit.default_timer()
def tensorboard_close(self):
if self.tb is not None:
self.tb.close()
# ----- also include some small utils -----
def torch_to_numpy(*argv):
if len(argv) > 1:
data = list(argv)
else:
data = argv[0]
if isinstance(data, torch.Tensor):
return data.to('cpu').detach().numpy()
elif isinstance(data, (list, tuple)):
out = []
for di in data:
out.append(torch_to_numpy(di))
return out
elif isinstance(data, dict):
out = {}
for ni, di in data.items():
out[ni] = torch_to_numpy(di)
return out
else:
return data