|
|
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from .BasePIFuNet import BasePIFuNet |
|
from .MLP import MLP |
|
from .DepthNormalizer import DepthNormalizer |
|
from .HGFilters import HGFilter |
|
from ..net_util import init_net |
|
from ..networks import define_G |
|
import cv2 |
|
|
|
class HGPIFuNetwNML(BasePIFuNet): |
|
''' |
|
HGPIFu uses stacked hourglass as an image encoder. |
|
''' |
|
|
|
def __init__(self, |
|
opt, |
|
projection_mode='orthogonal', |
|
criteria={'occ': nn.MSELoss()} |
|
): |
|
super(HGPIFuNetwNML, self).__init__( |
|
projection_mode=projection_mode, |
|
criteria=criteria) |
|
|
|
self.name = 'hg_pifu' |
|
|
|
in_ch = 3 |
|
try: |
|
if opt.use_front_normal: |
|
in_ch += 3 |
|
if opt.use_back_normal: |
|
in_ch += 3 |
|
except: |
|
pass |
|
self.opt = opt |
|
self.image_filter = HGFilter(opt.num_stack, opt.hg_depth, in_ch, opt.hg_dim, |
|
opt.norm, opt.hg_down, False) |
|
|
|
self.mlp = MLP( |
|
filter_channels=self.opt.mlp_dim, |
|
merge_layer=self.opt.merge_layer, |
|
res_layers=self.opt.mlp_res_layers, |
|
norm=self.opt.mlp_norm, |
|
last_op=nn.Sigmoid()) |
|
|
|
self.spatial_enc = DepthNormalizer(opt) |
|
|
|
self.im_feat_list = [] |
|
self.tmpx = None |
|
self.normx = None |
|
self.phi = None |
|
|
|
self.intermediate_preds_list = [] |
|
|
|
init_net(self) |
|
|
|
self.netF = None |
|
self.netB = None |
|
try: |
|
if opt.use_front_normal: |
|
self.netF = define_G(3, 3, 64, "global", 4, 9, 1, 3, "instance") |
|
if opt.use_back_normal: |
|
self.netB = define_G(3, 3, 64, "global", 4, 9, 1, 3, "instance") |
|
except: |
|
pass |
|
self.nmlF = None |
|
self.nmlB = None |
|
|
|
def loadFromHGHPIFu(self, net): |
|
hgnet = net.image_filter |
|
pretrained_dict = hgnet.state_dict() |
|
model_dict = self.image_filter.state_dict() |
|
|
|
pretrained_dict = {k: v for k, v in hgnet.state_dict().items() if k in model_dict} |
|
|
|
for k, v in pretrained_dict.items(): |
|
if v.size() == model_dict[k].size(): |
|
model_dict[k] = v |
|
|
|
not_initialized = set() |
|
|
|
for k, v in model_dict.items(): |
|
if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): |
|
not_initialized.add(k.split('.')[0]) |
|
|
|
print('not initialized', sorted(not_initialized)) |
|
self.image_filter.load_state_dict(model_dict) |
|
|
|
pretrained_dict = net.mlp.state_dict() |
|
model_dict = self.mlp.state_dict() |
|
|
|
pretrained_dict = {k: v for k, v in net.mlp.state_dict().items() if k in model_dict} |
|
|
|
for k, v in pretrained_dict.items(): |
|
if v.size() == model_dict[k].size(): |
|
model_dict[k] = v |
|
|
|
not_initialized = set() |
|
|
|
for k, v in model_dict.items(): |
|
if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): |
|
not_initialized.add(k.split('.')[0]) |
|
|
|
print('not initialized', sorted(not_initialized)) |
|
self.mlp.load_state_dict(model_dict) |
|
|
|
def filter(self, images): |
|
''' |
|
apply a fully convolutional network to images. |
|
the resulting feature will be stored. |
|
args: |
|
images: [B, C, H, W] |
|
''' |
|
nmls = [] |
|
|
|
with torch.no_grad(): |
|
if self.netF is not None: |
|
self.nmlF = self.netF.forward(images).detach() |
|
nmls.append(self.nmlF) |
|
if self.netB is not None: |
|
self.nmlB = self.netB.forward(images).detach() |
|
nmls.append(self.nmlB) |
|
if len(nmls) != 0: |
|
nmls = torch.cat(nmls,1) |
|
if images.size()[2:] != nmls.size()[2:]: |
|
nmls = nn.Upsample(size=images.size()[2:], mode='bilinear', align_corners=True)(nmls) |
|
images = torch.cat([images,nmls],1) |
|
|
|
|
|
self.im_feat_list, self.normx = self.image_filter(images) |
|
|
|
if not self.training: |
|
self.im_feat_list = [self.im_feat_list[-1]] |
|
|
|
def query(self, points, calibs, transforms=None, labels=None, update_pred=True, update_phi=True): |
|
''' |
|
given 3d points, we obtain 2d projection of these given the camera matrices. |
|
filter needs to be called beforehand. |
|
the prediction is stored to self.preds |
|
args: |
|
points: [B, 3, N] 3d points in world space |
|
calibs: [B, 3, 4] calibration matrices for each image |
|
transforms: [B, 2, 3] image space coordinate transforms |
|
labels: [B, C, N] ground truth labels (for supervision only) |
|
return: |
|
[B, C, N] prediction |
|
''' |
|
xyz = self.projection(points, calibs, transforms) |
|
xy = xyz[:, :2, :] |
|
|
|
|
|
in_bb = (xyz >= -1) & (xyz <= 1) |
|
in_bb = in_bb[:, 0, :] & in_bb[:, 1, :] & in_bb[:, 2, :] |
|
in_bb = in_bb[:, None, :].detach().float() |
|
|
|
if labels is not None: |
|
self.labels = in_bb * labels |
|
|
|
sp_feat = self.spatial_enc(xyz, calibs=calibs) |
|
|
|
intermediate_preds_list = [] |
|
|
|
phi = None |
|
for i, im_feat in enumerate(self.im_feat_list): |
|
point_local_feat_list = [self.index(im_feat, xy), sp_feat] |
|
point_local_feat = torch.cat(point_local_feat_list, 1) |
|
pred, phi = self.mlp(point_local_feat) |
|
pred = in_bb * pred |
|
|
|
intermediate_preds_list.append(pred) |
|
|
|
if update_phi: |
|
self.phi = phi |
|
|
|
if update_pred: |
|
self.intermediate_preds_list = intermediate_preds_list |
|
self.preds = self.intermediate_preds_list[-1] |
|
|
|
def calc_normal(self, points, calibs, transforms=None, labels=None, delta=0.01, fd_type='forward'): |
|
''' |
|
return surface normal in 'model' space. |
|
it computes normal only in the last stack. |
|
note that the current implementation use forward difference. |
|
args: |
|
points: [B, 3, N] 3d points in world space |
|
calibs: [B, 3, 4] calibration matrices for each image |
|
transforms: [B, 2, 3] image space coordinate transforms |
|
delta: perturbation for finite difference |
|
fd_type: finite difference type (forward/backward/central) |
|
''' |
|
pdx = points.clone() |
|
pdx[:,0,:] += delta |
|
pdy = points.clone() |
|
pdy[:,1,:] += delta |
|
pdz = points.clone() |
|
pdz[:,2,:] += delta |
|
|
|
if labels is not None: |
|
self.labels_nml = labels |
|
|
|
points_all = torch.stack([points, pdx, pdy, pdz], 3) |
|
points_all = points_all.view(*points.size()[:2],-1) |
|
xyz = self.projection(points_all, calibs, transforms) |
|
xy = xyz[:, :2, :] |
|
|
|
im_feat = self.im_feat_list[-1] |
|
sp_feat = self.spatial_enc(xyz, calibs=calibs) |
|
|
|
point_local_feat_list = [self.index(im_feat, xy), sp_feat] |
|
point_local_feat = torch.cat(point_local_feat_list, 1) |
|
|
|
pred = self.mlp(point_local_feat)[0] |
|
|
|
pred = pred.view(*pred.size()[:2],-1,4) |
|
|
|
|
|
dfdx = pred[:,:,:,1] - pred[:,:,:,0] |
|
dfdy = pred[:,:,:,2] - pred[:,:,:,0] |
|
dfdz = pred[:,:,:,3] - pred[:,:,:,0] |
|
|
|
nml = -torch.cat([dfdx,dfdy,dfdz], 1) |
|
nml = F.normalize(nml, dim=1, eps=1e-8) |
|
|
|
self.nmls = nml |
|
|
|
def get_im_feat(self): |
|
''' |
|
return the image filter in the last stack |
|
return: |
|
[B, C, H, W] |
|
''' |
|
return self.im_feat_list[-1] |
|
|
|
|
|
def get_error(self, gamma): |
|
''' |
|
return the loss given the ground truth labels and prediction |
|
''' |
|
error = {} |
|
error['Err(occ)'] = 0 |
|
for preds in self.intermediate_preds_list: |
|
error['Err(occ)'] += self.criteria['occ'](preds, self.labels, gamma) |
|
|
|
error['Err(occ)'] /= len(self.intermediate_preds_list) |
|
|
|
if self.nmls is not None and self.labels_nml is not None: |
|
error['Err(nml)'] = self.criteria['nml'](self.nmls, self.labels_nml) |
|
|
|
return error |
|
|
|
def forward(self, images, points, calibs, labels, gamma, points_nml=None, labels_nml=None): |
|
self.filter(images) |
|
self.query(points, calibs, labels=labels) |
|
if points_nml is not None and labels_nml is not None: |
|
self.calc_normal(points_nml, calibs, labels=labels_nml) |
|
res = self.get_preds() |
|
|
|
err = self.get_error(gamma) |
|
|
|
return err, res |
|
|