REC-MV_preprocess / pifuhd /lib /model /HGPIFuNetwNML.py
mambazjp's picture
Upload 58 files
355b5d6
raw
history blame
9.08 kB
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .BasePIFuNet import BasePIFuNet
from .MLP import MLP
from .DepthNormalizer import DepthNormalizer
from .HGFilters import HGFilter
from ..net_util import init_net
from ..networks import define_G
import cv2
class HGPIFuNetwNML(BasePIFuNet):
'''
HGPIFu uses stacked hourglass as an image encoder.
'''
def __init__(self,
opt,
projection_mode='orthogonal',
criteria={'occ': nn.MSELoss()}
):
super(HGPIFuNetwNML, self).__init__(
projection_mode=projection_mode,
criteria=criteria)
self.name = 'hg_pifu'
in_ch = 3
try:
if opt.use_front_normal:
in_ch += 3
if opt.use_back_normal:
in_ch += 3
except:
pass
self.opt = opt
self.image_filter = HGFilter(opt.num_stack, opt.hg_depth, in_ch, opt.hg_dim,
opt.norm, opt.hg_down, False)
self.mlp = MLP(
filter_channels=self.opt.mlp_dim,
merge_layer=self.opt.merge_layer,
res_layers=self.opt.mlp_res_layers,
norm=self.opt.mlp_norm,
last_op=nn.Sigmoid())
self.spatial_enc = DepthNormalizer(opt)
self.im_feat_list = []
self.tmpx = None
self.normx = None
self.phi = None
self.intermediate_preds_list = []
init_net(self)
self.netF = None
self.netB = None
try:
if opt.use_front_normal:
self.netF = define_G(3, 3, 64, "global", 4, 9, 1, 3, "instance")
if opt.use_back_normal:
self.netB = define_G(3, 3, 64, "global", 4, 9, 1, 3, "instance")
except:
pass
self.nmlF = None
self.nmlB = None
def loadFromHGHPIFu(self, net):
hgnet = net.image_filter
pretrained_dict = hgnet.state_dict()
model_dict = self.image_filter.state_dict()
pretrained_dict = {k: v for k, v in hgnet.state_dict().items() if k in model_dict}
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
not_initialized = set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])
print('not initialized', sorted(not_initialized))
self.image_filter.load_state_dict(model_dict)
pretrained_dict = net.mlp.state_dict()
model_dict = self.mlp.state_dict()
pretrained_dict = {k: v for k, v in net.mlp.state_dict().items() if k in model_dict}
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
not_initialized = set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])
print('not initialized', sorted(not_initialized))
self.mlp.load_state_dict(model_dict)
def filter(self, images):
'''
apply a fully convolutional network to images.
the resulting feature will be stored.
args:
images: [B, C, H, W]
'''
nmls = []
# if you wish to train jointly, remove detach etc.
with torch.no_grad():
if self.netF is not None:
self.nmlF = self.netF.forward(images).detach()
nmls.append(self.nmlF)
if self.netB is not None:
self.nmlB = self.netB.forward(images).detach()
nmls.append(self.nmlB)
if len(nmls) != 0:
nmls = torch.cat(nmls,1)
if images.size()[2:] != nmls.size()[2:]:
nmls = nn.Upsample(size=images.size()[2:], mode='bilinear', align_corners=True)(nmls)
images = torch.cat([images,nmls],1)
self.im_feat_list, self.normx = self.image_filter(images)
if not self.training:
self.im_feat_list = [self.im_feat_list[-1]]
def query(self, points, calibs, transforms=None, labels=None, update_pred=True, update_phi=True):
'''
given 3d points, we obtain 2d projection of these given the camera matrices.
filter needs to be called beforehand.
the prediction is stored to self.preds
args:
points: [B, 3, N] 3d points in world space
calibs: [B, 3, 4] calibration matrices for each image
transforms: [B, 2, 3] image space coordinate transforms
labels: [B, C, N] ground truth labels (for supervision only)
return:
[B, C, N] prediction
'''
xyz = self.projection(points, calibs, transforms)
xy = xyz[:, :2, :]
# if the point is outside bounding box, return outside.
in_bb = (xyz >= -1) & (xyz <= 1)
in_bb = in_bb[:, 0, :] & in_bb[:, 1, :] & in_bb[:, 2, :]
in_bb = in_bb[:, None, :].detach().float()
if labels is not None:
self.labels = in_bb * labels
sp_feat = self.spatial_enc(xyz, calibs=calibs)
intermediate_preds_list = []
phi = None
for i, im_feat in enumerate(self.im_feat_list):
point_local_feat_list = [self.index(im_feat, xy), sp_feat]
point_local_feat = torch.cat(point_local_feat_list, 1)
pred, phi = self.mlp(point_local_feat)
pred = in_bb * pred
intermediate_preds_list.append(pred)
if update_phi:
self.phi = phi
if update_pred:
self.intermediate_preds_list = intermediate_preds_list
self.preds = self.intermediate_preds_list[-1]
def calc_normal(self, points, calibs, transforms=None, labels=None, delta=0.01, fd_type='forward'):
'''
return surface normal in 'model' space.
it computes normal only in the last stack.
note that the current implementation use forward difference.
args:
points: [B, 3, N] 3d points in world space
calibs: [B, 3, 4] calibration matrices for each image
transforms: [B, 2, 3] image space coordinate transforms
delta: perturbation for finite difference
fd_type: finite difference type (forward/backward/central)
'''
pdx = points.clone()
pdx[:,0,:] += delta
pdy = points.clone()
pdy[:,1,:] += delta
pdz = points.clone()
pdz[:,2,:] += delta
if labels is not None:
self.labels_nml = labels
points_all = torch.stack([points, pdx, pdy, pdz], 3)
points_all = points_all.view(*points.size()[:2],-1)
xyz = self.projection(points_all, calibs, transforms)
xy = xyz[:, :2, :]
im_feat = self.im_feat_list[-1]
sp_feat = self.spatial_enc(xyz, calibs=calibs)
point_local_feat_list = [self.index(im_feat, xy), sp_feat]
point_local_feat = torch.cat(point_local_feat_list, 1)
pred = self.mlp(point_local_feat)[0]
pred = pred.view(*pred.size()[:2],-1,4) # (B, 1, N, 4)
# divide by delta is omitted since it's normalized anyway
dfdx = pred[:,:,:,1] - pred[:,:,:,0]
dfdy = pred[:,:,:,2] - pred[:,:,:,0]
dfdz = pred[:,:,:,3] - pred[:,:,:,0]
nml = -torch.cat([dfdx,dfdy,dfdz], 1)
nml = F.normalize(nml, dim=1, eps=1e-8)
self.nmls = nml
def get_im_feat(self):
'''
return the image filter in the last stack
return:
[B, C, H, W]
'''
return self.im_feat_list[-1]
def get_error(self, gamma):
'''
return the loss given the ground truth labels and prediction
'''
error = {}
error['Err(occ)'] = 0
for preds in self.intermediate_preds_list:
error['Err(occ)'] += self.criteria['occ'](preds, self.labels, gamma)
error['Err(occ)'] /= len(self.intermediate_preds_list)
if self.nmls is not None and self.labels_nml is not None:
error['Err(nml)'] = self.criteria['nml'](self.nmls, self.labels_nml)
return error
def forward(self, images, points, calibs, labels, gamma, points_nml=None, labels_nml=None):
self.filter(images)
self.query(points, calibs, labels=labels)
if points_nml is not None and labels_nml is not None:
self.calc_normal(points_nml, calibs, labels=labels_nml)
res = self.get_preds()
err = self.get_error(gamma)
return err, res