Spaces:
Running
Running
File size: 2,322 Bytes
4a3ab35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import datetime
import os
import torch
import matplotlib
matplotlib.use('Agg')
import scipy.signal
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
class LossHistory():
def __init__(self, log_dir, model, input_shape):
time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
self.log_dir = os.path.join(log_dir, "loss_" + str(time_str))
self.losses = []
self.val_loss = []
os.makedirs(self.log_dir)
self.writer = SummaryWriter(self.log_dir)
try:
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
self.writer.add_graph(model, dummy_input)
except:
pass
def append_loss(self, epoch, loss, val_loss):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.losses.append(loss)
self.val_loss.append(val_loss)
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
f.write(str(loss))
f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
f.write(str(val_loss))
f.write("\n")
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)
self.loss_plot()
def loss_plot(self):
iters = range(len(self.losses))
plt.figure()
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
try:
if len(self.losses) < 25:
num = 5
else:
num = 15
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
plt.cla()
plt.close("all")
|