import cv2 from collections import OrderedDict import torch def save_img(filepath, img): cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) def load_checkpoint(model, weights): checkpoint = torch.load(weights) try: model.load_state_dict(checkpoint["state_dict"]) except: state_dict = checkpoint["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) def setup(args): save_dir = 'result/' folder = 'test/' return folder, save_dir