File size: 10,843 Bytes
585c7ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
"""
import numpy as np
import torch
from src.face3d.models.base_model import BaseModel
from src.face3d.models import networks
from src.face3d.models.bfm import ParametricFaceModel
from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss
from src.face3d.util import util
from src.face3d.util.nvdiffrast import MeshRenderer
# from src.face3d.util.preprocess import estimate_norm_torch
import trimesh
from scipy.io import savemat
class FaceReconModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=False):
""" Configures options specific for CUT model
"""
# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')
parser.add_argument('--init_path', type=str, default='./checkpoints/init_model/resnet50-0676ba61.pth')
parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc')
parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
# renderer parameters
parser.add_argument('--focal', type=float, default=1015.)
parser.add_argument('--center', type=float, default=112.)
parser.add_argument('--camera_d', type=float, default=10.)
parser.add_argument('--z_near', type=float, default=5.)
parser.add_argument('--z_far', type=float, default=15.)
if is_train:
# training parameters
parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure')
parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth')
parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss')
parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face')
# augmentation parameters
parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels')
parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor')
parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree')
# loss weights
parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss')
parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss')
parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss')
parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss')
parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss')
parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss')
parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss')
parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss')
parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss')
opt, _ = parser.parse_known_args()
parser.set_defaults(
focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.
)
if is_train:
parser.set_defaults(
use_crop_face=True, use_predef_M=False
)
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
self.visual_names = ['output_vis']
self.model_names = ['net_recon']
self.parallel_names = self.model_names + ['renderer']
self.facemodel = ParametricFaceModel(
bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center,
is_train=self.isTrain, default_name=opt.bfm_model
)
fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
self.renderer = MeshRenderer(
rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center)
)
if self.isTrain:
self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc']
self.net_recog = networks.define_net_recog(
net_recog=opt.net_recog, pretrained_path=opt.net_recog_path
)
# loss func name: (compute_%s_loss) % loss_name
self.compute_feat_loss = perceptual_loss
self.comupte_color_loss = photo_loss
self.compute_lm_loss = landmark_loss
self.compute_reg_loss = reg_loss
self.compute_reflc_loss = reflectance_loss
self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)
self.optimizers = [self.optimizer]
self.parallel_names += ['net_recog']
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
self.input_img = input['imgs'].to(self.device)
self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None
self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
self.trans_m = input['M'].to(self.device) if 'M' in input else None
self.image_paths = input['im_paths'] if 'im_paths' in input else None
def forward(self, output_coeff, device):
self.facemodel.to(device)
self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \
self.facemodel.compute_for_render(output_coeff)
self.pred_mask, _, self.pred_face = self.renderer(
self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)
self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)
def compute_losses(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
assert self.net_recog.training == False
trans_m = self.trans_m
if not self.opt.use_predef_M:
trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])
pred_feat = self.net_recog(self.pred_face, trans_m)
gt_feat = self.net_recog(self.input_img, self.trans_m)
self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)
face_mask = self.pred_mask
if self.opt.use_crop_face:
face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)
face_mask = face_mask.detach()
self.loss_color = self.opt.w_color * self.comupte_color_loss(
self.pred_face, self.input_img, self.atten_mask * face_mask)
loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)
self.loss_reg = self.opt.w_reg * loss_reg
self.loss_gamma = self.opt.w_gamma * loss_gamma
self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)
self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)
self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \
+ self.loss_lm + self.loss_reflc
def optimize_parameters(self, isTrain=True):
self.forward()
self.compute_losses()
"""Update network weights; it will be called in every training iteration."""
if isTrain:
self.optimizer.zero_grad()
self.loss_all.backward()
self.optimizer.step()
def compute_visuals(self):
with torch.no_grad():
input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
if self.gt_lm is not None:
gt_lm_numpy = self.gt_lm.cpu().numpy()
pred_lm_numpy = self.pred_lm.detach().cpu().numpy()
output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b')
output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r')
output_vis_numpy = np.concatenate((input_img_numpy,
output_vis_numpy_raw, output_vis_numpy), axis=-2)
else:
output_vis_numpy = np.concatenate((input_img_numpy,
output_vis_numpy_raw), axis=-2)
self.output_vis = torch.tensor(
output_vis_numpy / 255., dtype=torch.float32
).permute(0, 3, 1, 2).to(self.device)
def save_mesh(self, name):
recon_shape = self.pred_vertex # get reconstructed shape
recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
recon_shape = recon_shape.cpu().numpy()[0]
recon_color = self.pred_color
recon_color = recon_color.cpu().numpy()[0]
tri = self.facemodel.face_buf.cpu().numpy()
mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8))
mesh.export(name)
def save_coeff(self,name):
pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}
pred_lm = self.pred_lm.cpu().numpy()
pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate
pred_coeffs['lm68'] = pred_lm
savemat(name,pred_coeffs)
|