|
|
|
|
|
import sys |
|
import os |
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
import time |
|
import json |
|
import numpy as np |
|
import cv2 |
|
import random |
|
import torch |
|
import torch.nn as nn |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
import matplotlib.pyplot as plt |
|
import matplotlib.cm as cm |
|
import matplotlib |
|
from numpy.linalg import inv |
|
|
|
from lib.options import BaseOptions |
|
from lib.mesh_util import save_obj_mesh_with_color, reconstruction |
|
from lib.data import EvalWPoseDataset, EvalDataset |
|
from lib.model import HGPIFuNetwNML, HGPIFuMRNet |
|
from lib.geometry import index |
|
|
|
from PIL import Image |
|
|
|
parser = BaseOptions() |
|
|
|
def gen_mesh(res, net, cuda, data, save_path, thresh=0.5, use_octree=True, components=False): |
|
image_tensor_global = data['img_512'].to(device=cuda) |
|
image_tensor = data['img'].to(device=cuda) |
|
calib_tensor = data['calib'].to(device=cuda) |
|
|
|
net.filter_global(image_tensor_global) |
|
net.filter_local(image_tensor[:,None]) |
|
|
|
try: |
|
if net.netG.netF is not None: |
|
image_tensor_global = torch.cat([image_tensor_global, net.netG.nmlF], 0) |
|
if net.netG.netB is not None: |
|
image_tensor_global = torch.cat([image_tensor_global, net.netG.nmlB], 0) |
|
except: |
|
pass |
|
|
|
b_min = data['b_min'] |
|
b_max = data['b_max'] |
|
try: |
|
save_img_path = save_path[:-4] + '.png' |
|
save_img_list = [] |
|
for v in range(image_tensor_global.shape[0]): |
|
save_img = (np.transpose(image_tensor_global[v].detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0 |
|
save_img_list.append(save_img) |
|
save_img = np.concatenate(save_img_list, axis=1) |
|
cv2.imwrite(save_img_path, save_img) |
|
|
|
verts, faces, _, _ = reconstruction( |
|
net, cuda, calib_tensor, res, b_min, b_max, thresh, use_octree=use_octree, num_samples=50000) |
|
verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(device=cuda).float() |
|
|
|
|
|
|
|
|
|
color = np.zeros(verts.shape) |
|
interval = 50000 |
|
for i in range(len(color) // interval + 1): |
|
left = i * interval |
|
if i == len(color) // interval: |
|
right = -1 |
|
else: |
|
right = (i + 1) * interval |
|
net.calc_normal(verts_tensor[:, None, :, left:right], calib_tensor[:,None], calib_tensor) |
|
nml = net.nmls.detach().cpu().numpy()[0] * 0.5 + 0.5 |
|
color[left:right] = nml.T |
|
|
|
save_obj_mesh_with_color(save_path, verts, faces, color) |
|
except Exception as e: |
|
print(e) |
|
|
|
|
|
def gen_mesh_imgColor(res, net, cuda, data, save_path, thresh=0.5, use_octree=True, components=False): |
|
image_tensor_global = data['img_512'].to(device=cuda) |
|
image_tensor = data['img'].to(device=cuda) |
|
calib_tensor = data['calib'].to(device=cuda) |
|
|
|
net.filter_global(image_tensor_global) |
|
net.filter_local(image_tensor[:,None]) |
|
|
|
try: |
|
if net.netG.netF is not None: |
|
image_tensor_global = torch.cat([image_tensor_global, net.netG.nmlF], 0) |
|
if net.netG.netB is not None: |
|
image_tensor_global = torch.cat([image_tensor_global, net.netG.nmlB], 0) |
|
except: |
|
pass |
|
|
|
b_min = data['b_min'] |
|
b_max = data['b_max'] |
|
try: |
|
save_img_path = save_path[:-4] + '.png' |
|
save_img_list = [] |
|
for v in range(image_tensor_global.shape[0]): |
|
save_img = (np.transpose(image_tensor_global[v].detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0 |
|
save_img_list.append(save_img) |
|
save_img = np.concatenate(save_img_list, axis=1) |
|
cv2.imwrite(save_img_path, save_img) |
|
|
|
verts, faces, _, _ = reconstruction( |
|
net, cuda, calib_tensor, res, b_min, b_max, thresh, use_octree=use_octree, num_samples=100000) |
|
verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(device=cuda).float() |
|
|
|
|
|
xyz_tensor = net.projection(verts_tensor, calib_tensor[:1]) |
|
uv = xyz_tensor[:, :2, :] |
|
color = index(image_tensor[:1], uv).detach().cpu().numpy()[0].T |
|
color = color * 0.5 + 0.5 |
|
|
|
if 'calib_world' in data: |
|
calib_world = data['calib_world'].numpy()[0] |
|
verts = np.matmul(np.concatenate([verts, np.ones_like(verts[:,:1])],1), inv(calib_world).T)[:,:3] |
|
|
|
save_obj_mesh_with_color(save_path, verts, faces, color) |
|
|
|
except Exception as e: |
|
print(e) |
|
|
|
|
|
def recon(opt, use_rect=False): |
|
|
|
state_dict_path = None |
|
if opt.load_netMR_checkpoint_path is not None: |
|
state_dict_path = opt.load_netMR_checkpoint_path |
|
elif opt.resume_epoch < 0: |
|
state_dict_path = '%s/%s_train_latest' % (opt.checkpoints_path, opt.name) |
|
opt.resume_epoch = 0 |
|
else: |
|
state_dict_path = '%s/%s_train_epoch_%d' % (opt.checkpoints_path, opt.name, opt.resume_epoch) |
|
|
|
start_id = opt.start_id |
|
end_id = opt.end_id |
|
|
|
cuda = torch.device('cuda:%d' % opt.gpu_id if torch.cuda.is_available() else 'cpu') |
|
|
|
state_dict = None |
|
if state_dict_path is not None and os.path.exists(state_dict_path): |
|
print('Resuming from ', state_dict_path) |
|
state_dict = torch.load(state_dict_path, map_location=cuda) |
|
print('Warning: opt is overwritten.') |
|
dataroot = opt.dataroot |
|
resolution = opt.resolution |
|
results_path = opt.results_path |
|
loadSize = opt.loadSize |
|
|
|
opt = state_dict['opt'] |
|
opt.dataroot = dataroot |
|
opt.resolution = resolution |
|
opt.results_path = results_path |
|
opt.loadSize = loadSize |
|
else: |
|
raise Exception('failed loading state dict!', state_dict_path) |
|
|
|
|
|
|
|
if use_rect: |
|
test_dataset = EvalDataset(opt) |
|
else: |
|
test_dataset = EvalWPoseDataset(opt) |
|
|
|
print('test data size: ', len(test_dataset)) |
|
projection_mode = test_dataset.projection_mode |
|
|
|
opt_netG = state_dict['opt_netG'] |
|
netG = HGPIFuNetwNML(opt_netG, projection_mode).to(device=cuda) |
|
netMR = HGPIFuMRNet(opt, netG, projection_mode).to(device=cuda) |
|
|
|
def set_eval(): |
|
netG.eval() |
|
|
|
|
|
netMR.load_state_dict(state_dict['model_state_dict']) |
|
|
|
os.makedirs(opt.checkpoints_path, exist_ok=True) |
|
os.makedirs(opt.results_path, exist_ok=True) |
|
os.makedirs('%s/%s/recon' % (opt.results_path, opt.name), exist_ok=True) |
|
|
|
if start_id < 0: |
|
start_id = 0 |
|
if end_id < 0: |
|
end_id = len(test_dataset) |
|
|
|
|
|
with torch.no_grad(): |
|
set_eval() |
|
|
|
print('generate mesh (test) ...') |
|
for i in tqdm(range(start_id, end_id)): |
|
if i >= len(test_dataset): |
|
break |
|
|
|
|
|
if True: |
|
test_data = test_dataset[i] |
|
|
|
save_path = '%s/%s/recon/result_%s_%d.obj' % (opt.results_path, opt.name, test_data['name'], opt.resolution) |
|
|
|
print(save_path) |
|
gen_mesh(opt.resolution, netMR, cuda, test_data, save_path, components=opt.use_compose) |
|
else: |
|
for j in range(test_dataset.get_n_person(i)): |
|
test_dataset.person_id = j |
|
test_data = test_dataset[i] |
|
save_path = '%s/%s/recon/result_%s_%d.obj' % (opt.results_path, opt.name, test_data['name'], j) |
|
gen_mesh(opt.resolution, netMR, cuda, test_data, save_path, components=opt.use_compose) |
|
|
|
def reconWrapper(args=None, use_rect=False): |
|
opt = parser.parse(args) |
|
recon(opt, use_rect) |
|
|
|
if __name__ == '__main__': |
|
reconWrapper() |
|
|
|
|