Spaces:
Runtime error
Runtime error
xuehongyang
commited on
Commit
•
83d8d3c
1
Parent(s):
108a0f7
ser
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- AdaptiveWingLoss/__pycache__/aux.cpython-310.pyc +0 -0
- AdaptiveWingLoss/aux.py +4 -0
- AdaptiveWingLoss/core/__init__.py +0 -0
- AdaptiveWingLoss/core/__pycache__/__init__.cpython-310.pyc +0 -0
- AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-310.pyc +0 -0
- AdaptiveWingLoss/core/__pycache__/models.cpython-310.pyc +0 -0
- AdaptiveWingLoss/core/coord_conv.py +143 -0
- AdaptiveWingLoss/core/dataloader.py +350 -0
- AdaptiveWingLoss/core/evaler.py +125 -0
- AdaptiveWingLoss/core/models.py +239 -0
- AdaptiveWingLoss/utils/__init__.py +0 -0
- AdaptiveWingLoss/utils/utils.py +437 -0
- Deep3DFaceRecon_pytorch/LICENSE +21 -0
- Deep3DFaceRecon_pytorch/README.md +256 -0
- Deep3DFaceRecon_pytorch/data/__init__.py +118 -0
- Deep3DFaceRecon_pytorch/data/base_dataset.py +132 -0
- Deep3DFaceRecon_pytorch/data/flist_dataset.py +129 -0
- Deep3DFaceRecon_pytorch/data/image_folder.py +77 -0
- Deep3DFaceRecon_pytorch/data/template_dataset.py +80 -0
- Deep3DFaceRecon_pytorch/data_preparation.py +57 -0
- Deep3DFaceRecon_pytorch/environment.yml +24 -0
- Deep3DFaceRecon_pytorch/models/__init__.py +69 -0
- Deep3DFaceRecon_pytorch/models/__pycache__/__init__.cpython-310.pyc +0 -0
- Deep3DFaceRecon_pytorch/models/__pycache__/base_model.cpython-310.pyc +0 -0
- Deep3DFaceRecon_pytorch/models/__pycache__/bfm.cpython-310.pyc +0 -0
- Deep3DFaceRecon_pytorch/models/__pycache__/networks.cpython-310.pyc +0 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/README.md +218 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__init__.py +157 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc +0 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc +0 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc +0 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet.py +198 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet2060.py +182 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/mobilefacenet.py +160 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/vit.py +302 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/3millions.py +23 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/__init__.py +0 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/base.py +59 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_mbf.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r100.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r50.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_mbf.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r100.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r50.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_mbf.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r100.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50_onegpu.py +27 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50.py +28 -0
- Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py +28 -0
AdaptiveWingLoss/__pycache__/aux.cpython-310.pyc
ADDED
Binary file (419 Bytes). View file
|
|
AdaptiveWingLoss/aux.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def detect_landmarks(inputs, model_ft):
|
2 |
+
outputs, _ = model_ft(inputs)
|
3 |
+
pred_heatmap = outputs[-1][:, :-1, :, :]
|
4 |
+
return pred_heatmap[:, 96, :, :], pred_heatmap[:, 97, :, :]
|
AdaptiveWingLoss/core/__init__.py
ADDED
File without changes
|
AdaptiveWingLoss/core/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (163 Bytes). View file
|
|
AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-310.pyc
ADDED
Binary file (4.25 kB). View file
|
|
AdaptiveWingLoss/core/__pycache__/models.cpython-310.pyc
ADDED
Binary file (5.82 kB). View file
|
|
AdaptiveWingLoss/core/coord_conv.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class AddCoordsTh(nn.Module):
|
6 |
+
def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False):
|
7 |
+
super(AddCoordsTh, self).__init__()
|
8 |
+
self.x_dim = x_dim
|
9 |
+
self.y_dim = y_dim
|
10 |
+
self.with_r = with_r
|
11 |
+
self.with_boundary = with_boundary
|
12 |
+
|
13 |
+
def forward(self, input_tensor, heatmap=None):
|
14 |
+
"""
|
15 |
+
input_tensor: (batch, c, x_dim, y_dim)
|
16 |
+
"""
|
17 |
+
batch_size_tensor = input_tensor.shape[0]
|
18 |
+
|
19 |
+
xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(input_tensor.device)
|
20 |
+
xx_ones = xx_ones.unsqueeze(-1)
|
21 |
+
|
22 |
+
xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor.device)
|
23 |
+
xx_range = xx_range.unsqueeze(1)
|
24 |
+
|
25 |
+
xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
|
26 |
+
xx_channel = xx_channel.unsqueeze(-1)
|
27 |
+
|
28 |
+
yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(input_tensor.device)
|
29 |
+
yy_ones = yy_ones.unsqueeze(1)
|
30 |
+
|
31 |
+
yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor.device)
|
32 |
+
yy_range = yy_range.unsqueeze(-1)
|
33 |
+
|
34 |
+
yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
|
35 |
+
yy_channel = yy_channel.unsqueeze(-1)
|
36 |
+
|
37 |
+
xx_channel = xx_channel.permute(0, 3, 2, 1)
|
38 |
+
yy_channel = yy_channel.permute(0, 3, 2, 1)
|
39 |
+
|
40 |
+
xx_channel = xx_channel / (self.x_dim - 1)
|
41 |
+
yy_channel = yy_channel / (self.y_dim - 1)
|
42 |
+
|
43 |
+
xx_channel = xx_channel * 2 - 1
|
44 |
+
yy_channel = yy_channel * 2 - 1
|
45 |
+
|
46 |
+
xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
|
47 |
+
yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
|
48 |
+
|
49 |
+
if self.with_boundary and type(heatmap) != type(None):
|
50 |
+
boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0)
|
51 |
+
|
52 |
+
zero_tensor = torch.zeros_like(xx_channel)
|
53 |
+
xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor)
|
54 |
+
yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor)
|
55 |
+
if self.with_boundary and type(heatmap) != type(None):
|
56 |
+
xx_boundary_channel = xx_boundary_channel.to(input_tensor.device)
|
57 |
+
yy_boundary_channel = yy_boundary_channel.to(input_tensor.device)
|
58 |
+
ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
|
59 |
+
|
60 |
+
if self.with_r:
|
61 |
+
rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
|
62 |
+
rr = rr / torch.max(rr)
|
63 |
+
ret = torch.cat([ret, rr], dim=1)
|
64 |
+
|
65 |
+
if self.with_boundary and type(heatmap) != type(None):
|
66 |
+
ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1)
|
67 |
+
return ret
|
68 |
+
|
69 |
+
|
70 |
+
class CoordConvTh(nn.Module):
|
71 |
+
"""CoordConv layer as in the paper."""
|
72 |
+
|
73 |
+
def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs):
|
74 |
+
super(CoordConvTh, self).__init__()
|
75 |
+
self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary)
|
76 |
+
in_channels += 2
|
77 |
+
if with_r:
|
78 |
+
in_channels += 1
|
79 |
+
if with_boundary and not first_one:
|
80 |
+
in_channels += 2
|
81 |
+
self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
|
82 |
+
|
83 |
+
def forward(self, input_tensor, heatmap=None):
|
84 |
+
ret = self.addcoords(input_tensor, heatmap)
|
85 |
+
last_channel = ret[:, -2:, :, :]
|
86 |
+
ret = self.conv(ret)
|
87 |
+
return ret, last_channel
|
88 |
+
|
89 |
+
|
90 |
+
"""
|
91 |
+
An alternative implementation for PyTorch with auto-infering the x-y dimensions.
|
92 |
+
"""
|
93 |
+
|
94 |
+
|
95 |
+
class AddCoords(nn.Module):
|
96 |
+
def __init__(self, with_r=False):
|
97 |
+
super().__init__()
|
98 |
+
self.with_r = with_r
|
99 |
+
|
100 |
+
def forward(self, input_tensor):
|
101 |
+
"""
|
102 |
+
Args:
|
103 |
+
input_tensor: shape(batch, channel, x_dim, y_dim)
|
104 |
+
"""
|
105 |
+
batch_size, _, x_dim, y_dim = input_tensor.size()
|
106 |
+
|
107 |
+
xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
|
108 |
+
yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
|
109 |
+
|
110 |
+
xx_channel = xx_channel / (x_dim - 1)
|
111 |
+
yy_channel = yy_channel / (y_dim - 1)
|
112 |
+
|
113 |
+
xx_channel = xx_channel * 2 - 1
|
114 |
+
yy_channel = yy_channel * 2 - 1
|
115 |
+
|
116 |
+
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
|
117 |
+
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
|
118 |
+
|
119 |
+
if input_tensor.is_cuda:
|
120 |
+
xx_channel = xx_channel.to(input_tensor.device)
|
121 |
+
yy_channel = yy_channel.to(input_tensor.device)
|
122 |
+
|
123 |
+
ret = torch.cat([input_tensor, xx_channel.type_as(input_tensor), yy_channel.type_as(input_tensor)], dim=1)
|
124 |
+
|
125 |
+
if self.with_r:
|
126 |
+
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
|
127 |
+
if input_tensor.is_cuda:
|
128 |
+
rr = rr.to(input_tensor.device)
|
129 |
+
ret = torch.cat([ret, rr], dim=1)
|
130 |
+
|
131 |
+
return ret
|
132 |
+
|
133 |
+
|
134 |
+
class CoordConv(nn.Module):
|
135 |
+
def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
|
136 |
+
super().__init__()
|
137 |
+
self.addcoords = AddCoords(with_r=with_r)
|
138 |
+
self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
ret = self.addcoords(x)
|
142 |
+
ret = self.conv(ret)
|
143 |
+
return ret
|
AdaptiveWingLoss/core/dataloader.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import glob
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import sys
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import numpy as np
|
11 |
+
import scipy.io as sio
|
12 |
+
import torch
|
13 |
+
from imgaug import augmenters as iaa
|
14 |
+
from PIL import Image
|
15 |
+
from scipy import interpolate
|
16 |
+
from skimage import io
|
17 |
+
from skimage import transform as ski_transform
|
18 |
+
from skimage.color import rgb2gray
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
from torchvision import transforms
|
22 |
+
from torchvision import utils
|
23 |
+
from torchvision.transforms import Compose
|
24 |
+
from torchvision.transforms import Lambda
|
25 |
+
from torchvision.transforms.functional import adjust_brightness
|
26 |
+
from torchvision.transforms.functional import adjust_contrast
|
27 |
+
from torchvision.transforms.functional import adjust_hue
|
28 |
+
from torchvision.transforms.functional import adjust_saturation
|
29 |
+
|
30 |
+
from utils.utils import cv_crop
|
31 |
+
from utils.utils import cv_rotate
|
32 |
+
from utils.utils import draw_gaussian
|
33 |
+
from utils.utils import fig2data
|
34 |
+
from utils.utils import generate_weight_map
|
35 |
+
from utils.utils import power_transform
|
36 |
+
from utils.utils import shuffle_lr
|
37 |
+
from utils.utils import transform
|
38 |
+
|
39 |
+
|
40 |
+
class AddBoundary(object):
|
41 |
+
def __init__(self, num_landmarks=68):
|
42 |
+
self.num_landmarks = num_landmarks
|
43 |
+
|
44 |
+
def __call__(self, sample):
|
45 |
+
landmarks_64 = np.floor(sample["landmarks"] / 4.0)
|
46 |
+
if self.num_landmarks == 68:
|
47 |
+
boundaries = {}
|
48 |
+
boundaries["cheek"] = landmarks_64[0:17]
|
49 |
+
boundaries["left_eyebrow"] = landmarks_64[17:22]
|
50 |
+
boundaries["right_eyebrow"] = landmarks_64[22:27]
|
51 |
+
boundaries["uper_left_eyelid"] = landmarks_64[36:40]
|
52 |
+
boundaries["lower_left_eyelid"] = np.array([landmarks_64[i] for i in [36, 41, 40, 39]])
|
53 |
+
boundaries["upper_right_eyelid"] = landmarks_64[42:46]
|
54 |
+
boundaries["lower_right_eyelid"] = np.array([landmarks_64[i] for i in [42, 47, 46, 45]])
|
55 |
+
boundaries["noise"] = landmarks_64[27:31]
|
56 |
+
boundaries["noise_bot"] = landmarks_64[31:36]
|
57 |
+
boundaries["upper_outer_lip"] = landmarks_64[48:55]
|
58 |
+
boundaries["upper_inner_lip"] = np.array([landmarks_64[i] for i in [60, 61, 62, 63, 64]])
|
59 |
+
boundaries["lower_outer_lip"] = np.array([landmarks_64[i] for i in [48, 59, 58, 57, 56, 55, 54]])
|
60 |
+
boundaries["lower_inner_lip"] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
|
61 |
+
elif self.num_landmarks == 98:
|
62 |
+
boundaries = {}
|
63 |
+
boundaries["cheek"] = landmarks_64[0:33]
|
64 |
+
boundaries["left_eyebrow"] = landmarks_64[33:38]
|
65 |
+
boundaries["right_eyebrow"] = landmarks_64[42:47]
|
66 |
+
boundaries["uper_left_eyelid"] = landmarks_64[60:65]
|
67 |
+
boundaries["lower_left_eyelid"] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
|
68 |
+
boundaries["upper_right_eyelid"] = landmarks_64[68:73]
|
69 |
+
boundaries["lower_right_eyelid"] = np.array([landmarks_64[i] for i in [68, 75, 74, 73, 72]])
|
70 |
+
boundaries["noise"] = landmarks_64[51:55]
|
71 |
+
boundaries["noise_bot"] = landmarks_64[55:60]
|
72 |
+
boundaries["upper_outer_lip"] = landmarks_64[76:83]
|
73 |
+
boundaries["upper_inner_lip"] = np.array([landmarks_64[i] for i in [88, 89, 90, 91, 92]])
|
74 |
+
boundaries["lower_outer_lip"] = np.array([landmarks_64[i] for i in [76, 87, 86, 85, 84, 83, 82]])
|
75 |
+
boundaries["lower_inner_lip"] = np.array([landmarks_64[i] for i in [88, 95, 94, 93, 92]])
|
76 |
+
elif self.num_landmarks == 19:
|
77 |
+
boundaries = {}
|
78 |
+
boundaries["left_eyebrow"] = landmarks_64[0:3]
|
79 |
+
boundaries["right_eyebrow"] = landmarks_64[3:5]
|
80 |
+
boundaries["left_eye"] = landmarks_64[6:9]
|
81 |
+
boundaries["right_eye"] = landmarks_64[9:12]
|
82 |
+
boundaries["noise"] = landmarks_64[12:15]
|
83 |
+
|
84 |
+
elif self.num_landmarks == 29:
|
85 |
+
boundaries = {}
|
86 |
+
boundaries["upper_left_eyebrow"] = np.stack([landmarks_64[0], landmarks_64[4], landmarks_64[2]], axis=0)
|
87 |
+
boundaries["lower_left_eyebrow"] = np.stack([landmarks_64[0], landmarks_64[5], landmarks_64[2]], axis=0)
|
88 |
+
boundaries["upper_right_eyebrow"] = np.stack([landmarks_64[1], landmarks_64[6], landmarks_64[3]], axis=0)
|
89 |
+
boundaries["lower_right_eyebrow"] = np.stack([landmarks_64[1], landmarks_64[7], landmarks_64[3]], axis=0)
|
90 |
+
boundaries["upper_left_eye"] = np.stack([landmarks_64[8], landmarks_64[12], landmarks_64[10]], axis=0)
|
91 |
+
boundaries["lower_left_eye"] = np.stack([landmarks_64[8], landmarks_64[13], landmarks_64[10]], axis=0)
|
92 |
+
boundaries["upper_right_eye"] = np.stack([landmarks_64[9], landmarks_64[14], landmarks_64[11]], axis=0)
|
93 |
+
boundaries["lower_right_eye"] = np.stack([landmarks_64[9], landmarks_64[15], landmarks_64[11]], axis=0)
|
94 |
+
boundaries["noise"] = np.stack([landmarks_64[18], landmarks_64[21], landmarks_64[19]], axis=0)
|
95 |
+
boundaries["outer_upper_lip"] = np.stack([landmarks_64[22], landmarks_64[24], landmarks_64[23]], axis=0)
|
96 |
+
boundaries["inner_upper_lip"] = np.stack([landmarks_64[22], landmarks_64[25], landmarks_64[23]], axis=0)
|
97 |
+
boundaries["outer_lower_lip"] = np.stack([landmarks_64[22], landmarks_64[26], landmarks_64[23]], axis=0)
|
98 |
+
boundaries["inner_lower_lip"] = np.stack([landmarks_64[22], landmarks_64[27], landmarks_64[23]], axis=0)
|
99 |
+
functions = {}
|
100 |
+
|
101 |
+
for key, points in boundaries.items():
|
102 |
+
temp = points[0]
|
103 |
+
new_points = points[0:1, :]
|
104 |
+
for point in points[1:]:
|
105 |
+
if point[0] == temp[0] and point[1] == temp[1]:
|
106 |
+
continue
|
107 |
+
else:
|
108 |
+
new_points = np.concatenate((new_points, np.expand_dims(point, 0)), axis=0)
|
109 |
+
temp = point
|
110 |
+
points = new_points
|
111 |
+
if points.shape[0] == 1:
|
112 |
+
points = np.concatenate((points, points + 0.001), axis=0)
|
113 |
+
k = min(4, points.shape[0])
|
114 |
+
functions[key] = interpolate.splprep([points[:, 0], points[:, 1]], k=k - 1, s=0)
|
115 |
+
|
116 |
+
boundary_map = np.zeros((64, 64))
|
117 |
+
|
118 |
+
fig = plt.figure(figsize=[64 / 96.0, 64 / 96.0], dpi=96)
|
119 |
+
|
120 |
+
ax = fig.add_axes([0, 0, 1, 1])
|
121 |
+
|
122 |
+
ax.axis("off")
|
123 |
+
|
124 |
+
ax.imshow(boundary_map, interpolation="nearest", cmap="gray")
|
125 |
+
# ax.scatter(landmarks[:, 0], landmarks[:, 1], s=1, marker=',', c='w')
|
126 |
+
|
127 |
+
for key in functions.keys():
|
128 |
+
xnew = np.arange(0, 1, 0.01)
|
129 |
+
out = interpolate.splev(xnew, functions[key][0], der=0)
|
130 |
+
plt.plot(out[0], out[1], ",", linewidth=1, color="w")
|
131 |
+
|
132 |
+
img = fig2data(fig)
|
133 |
+
|
134 |
+
plt.close()
|
135 |
+
|
136 |
+
sigma = 1
|
137 |
+
temp = 255 - img[:, :, 1]
|
138 |
+
temp = cv2.distanceTransform(temp, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
|
139 |
+
temp = temp.astype(np.float32)
|
140 |
+
temp = np.where(temp < 3 * sigma, np.exp(-(temp * temp) / (2 * sigma * sigma)), 0)
|
141 |
+
|
142 |
+
fig = plt.figure(figsize=[64 / 96.0, 64 / 96.0], dpi=96)
|
143 |
+
|
144 |
+
ax = fig.add_axes([0, 0, 1, 1])
|
145 |
+
|
146 |
+
ax.axis("off")
|
147 |
+
ax.imshow(temp, cmap="gray")
|
148 |
+
plt.close()
|
149 |
+
|
150 |
+
boundary_map = fig2data(fig)
|
151 |
+
|
152 |
+
sample["boundary"] = boundary_map[:, :, 0]
|
153 |
+
|
154 |
+
return sample
|
155 |
+
|
156 |
+
|
157 |
+
class AddWeightMap(object):
|
158 |
+
def __call__(self, sample):
|
159 |
+
heatmap = sample["heatmap"]
|
160 |
+
boundary = sample["boundary"]
|
161 |
+
heatmap = np.concatenate((heatmap, np.expand_dims(boundary, axis=0)), 0)
|
162 |
+
weight_map = np.zeros_like(heatmap)
|
163 |
+
for i in range(heatmap.shape[0]):
|
164 |
+
weight_map[i] = generate_weight_map(weight_map[i], heatmap[i])
|
165 |
+
sample["weight_map"] = weight_map
|
166 |
+
return sample
|
167 |
+
|
168 |
+
|
169 |
+
class ToTensor(object):
|
170 |
+
"""Convert ndarrays in sample to Tensors."""
|
171 |
+
|
172 |
+
def __call__(self, sample):
|
173 |
+
image, heatmap, landmarks, boundary, weight_map = (
|
174 |
+
sample["image"],
|
175 |
+
sample["heatmap"],
|
176 |
+
sample["landmarks"],
|
177 |
+
sample["boundary"],
|
178 |
+
sample["weight_map"],
|
179 |
+
)
|
180 |
+
|
181 |
+
# swap color axis because
|
182 |
+
# numpy image: H x W x C
|
183 |
+
# torch image: C X H X W
|
184 |
+
if len(image.shape) == 2:
|
185 |
+
image = np.expand_dims(image, axis=2)
|
186 |
+
image_small = np.expand_dims(image_small, axis=2)
|
187 |
+
image = image.transpose((2, 0, 1))
|
188 |
+
boundary = np.expand_dims(boundary, axis=2)
|
189 |
+
boundary = boundary.transpose((2, 0, 1))
|
190 |
+
return {
|
191 |
+
"image": torch.from_numpy(image).float().div(255.0),
|
192 |
+
"heatmap": torch.from_numpy(heatmap).float(),
|
193 |
+
"landmarks": torch.from_numpy(landmarks).float(),
|
194 |
+
"boundary": torch.from_numpy(boundary).float().div(255.0),
|
195 |
+
"weight_map": torch.from_numpy(weight_map).float(),
|
196 |
+
}
|
197 |
+
|
198 |
+
|
199 |
+
class FaceLandmarksDataset(Dataset):
|
200 |
+
"""Face Landmarks dataset."""
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
img_dir,
|
205 |
+
landmarks_dir,
|
206 |
+
num_landmarks=68,
|
207 |
+
gray_scale=False,
|
208 |
+
detect_face=False,
|
209 |
+
enhance=False,
|
210 |
+
center_shift=0,
|
211 |
+
transform=None,
|
212 |
+
):
|
213 |
+
"""
|
214 |
+
Args:
|
215 |
+
landmark_dir (string): Path to the mat file with landmarks saved.
|
216 |
+
img_dir (string): Directory with all the images.
|
217 |
+
transform (callable, optional): Optional transform to be applied
|
218 |
+
on a sample.
|
219 |
+
"""
|
220 |
+
self.img_dir = img_dir
|
221 |
+
self.landmarks_dir = landmarks_dir
|
222 |
+
self.num_lanmdkars = num_landmarks
|
223 |
+
self.transform = transform
|
224 |
+
self.img_names = glob.glob(self.img_dir + "*.jpg") + glob.glob(self.img_dir + "*.png")
|
225 |
+
self.gray_scale = gray_scale
|
226 |
+
self.detect_face = detect_face
|
227 |
+
self.enhance = enhance
|
228 |
+
self.center_shift = center_shift
|
229 |
+
if self.detect_face:
|
230 |
+
self.face_detector = MTCNN(thresh=[0.5, 0.6, 0.7])
|
231 |
+
|
232 |
+
def __len__(self):
|
233 |
+
return len(self.img_names)
|
234 |
+
|
235 |
+
def __getitem__(self, idx):
|
236 |
+
img_name = self.img_names[idx]
|
237 |
+
pil_image = Image.open(img_name)
|
238 |
+
if pil_image.mode != "RGB":
|
239 |
+
# if input is grayscale image, convert it to 3 channel image
|
240 |
+
if self.enhance:
|
241 |
+
pil_image = power_transform(pil_image, 0.5)
|
242 |
+
temp_image = Image.new("RGB", pil_image.size)
|
243 |
+
temp_image.paste(pil_image)
|
244 |
+
pil_image = temp_image
|
245 |
+
image = np.array(pil_image)
|
246 |
+
if self.gray_scale:
|
247 |
+
image = rgb2gray(image)
|
248 |
+
image = np.expand_dims(image, axis=2)
|
249 |
+
image = np.concatenate((image, image, image), axis=2)
|
250 |
+
image = image * 255.0
|
251 |
+
image = image.astype(np.uint8)
|
252 |
+
if not self.detect_face:
|
253 |
+
center = [450 // 2, 450 // 2 + 0]
|
254 |
+
if self.center_shift != 0:
|
255 |
+
center[0] += int(np.random.uniform(-self.center_shift, self.center_shift))
|
256 |
+
center[1] += int(np.random.uniform(-self.center_shift, self.center_shift))
|
257 |
+
scale = 1.8
|
258 |
+
else:
|
259 |
+
detected_faces = self.face_detector.detect_image(image)
|
260 |
+
if len(detected_faces) > 0:
|
261 |
+
box = detected_faces[0]
|
262 |
+
left, top, right, bottom, _ = box
|
263 |
+
center = [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]
|
264 |
+
center[1] = center[1] - (bottom - top) * 0.12
|
265 |
+
scale = (right - left + bottom - top) / 195.0
|
266 |
+
else:
|
267 |
+
center = [450 // 2, 450 // 2 + 0]
|
268 |
+
scale = 1.8
|
269 |
+
if self.center_shift != 0:
|
270 |
+
shift = self.center * self.center_shift / 450
|
271 |
+
center[0] += int(np.random.uniform(-shift, shift))
|
272 |
+
center[1] += int(np.random.uniform(-shift, shift))
|
273 |
+
base_name = os.path.basename(img_name)
|
274 |
+
landmarks_base_name = base_name[:-4] + "_pts.mat"
|
275 |
+
landmarks_name = os.path.join(self.landmarks_dir, landmarks_base_name)
|
276 |
+
if os.path.isfile(landmarks_name):
|
277 |
+
mat_data = sio.loadmat(landmarks_name)
|
278 |
+
landmarks = mat_data["pts_2d"]
|
279 |
+
elif os.path.isfile(landmarks_name[:-8] + ".pts.npy"):
|
280 |
+
landmarks = np.load(landmarks_name[:-8] + ".pts.npy")
|
281 |
+
else:
|
282 |
+
landmarks = []
|
283 |
+
heatmap = []
|
284 |
+
|
285 |
+
if landmarks != []:
|
286 |
+
new_image, new_landmarks = cv_crop(image, landmarks, center, scale, 256, self.center_shift)
|
287 |
+
tries = 0
|
288 |
+
while self.center_shift != 0 and tries < 5 and (np.max(new_landmarks) > 240 or np.min(new_landmarks) < 15):
|
289 |
+
center = [450 // 2, 450 // 2 + 0]
|
290 |
+
scale += 0.05
|
291 |
+
center[0] += int(np.random.uniform(-self.center_shift, self.center_shift))
|
292 |
+
center[1] += int(np.random.uniform(-self.center_shift, self.center_shift))
|
293 |
+
|
294 |
+
new_image, new_landmarks = cv_crop(image, landmarks, center, scale, 256, self.center_shift)
|
295 |
+
tries += 1
|
296 |
+
if np.max(new_landmarks) > 250 or np.min(new_landmarks) < 5:
|
297 |
+
center = [450 // 2, 450 // 2 + 0]
|
298 |
+
scale = 2.25
|
299 |
+
new_image, new_landmarks = cv_crop(image, landmarks, center, scale, 256, 100)
|
300 |
+
assert np.min(new_landmarks) > 0 and np.max(new_landmarks) < 256, "Landmarks out of boundary!"
|
301 |
+
image = new_image
|
302 |
+
landmarks = new_landmarks
|
303 |
+
heatmap = np.zeros((self.num_lanmdkars, 64, 64))
|
304 |
+
for i in range(self.num_lanmdkars):
|
305 |
+
if landmarks[i][0] > 0:
|
306 |
+
heatmap[i] = draw_gaussian(heatmap[i], landmarks[i] / 4.0 + 1, 1)
|
307 |
+
sample = {"image": image, "heatmap": heatmap, "landmarks": landmarks}
|
308 |
+
if self.transform:
|
309 |
+
sample = self.transform(sample)
|
310 |
+
|
311 |
+
return sample
|
312 |
+
|
313 |
+
|
314 |
+
def get_dataset(
|
315 |
+
val_img_dir,
|
316 |
+
val_landmarks_dir,
|
317 |
+
batch_size,
|
318 |
+
num_landmarks=68,
|
319 |
+
rotation=0,
|
320 |
+
scale=0,
|
321 |
+
center_shift=0,
|
322 |
+
random_flip=False,
|
323 |
+
brightness=0,
|
324 |
+
contrast=0,
|
325 |
+
saturation=0,
|
326 |
+
blur=False,
|
327 |
+
noise=False,
|
328 |
+
jpeg_effect=False,
|
329 |
+
random_occlusion=False,
|
330 |
+
gray_scale=False,
|
331 |
+
detect_face=False,
|
332 |
+
enhance=False,
|
333 |
+
):
|
334 |
+
val_transforms = transforms.Compose([AddBoundary(num_landmarks), AddWeightMap(), ToTensor()])
|
335 |
+
|
336 |
+
val_dataset = FaceLandmarksDataset(
|
337 |
+
val_img_dir,
|
338 |
+
val_landmarks_dir,
|
339 |
+
num_landmarks=num_landmarks,
|
340 |
+
gray_scale=gray_scale,
|
341 |
+
detect_face=detect_face,
|
342 |
+
enhance=enhance,
|
343 |
+
transform=val_transforms,
|
344 |
+
)
|
345 |
+
|
346 |
+
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=6)
|
347 |
+
data_loaders = {"val": val_dataloader}
|
348 |
+
dataset_sizes = {}
|
349 |
+
dataset_sizes["val"] = len(val_dataset)
|
350 |
+
return data_loaders, dataset_sizes
|
AdaptiveWingLoss/core/evaler.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
|
3 |
+
matplotlib.use("Agg")
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import copy
|
7 |
+
import time
|
8 |
+
from torch.autograd import Variable
|
9 |
+
import shutil
|
10 |
+
from skimage import io
|
11 |
+
import numpy as np
|
12 |
+
from utils.utils import fan_NME, show_landmarks, get_preds_fromhm
|
13 |
+
from PIL import Image, ImageDraw
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import cv2
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
|
20 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
21 |
+
|
22 |
+
|
23 |
+
def eval_model(
|
24 |
+
model, dataloaders, dataset_sizes, writer, use_gpu=True, epoches=5, dataset="val", save_path="./", num_landmarks=68
|
25 |
+
):
|
26 |
+
global_nme = 0
|
27 |
+
model.eval()
|
28 |
+
for epoch in range(epoches):
|
29 |
+
running_loss = 0
|
30 |
+
step = 0
|
31 |
+
total_nme = 0
|
32 |
+
total_count = 0
|
33 |
+
fail_count = 0
|
34 |
+
nmes = []
|
35 |
+
# running_corrects = 0
|
36 |
+
|
37 |
+
# Iterate over data.
|
38 |
+
with torch.no_grad():
|
39 |
+
for data in dataloaders[dataset]:
|
40 |
+
total_runtime = 0
|
41 |
+
run_count = 0
|
42 |
+
step_start = time.time()
|
43 |
+
step += 1
|
44 |
+
# get the inputs
|
45 |
+
inputs = data["image"].type(torch.FloatTensor)
|
46 |
+
labels_heatmap = data["heatmap"].type(torch.FloatTensor)
|
47 |
+
labels_boundary = data["boundary"].type(torch.FloatTensor)
|
48 |
+
landmarks = data["landmarks"].type(torch.FloatTensor)
|
49 |
+
loss_weight_map = data["weight_map"].type(torch.FloatTensor)
|
50 |
+
# wrap them in Variable
|
51 |
+
if use_gpu:
|
52 |
+
inputs = inputs.to(device)
|
53 |
+
labels_heatmap = labels_heatmap.to(device)
|
54 |
+
labels_boundary = labels_boundary.to(device)
|
55 |
+
loss_weight_map = loss_weight_map.to(device)
|
56 |
+
else:
|
57 |
+
inputs, labels_heatmap = Variable(inputs), Variable(labels_heatmap)
|
58 |
+
labels_boundary = Variable(labels_boundary)
|
59 |
+
labels = torch.cat((labels_heatmap, labels_boundary), 1)
|
60 |
+
single_start = time.time()
|
61 |
+
outputs, boundary_channels = model(inputs)
|
62 |
+
single_end = time.time()
|
63 |
+
total_runtime += time.time() - single_start
|
64 |
+
run_count += 1
|
65 |
+
step_end = time.time()
|
66 |
+
for i in range(inputs.shape[0]):
|
67 |
+
img = inputs[i]
|
68 |
+
img = img.cpu().numpy()
|
69 |
+
img = img.transpose((1, 2, 0)) * 255.0
|
70 |
+
img = img.astype(np.uint8)
|
71 |
+
img = Image.fromarray(img)
|
72 |
+
# pred_heatmap = outputs[-1][i].detach().cpu()[:-1, :, :]
|
73 |
+
pred_heatmap = outputs[-1][:, :-1, :, :][i].detach().cpu()
|
74 |
+
pred_landmarks, _ = get_preds_fromhm(pred_heatmap.unsqueeze(0))
|
75 |
+
pred_landmarks = pred_landmarks.squeeze().numpy()
|
76 |
+
|
77 |
+
gt_landmarks = data["landmarks"][i].numpy()
|
78 |
+
if num_landmarks == 68:
|
79 |
+
left_eye = np.average(gt_landmarks[36:42], axis=0)
|
80 |
+
right_eye = np.average(gt_landmarks[42:48], axis=0)
|
81 |
+
norm_factor = np.linalg.norm(left_eye - right_eye)
|
82 |
+
# norm_factor = np.linalg.norm(gt_landmarks[36]- gt_landmarks[45])
|
83 |
+
|
84 |
+
elif num_landmarks == 98:
|
85 |
+
norm_factor = np.linalg.norm(gt_landmarks[60] - gt_landmarks[72])
|
86 |
+
elif num_landmarks == 19:
|
87 |
+
left, top = gt_landmarks[-2, :]
|
88 |
+
right, bottom = gt_landmarks[-1, :]
|
89 |
+
norm_factor = math.sqrt(abs(right - left) * abs(top - bottom))
|
90 |
+
gt_landmarks = gt_landmarks[:-2, :]
|
91 |
+
elif num_landmarks == 29:
|
92 |
+
# norm_factor = np.linalg.norm(gt_landmarks[8]- gt_landmarks[9])
|
93 |
+
norm_factor = np.linalg.norm(gt_landmarks[16] - gt_landmarks[17])
|
94 |
+
single_nme = (
|
95 |
+
np.sum(np.linalg.norm(pred_landmarks * 4 - gt_landmarks, axis=1)) / pred_landmarks.shape[0]
|
96 |
+
) / norm_factor
|
97 |
+
|
98 |
+
nmes.append(single_nme)
|
99 |
+
total_count += 1
|
100 |
+
if single_nme > 0.1:
|
101 |
+
fail_count += 1
|
102 |
+
if step % 10 == 0:
|
103 |
+
print(
|
104 |
+
"Step {} Time: {:.6f} Input Mean: {:.6f} Output Mean: {:.6f}".format(
|
105 |
+
step, step_end - step_start, torch.mean(labels), torch.mean(outputs[0])
|
106 |
+
)
|
107 |
+
)
|
108 |
+
# gt_landmarks = landmarks.numpy()
|
109 |
+
# pred_heatmap = outputs[-1].to('cpu').numpy()
|
110 |
+
gt_landmarks = landmarks
|
111 |
+
batch_nme = fan_NME(outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
|
112 |
+
# batch_nme = 0
|
113 |
+
total_nme += batch_nme
|
114 |
+
epoch_nme = total_nme / dataset_sizes["val"]
|
115 |
+
global_nme += epoch_nme
|
116 |
+
nme_save_path = os.path.join(save_path, "nme_log.npy")
|
117 |
+
np.save(nme_save_path, np.array(nmes))
|
118 |
+
print(
|
119 |
+
"NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}".format(
|
120 |
+
epoch_nme, fail_count / total_count, total_count, fail_count
|
121 |
+
)
|
122 |
+
)
|
123 |
+
print("Evaluation done! Average NME: {:.6f}".format(global_nme / epoches))
|
124 |
+
print("Everage runtime for a single batch: {:.6f}".format(total_runtime / run_count))
|
125 |
+
return model
|
AdaptiveWingLoss/core/models.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from AdaptiveWingLoss.core.coord_conv import CoordConvTh
|
8 |
+
|
9 |
+
|
10 |
+
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilation=1):
|
11 |
+
"3x3 convolution with padding"
|
12 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias, dilation=dilation)
|
13 |
+
|
14 |
+
|
15 |
+
class BasicBlock(nn.Module):
|
16 |
+
expansion = 1
|
17 |
+
|
18 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
19 |
+
super(BasicBlock, self).__init__()
|
20 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
21 |
+
# self.bn1 = nn.BatchNorm2d(planes)
|
22 |
+
self.relu = nn.ReLU(inplace=True)
|
23 |
+
self.conv2 = conv3x3(planes, planes)
|
24 |
+
# self.bn2 = nn.BatchNorm2d(planes)
|
25 |
+
self.downsample = downsample
|
26 |
+
self.stride = stride
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
residual = x
|
30 |
+
|
31 |
+
out = self.conv1(x)
|
32 |
+
# out = self.bn1(out)
|
33 |
+
out = self.relu(out)
|
34 |
+
|
35 |
+
out = self.conv2(out)
|
36 |
+
# out = self.bn2(out)
|
37 |
+
|
38 |
+
if self.downsample is not None:
|
39 |
+
residual = self.downsample(x)
|
40 |
+
|
41 |
+
out += residual
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
return out
|
45 |
+
|
46 |
+
|
47 |
+
class ConvBlock(nn.Module):
|
48 |
+
def __init__(self, in_planes, out_planes):
|
49 |
+
super(ConvBlock, self).__init__()
|
50 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
51 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
52 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
53 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), padding=1, dilation=1)
|
54 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
55 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), padding=1, dilation=1)
|
56 |
+
|
57 |
+
if in_planes != out_planes:
|
58 |
+
self.downsample = nn.Sequential(
|
59 |
+
nn.BatchNorm2d(in_planes),
|
60 |
+
nn.ReLU(True),
|
61 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False),
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
self.downsample = None
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
residual = x
|
68 |
+
|
69 |
+
out1 = self.bn1(x)
|
70 |
+
out1 = F.relu(out1, True)
|
71 |
+
out1 = self.conv1(out1)
|
72 |
+
|
73 |
+
out2 = self.bn2(out1)
|
74 |
+
out2 = F.relu(out2, True)
|
75 |
+
out2 = self.conv2(out2)
|
76 |
+
|
77 |
+
out3 = self.bn3(out2)
|
78 |
+
out3 = F.relu(out3, True)
|
79 |
+
out3 = self.conv3(out3)
|
80 |
+
|
81 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
82 |
+
|
83 |
+
if self.downsample is not None:
|
84 |
+
residual = self.downsample(residual)
|
85 |
+
|
86 |
+
out3 += residual
|
87 |
+
|
88 |
+
return out3
|
89 |
+
|
90 |
+
|
91 |
+
class HourGlass(nn.Module):
|
92 |
+
def __init__(self, num_modules, depth, num_features, first_one=False):
|
93 |
+
super(HourGlass, self).__init__()
|
94 |
+
self.num_modules = num_modules
|
95 |
+
self.depth = depth
|
96 |
+
self.features = num_features
|
97 |
+
self.coordconv = CoordConvTh(
|
98 |
+
x_dim=64,
|
99 |
+
y_dim=64,
|
100 |
+
with_r=True,
|
101 |
+
with_boundary=True,
|
102 |
+
in_channels=256,
|
103 |
+
first_one=first_one,
|
104 |
+
out_channels=256,
|
105 |
+
kernel_size=1,
|
106 |
+
stride=1,
|
107 |
+
padding=0,
|
108 |
+
)
|
109 |
+
self._generate_network(self.depth)
|
110 |
+
|
111 |
+
def _generate_network(self, level):
|
112 |
+
self.add_module("b1_" + str(level), ConvBlock(256, 256))
|
113 |
+
|
114 |
+
self.add_module("b2_" + str(level), ConvBlock(256, 256))
|
115 |
+
|
116 |
+
if level > 1:
|
117 |
+
self._generate_network(level - 1)
|
118 |
+
else:
|
119 |
+
self.add_module("b2_plus_" + str(level), ConvBlock(256, 256))
|
120 |
+
|
121 |
+
self.add_module("b3_" + str(level), ConvBlock(256, 256))
|
122 |
+
|
123 |
+
def _forward(self, level, inp):
|
124 |
+
# Upper branch
|
125 |
+
up1 = inp
|
126 |
+
up1 = self._modules["b1_" + str(level)](up1)
|
127 |
+
|
128 |
+
# Lower branch
|
129 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
130 |
+
low1 = self._modules["b2_" + str(level)](low1)
|
131 |
+
|
132 |
+
if level > 1:
|
133 |
+
low2 = self._forward(level - 1, low1)
|
134 |
+
else:
|
135 |
+
low2 = low1
|
136 |
+
low2 = self._modules["b2_plus_" + str(level)](low2)
|
137 |
+
|
138 |
+
low3 = low2
|
139 |
+
low3 = self._modules["b3_" + str(level)](low3)
|
140 |
+
|
141 |
+
up2 = F.upsample(low3, scale_factor=2, mode="nearest")
|
142 |
+
|
143 |
+
return up1 + up2
|
144 |
+
|
145 |
+
def forward(self, x, heatmap):
|
146 |
+
x, last_channel = self.coordconv(x, heatmap)
|
147 |
+
return self._forward(self.depth, x), last_channel
|
148 |
+
|
149 |
+
|
150 |
+
class FAN(nn.Module):
|
151 |
+
def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68):
|
152 |
+
super(FAN, self).__init__()
|
153 |
+
self.num_modules = num_modules
|
154 |
+
self.gray_scale = gray_scale
|
155 |
+
self.end_relu = end_relu
|
156 |
+
self.num_landmarks = num_landmarks
|
157 |
+
|
158 |
+
# Base part
|
159 |
+
if self.gray_scale:
|
160 |
+
self.conv1 = CoordConvTh(
|
161 |
+
x_dim=256,
|
162 |
+
y_dim=256,
|
163 |
+
with_r=True,
|
164 |
+
with_boundary=False,
|
165 |
+
in_channels=3,
|
166 |
+
out_channels=64,
|
167 |
+
kernel_size=7,
|
168 |
+
stride=2,
|
169 |
+
padding=3,
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
self.conv1 = CoordConvTh(
|
173 |
+
x_dim=256,
|
174 |
+
y_dim=256,
|
175 |
+
with_r=True,
|
176 |
+
with_boundary=False,
|
177 |
+
in_channels=3,
|
178 |
+
out_channels=64,
|
179 |
+
kernel_size=7,
|
180 |
+
stride=2,
|
181 |
+
padding=3,
|
182 |
+
)
|
183 |
+
self.bn1 = nn.BatchNorm2d(64)
|
184 |
+
self.conv2 = ConvBlock(64, 128)
|
185 |
+
self.conv3 = ConvBlock(128, 128)
|
186 |
+
self.conv4 = ConvBlock(128, 256)
|
187 |
+
|
188 |
+
# Stacking part
|
189 |
+
for hg_module in range(self.num_modules):
|
190 |
+
if hg_module == 0:
|
191 |
+
first_one = True
|
192 |
+
else:
|
193 |
+
first_one = False
|
194 |
+
self.add_module("m" + str(hg_module), HourGlass(1, 4, 256, first_one))
|
195 |
+
self.add_module("top_m_" + str(hg_module), ConvBlock(256, 256))
|
196 |
+
self.add_module("conv_last" + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
197 |
+
self.add_module("bn_end" + str(hg_module), nn.BatchNorm2d(256))
|
198 |
+
self.add_module("l" + str(hg_module), nn.Conv2d(256, num_landmarks + 1, kernel_size=1, stride=1, padding=0))
|
199 |
+
|
200 |
+
if hg_module < self.num_modules - 1:
|
201 |
+
self.add_module("bl" + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
202 |
+
self.add_module(
|
203 |
+
"al" + str(hg_module), nn.Conv2d(num_landmarks + 1, 256, kernel_size=1, stride=1, padding=0)
|
204 |
+
)
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
x, _ = self.conv1(x)
|
208 |
+
x = F.relu(self.bn1(x), True)
|
209 |
+
# x = F.relu(self.bn1(self.conv1(x)), True)
|
210 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
211 |
+
x = self.conv3(x)
|
212 |
+
x = self.conv4(x)
|
213 |
+
|
214 |
+
previous = x
|
215 |
+
|
216 |
+
outputs = []
|
217 |
+
boundary_channels = []
|
218 |
+
tmp_out = None
|
219 |
+
for i in range(self.num_modules):
|
220 |
+
hg, boundary_channel = self._modules["m" + str(i)](previous, tmp_out)
|
221 |
+
|
222 |
+
ll = hg
|
223 |
+
ll = self._modules["top_m_" + str(i)](ll)
|
224 |
+
|
225 |
+
ll = F.relu(self._modules["bn_end" + str(i)](self._modules["conv_last" + str(i)](ll)), True)
|
226 |
+
|
227 |
+
# Predict heatmaps
|
228 |
+
tmp_out = self._modules["l" + str(i)](ll)
|
229 |
+
if self.end_relu:
|
230 |
+
tmp_out = F.relu(tmp_out) # HACK: Added relu
|
231 |
+
outputs.append(tmp_out)
|
232 |
+
boundary_channels.append(boundary_channel)
|
233 |
+
|
234 |
+
if i < self.num_modules - 1:
|
235 |
+
ll = self._modules["bl" + str(i)](ll)
|
236 |
+
tmp_out_ = self._modules["al" + str(i)](tmp_out)
|
237 |
+
previous = previous + ll + tmp_out_
|
238 |
+
|
239 |
+
return outputs, boundary_channels
|
AdaptiveWingLoss/utils/__init__.py
ADDED
File without changes
|
AdaptiveWingLoss/utils/utils.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
from __future__ import print_function
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import matplotlib
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from PIL import Image
|
14 |
+
from scipy import ndimage
|
15 |
+
from skimage import io
|
16 |
+
from skimage import transform as ski_transform
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
from torchvision import transforms
|
20 |
+
from torchvision import utils
|
21 |
+
|
22 |
+
|
23 |
+
def _gaussian(
|
24 |
+
size=3,
|
25 |
+
sigma=0.25,
|
26 |
+
amplitude=1,
|
27 |
+
normalize=False,
|
28 |
+
width=None,
|
29 |
+
height=None,
|
30 |
+
sigma_horz=None,
|
31 |
+
sigma_vert=None,
|
32 |
+
mean_horz=0.5,
|
33 |
+
mean_vert=0.5,
|
34 |
+
):
|
35 |
+
# handle some defaults
|
36 |
+
if width is None:
|
37 |
+
width = size
|
38 |
+
if height is None:
|
39 |
+
height = size
|
40 |
+
if sigma_horz is None:
|
41 |
+
sigma_horz = sigma
|
42 |
+
if sigma_vert is None:
|
43 |
+
sigma_vert = sigma
|
44 |
+
center_x = mean_horz * width + 0.5
|
45 |
+
center_y = mean_vert * height + 0.5
|
46 |
+
gauss = np.empty((height, width), dtype=np.float32)
|
47 |
+
# generate kernel
|
48 |
+
for i in range(height):
|
49 |
+
for j in range(width):
|
50 |
+
gauss[i][j] = amplitude * math.exp(
|
51 |
+
-(
|
52 |
+
math.pow((j + 1 - center_x) / (sigma_horz * width), 2) / 2.0
|
53 |
+
+ math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0
|
54 |
+
)
|
55 |
+
)
|
56 |
+
if normalize:
|
57 |
+
gauss = gauss / np.sum(gauss)
|
58 |
+
return gauss
|
59 |
+
|
60 |
+
|
61 |
+
def draw_gaussian(image, point, sigma):
|
62 |
+
# Check if the gaussian is inside
|
63 |
+
ul = [np.floor(np.floor(point[0]) - 3 * sigma), np.floor(np.floor(point[1]) - 3 * sigma)]
|
64 |
+
br = [np.floor(np.floor(point[0]) + 3 * sigma), np.floor(np.floor(point[1]) + 3 * sigma)]
|
65 |
+
if ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1:
|
66 |
+
return image
|
67 |
+
size = 6 * sigma + 1
|
68 |
+
g = _gaussian(size)
|
69 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
70 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
71 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
72 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
73 |
+
assert g_x[0] > 0 and g_y[1] > 0
|
74 |
+
correct = False
|
75 |
+
while not correct:
|
76 |
+
try:
|
77 |
+
image[img_y[0] - 1 : img_y[1], img_x[0] - 1 : img_x[1]] = (
|
78 |
+
image[img_y[0] - 1 : img_y[1], img_x[0] - 1 : img_x[1]] + g[g_y[0] - 1 : g_y[1], g_x[0] - 1 : g_x[1]]
|
79 |
+
)
|
80 |
+
correct = True
|
81 |
+
except:
|
82 |
+
print(
|
83 |
+
"img_x: {}, img_y: {}, g_x:{}, g_y:{}, point:{}, g_shape:{}, ul:{}, br:{}".format(
|
84 |
+
img_x, img_y, g_x, g_y, point, g.shape, ul, br
|
85 |
+
)
|
86 |
+
)
|
87 |
+
ul = [np.floor(np.floor(point[0]) - 3 * sigma), np.floor(np.floor(point[1]) - 3 * sigma)]
|
88 |
+
br = [np.floor(np.floor(point[0]) + 3 * sigma), np.floor(np.floor(point[1]) + 3 * sigma)]
|
89 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
90 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
91 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
92 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
93 |
+
pass
|
94 |
+
image[image > 1] = 1
|
95 |
+
return image
|
96 |
+
|
97 |
+
|
98 |
+
def transform(point, center, scale, resolution, rotation=0, invert=False):
|
99 |
+
_pt = np.ones(3)
|
100 |
+
_pt[0] = point[0]
|
101 |
+
_pt[1] = point[1]
|
102 |
+
|
103 |
+
h = 200.0 * scale
|
104 |
+
t = np.eye(3)
|
105 |
+
t[0, 0] = resolution / h
|
106 |
+
t[1, 1] = resolution / h
|
107 |
+
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
108 |
+
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
109 |
+
|
110 |
+
if rotation != 0:
|
111 |
+
rotation = -rotation
|
112 |
+
r = np.eye(3)
|
113 |
+
ang = rotation * math.pi / 180.0
|
114 |
+
s = math.sin(ang)
|
115 |
+
c = math.cos(ang)
|
116 |
+
r[0][0] = c
|
117 |
+
r[0][1] = -s
|
118 |
+
r[1][0] = s
|
119 |
+
r[1][1] = c
|
120 |
+
|
121 |
+
t_ = np.eye(3)
|
122 |
+
t_[0][2] = -resolution / 2.0
|
123 |
+
t_[1][2] = -resolution / 2.0
|
124 |
+
t_inv = torch.eye(3)
|
125 |
+
t_inv[0][2] = resolution / 2.0
|
126 |
+
t_inv[1][2] = resolution / 2.0
|
127 |
+
t = reduce(np.matmul, [t_inv, r, t_, t])
|
128 |
+
|
129 |
+
if invert:
|
130 |
+
t = np.linalg.inv(t)
|
131 |
+
new_point = (np.matmul(t, _pt))[0:2]
|
132 |
+
|
133 |
+
return new_point.astype(int)
|
134 |
+
|
135 |
+
|
136 |
+
def cv_crop(image, landmarks, center, scale, resolution=256, center_shift=0):
|
137 |
+
new_image = cv2.copyMakeBorder(
|
138 |
+
image, center_shift, center_shift, center_shift, center_shift, cv2.BORDER_CONSTANT, value=[0, 0, 0]
|
139 |
+
)
|
140 |
+
new_landmarks = landmarks.copy()
|
141 |
+
if center_shift != 0:
|
142 |
+
center[0] += center_shift
|
143 |
+
center[1] += center_shift
|
144 |
+
new_landmarks = new_landmarks + center_shift
|
145 |
+
length = 200 * scale
|
146 |
+
top = int(center[1] - length // 2)
|
147 |
+
bottom = int(center[1] + length // 2)
|
148 |
+
left = int(center[0] - length // 2)
|
149 |
+
right = int(center[0] + length // 2)
|
150 |
+
y_pad = abs(min(top, new_image.shape[0] - bottom, 0))
|
151 |
+
x_pad = abs(min(left, new_image.shape[1] - right, 0))
|
152 |
+
top, bottom, left, right = top + y_pad, bottom + y_pad, left + x_pad, right + x_pad
|
153 |
+
new_image = cv2.copyMakeBorder(new_image, y_pad, y_pad, x_pad, x_pad, cv2.BORDER_CONSTANT, value=[0, 0, 0])
|
154 |
+
new_image = new_image[top:bottom, left:right]
|
155 |
+
new_image = cv2.resize(new_image, dsize=(int(resolution), int(resolution)), interpolation=cv2.INTER_LINEAR)
|
156 |
+
new_landmarks[:, 0] = (new_landmarks[:, 0] + x_pad - left) * resolution / length
|
157 |
+
new_landmarks[:, 1] = (new_landmarks[:, 1] + y_pad - top) * resolution / length
|
158 |
+
return new_image, new_landmarks
|
159 |
+
|
160 |
+
|
161 |
+
def cv_rotate(image, landmarks, heatmap, rot, scale, resolution=256):
|
162 |
+
img_mat = cv2.getRotationMatrix2D((resolution // 2, resolution // 2), rot, scale)
|
163 |
+
ones = np.ones(shape=(landmarks.shape[0], 1))
|
164 |
+
stacked_landmarks = np.hstack([landmarks, ones])
|
165 |
+
new_landmarks = img_mat.dot(stacked_landmarks.T).T
|
166 |
+
if np.max(new_landmarks) > 255 or np.min(new_landmarks) < 0:
|
167 |
+
return image, landmarks, heatmap
|
168 |
+
else:
|
169 |
+
new_image = cv2.warpAffine(image, img_mat, (resolution, resolution))
|
170 |
+
if heatmap is not None:
|
171 |
+
new_heatmap = np.zeros((heatmap.shape[0], 64, 64))
|
172 |
+
for i in range(heatmap.shape[0]):
|
173 |
+
if new_landmarks[i][0] > 0:
|
174 |
+
new_heatmap[i] = draw_gaussian(new_heatmap[i], new_landmarks[i] / 4.0 + 1, 1)
|
175 |
+
return new_image, new_landmarks, new_heatmap
|
176 |
+
|
177 |
+
|
178 |
+
def show_landmarks(image, heatmap, gt_landmarks, gt_heatmap):
|
179 |
+
"""Show image with pred_landmarks"""
|
180 |
+
pred_landmarks = []
|
181 |
+
pred_landmarks, _ = get_preds_fromhm(torch.from_numpy(heatmap).unsqueeze(0))
|
182 |
+
pred_landmarks = pred_landmarks.squeeze() * 4
|
183 |
+
|
184 |
+
# pred_landmarks2 = get_preds_fromhm2(heatmap)
|
185 |
+
heatmap = np.max(gt_heatmap, axis=0)
|
186 |
+
heatmap = heatmap / np.max(heatmap)
|
187 |
+
# image = ski_transform.resize(image, (64, 64))*255
|
188 |
+
image = image.astype(np.uint8)
|
189 |
+
heatmap = np.max(gt_heatmap, axis=0)
|
190 |
+
heatmap = ski_transform.resize(heatmap, (image.shape[0], image.shape[1]))
|
191 |
+
heatmap *= 255
|
192 |
+
heatmap = heatmap.astype(np.uint8)
|
193 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
194 |
+
plt.imshow(image)
|
195 |
+
plt.scatter(gt_landmarks[:, 0], gt_landmarks[:, 1], s=0.5, marker=".", c="g")
|
196 |
+
plt.scatter(pred_landmarks[:, 0], pred_landmarks[:, 1], s=0.5, marker=".", c="r")
|
197 |
+
plt.pause(0.001) # pause a bit so that plots are updated
|
198 |
+
|
199 |
+
|
200 |
+
def fan_NME(pred_heatmaps, gt_landmarks, num_landmarks=68):
|
201 |
+
"""
|
202 |
+
Calculate total NME for a batch of data
|
203 |
+
|
204 |
+
Args:
|
205 |
+
pred_heatmaps: torch tensor of size [batch, points, height, width]
|
206 |
+
gt_landmarks: torch tesnsor of size [batch, points, x, y]
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
nme: sum of nme for this batch
|
210 |
+
"""
|
211 |
+
nme = 0
|
212 |
+
pred_landmarks, _ = get_preds_fromhm(pred_heatmaps)
|
213 |
+
pred_landmarks = pred_landmarks.numpy()
|
214 |
+
gt_landmarks = gt_landmarks.numpy()
|
215 |
+
for i in range(pred_landmarks.shape[0]):
|
216 |
+
pred_landmark = pred_landmarks[i] * 4.0
|
217 |
+
gt_landmark = gt_landmarks[i]
|
218 |
+
|
219 |
+
if num_landmarks == 68:
|
220 |
+
left_eye = np.average(gt_landmark[36:42], axis=0)
|
221 |
+
right_eye = np.average(gt_landmark[42:48], axis=0)
|
222 |
+
norm_factor = np.linalg.norm(left_eye - right_eye)
|
223 |
+
# norm_factor = np.linalg.norm(gt_landmark[36]- gt_landmark[45])
|
224 |
+
elif num_landmarks == 98:
|
225 |
+
norm_factor = np.linalg.norm(gt_landmark[60] - gt_landmark[72])
|
226 |
+
elif num_landmarks == 19:
|
227 |
+
left, top = gt_landmark[-2, :]
|
228 |
+
right, bottom = gt_landmark[-1, :]
|
229 |
+
norm_factor = math.sqrt(abs(right - left) * abs(top - bottom))
|
230 |
+
gt_landmark = gt_landmark[:-2, :]
|
231 |
+
elif num_landmarks == 29:
|
232 |
+
# norm_factor = np.linalg.norm(gt_landmark[8]- gt_landmark[9])
|
233 |
+
norm_factor = np.linalg.norm(gt_landmark[16] - gt_landmark[17])
|
234 |
+
nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
|
235 |
+
return nme
|
236 |
+
|
237 |
+
|
238 |
+
def fan_NME_hm(pred_heatmaps, gt_heatmaps, num_landmarks=68):
|
239 |
+
"""
|
240 |
+
Calculate total NME for a batch of data
|
241 |
+
|
242 |
+
Args:
|
243 |
+
pred_heatmaps: torch tensor of size [batch, points, height, width]
|
244 |
+
gt_landmarks: torch tesnsor of size [batch, points, x, y]
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
nme: sum of nme for this batch
|
248 |
+
"""
|
249 |
+
nme = 0
|
250 |
+
pred_landmarks, _ = get_index_fromhm(pred_heatmaps)
|
251 |
+
pred_landmarks = pred_landmarks.numpy()
|
252 |
+
gt_landmarks = gt_landmarks.numpy()
|
253 |
+
for i in range(pred_landmarks.shape[0]):
|
254 |
+
pred_landmark = pred_landmarks[i] * 4.0
|
255 |
+
gt_landmark = gt_landmarks[i]
|
256 |
+
if num_landmarks == 68:
|
257 |
+
left_eye = np.average(gt_landmark[36:42], axis=0)
|
258 |
+
right_eye = np.average(gt_landmark[42:48], axis=0)
|
259 |
+
norm_factor = np.linalg.norm(left_eye - right_eye)
|
260 |
+
else:
|
261 |
+
norm_factor = np.linalg.norm(gt_landmark[60] - gt_landmark[72])
|
262 |
+
nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
|
263 |
+
return nme
|
264 |
+
|
265 |
+
|
266 |
+
def power_transform(img, power):
|
267 |
+
img = np.array(img)
|
268 |
+
img_new = np.power((img / 255.0), power) * 255.0
|
269 |
+
img_new = img_new.astype(np.uint8)
|
270 |
+
img_new = Image.fromarray(img_new)
|
271 |
+
return img_new
|
272 |
+
|
273 |
+
|
274 |
+
def get_preds_fromhm(hm, center=None, scale=None, rot=None):
|
275 |
+
max, idx = torch.max(hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
276 |
+
idx += 1
|
277 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
278 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
279 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
280 |
+
|
281 |
+
for i in range(preds.size(0)):
|
282 |
+
for j in range(preds.size(1)):
|
283 |
+
hm_ = hm[i, j, :]
|
284 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
285 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
286 |
+
diff = torch.FloatTensor([hm_[pY, pX + 1] - hm_[pY, pX - 1], hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
287 |
+
preds[i, j].add_(diff.sign_().mul_(0.25))
|
288 |
+
|
289 |
+
preds.add_(-0.5)
|
290 |
+
|
291 |
+
preds_orig = torch.zeros(preds.size())
|
292 |
+
if center is not None and scale is not None:
|
293 |
+
for i in range(hm.size(0)):
|
294 |
+
for j in range(hm.size(1)):
|
295 |
+
preds_orig[i, j] = transform(preds[i, j], center, scale, hm.size(2), rot, True)
|
296 |
+
|
297 |
+
return preds, preds_orig
|
298 |
+
|
299 |
+
|
300 |
+
def get_index_fromhm(hm):
|
301 |
+
max, idx = torch.max(hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
302 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
303 |
+
preds[..., 0].remainder_(hm.size(3))
|
304 |
+
preds[..., 1].div_(hm.size(2)).floor_()
|
305 |
+
|
306 |
+
for i in range(preds.size(0)):
|
307 |
+
for j in range(preds.size(1)):
|
308 |
+
hm_ = hm[i, j, :]
|
309 |
+
pX, pY = int(preds[i, j, 0]), int(preds[i, j, 1])
|
310 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
311 |
+
diff = torch.FloatTensor([hm_[pY, pX + 1] - hm_[pY, pX - 1], hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
312 |
+
preds[i, j].add_(diff.sign_().mul_(0.25))
|
313 |
+
|
314 |
+
return preds
|
315 |
+
|
316 |
+
|
317 |
+
def shuffle_lr(parts, num_landmarks=68, pairs=None):
|
318 |
+
if num_landmarks == 68:
|
319 |
+
if pairs is None:
|
320 |
+
pairs = [
|
321 |
+
[0, 16],
|
322 |
+
[1, 15],
|
323 |
+
[2, 14],
|
324 |
+
[3, 13],
|
325 |
+
[4, 12],
|
326 |
+
[5, 11],
|
327 |
+
[6, 10],
|
328 |
+
[7, 9],
|
329 |
+
[17, 26],
|
330 |
+
[18, 25],
|
331 |
+
[19, 24],
|
332 |
+
[20, 23],
|
333 |
+
[21, 22],
|
334 |
+
[36, 45],
|
335 |
+
[37, 44],
|
336 |
+
[38, 43],
|
337 |
+
[39, 42],
|
338 |
+
[41, 46],
|
339 |
+
[40, 47],
|
340 |
+
[31, 35],
|
341 |
+
[32, 34],
|
342 |
+
[50, 52],
|
343 |
+
[49, 53],
|
344 |
+
[48, 54],
|
345 |
+
[61, 63],
|
346 |
+
[60, 64],
|
347 |
+
[67, 65],
|
348 |
+
[59, 55],
|
349 |
+
[58, 56],
|
350 |
+
]
|
351 |
+
elif num_landmarks == 98:
|
352 |
+
if pairs is None:
|
353 |
+
pairs = [
|
354 |
+
[0, 32],
|
355 |
+
[1, 31],
|
356 |
+
[2, 30],
|
357 |
+
[3, 29],
|
358 |
+
[4, 28],
|
359 |
+
[5, 27],
|
360 |
+
[6, 26],
|
361 |
+
[7, 25],
|
362 |
+
[8, 24],
|
363 |
+
[9, 23],
|
364 |
+
[10, 22],
|
365 |
+
[11, 21],
|
366 |
+
[12, 20],
|
367 |
+
[13, 19],
|
368 |
+
[14, 18],
|
369 |
+
[15, 17],
|
370 |
+
[33, 46],
|
371 |
+
[34, 45],
|
372 |
+
[35, 44],
|
373 |
+
[36, 43],
|
374 |
+
[37, 42],
|
375 |
+
[38, 50],
|
376 |
+
[39, 49],
|
377 |
+
[40, 48],
|
378 |
+
[41, 47],
|
379 |
+
[60, 72],
|
380 |
+
[61, 71],
|
381 |
+
[62, 70],
|
382 |
+
[63, 69],
|
383 |
+
[64, 68],
|
384 |
+
[65, 75],
|
385 |
+
[66, 74],
|
386 |
+
[67, 73],
|
387 |
+
[96, 97],
|
388 |
+
[55, 59],
|
389 |
+
[56, 58],
|
390 |
+
[76, 82],
|
391 |
+
[77, 81],
|
392 |
+
[78, 80],
|
393 |
+
[88, 92],
|
394 |
+
[89, 91],
|
395 |
+
[95, 93],
|
396 |
+
[87, 83],
|
397 |
+
[86, 84],
|
398 |
+
]
|
399 |
+
elif num_landmarks == 19:
|
400 |
+
if pairs is None:
|
401 |
+
pairs = [[0, 5], [1, 4], [2, 3], [6, 11], [7, 10], [8, 9], [12, 14], [15, 17]]
|
402 |
+
elif num_landmarks == 29:
|
403 |
+
if pairs is None:
|
404 |
+
pairs = [[0, 1], [4, 6], [5, 7], [2, 3], [8, 9], [12, 14], [16, 17], [13, 15], [10, 11], [18, 19], [22, 23]]
|
405 |
+
for matched_p in pairs:
|
406 |
+
idx1, idx2 = matched_p[0], matched_p[1]
|
407 |
+
tmp = np.copy(parts[idx1])
|
408 |
+
np.copyto(parts[idx1], parts[idx2])
|
409 |
+
np.copyto(parts[idx2], tmp)
|
410 |
+
return parts
|
411 |
+
|
412 |
+
|
413 |
+
def generate_weight_map(weight_map, heatmap):
|
414 |
+
|
415 |
+
k_size = 3
|
416 |
+
dilate = ndimage.grey_dilation(heatmap, size=(k_size, k_size))
|
417 |
+
weight_map[np.where(dilate > 0.2)] = 1
|
418 |
+
return weight_map
|
419 |
+
|
420 |
+
|
421 |
+
def fig2data(fig):
|
422 |
+
"""
|
423 |
+
@brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
|
424 |
+
@param fig a matplotlib figure
|
425 |
+
@return a numpy 3D array of RGBA values
|
426 |
+
"""
|
427 |
+
# draw the renderer
|
428 |
+
fig.canvas.draw()
|
429 |
+
|
430 |
+
# Get the RGB buffer from the figure
|
431 |
+
w, h = fig.canvas.get_width_height()
|
432 |
+
buf = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
433 |
+
buf.shape = (w, h, 3)
|
434 |
+
|
435 |
+
# canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
|
436 |
+
buf = np.roll(buf, 3, axis=2)
|
437 |
+
return buf
|
Deep3DFaceRecon_pytorch/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Sicheng Xu
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Deep3DFaceRecon_pytorch/README.md
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set —— PyTorch implementation ##
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<img src="/images/example.gif">
|
5 |
+
</p>
|
6 |
+
|
7 |
+
This is an unofficial official pytorch implementation of the following paper:
|
8 |
+
|
9 |
+
Y. Deng, J. Yang, S. Xu, D. Chen, Y. Jia, and X. Tong, [Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set](https://arxiv.org/abs/1903.08527), IEEE Computer Vision and Pattern Recognition Workshop (CVPRW) on Analysis and Modeling of Faces and Gestures (AMFG), 2019. (**_Best Paper Award!_**)
|
10 |
+
|
11 |
+
The method enforces a hybrid-level weakly-supervised training for CNN-based 3D face reconstruction. It is fast, accurate, and robust to pose and occlussions. It achieves state-of-the-art performance on multiple datasets such as FaceWarehouse, MICC Florence and NoW Challenge.
|
12 |
+
|
13 |
+
|
14 |
+
For the original tensorflow implementation, check this [repo](https://github.com/microsoft/Deep3DFaceReconstruction).
|
15 |
+
|
16 |
+
This implementation is written by S. Xu.
|
17 |
+
|
18 |
+
## Performance
|
19 |
+
|
20 |
+
### ● Reconstruction accuracy
|
21 |
+
|
22 |
+
The pytorch implementation achieves lower shape reconstruction error (9% improvement) compare to the [original tensorflow implementation](https://github.com/microsoft/Deep3DFaceReconstruction). Quantitative evaluation (average shape errors in mm) on several benchmarks is as follows:
|
23 |
+
|
24 |
+
|Method|FaceWareHouse|MICC Florence | NoW Challenge |
|
25 |
+
|:----:|:-----------:|:-----------:|:-----------:|
|
26 |
+
|Deep3DFace Tensorflow | 1.81±0.50 | 1.67±0.50 | 1.54±1.29 |
|
27 |
+
|**Deep3DFace PyTorch** |**1.64±0.50**|**1.53±0.45**| **1.41±1.21** |
|
28 |
+
|
29 |
+
The comparison result with state-of-the-art public 3D face reconstruction methods on the NoW face benchmark is as follows:
|
30 |
+
|Rank|Method|Median(mm) | Mean(mm) | Std(mm) |
|
31 |
+
|:----:|:-----------:|:-----------:|:-----------:|:-----------:|
|
32 |
+
| 1. | [DECA\[Feng et al., SIGGRAPH 2021\]](https://github.com/YadiraF/DECA)|1.09|1.38|1.18|
|
33 |
+
| **2.** | **Deep3DFace PyTorch**|**1.11**|**1.41**|**1.21**|
|
34 |
+
| 3. | [RingNet [Sanyal et al., CVPR 2019]](https://github.com/soubhiksanyal/RingNet) | 1.21 | 1.53 | 1.31 |
|
35 |
+
| 4. | [Deep3DFace [Deng et al., CVPRW 2019]](https://github.com/microsoft/Deep3DFaceReconstruction) | 1.23 | 1.54 | 1.29 |
|
36 |
+
| 5. | [3DDFA-V2 [Guo et al., ECCV 2020]](https://github.com/cleardusk/3DDFA_V2) | 1.23 | 1.57 | 1.39 |
|
37 |
+
| 6. | [MGCNet [Shang et al., ECCV 2020]](https://github.com/jiaxiangshang/MGCNet) | 1.31 | 1.87 | 2.63 |
|
38 |
+
| 7. | [PRNet [Feng et al., ECCV 2018]](https://github.com/YadiraF/PRNet) | 1.50 | 1.98 | 1.88 |
|
39 |
+
| 8. | [3DMM-CNN [Tran et al., CVPR 2017]](https://github.com/anhttran/3dmm_cnn) | 1.84 | 2.33 | 2.05 |
|
40 |
+
|
41 |
+
For more details about the evaluation, check [Now Challenge](https://ringnet.is.tue.mpg.de/challenge.html) website.
|
42 |
+
|
43 |
+
**_A recent benchmark [REALY](https://www.realy3dface.com/) indicates that our method still has the SOTA performance! You can check their paper and website for more details._**
|
44 |
+
|
45 |
+
### ● Visual quality
|
46 |
+
The pytorch implementation achieves better visual consistency with the input images compare to the original tensorflow version.
|
47 |
+
|
48 |
+
<p align="center">
|
49 |
+
<img src="/images/compare.png">
|
50 |
+
</p>
|
51 |
+
|
52 |
+
### ● Speed
|
53 |
+
The training speed is on par with the original tensorflow implementation. For more information, see [here](https://github.com/sicxu/Deep3DFaceRecon_pytorch#train-the-face-reconstruction-network).
|
54 |
+
|
55 |
+
## Major changes
|
56 |
+
|
57 |
+
### ● Differentiable renderer
|
58 |
+
|
59 |
+
We use [Nvdiffrast](https://nvlabs.github.io/nvdiffrast/) which is a pytorch library that provides high-performance primitive operations for rasterization-based differentiable rendering. The original tensorflow implementation used [tf_mesh_renderer](https://github.com/google/tf_mesh_renderer) instead.
|
60 |
+
|
61 |
+
### ● Face recognition model
|
62 |
+
|
63 |
+
We use [Arcface](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch), a state-of-the-art face recognition model, for perceptual loss computation. By contrast, the original tensorflow implementation used [Facenet](https://github.com/davidsandberg/facenet).
|
64 |
+
|
65 |
+
### ● Training configuration
|
66 |
+
|
67 |
+
Data augmentation is used in the training process which contains random image shifting, scaling, rotation, and flipping. We also enlarge the training batchsize from 5 to 32 to stablize the training process.
|
68 |
+
|
69 |
+
### ● Training data
|
70 |
+
|
71 |
+
We use an extra high quality face image dataset [FFHQ](https://github.com/NVlabs/ffhq-dataset) to increase the diversity of training data.
|
72 |
+
|
73 |
+
## Requirements
|
74 |
+
**This implementation is only tested under Ubuntu environment with Nvidia GPUs and CUDA installed.**
|
75 |
+
|
76 |
+
## Installation
|
77 |
+
1. Clone the repository and set up a conda environment with all dependencies as follows:
|
78 |
+
```
|
79 |
+
git clone https://github.com/sicxu/Deep3DFaceRecon_pytorch.git
|
80 |
+
cd Deep3DFaceRecon_pytorch
|
81 |
+
conda env create -f environment.yml
|
82 |
+
source activate deep3d_pytorch
|
83 |
+
```
|
84 |
+
|
85 |
+
2. Install Nvdiffrast library:
|
86 |
+
```
|
87 |
+
git clone https://github.com/NVlabs/nvdiffrast
|
88 |
+
cd nvdiffrast # ./Deep3DFaceRecon_pytorch/nvdiffrast
|
89 |
+
pip install .
|
90 |
+
```
|
91 |
+
|
92 |
+
3. Install Arcface Pytorch:
|
93 |
+
```
|
94 |
+
cd .. # ./Deep3DFaceRecon_pytorch
|
95 |
+
git clone https://github.com/deepinsight/insightface.git
|
96 |
+
cp -r ./insightface/recognition/arcface_torch ./models/
|
97 |
+
```
|
98 |
+
## Inference with a pre-trained model
|
99 |
+
|
100 |
+
### Prepare prerequisite models
|
101 |
+
1. Our method uses [Basel Face Model 2009 (BFM09)](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-0&id=basel_face_model) to represent 3d faces. Get access to BFM09 using this [link](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-2&id=downloads). After getting the access, download "01_MorphableModel.mat". In addition, we use an Expression Basis provided by [Guo et al.](https://github.com/Juyong/3DFace). Download the Expression Basis (Exp_Pca.bin) using this [link (google drive)](https://drive.google.com/file/d/1bw5Xf8C12pWmcMhNEu6PtsYVZkVucEN6/view?usp=sharing). Organize all files into the following structure:
|
102 |
+
```
|
103 |
+
Deep3DFaceRecon_pytorch
|
104 |
+
│
|
105 |
+
└─── BFM
|
106 |
+
│
|
107 |
+
└─── 01_MorphableModel.mat
|
108 |
+
│
|
109 |
+
└─── Exp_Pca.bin
|
110 |
+
|
|
111 |
+
└─── ...
|
112 |
+
```
|
113 |
+
2. We provide a model trained on a combination of [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html),
|
114 |
+
[LFW](http://vis-www.cs.umass.edu/lfw/), [300WLP](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm),
|
115 |
+
[IJB-A](https://www.nist.gov/programs-projects/face-challenges), [LS3D-W](https://www.adrianbulat.com/face-alignment), and [FFHQ](https://github.com/NVlabs/ffhq-dataset) datasets. Download the pre-trained model using this [link (google drive)](https://drive.google.com/drive/folders/1liaIxn9smpudjjqMaWWRpP0mXRW_qRPP?usp=sharing) and organize the directory into the following structure:
|
116 |
+
```
|
117 |
+
Deep3DFaceRecon_pytorch
|
118 |
+
│
|
119 |
+
└─── checkpoints
|
120 |
+
│
|
121 |
+
└─── <model_name>
|
122 |
+
│
|
123 |
+
└─── epoch_20.pth
|
124 |
+
|
125 |
+
```
|
126 |
+
|
127 |
+
### Test with custom images
|
128 |
+
To reconstruct 3d faces from test images, organize the test image folder as follows:
|
129 |
+
```
|
130 |
+
Deep3DFaceRecon_pytorch
|
131 |
+
│
|
132 |
+
└─── <folder_to_test_images>
|
133 |
+
│
|
134 |
+
└─── *.jpg/*.png
|
135 |
+
|
|
136 |
+
└─── detections
|
137 |
+
|
|
138 |
+
└─── *.txt
|
139 |
+
```
|
140 |
+
The \*.jpg/\*.png files are test images. The \*.txt files are detected 5 facial landmarks with a shape of 5x2, and have the same name as the corresponding images. Check [./datasets/examples](datasets/examples) for a reference.
|
141 |
+
|
142 |
+
Then, run the test script:
|
143 |
+
```
|
144 |
+
# get reconstruction results of your custom images
|
145 |
+
python test.py --name=<model_name> --epoch=20 --img_folder=<folder_to_test_images>
|
146 |
+
|
147 |
+
# get reconstruction results of example images
|
148 |
+
python test.py --name=<model_name> --epoch=20 --img_folder=./datasets/examples
|
149 |
+
```
|
150 |
+
**_Following [#108](https://github.com/sicxu/Deep3DFaceRecon_pytorch/issues/108), if you don't have OpenGL environment, you can simply add "--use_opengl False" to use CUDA context. Make sure you have updated the nvdiffrast to the latest version._**
|
151 |
+
|
152 |
+
Results will be saved into ./checkpoints/<model_name>/results/<folder_to_test_images>, which contain the following files:
|
153 |
+
| \*.png | A combination of cropped input image, reconstructed image, and visualization of projected landmarks.
|
154 |
+
|:----|:-----------|
|
155 |
+
| \*.obj | Reconstructed 3d face mesh with predicted color (texture+illumination) in the world coordinate space. Best viewed in Meshlab. |
|
156 |
+
| \*.mat | Predicted 257-dimensional coefficients and 68 projected 2d facial landmarks. Best viewed in Matlab.
|
157 |
+
|
158 |
+
## Training a model from scratch
|
159 |
+
### Prepare prerequisite models
|
160 |
+
1. We rely on [Arcface](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch) to extract identity features for loss computation. Download the pre-trained model from Arcface using this [link](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch#ms1mv3). By default, we use the resnet50 backbone ([ms1mv3_arcface_r50_fp16](https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215583&cid=4A83B6B633B029CC)), organize the download files into the following structure:
|
161 |
+
```
|
162 |
+
Deep3DFaceRecon_pytorch
|
163 |
+
│
|
164 |
+
└─── checkpoints
|
165 |
+
│
|
166 |
+
└─── recog_model
|
167 |
+
│
|
168 |
+
└─── ms1mv3_arcface_r50_fp16
|
169 |
+
|
|
170 |
+
└─── backbone.pth
|
171 |
+
```
|
172 |
+
2. We initialize R-Net using the weights trained on [ImageNet](https://image-net.org/). Download the weights provided by PyTorch using this [link](https://download.pytorch.org/models/resnet50-0676ba61.pth), and organize the file as the following structure:
|
173 |
+
```
|
174 |
+
Deep3DFaceRecon_pytorch
|
175 |
+
│
|
176 |
+
└─── checkpoints
|
177 |
+
│
|
178 |
+
└─── init_model
|
179 |
+
│
|
180 |
+
└─── resnet50-0676ba61.pth
|
181 |
+
```
|
182 |
+
3. We provide a landmark detector (tensorflow model) to extract 68 facial landmarks for loss computation. The detector is trained on [300WLP](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm), [LFW](http://vis-www.cs.umass.edu/lfw/), and [LS3D-W](https://www.adrianbulat.com/face-alignment) datasets. Download the trained model using this [link (google drive)](https://drive.google.com/file/d/1Jl1yy2v7lIJLTRVIpgg2wvxYITI8Dkmw/view?usp=sharing) and organize the file as follows:
|
183 |
+
```
|
184 |
+
Deep3DFaceRecon_pytorch
|
185 |
+
│
|
186 |
+
└─── checkpoints
|
187 |
+
│
|
188 |
+
└─── lm_model
|
189 |
+
│
|
190 |
+
└─── 68lm_detector.pb
|
191 |
+
```
|
192 |
+
### Data preparation
|
193 |
+
1. To train a model with custom images,5 facial landmarks of each image are needed in advance for an image pre-alignment process. We recommend using [dlib](http://dlib.net/) or [MTCNN](https://github.com/ipazc/mtcnn) to detect these landmarks. Then, organize all files into the following structure:
|
194 |
+
```
|
195 |
+
Deep3DFaceRecon_pytorch
|
196 |
+
│
|
197 |
+
└─── datasets
|
198 |
+
│
|
199 |
+
└─── <folder_to_training_images>
|
200 |
+
│
|
201 |
+
└─── *.png/*.jpg
|
202 |
+
|
|
203 |
+
└─── detections
|
204 |
+
|
|
205 |
+
└─── *.txt
|
206 |
+
```
|
207 |
+
The \*.txt files contain 5 facial landmarks with a shape of 5x2, and should have the same name with their corresponding images.
|
208 |
+
|
209 |
+
2. Generate 68 landmarks and skin attention mask for images using the following script:
|
210 |
+
```
|
211 |
+
# preprocess training images
|
212 |
+
python data_preparation.py --img_folder <folder_to_training_images>
|
213 |
+
|
214 |
+
# alternatively, you can preprocess multiple image folders simultaneously
|
215 |
+
python data_preparation.py --img_folder <folder_to_training_images1> <folder_to_training_images2> <folder_to_training_images3>
|
216 |
+
|
217 |
+
# preprocess validation images
|
218 |
+
python data_preparation.py --img_folder <folder_to_validation_images> --mode=val
|
219 |
+
```
|
220 |
+
The script will generate files of landmarks and skin masks, and save them into ./datasets/<folder_to_training_images>. In addition, it also generates a file containing the path of all training data into ./datalist which will then be used in the training script.
|
221 |
+
|
222 |
+
### Train the face reconstruction network
|
223 |
+
Run the following script to train a face reconstruction model using the pre-processed data:
|
224 |
+
```
|
225 |
+
# train with single GPU
|
226 |
+
python train.py --name=<custom_experiment_name> --gpu_ids=0
|
227 |
+
|
228 |
+
# train with multiple GPUs
|
229 |
+
python train.py --name=<custom_experiment_name> --gpu_ids=0,1
|
230 |
+
|
231 |
+
# train with other custom settings
|
232 |
+
python train.py --name=<custom_experiment_name> --gpu_ids=0 --batch_size=32 --n_epochs=20
|
233 |
+
```
|
234 |
+
Training logs and model parameters will be saved into ./checkpoints/<custom_experiment_name>.
|
235 |
+
|
236 |
+
By default, the script uses a batchsize of 32 and will train the model with 20 epochs. For reference, the pre-trained model in this repo is trained with the default setting on a image collection of 300k images. A single iteration takes 0.8~0.9s on a single Tesla M40 GPU. The total training process takes around two days.
|
237 |
+
|
238 |
+
To use a trained model, see [Inference](https://github.com/sicxu/Deep3DFaceRecon_pytorch#inference-with-a-pre-trained-model) section.
|
239 |
+
## Contact
|
240 |
+
If you have any questions, please contact the paper authors.
|
241 |
+
|
242 |
+
## Citation
|
243 |
+
|
244 |
+
Please cite the following paper if this model helps your research:
|
245 |
+
|
246 |
+
@inproceedings{deng2019accurate,
|
247 |
+
title={Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set},
|
248 |
+
author={Yu Deng and Jiaolong Yang and Sicheng Xu and Dong Chen and Yunde Jia and Xin Tong},
|
249 |
+
booktitle={IEEE Computer Vision and Pattern Recognition Workshops},
|
250 |
+
year={2019}
|
251 |
+
}
|
252 |
+
##
|
253 |
+
The face images on this page are from the public [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset released by MMLab, CUHK.
|
254 |
+
|
255 |
+
Part of the code in this implementation takes [CUT](https://github.com/taesungp/contrastive-unpaired-translation) as a reference.
|
256 |
+
|
Deep3DFaceRecon_pytorch/data/__init__.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import importlib
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch.utils.data
|
17 |
+
from data.base_dataset import BaseDataset
|
18 |
+
|
19 |
+
|
20 |
+
def find_dataset_using_name(dataset_name):
|
21 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
22 |
+
|
23 |
+
In the file, the class called DatasetNameDataset() will
|
24 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
25 |
+
and it is case-insensitive.
|
26 |
+
"""
|
27 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
28 |
+
datasetlib = importlib.import_module(dataset_filename)
|
29 |
+
|
30 |
+
dataset = None
|
31 |
+
target_dataset_name = dataset_name.replace("_", "") + "dataset"
|
32 |
+
for name, cls in datasetlib.__dict__.items():
|
33 |
+
if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset):
|
34 |
+
dataset = cls
|
35 |
+
|
36 |
+
if dataset is None:
|
37 |
+
raise NotImplementedError(
|
38 |
+
"In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase."
|
39 |
+
% (dataset_filename, target_dataset_name)
|
40 |
+
)
|
41 |
+
|
42 |
+
return dataset
|
43 |
+
|
44 |
+
|
45 |
+
def get_option_setter(dataset_name):
|
46 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
47 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
48 |
+
return dataset_class.modify_commandline_options
|
49 |
+
|
50 |
+
|
51 |
+
def create_dataset(opt, rank=0):
|
52 |
+
"""Create a dataset given the option.
|
53 |
+
|
54 |
+
This function wraps the class CustomDatasetDataLoader.
|
55 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
56 |
+
|
57 |
+
Example:
|
58 |
+
>>> from data import create_dataset
|
59 |
+
>>> dataset = create_dataset(opt)
|
60 |
+
"""
|
61 |
+
data_loader = CustomDatasetDataLoader(opt, rank=rank)
|
62 |
+
dataset = data_loader.load_data()
|
63 |
+
return dataset
|
64 |
+
|
65 |
+
|
66 |
+
class CustomDatasetDataLoader:
|
67 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
68 |
+
|
69 |
+
def __init__(self, opt, rank=0):
|
70 |
+
"""Initialize this class
|
71 |
+
|
72 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
73 |
+
Step 2: create a multi-threaded data loader.
|
74 |
+
"""
|
75 |
+
self.opt = opt
|
76 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
77 |
+
self.dataset = dataset_class(opt)
|
78 |
+
self.sampler = None
|
79 |
+
print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
|
80 |
+
if opt.use_ddp and opt.isTrain:
|
81 |
+
world_size = opt.world_size
|
82 |
+
self.sampler = torch.utils.data.distributed.DistributedSampler(
|
83 |
+
self.dataset, num_replicas=world_size, rank=rank, shuffle=not opt.serial_batches
|
84 |
+
)
|
85 |
+
self.dataloader = torch.utils.data.DataLoader(
|
86 |
+
self.dataset,
|
87 |
+
sampler=self.sampler,
|
88 |
+
num_workers=int(opt.num_threads / world_size),
|
89 |
+
batch_size=int(opt.batch_size / world_size),
|
90 |
+
drop_last=True,
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
self.dataloader = torch.utils.data.DataLoader(
|
94 |
+
self.dataset,
|
95 |
+
batch_size=opt.batch_size,
|
96 |
+
shuffle=(not opt.serial_batches) and opt.isTrain,
|
97 |
+
num_workers=int(opt.num_threads),
|
98 |
+
drop_last=True,
|
99 |
+
)
|
100 |
+
|
101 |
+
def set_epoch(self, epoch):
|
102 |
+
self.dataset.current_epoch = epoch
|
103 |
+
if self.sampler is not None:
|
104 |
+
self.sampler.set_epoch(epoch)
|
105 |
+
|
106 |
+
def load_data(self):
|
107 |
+
return self
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
"""Return the number of data in the dataset"""
|
111 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
112 |
+
|
113 |
+
def __iter__(self):
|
114 |
+
"""Return a batch of data"""
|
115 |
+
for i, data in enumerate(self.dataloader):
|
116 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
117 |
+
break
|
118 |
+
yield data
|
Deep3DFaceRecon_pytorch/data/base_dataset.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
2 |
+
|
3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
from abc import ABC
|
7 |
+
from abc import abstractmethod
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch.utils.data as data
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
|
15 |
+
class BaseDataset(data.Dataset, ABC):
|
16 |
+
"""This class is an abstract base class (ABC) for datasets.
|
17 |
+
|
18 |
+
To create a subclass, you need to implement the following four functions:
|
19 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
20 |
+
-- <__len__>: return the size of dataset.
|
21 |
+
-- <__getitem__>: get a data point.
|
22 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, opt):
|
26 |
+
"""Initialize the class; save the options in the class
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
30 |
+
"""
|
31 |
+
self.opt = opt
|
32 |
+
# self.root = opt.dataroot
|
33 |
+
self.current_epoch = 0
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def modify_commandline_options(parser, is_train):
|
37 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
parser -- original option parser
|
41 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
the modified parser.
|
45 |
+
"""
|
46 |
+
return parser
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def __len__(self):
|
50 |
+
"""Return the total number of images in the dataset."""
|
51 |
+
return 0
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
def __getitem__(self, index):
|
55 |
+
"""Return a data point and its metadata information.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
index - - a random integer for data indexing
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
62 |
+
"""
|
63 |
+
pass
|
64 |
+
|
65 |
+
|
66 |
+
def get_transform(grayscale=False):
|
67 |
+
transform_list = []
|
68 |
+
if grayscale:
|
69 |
+
transform_list.append(transforms.Grayscale(1))
|
70 |
+
transform_list += [transforms.ToTensor()]
|
71 |
+
return transforms.Compose(transform_list)
|
72 |
+
|
73 |
+
|
74 |
+
def get_affine_mat(opt, size):
|
75 |
+
shift_x, shift_y, scale, rot_angle, flip = 0.0, 0.0, 1.0, 0.0, False
|
76 |
+
w, h = size
|
77 |
+
|
78 |
+
if "shift" in opt.preprocess:
|
79 |
+
shift_pixs = int(opt.shift_pixs)
|
80 |
+
shift_x = random.randint(-shift_pixs, shift_pixs)
|
81 |
+
shift_y = random.randint(-shift_pixs, shift_pixs)
|
82 |
+
if "scale" in opt.preprocess:
|
83 |
+
scale = 1 + opt.scale_delta * (2 * random.random() - 1)
|
84 |
+
if "rot" in opt.preprocess:
|
85 |
+
rot_angle = opt.rot_angle * (2 * random.random() - 1)
|
86 |
+
rot_rad = -rot_angle * np.pi / 180
|
87 |
+
if "flip" in opt.preprocess:
|
88 |
+
flip = random.random() > 0.5
|
89 |
+
|
90 |
+
shift_to_origin = np.array([1, 0, -w // 2, 0, 1, -h // 2, 0, 0, 1]).reshape([3, 3])
|
91 |
+
flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
|
92 |
+
shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
|
93 |
+
rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape(
|
94 |
+
[3, 3]
|
95 |
+
)
|
96 |
+
scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
|
97 |
+
shift_to_center = np.array([1, 0, w // 2, 0, 1, h // 2, 0, 0, 1]).reshape([3, 3])
|
98 |
+
|
99 |
+
affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
|
100 |
+
affine_inv = np.linalg.inv(affine)
|
101 |
+
return affine, affine_inv, flip
|
102 |
+
|
103 |
+
|
104 |
+
def apply_img_affine(img, affine_inv, method=Image.Resampling.BICUBIC):
|
105 |
+
return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.Resampling.BICUBIC)
|
106 |
+
|
107 |
+
|
108 |
+
def apply_lm_affine(landmark, affine, flip, size):
|
109 |
+
_, h = size
|
110 |
+
lm = landmark.copy()
|
111 |
+
lm[:, 1] = h - 1 - lm[:, 1]
|
112 |
+
lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
|
113 |
+
lm = lm @ np.transpose(affine)
|
114 |
+
lm[:, :2] = lm[:, :2] / lm[:, 2:]
|
115 |
+
lm = lm[:, :2]
|
116 |
+
lm[:, 1] = h - 1 - lm[:, 1]
|
117 |
+
if flip:
|
118 |
+
lm_ = lm.copy()
|
119 |
+
lm_[:17] = lm[16::-1]
|
120 |
+
lm_[17:22] = lm[26:21:-1]
|
121 |
+
lm_[22:27] = lm[21:16:-1]
|
122 |
+
lm_[31:36] = lm[35:30:-1]
|
123 |
+
lm_[36:40] = lm[45:41:-1]
|
124 |
+
lm_[40:42] = lm[47:45:-1]
|
125 |
+
lm_[42:46] = lm[39:35:-1]
|
126 |
+
lm_[46:48] = lm[41:39:-1]
|
127 |
+
lm_[48:55] = lm[54:47:-1]
|
128 |
+
lm_[55:60] = lm[59:54:-1]
|
129 |
+
lm_[60:65] = lm[64:59:-1]
|
130 |
+
lm_[65:68] = lm[67:64:-1]
|
131 |
+
lm = lm_
|
132 |
+
return lm
|
Deep3DFaceRecon_pytorch/data/flist_dataset.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
|
2 |
+
"""
|
3 |
+
import json
|
4 |
+
import os.path
|
5 |
+
import pickle
|
6 |
+
import random
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import util.util as util
|
11 |
+
from data.base_dataset import apply_img_affine
|
12 |
+
from data.base_dataset import apply_lm_affine
|
13 |
+
from data.base_dataset import BaseDataset
|
14 |
+
from data.base_dataset import get_affine_mat
|
15 |
+
from data.base_dataset import get_transform
|
16 |
+
from data.image_folder import make_dataset
|
17 |
+
from PIL import Image
|
18 |
+
from scipy.io import loadmat
|
19 |
+
from scipy.io import savemat
|
20 |
+
from util.load_mats import load_lm3d
|
21 |
+
from util.preprocess import align_img
|
22 |
+
from util.preprocess import estimate_norm
|
23 |
+
|
24 |
+
|
25 |
+
def default_flist_reader(flist):
|
26 |
+
"""
|
27 |
+
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
|
28 |
+
"""
|
29 |
+
imlist = []
|
30 |
+
with open(flist, "r") as rf:
|
31 |
+
for line in rf.readlines():
|
32 |
+
impath = line.strip()
|
33 |
+
imlist.append(impath)
|
34 |
+
|
35 |
+
return imlist
|
36 |
+
|
37 |
+
|
38 |
+
def jason_flist_reader(flist):
|
39 |
+
with open(flist, "r") as fp:
|
40 |
+
info = json.load(fp)
|
41 |
+
return info
|
42 |
+
|
43 |
+
|
44 |
+
def parse_label(label):
|
45 |
+
return torch.tensor(np.array(label).astype(np.float32))
|
46 |
+
|
47 |
+
|
48 |
+
class FlistDataset(BaseDataset):
|
49 |
+
"""
|
50 |
+
It requires one directories to host training images '/path/to/data/train'
|
51 |
+
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, opt):
|
55 |
+
"""Initialize this dataset class.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
59 |
+
"""
|
60 |
+
BaseDataset.__init__(self, opt)
|
61 |
+
|
62 |
+
self.lm3d_std = load_lm3d(opt.bfm_folder)
|
63 |
+
|
64 |
+
msk_names = default_flist_reader(opt.flist)
|
65 |
+
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
|
66 |
+
|
67 |
+
self.size = len(self.msk_paths)
|
68 |
+
self.opt = opt
|
69 |
+
|
70 |
+
self.name = "train" if opt.isTrain else "val"
|
71 |
+
if "_" in opt.flist:
|
72 |
+
self.name += "_" + opt.flist.split(os.sep)[-1].split("_")[0]
|
73 |
+
|
74 |
+
def __getitem__(self, index):
|
75 |
+
"""Return a data point and its metadata information.
|
76 |
+
|
77 |
+
Parameters:
|
78 |
+
index (int) -- a random integer for data indexing
|
79 |
+
|
80 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
81 |
+
img (tensor) -- an image in the input domain
|
82 |
+
msk (tensor) -- its corresponding attention mask
|
83 |
+
lm (tensor) -- its corresponding 3d landmarks
|
84 |
+
im_paths (str) -- image paths
|
85 |
+
aug_flag (bool) -- a flag used to tell whether its raw or augmented
|
86 |
+
"""
|
87 |
+
msk_path = self.msk_paths[index % self.size] # make sure index is within then range
|
88 |
+
img_path = msk_path.replace("mask/", "")
|
89 |
+
lm_path = ".".join(msk_path.replace("mask", "landmarks").split(".")[:-1]) + ".txt"
|
90 |
+
|
91 |
+
raw_img = Image.open(img_path).convert("RGB")
|
92 |
+
raw_msk = Image.open(msk_path).convert("RGB")
|
93 |
+
raw_lm = np.loadtxt(lm_path).astype(np.float32)
|
94 |
+
|
95 |
+
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
|
96 |
+
|
97 |
+
aug_flag = self.opt.use_aug and self.opt.isTrain
|
98 |
+
if aug_flag:
|
99 |
+
img, lm, msk = self._augmentation(img, lm, self.opt, msk)
|
100 |
+
|
101 |
+
_, H = img.size
|
102 |
+
M = estimate_norm(lm, H)
|
103 |
+
transform = get_transform()
|
104 |
+
img_tensor = transform(img)
|
105 |
+
msk_tensor = transform(msk)[:1, ...]
|
106 |
+
lm_tensor = parse_label(lm)
|
107 |
+
M_tensor = parse_label(M)
|
108 |
+
|
109 |
+
return {
|
110 |
+
"imgs": img_tensor,
|
111 |
+
"lms": lm_tensor,
|
112 |
+
"msks": msk_tensor,
|
113 |
+
"M": M_tensor,
|
114 |
+
"im_paths": img_path,
|
115 |
+
"aug_flag": aug_flag,
|
116 |
+
"dataset": self.name,
|
117 |
+
}
|
118 |
+
|
119 |
+
def _augmentation(self, img, lm, opt, msk=None):
|
120 |
+
affine, affine_inv, flip = get_affine_mat(opt, img.size)
|
121 |
+
img = apply_img_affine(img, affine_inv)
|
122 |
+
lm = apply_lm_affine(lm, affine, flip, img.size)
|
123 |
+
if msk is not None:
|
124 |
+
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
|
125 |
+
return img, lm, msk
|
126 |
+
|
127 |
+
def __len__(self):
|
128 |
+
"""Return the total number of images in the dataset."""
|
129 |
+
return self.size
|
Deep3DFaceRecon_pytorch/data/image_folder.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A modified image folder class
|
2 |
+
|
3 |
+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
4 |
+
so that this class can load images from both current directory and its subdirectories.
|
5 |
+
"""
|
6 |
+
import os.path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch.utils.data as data
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
IMG_EXTENSIONS = [
|
13 |
+
".jpg",
|
14 |
+
".JPG",
|
15 |
+
".jpeg",
|
16 |
+
".JPEG",
|
17 |
+
".png",
|
18 |
+
".PNG",
|
19 |
+
".ppm",
|
20 |
+
".PPM",
|
21 |
+
".bmp",
|
22 |
+
".BMP",
|
23 |
+
".tif",
|
24 |
+
".TIF",
|
25 |
+
".tiff",
|
26 |
+
".TIFF",
|
27 |
+
]
|
28 |
+
|
29 |
+
|
30 |
+
def is_image_file(filename):
|
31 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
32 |
+
|
33 |
+
|
34 |
+
def make_dataset(dir, max_dataset_size=float("inf")):
|
35 |
+
images = []
|
36 |
+
assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir
|
37 |
+
|
38 |
+
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
39 |
+
for fname in fnames:
|
40 |
+
if is_image_file(fname):
|
41 |
+
path = os.path.join(root, fname)
|
42 |
+
images.append(path)
|
43 |
+
return images[: min(max_dataset_size, len(images))]
|
44 |
+
|
45 |
+
|
46 |
+
def default_loader(path):
|
47 |
+
return Image.open(path).convert("RGB")
|
48 |
+
|
49 |
+
|
50 |
+
class ImageFolder(data.Dataset):
|
51 |
+
def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
|
52 |
+
imgs = make_dataset(root)
|
53 |
+
if len(imgs) == 0:
|
54 |
+
raise (
|
55 |
+
RuntimeError(
|
56 |
+
"Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
|
57 |
+
)
|
58 |
+
)
|
59 |
+
|
60 |
+
self.root = root
|
61 |
+
self.imgs = imgs
|
62 |
+
self.transform = transform
|
63 |
+
self.return_paths = return_paths
|
64 |
+
self.loader = loader
|
65 |
+
|
66 |
+
def __getitem__(self, index):
|
67 |
+
path = self.imgs[index]
|
68 |
+
img = self.loader(path)
|
69 |
+
if self.transform is not None:
|
70 |
+
img = self.transform(img)
|
71 |
+
if self.return_paths:
|
72 |
+
return img, path
|
73 |
+
else:
|
74 |
+
return img
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
return len(self.imgs)
|
Deep3DFaceRecon_pytorch/data/template_dataset.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Dataset class template
|
2 |
+
|
3 |
+
This module provides a template for users to implement custom datasets.
|
4 |
+
You can specify '--dataset_mode template' to use this dataset.
|
5 |
+
The class name should be consistent with both the filename and its dataset_mode option.
|
6 |
+
The filename should be <dataset_mode>_dataset.py
|
7 |
+
The class name should be <Dataset_mode>Dataset.py
|
8 |
+
You need to implement the following functions:
|
9 |
+
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
|
10 |
+
-- <__init__>: Initialize this dataset class.
|
11 |
+
-- <__getitem__>: Return a data point and its metadata information.
|
12 |
+
-- <__len__>: Return the number of images.
|
13 |
+
"""
|
14 |
+
from data.base_dataset import BaseDataset
|
15 |
+
from data.base_dataset import get_transform
|
16 |
+
|
17 |
+
# from data.image_folder import make_dataset
|
18 |
+
# from PIL import Image
|
19 |
+
|
20 |
+
|
21 |
+
class TemplateDataset(BaseDataset):
|
22 |
+
"""A template dataset class for you to implement custom datasets."""
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def modify_commandline_options(parser, is_train):
|
26 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
parser -- original option parser
|
30 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
the modified parser.
|
34 |
+
"""
|
35 |
+
parser.add_argument("--new_dataset_option", type=float, default=1.0, help="new dataset option")
|
36 |
+
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
|
37 |
+
return parser
|
38 |
+
|
39 |
+
def __init__(self, opt):
|
40 |
+
"""Initialize this dataset class.
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
44 |
+
|
45 |
+
A few things can be done here.
|
46 |
+
- save the options (have been done in BaseDataset)
|
47 |
+
- get image paths and meta information of the dataset.
|
48 |
+
- define the image transformation.
|
49 |
+
"""
|
50 |
+
# save the option and dataset root
|
51 |
+
BaseDataset.__init__(self, opt)
|
52 |
+
# get the image paths of your dataset;
|
53 |
+
self.image_paths = (
|
54 |
+
[]
|
55 |
+
) # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
|
56 |
+
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
|
57 |
+
self.transform = get_transform(opt)
|
58 |
+
|
59 |
+
def __getitem__(self, index):
|
60 |
+
"""Return a data point and its metadata information.
|
61 |
+
|
62 |
+
Parameters:
|
63 |
+
index -- a random integer for data indexing
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
a dictionary of data with their names. It usually contains the data itself and its metadata information.
|
67 |
+
|
68 |
+
Step 1: get a random image path: e.g., path = self.image_paths[index]
|
69 |
+
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
|
70 |
+
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
|
71 |
+
Step 4: return a data point as a dictionary.
|
72 |
+
"""
|
73 |
+
path = "temp" # needs to be a string
|
74 |
+
data_A = None # needs to be a tensor
|
75 |
+
data_B = None # needs to be a tensor
|
76 |
+
return {"data_A": data_A, "data_B": data_B, "path": path}
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
"""Return the total number of images."""
|
80 |
+
return len(self.image_paths)
|
Deep3DFaceRecon_pytorch/data_preparation.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script is the data preparation script for Deep3DFaceRecon_pytorch
|
2 |
+
"""
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from util.detect_lm68 import detect_68p
|
9 |
+
from util.detect_lm68 import load_lm_graph
|
10 |
+
from util.generate_list import check_list
|
11 |
+
from util.generate_list import write_list
|
12 |
+
from util.skin_mask import get_skin_mask
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument("--data_root", type=str, default="datasets", help="root directory for training data")
|
18 |
+
parser.add_argument("--img_folder", nargs="+", required=True, help="folders of training images")
|
19 |
+
parser.add_argument("--mode", type=str, default="train", help="train or val")
|
20 |
+
opt = parser.parse_args()
|
21 |
+
|
22 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
23 |
+
|
24 |
+
|
25 |
+
def data_prepare(folder_list, mode):
|
26 |
+
|
27 |
+
lm_sess, input_op, output_op = load_lm_graph(
|
28 |
+
"./checkpoints/lm_model/68lm_detector.pb"
|
29 |
+
) # load a tensorflow version 68-landmark detector
|
30 |
+
|
31 |
+
for img_folder in folder_list:
|
32 |
+
detect_68p(img_folder, lm_sess, input_op, output_op) # detect landmarks for images
|
33 |
+
get_skin_mask(img_folder) # generate skin attention mask for images
|
34 |
+
|
35 |
+
# create files that record path to all training data
|
36 |
+
msks_list = []
|
37 |
+
for img_folder in folder_list:
|
38 |
+
path = os.path.join(img_folder, "mask")
|
39 |
+
msks_list += [
|
40 |
+
"/".join([img_folder, "mask", i])
|
41 |
+
for i in sorted(os.listdir(path))
|
42 |
+
if "jpg" in i or "png" in i or "jpeg" in i or "PNG" in i
|
43 |
+
]
|
44 |
+
|
45 |
+
imgs_list = [i.replace("mask/", "") for i in msks_list]
|
46 |
+
lms_list = [i.replace("mask", "landmarks") for i in msks_list]
|
47 |
+
lms_list = [".".join(i.split(".")[:-1]) + ".txt" for i in lms_list]
|
48 |
+
|
49 |
+
lms_list_final, imgs_list_final, msks_list_final = check_list(
|
50 |
+
lms_list, imgs_list, msks_list
|
51 |
+
) # check if the path is valid
|
52 |
+
write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
print("Datasets:", opt.img_folder)
|
57 |
+
data_prepare([os.path.join(opt.data_root, folder) for folder in opt.img_folder], opt.mode)
|
Deep3DFaceRecon_pytorch/environment.yml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: deep3d_pytorch
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- conda-forge
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- python=3.6
|
8 |
+
- pytorch=1.6.0
|
9 |
+
- torchvision=0.7.0
|
10 |
+
- numpy=1.18.1
|
11 |
+
- scikit-image=0.16.2
|
12 |
+
- scipy=1.4.1
|
13 |
+
- pillow=6.2.1
|
14 |
+
- pip=20.0.2
|
15 |
+
- ipython=7.13.0
|
16 |
+
- yaml=0.1.7
|
17 |
+
- pip:
|
18 |
+
- matplotlib==2.2.5
|
19 |
+
- opencv-python==3.4.9.33
|
20 |
+
- tensorboard==1.15.0
|
21 |
+
- tensorflow==1.15.0
|
22 |
+
- kornia==0.5.5
|
23 |
+
- dominate==2.6.0
|
24 |
+
- trimesh==3.9.20
|
Deep3DFaceRecon_pytorch/models/__init__.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
import importlib
|
21 |
+
|
22 |
+
from Deep3DFaceRecon_pytorch.models.base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace("_", "") + "model"
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() and issubclass(cls, BaseModel):
|
38 |
+
model = cls
|
39 |
+
|
40 |
+
if model is None:
|
41 |
+
print(
|
42 |
+
"In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase."
|
43 |
+
% (model_filename, target_model_name)
|
44 |
+
)
|
45 |
+
exit(0)
|
46 |
+
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
50 |
+
def get_option_setter(model_name):
|
51 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
52 |
+
model_class = find_model_using_name(model_name)
|
53 |
+
return model_class.modify_commandline_options
|
54 |
+
|
55 |
+
|
56 |
+
def create_model(opt):
|
57 |
+
"""Create a model given the option.
|
58 |
+
|
59 |
+
This function warps the class CustomDatasetDataLoader.
|
60 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
61 |
+
|
62 |
+
Example:
|
63 |
+
>>> from models import create_model
|
64 |
+
>>> model = create_model(opt)
|
65 |
+
"""
|
66 |
+
model = find_model_using_name(opt.model)
|
67 |
+
instance = model(opt)
|
68 |
+
print("model [%s] was created" % type(instance).__name__)
|
69 |
+
return instance
|
Deep3DFaceRecon_pytorch/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (3.28 kB). View file
|
|
Deep3DFaceRecon_pytorch/models/__pycache__/base_model.cpython-310.pyc
ADDED
Binary file (12.5 kB). View file
|
|
Deep3DFaceRecon_pytorch/models/__pycache__/bfm.cpython-310.pyc
ADDED
Binary file (9.77 kB). View file
|
|
Deep3DFaceRecon_pytorch/models/__pycache__/networks.cpython-310.pyc
ADDED
Binary file (17.1 kB). View file
|
|
Deep3DFaceRecon_pytorch/models/arcface_torch/README.md
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Distributed Arcface Training in Pytorch
|
2 |
+
|
3 |
+
The "arcface_torch" repository is the official implementation of the ArcFace algorithm. It supports distributed and sparse training with multiple distributed training examples, including several memory-saving techniques such as mixed precision training and gradient checkpointing. It also supports training for ViT models and datasets including WebFace42M and Glint360K, two of the largest open-source datasets. Additionally, the repository comes with a built-in tool for converting to ONNX format, making it easy to submit to MFR evaluation systems.
|
4 |
+
|
5 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-ijb-c)](https://paperswithcode.com/sota/face-verification-on-ijb-c?p=killing-two-birds-with-one-stone-efficient)
|
6 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-ijb-b)](https://paperswithcode.com/sota/face-verification-on-ijb-b?p=killing-two-birds-with-one-stone-efficient)
|
7 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-agedb-30)](https://paperswithcode.com/sota/face-verification-on-agedb-30?p=killing-two-birds-with-one-stone-efficient)
|
8 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-cfp-fp)](https://paperswithcode.com/sota/face-verification-on-cfp-fp?p=killing-two-birds-with-one-stone-efficient)
|
9 |
+
|
10 |
+
## Requirements
|
11 |
+
|
12 |
+
To avail the latest features of PyTorch, we have upgraded to version 1.12.0.
|
13 |
+
|
14 |
+
- Install [PyTorch](https://pytorch.org/get-started/previous-versions/) (torch>=1.12.0).
|
15 |
+
- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
|
16 |
+
- `pip install -r requirement.txt`.
|
17 |
+
|
18 |
+
## How to Training
|
19 |
+
|
20 |
+
To train a model, execute the `train.py` script with the path to the configuration files. The sample commands provided below demonstrate the process of conducting distributed training.
|
21 |
+
|
22 |
+
### 1. To run on one GPU:
|
23 |
+
|
24 |
+
```shell
|
25 |
+
python train_v2.py configs/ms1mv3_r50_onegpu
|
26 |
+
```
|
27 |
+
|
28 |
+
Note:
|
29 |
+
It is not recommended to use a single GPU for training, as this may result in longer training times and suboptimal performance. For best results, we suggest using multiple GPUs or a GPU cluster.
|
30 |
+
|
31 |
+
|
32 |
+
### 2. To run on a machine with 8 GPUs:
|
33 |
+
|
34 |
+
```shell
|
35 |
+
torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50
|
36 |
+
```
|
37 |
+
|
38 |
+
### 3. To run on 2 machines with 8 GPUs each:
|
39 |
+
|
40 |
+
Node 0:
|
41 |
+
|
42 |
+
```shell
|
43 |
+
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
|
44 |
+
```
|
45 |
+
|
46 |
+
Node 1:
|
47 |
+
|
48 |
+
```shell
|
49 |
+
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
|
50 |
+
```
|
51 |
+
|
52 |
+
### 4. Run ViT-B on a machine with 24k batchsize:
|
53 |
+
|
54 |
+
```shell
|
55 |
+
torchrun --nproc_per_node=8 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b
|
56 |
+
```
|
57 |
+
|
58 |
+
|
59 |
+
## Download Datasets or Prepare Datasets
|
60 |
+
- [MS1MV2](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-arcface-85k-ids58m-images-57) (87k IDs, 5.8M images)
|
61 |
+
- [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images)
|
62 |
+
- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images)
|
63 |
+
- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images)
|
64 |
+
- [Your Dataset, Click Here!](docs/prepare_custom_dataset.md)
|
65 |
+
|
66 |
+
Note:
|
67 |
+
If you want to use DALI for data reading, please use the script 'scripts/shuffle_rec.py' to shuffle the InsightFace style rec before using it.
|
68 |
+
Example:
|
69 |
+
|
70 |
+
`python scripts/shuffle_rec.py ms1m-retinaface-t1`
|
71 |
+
|
72 |
+
You will get the "shuffled_ms1m-retinaface-t1" folder, where the samples in the "train.rec" file are shuffled.
|
73 |
+
|
74 |
+
|
75 |
+
## Model Zoo
|
76 |
+
|
77 |
+
- The models are available for non-commercial research purposes only.
|
78 |
+
- All models can be found in here.
|
79 |
+
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
|
80 |
+
- [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
|
81 |
+
|
82 |
+
### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md)
|
83 |
+
|
84 |
+
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
|
85 |
+
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
|
86 |
+
As the result, we can evaluate the FAIR performance for different algorithms.
|
87 |
+
|
88 |
+
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
|
89 |
+
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
|
90 |
+
|
91 |
+
|
92 |
+
#### 1. Training on Single-Host GPU
|
93 |
+
|
94 |
+
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
|
95 |
+
|:---------------|:--------------------|:------------|:------------|:------------|:------------------------------------------------------------------------------------------------------------------------------------|
|
96 |
+
| MS1MV2 | mobilefacenet-0.45G | 62.07 | 93.61 | 90.28 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_mbf/training.log) |
|
97 |
+
| MS1MV2 | r50 | 75.13 | 95.97 | 94.07 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r50/training.log) |
|
98 |
+
| MS1MV2 | r100 | 78.12 | 96.37 | 94.27 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r100/training.log) |
|
99 |
+
| MS1MV3 | mobilefacenet-0.45G | 63.78 | 94.23 | 91.33 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mbf/training.log) |
|
100 |
+
| MS1MV3 | r50 | 79.14 | 96.37 | 94.47 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r50/training.log) |
|
101 |
+
| MS1MV3 | r100 | 81.97 | 96.85 | 95.02 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100/training.log) |
|
102 |
+
| Glint360K | mobilefacenet-0.45G | 70.18 | 95.04 | 92.62 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mbf/training.log) |
|
103 |
+
| Glint360K | r50 | 86.34 | 97.16 | 95.81 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r50/training.log) |
|
104 |
+
| Glint360k | r100 | 89.52 | 97.55 | 96.38 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100/training.log) |
|
105 |
+
| WF4M | r100 | 89.87 | 97.19 | 95.48 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf4m_r100/training.log) |
|
106 |
+
| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc02_r100/training.log) |
|
107 |
+
| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc03_r100/training.log) |
|
108 |
+
| WF12M | r100 | 94.69 | 97.59 | 95.97 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_r100/training.log) |
|
109 |
+
| WF42M-PFC-0.2 | r100 | 96.27 | 97.70 | 96.31 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_r100/training.log) |
|
110 |
+
| WF42M-PFC-0.2 | ViT-T-1.5G | 92.04 | 97.27 | 95.68 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_40epoch_8gpu_vit_t/training.log) |
|
111 |
+
| WF42M-PFC-0.3 | ViT-B-11G | 97.16 | 97.91 | 97.05 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_8gpu/training.log) |
|
112 |
+
|
113 |
+
#### 2. Training on Multi-Host GPU
|
114 |
+
|
115 |
+
| Datasets | Backbone(bs*gpus) | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
|
116 |
+
|:-----------------|:------------------|:------------|:------------|:------------|:-----------|:-------------------------------------------------------------------------------------------------------------------------------------------|
|
117 |
+
| WF42M-PFC-0.2 | r50(512*8) | 93.83 | 97.53 | 96.16 | ~5900 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log) |
|
118 |
+
| WF42M-PFC-0.2 | r50(512*16) | 93.96 | 97.46 | 96.12 | ~11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log) |
|
119 |
+
| WF42M-PFC-0.2 | r50(128*32) | 94.04 | 97.48 | 95.94 | ~17000 | click me |
|
120 |
+
| WF42M-PFC-0.2 | r100(128*16) | 96.28 | 97.80 | 96.57 | ~5200 | click me |
|
121 |
+
| WF42M-PFC-0.2 | r100(256*16) | 96.69 | 97.85 | 96.63 | ~5200 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log) |
|
122 |
+
| WF42M-PFC-0.0018 | r100(512*32) | 93.08 | 97.51 | 95.88 | ~10000 | click me |
|
123 |
+
| WF42M-PFC-0.2 | r100(128*32) | 96.57 | 97.83 | 96.50 | ~9800 | click me |
|
124 |
+
|
125 |
+
`r100(128*32)` means backbone is r100, batchsize per gpu is 128, the number of gpus is 32.
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
#### 3. ViT For Face Recognition
|
130 |
+
|
131 |
+
| Datasets | Backbone(bs) | FLOPs | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
|
132 |
+
|:--------------|:--------------|:------|:------------|:------------|:------------|:-----------|:-----------------------------------------------------------------------------------------------------------------------------|
|
133 |
+
| WF42M-PFC-0.3 | r18(128*32) | 2.6 | 79.13 | 95.77 | 93.36 | - | click me |
|
134 |
+
| WF42M-PFC-0.3 | r50(128*32) | 6.3 | 94.03 | 97.48 | 95.94 | - | click me |
|
135 |
+
| WF42M-PFC-0.3 | r100(128*32) | 12.1 | 96.69 | 97.82 | 96.45 | - | click me |
|
136 |
+
| WF42M-PFC-0.3 | r200(128*32) | 23.5 | 97.70 | 97.97 | 96.93 | - | click me |
|
137 |
+
| WF42M-PFC-0.3 | VIT-T(384*64) | 1.5 | 92.24 | 97.31 | 95.97 | ~35000 | click me |
|
138 |
+
| WF42M-PFC-0.3 | VIT-S(384*64) | 5.7 | 95.87 | 97.73 | 96.57 | ~25000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_s_64gpu/training.log) |
|
139 |
+
| WF42M-PFC-0.3 | VIT-B(384*64) | 11.4 | 97.42 | 97.90 | 97.04 | ~13800 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_64gpu/training.log) |
|
140 |
+
| WF42M-PFC-0.3 | VIT-L(384*64) | 25.3 | 97.85 | 98.00 | 97.23 | ~9406 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_l_64gpu/training.log) |
|
141 |
+
|
142 |
+
`WF42M` means WebFace42M, `PFC-0.3` means negivate class centers sample rate is 0.3.
|
143 |
+
|
144 |
+
#### 4. Noisy Datasets
|
145 |
+
|
146 |
+
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
|
147 |
+
|:-------------------------|:---------|:------------|:------------|:------------|:---------|
|
148 |
+
| WF12M-Flip(40%) | r50 | 43.87 | 88.35 | 80.78 | click me |
|
149 |
+
| WF12M-Flip(40%)-PFC-0.1* | r50 | 80.20 | 96.11 | 93.79 | click me |
|
150 |
+
| WF12M-Conflict | r50 | 79.93 | 95.30 | 91.56 | click me |
|
151 |
+
| WF12M-Conflict-PFC-0.3* | r50 | 91.68 | 97.28 | 95.75 | click me |
|
152 |
+
|
153 |
+
`WF12M` means WebFace12M, `+PFC-0.1*` denotes additional abnormal inter-class filtering.
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
## Speed Benchmark
|
158 |
+
<div><img src="https://github.com/anxiangsir/insightface_arcface_log/blob/master/pfc_exp.png" width = "90%" /></div>
|
159 |
+
|
160 |
+
|
161 |
+
**Arcface-Torch** is an efficient tool for training large-scale face recognition training sets. When the number of classes in the training sets exceeds one million, the partial FC sampling strategy maintains the same accuracy while providing several times faster training performance and lower GPU memory utilization. The partial FC is a sparse variant of the model parallel architecture for large-scale face recognition, utilizing a sparse softmax that dynamically samples a subset of class centers for each training batch. During each iteration, only a sparse portion of the parameters are updated, leading to a significant reduction in GPU memory requirements and computational demands. With the partial FC approach, it is possible to train sets with up to 29 million identities, the largest to date. Furthermore, the partial FC method supports multi-machine distributed training and mixed precision training.
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
More details see
|
166 |
+
[speed_benchmark.md](docs/speed_benchmark.md) in docs.
|
167 |
+
|
168 |
+
> 1. Training Speed of Various Parallel Techniques (Samples per Second) on a Tesla V100 32GB x 8 System (Higher is Optimal)
|
169 |
+
|
170 |
+
`-` means training failed because of gpu memory limitations.
|
171 |
+
|
172 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
173 |
+
|:--------------------------------|:--------------|:---------------|:---------------|
|
174 |
+
| 125000 | 4681 | 4824 | 5004 |
|
175 |
+
| 1400000 | **1672** | 3043 | 4738 |
|
176 |
+
| 5500000 | **-** | **1389** | 3975 |
|
177 |
+
| 8000000 | **-** | **-** | 3565 |
|
178 |
+
| 16000000 | **-** | **-** | 2679 |
|
179 |
+
| 29000000 | **-** | **-** | **1855** |
|
180 |
+
|
181 |
+
> 2. GPU Memory Utilization of Various Parallel Techniques (MB per GPU) on a Tesla V100 32GB x 8 System (Lower is Optimal)
|
182 |
+
|
183 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
184 |
+
|:--------------------------------|:--------------|:---------------|:---------------|
|
185 |
+
| 125000 | 7358 | 5306 | 4868 |
|
186 |
+
| 1400000 | 32252 | 11178 | 6056 |
|
187 |
+
| 5500000 | **-** | 32188 | 9854 |
|
188 |
+
| 8000000 | **-** | **-** | 12310 |
|
189 |
+
| 16000000 | **-** | **-** | 19950 |
|
190 |
+
| 29000000 | **-** | **-** | 32324 |
|
191 |
+
|
192 |
+
|
193 |
+
## Citations
|
194 |
+
|
195 |
+
```
|
196 |
+
@inproceedings{deng2019arcface,
|
197 |
+
title={Arcface: Additive angular margin loss for deep face recognition},
|
198 |
+
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
|
199 |
+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
200 |
+
pages={4690--4699},
|
201 |
+
year={2019}
|
202 |
+
}
|
203 |
+
@inproceedings{An_2022_CVPR,
|
204 |
+
author={An, Xiang and Deng, Jiankang and Guo, Jia and Feng, Ziyong and Zhu, XuHan and Yang, Jing and Liu, Tongliang},
|
205 |
+
title={Killing Two Birds With One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC},
|
206 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
207 |
+
month={June},
|
208 |
+
year={2022},
|
209 |
+
pages={4042-4051}
|
210 |
+
}
|
211 |
+
@inproceedings{zhu2021webface260m,
|
212 |
+
title={Webface260m: A benchmark unveiling the power of million-scale deep face recognition},
|
213 |
+
author={Zhu, Zheng and Huang, Guan and Deng, Jiankang and Ye, Yun and Huang, Junjie and Chen, Xinze and Zhu, Jiagang and Yang, Tian and Lu, Jiwen and Du, Dalong and Zhou, Jie},
|
214 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
215 |
+
pages={10492--10502},
|
216 |
+
year={2021}
|
217 |
+
}
|
218 |
+
```
|
Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__init__.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .iresnet import iresnet100
|
2 |
+
from .iresnet import iresnet18
|
3 |
+
from .iresnet import iresnet200
|
4 |
+
from .iresnet import iresnet34
|
5 |
+
from .iresnet import iresnet50
|
6 |
+
from .mobilefacenet import get_mbf
|
7 |
+
|
8 |
+
|
9 |
+
def get_model(name, **kwargs):
|
10 |
+
# resnet
|
11 |
+
if name == "r18":
|
12 |
+
return iresnet18(False, **kwargs)
|
13 |
+
elif name == "r34":
|
14 |
+
return iresnet34(False, **kwargs)
|
15 |
+
elif name == "r50":
|
16 |
+
return iresnet50(False, **kwargs)
|
17 |
+
elif name == "r100":
|
18 |
+
return iresnet100(False, **kwargs)
|
19 |
+
elif name == "r200":
|
20 |
+
return iresnet200(False, **kwargs)
|
21 |
+
elif name == "r2060":
|
22 |
+
from .iresnet2060 import iresnet2060
|
23 |
+
|
24 |
+
return iresnet2060(False, **kwargs)
|
25 |
+
|
26 |
+
elif name == "mbf":
|
27 |
+
fp16 = kwargs.get("fp16", False)
|
28 |
+
num_features = kwargs.get("num_features", 512)
|
29 |
+
return get_mbf(fp16=fp16, num_features=num_features)
|
30 |
+
|
31 |
+
elif name == "mbf_large":
|
32 |
+
from .mobilefacenet import get_mbf_large
|
33 |
+
|
34 |
+
fp16 = kwargs.get("fp16", False)
|
35 |
+
num_features = kwargs.get("num_features", 512)
|
36 |
+
return get_mbf_large(fp16=fp16, num_features=num_features)
|
37 |
+
|
38 |
+
elif name == "vit_t":
|
39 |
+
num_features = kwargs.get("num_features", 512)
|
40 |
+
from .vit import VisionTransformer
|
41 |
+
|
42 |
+
return VisionTransformer(
|
43 |
+
img_size=112,
|
44 |
+
patch_size=9,
|
45 |
+
num_classes=num_features,
|
46 |
+
embed_dim=256,
|
47 |
+
depth=12,
|
48 |
+
num_heads=8,
|
49 |
+
drop_path_rate=0.1,
|
50 |
+
norm_layer="ln",
|
51 |
+
mask_ratio=0.1,
|
52 |
+
)
|
53 |
+
|
54 |
+
elif name == "vit_t_dp005_mask0": # For WebFace42M
|
55 |
+
num_features = kwargs.get("num_features", 512)
|
56 |
+
from .vit import VisionTransformer
|
57 |
+
|
58 |
+
return VisionTransformer(
|
59 |
+
img_size=112,
|
60 |
+
patch_size=9,
|
61 |
+
num_classes=num_features,
|
62 |
+
embed_dim=256,
|
63 |
+
depth=12,
|
64 |
+
num_heads=8,
|
65 |
+
drop_path_rate=0.05,
|
66 |
+
norm_layer="ln",
|
67 |
+
mask_ratio=0.0,
|
68 |
+
)
|
69 |
+
|
70 |
+
elif name == "vit_s":
|
71 |
+
num_features = kwargs.get("num_features", 512)
|
72 |
+
from .vit import VisionTransformer
|
73 |
+
|
74 |
+
return VisionTransformer(
|
75 |
+
img_size=112,
|
76 |
+
patch_size=9,
|
77 |
+
num_classes=num_features,
|
78 |
+
embed_dim=512,
|
79 |
+
depth=12,
|
80 |
+
num_heads=8,
|
81 |
+
drop_path_rate=0.1,
|
82 |
+
norm_layer="ln",
|
83 |
+
mask_ratio=0.1,
|
84 |
+
)
|
85 |
+
|
86 |
+
elif name == "vit_s_dp005_mask_0": # For WebFace42M
|
87 |
+
num_features = kwargs.get("num_features", 512)
|
88 |
+
from .vit import VisionTransformer
|
89 |
+
|
90 |
+
return VisionTransformer(
|
91 |
+
img_size=112,
|
92 |
+
patch_size=9,
|
93 |
+
num_classes=num_features,
|
94 |
+
embed_dim=512,
|
95 |
+
depth=12,
|
96 |
+
num_heads=8,
|
97 |
+
drop_path_rate=0.05,
|
98 |
+
norm_layer="ln",
|
99 |
+
mask_ratio=0.0,
|
100 |
+
)
|
101 |
+
|
102 |
+
elif name == "vit_b":
|
103 |
+
# this is a feature
|
104 |
+
num_features = kwargs.get("num_features", 512)
|
105 |
+
from .vit import VisionTransformer
|
106 |
+
|
107 |
+
return VisionTransformer(
|
108 |
+
img_size=112,
|
109 |
+
patch_size=9,
|
110 |
+
num_classes=num_features,
|
111 |
+
embed_dim=512,
|
112 |
+
depth=24,
|
113 |
+
num_heads=8,
|
114 |
+
drop_path_rate=0.1,
|
115 |
+
norm_layer="ln",
|
116 |
+
mask_ratio=0.1,
|
117 |
+
using_checkpoint=True,
|
118 |
+
)
|
119 |
+
|
120 |
+
elif name == "vit_b_dp005_mask_005": # For WebFace42M
|
121 |
+
# this is a feature
|
122 |
+
num_features = kwargs.get("num_features", 512)
|
123 |
+
from .vit import VisionTransformer
|
124 |
+
|
125 |
+
return VisionTransformer(
|
126 |
+
img_size=112,
|
127 |
+
patch_size=9,
|
128 |
+
num_classes=num_features,
|
129 |
+
embed_dim=512,
|
130 |
+
depth=24,
|
131 |
+
num_heads=8,
|
132 |
+
drop_path_rate=0.05,
|
133 |
+
norm_layer="ln",
|
134 |
+
mask_ratio=0.05,
|
135 |
+
using_checkpoint=True,
|
136 |
+
)
|
137 |
+
|
138 |
+
elif name == "vit_l_dp005_mask_005": # For WebFace42M
|
139 |
+
# this is a feature
|
140 |
+
num_features = kwargs.get("num_features", 512)
|
141 |
+
from .vit import VisionTransformer
|
142 |
+
|
143 |
+
return VisionTransformer(
|
144 |
+
img_size=112,
|
145 |
+
patch_size=9,
|
146 |
+
num_classes=num_features,
|
147 |
+
embed_dim=768,
|
148 |
+
depth=24,
|
149 |
+
num_heads=8,
|
150 |
+
drop_path_rate=0.05,
|
151 |
+
norm_layer="ln",
|
152 |
+
mask_ratio=0.05,
|
153 |
+
using_checkpoint=True,
|
154 |
+
)
|
155 |
+
|
156 |
+
else:
|
157 |
+
raise ValueError()
|
Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (2.02 kB). View file
|
|
Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc
ADDED
Binary file (5.62 kB). View file
|
|
Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc
ADDED
Binary file (5.96 kB). View file
|
|
Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
|
5 |
+
__all__ = ["iresnet18", "iresnet34", "iresnet50", "iresnet100", "iresnet200"]
|
6 |
+
using_ckpt = False
|
7 |
+
|
8 |
+
|
9 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
10 |
+
"""3x3 convolution with padding"""
|
11 |
+
return nn.Conv2d(
|
12 |
+
in_planes,
|
13 |
+
out_planes,
|
14 |
+
kernel_size=3,
|
15 |
+
stride=stride,
|
16 |
+
padding=dilation,
|
17 |
+
groups=groups,
|
18 |
+
bias=False,
|
19 |
+
dilation=dilation,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
24 |
+
"""1x1 convolution"""
|
25 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
26 |
+
|
27 |
+
|
28 |
+
class IBasicBlock(nn.Module):
|
29 |
+
expansion = 1
|
30 |
+
|
31 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1):
|
32 |
+
super(IBasicBlock, self).__init__()
|
33 |
+
if groups != 1 or base_width != 64:
|
34 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
35 |
+
if dilation > 1:
|
36 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
37 |
+
self.bn1 = nn.BatchNorm2d(
|
38 |
+
inplanes,
|
39 |
+
eps=1e-05,
|
40 |
+
)
|
41 |
+
self.conv1 = conv3x3(inplanes, planes)
|
42 |
+
self.bn2 = nn.BatchNorm2d(
|
43 |
+
planes,
|
44 |
+
eps=1e-05,
|
45 |
+
)
|
46 |
+
self.prelu = nn.PReLU(planes)
|
47 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
48 |
+
self.bn3 = nn.BatchNorm2d(
|
49 |
+
planes,
|
50 |
+
eps=1e-05,
|
51 |
+
)
|
52 |
+
self.downsample = downsample
|
53 |
+
self.stride = stride
|
54 |
+
|
55 |
+
def forward_impl(self, x):
|
56 |
+
identity = x
|
57 |
+
out = self.bn1(x)
|
58 |
+
out = self.conv1(out)
|
59 |
+
out = self.bn2(out)
|
60 |
+
out = self.prelu(out)
|
61 |
+
out = self.conv2(out)
|
62 |
+
out = self.bn3(out)
|
63 |
+
if self.downsample is not None:
|
64 |
+
identity = self.downsample(x)
|
65 |
+
out += identity
|
66 |
+
return out
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
if self.training and using_ckpt:
|
70 |
+
return checkpoint(self.forward_impl, x)
|
71 |
+
else:
|
72 |
+
return self.forward_impl(x)
|
73 |
+
|
74 |
+
|
75 |
+
class IResNet(nn.Module):
|
76 |
+
fc_scale = 7 * 7
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
block,
|
81 |
+
layers,
|
82 |
+
dropout=0,
|
83 |
+
num_features=512,
|
84 |
+
zero_init_residual=False,
|
85 |
+
groups=1,
|
86 |
+
width_per_group=64,
|
87 |
+
replace_stride_with_dilation=None,
|
88 |
+
fp16=False,
|
89 |
+
):
|
90 |
+
super(IResNet, self).__init__()
|
91 |
+
self.extra_gflops = 0.0
|
92 |
+
self.fp16 = fp16
|
93 |
+
self.inplanes = 64
|
94 |
+
self.dilation = 1
|
95 |
+
if replace_stride_with_dilation is None:
|
96 |
+
replace_stride_with_dilation = [False, False, False]
|
97 |
+
if len(replace_stride_with_dilation) != 3:
|
98 |
+
raise ValueError(
|
99 |
+
"replace_stride_with_dilation should be None "
|
100 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
|
101 |
+
)
|
102 |
+
self.groups = groups
|
103 |
+
self.base_width = width_per_group
|
104 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
105 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
106 |
+
self.prelu = nn.PReLU(self.inplanes)
|
107 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
108 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
|
109 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
|
110 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
|
111 |
+
self.bn2 = nn.BatchNorm2d(
|
112 |
+
512 * block.expansion,
|
113 |
+
eps=1e-05,
|
114 |
+
)
|
115 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
116 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
117 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
118 |
+
nn.init.constant_(self.features.weight, 1.0)
|
119 |
+
self.features.weight.requires_grad = False
|
120 |
+
|
121 |
+
for m in self.modules():
|
122 |
+
if isinstance(m, nn.Conv2d):
|
123 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
124 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
125 |
+
nn.init.constant_(m.weight, 1)
|
126 |
+
nn.init.constant_(m.bias, 0)
|
127 |
+
|
128 |
+
if zero_init_residual:
|
129 |
+
for m in self.modules():
|
130 |
+
if isinstance(m, IBasicBlock):
|
131 |
+
nn.init.constant_(m.bn2.weight, 0)
|
132 |
+
|
133 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
134 |
+
downsample = None
|
135 |
+
previous_dilation = self.dilation
|
136 |
+
if dilate:
|
137 |
+
self.dilation *= stride
|
138 |
+
stride = 1
|
139 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
140 |
+
downsample = nn.Sequential(
|
141 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
142 |
+
nn.BatchNorm2d(
|
143 |
+
planes * block.expansion,
|
144 |
+
eps=1e-05,
|
145 |
+
),
|
146 |
+
)
|
147 |
+
layers = []
|
148 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation))
|
149 |
+
self.inplanes = planes * block.expansion
|
150 |
+
for _ in range(1, blocks):
|
151 |
+
layers.append(
|
152 |
+
block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation)
|
153 |
+
)
|
154 |
+
|
155 |
+
return nn.Sequential(*layers)
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
with torch.cuda.amp.autocast(self.fp16):
|
159 |
+
x = self.conv1(x)
|
160 |
+
x = self.bn1(x)
|
161 |
+
x = self.prelu(x)
|
162 |
+
x = self.layer1(x)
|
163 |
+
x = self.layer2(x)
|
164 |
+
x = self.layer3(x)
|
165 |
+
x = self.layer4(x)
|
166 |
+
x = self.bn2(x)
|
167 |
+
x = torch.flatten(x, 1)
|
168 |
+
x = self.dropout(x)
|
169 |
+
x = self.fc(x.float() if self.fp16 else x)
|
170 |
+
x = self.features(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
175 |
+
model = IResNet(block, layers, **kwargs)
|
176 |
+
if pretrained:
|
177 |
+
raise ValueError()
|
178 |
+
return model
|
179 |
+
|
180 |
+
|
181 |
+
def iresnet18(pretrained=False, progress=True, **kwargs):
|
182 |
+
return _iresnet("iresnet18", IBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
183 |
+
|
184 |
+
|
185 |
+
def iresnet34(pretrained=False, progress=True, **kwargs):
|
186 |
+
return _iresnet("iresnet34", IBasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
|
187 |
+
|
188 |
+
|
189 |
+
def iresnet50(pretrained=False, progress=True, **kwargs):
|
190 |
+
return _iresnet("iresnet50", IBasicBlock, [3, 4, 14, 3], pretrained, progress, **kwargs)
|
191 |
+
|
192 |
+
|
193 |
+
def iresnet100(pretrained=False, progress=True, **kwargs):
|
194 |
+
return _iresnet("iresnet100", IBasicBlock, [3, 13, 30, 3], pretrained, progress, **kwargs)
|
195 |
+
|
196 |
+
|
197 |
+
def iresnet200(pretrained=False, progress=True, **kwargs):
|
198 |
+
return _iresnet("iresnet200", IBasicBlock, [6, 26, 60, 6], pretrained, progress, **kwargs)
|
Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/iresnet2060.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
assert torch.__version__ >= "1.8.1"
|
5 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
6 |
+
|
7 |
+
__all__ = ["iresnet2060"]
|
8 |
+
|
9 |
+
|
10 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
11 |
+
"""3x3 convolution with padding"""
|
12 |
+
return nn.Conv2d(
|
13 |
+
in_planes,
|
14 |
+
out_planes,
|
15 |
+
kernel_size=3,
|
16 |
+
stride=stride,
|
17 |
+
padding=dilation,
|
18 |
+
groups=groups,
|
19 |
+
bias=False,
|
20 |
+
dilation=dilation,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
25 |
+
"""1x1 convolution"""
|
26 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
27 |
+
|
28 |
+
|
29 |
+
class IBasicBlock(nn.Module):
|
30 |
+
expansion = 1
|
31 |
+
|
32 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1):
|
33 |
+
super(IBasicBlock, self).__init__()
|
34 |
+
if groups != 1 or base_width != 64:
|
35 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
36 |
+
if dilation > 1:
|
37 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
38 |
+
self.bn1 = nn.BatchNorm2d(
|
39 |
+
inplanes,
|
40 |
+
eps=1e-05,
|
41 |
+
)
|
42 |
+
self.conv1 = conv3x3(inplanes, planes)
|
43 |
+
self.bn2 = nn.BatchNorm2d(
|
44 |
+
planes,
|
45 |
+
eps=1e-05,
|
46 |
+
)
|
47 |
+
self.prelu = nn.PReLU(planes)
|
48 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
49 |
+
self.bn3 = nn.BatchNorm2d(
|
50 |
+
planes,
|
51 |
+
eps=1e-05,
|
52 |
+
)
|
53 |
+
self.downsample = downsample
|
54 |
+
self.stride = stride
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
identity = x
|
58 |
+
out = self.bn1(x)
|
59 |
+
out = self.conv1(out)
|
60 |
+
out = self.bn2(out)
|
61 |
+
out = self.prelu(out)
|
62 |
+
out = self.conv2(out)
|
63 |
+
out = self.bn3(out)
|
64 |
+
if self.downsample is not None:
|
65 |
+
identity = self.downsample(x)
|
66 |
+
out += identity
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class IResNet(nn.Module):
|
71 |
+
fc_scale = 7 * 7
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
block,
|
76 |
+
layers,
|
77 |
+
dropout=0,
|
78 |
+
num_features=512,
|
79 |
+
zero_init_residual=False,
|
80 |
+
groups=1,
|
81 |
+
width_per_group=64,
|
82 |
+
replace_stride_with_dilation=None,
|
83 |
+
fp16=False,
|
84 |
+
):
|
85 |
+
super(IResNet, self).__init__()
|
86 |
+
self.fp16 = fp16
|
87 |
+
self.inplanes = 64
|
88 |
+
self.dilation = 1
|
89 |
+
if replace_stride_with_dilation is None:
|
90 |
+
replace_stride_with_dilation = [False, False, False]
|
91 |
+
if len(replace_stride_with_dilation) != 3:
|
92 |
+
raise ValueError(
|
93 |
+
"replace_stride_with_dilation should be None "
|
94 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
|
95 |
+
)
|
96 |
+
self.groups = groups
|
97 |
+
self.base_width = width_per_group
|
98 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
99 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
100 |
+
self.prelu = nn.PReLU(self.inplanes)
|
101 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
102 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
|
103 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
|
104 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
|
105 |
+
self.bn2 = nn.BatchNorm2d(
|
106 |
+
512 * block.expansion,
|
107 |
+
eps=1e-05,
|
108 |
+
)
|
109 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
110 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
111 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
112 |
+
nn.init.constant_(self.features.weight, 1.0)
|
113 |
+
self.features.weight.requires_grad = False
|
114 |
+
|
115 |
+
for m in self.modules():
|
116 |
+
if isinstance(m, nn.Conv2d):
|
117 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
118 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
119 |
+
nn.init.constant_(m.weight, 1)
|
120 |
+
nn.init.constant_(m.bias, 0)
|
121 |
+
|
122 |
+
if zero_init_residual:
|
123 |
+
for m in self.modules():
|
124 |
+
if isinstance(m, IBasicBlock):
|
125 |
+
nn.init.constant_(m.bn2.weight, 0)
|
126 |
+
|
127 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
128 |
+
downsample = None
|
129 |
+
previous_dilation = self.dilation
|
130 |
+
if dilate:
|
131 |
+
self.dilation *= stride
|
132 |
+
stride = 1
|
133 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
134 |
+
downsample = nn.Sequential(
|
135 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
136 |
+
nn.BatchNorm2d(
|
137 |
+
planes * block.expansion,
|
138 |
+
eps=1e-05,
|
139 |
+
),
|
140 |
+
)
|
141 |
+
layers = []
|
142 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation))
|
143 |
+
self.inplanes = planes * block.expansion
|
144 |
+
for _ in range(1, blocks):
|
145 |
+
layers.append(
|
146 |
+
block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation)
|
147 |
+
)
|
148 |
+
|
149 |
+
return nn.Sequential(*layers)
|
150 |
+
|
151 |
+
def checkpoint(self, func, num_seg, x):
|
152 |
+
if self.training:
|
153 |
+
return checkpoint_sequential(func, num_seg, x)
|
154 |
+
else:
|
155 |
+
return func(x)
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
with torch.cuda.amp.autocast(self.fp16):
|
159 |
+
x = self.conv1(x)
|
160 |
+
x = self.bn1(x)
|
161 |
+
x = self.prelu(x)
|
162 |
+
x = self.layer1(x)
|
163 |
+
x = self.checkpoint(self.layer2, 20, x)
|
164 |
+
x = self.checkpoint(self.layer3, 100, x)
|
165 |
+
x = self.layer4(x)
|
166 |
+
x = self.bn2(x)
|
167 |
+
x = torch.flatten(x, 1)
|
168 |
+
x = self.dropout(x)
|
169 |
+
x = self.fc(x.float() if self.fp16 else x)
|
170 |
+
x = self.features(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
175 |
+
model = IResNet(block, layers, **kwargs)
|
176 |
+
if pretrained:
|
177 |
+
raise ValueError()
|
178 |
+
return model
|
179 |
+
|
180 |
+
|
181 |
+
def iresnet2060(pretrained=False, progress=True, **kwargs):
|
182 |
+
return _iresnet("iresnet2060", IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
|
Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/mobilefacenet.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
|
3 |
+
Original author cavalleria
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import BatchNorm1d
|
8 |
+
from torch.nn import BatchNorm2d
|
9 |
+
from torch.nn import Conv2d
|
10 |
+
from torch.nn import Linear
|
11 |
+
from torch.nn import Module
|
12 |
+
from torch.nn import PReLU
|
13 |
+
from torch.nn import Sequential
|
14 |
+
|
15 |
+
|
16 |
+
class Flatten(Module):
|
17 |
+
def forward(self, x):
|
18 |
+
return x.view(x.size(0), -1)
|
19 |
+
|
20 |
+
|
21 |
+
class ConvBlock(Module):
|
22 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
23 |
+
super(ConvBlock, self).__init__()
|
24 |
+
self.layers = nn.Sequential(
|
25 |
+
Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
|
26 |
+
BatchNorm2d(num_features=out_c),
|
27 |
+
PReLU(num_parameters=out_c),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
return self.layers(x)
|
32 |
+
|
33 |
+
|
34 |
+
class LinearBlock(Module):
|
35 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
36 |
+
super(LinearBlock, self).__init__()
|
37 |
+
self.layers = nn.Sequential(
|
38 |
+
Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), BatchNorm2d(num_features=out_c)
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
return self.layers(x)
|
43 |
+
|
44 |
+
|
45 |
+
class DepthWise(Module):
|
46 |
+
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
47 |
+
super(DepthWise, self).__init__()
|
48 |
+
self.residual = residual
|
49 |
+
self.layers = nn.Sequential(
|
50 |
+
ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
|
51 |
+
ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
|
52 |
+
LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
short_cut = None
|
57 |
+
if self.residual:
|
58 |
+
short_cut = x
|
59 |
+
x = self.layers(x)
|
60 |
+
if self.residual:
|
61 |
+
output = short_cut + x
|
62 |
+
else:
|
63 |
+
output = x
|
64 |
+
return output
|
65 |
+
|
66 |
+
|
67 |
+
class Residual(Module):
|
68 |
+
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
69 |
+
super(Residual, self).__init__()
|
70 |
+
modules = []
|
71 |
+
for _ in range(num_block):
|
72 |
+
modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
|
73 |
+
self.layers = Sequential(*modules)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
return self.layers(x)
|
77 |
+
|
78 |
+
|
79 |
+
class GDC(Module):
|
80 |
+
def __init__(self, embedding_size):
|
81 |
+
super(GDC, self).__init__()
|
82 |
+
self.layers = nn.Sequential(
|
83 |
+
LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
|
84 |
+
Flatten(),
|
85 |
+
Linear(512, embedding_size, bias=False),
|
86 |
+
BatchNorm1d(embedding_size),
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
return self.layers(x)
|
91 |
+
|
92 |
+
|
93 |
+
class MobileFaceNet(Module):
|
94 |
+
def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2):
|
95 |
+
super(MobileFaceNet, self).__init__()
|
96 |
+
self.scale = scale
|
97 |
+
self.fp16 = fp16
|
98 |
+
self.layers = nn.ModuleList()
|
99 |
+
self.layers.append(ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)))
|
100 |
+
if blocks[0] == 1:
|
101 |
+
self.layers.append(
|
102 |
+
ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
self.layers.append(
|
106 |
+
Residual(
|
107 |
+
64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
|
108 |
+
),
|
109 |
+
)
|
110 |
+
|
111 |
+
self.layers.extend(
|
112 |
+
[
|
113 |
+
DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
|
114 |
+
Residual(
|
115 |
+
64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
|
116 |
+
),
|
117 |
+
DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
|
118 |
+
Residual(
|
119 |
+
128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
|
120 |
+
),
|
121 |
+
DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
|
122 |
+
Residual(
|
123 |
+
128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
|
124 |
+
),
|
125 |
+
]
|
126 |
+
)
|
127 |
+
|
128 |
+
self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
129 |
+
self.features = GDC(num_features)
|
130 |
+
self._initialize_weights()
|
131 |
+
|
132 |
+
def _initialize_weights(self):
|
133 |
+
for m in self.modules():
|
134 |
+
if isinstance(m, nn.Conv2d):
|
135 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
136 |
+
if m.bias is not None:
|
137 |
+
m.bias.data.zero_()
|
138 |
+
elif isinstance(m, nn.BatchNorm2d):
|
139 |
+
m.weight.data.fill_(1)
|
140 |
+
m.bias.data.zero_()
|
141 |
+
elif isinstance(m, nn.Linear):
|
142 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
143 |
+
if m.bias is not None:
|
144 |
+
m.bias.data.zero_()
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
with torch.cuda.amp.autocast(self.fp16):
|
148 |
+
for func in self.layers:
|
149 |
+
x = func(x)
|
150 |
+
x = self.conv_sep(x.float() if self.fp16 else x)
|
151 |
+
x = self.features(x)
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2):
|
156 |
+
return MobileFaceNet(fp16, num_features, blocks, scale=scale)
|
157 |
+
|
158 |
+
|
159 |
+
def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4):
|
160 |
+
return MobileFaceNet(fp16, num_features, blocks, scale=scale)
|
Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/vit.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from timm.models.layers import DropPath
|
7 |
+
from timm.models.layers import to_2tuple
|
8 |
+
from timm.models.layers import trunc_normal_
|
9 |
+
|
10 |
+
|
11 |
+
class Mlp(nn.Module):
|
12 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.0):
|
13 |
+
super().__init__()
|
14 |
+
out_features = out_features or in_features
|
15 |
+
hidden_features = hidden_features or in_features
|
16 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
17 |
+
self.act = act_layer()
|
18 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
19 |
+
self.drop = nn.Dropout(drop)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = self.fc1(x)
|
23 |
+
x = self.act(x)
|
24 |
+
x = self.drop(x)
|
25 |
+
x = self.fc2(x)
|
26 |
+
x = self.drop(x)
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class VITBatchNorm(nn.Module):
|
31 |
+
def __init__(self, num_features):
|
32 |
+
super().__init__()
|
33 |
+
self.num_features = num_features
|
34 |
+
self.bn = nn.BatchNorm1d(num_features=num_features)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.bn(x)
|
38 |
+
|
39 |
+
|
40 |
+
class Attention(nn.Module):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
dim: int,
|
44 |
+
num_heads: int = 8,
|
45 |
+
qkv_bias: bool = False,
|
46 |
+
qk_scale: Optional[None] = None,
|
47 |
+
attn_drop: float = 0.0,
|
48 |
+
proj_drop: float = 0.0,
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
self.num_heads = num_heads
|
52 |
+
head_dim = dim // num_heads
|
53 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
54 |
+
self.scale = qk_scale or head_dim**-0.5
|
55 |
+
|
56 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
57 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
58 |
+
self.proj = nn.Linear(dim, dim)
|
59 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
|
63 |
+
with torch.cuda.amp.autocast(True):
|
64 |
+
batch_size, num_token, embed_dim = x.shape
|
65 |
+
# qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads]
|
66 |
+
qkv = (
|
67 |
+
self.qkv(x)
|
68 |
+
.reshape(batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads)
|
69 |
+
.permute(2, 0, 3, 1, 4)
|
70 |
+
)
|
71 |
+
with torch.cuda.amp.autocast(False):
|
72 |
+
q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()
|
73 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
74 |
+
attn = attn.softmax(dim=-1)
|
75 |
+
attn = self.attn_drop(attn)
|
76 |
+
x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim)
|
77 |
+
with torch.cuda.amp.autocast(True):
|
78 |
+
x = self.proj(x)
|
79 |
+
x = self.proj_drop(x)
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class Block(nn.Module):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
dim: int,
|
87 |
+
num_heads: int,
|
88 |
+
num_patches: int,
|
89 |
+
mlp_ratio: float = 4.0,
|
90 |
+
qkv_bias: bool = False,
|
91 |
+
qk_scale: Optional[None] = None,
|
92 |
+
drop: float = 0.0,
|
93 |
+
attn_drop: float = 0.0,
|
94 |
+
drop_path: float = 0.0,
|
95 |
+
act_layer: Callable = nn.ReLU6,
|
96 |
+
norm_layer: str = "ln",
|
97 |
+
patch_n: int = 144,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
if norm_layer == "bn":
|
102 |
+
self.norm1 = VITBatchNorm(num_features=num_patches)
|
103 |
+
self.norm2 = VITBatchNorm(num_features=num_patches)
|
104 |
+
elif norm_layer == "ln":
|
105 |
+
self.norm1 = nn.LayerNorm(dim)
|
106 |
+
self.norm2 = nn.LayerNorm(dim)
|
107 |
+
|
108 |
+
self.attn = Attention(
|
109 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop
|
110 |
+
)
|
111 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
112 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
113 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
114 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
115 |
+
self.extra_gflops = (num_heads * patch_n * (dim // num_heads) * patch_n * 2) / (1000**3)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
119 |
+
with torch.cuda.amp.autocast(True):
|
120 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class PatchEmbed(nn.Module):
|
125 |
+
def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768):
|
126 |
+
super().__init__()
|
127 |
+
img_size = to_2tuple(img_size)
|
128 |
+
patch_size = to_2tuple(patch_size)
|
129 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
130 |
+
self.img_size = img_size
|
131 |
+
self.patch_size = patch_size
|
132 |
+
self.num_patches = num_patches
|
133 |
+
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
batch_size, channels, height, width = x.shape
|
137 |
+
assert (
|
138 |
+
height == self.img_size[0] and width == self.img_size[1]
|
139 |
+
), f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
140 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
class VisionTransformer(nn.Module):
|
145 |
+
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
img_size: int = 112,
|
150 |
+
patch_size: int = 16,
|
151 |
+
in_channels: int = 3,
|
152 |
+
num_classes: int = 1000,
|
153 |
+
embed_dim: int = 768,
|
154 |
+
depth: int = 12,
|
155 |
+
num_heads: int = 12,
|
156 |
+
mlp_ratio: float = 4.0,
|
157 |
+
qkv_bias: bool = False,
|
158 |
+
qk_scale: Optional[None] = None,
|
159 |
+
drop_rate: float = 0.0,
|
160 |
+
attn_drop_rate: float = 0.0,
|
161 |
+
drop_path_rate: float = 0.0,
|
162 |
+
hybrid_backbone: Optional[None] = None,
|
163 |
+
norm_layer: str = "ln",
|
164 |
+
mask_ratio=0.1,
|
165 |
+
using_checkpoint=False,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
self.num_classes = num_classes
|
169 |
+
# num_features for consistency with other models
|
170 |
+
self.num_features = self.embed_dim = embed_dim
|
171 |
+
|
172 |
+
if hybrid_backbone is not None:
|
173 |
+
raise ValueError
|
174 |
+
else:
|
175 |
+
self.patch_embed = PatchEmbed(
|
176 |
+
img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim
|
177 |
+
)
|
178 |
+
self.mask_ratio = mask_ratio
|
179 |
+
self.using_checkpoint = using_checkpoint
|
180 |
+
num_patches = self.patch_embed.num_patches
|
181 |
+
self.num_patches = num_patches
|
182 |
+
|
183 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
184 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
185 |
+
|
186 |
+
# stochastic depth decay rule
|
187 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
188 |
+
patch_n = (img_size // patch_size) ** 2
|
189 |
+
self.blocks = nn.ModuleList(
|
190 |
+
[
|
191 |
+
Block(
|
192 |
+
dim=embed_dim,
|
193 |
+
num_heads=num_heads,
|
194 |
+
mlp_ratio=mlp_ratio,
|
195 |
+
qkv_bias=qkv_bias,
|
196 |
+
qk_scale=qk_scale,
|
197 |
+
drop=drop_rate,
|
198 |
+
attn_drop=attn_drop_rate,
|
199 |
+
drop_path=dpr[i],
|
200 |
+
norm_layer=norm_layer,
|
201 |
+
num_patches=num_patches,
|
202 |
+
patch_n=patch_n,
|
203 |
+
)
|
204 |
+
for i in range(depth)
|
205 |
+
]
|
206 |
+
)
|
207 |
+
self.extra_gflops = 0.0
|
208 |
+
for _block in self.blocks:
|
209 |
+
self.extra_gflops += _block.extra_gflops
|
210 |
+
|
211 |
+
if norm_layer == "ln":
|
212 |
+
self.norm = nn.LayerNorm(embed_dim)
|
213 |
+
elif norm_layer == "bn":
|
214 |
+
self.norm = VITBatchNorm(self.num_patches)
|
215 |
+
|
216 |
+
# features head
|
217 |
+
self.feature = nn.Sequential(
|
218 |
+
nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False),
|
219 |
+
nn.BatchNorm1d(num_features=embed_dim, eps=2e-5),
|
220 |
+
nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False),
|
221 |
+
nn.BatchNorm1d(num_features=num_classes, eps=2e-5),
|
222 |
+
)
|
223 |
+
|
224 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
225 |
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
226 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
227 |
+
# trunc_normal_(self.cls_token, std=.02)
|
228 |
+
self.apply(self._init_weights)
|
229 |
+
|
230 |
+
def _init_weights(self, m):
|
231 |
+
if isinstance(m, nn.Linear):
|
232 |
+
trunc_normal_(m.weight, std=0.02)
|
233 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
234 |
+
nn.init.constant_(m.bias, 0)
|
235 |
+
elif isinstance(m, nn.LayerNorm):
|
236 |
+
nn.init.constant_(m.bias, 0)
|
237 |
+
nn.init.constant_(m.weight, 1.0)
|
238 |
+
|
239 |
+
@torch.jit.ignore
|
240 |
+
def no_weight_decay(self):
|
241 |
+
return {"pos_embed", "cls_token"}
|
242 |
+
|
243 |
+
def get_classifier(self):
|
244 |
+
return self.head
|
245 |
+
|
246 |
+
def random_masking(self, x, mask_ratio=0.1):
|
247 |
+
"""
|
248 |
+
Perform per-sample random masking by per-sample shuffling.
|
249 |
+
Per-sample shuffling is done by argsort random noise.
|
250 |
+
x: [N, L, D], sequence
|
251 |
+
"""
|
252 |
+
N, L, D = x.size() # batch, length, dim
|
253 |
+
len_keep = int(L * (1 - mask_ratio))
|
254 |
+
|
255 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
256 |
+
|
257 |
+
# sort noise for each sample
|
258 |
+
# ascend: small is keep, large is remove
|
259 |
+
ids_shuffle = torch.argsort(noise, dim=1)
|
260 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
261 |
+
|
262 |
+
# keep the first subset
|
263 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
264 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
265 |
+
|
266 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
267 |
+
mask = torch.ones([N, L], device=x.device)
|
268 |
+
mask[:, :len_keep] = 0
|
269 |
+
# unshuffle to get the binary mask
|
270 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
271 |
+
|
272 |
+
return x_masked, mask, ids_restore
|
273 |
+
|
274 |
+
def forward_features(self, x):
|
275 |
+
B = x.shape[0]
|
276 |
+
x = self.patch_embed(x)
|
277 |
+
x = x + self.pos_embed
|
278 |
+
x = self.pos_drop(x)
|
279 |
+
|
280 |
+
if self.training and self.mask_ratio > 0:
|
281 |
+
x, _, ids_restore = self.random_masking(x)
|
282 |
+
|
283 |
+
for func in self.blocks:
|
284 |
+
if self.using_checkpoint and self.training:
|
285 |
+
from torch.utils.checkpoint import checkpoint
|
286 |
+
|
287 |
+
x = checkpoint(func, x)
|
288 |
+
else:
|
289 |
+
x = func(x)
|
290 |
+
x = self.norm(x.float())
|
291 |
+
|
292 |
+
if self.training and self.mask_ratio > 0:
|
293 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
|
294 |
+
x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
|
295 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
296 |
+
x = x_
|
297 |
+
return torch.reshape(x, (B, self.num_patches * self.embed_dim))
|
298 |
+
|
299 |
+
def forward(self, x):
|
300 |
+
x = self.forward_features(x)
|
301 |
+
x = self.feature(x)
|
302 |
+
return x
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/3millions.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# configs for test speed
|
4 |
+
|
5 |
+
config = edict()
|
6 |
+
config.margin_list = (1.0, 0.0, 0.4)
|
7 |
+
config.network = "mbf"
|
8 |
+
config.resume = False
|
9 |
+
config.output = None
|
10 |
+
config.embedding_size = 512
|
11 |
+
config.sample_rate = 0.1
|
12 |
+
config.fp16 = True
|
13 |
+
config.momentum = 0.9
|
14 |
+
config.weight_decay = 5e-4
|
15 |
+
config.batch_size = 512 # total_batch_size = batch_size * num_gpus
|
16 |
+
config.lr = 0.1 # batch size is 512
|
17 |
+
|
18 |
+
config.rec = "synthetic"
|
19 |
+
config.num_classes = 30 * 10000
|
20 |
+
config.num_image = 100000
|
21 |
+
config.num_epoch = 30
|
22 |
+
config.warmup_epoch = -1
|
23 |
+
config.val_targets = []
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/__init__.py
ADDED
File without changes
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/base.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
|
9 |
+
# Margin Base Softmax
|
10 |
+
config.margin_list = (1.0, 0.5, 0.0)
|
11 |
+
config.network = "r50"
|
12 |
+
config.resume = False
|
13 |
+
config.save_all_states = False
|
14 |
+
config.output = "ms1mv3_arcface_r50"
|
15 |
+
|
16 |
+
config.embedding_size = 512
|
17 |
+
|
18 |
+
# Partial FC
|
19 |
+
config.sample_rate = 1
|
20 |
+
config.interclass_filtering_threshold = 0
|
21 |
+
|
22 |
+
config.fp16 = False
|
23 |
+
config.batch_size = 128
|
24 |
+
|
25 |
+
# For SGD
|
26 |
+
config.optimizer = "sgd"
|
27 |
+
config.lr = 0.1
|
28 |
+
config.momentum = 0.9
|
29 |
+
config.weight_decay = 5e-4
|
30 |
+
|
31 |
+
# For AdamW
|
32 |
+
# config.optimizer = "adamw"
|
33 |
+
# config.lr = 0.001
|
34 |
+
# config.weight_decay = 0.1
|
35 |
+
|
36 |
+
config.verbose = 2000
|
37 |
+
config.frequent = 10
|
38 |
+
|
39 |
+
# For Large Sacle Dataset, such as WebFace42M
|
40 |
+
config.dali = False
|
41 |
+
|
42 |
+
# Gradient ACC
|
43 |
+
config.gradient_acc = 1
|
44 |
+
|
45 |
+
# setup seed
|
46 |
+
config.seed = 2048
|
47 |
+
|
48 |
+
# dataload numworkers
|
49 |
+
config.num_workers = 2
|
50 |
+
|
51 |
+
# WandB Logger
|
52 |
+
config.wandb_key = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
|
53 |
+
config.suffix_run_name = None
|
54 |
+
config.using_wandb = False
|
55 |
+
config.wandb_entity = "entity"
|
56 |
+
config.wandb_project = "project"
|
57 |
+
config.wandb_log_all = True
|
58 |
+
config.save_artifacts = False
|
59 |
+
config.wandb_resume = False # resume wandb run: Only if the you wand t resume the last run that it was interrupted
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_mbf.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.0, 0.4)
|
9 |
+
config.network = "mbf"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 1e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/glint360k"
|
23 |
+
config.num_classes = 360232
|
24 |
+
config.num_image = 17091657
|
25 |
+
config.num_epoch = 20
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r100.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.0, 0.4)
|
9 |
+
config.network = "r100"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 1e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/glint360k"
|
23 |
+
config.num_classes = 360232
|
24 |
+
config.num_image = 17091657
|
25 |
+
config.num_epoch = 20
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r50.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.0, 0.4)
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 1e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/glint360k"
|
23 |
+
config.num_classes = 360232
|
24 |
+
config.num_image = 17091657
|
25 |
+
config.num_epoch = 20
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_mbf.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.5, 0.0)
|
9 |
+
config.network = "mbf"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 1e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/faces_emore"
|
23 |
+
config.num_classes = 85742
|
24 |
+
config.num_image = 5822653
|
25 |
+
config.num_epoch = 40
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r100.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.5, 0.0)
|
9 |
+
config.network = "r100"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/faces_emore"
|
23 |
+
config.num_classes = 85742
|
24 |
+
config.num_image = 5822653
|
25 |
+
config.num_epoch = 20
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r50.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.5, 0.0)
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/faces_emore"
|
23 |
+
config.num_classes = 85742
|
24 |
+
config.num_image = 5822653
|
25 |
+
config.num_epoch = 20
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_mbf.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.5, 0.0)
|
9 |
+
config.network = "mbf"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 1e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
23 |
+
config.num_classes = 93431
|
24 |
+
config.num_image = 5179510
|
25 |
+
config.num_epoch = 40
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r100.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.5, 0.0)
|
9 |
+
config.network = "r100"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
23 |
+
config.num_classes = 93431
|
24 |
+
config.num_image = 5179510
|
25 |
+
config.num_epoch = 20
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.5, 0.0)
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
23 |
+
config.num_classes = 93431
|
24 |
+
config.num_image = 5179510
|
25 |
+
config.num_epoch = 20
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50_onegpu.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.5, 0.0)
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.02
|
19 |
+
config.verbose = 2000
|
20 |
+
config.dali = False
|
21 |
+
|
22 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
23 |
+
config.num_classes = 93431
|
24 |
+
config.num_image = 5179510
|
25 |
+
config.num_epoch = 20
|
26 |
+
config.warmup_epoch = 0
|
27 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.0, 0.4)
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.interclass_filtering_threshold = 0
|
15 |
+
config.fp16 = True
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.optimizer = "sgd"
|
19 |
+
config.lr = 0.1
|
20 |
+
config.verbose = 2000
|
21 |
+
config.dali = False
|
22 |
+
|
23 |
+
config.rec = "/train_tmp/WebFace12M_Conflict"
|
24 |
+
config.num_classes = 1017970
|
25 |
+
config.num_image = 12720066
|
26 |
+
config.num_epoch = 20
|
27 |
+
config.warmup_epoch = config.num_epoch // 10
|
28 |
+
config.val_targets = []
|
Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.margin_list = (1.0, 0.0, 0.4)
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 0.3
|
14 |
+
config.interclass_filtering_threshold = 0.4
|
15 |
+
config.fp16 = True
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.optimizer = "sgd"
|
19 |
+
config.lr = 0.1
|
20 |
+
config.verbose = 2000
|
21 |
+
config.dali = False
|
22 |
+
|
23 |
+
config.rec = "/train_tmp/WebFace12M_Conflict"
|
24 |
+
config.num_classes = 1017970
|
25 |
+
config.num_image = 12720066
|
26 |
+
config.num_epoch = 20
|
27 |
+
config.warmup_epoch = config.num_epoch // 10
|
28 |
+
config.val_targets = []
|