Spaces:
Runtime error
Runtime error
Sanket
commited on
Commit
•
3d37b6e
1
Parent(s):
ff4715d
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -6
- LICENSE +21 -0
- README.md +32 -6
- app.py +204 -0
- e4e/.gitignore +129 -0
- e4e/criteria/__init__.py +0 -0
- e4e/criteria/id_loss.py +47 -0
- e4e/criteria/lpips/__init__.py +0 -0
- e4e/criteria/lpips/lpips.py +35 -0
- e4e/criteria/lpips/networks.py +96 -0
- e4e/criteria/lpips/utils.py +30 -0
- e4e/criteria/moco_loss.py +71 -0
- e4e/criteria/w_norm.py +14 -0
- e4e/datasets/__init__.py +0 -0
- e4e/datasets/gt_res_dataset.py +32 -0
- e4e/datasets/images_dataset.py +33 -0
- e4e/datasets/inference_dataset.py +25 -0
- e4e/editings/ganspace.py +22 -0
- e4e/editings/ganspace_pca/cars_pca.pt +3 -0
- e4e/editings/ganspace_pca/ffhq_pca.pt +3 -0
- e4e/editings/interfacegan_directions/age.pt +3 -0
- e4e/editings/interfacegan_directions/pose.pt +3 -0
- e4e/editings/interfacegan_directions/smile.pt +3 -0
- e4e/editings/latent_editor.py +45 -0
- e4e/editings/sefa.py +46 -0
- e4e/environment/e4e_env.yaml +73 -0
- e4e/metrics/LEC.py +134 -0
- e4e/models/__init__.py +0 -0
- e4e/models/discriminator.py +20 -0
- e4e/models/encoders/__init__.py +0 -0
- e4e/models/encoders/helpers.py +140 -0
- e4e/models/encoders/model_irse.py +84 -0
- e4e/models/encoders/psp_encoders.py +200 -0
- e4e/models/latent_codes_pool.py +55 -0
- e4e/models/psp.py +99 -0
- e4e/models/stylegan2/__init__.py +0 -0
- e4e/models/stylegan2/model.py +678 -0
- e4e/models/stylegan2/op/__init__.py +0 -0
- e4e/models/stylegan2/op/fused_act.py +85 -0
- e4e/models/stylegan2/op/fused_bias_act.cpp +21 -0
- e4e/models/stylegan2/op/fused_bias_act_kernel.cu +99 -0
- e4e/models/stylegan2/op/upfirdn2d.cpp +23 -0
- e4e/models/stylegan2/op/upfirdn2d.py +184 -0
- e4e/models/stylegan2/op/upfirdn2d_kernel.cu +272 -0
- e4e/notebooks/images/car_img.jpg +0 -0
- e4e/notebooks/images/church_img.jpg +0 -0
- e4e/notebooks/images/horse_img.jpg +0 -0
- e4e/notebooks/images/input_img.jpg +0 -0
- e4e/options/__init__.py +0 -0
- e4e/options/train_options.py +84 -0
.gitattributes
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
@@ -9,13 +10,9 @@
|
|
9 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
@@ -24,8 +21,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
24 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
|
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
13 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Min Jin Chong
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,38 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: JoJoGAN
|
3 |
+
emoji: 🌍
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.1.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Configuration
|
13 |
+
|
14 |
+
`title`: _string_
|
15 |
+
Display title for the Space
|
16 |
+
|
17 |
+
`emoji`: _string_
|
18 |
+
Space emoji (emoji-only character allowed)
|
19 |
+
|
20 |
+
`colorFrom`: _string_
|
21 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
22 |
+
|
23 |
+
`colorTo`: _string_
|
24 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
25 |
+
|
26 |
+
`sdk`: _string_
|
27 |
+
Can be either `gradio` or `streamlit`
|
28 |
+
|
29 |
+
`sdk_version` : _string_
|
30 |
+
Only applicable for `streamlit` SDK.
|
31 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
32 |
+
|
33 |
+
`app_file`: _string_
|
34 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
35 |
+
Path is relative to the root of the repository.
|
36 |
+
|
37 |
+
`pinned`: _boolean_
|
38 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
import torch
|
6 |
+
torch.backends.cudnn.benchmark = True
|
7 |
+
from torchvision import transforms, utils
|
8 |
+
from util import *
|
9 |
+
from PIL import Image
|
10 |
+
import math
|
11 |
+
import random
|
12 |
+
import numpy as np
|
13 |
+
from torch import nn, autograd, optim
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from tqdm import tqdm
|
16 |
+
import lpips
|
17 |
+
from model import *
|
18 |
+
|
19 |
+
|
20 |
+
#from e4e_projection import projection as e4e_projection
|
21 |
+
|
22 |
+
from copy import deepcopy
|
23 |
+
import imageio
|
24 |
+
|
25 |
+
import os
|
26 |
+
import sys
|
27 |
+
import numpy as np
|
28 |
+
from PIL import Image
|
29 |
+
import torch
|
30 |
+
import torchvision.transforms as transforms
|
31 |
+
from argparse import Namespace
|
32 |
+
from e4e.models.psp import pSp
|
33 |
+
from util import *
|
34 |
+
from huggingface_hub import hf_hub_download
|
35 |
+
|
36 |
+
device= 'cpu'
|
37 |
+
model_path_e = hf_hub_download(repo_id="akhaliq/JoJoGAN_e4e_ffhq_encode", filename="e4e_ffhq_encode.pt")
|
38 |
+
ckpt = torch.load(model_path_e, map_location='cpu')
|
39 |
+
opts = ckpt['opts']
|
40 |
+
opts['checkpoint_path'] = model_path_e
|
41 |
+
opts= Namespace(**opts)
|
42 |
+
net = pSp(opts, device).eval().to(device)
|
43 |
+
|
44 |
+
@ torch.no_grad()
|
45 |
+
def projection(img, name, device='cuda'):
|
46 |
+
|
47 |
+
|
48 |
+
transform = transforms.Compose(
|
49 |
+
[
|
50 |
+
transforms.Resize(256),
|
51 |
+
transforms.CenterCrop(256),
|
52 |
+
transforms.ToTensor(),
|
53 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
54 |
+
]
|
55 |
+
)
|
56 |
+
img = transform(img).unsqueeze(0).to(device)
|
57 |
+
images, w_plus = net(img, randomize_noise=False, return_latents=True)
|
58 |
+
result_file = {}
|
59 |
+
result_file['latent'] = w_plus[0]
|
60 |
+
torch.save(result_file, name)
|
61 |
+
return w_plus[0]
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
device = 'cpu'
|
67 |
+
|
68 |
+
|
69 |
+
latent_dim = 512
|
70 |
+
|
71 |
+
model_path_s = hf_hub_download(repo_id="akhaliq/jojogan-stylegan2-ffhq-config-f", filename="stylegan2-ffhq-config-f.pt")
|
72 |
+
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
|
73 |
+
ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage)
|
74 |
+
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
|
75 |
+
mean_latent = original_generator.mean_latent(10000)
|
76 |
+
|
77 |
+
generatorjojo = deepcopy(original_generator)
|
78 |
+
|
79 |
+
generatordisney = deepcopy(original_generator)
|
80 |
+
|
81 |
+
generatorjinx = deepcopy(original_generator)
|
82 |
+
|
83 |
+
generatorcaitlyn = deepcopy(original_generator)
|
84 |
+
|
85 |
+
generatoryasuho = deepcopy(original_generator)
|
86 |
+
|
87 |
+
generatorarcanemulti = deepcopy(original_generator)
|
88 |
+
|
89 |
+
generatorart = deepcopy(original_generator)
|
90 |
+
|
91 |
+
generatorspider = deepcopy(original_generator)
|
92 |
+
|
93 |
+
generatorsketch = deepcopy(original_generator)
|
94 |
+
|
95 |
+
|
96 |
+
transform = transforms.Compose(
|
97 |
+
[
|
98 |
+
transforms.Resize((1024, 1024)),
|
99 |
+
transforms.ToTensor(),
|
100 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
101 |
+
]
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
|
108 |
+
|
109 |
+
|
110 |
+
ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
|
111 |
+
generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
|
112 |
+
|
113 |
+
|
114 |
+
modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
|
115 |
+
|
116 |
+
ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
|
117 |
+
generatordisney.load_state_dict(ckptdisney["g"], strict=False)
|
118 |
+
|
119 |
+
|
120 |
+
modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
|
121 |
+
|
122 |
+
ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
|
123 |
+
generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
|
124 |
+
|
125 |
+
|
126 |
+
modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
|
127 |
+
|
128 |
+
ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
|
129 |
+
generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
|
130 |
+
|
131 |
+
|
132 |
+
modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
|
133 |
+
|
134 |
+
ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
|
135 |
+
generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
|
136 |
+
|
137 |
+
|
138 |
+
model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
|
139 |
+
|
140 |
+
ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
|
141 |
+
generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
|
142 |
+
|
143 |
+
|
144 |
+
modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
|
145 |
+
|
146 |
+
ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
|
147 |
+
generatorart.load_state_dict(ckptart["g"], strict=False)
|
148 |
+
|
149 |
+
|
150 |
+
modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
|
151 |
+
|
152 |
+
ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
|
153 |
+
generatorspider.load_state_dict(ckptspider["g"], strict=False)
|
154 |
+
|
155 |
+
modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
|
156 |
+
|
157 |
+
ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
|
158 |
+
generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
|
159 |
+
|
160 |
+
def inference(img, model):
|
161 |
+
img.save('out.jpg')
|
162 |
+
aligned_face = align_face('out.jpg')
|
163 |
+
|
164 |
+
my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
|
165 |
+
if model == 'JoJo':
|
166 |
+
with torch.no_grad():
|
167 |
+
my_sample = generatorjojo(my_w, input_is_latent=True)
|
168 |
+
elif model == 'Disney':
|
169 |
+
with torch.no_grad():
|
170 |
+
my_sample = generatordisney(my_w, input_is_latent=True)
|
171 |
+
elif model == 'Jinx':
|
172 |
+
with torch.no_grad():
|
173 |
+
my_sample = generatorjinx(my_w, input_is_latent=True)
|
174 |
+
elif model == 'Caitlyn':
|
175 |
+
with torch.no_grad():
|
176 |
+
my_sample = generatorcaitlyn(my_w, input_is_latent=True)
|
177 |
+
elif model == 'Yasuho':
|
178 |
+
with torch.no_grad():
|
179 |
+
my_sample = generatoryasuho(my_w, input_is_latent=True)
|
180 |
+
elif model == 'Arcane Multi':
|
181 |
+
with torch.no_grad():
|
182 |
+
my_sample = generatorarcanemulti(my_w, input_is_latent=True)
|
183 |
+
elif model == 'Art':
|
184 |
+
with torch.no_grad():
|
185 |
+
my_sample = generatorart(my_w, input_is_latent=True)
|
186 |
+
elif model == 'Spider-Verse':
|
187 |
+
with torch.no_grad():
|
188 |
+
my_sample = generatorspider(my_w, input_is_latent=True)
|
189 |
+
else:
|
190 |
+
with torch.no_grad():
|
191 |
+
my_sample = generatorsketch(my_w, input_is_latent=True)
|
192 |
+
|
193 |
+
|
194 |
+
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
|
195 |
+
imageio.imwrite('filename.jpeg', npimage)
|
196 |
+
return 'filename.jpeg'
|
197 |
+
|
198 |
+
title = "JoJoGAN"
|
199 |
+
description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
200 |
+
|
201 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
|
202 |
+
|
203 |
+
examples=[['mona.png','Jinx']]
|
204 |
+
gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
|
e4e/.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
e4e/criteria/__init__.py
ADDED
File without changes
|
e4e/criteria/id_loss.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from configs.paths_config import model_paths
|
4 |
+
from models.encoders.model_irse import Backbone
|
5 |
+
|
6 |
+
|
7 |
+
class IDLoss(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(IDLoss, self).__init__()
|
10 |
+
print('Loading ResNet ArcFace')
|
11 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
12 |
+
self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
|
13 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
14 |
+
self.facenet.eval()
|
15 |
+
for module in [self.facenet, self.face_pool]:
|
16 |
+
for param in module.parameters():
|
17 |
+
param.requires_grad = False
|
18 |
+
|
19 |
+
def extract_feats(self, x):
|
20 |
+
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
21 |
+
x = self.face_pool(x)
|
22 |
+
x_feats = self.facenet(x)
|
23 |
+
return x_feats
|
24 |
+
|
25 |
+
def forward(self, y_hat, y, x):
|
26 |
+
n_samples = x.shape[0]
|
27 |
+
x_feats = self.extract_feats(x)
|
28 |
+
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
29 |
+
y_hat_feats = self.extract_feats(y_hat)
|
30 |
+
y_feats = y_feats.detach()
|
31 |
+
loss = 0
|
32 |
+
sim_improvement = 0
|
33 |
+
id_logs = []
|
34 |
+
count = 0
|
35 |
+
for i in range(n_samples):
|
36 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
37 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
38 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
39 |
+
id_logs.append({'diff_target': float(diff_target),
|
40 |
+
'diff_input': float(diff_input),
|
41 |
+
'diff_views': float(diff_views)})
|
42 |
+
loss += 1 - diff_target
|
43 |
+
id_diff = float(diff_target) - float(diff_views)
|
44 |
+
sim_improvement += id_diff
|
45 |
+
count += 1
|
46 |
+
|
47 |
+
return loss / count, sim_improvement / count, id_logs
|
e4e/criteria/lpips/__init__.py
ADDED
File without changes
|
e4e/criteria/lpips/lpips.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from criteria.lpips.networks import get_network, LinLayers
|
5 |
+
from criteria.lpips.utils import get_state_dict
|
6 |
+
|
7 |
+
|
8 |
+
class LPIPS(nn.Module):
|
9 |
+
r"""Creates a criterion that measures
|
10 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
11 |
+
Arguments:
|
12 |
+
net_type (str): the network type to compare the features:
|
13 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
14 |
+
version (str): the version of LPIPS. Default: 0.1.
|
15 |
+
"""
|
16 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
17 |
+
|
18 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
19 |
+
|
20 |
+
super(LPIPS, self).__init__()
|
21 |
+
|
22 |
+
# pretrained network
|
23 |
+
self.net = get_network(net_type).to("cuda")
|
24 |
+
|
25 |
+
# linear layers
|
26 |
+
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
|
27 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
30 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
31 |
+
|
32 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
33 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
34 |
+
|
35 |
+
return torch.sum(torch.cat(res, 0)) / x.shape[0]
|
e4e/criteria/lpips/networks.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
|
3 |
+
from itertools import chain
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from criteria.lpips.utils import normalize_activation
|
10 |
+
|
11 |
+
|
12 |
+
def get_network(net_type: str):
|
13 |
+
if net_type == 'alex':
|
14 |
+
return AlexNet()
|
15 |
+
elif net_type == 'squeeze':
|
16 |
+
return SqueezeNet()
|
17 |
+
elif net_type == 'vgg':
|
18 |
+
return VGG16()
|
19 |
+
else:
|
20 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
21 |
+
|
22 |
+
|
23 |
+
class LinLayers(nn.ModuleList):
|
24 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
25 |
+
super(LinLayers, self).__init__([
|
26 |
+
nn.Sequential(
|
27 |
+
nn.Identity(),
|
28 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
29 |
+
) for nc in n_channels_list
|
30 |
+
])
|
31 |
+
|
32 |
+
for param in self.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
|
36 |
+
class BaseNet(nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super(BaseNet, self).__init__()
|
39 |
+
|
40 |
+
# register buffer
|
41 |
+
self.register_buffer(
|
42 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
43 |
+
self.register_buffer(
|
44 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
45 |
+
|
46 |
+
def set_requires_grad(self, state: bool):
|
47 |
+
for param in chain(self.parameters(), self.buffers()):
|
48 |
+
param.requires_grad = state
|
49 |
+
|
50 |
+
def z_score(self, x: torch.Tensor):
|
51 |
+
return (x - self.mean) / self.std
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor):
|
54 |
+
x = self.z_score(x)
|
55 |
+
|
56 |
+
output = []
|
57 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
58 |
+
x = layer(x)
|
59 |
+
if i in self.target_layers:
|
60 |
+
output.append(normalize_activation(x))
|
61 |
+
if len(output) == len(self.target_layers):
|
62 |
+
break
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class SqueezeNet(BaseNet):
|
67 |
+
def __init__(self):
|
68 |
+
super(SqueezeNet, self).__init__()
|
69 |
+
|
70 |
+
self.layers = models.squeezenet1_1(True).features
|
71 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
72 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
73 |
+
|
74 |
+
self.set_requires_grad(False)
|
75 |
+
|
76 |
+
|
77 |
+
class AlexNet(BaseNet):
|
78 |
+
def __init__(self):
|
79 |
+
super(AlexNet, self).__init__()
|
80 |
+
|
81 |
+
self.layers = models.alexnet(True).features
|
82 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
83 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
84 |
+
|
85 |
+
self.set_requires_grad(False)
|
86 |
+
|
87 |
+
|
88 |
+
class VGG16(BaseNet):
|
89 |
+
def __init__(self):
|
90 |
+
super(VGG16, self).__init__()
|
91 |
+
|
92 |
+
self.layers = models.vgg16(True).features
|
93 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
94 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
95 |
+
|
96 |
+
self.set_requires_grad(False)
|
e4e/criteria/lpips/utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def normalize_activation(x, eps=1e-10):
|
7 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
8 |
+
return x / (norm_factor + eps)
|
9 |
+
|
10 |
+
|
11 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
12 |
+
# build url
|
13 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
14 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
15 |
+
|
16 |
+
# download
|
17 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
18 |
+
url, progress=True,
|
19 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
20 |
+
)
|
21 |
+
|
22 |
+
# rename keys
|
23 |
+
new_state_dict = OrderedDict()
|
24 |
+
for key, val in old_state_dict.items():
|
25 |
+
new_key = key
|
26 |
+
new_key = new_key.replace('lin', '')
|
27 |
+
new_key = new_key.replace('model.', '')
|
28 |
+
new_state_dict[new_key] = val
|
29 |
+
|
30 |
+
return new_state_dict
|
e4e/criteria/moco_loss.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from configs.paths_config import model_paths
|
6 |
+
|
7 |
+
|
8 |
+
class MocoLoss(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, opts):
|
11 |
+
super(MocoLoss, self).__init__()
|
12 |
+
print("Loading MOCO model from path: {}".format(model_paths["moco"]))
|
13 |
+
self.model = self.__load_model()
|
14 |
+
self.model.eval()
|
15 |
+
for param in self.model.parameters():
|
16 |
+
param.requires_grad = False
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def __load_model():
|
20 |
+
import torchvision.models as models
|
21 |
+
model = models.__dict__["resnet50"]()
|
22 |
+
# freeze all layers but the last fc
|
23 |
+
for name, param in model.named_parameters():
|
24 |
+
if name not in ['fc.weight', 'fc.bias']:
|
25 |
+
param.requires_grad = False
|
26 |
+
checkpoint = torch.load(model_paths['moco'], map_location="cpu")
|
27 |
+
state_dict = checkpoint['state_dict']
|
28 |
+
# rename moco pre-trained keys
|
29 |
+
for k in list(state_dict.keys()):
|
30 |
+
# retain only encoder_q up to before the embedding layer
|
31 |
+
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
32 |
+
# remove prefix
|
33 |
+
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
34 |
+
# delete renamed or unused k
|
35 |
+
del state_dict[k]
|
36 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
37 |
+
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
|
38 |
+
# remove output layer
|
39 |
+
model = nn.Sequential(*list(model.children())[:-1]).cuda()
|
40 |
+
return model
|
41 |
+
|
42 |
+
def extract_feats(self, x):
|
43 |
+
x = F.interpolate(x, size=224)
|
44 |
+
x_feats = self.model(x)
|
45 |
+
x_feats = nn.functional.normalize(x_feats, dim=1)
|
46 |
+
x_feats = x_feats.squeeze()
|
47 |
+
return x_feats
|
48 |
+
|
49 |
+
def forward(self, y_hat, y, x):
|
50 |
+
n_samples = x.shape[0]
|
51 |
+
x_feats = self.extract_feats(x)
|
52 |
+
y_feats = self.extract_feats(y)
|
53 |
+
y_hat_feats = self.extract_feats(y_hat)
|
54 |
+
y_feats = y_feats.detach()
|
55 |
+
loss = 0
|
56 |
+
sim_improvement = 0
|
57 |
+
sim_logs = []
|
58 |
+
count = 0
|
59 |
+
for i in range(n_samples):
|
60 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
61 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
62 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
63 |
+
sim_logs.append({'diff_target': float(diff_target),
|
64 |
+
'diff_input': float(diff_input),
|
65 |
+
'diff_views': float(diff_views)})
|
66 |
+
loss += 1 - diff_target
|
67 |
+
sim_diff = float(diff_target) - float(diff_views)
|
68 |
+
sim_improvement += sim_diff
|
69 |
+
count += 1
|
70 |
+
|
71 |
+
return loss / count, sim_improvement / count, sim_logs
|
e4e/criteria/w_norm.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class WNormLoss(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, start_from_latent_avg=True):
|
8 |
+
super(WNormLoss, self).__init__()
|
9 |
+
self.start_from_latent_avg = start_from_latent_avg
|
10 |
+
|
11 |
+
def forward(self, latent, latent_avg=None):
|
12 |
+
if self.start_from_latent_avg:
|
13 |
+
latent = latent - latent_avg
|
14 |
+
return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
|
e4e/datasets/__init__.py
ADDED
File without changes
|
e4e/datasets/gt_res_dataset.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# encoding: utf-8
|
3 |
+
import os
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
|
8 |
+
class GTResDataset(Dataset):
|
9 |
+
|
10 |
+
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
|
11 |
+
self.pairs = []
|
12 |
+
for f in os.listdir(root_path):
|
13 |
+
image_path = os.path.join(root_path, f)
|
14 |
+
gt_path = os.path.join(gt_dir, f)
|
15 |
+
if f.endswith(".jpg") or f.endswith(".png"):
|
16 |
+
self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
|
17 |
+
self.transform = transform
|
18 |
+
self.transform_train = transform_train
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.pairs)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
from_path, to_path, _ = self.pairs[index]
|
25 |
+
from_im = Image.open(from_path).convert('RGB')
|
26 |
+
to_im = Image.open(to_path).convert('RGB')
|
27 |
+
|
28 |
+
if self.transform:
|
29 |
+
to_im = self.transform(to_im)
|
30 |
+
from_im = self.transform(from_im)
|
31 |
+
|
32 |
+
return from_im, to_im
|
e4e/datasets/images_dataset.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class ImagesDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
|
9 |
+
self.source_paths = sorted(data_utils.make_dataset(source_root))
|
10 |
+
self.target_paths = sorted(data_utils.make_dataset(target_root))
|
11 |
+
self.source_transform = source_transform
|
12 |
+
self.target_transform = target_transform
|
13 |
+
self.opts = opts
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.source_paths)
|
17 |
+
|
18 |
+
def __getitem__(self, index):
|
19 |
+
from_path = self.source_paths[index]
|
20 |
+
from_im = Image.open(from_path)
|
21 |
+
from_im = from_im.convert('RGB')
|
22 |
+
|
23 |
+
to_path = self.target_paths[index]
|
24 |
+
to_im = Image.open(to_path).convert('RGB')
|
25 |
+
if self.target_transform:
|
26 |
+
to_im = self.target_transform(to_im)
|
27 |
+
|
28 |
+
if self.source_transform:
|
29 |
+
from_im = self.source_transform(from_im)
|
30 |
+
else:
|
31 |
+
from_im = to_im
|
32 |
+
|
33 |
+
return from_im, to_im
|
e4e/datasets/inference_dataset.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class InferenceDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, root, opts, transform=None, preprocess=None):
|
9 |
+
self.paths = sorted(data_utils.make_dataset(root))
|
10 |
+
self.transform = transform
|
11 |
+
self.preprocess = preprocess
|
12 |
+
self.opts = opts
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.paths)
|
16 |
+
|
17 |
+
def __getitem__(self, index):
|
18 |
+
from_path = self.paths[index]
|
19 |
+
if self.preprocess is not None:
|
20 |
+
from_im = self.preprocess(from_path)
|
21 |
+
else:
|
22 |
+
from_im = Image.open(from_path).convert('RGB')
|
23 |
+
if self.transform:
|
24 |
+
from_im = self.transform(from_im)
|
25 |
+
return from_im
|
e4e/editings/ganspace.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def edit(latents, pca, edit_directions):
|
5 |
+
edit_latents = []
|
6 |
+
for latent in latents:
|
7 |
+
for pca_idx, start, end, strength in edit_directions:
|
8 |
+
delta = get_delta(pca, latent, pca_idx, strength)
|
9 |
+
delta_padded = torch.zeros(latent.shape).to('cuda')
|
10 |
+
delta_padded[start:end] += delta.repeat(end - start, 1)
|
11 |
+
edit_latents.append(latent + delta_padded)
|
12 |
+
return torch.stack(edit_latents)
|
13 |
+
|
14 |
+
|
15 |
+
def get_delta(pca, latent, idx, strength):
|
16 |
+
# pca: ganspace checkpoint. latent: (16, 512) w+
|
17 |
+
w_centered = latent - pca['mean'].to('cuda')
|
18 |
+
lat_comp = pca['comp'].to('cuda')
|
19 |
+
lat_std = pca['std'].to('cuda')
|
20 |
+
w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
|
21 |
+
delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
|
22 |
+
return delta
|
e4e/editings/ganspace_pca/cars_pca.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392
|
3 |
+
size 167562
|
e4e/editings/ganspace_pca/ffhq_pca.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36
|
3 |
+
size 167562
|
e4e/editings/interfacegan_directions/age.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0
|
3 |
+
size 2808
|
e4e/editings/interfacegan_directions/pose.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:736e0eacc8488fa0b020a2e7bd256b957284c364191dfea693705e5d06d43e7d
|
3 |
+
size 37624
|
e4e/editings/interfacegan_directions/smile.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:817a7e732b59dee9eba862bec8bd7e8373568443bc9f9731a21cf9b0356f0653
|
3 |
+
size 2808
|
e4e/editings/latent_editor.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import sys
|
3 |
+
sys.path.append(".")
|
4 |
+
sys.path.append("..")
|
5 |
+
from editings import ganspace, sefa
|
6 |
+
from utils.common import tensor2im
|
7 |
+
|
8 |
+
|
9 |
+
class LatentEditor(object):
|
10 |
+
def __init__(self, stylegan_generator, is_cars=False):
|
11 |
+
self.generator = stylegan_generator
|
12 |
+
self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output.
|
13 |
+
|
14 |
+
def apply_ganspace(self, latent, ganspace_pca, edit_directions):
|
15 |
+
edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions)
|
16 |
+
return self._latents_to_image(edit_latents)
|
17 |
+
|
18 |
+
def apply_interfacegan(self, latent, direction, factor=1, factor_range=None):
|
19 |
+
edit_latents = []
|
20 |
+
if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5)
|
21 |
+
for f in range(*factor_range):
|
22 |
+
edit_latent = latent + f * direction
|
23 |
+
edit_latents.append(edit_latent)
|
24 |
+
edit_latents = torch.cat(edit_latents)
|
25 |
+
else:
|
26 |
+
edit_latents = latent + factor * direction
|
27 |
+
return self._latents_to_image(edit_latents)
|
28 |
+
|
29 |
+
def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs):
|
30 |
+
edit_latents = sefa.edit(self.generator, latent, indices, **kwargs)
|
31 |
+
return self._latents_to_image(edit_latents)
|
32 |
+
|
33 |
+
# Currently, in order to apply StyleFlow editings, one should run inference,
|
34 |
+
# save the latent codes and load them form the official StyleFlow repository.
|
35 |
+
# def apply_styleflow(self):
|
36 |
+
# pass
|
37 |
+
|
38 |
+
def _latents_to_image(self, latents):
|
39 |
+
with torch.no_grad():
|
40 |
+
images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True)
|
41 |
+
if self.is_cars:
|
42 |
+
images = images[:, :, 64:448, :] # 512x512 -> 384x512
|
43 |
+
horizontal_concat_image = torch.cat(list(images), 2)
|
44 |
+
final_image = tensor2im(horizontal_concat_image)
|
45 |
+
return final_image
|
e4e/editings/sefa.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11):
|
7 |
+
|
8 |
+
layers, boundaries, values = factorize_weight(generator, indices)
|
9 |
+
codes = latents.detach().cpu().numpy() # (1,18,512)
|
10 |
+
|
11 |
+
# Generate visualization pages.
|
12 |
+
distances = np.linspace(start_distance, end_distance, step)
|
13 |
+
num_sam = num_samples
|
14 |
+
num_sem = semantics
|
15 |
+
|
16 |
+
edited_latents = []
|
17 |
+
for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
|
18 |
+
boundary = boundaries[sem_id:sem_id + 1]
|
19 |
+
for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
|
20 |
+
code = codes[sam_id:sam_id + 1]
|
21 |
+
for col_id, d in enumerate(distances, start=1):
|
22 |
+
temp_code = code.copy()
|
23 |
+
temp_code[:, layers, :] += boundary * d
|
24 |
+
edited_latents.append(torch.from_numpy(temp_code).float().cuda())
|
25 |
+
return torch.cat(edited_latents)
|
26 |
+
|
27 |
+
|
28 |
+
def factorize_weight(g_ema, layers='all'):
|
29 |
+
|
30 |
+
weights = []
|
31 |
+
if layers == 'all' or 0 in layers:
|
32 |
+
weight = g_ema.conv1.conv.modulation.weight.T
|
33 |
+
weights.append(weight.cpu().detach().numpy())
|
34 |
+
|
35 |
+
if layers == 'all':
|
36 |
+
layers = list(range(g_ema.num_layers - 1))
|
37 |
+
else:
|
38 |
+
layers = [l - 1 for l in layers if l != 0]
|
39 |
+
|
40 |
+
for idx in layers:
|
41 |
+
weight = g_ema.convs[idx].conv.modulation.weight.T
|
42 |
+
weights.append(weight.cpu().detach().numpy())
|
43 |
+
weight = np.concatenate(weights, axis=1).astype(np.float32)
|
44 |
+
weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
|
45 |
+
eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
|
46 |
+
return layers, eigen_vectors.T, eigen_values
|
e4e/environment/e4e_env.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: e4e_env
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- ca-certificates=2020.4.5.1=hecc5488_0
|
8 |
+
- certifi=2020.4.5.1=py36h9f0ad1d_0
|
9 |
+
- libedit=3.1.20181209=hc058e9b_0
|
10 |
+
- libffi=3.2.1=hd88cf55_4
|
11 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
12 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
13 |
+
- ncurses=6.2=he6710b0_1
|
14 |
+
- ninja=1.10.0=hc9558a2_0
|
15 |
+
- openssl=1.1.1g=h516909a_0
|
16 |
+
- pip=20.0.2=py36_3
|
17 |
+
- python=3.6.7=h0371630_0
|
18 |
+
- python_abi=3.6=1_cp36m
|
19 |
+
- readline=7.0=h7b6447c_5
|
20 |
+
- setuptools=46.4.0=py36_0
|
21 |
+
- sqlite=3.31.1=h62c20be_1
|
22 |
+
- tk=8.6.8=hbc83047_0
|
23 |
+
- wheel=0.34.2=py36_0
|
24 |
+
- xz=5.2.5=h7b6447c_0
|
25 |
+
- zlib=1.2.11=h7b6447c_3
|
26 |
+
- pip:
|
27 |
+
- absl-py==0.9.0
|
28 |
+
- cachetools==4.1.0
|
29 |
+
- chardet==3.0.4
|
30 |
+
- cycler==0.10.0
|
31 |
+
- decorator==4.4.2
|
32 |
+
- future==0.18.2
|
33 |
+
- google-auth==1.15.0
|
34 |
+
- google-auth-oauthlib==0.4.1
|
35 |
+
- grpcio==1.29.0
|
36 |
+
- idna==2.9
|
37 |
+
- imageio==2.8.0
|
38 |
+
- importlib-metadata==1.6.0
|
39 |
+
- kiwisolver==1.2.0
|
40 |
+
- markdown==3.2.2
|
41 |
+
- matplotlib==3.2.1
|
42 |
+
- mxnet==1.6.0
|
43 |
+
- networkx==2.4
|
44 |
+
- numpy==1.18.4
|
45 |
+
- oauthlib==3.1.0
|
46 |
+
- opencv-python==4.2.0.34
|
47 |
+
- pillow==7.1.2
|
48 |
+
- protobuf==3.12.1
|
49 |
+
- pyasn1==0.4.8
|
50 |
+
- pyasn1-modules==0.2.8
|
51 |
+
- pyparsing==2.4.7
|
52 |
+
- python-dateutil==2.8.1
|
53 |
+
- pytorch-lightning==0.7.1
|
54 |
+
- pywavelets==1.1.1
|
55 |
+
- requests==2.23.0
|
56 |
+
- requests-oauthlib==1.3.0
|
57 |
+
- rsa==4.0
|
58 |
+
- scikit-image==0.17.2
|
59 |
+
- scipy==1.4.1
|
60 |
+
- six==1.15.0
|
61 |
+
- tensorboard==2.2.1
|
62 |
+
- tensorboard-plugin-wit==1.6.0.post3
|
63 |
+
- tensorboardx==1.9
|
64 |
+
- tifffile==2020.5.25
|
65 |
+
- torch==1.6.0
|
66 |
+
- torchvision==0.7.1
|
67 |
+
- tqdm==4.46.0
|
68 |
+
- urllib3==1.25.9
|
69 |
+
- werkzeug==1.0.1
|
70 |
+
- zipp==3.1.0
|
71 |
+
- pyaml
|
72 |
+
prefix: ~/anaconda3/envs/e4e_env
|
73 |
+
|
e4e/metrics/LEC.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
sys.path.append(".")
|
8 |
+
sys.path.append("..")
|
9 |
+
|
10 |
+
from configs import data_configs
|
11 |
+
from datasets.images_dataset import ImagesDataset
|
12 |
+
from utils.model_utils import setup_model
|
13 |
+
|
14 |
+
|
15 |
+
class LEC:
|
16 |
+
def __init__(self, net, is_cars=False):
|
17 |
+
"""
|
18 |
+
Latent Editing Consistency metric as proposed in the main paper.
|
19 |
+
:param net: e4e model loaded over the pSp framework.
|
20 |
+
:param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images.
|
21 |
+
"""
|
22 |
+
self.net = net
|
23 |
+
self.is_cars = is_cars
|
24 |
+
|
25 |
+
def _encode(self, images):
|
26 |
+
"""
|
27 |
+
Encodes the given images into StyleGAN's latent space.
|
28 |
+
:param images: Tensor of shape NxCxHxW representing the images to be encoded.
|
29 |
+
:return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space).
|
30 |
+
"""
|
31 |
+
codes = self.net.encoder(images)
|
32 |
+
assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}"
|
33 |
+
# normalize with respect to the center of an average face
|
34 |
+
if self.net.opts.start_from_latent_avg:
|
35 |
+
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
|
36 |
+
return codes
|
37 |
+
|
38 |
+
def _generate(self, codes):
|
39 |
+
"""
|
40 |
+
Generate the StyleGAN2 images of the given codes
|
41 |
+
:param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space).
|
42 |
+
:return: Tensor of shape NxCxHxW representing the generated images.
|
43 |
+
"""
|
44 |
+
images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True)
|
45 |
+
images = self.net.face_pool(images)
|
46 |
+
if self.is_cars:
|
47 |
+
images = images[:, :, 32:224, :]
|
48 |
+
return images
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def _filter_outliers(arr):
|
52 |
+
arr = np.array(arr)
|
53 |
+
|
54 |
+
lo = np.percentile(arr, 1, interpolation="lower")
|
55 |
+
hi = np.percentile(arr, 99, interpolation="higher")
|
56 |
+
return np.extract(
|
57 |
+
np.logical_and(lo <= arr, arr <= hi), arr
|
58 |
+
)
|
59 |
+
|
60 |
+
def calculate_metric(self, data_loader, edit_function, inverse_edit_function):
|
61 |
+
"""
|
62 |
+
Calculate the LEC metric score.
|
63 |
+
:param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader.
|
64 |
+
:param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the
|
65 |
+
latent space.
|
66 |
+
:param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the
|
67 |
+
`edit_function` parameter.
|
68 |
+
:return: The LEC metric score.
|
69 |
+
"""
|
70 |
+
distances = []
|
71 |
+
with torch.no_grad():
|
72 |
+
for batch in data_loader:
|
73 |
+
x, _ = batch
|
74 |
+
inputs = x.to(device).float()
|
75 |
+
|
76 |
+
codes = self._encode(inputs)
|
77 |
+
edited_codes = edit_function(codes)
|
78 |
+
edited_image = self._generate(edited_codes)
|
79 |
+
edited_image_inversion_codes = self._encode(edited_image)
|
80 |
+
inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes)
|
81 |
+
|
82 |
+
dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean()
|
83 |
+
distances.append(dist.to("cpu").numpy())
|
84 |
+
|
85 |
+
distances = self._filter_outliers(distances)
|
86 |
+
return distances.mean()
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
device = "cuda"
|
91 |
+
|
92 |
+
parser = argparse.ArgumentParser(description="LEC metric calculator")
|
93 |
+
|
94 |
+
parser.add_argument("--batch", type=int, default=8, help="batch size for the models")
|
95 |
+
parser.add_argument("--images_dir", type=str, default=None,
|
96 |
+
help="Path to the images directory on which we calculate the LEC score")
|
97 |
+
parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints")
|
98 |
+
|
99 |
+
args = parser.parse_args()
|
100 |
+
print(args)
|
101 |
+
|
102 |
+
net, opts = setup_model(args.ckpt, device)
|
103 |
+
dataset_args = data_configs.DATASETS[opts.dataset_type]
|
104 |
+
transforms_dict = dataset_args['transforms'](opts).get_transforms()
|
105 |
+
|
106 |
+
images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir
|
107 |
+
test_dataset = ImagesDataset(source_root=images_directory,
|
108 |
+
target_root=images_directory,
|
109 |
+
source_transform=transforms_dict['transform_source'],
|
110 |
+
target_transform=transforms_dict['transform_test'],
|
111 |
+
opts=opts)
|
112 |
+
|
113 |
+
data_loader = DataLoader(test_dataset,
|
114 |
+
batch_size=args.batch,
|
115 |
+
shuffle=False,
|
116 |
+
num_workers=2,
|
117 |
+
drop_last=True)
|
118 |
+
|
119 |
+
print(f'dataset length: {len(test_dataset)}')
|
120 |
+
|
121 |
+
# In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric.
|
122 |
+
# Change the provided example according to your domain and needs.
|
123 |
+
direction = torch.load('../editings/interfacegan_directions/age.pt').to(device)
|
124 |
+
|
125 |
+
def edit_func_example(codes):
|
126 |
+
return codes + 3 * direction
|
127 |
+
|
128 |
+
|
129 |
+
def inverse_edit_func_example(codes):
|
130 |
+
return codes - 3 * direction
|
131 |
+
|
132 |
+
lec = LEC(net, is_cars='car' in opts.dataset_type)
|
133 |
+
result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example)
|
134 |
+
print(f"LEC: {result}")
|
e4e/models/__init__.py
ADDED
File without changes
|
e4e/models/discriminator.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class LatentCodesDiscriminator(nn.Module):
|
5 |
+
def __init__(self, style_dim, n_mlp):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
self.style_dim = style_dim
|
9 |
+
|
10 |
+
layers = []
|
11 |
+
for i in range(n_mlp-1):
|
12 |
+
layers.append(
|
13 |
+
nn.Linear(style_dim, style_dim)
|
14 |
+
)
|
15 |
+
layers.append(nn.LeakyReLU(0.2))
|
16 |
+
layers.append(nn.Linear(512, 1))
|
17 |
+
self.mlp = nn.Sequential(*layers)
|
18 |
+
|
19 |
+
def forward(self, w):
|
20 |
+
return self.mlp(w)
|
e4e/models/encoders/__init__.py
ADDED
File without changes
|
e4e/models/encoders/helpers.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
5 |
+
|
6 |
+
"""
|
7 |
+
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
8 |
+
"""
|
9 |
+
|
10 |
+
|
11 |
+
class Flatten(Module):
|
12 |
+
def forward(self, input):
|
13 |
+
return input.view(input.size(0), -1)
|
14 |
+
|
15 |
+
|
16 |
+
def l2_norm(input, axis=1):
|
17 |
+
norm = torch.norm(input, 2, axis, True)
|
18 |
+
output = torch.div(input, norm)
|
19 |
+
return output
|
20 |
+
|
21 |
+
|
22 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
23 |
+
""" A named tuple describing a ResNet block. """
|
24 |
+
|
25 |
+
|
26 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
27 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
28 |
+
|
29 |
+
|
30 |
+
def get_blocks(num_layers):
|
31 |
+
if num_layers == 50:
|
32 |
+
blocks = [
|
33 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
34 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
35 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
36 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
37 |
+
]
|
38 |
+
elif num_layers == 100:
|
39 |
+
blocks = [
|
40 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
41 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
42 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
43 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
44 |
+
]
|
45 |
+
elif num_layers == 152:
|
46 |
+
blocks = [
|
47 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
48 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
49 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
50 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
51 |
+
]
|
52 |
+
else:
|
53 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
54 |
+
return blocks
|
55 |
+
|
56 |
+
|
57 |
+
class SEModule(Module):
|
58 |
+
def __init__(self, channels, reduction):
|
59 |
+
super(SEModule, self).__init__()
|
60 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
61 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
62 |
+
self.relu = ReLU(inplace=True)
|
63 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
64 |
+
self.sigmoid = Sigmoid()
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
module_input = x
|
68 |
+
x = self.avg_pool(x)
|
69 |
+
x = self.fc1(x)
|
70 |
+
x = self.relu(x)
|
71 |
+
x = self.fc2(x)
|
72 |
+
x = self.sigmoid(x)
|
73 |
+
return module_input * x
|
74 |
+
|
75 |
+
|
76 |
+
class bottleneck_IR(Module):
|
77 |
+
def __init__(self, in_channel, depth, stride):
|
78 |
+
super(bottleneck_IR, self).__init__()
|
79 |
+
if in_channel == depth:
|
80 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
81 |
+
else:
|
82 |
+
self.shortcut_layer = Sequential(
|
83 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
84 |
+
BatchNorm2d(depth)
|
85 |
+
)
|
86 |
+
self.res_layer = Sequential(
|
87 |
+
BatchNorm2d(in_channel),
|
88 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
89 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
90 |
+
)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
shortcut = self.shortcut_layer(x)
|
94 |
+
res = self.res_layer(x)
|
95 |
+
return res + shortcut
|
96 |
+
|
97 |
+
|
98 |
+
class bottleneck_IR_SE(Module):
|
99 |
+
def __init__(self, in_channel, depth, stride):
|
100 |
+
super(bottleneck_IR_SE, self).__init__()
|
101 |
+
if in_channel == depth:
|
102 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
103 |
+
else:
|
104 |
+
self.shortcut_layer = Sequential(
|
105 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
106 |
+
BatchNorm2d(depth)
|
107 |
+
)
|
108 |
+
self.res_layer = Sequential(
|
109 |
+
BatchNorm2d(in_channel),
|
110 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
111 |
+
PReLU(depth),
|
112 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
113 |
+
BatchNorm2d(depth),
|
114 |
+
SEModule(depth, 16)
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
shortcut = self.shortcut_layer(x)
|
119 |
+
res = self.res_layer(x)
|
120 |
+
return res + shortcut
|
121 |
+
|
122 |
+
|
123 |
+
def _upsample_add(x, y):
|
124 |
+
"""Upsample and add two feature maps.
|
125 |
+
Args:
|
126 |
+
x: (Variable) top feature map to be upsampled.
|
127 |
+
y: (Variable) lateral feature map.
|
128 |
+
Returns:
|
129 |
+
(Variable) added feature map.
|
130 |
+
Note in PyTorch, when input size is odd, the upsampled feature map
|
131 |
+
with `F.upsample(..., scale_factor=2, mode='nearest')`
|
132 |
+
maybe not equal to the lateral feature map size.
|
133 |
+
e.g.
|
134 |
+
original input size: [N,_,15,15] ->
|
135 |
+
conv2d feature map size: [N,_,8,8] ->
|
136 |
+
upsampled feature map size: [N,_,16,16]
|
137 |
+
So we choose bilinear upsample which supports arbitrary output sizes.
|
138 |
+
"""
|
139 |
+
_, _, H, W = y.size()
|
140 |
+
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
|
e4e/models/encoders/model_irse.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
2 |
+
from e4e.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
3 |
+
|
4 |
+
"""
|
5 |
+
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class Backbone(Module):
|
10 |
+
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
11 |
+
super(Backbone, self).__init__()
|
12 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
13 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
14 |
+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
15 |
+
blocks = get_blocks(num_layers)
|
16 |
+
if mode == 'ir':
|
17 |
+
unit_module = bottleneck_IR
|
18 |
+
elif mode == 'ir_se':
|
19 |
+
unit_module = bottleneck_IR_SE
|
20 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
21 |
+
BatchNorm2d(64),
|
22 |
+
PReLU(64))
|
23 |
+
if input_size == 112:
|
24 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
25 |
+
Dropout(drop_ratio),
|
26 |
+
Flatten(),
|
27 |
+
Linear(512 * 7 * 7, 512),
|
28 |
+
BatchNorm1d(512, affine=affine))
|
29 |
+
else:
|
30 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
31 |
+
Dropout(drop_ratio),
|
32 |
+
Flatten(),
|
33 |
+
Linear(512 * 14 * 14, 512),
|
34 |
+
BatchNorm1d(512, affine=affine))
|
35 |
+
|
36 |
+
modules = []
|
37 |
+
for block in blocks:
|
38 |
+
for bottleneck in block:
|
39 |
+
modules.append(unit_module(bottleneck.in_channel,
|
40 |
+
bottleneck.depth,
|
41 |
+
bottleneck.stride))
|
42 |
+
self.body = Sequential(*modules)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.input_layer(x)
|
46 |
+
x = self.body(x)
|
47 |
+
x = self.output_layer(x)
|
48 |
+
return l2_norm(x)
|
49 |
+
|
50 |
+
|
51 |
+
def IR_50(input_size):
|
52 |
+
"""Constructs a ir-50 model."""
|
53 |
+
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
54 |
+
return model
|
55 |
+
|
56 |
+
|
57 |
+
def IR_101(input_size):
|
58 |
+
"""Constructs a ir-101 model."""
|
59 |
+
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def IR_152(input_size):
|
64 |
+
"""Constructs a ir-152 model."""
|
65 |
+
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
66 |
+
return model
|
67 |
+
|
68 |
+
|
69 |
+
def IR_SE_50(input_size):
|
70 |
+
"""Constructs a ir_se-50 model."""
|
71 |
+
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
72 |
+
return model
|
73 |
+
|
74 |
+
|
75 |
+
def IR_SE_101(input_size):
|
76 |
+
"""Constructs a ir_se-101 model."""
|
77 |
+
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
78 |
+
return model
|
79 |
+
|
80 |
+
|
81 |
+
def IR_SE_152(input_size):
|
82 |
+
"""Constructs a ir_se-152 model."""
|
83 |
+
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
84 |
+
return model
|
e4e/models/encoders/psp_encoders.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
|
7 |
+
|
8 |
+
from e4e.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
|
9 |
+
from e4e.models.stylegan2.model import EqualLinear
|
10 |
+
|
11 |
+
|
12 |
+
class ProgressiveStage(Enum):
|
13 |
+
WTraining = 0
|
14 |
+
Delta1Training = 1
|
15 |
+
Delta2Training = 2
|
16 |
+
Delta3Training = 3
|
17 |
+
Delta4Training = 4
|
18 |
+
Delta5Training = 5
|
19 |
+
Delta6Training = 6
|
20 |
+
Delta7Training = 7
|
21 |
+
Delta8Training = 8
|
22 |
+
Delta9Training = 9
|
23 |
+
Delta10Training = 10
|
24 |
+
Delta11Training = 11
|
25 |
+
Delta12Training = 12
|
26 |
+
Delta13Training = 13
|
27 |
+
Delta14Training = 14
|
28 |
+
Delta15Training = 15
|
29 |
+
Delta16Training = 16
|
30 |
+
Delta17Training = 17
|
31 |
+
Inference = 18
|
32 |
+
|
33 |
+
|
34 |
+
class GradualStyleBlock(Module):
|
35 |
+
def __init__(self, in_c, out_c, spatial):
|
36 |
+
super(GradualStyleBlock, self).__init__()
|
37 |
+
self.out_c = out_c
|
38 |
+
self.spatial = spatial
|
39 |
+
num_pools = int(np.log2(spatial))
|
40 |
+
modules = []
|
41 |
+
modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
|
42 |
+
nn.LeakyReLU()]
|
43 |
+
for i in range(num_pools - 1):
|
44 |
+
modules += [
|
45 |
+
Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
|
46 |
+
nn.LeakyReLU()
|
47 |
+
]
|
48 |
+
self.convs = nn.Sequential(*modules)
|
49 |
+
self.linear = EqualLinear(out_c, out_c, lr_mul=1)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = self.convs(x)
|
53 |
+
x = x.view(-1, self.out_c)
|
54 |
+
x = self.linear(x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
class GradualStyleEncoder(Module):
|
59 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
60 |
+
super(GradualStyleEncoder, self).__init__()
|
61 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
62 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
63 |
+
blocks = get_blocks(num_layers)
|
64 |
+
if mode == 'ir':
|
65 |
+
unit_module = bottleneck_IR
|
66 |
+
elif mode == 'ir_se':
|
67 |
+
unit_module = bottleneck_IR_SE
|
68 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
69 |
+
BatchNorm2d(64),
|
70 |
+
PReLU(64))
|
71 |
+
modules = []
|
72 |
+
for block in blocks:
|
73 |
+
for bottleneck in block:
|
74 |
+
modules.append(unit_module(bottleneck.in_channel,
|
75 |
+
bottleneck.depth,
|
76 |
+
bottleneck.stride))
|
77 |
+
self.body = Sequential(*modules)
|
78 |
+
|
79 |
+
self.styles = nn.ModuleList()
|
80 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
81 |
+
self.style_count = 2 * log_size - 2
|
82 |
+
self.coarse_ind = 3
|
83 |
+
self.middle_ind = 7
|
84 |
+
for i in range(self.style_count):
|
85 |
+
if i < self.coarse_ind:
|
86 |
+
style = GradualStyleBlock(512, 512, 16)
|
87 |
+
elif i < self.middle_ind:
|
88 |
+
style = GradualStyleBlock(512, 512, 32)
|
89 |
+
else:
|
90 |
+
style = GradualStyleBlock(512, 512, 64)
|
91 |
+
self.styles.append(style)
|
92 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
93 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = self.input_layer(x)
|
97 |
+
|
98 |
+
latents = []
|
99 |
+
modulelist = list(self.body._modules.values())
|
100 |
+
for i, l in enumerate(modulelist):
|
101 |
+
x = l(x)
|
102 |
+
if i == 6:
|
103 |
+
c1 = x
|
104 |
+
elif i == 20:
|
105 |
+
c2 = x
|
106 |
+
elif i == 23:
|
107 |
+
c3 = x
|
108 |
+
|
109 |
+
for j in range(self.coarse_ind):
|
110 |
+
latents.append(self.styles[j](c3))
|
111 |
+
|
112 |
+
p2 = _upsample_add(c3, self.latlayer1(c2))
|
113 |
+
for j in range(self.coarse_ind, self.middle_ind):
|
114 |
+
latents.append(self.styles[j](p2))
|
115 |
+
|
116 |
+
p1 = _upsample_add(p2, self.latlayer2(c1))
|
117 |
+
for j in range(self.middle_ind, self.style_count):
|
118 |
+
latents.append(self.styles[j](p1))
|
119 |
+
|
120 |
+
out = torch.stack(latents, dim=1)
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
class Encoder4Editing(Module):
|
125 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
126 |
+
super(Encoder4Editing, self).__init__()
|
127 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
128 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
129 |
+
blocks = get_blocks(num_layers)
|
130 |
+
if mode == 'ir':
|
131 |
+
unit_module = bottleneck_IR
|
132 |
+
elif mode == 'ir_se':
|
133 |
+
unit_module = bottleneck_IR_SE
|
134 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
135 |
+
BatchNorm2d(64),
|
136 |
+
PReLU(64))
|
137 |
+
modules = []
|
138 |
+
for block in blocks:
|
139 |
+
for bottleneck in block:
|
140 |
+
modules.append(unit_module(bottleneck.in_channel,
|
141 |
+
bottleneck.depth,
|
142 |
+
bottleneck.stride))
|
143 |
+
self.body = Sequential(*modules)
|
144 |
+
|
145 |
+
self.styles = nn.ModuleList()
|
146 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
147 |
+
self.style_count = 2 * log_size - 2
|
148 |
+
self.coarse_ind = 3
|
149 |
+
self.middle_ind = 7
|
150 |
+
|
151 |
+
for i in range(self.style_count):
|
152 |
+
if i < self.coarse_ind:
|
153 |
+
style = GradualStyleBlock(512, 512, 16)
|
154 |
+
elif i < self.middle_ind:
|
155 |
+
style = GradualStyleBlock(512, 512, 32)
|
156 |
+
else:
|
157 |
+
style = GradualStyleBlock(512, 512, 64)
|
158 |
+
self.styles.append(style)
|
159 |
+
|
160 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
161 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
162 |
+
|
163 |
+
self.progressive_stage = ProgressiveStage.Inference
|
164 |
+
|
165 |
+
def get_deltas_starting_dimensions(self):
|
166 |
+
''' Get a list of the initial dimension of every delta from which it is applied '''
|
167 |
+
return list(range(self.style_count)) # Each dimension has a delta applied to it
|
168 |
+
|
169 |
+
def set_progressive_stage(self, new_stage: ProgressiveStage):
|
170 |
+
self.progressive_stage = new_stage
|
171 |
+
print('Changed progressive stage to: ', new_stage)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
x = self.input_layer(x)
|
175 |
+
|
176 |
+
modulelist = list(self.body._modules.values())
|
177 |
+
for i, l in enumerate(modulelist):
|
178 |
+
x = l(x)
|
179 |
+
if i == 6:
|
180 |
+
c1 = x
|
181 |
+
elif i == 20:
|
182 |
+
c2 = x
|
183 |
+
elif i == 23:
|
184 |
+
c3 = x
|
185 |
+
|
186 |
+
# Infer main W and duplicate it
|
187 |
+
w0 = self.styles[0](c3)
|
188 |
+
w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
|
189 |
+
stage = self.progressive_stage.value
|
190 |
+
features = c3
|
191 |
+
for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
|
192 |
+
if i == self.coarse_ind:
|
193 |
+
p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
|
194 |
+
features = p2
|
195 |
+
elif i == self.middle_ind:
|
196 |
+
p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
|
197 |
+
features = p1
|
198 |
+
delta_i = self.styles[i](features)
|
199 |
+
w[:, i] += delta_i
|
200 |
+
return w
|
e4e/models/latent_codes_pool.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class LatentCodesPool:
|
6 |
+
"""This class implements latent codes buffer that stores previously generated w latent codes.
|
7 |
+
This buffer enables us to update discriminators using a history of generated w's
|
8 |
+
rather than the ones produced by the latest encoder.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, pool_size):
|
12 |
+
"""Initialize the ImagePool class
|
13 |
+
Parameters:
|
14 |
+
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
15 |
+
"""
|
16 |
+
self.pool_size = pool_size
|
17 |
+
if self.pool_size > 0: # create an empty pool
|
18 |
+
self.num_ws = 0
|
19 |
+
self.ws = []
|
20 |
+
|
21 |
+
def query(self, ws):
|
22 |
+
"""Return w's from the pool.
|
23 |
+
Parameters:
|
24 |
+
ws: the latest generated w's from the generator
|
25 |
+
Returns w's from the buffer.
|
26 |
+
By 50/100, the buffer will return input w's.
|
27 |
+
By 50/100, the buffer will return w's previously stored in the buffer,
|
28 |
+
and insert the current w's to the buffer.
|
29 |
+
"""
|
30 |
+
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
31 |
+
return ws
|
32 |
+
return_ws = []
|
33 |
+
for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
|
34 |
+
# w = torch.unsqueeze(image.data, 0)
|
35 |
+
if w.ndim == 2:
|
36 |
+
i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
|
37 |
+
w = w[i]
|
38 |
+
self.handle_w(w, return_ws)
|
39 |
+
return_ws = torch.stack(return_ws, 0) # collect all the images and return
|
40 |
+
return return_ws
|
41 |
+
|
42 |
+
def handle_w(self, w, return_ws):
|
43 |
+
if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
|
44 |
+
self.num_ws = self.num_ws + 1
|
45 |
+
self.ws.append(w)
|
46 |
+
return_ws.append(w)
|
47 |
+
else:
|
48 |
+
p = random.uniform(0, 1)
|
49 |
+
if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
|
50 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
51 |
+
tmp = self.ws[random_id].clone()
|
52 |
+
self.ws[random_id] = w
|
53 |
+
return_ws.append(tmp)
|
54 |
+
else: # by another 50% chance, the buffer will return the current image
|
55 |
+
return_ws.append(w)
|
e4e/models/psp.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
|
3 |
+
matplotlib.use('Agg')
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from e4e.models.encoders import psp_encoders
|
7 |
+
from e4e.models.stylegan2.model import Generator
|
8 |
+
from e4e.configs.paths_config import model_paths
|
9 |
+
|
10 |
+
|
11 |
+
def get_keys(d, name):
|
12 |
+
if 'state_dict' in d:
|
13 |
+
d = d['state_dict']
|
14 |
+
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
|
15 |
+
return d_filt
|
16 |
+
|
17 |
+
|
18 |
+
class pSp(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, opts, device):
|
21 |
+
super(pSp, self).__init__()
|
22 |
+
self.opts = opts
|
23 |
+
self.device = device
|
24 |
+
# Define architecture
|
25 |
+
self.encoder = self.set_encoder()
|
26 |
+
self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
|
27 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
28 |
+
# Load weights if needed
|
29 |
+
self.load_weights()
|
30 |
+
|
31 |
+
def set_encoder(self):
|
32 |
+
if self.opts.encoder_type == 'GradualStyleEncoder':
|
33 |
+
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
|
34 |
+
elif self.opts.encoder_type == 'Encoder4Editing':
|
35 |
+
encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
|
36 |
+
else:
|
37 |
+
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
|
38 |
+
return encoder
|
39 |
+
|
40 |
+
def load_weights(self):
|
41 |
+
if self.opts.checkpoint_path is not None:
|
42 |
+
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
|
43 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
|
44 |
+
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
|
45 |
+
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
46 |
+
self.__load_latent_avg(ckpt)
|
47 |
+
else:
|
48 |
+
print('Loading encoders weights from irse50!')
|
49 |
+
encoder_ckpt = torch.load(model_paths['ir_se50'])
|
50 |
+
self.encoder.load_state_dict(encoder_ckpt, strict=False)
|
51 |
+
print('Loading decoder weights from pretrained!')
|
52 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
53 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
54 |
+
self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
|
55 |
+
|
56 |
+
def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
|
57 |
+
inject_latent=None, return_latents=False, alpha=None):
|
58 |
+
if input_code:
|
59 |
+
codes = x
|
60 |
+
else:
|
61 |
+
codes = self.encoder(x)
|
62 |
+
# normalize with respect to the center of an average face
|
63 |
+
if self.opts.start_from_latent_avg:
|
64 |
+
if codes.ndim == 2:
|
65 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
|
66 |
+
else:
|
67 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
|
68 |
+
|
69 |
+
if latent_mask is not None:
|
70 |
+
for i in latent_mask:
|
71 |
+
if inject_latent is not None:
|
72 |
+
if alpha is not None:
|
73 |
+
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
|
74 |
+
else:
|
75 |
+
codes[:, i] = inject_latent[:, i]
|
76 |
+
else:
|
77 |
+
codes[:, i] = 0
|
78 |
+
|
79 |
+
input_is_latent = not input_code
|
80 |
+
images, result_latent = self.decoder([codes],
|
81 |
+
input_is_latent=input_is_latent,
|
82 |
+
randomize_noise=randomize_noise,
|
83 |
+
return_latents=return_latents)
|
84 |
+
|
85 |
+
if resize:
|
86 |
+
images = self.face_pool(images)
|
87 |
+
|
88 |
+
if return_latents:
|
89 |
+
return images, result_latent
|
90 |
+
else:
|
91 |
+
return images
|
92 |
+
|
93 |
+
def __load_latent_avg(self, ckpt, repeat=None):
|
94 |
+
if 'latent_avg' in ckpt:
|
95 |
+
self.latent_avg = ckpt['latent_avg'].to(self.device)
|
96 |
+
if repeat is not None:
|
97 |
+
self.latent_avg = self.latent_avg.repeat(repeat, 1)
|
98 |
+
else:
|
99 |
+
self.latent_avg = None
|
e4e/models/stylegan2/__init__.py
ADDED
File without changes
|
e4e/models/stylegan2/model.py
ADDED
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
if torch.cuda.is_available():
|
8 |
+
from op.fused_act import FusedLeakyReLU, fused_leaky_relu
|
9 |
+
from op.upfirdn2d import upfirdn2d
|
10 |
+
else:
|
11 |
+
from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
|
12 |
+
from op.upfirdn2d_cpu import upfirdn2d
|
13 |
+
|
14 |
+
|
15 |
+
class PixelNorm(nn.Module):
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
def forward(self, input):
|
20 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
21 |
+
|
22 |
+
|
23 |
+
def make_kernel(k):
|
24 |
+
k = torch.tensor(k, dtype=torch.float32)
|
25 |
+
|
26 |
+
if k.ndim == 1:
|
27 |
+
k = k[None, :] * k[:, None]
|
28 |
+
|
29 |
+
k /= k.sum()
|
30 |
+
|
31 |
+
return k
|
32 |
+
|
33 |
+
|
34 |
+
class Upsample(nn.Module):
|
35 |
+
def __init__(self, kernel, factor=2):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.factor = factor
|
39 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
40 |
+
self.register_buffer('kernel', kernel)
|
41 |
+
|
42 |
+
p = kernel.shape[0] - factor
|
43 |
+
|
44 |
+
pad0 = (p + 1) // 2 + factor - 1
|
45 |
+
pad1 = p // 2
|
46 |
+
|
47 |
+
self.pad = (pad0, pad1)
|
48 |
+
|
49 |
+
def forward(self, input):
|
50 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
51 |
+
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class Downsample(nn.Module):
|
56 |
+
def __init__(self, kernel, factor=2):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.factor = factor
|
60 |
+
kernel = make_kernel(kernel)
|
61 |
+
self.register_buffer('kernel', kernel)
|
62 |
+
|
63 |
+
p = kernel.shape[0] - factor
|
64 |
+
|
65 |
+
pad0 = (p + 1) // 2
|
66 |
+
pad1 = p // 2
|
67 |
+
|
68 |
+
self.pad = (pad0, pad1)
|
69 |
+
|
70 |
+
def forward(self, input):
|
71 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
72 |
+
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
class Blur(nn.Module):
|
77 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
kernel = make_kernel(kernel)
|
81 |
+
|
82 |
+
if upsample_factor > 1:
|
83 |
+
kernel = kernel * (upsample_factor ** 2)
|
84 |
+
|
85 |
+
self.register_buffer('kernel', kernel)
|
86 |
+
|
87 |
+
self.pad = pad
|
88 |
+
|
89 |
+
def forward(self, input):
|
90 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
91 |
+
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class EqualConv2d(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.weight = nn.Parameter(
|
102 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
103 |
+
)
|
104 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
105 |
+
|
106 |
+
self.stride = stride
|
107 |
+
self.padding = padding
|
108 |
+
|
109 |
+
if bias:
|
110 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
111 |
+
|
112 |
+
else:
|
113 |
+
self.bias = None
|
114 |
+
|
115 |
+
def forward(self, input):
|
116 |
+
out = F.conv2d(
|
117 |
+
input,
|
118 |
+
self.weight * self.scale,
|
119 |
+
bias=self.bias,
|
120 |
+
stride=self.stride,
|
121 |
+
padding=self.padding,
|
122 |
+
)
|
123 |
+
|
124 |
+
return out
|
125 |
+
|
126 |
+
def __repr__(self):
|
127 |
+
return (
|
128 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
129 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
class EqualLinear(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
140 |
+
|
141 |
+
if bias:
|
142 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
143 |
+
|
144 |
+
else:
|
145 |
+
self.bias = None
|
146 |
+
|
147 |
+
self.activation = activation
|
148 |
+
|
149 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
150 |
+
self.lr_mul = lr_mul
|
151 |
+
|
152 |
+
def forward(self, input):
|
153 |
+
if self.activation:
|
154 |
+
out = F.linear(input, self.weight * self.scale)
|
155 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
156 |
+
|
157 |
+
else:
|
158 |
+
out = F.linear(
|
159 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
160 |
+
)
|
161 |
+
|
162 |
+
return out
|
163 |
+
|
164 |
+
def __repr__(self):
|
165 |
+
return (
|
166 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
class ScaledLeakyReLU(nn.Module):
|
171 |
+
def __init__(self, negative_slope=0.2):
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
self.negative_slope = negative_slope
|
175 |
+
|
176 |
+
def forward(self, input):
|
177 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
178 |
+
|
179 |
+
return out * math.sqrt(2)
|
180 |
+
|
181 |
+
|
182 |
+
class ModulatedConv2d(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
in_channel,
|
186 |
+
out_channel,
|
187 |
+
kernel_size,
|
188 |
+
style_dim,
|
189 |
+
demodulate=True,
|
190 |
+
upsample=False,
|
191 |
+
downsample=False,
|
192 |
+
blur_kernel=[1, 3, 3, 1],
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
self.eps = 1e-8
|
197 |
+
self.kernel_size = kernel_size
|
198 |
+
self.in_channel = in_channel
|
199 |
+
self.out_channel = out_channel
|
200 |
+
self.upsample = upsample
|
201 |
+
self.downsample = downsample
|
202 |
+
|
203 |
+
if upsample:
|
204 |
+
factor = 2
|
205 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
206 |
+
pad0 = (p + 1) // 2 + factor - 1
|
207 |
+
pad1 = p // 2 + 1
|
208 |
+
|
209 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
210 |
+
|
211 |
+
if downsample:
|
212 |
+
factor = 2
|
213 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
214 |
+
pad0 = (p + 1) // 2
|
215 |
+
pad1 = p // 2
|
216 |
+
|
217 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
218 |
+
|
219 |
+
fan_in = in_channel * kernel_size ** 2
|
220 |
+
self.scale = 1 / math.sqrt(fan_in)
|
221 |
+
self.padding = kernel_size // 2
|
222 |
+
|
223 |
+
self.weight = nn.Parameter(
|
224 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
225 |
+
)
|
226 |
+
|
227 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
228 |
+
|
229 |
+
self.demodulate = demodulate
|
230 |
+
|
231 |
+
def __repr__(self):
|
232 |
+
return (
|
233 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
234 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
235 |
+
)
|
236 |
+
|
237 |
+
def forward(self, input, style):
|
238 |
+
batch, in_channel, height, width = input.shape
|
239 |
+
|
240 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
241 |
+
weight = self.scale * self.weight * style
|
242 |
+
|
243 |
+
if self.demodulate:
|
244 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
245 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
246 |
+
|
247 |
+
weight = weight.view(
|
248 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
249 |
+
)
|
250 |
+
|
251 |
+
if self.upsample:
|
252 |
+
input = input.view(1, batch * in_channel, height, width)
|
253 |
+
weight = weight.view(
|
254 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
255 |
+
)
|
256 |
+
weight = weight.transpose(1, 2).reshape(
|
257 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
258 |
+
)
|
259 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
260 |
+
_, _, height, width = out.shape
|
261 |
+
out = out.view(batch, self.out_channel, height, width)
|
262 |
+
out = self.blur(out)
|
263 |
+
|
264 |
+
elif self.downsample:
|
265 |
+
input = self.blur(input)
|
266 |
+
_, _, height, width = input.shape
|
267 |
+
input = input.view(1, batch * in_channel, height, width)
|
268 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
269 |
+
_, _, height, width = out.shape
|
270 |
+
out = out.view(batch, self.out_channel, height, width)
|
271 |
+
|
272 |
+
else:
|
273 |
+
input = input.view(1, batch * in_channel, height, width)
|
274 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
275 |
+
_, _, height, width = out.shape
|
276 |
+
out = out.view(batch, self.out_channel, height, width)
|
277 |
+
|
278 |
+
return out
|
279 |
+
|
280 |
+
|
281 |
+
class NoiseInjection(nn.Module):
|
282 |
+
def __init__(self):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
286 |
+
|
287 |
+
def forward(self, image, noise=None):
|
288 |
+
if noise is None:
|
289 |
+
batch, _, height, width = image.shape
|
290 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
291 |
+
|
292 |
+
return image + self.weight * noise
|
293 |
+
|
294 |
+
|
295 |
+
class ConstantInput(nn.Module):
|
296 |
+
def __init__(self, channel, size=4):
|
297 |
+
super().__init__()
|
298 |
+
|
299 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
300 |
+
|
301 |
+
def forward(self, input):
|
302 |
+
batch = input.shape[0]
|
303 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
304 |
+
|
305 |
+
return out
|
306 |
+
|
307 |
+
|
308 |
+
class StyledConv(nn.Module):
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
in_channel,
|
312 |
+
out_channel,
|
313 |
+
kernel_size,
|
314 |
+
style_dim,
|
315 |
+
upsample=False,
|
316 |
+
blur_kernel=[1, 3, 3, 1],
|
317 |
+
demodulate=True,
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
|
321 |
+
self.conv = ModulatedConv2d(
|
322 |
+
in_channel,
|
323 |
+
out_channel,
|
324 |
+
kernel_size,
|
325 |
+
style_dim,
|
326 |
+
upsample=upsample,
|
327 |
+
blur_kernel=blur_kernel,
|
328 |
+
demodulate=demodulate,
|
329 |
+
)
|
330 |
+
|
331 |
+
self.noise = NoiseInjection()
|
332 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
333 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
334 |
+
self.activate = FusedLeakyReLU(out_channel)
|
335 |
+
|
336 |
+
def forward(self, input, style, noise=None):
|
337 |
+
out = self.conv(input, style)
|
338 |
+
out = self.noise(out, noise=noise)
|
339 |
+
# out = out + self.bias
|
340 |
+
out = self.activate(out)
|
341 |
+
|
342 |
+
return out
|
343 |
+
|
344 |
+
|
345 |
+
class ToRGB(nn.Module):
|
346 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
347 |
+
super().__init__()
|
348 |
+
|
349 |
+
if upsample:
|
350 |
+
self.upsample = Upsample(blur_kernel)
|
351 |
+
|
352 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
353 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
354 |
+
|
355 |
+
def forward(self, input, style, skip=None):
|
356 |
+
out = self.conv(input, style)
|
357 |
+
out = out + self.bias
|
358 |
+
|
359 |
+
if skip is not None:
|
360 |
+
skip = self.upsample(skip)
|
361 |
+
|
362 |
+
out = out + skip
|
363 |
+
|
364 |
+
return out
|
365 |
+
|
366 |
+
|
367 |
+
class Generator(nn.Module):
|
368 |
+
def __init__(
|
369 |
+
self,
|
370 |
+
size,
|
371 |
+
style_dim,
|
372 |
+
n_mlp,
|
373 |
+
channel_multiplier=2,
|
374 |
+
blur_kernel=[1, 3, 3, 1],
|
375 |
+
lr_mlp=0.01,
|
376 |
+
):
|
377 |
+
super().__init__()
|
378 |
+
|
379 |
+
self.size = size
|
380 |
+
|
381 |
+
self.style_dim = style_dim
|
382 |
+
|
383 |
+
layers = [PixelNorm()]
|
384 |
+
|
385 |
+
for i in range(n_mlp):
|
386 |
+
layers.append(
|
387 |
+
EqualLinear(
|
388 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
389 |
+
)
|
390 |
+
)
|
391 |
+
|
392 |
+
self.style = nn.Sequential(*layers)
|
393 |
+
|
394 |
+
self.channels = {
|
395 |
+
4: 512,
|
396 |
+
8: 512,
|
397 |
+
16: 512,
|
398 |
+
32: 512,
|
399 |
+
64: 256 * channel_multiplier,
|
400 |
+
128: 128 * channel_multiplier,
|
401 |
+
256: 64 * channel_multiplier,
|
402 |
+
512: 32 * channel_multiplier,
|
403 |
+
1024: 16 * channel_multiplier,
|
404 |
+
}
|
405 |
+
|
406 |
+
self.input = ConstantInput(self.channels[4])
|
407 |
+
self.conv1 = StyledConv(
|
408 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
409 |
+
)
|
410 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
411 |
+
|
412 |
+
self.log_size = int(math.log(size, 2))
|
413 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
414 |
+
|
415 |
+
self.convs = nn.ModuleList()
|
416 |
+
self.upsamples = nn.ModuleList()
|
417 |
+
self.to_rgbs = nn.ModuleList()
|
418 |
+
self.noises = nn.Module()
|
419 |
+
|
420 |
+
in_channel = self.channels[4]
|
421 |
+
|
422 |
+
for layer_idx in range(self.num_layers):
|
423 |
+
res = (layer_idx + 5) // 2
|
424 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
425 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
426 |
+
|
427 |
+
for i in range(3, self.log_size + 1):
|
428 |
+
out_channel = self.channels[2 ** i]
|
429 |
+
|
430 |
+
self.convs.append(
|
431 |
+
StyledConv(
|
432 |
+
in_channel,
|
433 |
+
out_channel,
|
434 |
+
3,
|
435 |
+
style_dim,
|
436 |
+
upsample=True,
|
437 |
+
blur_kernel=blur_kernel,
|
438 |
+
)
|
439 |
+
)
|
440 |
+
|
441 |
+
self.convs.append(
|
442 |
+
StyledConv(
|
443 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
444 |
+
)
|
445 |
+
)
|
446 |
+
|
447 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
448 |
+
|
449 |
+
in_channel = out_channel
|
450 |
+
|
451 |
+
self.n_latent = self.log_size * 2 - 2
|
452 |
+
|
453 |
+
def make_noise(self):
|
454 |
+
device = self.input.input.device
|
455 |
+
|
456 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
457 |
+
|
458 |
+
for i in range(3, self.log_size + 1):
|
459 |
+
for _ in range(2):
|
460 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
461 |
+
|
462 |
+
return noises
|
463 |
+
|
464 |
+
def mean_latent(self, n_latent):
|
465 |
+
latent_in = torch.randn(
|
466 |
+
n_latent, self.style_dim, device=self.input.input.device
|
467 |
+
)
|
468 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
469 |
+
|
470 |
+
return latent
|
471 |
+
|
472 |
+
def get_latent(self, input):
|
473 |
+
return self.style(input)
|
474 |
+
|
475 |
+
def forward(
|
476 |
+
self,
|
477 |
+
styles,
|
478 |
+
return_latents=False,
|
479 |
+
return_features=False,
|
480 |
+
inject_index=None,
|
481 |
+
truncation=1,
|
482 |
+
truncation_latent=None,
|
483 |
+
input_is_latent=False,
|
484 |
+
noise=None,
|
485 |
+
randomize_noise=True,
|
486 |
+
):
|
487 |
+
if not input_is_latent:
|
488 |
+
styles = [self.style(s) for s in styles]
|
489 |
+
|
490 |
+
if noise is None:
|
491 |
+
if randomize_noise:
|
492 |
+
noise = [None] * self.num_layers
|
493 |
+
else:
|
494 |
+
noise = [
|
495 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
496 |
+
]
|
497 |
+
|
498 |
+
if truncation < 1:
|
499 |
+
style_t = []
|
500 |
+
|
501 |
+
for style in styles:
|
502 |
+
style_t.append(
|
503 |
+
truncation_latent + truncation * (style - truncation_latent)
|
504 |
+
)
|
505 |
+
|
506 |
+
styles = style_t
|
507 |
+
|
508 |
+
if len(styles) < 2:
|
509 |
+
inject_index = self.n_latent
|
510 |
+
|
511 |
+
if styles[0].ndim < 3:
|
512 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
513 |
+
else:
|
514 |
+
latent = styles[0]
|
515 |
+
|
516 |
+
else:
|
517 |
+
if inject_index is None:
|
518 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
519 |
+
|
520 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
521 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
522 |
+
|
523 |
+
latent = torch.cat([latent, latent2], 1)
|
524 |
+
|
525 |
+
out = self.input(latent)
|
526 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
527 |
+
|
528 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
529 |
+
|
530 |
+
i = 1
|
531 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
532 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
533 |
+
):
|
534 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
535 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
536 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
537 |
+
|
538 |
+
i += 2
|
539 |
+
|
540 |
+
image = skip
|
541 |
+
|
542 |
+
if return_latents:
|
543 |
+
return image, latent
|
544 |
+
elif return_features:
|
545 |
+
return image, out
|
546 |
+
else:
|
547 |
+
return image, None
|
548 |
+
|
549 |
+
|
550 |
+
class ConvLayer(nn.Sequential):
|
551 |
+
def __init__(
|
552 |
+
self,
|
553 |
+
in_channel,
|
554 |
+
out_channel,
|
555 |
+
kernel_size,
|
556 |
+
downsample=False,
|
557 |
+
blur_kernel=[1, 3, 3, 1],
|
558 |
+
bias=True,
|
559 |
+
activate=True,
|
560 |
+
):
|
561 |
+
layers = []
|
562 |
+
|
563 |
+
if downsample:
|
564 |
+
factor = 2
|
565 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
566 |
+
pad0 = (p + 1) // 2
|
567 |
+
pad1 = p // 2
|
568 |
+
|
569 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
570 |
+
|
571 |
+
stride = 2
|
572 |
+
self.padding = 0
|
573 |
+
|
574 |
+
else:
|
575 |
+
stride = 1
|
576 |
+
self.padding = kernel_size // 2
|
577 |
+
|
578 |
+
layers.append(
|
579 |
+
EqualConv2d(
|
580 |
+
in_channel,
|
581 |
+
out_channel,
|
582 |
+
kernel_size,
|
583 |
+
padding=self.padding,
|
584 |
+
stride=stride,
|
585 |
+
bias=bias and not activate,
|
586 |
+
)
|
587 |
+
)
|
588 |
+
|
589 |
+
if activate:
|
590 |
+
if bias:
|
591 |
+
layers.append(FusedLeakyReLU(out_channel))
|
592 |
+
|
593 |
+
else:
|
594 |
+
layers.append(ScaledLeakyReLU(0.2))
|
595 |
+
|
596 |
+
super().__init__(*layers)
|
597 |
+
|
598 |
+
|
599 |
+
class ResBlock(nn.Module):
|
600 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
601 |
+
super().__init__()
|
602 |
+
|
603 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
604 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
605 |
+
|
606 |
+
self.skip = ConvLayer(
|
607 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
608 |
+
)
|
609 |
+
|
610 |
+
def forward(self, input):
|
611 |
+
out = self.conv1(input)
|
612 |
+
out = self.conv2(out)
|
613 |
+
|
614 |
+
skip = self.skip(input)
|
615 |
+
out = (out + skip) / math.sqrt(2)
|
616 |
+
|
617 |
+
return out
|
618 |
+
|
619 |
+
|
620 |
+
class Discriminator(nn.Module):
|
621 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
622 |
+
super().__init__()
|
623 |
+
|
624 |
+
channels = {
|
625 |
+
4: 512,
|
626 |
+
8: 512,
|
627 |
+
16: 512,
|
628 |
+
32: 512,
|
629 |
+
64: 256 * channel_multiplier,
|
630 |
+
128: 128 * channel_multiplier,
|
631 |
+
256: 64 * channel_multiplier,
|
632 |
+
512: 32 * channel_multiplier,
|
633 |
+
1024: 16 * channel_multiplier,
|
634 |
+
}
|
635 |
+
|
636 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
637 |
+
|
638 |
+
log_size = int(math.log(size, 2))
|
639 |
+
|
640 |
+
in_channel = channels[size]
|
641 |
+
|
642 |
+
for i in range(log_size, 2, -1):
|
643 |
+
out_channel = channels[2 ** (i - 1)]
|
644 |
+
|
645 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
646 |
+
|
647 |
+
in_channel = out_channel
|
648 |
+
|
649 |
+
self.convs = nn.Sequential(*convs)
|
650 |
+
|
651 |
+
self.stddev_group = 4
|
652 |
+
self.stddev_feat = 1
|
653 |
+
|
654 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
655 |
+
self.final_linear = nn.Sequential(
|
656 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
657 |
+
EqualLinear(channels[4], 1),
|
658 |
+
)
|
659 |
+
|
660 |
+
def forward(self, input):
|
661 |
+
out = self.convs(input)
|
662 |
+
|
663 |
+
batch, channel, height, width = out.shape
|
664 |
+
group = min(batch, self.stddev_group)
|
665 |
+
stddev = out.view(
|
666 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
667 |
+
)
|
668 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
669 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
670 |
+
stddev = stddev.repeat(group, 1, height, width)
|
671 |
+
out = torch.cat([out, stddev], 1)
|
672 |
+
|
673 |
+
out = self.final_conv(out)
|
674 |
+
|
675 |
+
out = out.view(batch, -1)
|
676 |
+
out = self.final_linear(out)
|
677 |
+
|
678 |
+
return out
|
e4e/models/stylegan2/op/__init__.py
ADDED
File without changes
|
e4e/models/stylegan2/op/fused_act.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.utils.cpp_extension import load
|
7 |
+
|
8 |
+
module_path = os.path.dirname(__file__)
|
9 |
+
fused = load(
|
10 |
+
'fused',
|
11 |
+
sources=[
|
12 |
+
os.path.join(module_path, 'fused_bias_act.cpp'),
|
13 |
+
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
14 |
+
],
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
19 |
+
@staticmethod
|
20 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
21 |
+
ctx.save_for_backward(out)
|
22 |
+
ctx.negative_slope = negative_slope
|
23 |
+
ctx.scale = scale
|
24 |
+
|
25 |
+
empty = grad_output.new_empty(0)
|
26 |
+
|
27 |
+
grad_input = fused.fused_bias_act(
|
28 |
+
grad_output, empty, out, 3, 1, negative_slope, scale
|
29 |
+
)
|
30 |
+
|
31 |
+
dim = [0]
|
32 |
+
|
33 |
+
if grad_input.ndim > 2:
|
34 |
+
dim += list(range(2, grad_input.ndim))
|
35 |
+
|
36 |
+
grad_bias = grad_input.sum(dim).detach()
|
37 |
+
|
38 |
+
return grad_input, grad_bias
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
42 |
+
out, = ctx.saved_tensors
|
43 |
+
gradgrad_out = fused.fused_bias_act(
|
44 |
+
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
45 |
+
)
|
46 |
+
|
47 |
+
return gradgrad_out, None, None, None
|
48 |
+
|
49 |
+
|
50 |
+
class FusedLeakyReLUFunction(Function):
|
51 |
+
@staticmethod
|
52 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
53 |
+
empty = input.new_empty(0)
|
54 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
55 |
+
ctx.save_for_backward(out)
|
56 |
+
ctx.negative_slope = negative_slope
|
57 |
+
ctx.scale = scale
|
58 |
+
|
59 |
+
return out
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def backward(ctx, grad_output):
|
63 |
+
out, = ctx.saved_tensors
|
64 |
+
|
65 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
66 |
+
grad_output, out, ctx.negative_slope, ctx.scale
|
67 |
+
)
|
68 |
+
|
69 |
+
return grad_input, grad_bias, None, None
|
70 |
+
|
71 |
+
|
72 |
+
class FusedLeakyReLU(nn.Module):
|
73 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
77 |
+
self.negative_slope = negative_slope
|
78 |
+
self.scale = scale
|
79 |
+
|
80 |
+
def forward(self, input):
|
81 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
82 |
+
|
83 |
+
|
84 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
85 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
e4e/models/stylegan2/op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
5 |
+
int act, int grad, float alpha, float scale);
|
6 |
+
|
7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
9 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
10 |
+
|
11 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
12 |
+
int act, int grad, float alpha, float scale) {
|
13 |
+
CHECK_CUDA(input);
|
14 |
+
CHECK_CUDA(bias);
|
15 |
+
|
16 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
17 |
+
}
|
18 |
+
|
19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
20 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
21 |
+
}
|
e4e/models/stylegan2/op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
20 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
21 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
22 |
+
|
23 |
+
scalar_t zero = 0.0;
|
24 |
+
|
25 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
26 |
+
scalar_t x = p_x[xi];
|
27 |
+
|
28 |
+
if (use_bias) {
|
29 |
+
x += p_b[(xi / step_b) % size_b];
|
30 |
+
}
|
31 |
+
|
32 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
33 |
+
|
34 |
+
scalar_t y;
|
35 |
+
|
36 |
+
switch (act * 10 + grad) {
|
37 |
+
default:
|
38 |
+
case 10: y = x; break;
|
39 |
+
case 11: y = x; break;
|
40 |
+
case 12: y = 0.0; break;
|
41 |
+
|
42 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
43 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
44 |
+
case 32: y = 0.0; break;
|
45 |
+
}
|
46 |
+
|
47 |
+
out[xi] = y * scale;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
53 |
+
int act, int grad, float alpha, float scale) {
|
54 |
+
int curDevice = -1;
|
55 |
+
cudaGetDevice(&curDevice);
|
56 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
57 |
+
|
58 |
+
auto x = input.contiguous();
|
59 |
+
auto b = bias.contiguous();
|
60 |
+
auto ref = refer.contiguous();
|
61 |
+
|
62 |
+
int use_bias = b.numel() ? 1 : 0;
|
63 |
+
int use_ref = ref.numel() ? 1 : 0;
|
64 |
+
|
65 |
+
int size_x = x.numel();
|
66 |
+
int size_b = b.numel();
|
67 |
+
int step_b = 1;
|
68 |
+
|
69 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
70 |
+
step_b *= x.size(i);
|
71 |
+
}
|
72 |
+
|
73 |
+
int loop_x = 4;
|
74 |
+
int block_size = 4 * 32;
|
75 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
76 |
+
|
77 |
+
auto y = torch::empty_like(x);
|
78 |
+
|
79 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
80 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
81 |
+
y.data_ptr<scalar_t>(),
|
82 |
+
x.data_ptr<scalar_t>(),
|
83 |
+
b.data_ptr<scalar_t>(),
|
84 |
+
ref.data_ptr<scalar_t>(),
|
85 |
+
act,
|
86 |
+
grad,
|
87 |
+
alpha,
|
88 |
+
scale,
|
89 |
+
loop_x,
|
90 |
+
size_x,
|
91 |
+
step_b,
|
92 |
+
size_b,
|
93 |
+
use_bias,
|
94 |
+
use_ref
|
95 |
+
);
|
96 |
+
});
|
97 |
+
|
98 |
+
return y;
|
99 |
+
}
|
e4e/models/stylegan2/op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
5 |
+
int up_x, int up_y, int down_x, int down_y,
|
6 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
7 |
+
|
8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
10 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
11 |
+
|
12 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
13 |
+
int up_x, int up_y, int down_x, int down_y,
|
14 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
15 |
+
CHECK_CUDA(input);
|
16 |
+
CHECK_CUDA(kernel);
|
17 |
+
|
18 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
19 |
+
}
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
23 |
+
}
|
e4e/models/stylegan2/op/upfirdn2d.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.utils.cpp_extension import load
|
6 |
+
|
7 |
+
module_path = os.path.dirname(__file__)
|
8 |
+
upfirdn2d_op = load(
|
9 |
+
'upfirdn2d',
|
10 |
+
sources=[
|
11 |
+
os.path.join(module_path, 'upfirdn2d.cpp'),
|
12 |
+
os.path.join(module_path, 'upfirdn2d_kernel.cu'),
|
13 |
+
],
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class UpFirDn2dBackward(Function):
|
18 |
+
@staticmethod
|
19 |
+
def forward(
|
20 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
21 |
+
):
|
22 |
+
up_x, up_y = up
|
23 |
+
down_x, down_y = down
|
24 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
25 |
+
|
26 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
27 |
+
|
28 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
29 |
+
grad_output,
|
30 |
+
grad_kernel,
|
31 |
+
down_x,
|
32 |
+
down_y,
|
33 |
+
up_x,
|
34 |
+
up_y,
|
35 |
+
g_pad_x0,
|
36 |
+
g_pad_x1,
|
37 |
+
g_pad_y0,
|
38 |
+
g_pad_y1,
|
39 |
+
)
|
40 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
41 |
+
|
42 |
+
ctx.save_for_backward(kernel)
|
43 |
+
|
44 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
45 |
+
|
46 |
+
ctx.up_x = up_x
|
47 |
+
ctx.up_y = up_y
|
48 |
+
ctx.down_x = down_x
|
49 |
+
ctx.down_y = down_y
|
50 |
+
ctx.pad_x0 = pad_x0
|
51 |
+
ctx.pad_x1 = pad_x1
|
52 |
+
ctx.pad_y0 = pad_y0
|
53 |
+
ctx.pad_y1 = pad_y1
|
54 |
+
ctx.in_size = in_size
|
55 |
+
ctx.out_size = out_size
|
56 |
+
|
57 |
+
return grad_input
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def backward(ctx, gradgrad_input):
|
61 |
+
kernel, = ctx.saved_tensors
|
62 |
+
|
63 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
64 |
+
|
65 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
66 |
+
gradgrad_input,
|
67 |
+
kernel,
|
68 |
+
ctx.up_x,
|
69 |
+
ctx.up_y,
|
70 |
+
ctx.down_x,
|
71 |
+
ctx.down_y,
|
72 |
+
ctx.pad_x0,
|
73 |
+
ctx.pad_x1,
|
74 |
+
ctx.pad_y0,
|
75 |
+
ctx.pad_y1,
|
76 |
+
)
|
77 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
78 |
+
gradgrad_out = gradgrad_out.view(
|
79 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
80 |
+
)
|
81 |
+
|
82 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
83 |
+
|
84 |
+
|
85 |
+
class UpFirDn2d(Function):
|
86 |
+
@staticmethod
|
87 |
+
def forward(ctx, input, kernel, up, down, pad):
|
88 |
+
up_x, up_y = up
|
89 |
+
down_x, down_y = down
|
90 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
91 |
+
|
92 |
+
kernel_h, kernel_w = kernel.shape
|
93 |
+
batch, channel, in_h, in_w = input.shape
|
94 |
+
ctx.in_size = input.shape
|
95 |
+
|
96 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
97 |
+
|
98 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
99 |
+
|
100 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
101 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
102 |
+
ctx.out_size = (out_h, out_w)
|
103 |
+
|
104 |
+
ctx.up = (up_x, up_y)
|
105 |
+
ctx.down = (down_x, down_y)
|
106 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
107 |
+
|
108 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
109 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
110 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
111 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
112 |
+
|
113 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
114 |
+
|
115 |
+
out = upfirdn2d_op.upfirdn2d(
|
116 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
117 |
+
)
|
118 |
+
# out = out.view(major, out_h, out_w, minor)
|
119 |
+
out = out.view(-1, channel, out_h, out_w)
|
120 |
+
|
121 |
+
return out
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def backward(ctx, grad_output):
|
125 |
+
kernel, grad_kernel = ctx.saved_tensors
|
126 |
+
|
127 |
+
grad_input = UpFirDn2dBackward.apply(
|
128 |
+
grad_output,
|
129 |
+
kernel,
|
130 |
+
grad_kernel,
|
131 |
+
ctx.up,
|
132 |
+
ctx.down,
|
133 |
+
ctx.pad,
|
134 |
+
ctx.g_pad,
|
135 |
+
ctx.in_size,
|
136 |
+
ctx.out_size,
|
137 |
+
)
|
138 |
+
|
139 |
+
return grad_input, None, None, None, None
|
140 |
+
|
141 |
+
|
142 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
143 |
+
out = UpFirDn2d.apply(
|
144 |
+
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
145 |
+
)
|
146 |
+
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
def upfirdn2d_native(
|
151 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
152 |
+
):
|
153 |
+
_, in_h, in_w, minor = input.shape
|
154 |
+
kernel_h, kernel_w = kernel.shape
|
155 |
+
|
156 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
157 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
158 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
159 |
+
|
160 |
+
out = F.pad(
|
161 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
162 |
+
)
|
163 |
+
out = out[
|
164 |
+
:,
|
165 |
+
max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
|
166 |
+
max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
|
167 |
+
:,
|
168 |
+
]
|
169 |
+
|
170 |
+
out = out.permute(0, 3, 1, 2)
|
171 |
+
out = out.reshape(
|
172 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
173 |
+
)
|
174 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
175 |
+
out = F.conv2d(out, w)
|
176 |
+
out = out.reshape(
|
177 |
+
-1,
|
178 |
+
minor,
|
179 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
180 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
181 |
+
)
|
182 |
+
out = out.permute(0, 2, 3, 1)
|
183 |
+
|
184 |
+
return out[:, ::down_y, ::down_x, :]
|
e4e/models/stylegan2/op/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
19 |
+
int c = a / b;
|
20 |
+
|
21 |
+
if (c * b > a) {
|
22 |
+
c--;
|
23 |
+
}
|
24 |
+
|
25 |
+
return c;
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
struct UpFirDn2DKernelParams {
|
30 |
+
int up_x;
|
31 |
+
int up_y;
|
32 |
+
int down_x;
|
33 |
+
int down_y;
|
34 |
+
int pad_x0;
|
35 |
+
int pad_x1;
|
36 |
+
int pad_y0;
|
37 |
+
int pad_y1;
|
38 |
+
|
39 |
+
int major_dim;
|
40 |
+
int in_h;
|
41 |
+
int in_w;
|
42 |
+
int minor_dim;
|
43 |
+
int kernel_h;
|
44 |
+
int kernel_w;
|
45 |
+
int out_h;
|
46 |
+
int out_w;
|
47 |
+
int loop_major;
|
48 |
+
int loop_x;
|
49 |
+
};
|
50 |
+
|
51 |
+
|
52 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
53 |
+
__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
|
54 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
55 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
56 |
+
|
57 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
58 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
59 |
+
|
60 |
+
int minor_idx = blockIdx.x;
|
61 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
62 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
63 |
+
tile_out_y *= tile_out_h;
|
64 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
65 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
66 |
+
|
67 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
|
68 |
+
return;
|
69 |
+
}
|
70 |
+
|
71 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
|
72 |
+
int ky = tap_idx / kernel_w;
|
73 |
+
int kx = tap_idx - ky * kernel_w;
|
74 |
+
scalar_t v = 0.0;
|
75 |
+
|
76 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
77 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
78 |
+
}
|
79 |
+
|
80 |
+
sk[ky][kx] = v;
|
81 |
+
}
|
82 |
+
|
83 |
+
for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
|
84 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
|
85 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
86 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
87 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
88 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
89 |
+
|
90 |
+
__syncthreads();
|
91 |
+
|
92 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
|
93 |
+
int rel_in_y = in_idx / tile_in_w;
|
94 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
95 |
+
int in_x = rel_in_x + tile_in_x;
|
96 |
+
int in_y = rel_in_y + tile_in_y;
|
97 |
+
|
98 |
+
scalar_t v = 0.0;
|
99 |
+
|
100 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
101 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
|
102 |
+
}
|
103 |
+
|
104 |
+
sx[rel_in_y][rel_in_x] = v;
|
105 |
+
}
|
106 |
+
|
107 |
+
__syncthreads();
|
108 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
|
109 |
+
int rel_out_y = out_idx / tile_out_w;
|
110 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
111 |
+
int out_x = rel_out_x + tile_out_x;
|
112 |
+
int out_y = rel_out_y + tile_out_y;
|
113 |
+
|
114 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
115 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
116 |
+
int in_x = floor_div(mid_x, up_x);
|
117 |
+
int in_y = floor_div(mid_y, up_y);
|
118 |
+
int rel_in_x = in_x - tile_in_x;
|
119 |
+
int rel_in_y = in_y - tile_in_y;
|
120 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
121 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
122 |
+
|
123 |
+
scalar_t v = 0.0;
|
124 |
+
|
125 |
+
#pragma unroll
|
126 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
127 |
+
#pragma unroll
|
128 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
129 |
+
v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
130 |
+
|
131 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
132 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
|
133 |
+
}
|
134 |
+
}
|
135 |
+
}
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
|
140 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
141 |
+
int up_x, int up_y, int down_x, int down_y,
|
142 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
143 |
+
int curDevice = -1;
|
144 |
+
cudaGetDevice(&curDevice);
|
145 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
146 |
+
|
147 |
+
UpFirDn2DKernelParams p;
|
148 |
+
|
149 |
+
auto x = input.contiguous();
|
150 |
+
auto k = kernel.contiguous();
|
151 |
+
|
152 |
+
p.major_dim = x.size(0);
|
153 |
+
p.in_h = x.size(1);
|
154 |
+
p.in_w = x.size(2);
|
155 |
+
p.minor_dim = x.size(3);
|
156 |
+
p.kernel_h = k.size(0);
|
157 |
+
p.kernel_w = k.size(1);
|
158 |
+
p.up_x = up_x;
|
159 |
+
p.up_y = up_y;
|
160 |
+
p.down_x = down_x;
|
161 |
+
p.down_y = down_y;
|
162 |
+
p.pad_x0 = pad_x0;
|
163 |
+
p.pad_x1 = pad_x1;
|
164 |
+
p.pad_y0 = pad_y0;
|
165 |
+
p.pad_y1 = pad_y1;
|
166 |
+
|
167 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
|
168 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
|
169 |
+
|
170 |
+
auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
171 |
+
|
172 |
+
int mode = -1;
|
173 |
+
|
174 |
+
int tile_out_h;
|
175 |
+
int tile_out_w;
|
176 |
+
|
177 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
178 |
+
mode = 1;
|
179 |
+
tile_out_h = 16;
|
180 |
+
tile_out_w = 64;
|
181 |
+
}
|
182 |
+
|
183 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
|
184 |
+
mode = 2;
|
185 |
+
tile_out_h = 16;
|
186 |
+
tile_out_w = 64;
|
187 |
+
}
|
188 |
+
|
189 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
190 |
+
mode = 3;
|
191 |
+
tile_out_h = 16;
|
192 |
+
tile_out_w = 64;
|
193 |
+
}
|
194 |
+
|
195 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
196 |
+
mode = 4;
|
197 |
+
tile_out_h = 16;
|
198 |
+
tile_out_w = 64;
|
199 |
+
}
|
200 |
+
|
201 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
202 |
+
mode = 5;
|
203 |
+
tile_out_h = 8;
|
204 |
+
tile_out_w = 32;
|
205 |
+
}
|
206 |
+
|
207 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
208 |
+
mode = 6;
|
209 |
+
tile_out_h = 8;
|
210 |
+
tile_out_w = 32;
|
211 |
+
}
|
212 |
+
|
213 |
+
dim3 block_size;
|
214 |
+
dim3 grid_size;
|
215 |
+
|
216 |
+
if (tile_out_h > 0 && tile_out_w) {
|
217 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
218 |
+
p.loop_x = 1;
|
219 |
+
block_size = dim3(32 * 8, 1, 1);
|
220 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
221 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
222 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
223 |
+
}
|
224 |
+
|
225 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
226 |
+
switch (mode) {
|
227 |
+
case 1:
|
228 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
229 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
230 |
+
);
|
231 |
+
|
232 |
+
break;
|
233 |
+
|
234 |
+
case 2:
|
235 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
236 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
237 |
+
);
|
238 |
+
|
239 |
+
break;
|
240 |
+
|
241 |
+
case 3:
|
242 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
243 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
244 |
+
);
|
245 |
+
|
246 |
+
break;
|
247 |
+
|
248 |
+
case 4:
|
249 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
250 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
251 |
+
);
|
252 |
+
|
253 |
+
break;
|
254 |
+
|
255 |
+
case 5:
|
256 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
257 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
258 |
+
);
|
259 |
+
|
260 |
+
break;
|
261 |
+
|
262 |
+
case 6:
|
263 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
264 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
265 |
+
);
|
266 |
+
|
267 |
+
break;
|
268 |
+
}
|
269 |
+
});
|
270 |
+
|
271 |
+
return out;
|
272 |
+
}
|
e4e/notebooks/images/car_img.jpg
ADDED
e4e/notebooks/images/church_img.jpg
ADDED
e4e/notebooks/images/horse_img.jpg
ADDED
e4e/notebooks/images/input_img.jpg
ADDED
e4e/options/__init__.py
ADDED
File without changes
|
e4e/options/train_options.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
from configs.paths_config import model_paths
|
3 |
+
|
4 |
+
|
5 |
+
class TrainOptions:
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
self.parser = ArgumentParser()
|
9 |
+
self.initialize()
|
10 |
+
|
11 |
+
def initialize(self):
|
12 |
+
self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
|
13 |
+
self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str,
|
14 |
+
help='Type of dataset/experiment to run')
|
15 |
+
self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use')
|
16 |
+
|
17 |
+
self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
|
18 |
+
self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
|
19 |
+
self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
|
20 |
+
self.parser.add_argument('--test_workers', default=2, type=int,
|
21 |
+
help='Number of test/inference dataloader workers')
|
22 |
+
|
23 |
+
self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate')
|
24 |
+
self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
|
25 |
+
self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model')
|
26 |
+
self.parser.add_argument('--start_from_latent_avg', action='store_true',
|
27 |
+
help='Whether to add average latent vector to generate codes from encoder.')
|
28 |
+
self.parser.add_argument('--lpips_type', default='alex', type=str, help='LPIPS backbone')
|
29 |
+
|
30 |
+
self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor')
|
31 |
+
self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
|
32 |
+
self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
|
33 |
+
|
34 |
+
self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str,
|
35 |
+
help='Path to StyleGAN model weights')
|
36 |
+
self.parser.add_argument('--stylegan_size', default=1024, type=int,
|
37 |
+
help='size of pretrained StyleGAN Generator')
|
38 |
+
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
|
39 |
+
|
40 |
+
self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps')
|
41 |
+
self.parser.add_argument('--image_interval', default=100, type=int,
|
42 |
+
help='Interval for logging train images during training')
|
43 |
+
self.parser.add_argument('--board_interval', default=50, type=int,
|
44 |
+
help='Interval for logging metrics to tensorboard')
|
45 |
+
self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval')
|
46 |
+
self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval')
|
47 |
+
|
48 |
+
# Discriminator flags
|
49 |
+
self.parser.add_argument('--w_discriminator_lambda', default=0, type=float, help='Dw loss multiplier')
|
50 |
+
self.parser.add_argument('--w_discriminator_lr', default=2e-5, type=float, help='Dw learning rate')
|
51 |
+
self.parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization")
|
52 |
+
self.parser.add_argument("--d_reg_every", type=int, default=16,
|
53 |
+
help="interval for applying r1 regularization")
|
54 |
+
self.parser.add_argument('--use_w_pool', action='store_true',
|
55 |
+
help='Whether to store a latnet codes pool for the discriminator\'s training')
|
56 |
+
self.parser.add_argument("--w_pool_size", type=int, default=50,
|
57 |
+
help="W\'s pool size, depends on --use_w_pool")
|
58 |
+
|
59 |
+
# e4e specific
|
60 |
+
self.parser.add_argument('--delta_norm', type=int, default=2, help="norm type of the deltas")
|
61 |
+
self.parser.add_argument('--delta_norm_lambda', type=float, default=2e-4, help="lambda for delta norm loss")
|
62 |
+
|
63 |
+
# Progressive training
|
64 |
+
self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None,
|
65 |
+
help="The training steps of training new deltas. steps[i] starts the delta_i training")
|
66 |
+
self.parser.add_argument('--progressive_start', type=int, default=None,
|
67 |
+
help="The training step to start training the deltas, overrides progressive_steps")
|
68 |
+
self.parser.add_argument('--progressive_step_every', type=int, default=2_000,
|
69 |
+
help="Amount of training steps for each progressive step")
|
70 |
+
|
71 |
+
# Save additional training info to enable future training continuation from produced checkpoints
|
72 |
+
self.parser.add_argument('--save_training_data', action='store_true',
|
73 |
+
help='Save intermediate training data to resume training from the checkpoint')
|
74 |
+
self.parser.add_argument('--sub_exp_dir', default=None, type=str, help='Name of sub experiment directory')
|
75 |
+
self.parser.add_argument('--keep_optimizer', action='store_true',
|
76 |
+
help='Whether to continue from the checkpoint\'s optimizer')
|
77 |
+
self.parser.add_argument('--resume_training_from_ckpt', default=None, type=str,
|
78 |
+
help='Path to training checkpoint, works when --save_training_data was set to True')
|
79 |
+
self.parser.add_argument('--update_param_list', nargs='+', type=str, default=None,
|
80 |
+
help="Name of training parameters to update the loaded training checkpoint")
|
81 |
+
|
82 |
+
def parse(self):
|
83 |
+
opts = self.parser.parse_args()
|
84 |
+
return opts
|