import torch from model import MattingNetwork from torch.utils.data import DataLoader from torch.utils.data.dataset import Dataset import glob import os import cv2 import pdb import argparse class ItwDataset(Dataset): def __init__(self, input_pth, step, rotate): self.input_pth_list = glob.glob(os.path.join(input_pth, '*.png')) + \ glob.glob(os.path.join(input_pth, '*.jpg')) self.input_pth_list.sort() self.input_pth_list = self.input_pth_list[::step] self.rotate = rotate # pdb.set_trace() def __len__(self): return len(self.input_pth_list) def __getitem__(self, index): render_path = self.input_pth_list[index] # pdb.set_trace() img = cv2.imread(render_path) if self.rotate == '+90': img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) elif self.rotate == '-90': img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) elif self.rotate == '180': img = cv2.rotate(img, cv2.ROTATE_180) img = torch.from_numpy(img) img = img.permute(2,0,1)/255. img = img.unsqueeze(0) # img = torch.flip(img, dims = [0]) # print(img.shape) # img = img[::-1,...] # img = img.unsqueeze(0) return { 'img': img, 'file_name': os.path.basename(render_path)[:-4] } if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input_pth', type = str) parser.add_argument('--output_pth', type = str) parser.add_argument('--device', type = str, default = 'cpu') parser.add_argument('--step', type = int, default = 1) parser.add_argument('--rotate', type = str, default = '') args = parser.parse_args() device = torch.device(f'cuda:{args.device}') downsample_ratio = 0.4 model = MattingNetwork(variant='mobilenetv3').eval().to(device) # Or variant="resnet50" model.load_state_dict(torch.load('./checkpoint/rvm_mobilenetv3.pth')) rec = [None] * 4 # Initial recurrent states are None frame_dataset = ItwDataset(args.input_pth, args.step, args.rotate) # pdb.set_trace() if not os.path.exists(args.output_pth): os.makedirs(args.output_pth) for data in frame_dataset: save_img_pth = os.path.join(args.output_pth, data['file_name'] + '.png') if os.path.exists(save_img_pth): print(save_img_pth + ' exists!') continue # print('in') with torch.no_grad(): fgr, pha, *rec = model(data['img'].to(device), *rec, downsample_ratio) # pdb.set_trace() mask_infer = torch.round(pha.repeat(1,3,1,1))*255 mask_infer = mask_infer.squeeze(0).permute(1,2,0).detach().cpu().numpy() # pdb.set_trace() cv2.imwrite(save_img_pth, mask_infer) print(data['file_name'])