|
import torch |
|
from torch.utils.data import Dataset |
|
import torchvision.transforms as transforms |
|
from lib.networks import define_G |
|
from glob import glob |
|
import argparse |
|
import os |
|
import os.path as osp |
|
import cv2, pdb |
|
from tqdm import tqdm |
|
import numpy as np |
|
from PIL import Image |
|
parser = argparse.ArgumentParser(description='neu video body rec') |
|
parser.add_argument('--gid',default=0,type=int,metavar='ID', |
|
help='gpu id') |
|
parser.add_argument('--imgpath',default=None,metavar='M', |
|
help='config file') |
|
args = parser.parse_args() |
|
|
|
|
|
def crop_image(img, rect): |
|
x, y, w, h = rect |
|
|
|
left = abs(x) if x < 0 else 0 |
|
top = abs(y) if y < 0 else 0 |
|
right = abs(img.shape[1]-(x+w)) if x + w >= img.shape[1] else 0 |
|
bottom = abs(img.shape[0]-(y+h)) if y + h >= img.shape[0] else 0 |
|
|
|
if img.shape[2] == 4: |
|
color = [0, 0, 0, 0] |
|
else: |
|
color = [0, 0, 0] |
|
new_img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) |
|
|
|
x = x + left |
|
y = y + top |
|
|
|
return new_img[y:(y+h),x:(x+w),:] |
|
|
|
|
|
|
|
class EvalDataset(Dataset): |
|
def __init__(self, root): |
|
self.root=root |
|
|
|
|
|
|
|
self.img_files=[osp.join(self.root,f) for f in os.listdir(self.root) |
|
if f.split('.')[-1] in ['png', 'jpeg', 'jpg', 'PNG', 'JPG', 'JPEG']] |
|
|
|
|
|
self.img_files.sort(key=lambda x: int(osp.basename(x).split('.')[0])) |
|
|
|
|
|
self.to_tensor = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) |
|
]) |
|
self.person_id=0 |
|
def __len__(self): |
|
return len(self.img_files) |
|
|
|
def get_item(self, index): |
|
|
|
img_path = self.img_files[index] |
|
|
|
mask_path = self.img_files[index].replace('/imgs/','/masks/')[:-3]+'png' |
|
|
|
|
|
img_name = os.path.splitext(os.path.basename(img_path))[0] |
|
|
|
im = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) |
|
|
|
if osp.isfile(mask_path): |
|
mask=cv2.imread(mask_path) |
|
bg=~(mask>0).all(-1) |
|
im[bg]=np.zeros(im.shape[-1],dtype=im.dtype) |
|
else: |
|
bg=None |
|
H,W=im.shape[:2] |
|
if im.shape[2] == 4: |
|
im = im / 255.0 |
|
im[:,:,:3] /= im[:,:,3:] + 1e-8 |
|
im = im[:,:,3:] * im[:,:,:3] + 0.5 * (1.0 - im[:,:,3:]) |
|
im = (255.0 * im).astype(np.uint8) |
|
h, w = im.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rects = self.mask_to_bbox(mask) |
|
|
|
|
|
|
|
if len(rects.shape) == 1: |
|
rects = rects[None] |
|
pid=0 |
|
else: |
|
max_len=0 |
|
pid=-1 |
|
for ind,rect in enumerate(rects): |
|
cur_len=(rect[-2]+rect[-1])//2 |
|
if max_len<cur_len: |
|
max_len=cur_len |
|
pid=ind |
|
|
|
|
|
rect = rects[pid].tolist() |
|
im = crop_image(im, rect) |
|
im_512 = cv2.resize(im, (512, 512)) |
|
image_512 = Image.fromarray(im_512[:,:,::-1]).convert('RGB') |
|
|
|
|
|
image_512 = self.to_tensor(image_512) |
|
return (img_name,image_512.unsqueeze(0),bg,H,W,rect) |
|
|
|
def __getitem__(self, index): |
|
return self.get_item(index) |
|
|
|
def mask_to_bbox(self, mask): |
|
y_ind, x_ind = np.where((mask > 0).all(-1)) |
|
y1, y2, x1, x2 = y_ind.min(), y_ind.max(), x_ind.min(), x_ind.max() |
|
h, w = y2 - y1, x2 - x1 |
|
h_, w_ = 1.05 * h, 1.05 * w |
|
y_, x_ = y1 - (h_ - h) / 2, x1 - (w_ - w) / 2 |
|
length = max(h_, w_) |
|
rects = np.array([x_, y_, length, length], dtype=np.int32) |
|
return rects |
|
|
|
|
|
|
|
device=torch.device(args.gid) |
|
|
|
|
|
|
|
|
|
|
|
netF=define_G(3, 3, 64, "global", 4, 9, 1, 3, "instance") |
|
|
|
weights={} |
|
for k,v in torch.load('checkpoints/pifuhd.pt',map_location='cpu')['model_state_dict'].items(): |
|
if k[:10]=='netG.netF.': |
|
weights[k[10:]]=v |
|
|
|
netF.load_state_dict(weights) |
|
|
|
netF=netF.to(device) |
|
|
|
netF.eval() |
|
cids=[temp for temp in os.listdir(args.imgpath) if osp.isdir(osp.join(args.imgpath,temp)) and temp.isdigit()] |
|
|
|
if len(cids)==0: |
|
cids=['.'] |
|
for fold in cids: |
|
save_root=osp.normpath(osp.join(args.imgpath,osp.pardir,'normals',fold)) |
|
print(save_root) |
|
os.makedirs(save_root,exist_ok=True) |
|
dataset=EvalDataset(osp.normpath(osp.join(args.imgpath,fold))) |
|
writer=None |
|
with torch.no_grad(): |
|
for i in tqdm(range(len(dataset))): |
|
|
|
img_name,img,bg,H,W,rect=dataset[i] |
|
if writer is None: |
|
writer=cv2.VideoWriter(osp.join(save_root,'video.avi'),cv2.VideoWriter.fourcc('M','J','P','G'),30.,(W,H)) |
|
x,y,w,h=[float(tmp) for tmp in rect] |
|
|
|
|
|
img=img.to(device) |
|
nml=netF.forward(img) |
|
|
|
gridH,gridW=torch.meshgrid([torch.arange(H).float().to(device),torch.arange(W).float().to(device)]) |
|
coords=torch.stack([gridW,gridH]).permute(1,2,0).unsqueeze(0) |
|
|
|
|
|
coords[...,0] = 2.0 * (coords[...,0] - x)/w - 1.0 |
|
coords[...,1] = 2.0 * (coords[...,1] - y)/h - 1.0 |
|
|
|
nml=torch.nn.functional.grid_sample(nml,coords,mode='bilinear', padding_mode='zeros', align_corners=True) |
|
|
|
unvalid_mask=(torch.norm(nml,dim=1)<0.0001).detach().cpu().numpy()[0] |
|
nml=nml.detach().cpu().numpy()[0] |
|
|
|
nml=(np.transpose(nml,(1,2,0))*0.5+0.5)[:,:,::-1]*255.0 |
|
if unvalid_mask.sum()>0: |
|
nml[unvalid_mask]=0. |
|
|
|
if bg is not None: |
|
nml[bg]=0. |
|
|
|
|
|
cv2.imwrite(osp.join(save_root,img_name+'.png'),nml.astype(np.uint8)) |
|
writer.write(nml.astype(np.uint8)) |
|
|
|
if writer is not None: |
|
writer.release() |
|
print('done.') |