LN3Diff / utils /torch_utils /inference_matt.py
NIRVANALAN
release file
87c126b
# https://github.com/xinntao/facexlib/blob/master/inference/inference_matting.py
from tqdm import tqdm, trange
import argparse
from pathlib import Path
import cv2
import numpy as np
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from facexlib.matting import init_matting_model
from facexlib.utils import img2tensor
def matt_single(args):
modnet = init_matting_model()
# read image
img = cv2.imread(args.img_path) / 255.
# unify image channels to 3
if len(img.shape) == 2:
img = img[:, :, None]
if img.shape[2] == 1:
img = np.repeat(img, 3, axis=2)
elif img.shape[2] == 4:
img = img[:, :, 0:3]
img_t = img2tensor(img, bgr2rgb=True, float32=True)
normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
img_t = img_t.unsqueeze(0).cuda()
# resize image for input
_, _, im_h, im_w = img_t.shape
ref_size = 512
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
if im_w >= im_h:
im_rh = ref_size
im_rw = int(im_w / im_h * ref_size)
elif im_w < im_h:
im_rw = ref_size
im_rh = int(im_h / im_w * ref_size)
else:
im_rh = im_h
im_rw = im_w
im_rw = im_rw - im_rw % 32
im_rh = im_rh - im_rh % 32
img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area')
# inference
_, _, matte = modnet(img_t, True)
# resize and save matte
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
matte = matte[0][0].data.cpu().numpy()
cv2.imwrite(args.save_path, (matte * 255).astype('uint8'))
# get foreground
matte = matte[:, :, None]
foreground = img * matte + np.full(img.shape, 1) * (1 - matte)
cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255)
def matt_directory(args): # for extracting ffhq imgs foreground
modnet = init_matting_model()
all_imgs = list(Path(args.img_dir_path).rglob('*.png'))
print('all imgs: ', len(all_imgs))
tgt_dir_path = '/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_matte/'
# tgt_img_path = '/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_matting/'
for img_path in tqdm(all_imgs):
# read image
# img = cv2.imread(args.img_path) / 255.
img = cv2.imread(str(img_path)) / 255.
relative_img_path = Path(img_path).relative_to('/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_512/')
tgt_save_path = tgt_dir_path / relative_img_path
(tgt_save_path.parent).mkdir(parents=True, exist_ok=True)
# unify image channels to 3
if len(img.shape) == 2:
img = img[:, :, None]
if img.shape[2] == 1:
img = np.repeat(img, 3, axis=2)
elif img.shape[2] == 4:
img = img[:, :, 0:3]
img_t = img2tensor(img, bgr2rgb=True, float32=True)
normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
img_t = img_t.unsqueeze(0).cuda()
# resize image for input
_, _, im_h, im_w = img_t.shape
ref_size = 512
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
if im_w >= im_h:
im_rh = ref_size
im_rw = int(im_w / im_h * ref_size)
elif im_w < im_h:
im_rw = ref_size
im_rh = int(im_h / im_w * ref_size)
else:
im_rh = im_h
im_rw = im_w
im_rw = im_rw - im_rw % 32
im_rh = im_rh - im_rh % 32
img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area')
# inference
_, _, matte = modnet(img_t, True)
# resize and save matte
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
matte = matte[0][0].data.cpu().numpy()
# cv2.imwrite(args.save_path, (matte * 255).astype('uint8'))
cv2.imwrite(str(tgt_save_path), (matte * 255).astype('uint8'))
assert tgt_save_path.exists()
# get foreground
# matte = matte[:, :, None]
# foreground = img * matte + np.full(img.shape, 1) * (1 - matte)
# cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255)
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img_path', type=str, default='assets/test.jpg')
parser.add_argument('--save_path', type=str, default='test_matting.png')
parser.add_argument('--img_dir_path', type=str, default='assets', required=False)
args = parser.parse_args()
# matt_single(args)
matt_directory(args)