import cv2 | |
from collections import OrderedDict | |
import torch | |
def save_img(filepath, img): | |
cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) | |
def load_checkpoint(model, weights): | |
checkpoint = torch.load(weights) | |
try: | |
model.load_state_dict(checkpoint["state_dict"]) | |
except: | |
state_dict = checkpoint["state_dict"] | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] # remove `module.` | |
new_state_dict[name] = v | |
model.load_state_dict(new_state_dict) | |
def setup(args): | |
save_dir = 'result/' | |
folder = 'test/' | |
return folder, save_dir | |