Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections import OrderedDict | |
import numpy as np | |
class LogBuffer: | |
def __init__(self): | |
self.val_history = OrderedDict() | |
self.n_history = OrderedDict() | |
self.output = OrderedDict() | |
self.ready = False | |
def clear(self): | |
self.val_history.clear() | |
self.n_history.clear() | |
self.clear_output() | |
def clear_output(self): | |
self.output.clear() | |
self.ready = False | |
def update(self, vars, count=1): | |
assert isinstance(vars, dict) | |
for key, var in vars.items(): | |
if key not in self.val_history: | |
self.val_history[key] = [] | |
self.n_history[key] = [] | |
self.val_history[key].append(var) | |
self.n_history[key].append(count) | |
def average(self, n=0): | |
"""Average latest n values or all values.""" | |
assert n >= 0 | |
for key in self.val_history: | |
values = np.array(self.val_history[key][-n:]) | |
nums = np.array(self.n_history[key][-n:]) | |
avg = np.sum(values * nums) / np.sum(nums) | |
self.output[key] = avg | |
self.ready = True | |