|
import torch |
|
from model import MattingNetwork |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.dataset import Dataset |
|
import glob |
|
import os |
|
import cv2 |
|
import pdb |
|
import argparse |
|
|
|
class ItwDataset(Dataset): |
|
def __init__(self, input_pth, step, rotate): |
|
|
|
self.input_pth_list = glob.glob(os.path.join(input_pth, '*.png')) + \ |
|
glob.glob(os.path.join(input_pth, '*.jpg')) |
|
self.input_pth_list.sort() |
|
self.input_pth_list = self.input_pth_list[::step] |
|
self.rotate = rotate |
|
|
|
def __len__(self): |
|
return len(self.input_pth_list) |
|
|
|
def __getitem__(self, index): |
|
|
|
render_path = self.input_pth_list[index] |
|
|
|
img = cv2.imread(render_path) |
|
if self.rotate == '+90': |
|
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) |
|
elif self.rotate == '-90': |
|
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) |
|
elif self.rotate == '180': |
|
img = cv2.rotate(img, cv2.ROTATE_180) |
|
img = torch.from_numpy(img) |
|
img = img.permute(2,0,1)/255. |
|
img = img.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
'img': img, |
|
'file_name': os.path.basename(render_path)[:-4] |
|
} |
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--input_pth', type = str) |
|
parser.add_argument('--output_pth', type = str) |
|
parser.add_argument('--device', type = str, default = 'cpu') |
|
parser.add_argument('--step', type = int, default = 1) |
|
parser.add_argument('--rotate', type = str, default = '') |
|
args = parser.parse_args() |
|
device = torch.device(f'cuda:{args.device}') |
|
downsample_ratio = 0.4 |
|
model = MattingNetwork(variant='mobilenetv3').eval().to(device) |
|
model.load_state_dict(torch.load('./checkpoint/rvm_mobilenetv3.pth')) |
|
rec = [None] * 4 |
|
frame_dataset = ItwDataset(args.input_pth, args.step, args.rotate) |
|
|
|
if not os.path.exists(args.output_pth): |
|
os.makedirs(args.output_pth) |
|
for data in frame_dataset: |
|
save_img_pth = os.path.join(args.output_pth, data['file_name'] + '.png') |
|
if os.path.exists(save_img_pth): |
|
print(save_img_pth + ' exists!') |
|
continue |
|
|
|
with torch.no_grad(): |
|
fgr, pha, *rec = model(data['img'].to(device), *rec, downsample_ratio) |
|
|
|
mask_infer = torch.round(pha.repeat(1,3,1,1))*255 |
|
mask_infer = mask_infer.squeeze(0).permute(1,2,0).detach().cpu().numpy() |
|
|
|
cv2.imwrite(save_img_pth, mask_infer) |
|
print(data['file_name']) |