import torch from torch.utils.data import Dataset import torchvision.transforms as transforms from lib.networks import define_G from glob import glob import argparse import os import os.path as osp import cv2, pdb from tqdm import tqdm import numpy as np from PIL import Image parser = argparse.ArgumentParser(description='neu video body rec') parser.add_argument('--gid',default=0,type=int,metavar='ID', help='gpu id') parser.add_argument('--imgpath',default=None,metavar='M', help='config file') args = parser.parse_args() def crop_image(img, rect): x, y, w, h = rect left = abs(x) if x < 0 else 0 top = abs(y) if y < 0 else 0 right = abs(img.shape[1]-(x+w)) if x + w >= img.shape[1] else 0 bottom = abs(img.shape[0]-(y+h)) if y + h >= img.shape[0] else 0 if img.shape[2] == 4: color = [0, 0, 0, 0] else: color = [0, 0, 0] new_img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) x = x + left y = y + top # pdb.set_trace() return new_img[y:(y+h),x:(x+w),:] class EvalDataset(Dataset): def __init__(self, root): self.root=root # self.img_files=[osp.join(self.root,f) for f in os.listdir(self.root) # if f.split('.')[-1] in ['png', 'jpeg', 'jpg', 'PNG', 'JPG', 'JPEG'] # and osp.exists(osp.join(self.root,f.replace('.%s' % (f.split('.')[-1]), '_rect.txt')))] self.img_files=[osp.join(self.root,f) for f in os.listdir(self.root) if f.split('.')[-1] in ['png', 'jpeg', 'jpg', 'PNG', 'JPG', 'JPEG']] self.img_files.sort(key=lambda x: int(osp.basename(x).split('.')[0])) # self.img_files=sorted([osp.join(self.root,f) for f in ['0.png'] if f.split('.')[-1] in ['png', 'jpeg', 'jpg', 'PNG', 'JPG', 'JPEG'] and osp.exists(osp.join(self.root,f.replace('.%s' % (f.split('.')[-1]), '_rect.txt')))]) self.to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) ]) self.person_id=0 def __len__(self): return len(self.img_files) def get_item(self, index): # index = 386 img_path = self.img_files[index] # rect_path = self.img_files[index].replace('.%s' % (self.img_files[index].split('.')[-1]), '_rect.txt') mask_path = self.img_files[index].replace('/imgs/','/masks/')[:-3]+'png' # Name img_name = os.path.splitext(os.path.basename(img_path))[0] # pdb.set_trace() im = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # print(mask_path) if osp.isfile(mask_path): mask=cv2.imread(mask_path) bg=~(mask>0).all(-1) im[bg]=np.zeros(im.shape[-1],dtype=im.dtype) else: bg=None H,W=im.shape[:2] if im.shape[2] == 4: im = im / 255.0 im[:,:,:3] /= im[:,:,3:] + 1e-8 im = im[:,:,3:] * im[:,:,:3] + 0.5 * (1.0 - im[:,:,3:]) im = (255.0 * im).astype(np.uint8) h, w = im.shape[:2] # rects = np.loadtxt(rect_path, dtype=np.float64) # rects[-2:] *= 1.1 # rects = rects.astype(np.int32) # pdb.set_trace() # TODO: change the rects using mask, get x, y, w, h # get the y1,y2,x1,x2 rects = self.mask_to_bbox(mask) # pdb.set_trace() if len(rects.shape) == 1: rects = rects[None] pid=0 else: max_len=0 pid=-1 for ind,rect in enumerate(rects): cur_len=(rect[-2]+rect[-1])//2 if max_len 0).all(-1)) y1, y2, x1, x2 = y_ind.min(), y_ind.max(), x_ind.min(), x_ind.max() h, w = y2 - y1, x2 - x1 h_, w_ = 1.05 * h, 1.05 * w y_, x_ = y1 - (h_ - h) / 2, x1 - (w_ - w) / 2 length = max(h_, w_) rects = np.array([x_, y_, length, length], dtype=np.int32) return rects device=torch.device(args.gid) # save_root=osp.normpath(osp.join(args.imgpath,osp.pardir,'normals')) # os.makedirs(save_root,exist_ok=True) netF=define_G(3, 3, 64, "global", 4, 9, 1, 3, "instance") weights={} for k,v in torch.load('checkpoints/pifuhd.pt',map_location='cpu')['model_state_dict'].items(): if k[:10]=='netG.netF.': weights[k[10:]]=v netF.load_state_dict(weights) netF=netF.to(device) netF.eval() cids=[temp for temp in os.listdir(args.imgpath) if osp.isdir(osp.join(args.imgpath,temp)) and temp.isdigit()] if len(cids)==0: cids=['.'] for fold in cids: save_root=osp.normpath(osp.join(args.imgpath,osp.pardir,'normals',fold)) print(save_root) os.makedirs(save_root,exist_ok=True) dataset=EvalDataset(osp.normpath(osp.join(args.imgpath,fold))) writer=None with torch.no_grad(): for i in tqdm(range(len(dataset))): # pdb.set_trace() img_name,img,bg,H,W,rect=dataset[i] if writer is None: writer=cv2.VideoWriter(osp.join(save_root,'video.avi'),cv2.VideoWriter.fourcc('M','J','P','G'),30.,(W,H)) x,y,w,h=[float(tmp) for tmp in rect] # cv2.imwrite('test.png',((np.transpose(img.numpy()[0],(1,2,0))*0.5+0.5)[:,:,::-1]*255.0).astype(np.uint8)) img=img.to(device) nml=netF.forward(img) # normal: already normalized between [-1,1] gridH,gridW=torch.meshgrid([torch.arange(H).float().to(device),torch.arange(W).float().to(device)]) coords=torch.stack([gridW,gridH]).permute(1,2,0).unsqueeze(0) #pdb.set_trace() # Here is do what? grid_sample says the coords should be in [-1, 1], but here is in [-2, 2] coords[...,0] = 2.0 * (coords[...,0] - x)/w - 1.0 coords[...,1] = 2.0 * (coords[...,1] - y)/h - 1.0 # * coording to normalized coordinates, billinearly compute the normals nml=torch.nn.functional.grid_sample(nml,coords,mode='bilinear', padding_mode='zeros', align_corners=True) unvalid_mask=(torch.norm(nml,dim=1)<0.0001).detach().cpu().numpy()[0] nml=nml.detach().cpu().numpy()[0] # save normal map as rgb images nml=(np.transpose(nml,(1,2,0))*0.5+0.5)[:,:,::-1]*255.0 # *0.5 -> [-0.5,0.5] + 0.5 -> [0,1] * 255 -> [0,255] if unvalid_mask.sum()>0: nml[unvalid_mask]=0. # print(osp.join(save_root,img_name,'.png')) if bg is not None: nml[bg]=0. # if (unvalid_mask*(~bg)).sum()>0: # print(i) cv2.imwrite(osp.join(save_root,img_name+'.png'),nml.astype(np.uint8)) writer.write(nml.astype(np.uint8)) if writer is not None: writer.release() print('done.')