Spaces:
Running
Running
""" | |
This file contains some useful functions for train / val. | |
""" | |
import os | |
import numpy as np | |
import torch | |
################# | |
## image utils ## | |
################# | |
def convert_image(input_tensor, axis): | |
"""Convert single channel images to 3-channel images.""" | |
image_lst = [input_tensor for _ in range(3)] | |
outputs = np.concatenate(image_lst, axis) | |
return outputs | |
###################### | |
## checkpoint utils ## | |
###################### | |
def get_latest_checkpoint( | |
checkpoint_root, checkpoint_name, device=torch.device("cuda") | |
): | |
"""Get the latest checkpoint or by filename.""" | |
# Load specific checkpoint | |
if checkpoint_name is not None: | |
checkpoint = torch.load( | |
os.path.join(checkpoint_root, checkpoint_name), map_location=device | |
) | |
# Load the latest checkpoint | |
else: | |
lastest_checkpoint = sorted(os.listdir(os.path.join(checkpoint_root, "*.tar")))[ | |
-1 | |
] | |
checkpoint = torch.load( | |
os.path.join(checkpoint_root, lastest_checkpoint), map_location=device | |
) | |
return checkpoint | |
def remove_old_checkpoints(checkpoint_root, max_ckpt=15): | |
"""Remove the outdated checkpoints.""" | |
# Get sorted list of checkpoints | |
checkpoint_list = sorted( | |
[_ for _ in os.listdir(os.path.join(checkpoint_root)) if _.endswith(".tar")] | |
) | |
# Get the checkpoints to be removed | |
if len(checkpoint_list) > max_ckpt: | |
remove_list = checkpoint_list[:-max_ckpt] | |
for _ in remove_list: | |
full_name = os.path.join(checkpoint_root, _) | |
os.remove(full_name) | |
print("[Debug] Remove outdated checkpoint %s" % (full_name)) | |
def adapt_checkpoint(state_dict): | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if k.startswith("module."): | |
new_state_dict[k[7:]] = v | |
else: | |
new_state_dict[k] = v | |
return new_state_dict | |
################ | |
## HDF5 utils ## | |
################ | |
def parse_h5_data(h5_data): | |
"""Parse h5 dataset.""" | |
output_data = {} | |
for key in h5_data.keys(): | |
output_data[key] = np.array(h5_data[key]) | |
return output_data | |