REC-MV_preprocess / RobustVideoMatting /inference_itw_rotate.py
mambazjp's picture
Upload 57 files
8b79d57
import torch
from model import MattingNetwork
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import glob
import os
import cv2
import pdb
import argparse
class ItwDataset(Dataset):
def __init__(self, input_pth, step, rotate):
self.input_pth_list = glob.glob(os.path.join(input_pth, '*.png')) + \
glob.glob(os.path.join(input_pth, '*.jpg'))
self.input_pth_list.sort()
self.input_pth_list = self.input_pth_list[::step]
self.rotate = rotate
# pdb.set_trace()
def __len__(self):
return len(self.input_pth_list)
def __getitem__(self, index):
render_path = self.input_pth_list[index]
# pdb.set_trace()
img = cv2.imread(render_path)
if self.rotate == '+90':
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
elif self.rotate == '-90':
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
elif self.rotate == '180':
img = cv2.rotate(img, cv2.ROTATE_180)
img = torch.from_numpy(img)
img = img.permute(2,0,1)/255.
img = img.unsqueeze(0)
# img = torch.flip(img, dims = [0])
# print(img.shape)
# img = img[::-1,...]
# img = img.unsqueeze(0)
return {
'img': img,
'file_name': os.path.basename(render_path)[:-4]
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_pth', type = str)
parser.add_argument('--output_pth', type = str)
parser.add_argument('--device', type = str, default = 'cpu')
parser.add_argument('--step', type = int, default = 1)
parser.add_argument('--rotate', type = str, default = '')
args = parser.parse_args()
device = torch.device(f'cuda:{args.device}')
downsample_ratio = 0.4
model = MattingNetwork(variant='mobilenetv3').eval().to(device) # Or variant="resnet50"
model.load_state_dict(torch.load('./checkpoint/rvm_mobilenetv3.pth'))
rec = [None] * 4 # Initial recurrent states are None
frame_dataset = ItwDataset(args.input_pth, args.step, args.rotate)
# pdb.set_trace()
if not os.path.exists(args.output_pth):
os.makedirs(args.output_pth)
for data in frame_dataset:
save_img_pth = os.path.join(args.output_pth, data['file_name'] + '.png')
if os.path.exists(save_img_pth):
print(save_img_pth + ' exists!')
continue
# print('in')
with torch.no_grad():
fgr, pha, *rec = model(data['img'].to(device), *rec, downsample_ratio)
# pdb.set_trace()
mask_infer = torch.round(pha.repeat(1,3,1,1))*255
mask_infer = mask_infer.squeeze(0).permute(1,2,0).detach().cpu().numpy()
# pdb.set_trace()
cv2.imwrite(save_img_pth, mask_infer)
print(data['file_name'])