SRMNet_real_world_denoising / main_test_SRMNet.py
52Hz's picture
Update main_test_SRMNet.py
715282a
raw
history blame
No virus
676 Bytes
import cv2
from collections import OrderedDict
import torch
def save_img(filepath, img):
cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def load_checkpoint(model, weights):
checkpoint = torch.load(weights)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def setup(args):
save_dir = 'result/'
folder = 'test/'
return folder, save_dir