diff --git a/.gitattributes b/.gitattributes
index 818d649bf21cdef29b21f885c8f770f9baa1714e..957b2579c6ef20995a09efd9a17f8fd90606f5ed 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,6 +1,7 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
+*.bin.* filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
@@ -9,13 +10,9 @@
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
-*.npy filter=lfs diff=lfs merge=lfs -text
-*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
-*.pickle filter=lfs diff=lfs merge=lfs -text
-*.pkl filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
@@ -24,8 +21,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
-*.zst filter=lfs diff=lfs merge=lfs -text
+*.zstandard filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..464c587b66b0cdb32019704a37e90e9a4252c531
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Min Jin Chong
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index ebaf1fe08d91e6028e030ca148b0fd75830e2ed1..98158fd9dbff2afc2f0d207cfbd825bf48a31844 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,38 @@
---
-title: JoJoGan Powerhow2
-emoji: 📚
-colorFrom: red
-colorTo: blue
+title: JoJoGAN
+emoji: 🌍
+colorFrom: green
+colorTo: yellow
sdk: gradio
-sdk_version: 3.2
+sdk_version: 3.1.1
app_file: app.py
pinned: false
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Configuration
+
+`title`: _string_
+Display title for the Space
+
+`emoji`: _string_
+Space emoji (emoji-only character allowed)
+
+`colorFrom`: _string_
+Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
+
+`colorTo`: _string_
+Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
+
+`sdk`: _string_
+Can be either `gradio` or `streamlit`
+
+`sdk_version` : _string_
+Only applicable for `streamlit` SDK.
+See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
+
+`app_file`: _string_
+Path to your main application file (which contains either `gradio` or `streamlit` Python code).
+Path is relative to the root of the repository.
+
+`pinned`: _boolean_
+Whether the Space stays on top of your list.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..df2814cae8ab12b97c33c34c03a6498eb703d0e9
--- /dev/null
+++ b/app.py
@@ -0,0 +1,204 @@
+import os
+from PIL import Image
+import torch
+import gradio as gr
+import torch
+torch.backends.cudnn.benchmark = True
+from torchvision import transforms, utils
+from util import *
+from PIL import Image
+import math
+import random
+import numpy as np
+from torch import nn, autograd, optim
+from torch.nn import functional as F
+from tqdm import tqdm
+import lpips
+from model import *
+
+
+#from e4e_projection import projection as e4e_projection
+
+from copy import deepcopy
+import imageio
+
+import os
+import sys
+import numpy as np
+from PIL import Image
+import torch
+import torchvision.transforms as transforms
+from argparse import Namespace
+from e4e.models.psp import pSp
+from util import *
+from huggingface_hub import hf_hub_download
+
+device= 'cpu'
+model_path_e = hf_hub_download(repo_id="akhaliq/JoJoGAN_e4e_ffhq_encode", filename="e4e_ffhq_encode.pt")
+ckpt = torch.load(model_path_e, map_location='cpu')
+opts = ckpt['opts']
+opts['checkpoint_path'] = model_path_e
+opts= Namespace(**opts)
+net = pSp(opts, device).eval().to(device)
+
+@ torch.no_grad()
+def projection(img, name, device='cuda'):
+
+
+ transform = transforms.Compose(
+ [
+ transforms.Resize(256),
+ transforms.CenterCrop(256),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ]
+ )
+ img = transform(img).unsqueeze(0).to(device)
+ images, w_plus = net(img, randomize_noise=False, return_latents=True)
+ result_file = {}
+ result_file['latent'] = w_plus[0]
+ torch.save(result_file, name)
+ return w_plus[0]
+
+
+
+
+device = 'cpu'
+
+
+latent_dim = 512
+
+model_path_s = hf_hub_download(repo_id="akhaliq/jojogan-stylegan2-ffhq-config-f", filename="stylegan2-ffhq-config-f.pt")
+original_generator = Generator(1024, latent_dim, 8, 2).to(device)
+ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage)
+original_generator.load_state_dict(ckpt["g_ema"], strict=False)
+mean_latent = original_generator.mean_latent(10000)
+
+generatorjojo = deepcopy(original_generator)
+
+generatordisney = deepcopy(original_generator)
+
+generatorjinx = deepcopy(original_generator)
+
+generatorcaitlyn = deepcopy(original_generator)
+
+generatoryasuho = deepcopy(original_generator)
+
+generatorarcanemulti = deepcopy(original_generator)
+
+generatorart = deepcopy(original_generator)
+
+generatorspider = deepcopy(original_generator)
+
+generatorsketch = deepcopy(original_generator)
+
+
+transform = transforms.Compose(
+ [
+ transforms.Resize((1024, 1024)),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ ]
+)
+
+
+
+
+modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
+
+
+ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
+generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
+
+
+modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
+
+ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
+generatordisney.load_state_dict(ckptdisney["g"], strict=False)
+
+
+modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
+
+ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
+generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
+
+
+modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
+
+ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
+generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
+
+
+modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
+
+ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
+generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
+
+
+model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
+
+ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
+generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
+
+
+modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
+
+ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
+generatorart.load_state_dict(ckptart["g"], strict=False)
+
+
+modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
+
+ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
+generatorspider.load_state_dict(ckptspider["g"], strict=False)
+
+modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
+
+ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
+generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
+
+def inference(img, model):
+ img.save('out.jpg')
+ aligned_face = align_face('out.jpg')
+
+ my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
+ if model == 'JoJo':
+ with torch.no_grad():
+ my_sample = generatorjojo(my_w, input_is_latent=True)
+ elif model == 'Disney':
+ with torch.no_grad():
+ my_sample = generatordisney(my_w, input_is_latent=True)
+ elif model == 'Jinx':
+ with torch.no_grad():
+ my_sample = generatorjinx(my_w, input_is_latent=True)
+ elif model == 'Caitlyn':
+ with torch.no_grad():
+ my_sample = generatorcaitlyn(my_w, input_is_latent=True)
+ elif model == 'Yasuho':
+ with torch.no_grad():
+ my_sample = generatoryasuho(my_w, input_is_latent=True)
+ elif model == 'Arcane Multi':
+ with torch.no_grad():
+ my_sample = generatorarcanemulti(my_w, input_is_latent=True)
+ elif model == 'Art':
+ with torch.no_grad():
+ my_sample = generatorart(my_w, input_is_latent=True)
+ elif model == 'Spider-Verse':
+ with torch.no_grad():
+ my_sample = generatorspider(my_w, input_is_latent=True)
+ else:
+ with torch.no_grad():
+ my_sample = generatorsketch(my_w, input_is_latent=True)
+
+
+ npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
+ imageio.imwrite('filename.jpeg', npimage)
+ return 'filename.jpeg'
+
+title = "JoJoGAN"
+description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
+
+article = "
JoJoGAN: One Shot Face Stylization| Github Repo Pytorch
"
+
+examples=[['mona.png','Jinx']]
+gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
diff --git a/e4e/.gitignore b/e4e/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b6e47617de110dea7ca47e087ff1347cc2646eda
--- /dev/null
+++ b/e4e/.gitignore
@@ -0,0 +1,129 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/e4e/criteria/__init__.py b/e4e/criteria/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/criteria/id_loss.py b/e4e/criteria/id_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab806172eff18c0630536ae96817508c3197b8b
--- /dev/null
+++ b/e4e/criteria/id_loss.py
@@ -0,0 +1,47 @@
+import torch
+from torch import nn
+from configs.paths_config import model_paths
+from models.encoders.model_irse import Backbone
+
+
+class IDLoss(nn.Module):
+ def __init__(self):
+ super(IDLoss, self).__init__()
+ print('Loading ResNet ArcFace')
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
+ self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
+ self.facenet.eval()
+ for module in [self.facenet, self.face_pool]:
+ for param in module.parameters():
+ param.requires_grad = False
+
+ def extract_feats(self, x):
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
+ x = self.face_pool(x)
+ x_feats = self.facenet(x)
+ return x_feats
+
+ def forward(self, y_hat, y, x):
+ n_samples = x.shape[0]
+ x_feats = self.extract_feats(x)
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
+ y_hat_feats = self.extract_feats(y_hat)
+ y_feats = y_feats.detach()
+ loss = 0
+ sim_improvement = 0
+ id_logs = []
+ count = 0
+ for i in range(n_samples):
+ diff_target = y_hat_feats[i].dot(y_feats[i])
+ diff_input = y_hat_feats[i].dot(x_feats[i])
+ diff_views = y_feats[i].dot(x_feats[i])
+ id_logs.append({'diff_target': float(diff_target),
+ 'diff_input': float(diff_input),
+ 'diff_views': float(diff_views)})
+ loss += 1 - diff_target
+ id_diff = float(diff_target) - float(diff_views)
+ sim_improvement += id_diff
+ count += 1
+
+ return loss / count, sim_improvement / count, id_logs
diff --git a/e4e/criteria/lpips/__init__.py b/e4e/criteria/lpips/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/criteria/lpips/lpips.py b/e4e/criteria/lpips/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..1add6acc84c1c04cfcb536cf31ec5acdf24b716b
--- /dev/null
+++ b/e4e/criteria/lpips/lpips.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+
+from criteria.lpips.networks import get_network, LinLayers
+from criteria.lpips.utils import get_state_dict
+
+
+class LPIPS(nn.Module):
+ r"""Creates a criterion that measures
+ Learned Perceptual Image Patch Similarity (LPIPS).
+ Arguments:
+ net_type (str): the network type to compare the features:
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
+ version (str): the version of LPIPS. Default: 0.1.
+ """
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
+
+ assert version in ['0.1'], 'v0.1 is only supported now'
+
+ super(LPIPS, self).__init__()
+
+ # pretrained network
+ self.net = get_network(net_type).to("cuda")
+
+ # linear layers
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
+ self.lin.load_state_dict(get_state_dict(net_type, version))
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ feat_x, feat_y = self.net(x), self.net(y)
+
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
+
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
diff --git a/e4e/criteria/lpips/networks.py b/e4e/criteria/lpips/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a0d13ad2d560278f16586da68d3a5eadb26e746
--- /dev/null
+++ b/e4e/criteria/lpips/networks.py
@@ -0,0 +1,96 @@
+from typing import Sequence
+
+from itertools import chain
+
+import torch
+import torch.nn as nn
+from torchvision import models
+
+from criteria.lpips.utils import normalize_activation
+
+
+def get_network(net_type: str):
+ if net_type == 'alex':
+ return AlexNet()
+ elif net_type == 'squeeze':
+ return SqueezeNet()
+ elif net_type == 'vgg':
+ return VGG16()
+ else:
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
+
+
+class LinLayers(nn.ModuleList):
+ def __init__(self, n_channels_list: Sequence[int]):
+ super(LinLayers, self).__init__([
+ nn.Sequential(
+ nn.Identity(),
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
+ ) for nc in n_channels_list
+ ])
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+
+class BaseNet(nn.Module):
+ def __init__(self):
+ super(BaseNet, self).__init__()
+
+ # register buffer
+ self.register_buffer(
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer(
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def set_requires_grad(self, state: bool):
+ for param in chain(self.parameters(), self.buffers()):
+ param.requires_grad = state
+
+ def z_score(self, x: torch.Tensor):
+ return (x - self.mean) / self.std
+
+ def forward(self, x: torch.Tensor):
+ x = self.z_score(x)
+
+ output = []
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
+ x = layer(x)
+ if i in self.target_layers:
+ output.append(normalize_activation(x))
+ if len(output) == len(self.target_layers):
+ break
+ return output
+
+
+class SqueezeNet(BaseNet):
+ def __init__(self):
+ super(SqueezeNet, self).__init__()
+
+ self.layers = models.squeezenet1_1(True).features
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
+
+ self.set_requires_grad(False)
+
+
+class AlexNet(BaseNet):
+ def __init__(self):
+ super(AlexNet, self).__init__()
+
+ self.layers = models.alexnet(True).features
+ self.target_layers = [2, 5, 8, 10, 12]
+ self.n_channels_list = [64, 192, 384, 256, 256]
+
+ self.set_requires_grad(False)
+
+
+class VGG16(BaseNet):
+ def __init__(self):
+ super(VGG16, self).__init__()
+
+ self.layers = models.vgg16(True).features
+ self.target_layers = [4, 9, 16, 23, 30]
+ self.n_channels_list = [64, 128, 256, 512, 512]
+
+ self.set_requires_grad(False)
\ No newline at end of file
diff --git a/e4e/criteria/lpips/utils.py b/e4e/criteria/lpips/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d15a0983775810ef6239c561c67939b2b9ee3b5
--- /dev/null
+++ b/e4e/criteria/lpips/utils.py
@@ -0,0 +1,30 @@
+from collections import OrderedDict
+
+import torch
+
+
+def normalize_activation(x, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
+ return x / (norm_factor + eps)
+
+
+def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
+ # build url
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
+
+ # download
+ old_state_dict = torch.hub.load_state_dict_from_url(
+ url, progress=True,
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
+ )
+
+ # rename keys
+ new_state_dict = OrderedDict()
+ for key, val in old_state_dict.items():
+ new_key = key
+ new_key = new_key.replace('lin', '')
+ new_key = new_key.replace('model.', '')
+ new_state_dict[new_key] = val
+
+ return new_state_dict
diff --git a/e4e/criteria/moco_loss.py b/e4e/criteria/moco_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fb13fbd426202cff9014c876c85b0d5c4ec6a9d
--- /dev/null
+++ b/e4e/criteria/moco_loss.py
@@ -0,0 +1,71 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from configs.paths_config import model_paths
+
+
+class MocoLoss(nn.Module):
+
+ def __init__(self, opts):
+ super(MocoLoss, self).__init__()
+ print("Loading MOCO model from path: {}".format(model_paths["moco"]))
+ self.model = self.__load_model()
+ self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def __load_model():
+ import torchvision.models as models
+ model = models.__dict__["resnet50"]()
+ # freeze all layers but the last fc
+ for name, param in model.named_parameters():
+ if name not in ['fc.weight', 'fc.bias']:
+ param.requires_grad = False
+ checkpoint = torch.load(model_paths['moco'], map_location="cpu")
+ state_dict = checkpoint['state_dict']
+ # rename moco pre-trained keys
+ for k in list(state_dict.keys()):
+ # retain only encoder_q up to before the embedding layer
+ if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
+ # remove prefix
+ state_dict[k[len("module.encoder_q."):]] = state_dict[k]
+ # delete renamed or unused k
+ del state_dict[k]
+ msg = model.load_state_dict(state_dict, strict=False)
+ assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
+ # remove output layer
+ model = nn.Sequential(*list(model.children())[:-1]).cuda()
+ return model
+
+ def extract_feats(self, x):
+ x = F.interpolate(x, size=224)
+ x_feats = self.model(x)
+ x_feats = nn.functional.normalize(x_feats, dim=1)
+ x_feats = x_feats.squeeze()
+ return x_feats
+
+ def forward(self, y_hat, y, x):
+ n_samples = x.shape[0]
+ x_feats = self.extract_feats(x)
+ y_feats = self.extract_feats(y)
+ y_hat_feats = self.extract_feats(y_hat)
+ y_feats = y_feats.detach()
+ loss = 0
+ sim_improvement = 0
+ sim_logs = []
+ count = 0
+ for i in range(n_samples):
+ diff_target = y_hat_feats[i].dot(y_feats[i])
+ diff_input = y_hat_feats[i].dot(x_feats[i])
+ diff_views = y_feats[i].dot(x_feats[i])
+ sim_logs.append({'diff_target': float(diff_target),
+ 'diff_input': float(diff_input),
+ 'diff_views': float(diff_views)})
+ loss += 1 - diff_target
+ sim_diff = float(diff_target) - float(diff_views)
+ sim_improvement += sim_diff
+ count += 1
+
+ return loss / count, sim_improvement / count, sim_logs
diff --git a/e4e/criteria/w_norm.py b/e4e/criteria/w_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..a45ab6f67d8a3f7051be4b7236fa2f38446fd2c1
--- /dev/null
+++ b/e4e/criteria/w_norm.py
@@ -0,0 +1,14 @@
+import torch
+from torch import nn
+
+
+class WNormLoss(nn.Module):
+
+ def __init__(self, start_from_latent_avg=True):
+ super(WNormLoss, self).__init__()
+ self.start_from_latent_avg = start_from_latent_avg
+
+ def forward(self, latent, latent_avg=None):
+ if self.start_from_latent_avg:
+ latent = latent - latent_avg
+ return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
diff --git a/e4e/datasets/__init__.py b/e4e/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/datasets/gt_res_dataset.py b/e4e/datasets/gt_res_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0beacfee5335aa10aa7e8b7cabe206d7f9a56f7
--- /dev/null
+++ b/e4e/datasets/gt_res_dataset.py
@@ -0,0 +1,32 @@
+#!/usr/bin/python
+# encoding: utf-8
+import os
+from torch.utils.data import Dataset
+from PIL import Image
+import torch
+
+class GTResDataset(Dataset):
+
+ def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
+ self.pairs = []
+ for f in os.listdir(root_path):
+ image_path = os.path.join(root_path, f)
+ gt_path = os.path.join(gt_dir, f)
+ if f.endswith(".jpg") or f.endswith(".png"):
+ self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
+ self.transform = transform
+ self.transform_train = transform_train
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def __getitem__(self, index):
+ from_path, to_path, _ = self.pairs[index]
+ from_im = Image.open(from_path).convert('RGB')
+ to_im = Image.open(to_path).convert('RGB')
+
+ if self.transform:
+ to_im = self.transform(to_im)
+ from_im = self.transform(from_im)
+
+ return from_im, to_im
diff --git a/e4e/datasets/images_dataset.py b/e4e/datasets/images_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..00c54c7db944569a749af4c6f0c4d99fcc37f9cc
--- /dev/null
+++ b/e4e/datasets/images_dataset.py
@@ -0,0 +1,33 @@
+from torch.utils.data import Dataset
+from PIL import Image
+from utils import data_utils
+
+
+class ImagesDataset(Dataset):
+
+ def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
+ self.source_paths = sorted(data_utils.make_dataset(source_root))
+ self.target_paths = sorted(data_utils.make_dataset(target_root))
+ self.source_transform = source_transform
+ self.target_transform = target_transform
+ self.opts = opts
+
+ def __len__(self):
+ return len(self.source_paths)
+
+ def __getitem__(self, index):
+ from_path = self.source_paths[index]
+ from_im = Image.open(from_path)
+ from_im = from_im.convert('RGB')
+
+ to_path = self.target_paths[index]
+ to_im = Image.open(to_path).convert('RGB')
+ if self.target_transform:
+ to_im = self.target_transform(to_im)
+
+ if self.source_transform:
+ from_im = self.source_transform(from_im)
+ else:
+ from_im = to_im
+
+ return from_im, to_im
diff --git a/e4e/datasets/inference_dataset.py b/e4e/datasets/inference_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb577d7b538d634f27013c2784d2ea32143154cb
--- /dev/null
+++ b/e4e/datasets/inference_dataset.py
@@ -0,0 +1,25 @@
+from torch.utils.data import Dataset
+from PIL import Image
+from utils import data_utils
+
+
+class InferenceDataset(Dataset):
+
+ def __init__(self, root, opts, transform=None, preprocess=None):
+ self.paths = sorted(data_utils.make_dataset(root))
+ self.transform = transform
+ self.preprocess = preprocess
+ self.opts = opts
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, index):
+ from_path = self.paths[index]
+ if self.preprocess is not None:
+ from_im = self.preprocess(from_path)
+ else:
+ from_im = Image.open(from_path).convert('RGB')
+ if self.transform:
+ from_im = self.transform(from_im)
+ return from_im
diff --git a/e4e/editings/ganspace.py b/e4e/editings/ganspace.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c286a421280c542e9776a75e64bb65409da8fc7
--- /dev/null
+++ b/e4e/editings/ganspace.py
@@ -0,0 +1,22 @@
+import torch
+
+
+def edit(latents, pca, edit_directions):
+ edit_latents = []
+ for latent in latents:
+ for pca_idx, start, end, strength in edit_directions:
+ delta = get_delta(pca, latent, pca_idx, strength)
+ delta_padded = torch.zeros(latent.shape).to('cuda')
+ delta_padded[start:end] += delta.repeat(end - start, 1)
+ edit_latents.append(latent + delta_padded)
+ return torch.stack(edit_latents)
+
+
+def get_delta(pca, latent, idx, strength):
+ # pca: ganspace checkpoint. latent: (16, 512) w+
+ w_centered = latent - pca['mean'].to('cuda')
+ lat_comp = pca['comp'].to('cuda')
+ lat_std = pca['std'].to('cuda')
+ w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
+ delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
+ return delta
diff --git a/e4e/editings/ganspace_pca/cars_pca.pt b/e4e/editings/ganspace_pca/cars_pca.pt
new file mode 100644
index 0000000000000000000000000000000000000000..41c2618317f92be5089f99e1f566e9a45650b1bb
--- /dev/null
+++ b/e4e/editings/ganspace_pca/cars_pca.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392
+size 167562
diff --git a/e4e/editings/ganspace_pca/ffhq_pca.pt b/e4e/editings/ganspace_pca/ffhq_pca.pt
new file mode 100644
index 0000000000000000000000000000000000000000..8c8be273036803a6845ad067c8f659867343932d
--- /dev/null
+++ b/e4e/editings/ganspace_pca/ffhq_pca.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36
+size 167562
diff --git a/e4e/editings/interfacegan_directions/age.pt b/e4e/editings/interfacegan_directions/age.pt
new file mode 100644
index 0000000000000000000000000000000000000000..64cdd22d071c643c59ce94d58334f09f647e8a83
--- /dev/null
+++ b/e4e/editings/interfacegan_directions/age.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0
+size 2808
diff --git a/e4e/editings/interfacegan_directions/pose.pt b/e4e/editings/interfacegan_directions/pose.pt
new file mode 100644
index 0000000000000000000000000000000000000000..2b6ceffe285303e7b2b09287167dba965283570b
--- /dev/null
+++ b/e4e/editings/interfacegan_directions/pose.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:736e0eacc8488fa0b020a2e7bd256b957284c364191dfea693705e5d06d43e7d
+size 37624
diff --git a/e4e/editings/interfacegan_directions/smile.pt b/e4e/editings/interfacegan_directions/smile.pt
new file mode 100644
index 0000000000000000000000000000000000000000..eeedc44689954510ce2c3bb585f9f9968ee06825
--- /dev/null
+++ b/e4e/editings/interfacegan_directions/smile.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:817a7e732b59dee9eba862bec8bd7e8373568443bc9f9731a21cf9b0356f0653
+size 2808
diff --git a/e4e/editings/latent_editor.py b/e4e/editings/latent_editor.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bebca2f5c86f71b58fa1f30d24bfcb0da06d88f
--- /dev/null
+++ b/e4e/editings/latent_editor.py
@@ -0,0 +1,45 @@
+import torch
+import sys
+sys.path.append(".")
+sys.path.append("..")
+from editings import ganspace, sefa
+from utils.common import tensor2im
+
+
+class LatentEditor(object):
+ def __init__(self, stylegan_generator, is_cars=False):
+ self.generator = stylegan_generator
+ self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output.
+
+ def apply_ganspace(self, latent, ganspace_pca, edit_directions):
+ edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions)
+ return self._latents_to_image(edit_latents)
+
+ def apply_interfacegan(self, latent, direction, factor=1, factor_range=None):
+ edit_latents = []
+ if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5)
+ for f in range(*factor_range):
+ edit_latent = latent + f * direction
+ edit_latents.append(edit_latent)
+ edit_latents = torch.cat(edit_latents)
+ else:
+ edit_latents = latent + factor * direction
+ return self._latents_to_image(edit_latents)
+
+ def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs):
+ edit_latents = sefa.edit(self.generator, latent, indices, **kwargs)
+ return self._latents_to_image(edit_latents)
+
+ # Currently, in order to apply StyleFlow editings, one should run inference,
+ # save the latent codes and load them form the official StyleFlow repository.
+ # def apply_styleflow(self):
+ # pass
+
+ def _latents_to_image(self, latents):
+ with torch.no_grad():
+ images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True)
+ if self.is_cars:
+ images = images[:, :, 64:448, :] # 512x512 -> 384x512
+ horizontal_concat_image = torch.cat(list(images), 2)
+ final_image = tensor2im(horizontal_concat_image)
+ return final_image
diff --git a/e4e/editings/sefa.py b/e4e/editings/sefa.py
new file mode 100644
index 0000000000000000000000000000000000000000..db7083ce463b765a7cf452807883a3b85fb63fa5
--- /dev/null
+++ b/e4e/editings/sefa.py
@@ -0,0 +1,46 @@
+import torch
+import numpy as np
+from tqdm import tqdm
+
+
+def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11):
+
+ layers, boundaries, values = factorize_weight(generator, indices)
+ codes = latents.detach().cpu().numpy() # (1,18,512)
+
+ # Generate visualization pages.
+ distances = np.linspace(start_distance, end_distance, step)
+ num_sam = num_samples
+ num_sem = semantics
+
+ edited_latents = []
+ for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
+ boundary = boundaries[sem_id:sem_id + 1]
+ for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
+ code = codes[sam_id:sam_id + 1]
+ for col_id, d in enumerate(distances, start=1):
+ temp_code = code.copy()
+ temp_code[:, layers, :] += boundary * d
+ edited_latents.append(torch.from_numpy(temp_code).float().cuda())
+ return torch.cat(edited_latents)
+
+
+def factorize_weight(g_ema, layers='all'):
+
+ weights = []
+ if layers == 'all' or 0 in layers:
+ weight = g_ema.conv1.conv.modulation.weight.T
+ weights.append(weight.cpu().detach().numpy())
+
+ if layers == 'all':
+ layers = list(range(g_ema.num_layers - 1))
+ else:
+ layers = [l - 1 for l in layers if l != 0]
+
+ for idx in layers:
+ weight = g_ema.convs[idx].conv.modulation.weight.T
+ weights.append(weight.cpu().detach().numpy())
+ weight = np.concatenate(weights, axis=1).astype(np.float32)
+ weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
+ eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
+ return layers, eigen_vectors.T, eigen_values
diff --git a/e4e/environment/e4e_env.yaml b/e4e/environment/e4e_env.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4f537615ebb47afd74b5a9856fb9cbea2e0c4bf4
--- /dev/null
+++ b/e4e/environment/e4e_env.yaml
@@ -0,0 +1,73 @@
+name: e4e_env
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - ca-certificates=2020.4.5.1=hecc5488_0
+ - certifi=2020.4.5.1=py36h9f0ad1d_0
+ - libedit=3.1.20181209=hc058e9b_0
+ - libffi=3.2.1=hd88cf55_4
+ - libgcc-ng=9.1.0=hdf63c60_0
+ - libstdcxx-ng=9.1.0=hdf63c60_0
+ - ncurses=6.2=he6710b0_1
+ - ninja=1.10.0=hc9558a2_0
+ - openssl=1.1.1g=h516909a_0
+ - pip=20.0.2=py36_3
+ - python=3.6.7=h0371630_0
+ - python_abi=3.6=1_cp36m
+ - readline=7.0=h7b6447c_5
+ - setuptools=46.4.0=py36_0
+ - sqlite=3.31.1=h62c20be_1
+ - tk=8.6.8=hbc83047_0
+ - wheel=0.34.2=py36_0
+ - xz=5.2.5=h7b6447c_0
+ - zlib=1.2.11=h7b6447c_3
+ - pip:
+ - absl-py==0.9.0
+ - cachetools==4.1.0
+ - chardet==3.0.4
+ - cycler==0.10.0
+ - decorator==4.4.2
+ - future==0.18.2
+ - google-auth==1.15.0
+ - google-auth-oauthlib==0.4.1
+ - grpcio==1.29.0
+ - idna==2.9
+ - imageio==2.8.0
+ - importlib-metadata==1.6.0
+ - kiwisolver==1.2.0
+ - markdown==3.2.2
+ - matplotlib==3.2.1
+ - mxnet==1.6.0
+ - networkx==2.4
+ - numpy==1.18.4
+ - oauthlib==3.1.0
+ - opencv-python==4.2.0.34
+ - pillow==7.1.2
+ - protobuf==3.12.1
+ - pyasn1==0.4.8
+ - pyasn1-modules==0.2.8
+ - pyparsing==2.4.7
+ - python-dateutil==2.8.1
+ - pytorch-lightning==0.7.1
+ - pywavelets==1.1.1
+ - requests==2.23.0
+ - requests-oauthlib==1.3.0
+ - rsa==4.0
+ - scikit-image==0.17.2
+ - scipy==1.4.1
+ - six==1.15.0
+ - tensorboard==2.2.1
+ - tensorboard-plugin-wit==1.6.0.post3
+ - tensorboardx==1.9
+ - tifffile==2020.5.25
+ - torch==1.6.0
+ - torchvision==0.7.1
+ - tqdm==4.46.0
+ - urllib3==1.25.9
+ - werkzeug==1.0.1
+ - zipp==3.1.0
+ - pyaml
+prefix: ~/anaconda3/envs/e4e_env
+
diff --git a/e4e/metrics/LEC.py b/e4e/metrics/LEC.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eef2d2f00a4d757a56b6e845a8fde16aab306ab
--- /dev/null
+++ b/e4e/metrics/LEC.py
@@ -0,0 +1,134 @@
+import sys
+import argparse
+import torch
+import numpy as np
+from torch.utils.data import DataLoader
+
+sys.path.append(".")
+sys.path.append("..")
+
+from configs import data_configs
+from datasets.images_dataset import ImagesDataset
+from utils.model_utils import setup_model
+
+
+class LEC:
+ def __init__(self, net, is_cars=False):
+ """
+ Latent Editing Consistency metric as proposed in the main paper.
+ :param net: e4e model loaded over the pSp framework.
+ :param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images.
+ """
+ self.net = net
+ self.is_cars = is_cars
+
+ def _encode(self, images):
+ """
+ Encodes the given images into StyleGAN's latent space.
+ :param images: Tensor of shape NxCxHxW representing the images to be encoded.
+ :return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space).
+ """
+ codes = self.net.encoder(images)
+ assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}"
+ # normalize with respect to the center of an average face
+ if self.net.opts.start_from_latent_avg:
+ codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
+ return codes
+
+ def _generate(self, codes):
+ """
+ Generate the StyleGAN2 images of the given codes
+ :param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space).
+ :return: Tensor of shape NxCxHxW representing the generated images.
+ """
+ images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True)
+ images = self.net.face_pool(images)
+ if self.is_cars:
+ images = images[:, :, 32:224, :]
+ return images
+
+ @staticmethod
+ def _filter_outliers(arr):
+ arr = np.array(arr)
+
+ lo = np.percentile(arr, 1, interpolation="lower")
+ hi = np.percentile(arr, 99, interpolation="higher")
+ return np.extract(
+ np.logical_and(lo <= arr, arr <= hi), arr
+ )
+
+ def calculate_metric(self, data_loader, edit_function, inverse_edit_function):
+ """
+ Calculate the LEC metric score.
+ :param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader.
+ :param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the
+ latent space.
+ :param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the
+ `edit_function` parameter.
+ :return: The LEC metric score.
+ """
+ distances = []
+ with torch.no_grad():
+ for batch in data_loader:
+ x, _ = batch
+ inputs = x.to(device).float()
+
+ codes = self._encode(inputs)
+ edited_codes = edit_function(codes)
+ edited_image = self._generate(edited_codes)
+ edited_image_inversion_codes = self._encode(edited_image)
+ inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes)
+
+ dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean()
+ distances.append(dist.to("cpu").numpy())
+
+ distances = self._filter_outliers(distances)
+ return distances.mean()
+
+
+if __name__ == "__main__":
+ device = "cuda"
+
+ parser = argparse.ArgumentParser(description="LEC metric calculator")
+
+ parser.add_argument("--batch", type=int, default=8, help="batch size for the models")
+ parser.add_argument("--images_dir", type=str, default=None,
+ help="Path to the images directory on which we calculate the LEC score")
+ parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints")
+
+ args = parser.parse_args()
+ print(args)
+
+ net, opts = setup_model(args.ckpt, device)
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
+
+ images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir
+ test_dataset = ImagesDataset(source_root=images_directory,
+ target_root=images_directory,
+ source_transform=transforms_dict['transform_source'],
+ target_transform=transforms_dict['transform_test'],
+ opts=opts)
+
+ data_loader = DataLoader(test_dataset,
+ batch_size=args.batch,
+ shuffle=False,
+ num_workers=2,
+ drop_last=True)
+
+ print(f'dataset length: {len(test_dataset)}')
+
+ # In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric.
+ # Change the provided example according to your domain and needs.
+ direction = torch.load('../editings/interfacegan_directions/age.pt').to(device)
+
+ def edit_func_example(codes):
+ return codes + 3 * direction
+
+
+ def inverse_edit_func_example(codes):
+ return codes - 3 * direction
+
+ lec = LEC(net, is_cars='car' in opts.dataset_type)
+ result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example)
+ print(f"LEC: {result}")
diff --git a/e4e/models/__init__.py b/e4e/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/models/discriminator.py b/e4e/models/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..16bf3722c7f2e35cdc9bd177a33ed0975e67200d
--- /dev/null
+++ b/e4e/models/discriminator.py
@@ -0,0 +1,20 @@
+from torch import nn
+
+
+class LatentCodesDiscriminator(nn.Module):
+ def __init__(self, style_dim, n_mlp):
+ super().__init__()
+
+ self.style_dim = style_dim
+
+ layers = []
+ for i in range(n_mlp-1):
+ layers.append(
+ nn.Linear(style_dim, style_dim)
+ )
+ layers.append(nn.LeakyReLU(0.2))
+ layers.append(nn.Linear(512, 1))
+ self.mlp = nn.Sequential(*layers)
+
+ def forward(self, w):
+ return self.mlp(w)
diff --git a/e4e/models/encoders/__init__.py b/e4e/models/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/models/encoders/helpers.py b/e4e/models/encoders/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4a58b34ea5ca6912fe53c63dede0a8696f5c024
--- /dev/null
+++ b/e4e/models/encoders/helpers.py
@@ -0,0 +1,140 @@
+from collections import namedtuple
+import torch
+import torch.nn.functional as F
+from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
+
+"""
+ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Flatten(Module):
+ def forward(self, input):
+ return input.view(input.size(0), -1)
+
+
+def l2_norm(input, axis=1):
+ norm = torch.norm(input, 2, axis, True)
+ output = torch.div(input, norm)
+ return output
+
+
+class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
+ """ A named tuple describing a ResNet block. """
+
+
+def get_block(in_channel, depth, num_units, stride=2):
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
+
+
+def get_blocks(num_layers):
+ if num_layers == 50:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=4),
+ get_block(in_channel=128, depth=256, num_units=14),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 100:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=13),
+ get_block(in_channel=128, depth=256, num_units=30),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 152:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=8),
+ get_block(in_channel=128, depth=256, num_units=36),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ else:
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
+ return blocks
+
+
+class SEModule(Module):
+ def __init__(self, channels, reduction):
+ super(SEModule, self).__init__()
+ self.avg_pool = AdaptiveAvgPool2d(1)
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
+ self.relu = ReLU(inplace=True)
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
+ self.sigmoid = Sigmoid()
+
+ def forward(self, x):
+ module_input = x
+ x = self.avg_pool(x)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.sigmoid(x)
+ return module_input * x
+
+
+class bottleneck_IR(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+class bottleneck_IR_SE(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR_SE, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
+ PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
+ BatchNorm2d(depth),
+ SEModule(depth, 16)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+def _upsample_add(x, y):
+ """Upsample and add two feature maps.
+ Args:
+ x: (Variable) top feature map to be upsampled.
+ y: (Variable) lateral feature map.
+ Returns:
+ (Variable) added feature map.
+ Note in PyTorch, when input size is odd, the upsampled feature map
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
+ maybe not equal to the lateral feature map size.
+ e.g.
+ original input size: [N,_,15,15] ->
+ conv2d feature map size: [N,_,8,8] ->
+ upsampled feature map size: [N,_,16,16]
+ So we choose bilinear upsample which supports arbitrary output sizes.
+ """
+ _, _, H, W = y.size()
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
diff --git a/e4e/models/encoders/model_irse.py b/e4e/models/encoders/model_irse.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a94d67542f961ff6533f0335cf4cb0fa54024fb
--- /dev/null
+++ b/e4e/models/encoders/model_irse.py
@@ -0,0 +1,84 @@
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
+from e4e.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
+
+"""
+Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Backbone(Module):
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
+ super(Backbone, self).__init__()
+ assert input_size in [112, 224], "input_size should be 112 or 224"
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ if input_size == 112:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 7 * 7, 512),
+ BatchNorm1d(512, affine=affine))
+ else:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 14 * 14, 512),
+ BatchNorm1d(512, affine=affine))
+
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+ x = self.body(x)
+ x = self.output_layer(x)
+ return l2_norm(x)
+
+
+def IR_50(input_size):
+ """Constructs a ir-50 model."""
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_101(input_size):
+ """Constructs a ir-101 model."""
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_152(input_size):
+ """Constructs a ir-152 model."""
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_50(input_size):
+ """Constructs a ir_se-50 model."""
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_101(input_size):
+ """Constructs a ir_se-101 model."""
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_152(input_size):
+ """Constructs a ir_se-152 model."""
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
diff --git a/e4e/models/encoders/psp_encoders.py b/e4e/models/encoders/psp_encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc49acd11f062cbd29f839ee3c04bce7fa84f479
--- /dev/null
+++ b/e4e/models/encoders/psp_encoders.py
@@ -0,0 +1,200 @@
+from enum import Enum
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
+
+from e4e.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
+from e4e.models.stylegan2.model import EqualLinear
+
+
+class ProgressiveStage(Enum):
+ WTraining = 0
+ Delta1Training = 1
+ Delta2Training = 2
+ Delta3Training = 3
+ Delta4Training = 4
+ Delta5Training = 5
+ Delta6Training = 6
+ Delta7Training = 7
+ Delta8Training = 8
+ Delta9Training = 9
+ Delta10Training = 10
+ Delta11Training = 11
+ Delta12Training = 12
+ Delta13Training = 13
+ Delta14Training = 14
+ Delta15Training = 15
+ Delta16Training = 16
+ Delta17Training = 17
+ Inference = 18
+
+
+class GradualStyleBlock(Module):
+ def __init__(self, in_c, out_c, spatial):
+ super(GradualStyleBlock, self).__init__()
+ self.out_c = out_c
+ self.spatial = spatial
+ num_pools = int(np.log2(spatial))
+ modules = []
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU()]
+ for i in range(num_pools - 1):
+ modules += [
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU()
+ ]
+ self.convs = nn.Sequential(*modules)
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
+
+ def forward(self, x):
+ x = self.convs(x)
+ x = x.view(-1, self.out_c)
+ x = self.linear(x)
+ return x
+
+
+class GradualStyleEncoder(Module):
+ def __init__(self, num_layers, mode='ir', opts=None):
+ super(GradualStyleEncoder, self).__init__()
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ self.styles = nn.ModuleList()
+ log_size = int(math.log(opts.stylegan_size, 2))
+ self.style_count = 2 * log_size - 2
+ self.coarse_ind = 3
+ self.middle_ind = 7
+ for i in range(self.style_count):
+ if i < self.coarse_ind:
+ style = GradualStyleBlock(512, 512, 16)
+ elif i < self.middle_ind:
+ style = GradualStyleBlock(512, 512, 32)
+ else:
+ style = GradualStyleBlock(512, 512, 64)
+ self.styles.append(style)
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+
+ latents = []
+ modulelist = list(self.body._modules.values())
+ for i, l in enumerate(modulelist):
+ x = l(x)
+ if i == 6:
+ c1 = x
+ elif i == 20:
+ c2 = x
+ elif i == 23:
+ c3 = x
+
+ for j in range(self.coarse_ind):
+ latents.append(self.styles[j](c3))
+
+ p2 = _upsample_add(c3, self.latlayer1(c2))
+ for j in range(self.coarse_ind, self.middle_ind):
+ latents.append(self.styles[j](p2))
+
+ p1 = _upsample_add(p2, self.latlayer2(c1))
+ for j in range(self.middle_ind, self.style_count):
+ latents.append(self.styles[j](p1))
+
+ out = torch.stack(latents, dim=1)
+ return out
+
+
+class Encoder4Editing(Module):
+ def __init__(self, num_layers, mode='ir', opts=None):
+ super(Encoder4Editing, self).__init__()
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ self.styles = nn.ModuleList()
+ log_size = int(math.log(opts.stylegan_size, 2))
+ self.style_count = 2 * log_size - 2
+ self.coarse_ind = 3
+ self.middle_ind = 7
+
+ for i in range(self.style_count):
+ if i < self.coarse_ind:
+ style = GradualStyleBlock(512, 512, 16)
+ elif i < self.middle_ind:
+ style = GradualStyleBlock(512, 512, 32)
+ else:
+ style = GradualStyleBlock(512, 512, 64)
+ self.styles.append(style)
+
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
+
+ self.progressive_stage = ProgressiveStage.Inference
+
+ def get_deltas_starting_dimensions(self):
+ ''' Get a list of the initial dimension of every delta from which it is applied '''
+ return list(range(self.style_count)) # Each dimension has a delta applied to it
+
+ def set_progressive_stage(self, new_stage: ProgressiveStage):
+ self.progressive_stage = new_stage
+ print('Changed progressive stage to: ', new_stage)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+
+ modulelist = list(self.body._modules.values())
+ for i, l in enumerate(modulelist):
+ x = l(x)
+ if i == 6:
+ c1 = x
+ elif i == 20:
+ c2 = x
+ elif i == 23:
+ c3 = x
+
+ # Infer main W and duplicate it
+ w0 = self.styles[0](c3)
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
+ stage = self.progressive_stage.value
+ features = c3
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
+ if i == self.coarse_ind:
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
+ features = p2
+ elif i == self.middle_ind:
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
+ features = p1
+ delta_i = self.styles[i](features)
+ w[:, i] += delta_i
+ return w
diff --git a/e4e/models/latent_codes_pool.py b/e4e/models/latent_codes_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..0281d4b5e80f8eb26e824fa35b4f908dcb6634e6
--- /dev/null
+++ b/e4e/models/latent_codes_pool.py
@@ -0,0 +1,55 @@
+import random
+import torch
+
+
+class LatentCodesPool:
+ """This class implements latent codes buffer that stores previously generated w latent codes.
+ This buffer enables us to update discriminators using a history of generated w's
+ rather than the ones produced by the latest encoder.
+ """
+
+ def __init__(self, pool_size):
+ """Initialize the ImagePool class
+ Parameters:
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
+ """
+ self.pool_size = pool_size
+ if self.pool_size > 0: # create an empty pool
+ self.num_ws = 0
+ self.ws = []
+
+ def query(self, ws):
+ """Return w's from the pool.
+ Parameters:
+ ws: the latest generated w's from the generator
+ Returns w's from the buffer.
+ By 50/100, the buffer will return input w's.
+ By 50/100, the buffer will return w's previously stored in the buffer,
+ and insert the current w's to the buffer.
+ """
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
+ return ws
+ return_ws = []
+ for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
+ # w = torch.unsqueeze(image.data, 0)
+ if w.ndim == 2:
+ i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
+ w = w[i]
+ self.handle_w(w, return_ws)
+ return_ws = torch.stack(return_ws, 0) # collect all the images and return
+ return return_ws
+
+ def handle_w(self, w, return_ws):
+ if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
+ self.num_ws = self.num_ws + 1
+ self.ws.append(w)
+ return_ws.append(w)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
+ tmp = self.ws[random_id].clone()
+ self.ws[random_id] = w
+ return_ws.append(tmp)
+ else: # by another 50% chance, the buffer will return the current image
+ return_ws.append(w)
diff --git a/e4e/models/psp.py b/e4e/models/psp.py
new file mode 100644
index 0000000000000000000000000000000000000000..36c0b2b7b3fdd28bc32272d0d8fcff24e4848355
--- /dev/null
+++ b/e4e/models/psp.py
@@ -0,0 +1,99 @@
+import matplotlib
+
+matplotlib.use('Agg')
+import torch
+from torch import nn
+from e4e.models.encoders import psp_encoders
+from e4e.models.stylegan2.model import Generator
+from e4e.configs.paths_config import model_paths
+
+
+def get_keys(d, name):
+ if 'state_dict' in d:
+ d = d['state_dict']
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
+ return d_filt
+
+
+class pSp(nn.Module):
+
+ def __init__(self, opts, device):
+ super(pSp, self).__init__()
+ self.opts = opts
+ self.device = device
+ # Define architecture
+ self.encoder = self.set_encoder()
+ self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
+ # Load weights if needed
+ self.load_weights()
+
+ def set_encoder(self):
+ if self.opts.encoder_type == 'GradualStyleEncoder':
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
+ elif self.opts.encoder_type == 'Encoder4Editing':
+ encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
+ else:
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
+ return encoder
+
+ def load_weights(self):
+ if self.opts.checkpoint_path is not None:
+ print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
+ self.__load_latent_avg(ckpt)
+ else:
+ print('Loading encoders weights from irse50!')
+ encoder_ckpt = torch.load(model_paths['ir_se50'])
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
+ print('Loading decoder weights from pretrained!')
+ ckpt = torch.load(self.opts.stylegan_weights)
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
+ self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
+
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
+ inject_latent=None, return_latents=False, alpha=None):
+ if input_code:
+ codes = x
+ else:
+ codes = self.encoder(x)
+ # normalize with respect to the center of an average face
+ if self.opts.start_from_latent_avg:
+ if codes.ndim == 2:
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
+ else:
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
+
+ if latent_mask is not None:
+ for i in latent_mask:
+ if inject_latent is not None:
+ if alpha is not None:
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
+ else:
+ codes[:, i] = inject_latent[:, i]
+ else:
+ codes[:, i] = 0
+
+ input_is_latent = not input_code
+ images, result_latent = self.decoder([codes],
+ input_is_latent=input_is_latent,
+ randomize_noise=randomize_noise,
+ return_latents=return_latents)
+
+ if resize:
+ images = self.face_pool(images)
+
+ if return_latents:
+ return images, result_latent
+ else:
+ return images
+
+ def __load_latent_avg(self, ckpt, repeat=None):
+ if 'latent_avg' in ckpt:
+ self.latent_avg = ckpt['latent_avg'].to(self.device)
+ if repeat is not None:
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
+ else:
+ self.latent_avg = None
diff --git a/e4e/models/stylegan2/__init__.py b/e4e/models/stylegan2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/models/stylegan2/model.py b/e4e/models/stylegan2/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcb12af85669ab6fd7f79cb14ddbdf80b2fbd83d
--- /dev/null
+++ b/e4e/models/stylegan2/model.py
@@ -0,0 +1,678 @@
+import math
+import random
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+if torch.cuda.is_available():
+ from op.fused_act import FusedLeakyReLU, fused_leaky_relu
+ from op.upfirdn2d import upfirdn2d
+else:
+ from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
+ from op.upfirdn2d_cpu import upfirdn2d
+
+
+class PixelNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_kernel(k):
+ k = torch.tensor(k, dtype=torch.float32)
+
+ if k.ndim == 1:
+ k = k[None, :] * k[:, None]
+
+ k /= k.sum()
+
+ return k
+
+
+class Upsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel) * (factor ** 2)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
+
+ return out
+
+
+class Downsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
+
+ return out
+
+
+class Blur(nn.Module):
+ def __init__(self, kernel, pad, upsample_factor=1):
+ super().__init__()
+
+ kernel = make_kernel(kernel)
+
+ if upsample_factor > 1:
+ kernel = kernel * (upsample_factor ** 2)
+
+ self.register_buffer('kernel', kernel)
+
+ self.pad = pad
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
+
+ return out
+
+
+class EqualConv2d(nn.Module):
+ def __init__(
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
+ )
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
+
+ self.stride = stride
+ self.padding = padding
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channel))
+
+ else:
+ self.bias = None
+
+ def forward(self, input):
+ out = F.conv2d(
+ input,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
+ )
+
+
+class EqualLinear(nn.Module):
+ def __init__(
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
+
+ else:
+ self.bias = None
+
+ self.activation = activation
+
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+ if self.activation:
+ out = F.linear(input, self.weight * self.scale)
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+
+ else:
+ out = F.linear(
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
+ )
+
+
+class ScaledLeakyReLU(nn.Module):
+ def __init__(self, negative_slope=0.2):
+ super().__init__()
+
+ self.negative_slope = negative_slope
+
+ def forward(self, input):
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
+
+ return out * math.sqrt(2)
+
+
+class ModulatedConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ demodulate=True,
+ upsample=False,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ ):
+ super().__init__()
+
+ self.eps = 1e-8
+ self.kernel_size = kernel_size
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.upsample = upsample
+ self.downsample = downsample
+
+ if upsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2 + 1
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
+
+ fan_in = in_channel * kernel_size ** 2
+ self.scale = 1 / math.sqrt(fan_in)
+ self.padding = kernel_size // 2
+
+ self.weight = nn.Parameter(
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
+ )
+
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
+
+ self.demodulate = demodulate
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
+ f'upsample={self.upsample}, downsample={self.downsample})'
+ )
+
+ def forward(self, input, style):
+ batch, in_channel, height, width = input.shape
+
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
+ weight = self.scale * self.weight * style
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
+
+ weight = weight.view(
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+
+ if self.upsample:
+ input = input.view(1, batch * in_channel, height, width)
+ weight = weight.view(
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+ weight = weight.transpose(1, 2).reshape(
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
+ )
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+ out = self.blur(out)
+
+ elif self.downsample:
+ input = self.blur(input)
+ _, _, height, width = input.shape
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ else:
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ return out
+
+
+class NoiseInjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.zeros(1))
+
+ def forward(self, image, noise=None):
+ if noise is None:
+ batch, _, height, width = image.shape
+ noise = image.new_empty(batch, 1, height, width).normal_()
+
+ return image + self.weight * noise
+
+
+class ConstantInput(nn.Module):
+ def __init__(self, channel, size=4):
+ super().__init__()
+
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
+
+ def forward(self, input):
+ batch = input.shape[0]
+ out = self.input.repeat(batch, 1, 1, 1)
+
+ return out
+
+
+class StyledConv(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ demodulate=True,
+ ):
+ super().__init__()
+
+ self.conv = ModulatedConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=upsample,
+ blur_kernel=blur_kernel,
+ demodulate=demodulate,
+ )
+
+ self.noise = NoiseInjection()
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
+ # self.activate = ScaledLeakyReLU(0.2)
+ self.activate = FusedLeakyReLU(out_channel)
+
+ def forward(self, input, style, noise=None):
+ out = self.conv(input, style)
+ out = self.noise(out, noise=noise)
+ # out = out + self.bias
+ out = self.activate(out)
+
+ return out
+
+
+class ToRGB(nn.Module):
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ if upsample:
+ self.upsample = Upsample(blur_kernel)
+
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, input, style, skip=None):
+ out = self.conv(input, style)
+ out = out + self.bias
+
+ if skip is not None:
+ skip = self.upsample(skip)
+
+ out = out + skip
+
+ return out
+
+
+class Generator(nn.Module):
+ def __init__(
+ self,
+ size,
+ style_dim,
+ n_mlp,
+ channel_multiplier=2,
+ blur_kernel=[1, 3, 3, 1],
+ lr_mlp=0.01,
+ ):
+ super().__init__()
+
+ self.size = size
+
+ self.style_dim = style_dim
+
+ layers = [PixelNorm()]
+
+ for i in range(n_mlp):
+ layers.append(
+ EqualLinear(
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
+ )
+ )
+
+ self.style = nn.Sequential(*layers)
+
+ self.channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ self.input = ConstantInput(self.channels[4])
+ self.conv1 = StyledConv(
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
+ )
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
+
+ self.log_size = int(math.log(size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+
+ self.convs = nn.ModuleList()
+ self.upsamples = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channel = self.channels[4]
+
+ for layer_idx in range(self.num_layers):
+ res = (layer_idx + 5) // 2
+ shape = [1, 1, 2 ** res, 2 ** res]
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
+
+ for i in range(3, self.log_size + 1):
+ out_channel = self.channels[2 ** i]
+
+ self.convs.append(
+ StyledConv(
+ in_channel,
+ out_channel,
+ 3,
+ style_dim,
+ upsample=True,
+ blur_kernel=blur_kernel,
+ )
+ )
+
+ self.convs.append(
+ StyledConv(
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
+ )
+ )
+
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
+
+ in_channel = out_channel
+
+ self.n_latent = self.log_size * 2 - 2
+
+ def make_noise(self):
+ device = self.input.input.device
+
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
+
+ return noises
+
+ def mean_latent(self, n_latent):
+ latent_in = torch.randn(
+ n_latent, self.style_dim, device=self.input.input.device
+ )
+ latent = self.style(latent_in).mean(0, keepdim=True)
+
+ return latent
+
+ def get_latent(self, input):
+ return self.style(input)
+
+ def forward(
+ self,
+ styles,
+ return_latents=False,
+ return_features=False,
+ inject_index=None,
+ truncation=1,
+ truncation_latent=None,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ ):
+ if not input_is_latent:
+ styles = [self.style(s) for s in styles]
+
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers
+ else:
+ noise = [
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
+ ]
+
+ if truncation < 1:
+ style_t = []
+
+ for style in styles:
+ style_t.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+
+ styles = style_t
+
+ if len(styles) < 2:
+ inject_index = self.n_latent
+
+ if styles[0].ndim < 3:
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else:
+ latent = styles[0]
+
+ else:
+ if inject_index is None:
+ inject_index = random.randint(1, self.n_latent - 1)
+
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
+
+ latent = torch.cat([latent, latent2], 1)
+
+ out = self.input(latent)
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
+
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+ elif return_features:
+ return image, out
+ else:
+ return image, None
+
+
+class ConvLayer(nn.Sequential):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ bias=True,
+ activate=True,
+ ):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
+
+ stride = 2
+ self.padding = 0
+
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+
+ layers.append(
+ EqualConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=self.padding,
+ stride=stride,
+ bias=bias and not activate,
+ )
+ )
+
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channel))
+
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
+
+ self.skip = ConvLayer(
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
+ )
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ skip = self.skip(input)
+ out = (out + skip) / math.sqrt(2)
+
+ return out
+
+
+class Discriminator(nn.Module):
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ convs = [ConvLayer(3, channels[size], 1)]
+
+ log_size = int(math.log(size, 2))
+
+ in_channel = channels[size]
+
+ for i in range(log_size, 2, -1):
+ out_channel = channels[2 ** (i - 1)]
+
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
+
+ in_channel = out_channel
+
+ self.convs = nn.Sequential(*convs)
+
+ self.stddev_group = 4
+ self.stddev_feat = 1
+
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
+ self.final_linear = nn.Sequential(
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
+ EqualLinear(channels[4], 1),
+ )
+
+ def forward(self, input):
+ out = self.convs(input)
+
+ batch, channel, height, width = out.shape
+ group = min(batch, self.stddev_group)
+ stddev = out.view(
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
+ )
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
+ stddev = stddev.repeat(group, 1, height, width)
+ out = torch.cat([out, stddev], 1)
+
+ out = self.final_conv(out)
+
+ out = out.view(batch, -1)
+ out = self.final_linear(out)
+
+ return out
diff --git a/e4e/models/stylegan2/op/__init__.py b/e4e/models/stylegan2/op/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/models/stylegan2/op/fused_act.py b/e4e/models/stylegan2/op/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..973a84fffde53668d31397da5fb993bbc95f7be0
--- /dev/null
+++ b/e4e/models/stylegan2/op/fused_act.py
@@ -0,0 +1,85 @@
+import os
+
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.utils.cpp_extension import load
+
+module_path = os.path.dirname(__file__)
+fused = load(
+ 'fused',
+ sources=[
+ os.path.join(module_path, 'fused_bias_act.cpp'),
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
+ ],
+)
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused.fused_bias_act(
+ grad_output, empty, out, 3, 1, negative_slope, scale
+ )
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+ gradgrad_out = fused.fused_bias_act(
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
+ )
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
+ grad_output, out, ctx.negative_slope, ctx.scale
+ )
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/e4e/models/stylegan2/op/fused_bias_act.cpp b/e4e/models/stylegan2/op/fused_bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949
--- /dev/null
+++ b/e4e/models/stylegan2/op/fused_bias_act.cpp
@@ -0,0 +1,21 @@
+#include
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
\ No newline at end of file
diff --git a/e4e/models/stylegan2/op/fused_bias_act_kernel.cu b/e4e/models/stylegan2/op/fused_bias_act_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8
--- /dev/null
+++ b/e4e/models/stylegan2/op/fused_bias_act_kernel.cu
@@ -0,0 +1,99 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+template
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<<>>(
+ y.data_ptr(),
+ x.data_ptr(),
+ b.data_ptr(),
+ ref.data_ptr(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
\ No newline at end of file
diff --git a/e4e/models/stylegan2/op/upfirdn2d.cpp b/e4e/models/stylegan2/op/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e
--- /dev/null
+++ b/e4e/models/stylegan2/op/upfirdn2d.cpp
@@ -0,0 +1,23 @@
+#include
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(kernel);
+
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
\ No newline at end of file
diff --git a/e4e/models/stylegan2/op/upfirdn2d.py b/e4e/models/stylegan2/op/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc5a1e331c2bbb1893ac748cfd0f144ff0651b4
--- /dev/null
+++ b/e4e/models/stylegan2/op/upfirdn2d.py
@@ -0,0 +1,184 @@
+import os
+
+import torch
+from torch.autograd import Function
+from torch.utils.cpp_extension import load
+
+module_path = os.path.dirname(__file__)
+upfirdn2d_op = load(
+ 'upfirdn2d',
+ sources=[
+ os.path.join(module_path, 'upfirdn2d.cpp'),
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
+ ],
+)
+
+
+class UpFirDn2dBackward(Function):
+ @staticmethod
+ def forward(
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
+ ):
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_op.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ down_x,
+ down_y,
+ up_x,
+ up_y,
+ g_pad_x0,
+ g_pad_x1,
+ g_pad_y0,
+ g_pad_y1,
+ )
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ ctx.up_x,
+ ctx.up_y,
+ ctx.down_x,
+ ctx.down_y,
+ ctx.pad_x0,
+ ctx.pad_x1,
+ ctx.pad_y0,
+ ctx.pad_y1,
+ )
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
+ )
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_op.upfirdn2d(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+ )
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ out = UpFirDn2d.apply(
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
+ )
+
+ return out
+
+
+def upfirdn2d_native(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+):
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
+ )
+ out = out[
+ :,
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
+ )
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+
+ return out[:, ::down_y, ::down_x, :]
diff --git a/e4e/models/stylegan2/op/upfirdn2d_kernel.cu b/e4e/models/stylegan2/op/upfirdn2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..2a710aa6adc3d43ac93136a1814e3c39970e1c7e
--- /dev/null
+++ b/e4e/models/stylegan2/op/upfirdn2d_kernel.cu
@@ -0,0 +1,272 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+
+struct UpFirDn2DKernelParams {
+ int up_x;
+ int up_y;
+ int down_x;
+ int down_y;
+ int pad_x0;
+ int pad_x1;
+ int pad_y0;
+ int pad_y1;
+
+ int major_dim;
+ int in_h;
+ int in_w;
+ int minor_dim;
+ int kernel_h;
+ int kernel_w;
+ int out_h;
+ int out_w;
+ int loop_major;
+ int loop_x;
+};
+
+
+template
+__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+ __shared__ volatile float sk[kernel_h][kernel_w];
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+ int minor_idx = blockIdx.x;
+ int tile_out_y = minor_idx / p.minor_dim;
+ minor_idx -= tile_out_y * p.minor_dim;
+ tile_out_y *= tile_out_h;
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
+ int ky = tap_idx / kernel_w;
+ int kx = tap_idx - ky * kernel_w;
+ scalar_t v = 0.0;
+
+ if (kx < p.kernel_w & ky < p.kernel_h) {
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+ }
+
+ sk[ky][kx] = v;
+ }
+
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+ int tile_in_x = floor_div(tile_mid_x, up_x);
+ int tile_in_y = floor_div(tile_mid_y, up_y);
+
+ __syncthreads();
+
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
+ int rel_in_y = in_idx / tile_in_w;
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
+ int in_x = rel_in_x + tile_in_x;
+ int in_y = rel_in_y + tile_in_y;
+
+ scalar_t v = 0.0;
+
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
+ }
+
+ sx[rel_in_y][rel_in_x] = v;
+ }
+
+ __syncthreads();
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
+ int rel_out_y = out_idx / tile_out_w;
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
+ int out_x = rel_out_x + tile_out_x;
+ int out_y = rel_out_y + tile_out_y;
+
+ int mid_x = tile_mid_x + rel_out_x * down_x;
+ int mid_y = tile_mid_y + rel_out_y * down_y;
+ int in_x = floor_div(mid_x, up_x);
+ int in_y = floor_div(mid_y, up_y);
+ int rel_in_x = in_x - tile_in_x;
+ int rel_in_y = in_y - tile_in_y;
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+ scalar_t v = 0.0;
+
+ #pragma unroll
+ for (int y = 0; y < kernel_h / up_y; y++)
+ #pragma unroll
+ for (int x = 0; x < kernel_w / up_x; x++)
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+ if (out_x < p.out_w & out_y < p.out_h) {
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
+ }
+ }
+ }
+ }
+}
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ UpFirDn2DKernelParams p;
+
+ auto x = input.contiguous();
+ auto k = kernel.contiguous();
+
+ p.major_dim = x.size(0);
+ p.in_h = x.size(1);
+ p.in_w = x.size(2);
+ p.minor_dim = x.size(3);
+ p.kernel_h = k.size(0);
+ p.kernel_w = k.size(1);
+ p.up_x = up_x;
+ p.up_y = up_y;
+ p.down_x = down_x;
+ p.down_y = down_y;
+ p.pad_x0 = pad_x0;
+ p.pad_x1 = pad_x1;
+ p.pad_y0 = pad_y0;
+ p.pad_y1 = pad_y1;
+
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
+
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+ int mode = -1;
+
+ int tile_out_h;
+ int tile_out_w;
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 1;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
+ mode = 2;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 3;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 4;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 5;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 6;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ dim3 block_size;
+ dim3 grid_size;
+
+ if (tile_out_h > 0 && tile_out_w) {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 1;
+ block_size = dim3(32 * 8, 1, 1);
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+ switch (mode) {
+ case 1:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 2:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 3:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 4:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 5:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 6:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+ }
+ });
+
+ return out;
+}
\ No newline at end of file
diff --git a/e4e/notebooks/images/car_img.jpg b/e4e/notebooks/images/car_img.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..162d13ddc3a7496a160925098fa9bb31d42cfd2a
Binary files /dev/null and b/e4e/notebooks/images/car_img.jpg differ
diff --git a/e4e/notebooks/images/church_img.jpg b/e4e/notebooks/images/church_img.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2282837b5406496f9fd3180dde8b58b288ab88cd
Binary files /dev/null and b/e4e/notebooks/images/church_img.jpg differ
diff --git a/e4e/notebooks/images/horse_img.jpg b/e4e/notebooks/images/horse_img.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..510f4b98169528fe0d03b03683907baa3dcb0ca2
Binary files /dev/null and b/e4e/notebooks/images/horse_img.jpg differ
diff --git a/e4e/notebooks/images/input_img.jpg b/e4e/notebooks/images/input_img.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6534b669166946d63c5468f71a18b502eba7efb3
Binary files /dev/null and b/e4e/notebooks/images/input_img.jpg differ
diff --git a/e4e/options/__init__.py b/e4e/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/options/train_options.py b/e4e/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..583ea1423fdc9a649cd7044d74d554bf0ac2bf51
--- /dev/null
+++ b/e4e/options/train_options.py
@@ -0,0 +1,84 @@
+from argparse import ArgumentParser
+from configs.paths_config import model_paths
+
+
+class TrainOptions:
+
+ def __init__(self):
+ self.parser = ArgumentParser()
+ self.initialize()
+
+ def initialize(self):
+ self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
+ self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str,
+ help='Type of dataset/experiment to run')
+ self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use')
+
+ self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
+ self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
+ self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
+ self.parser.add_argument('--test_workers', default=2, type=int,
+ help='Number of test/inference dataloader workers')
+
+ self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate')
+ self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
+ self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model')
+ self.parser.add_argument('--start_from_latent_avg', action='store_true',
+ help='Whether to add average latent vector to generate codes from encoder.')
+ self.parser.add_argument('--lpips_type', default='alex', type=str, help='LPIPS backbone')
+
+ self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor')
+ self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
+ self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
+
+ self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str,
+ help='Path to StyleGAN model weights')
+ self.parser.add_argument('--stylegan_size', default=1024, type=int,
+ help='size of pretrained StyleGAN Generator')
+ self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
+
+ self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps')
+ self.parser.add_argument('--image_interval', default=100, type=int,
+ help='Interval for logging train images during training')
+ self.parser.add_argument('--board_interval', default=50, type=int,
+ help='Interval for logging metrics to tensorboard')
+ self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval')
+ self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval')
+
+ # Discriminator flags
+ self.parser.add_argument('--w_discriminator_lambda', default=0, type=float, help='Dw loss multiplier')
+ self.parser.add_argument('--w_discriminator_lr', default=2e-5, type=float, help='Dw learning rate')
+ self.parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization")
+ self.parser.add_argument("--d_reg_every", type=int, default=16,
+ help="interval for applying r1 regularization")
+ self.parser.add_argument('--use_w_pool', action='store_true',
+ help='Whether to store a latnet codes pool for the discriminator\'s training')
+ self.parser.add_argument("--w_pool_size", type=int, default=50,
+ help="W\'s pool size, depends on --use_w_pool")
+
+ # e4e specific
+ self.parser.add_argument('--delta_norm', type=int, default=2, help="norm type of the deltas")
+ self.parser.add_argument('--delta_norm_lambda', type=float, default=2e-4, help="lambda for delta norm loss")
+
+ # Progressive training
+ self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None,
+ help="The training steps of training new deltas. steps[i] starts the delta_i training")
+ self.parser.add_argument('--progressive_start', type=int, default=None,
+ help="The training step to start training the deltas, overrides progressive_steps")
+ self.parser.add_argument('--progressive_step_every', type=int, default=2_000,
+ help="Amount of training steps for each progressive step")
+
+ # Save additional training info to enable future training continuation from produced checkpoints
+ self.parser.add_argument('--save_training_data', action='store_true',
+ help='Save intermediate training data to resume training from the checkpoint')
+ self.parser.add_argument('--sub_exp_dir', default=None, type=str, help='Name of sub experiment directory')
+ self.parser.add_argument('--keep_optimizer', action='store_true',
+ help='Whether to continue from the checkpoint\'s optimizer')
+ self.parser.add_argument('--resume_training_from_ckpt', default=None, type=str,
+ help='Path to training checkpoint, works when --save_training_data was set to True')
+ self.parser.add_argument('--update_param_list', nargs='+', type=str, default=None,
+ help="Name of training parameters to update the loaded training checkpoint")
+
+ def parse(self):
+ opts = self.parser.parse_args()
+ return opts
diff --git a/e4e/scripts/calc_losses_on_images.py b/e4e/scripts/calc_losses_on_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..32b6bcee854da7ae357daf82bd986f30db9fb72c
--- /dev/null
+++ b/e4e/scripts/calc_losses_on_images.py
@@ -0,0 +1,87 @@
+from argparse import ArgumentParser
+import os
+import json
+import sys
+from tqdm import tqdm
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+import torchvision.transforms as transforms
+
+sys.path.append(".")
+sys.path.append("..")
+
+from criteria.lpips.lpips import LPIPS
+from datasets.gt_res_dataset import GTResDataset
+
+
+def parse_args():
+ parser = ArgumentParser(add_help=False)
+ parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2'])
+ parser.add_argument('--data_path', type=str, default='results')
+ parser.add_argument('--gt_path', type=str, default='gt_images')
+ parser.add_argument('--workers', type=int, default=4)
+ parser.add_argument('--batch_size', type=int, default=4)
+ parser.add_argument('--is_cars', action='store_true')
+ args = parser.parse_args()
+ return args
+
+
+def run(args):
+ resize_dims = (256, 256)
+ if args.is_cars:
+ resize_dims = (192, 256)
+ transform = transforms.Compose([transforms.Resize(resize_dims),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
+
+ print('Loading dataset')
+ dataset = GTResDataset(root_path=args.data_path,
+ gt_dir=args.gt_path,
+ transform=transform)
+
+ dataloader = DataLoader(dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=int(args.workers),
+ drop_last=True)
+
+ if args.mode == 'lpips':
+ loss_func = LPIPS(net_type='alex')
+ elif args.mode == 'l2':
+ loss_func = torch.nn.MSELoss()
+ else:
+ raise Exception('Not a valid mode!')
+ loss_func.cuda()
+
+ global_i = 0
+ scores_dict = {}
+ all_scores = []
+ for result_batch, gt_batch in tqdm(dataloader):
+ for i in range(args.batch_size):
+ loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda()))
+ all_scores.append(loss)
+ im_path = dataset.pairs[global_i][0]
+ scores_dict[os.path.basename(im_path)] = loss
+ global_i += 1
+
+ all_scores = list(scores_dict.values())
+ mean = np.mean(all_scores)
+ std = np.std(all_scores)
+ result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std)
+ print('Finished with ', args.data_path)
+ print(result_str)
+
+ out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
+ if not os.path.exists(out_path):
+ os.makedirs(out_path)
+
+ with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f:
+ f.write(result_str)
+ with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f:
+ json.dump(scores_dict, f)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ run(args)
diff --git a/e4e/scripts/inference.py b/e4e/scripts/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..185b9b34db85dcd97b9793bd5dbfc9d1ca046549
--- /dev/null
+++ b/e4e/scripts/inference.py
@@ -0,0 +1,133 @@
+import argparse
+
+import torch
+import numpy as np
+import sys
+import os
+import dlib
+
+sys.path.append(".")
+sys.path.append("..")
+
+from configs import data_configs, paths_config
+from datasets.inference_dataset import InferenceDataset
+from torch.utils.data import DataLoader
+from utils.model_utils import setup_model
+from utils.common import tensor2im
+from utils.alignment import align_face
+from PIL import Image
+
+
+def main(args):
+ net, opts = setup_model(args.ckpt, device)
+ is_cars = 'cars_' in opts.dataset_type
+ generator = net.decoder
+ generator.eval()
+ args, data_loader = setup_data_loader(args, opts)
+
+ # Check if latents exist
+ latents_file_path = os.path.join(args.save_dir, 'latents.pt')
+ if os.path.exists(latents_file_path):
+ latent_codes = torch.load(latents_file_path).to(device)
+ else:
+ latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars)
+ torch.save(latent_codes, latents_file_path)
+
+ if not args.latents_only:
+ generate_inversions(args, generator, latent_codes, is_cars=is_cars)
+
+
+def setup_data_loader(args, opts):
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
+ images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
+ print(f"images path: {images_path}")
+ align_function = None
+ if args.align:
+ align_function = run_alignment
+ test_dataset = InferenceDataset(root=images_path,
+ transform=transforms_dict['transform_test'],
+ preprocess=align_function,
+ opts=opts)
+
+ data_loader = DataLoader(test_dataset,
+ batch_size=args.batch,
+ shuffle=False,
+ num_workers=2,
+ drop_last=True)
+
+ print(f'dataset length: {len(test_dataset)}')
+
+ if args.n_sample is None:
+ args.n_sample = len(test_dataset)
+ return args, data_loader
+
+
+def get_latents(net, x, is_cars=False):
+ codes = net.encoder(x)
+ if net.opts.start_from_latent_avg:
+ if codes.ndim == 2:
+ codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
+ else:
+ codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
+ if codes.shape[1] == 18 and is_cars:
+ codes = codes[:, :16, :]
+ return codes
+
+
+def get_all_latents(net, data_loader, n_images=None, is_cars=False):
+ all_latents = []
+ i = 0
+ with torch.no_grad():
+ for batch in data_loader:
+ if n_images is not None and i > n_images:
+ break
+ x = batch
+ inputs = x.to(device).float()
+ latents = get_latents(net, inputs, is_cars)
+ all_latents.append(latents)
+ i += len(latents)
+ return torch.cat(all_latents)
+
+
+def save_image(img, save_dir, idx):
+ result = tensor2im(img)
+ im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg")
+ Image.fromarray(np.array(result)).save(im_save_path)
+
+
+@torch.no_grad()
+def generate_inversions(args, g, latent_codes, is_cars):
+ print('Saving inversion images')
+ inversions_directory_path = os.path.join(args.save_dir, 'inversions')
+ os.makedirs(inversions_directory_path, exist_ok=True)
+ for i in range(args.n_sample):
+ imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True)
+ if is_cars:
+ imgs = imgs[:, :, 64:448, :]
+ save_image(imgs[0], inversions_directory_path, i + 1)
+
+
+def run_alignment(image_path):
+ predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor'])
+ aligned_image = align_face(filepath=image_path, predictor=predictor)
+ print("Aligned image has shape: {}".format(aligned_image.size))
+ return aligned_image
+
+
+if __name__ == "__main__":
+ device = "cuda"
+
+ parser = argparse.ArgumentParser(description="Inference")
+ parser.add_argument("--images_dir", type=str, default=None,
+ help="The directory of the images to be inverted")
+ parser.add_argument("--save_dir", type=str, default=None,
+ help="The directory to save the latent codes and inversion images. (default: images_dir")
+ parser.add_argument("--batch", type=int, default=1, help="batch size for the generator")
+ parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.")
+ parser.add_argument("--latents_only", action="store_true", help="infer only the latent codes of the directory")
+ parser.add_argument("--align", action="store_true", help="align face images before inference")
+ parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to generator checkpoint")
+
+ args = parser.parse_args()
+ main(args)
diff --git a/e4e/scripts/train.py b/e4e/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..d885cfde49a0b21140e663e475918698d5e51ee3
--- /dev/null
+++ b/e4e/scripts/train.py
@@ -0,0 +1,88 @@
+"""
+This file runs the main training/val loop
+"""
+import os
+import json
+import math
+import sys
+import pprint
+import torch
+from argparse import Namespace
+
+sys.path.append(".")
+sys.path.append("..")
+
+from options.train_options import TrainOptions
+from training.coach import Coach
+
+
+def main():
+ opts = TrainOptions().parse()
+ previous_train_ckpt = None
+ if opts.resume_training_from_ckpt:
+ opts, previous_train_ckpt = load_train_checkpoint(opts)
+ else:
+ setup_progressive_steps(opts)
+ create_initial_experiment_dir(opts)
+
+ coach = Coach(opts, previous_train_ckpt)
+ coach.train()
+
+
+def load_train_checkpoint(opts):
+ train_ckpt_path = opts.resume_training_from_ckpt
+ previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu')
+ new_opts_dict = vars(opts)
+ opts = previous_train_ckpt['opts']
+ opts['resume_training_from_ckpt'] = train_ckpt_path
+ update_new_configs(opts, new_opts_dict)
+ pprint.pprint(opts)
+ opts = Namespace(**opts)
+ if opts.sub_exp_dir is not None:
+ sub_exp_dir = opts.sub_exp_dir
+ opts.exp_dir = os.path.join(opts.exp_dir, sub_exp_dir)
+ create_initial_experiment_dir(opts)
+ return opts, previous_train_ckpt
+
+
+def setup_progressive_steps(opts):
+ log_size = int(math.log(opts.stylegan_size, 2))
+ num_style_layers = 2*log_size - 2
+ num_deltas = num_style_layers - 1
+ if opts.progressive_start is not None: # If progressive delta training
+ opts.progressive_steps = [0]
+ next_progressive_step = opts.progressive_start
+ for i in range(num_deltas):
+ opts.progressive_steps.append(next_progressive_step)
+ next_progressive_step += opts.progressive_step_every
+
+ assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \
+ "Invalid progressive training input"
+
+
+def is_valid_progressive_steps(opts, num_style_layers):
+ return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0
+
+
+def create_initial_experiment_dir(opts):
+ if os.path.exists(opts.exp_dir):
+ raise Exception('Oops... {} already exists'.format(opts.exp_dir))
+ os.makedirs(opts.exp_dir)
+
+ opts_dict = vars(opts)
+ pprint.pprint(opts_dict)
+ with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
+ json.dump(opts_dict, f, indent=4, sort_keys=True)
+
+
+def update_new_configs(ckpt_opts, new_opts):
+ for k, v in new_opts.items():
+ if k not in ckpt_opts:
+ ckpt_opts[k] = v
+ if new_opts['update_param_list']:
+ for param in new_opts['update_param_list']:
+ ckpt_opts[param] = new_opts[param]
+
+
+if __name__ == '__main__':
+ main()
diff --git a/e4e/training/__init__.py b/e4e/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/training/coach.py b/e4e/training/coach.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c99da79e699c9362e02c289cd1425848d331d0b
--- /dev/null
+++ b/e4e/training/coach.py
@@ -0,0 +1,437 @@
+import os
+import random
+import matplotlib
+import matplotlib.pyplot as plt
+
+matplotlib.use('Agg')
+
+import torch
+from torch import nn, autograd
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+import torch.nn.functional as F
+
+from utils import common, train_utils
+from criteria import id_loss, moco_loss
+from configs import data_configs
+from datasets.images_dataset import ImagesDataset
+from criteria.lpips.lpips import LPIPS
+from models.psp import pSp
+from models.latent_codes_pool import LatentCodesPool
+from models.discriminator import LatentCodesDiscriminator
+from models.encoders.psp_encoders import ProgressiveStage
+from training.ranger import Ranger
+
+random.seed(0)
+torch.manual_seed(0)
+
+
+class Coach:
+ def __init__(self, opts, prev_train_checkpoint=None):
+ self.opts = opts
+
+ self.global_step = 0
+
+ self.device = 'cuda:0'
+ self.opts.device = self.device
+ # Initialize network
+ self.net = pSp(self.opts).to(self.device)
+
+ # Initialize loss
+ if self.opts.lpips_lambda > 0:
+ self.lpips_loss = LPIPS(net_type=self.opts.lpips_type).to(self.device).eval()
+ if self.opts.id_lambda > 0:
+ if 'ffhq' in self.opts.dataset_type or 'celeb' in self.opts.dataset_type:
+ self.id_loss = id_loss.IDLoss().to(self.device).eval()
+ else:
+ self.id_loss = moco_loss.MocoLoss(opts).to(self.device).eval()
+ self.mse_loss = nn.MSELoss().to(self.device).eval()
+
+ # Initialize optimizer
+ self.optimizer = self.configure_optimizers()
+
+ # Initialize discriminator
+ if self.opts.w_discriminator_lambda > 0:
+ self.discriminator = LatentCodesDiscriminator(512, 4).to(self.device)
+ self.discriminator_optimizer = torch.optim.Adam(list(self.discriminator.parameters()),
+ lr=opts.w_discriminator_lr)
+ self.real_w_pool = LatentCodesPool(self.opts.w_pool_size)
+ self.fake_w_pool = LatentCodesPool(self.opts.w_pool_size)
+
+ # Initialize dataset
+ self.train_dataset, self.test_dataset = self.configure_datasets()
+ self.train_dataloader = DataLoader(self.train_dataset,
+ batch_size=self.opts.batch_size,
+ shuffle=True,
+ num_workers=int(self.opts.workers),
+ drop_last=True)
+ self.test_dataloader = DataLoader(self.test_dataset,
+ batch_size=self.opts.test_batch_size,
+ shuffle=False,
+ num_workers=int(self.opts.test_workers),
+ drop_last=True)
+
+ # Initialize logger
+ log_dir = os.path.join(opts.exp_dir, 'logs')
+ os.makedirs(log_dir, exist_ok=True)
+ self.logger = SummaryWriter(log_dir=log_dir)
+
+ # Initialize checkpoint dir
+ self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.best_val_loss = None
+ if self.opts.save_interval is None:
+ self.opts.save_interval = self.opts.max_steps
+
+ if prev_train_checkpoint is not None:
+ self.load_from_train_checkpoint(prev_train_checkpoint)
+ prev_train_checkpoint = None
+
+ def load_from_train_checkpoint(self, ckpt):
+ print('Loading previous training data...')
+ self.global_step = ckpt['global_step'] + 1
+ self.best_val_loss = ckpt['best_val_loss']
+ self.net.load_state_dict(ckpt['state_dict'])
+
+ if self.opts.keep_optimizer:
+ self.optimizer.load_state_dict(ckpt['optimizer'])
+ if self.opts.w_discriminator_lambda > 0:
+ self.discriminator.load_state_dict(ckpt['discriminator_state_dict'])
+ self.discriminator_optimizer.load_state_dict(ckpt['discriminator_optimizer_state_dict'])
+ if self.opts.progressive_steps:
+ self.check_for_progressive_training_update(is_resume_from_ckpt=True)
+ print(f'Resuming training from step {self.global_step}')
+
+ def train(self):
+ self.net.train()
+ if self.opts.progressive_steps:
+ self.check_for_progressive_training_update()
+ while self.global_step < self.opts.max_steps:
+ for batch_idx, batch in enumerate(self.train_dataloader):
+ loss_dict = {}
+ if self.is_training_discriminator():
+ loss_dict = self.train_discriminator(batch)
+ x, y, y_hat, latent = self.forward(batch)
+ loss, encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
+ loss_dict = {**loss_dict, **encoder_loss_dict}
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+
+ # Logging related
+ if self.global_step % self.opts.image_interval == 0 or (
+ self.global_step < 1000 and self.global_step % 25 == 0):
+ self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces')
+ if self.global_step % self.opts.board_interval == 0:
+ self.print_metrics(loss_dict, prefix='train')
+ self.log_metrics(loss_dict, prefix='train')
+
+ # Validation related
+ val_loss_dict = None
+ if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
+ val_loss_dict = self.validate()
+ if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
+ self.best_val_loss = val_loss_dict['loss']
+ self.checkpoint_me(val_loss_dict, is_best=True)
+
+ if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
+ if val_loss_dict is not None:
+ self.checkpoint_me(val_loss_dict, is_best=False)
+ else:
+ self.checkpoint_me(loss_dict, is_best=False)
+
+ if self.global_step == self.opts.max_steps:
+ print('OMG, finished training!')
+ break
+
+ self.global_step += 1
+ if self.opts.progressive_steps:
+ self.check_for_progressive_training_update()
+
+ def check_for_progressive_training_update(self, is_resume_from_ckpt=False):
+ for i in range(len(self.opts.progressive_steps)):
+ if is_resume_from_ckpt and self.global_step >= self.opts.progressive_steps[i]: # Case checkpoint
+ self.net.encoder.set_progressive_stage(ProgressiveStage(i))
+ if self.global_step == self.opts.progressive_steps[i]: # Case training reached progressive step
+ self.net.encoder.set_progressive_stage(ProgressiveStage(i))
+
+ def validate(self):
+ self.net.eval()
+ agg_loss_dict = []
+ for batch_idx, batch in enumerate(self.test_dataloader):
+ cur_loss_dict = {}
+ if self.is_training_discriminator():
+ cur_loss_dict = self.validate_discriminator(batch)
+ with torch.no_grad():
+ x, y, y_hat, latent = self.forward(batch)
+ loss, cur_encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
+ cur_loss_dict = {**cur_loss_dict, **cur_encoder_loss_dict}
+ agg_loss_dict.append(cur_loss_dict)
+
+ # Logging related
+ self.parse_and_log_images(id_logs, x, y, y_hat,
+ title='images/test/faces',
+ subscript='{:04d}'.format(batch_idx))
+
+ # For first step just do sanity test on small amount of data
+ if self.global_step == 0 and batch_idx >= 4:
+ self.net.train()
+ return None # Do not log, inaccurate in first batch
+
+ loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
+ self.log_metrics(loss_dict, prefix='test')
+ self.print_metrics(loss_dict, prefix='test')
+
+ self.net.train()
+ return loss_dict
+
+ def checkpoint_me(self, loss_dict, is_best):
+ save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step)
+ save_dict = self.__get_save_dict()
+ checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
+ torch.save(save_dict, checkpoint_path)
+ with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
+ if is_best:
+ f.write(
+ '**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
+ else:
+ f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))
+
+ def configure_optimizers(self):
+ params = list(self.net.encoder.parameters())
+ if self.opts.train_decoder:
+ params += list(self.net.decoder.parameters())
+ else:
+ self.requires_grad(self.net.decoder, False)
+ if self.opts.optim_name == 'adam':
+ optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
+ else:
+ optimizer = Ranger(params, lr=self.opts.learning_rate)
+ return optimizer
+
+ def configure_datasets(self):
+ if self.opts.dataset_type not in data_configs.DATASETS.keys():
+ Exception('{} is not a valid dataset_type'.format(self.opts.dataset_type))
+ print('Loading dataset for {}'.format(self.opts.dataset_type))
+ dataset_args = data_configs.DATASETS[self.opts.dataset_type]
+ transforms_dict = dataset_args['transforms'](self.opts).get_transforms()
+ train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'],
+ target_root=dataset_args['train_target_root'],
+ source_transform=transforms_dict['transform_source'],
+ target_transform=transforms_dict['transform_gt_train'],
+ opts=self.opts)
+ test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'],
+ target_root=dataset_args['test_target_root'],
+ source_transform=transforms_dict['transform_source'],
+ target_transform=transforms_dict['transform_test'],
+ opts=self.opts)
+ print("Number of training samples: {}".format(len(train_dataset)))
+ print("Number of test samples: {}".format(len(test_dataset)))
+ return train_dataset, test_dataset
+
+ def calc_loss(self, x, y, y_hat, latent):
+ loss_dict = {}
+ loss = 0.0
+ id_logs = None
+ if self.is_training_discriminator(): # Adversarial loss
+ loss_disc = 0.
+ dims_to_discriminate = self.get_dims_to_discriminate() if self.is_progressive_training() else \
+ list(range(self.net.decoder.n_latent))
+
+ for i in dims_to_discriminate:
+ w = latent[:, i, :]
+ fake_pred = self.discriminator(w)
+ loss_disc += F.softplus(-fake_pred).mean()
+ loss_disc /= len(dims_to_discriminate)
+ loss_dict['encoder_discriminator_loss'] = float(loss_disc)
+ loss += self.opts.w_discriminator_lambda * loss_disc
+
+ if self.opts.progressive_steps and self.net.encoder.progressive_stage.value != 18: # delta regularization loss
+ total_delta_loss = 0
+ deltas_latent_dims = self.net.encoder.get_deltas_starting_dimensions()
+
+ first_w = latent[:, 0, :]
+ for i in range(1, self.net.encoder.progressive_stage.value + 1):
+ curr_dim = deltas_latent_dims[i]
+ delta = latent[:, curr_dim, :] - first_w
+ delta_loss = torch.norm(delta, self.opts.delta_norm, dim=1).mean()
+ loss_dict[f"delta{i}_loss"] = float(delta_loss)
+ total_delta_loss += delta_loss
+ loss_dict['total_delta_loss'] = float(total_delta_loss)
+ loss += self.opts.delta_norm_lambda * total_delta_loss
+
+ if self.opts.id_lambda > 0: # Similarity loss
+ loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x)
+ loss_dict['loss_id'] = float(loss_id)
+ loss_dict['id_improve'] = float(sim_improvement)
+ loss += loss_id * self.opts.id_lambda
+ if self.opts.l2_lambda > 0:
+ loss_l2 = F.mse_loss(y_hat, y)
+ loss_dict['loss_l2'] = float(loss_l2)
+ loss += loss_l2 * self.opts.l2_lambda
+ if self.opts.lpips_lambda > 0:
+ loss_lpips = self.lpips_loss(y_hat, y)
+ loss_dict['loss_lpips'] = float(loss_lpips)
+ loss += loss_lpips * self.opts.lpips_lambda
+ loss_dict['loss'] = float(loss)
+ return loss, loss_dict, id_logs
+
+ def forward(self, batch):
+ x, y = batch
+ x, y = x.to(self.device).float(), y.to(self.device).float()
+ y_hat, latent = self.net.forward(x, return_latents=True)
+ if self.opts.dataset_type == "cars_encode":
+ y_hat = y_hat[:, :, 32:224, :]
+ return x, y, y_hat, latent
+
+ def log_metrics(self, metrics_dict, prefix):
+ for key, value in metrics_dict.items():
+ self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step)
+
+ def print_metrics(self, metrics_dict, prefix):
+ print('Metrics for {}, step {}'.format(prefix, self.global_step))
+ for key, value in metrics_dict.items():
+ print('\t{} = '.format(key), value)
+
+ def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2):
+ im_data = []
+ for i in range(display_count):
+ cur_im_data = {
+ 'input_face': common.log_input_image(x[i], self.opts),
+ 'target_face': common.tensor2im(y[i]),
+ 'output_face': common.tensor2im(y_hat[i]),
+ }
+ if id_logs is not None:
+ for key in id_logs[i]:
+ cur_im_data[key] = id_logs[i][key]
+ im_data.append(cur_im_data)
+ self.log_images(title, im_data=im_data, subscript=subscript)
+
+ def log_images(self, name, im_data, subscript=None, log_latest=False):
+ fig = common.vis_faces(im_data)
+ step = self.global_step
+ if log_latest:
+ step = 0
+ if subscript:
+ path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step))
+ else:
+ path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step))
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ fig.savefig(path)
+ plt.close(fig)
+
+ def __get_save_dict(self):
+ save_dict = {
+ 'state_dict': self.net.state_dict(),
+ 'opts': vars(self.opts)
+ }
+ # save the latent avg in state_dict for inference if truncation of w was used during training
+ if self.opts.start_from_latent_avg:
+ save_dict['latent_avg'] = self.net.latent_avg
+
+ if self.opts.save_training_data: # Save necessary information to enable training continuation from checkpoint
+ save_dict['global_step'] = self.global_step
+ save_dict['optimizer'] = self.optimizer.state_dict()
+ save_dict['best_val_loss'] = self.best_val_loss
+ if self.opts.w_discriminator_lambda > 0:
+ save_dict['discriminator_state_dict'] = self.discriminator.state_dict()
+ save_dict['discriminator_optimizer_state_dict'] = self.discriminator_optimizer.state_dict()
+ return save_dict
+
+ def get_dims_to_discriminate(self):
+ deltas_starting_dimensions = self.net.encoder.get_deltas_starting_dimensions()
+ return deltas_starting_dimensions[:self.net.encoder.progressive_stage.value + 1]
+
+ def is_progressive_training(self):
+ return self.opts.progressive_steps is not None
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Discriminator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+ def is_training_discriminator(self):
+ return self.opts.w_discriminator_lambda > 0
+
+ @staticmethod
+ def discriminator_loss(real_pred, fake_pred, loss_dict):
+ real_loss = F.softplus(-real_pred).mean()
+ fake_loss = F.softplus(fake_pred).mean()
+
+ loss_dict['d_real_loss'] = float(real_loss)
+ loss_dict['d_fake_loss'] = float(fake_loss)
+
+ return real_loss + fake_loss
+
+ @staticmethod
+ def discriminator_r1_loss(real_pred, real_w):
+ grad_real, = autograd.grad(
+ outputs=real_pred.sum(), inputs=real_w, create_graph=True
+ )
+ grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
+
+ return grad_penalty
+
+ @staticmethod
+ def requires_grad(model, flag=True):
+ for p in model.parameters():
+ p.requires_grad = flag
+
+ def train_discriminator(self, batch):
+ loss_dict = {}
+ x, _ = batch
+ x = x.to(self.device).float()
+ self.requires_grad(self.discriminator, True)
+
+ with torch.no_grad():
+ real_w, fake_w = self.sample_real_and_fake_latents(x)
+ real_pred = self.discriminator(real_w)
+ fake_pred = self.discriminator(fake_w)
+ loss = self.discriminator_loss(real_pred, fake_pred, loss_dict)
+ loss_dict['discriminator_loss'] = float(loss)
+
+ self.discriminator_optimizer.zero_grad()
+ loss.backward()
+ self.discriminator_optimizer.step()
+
+ # r1 regularization
+ d_regularize = self.global_step % self.opts.d_reg_every == 0
+ if d_regularize:
+ real_w = real_w.detach()
+ real_w.requires_grad = True
+ real_pred = self.discriminator(real_w)
+ r1_loss = self.discriminator_r1_loss(real_pred, real_w)
+
+ self.discriminator.zero_grad()
+ r1_final_loss = self.opts.r1 / 2 * r1_loss * self.opts.d_reg_every + 0 * real_pred[0]
+ r1_final_loss.backward()
+ self.discriminator_optimizer.step()
+ loss_dict['discriminator_r1_loss'] = float(r1_final_loss)
+
+ # Reset to previous state
+ self.requires_grad(self.discriminator, False)
+
+ return loss_dict
+
+ def validate_discriminator(self, test_batch):
+ with torch.no_grad():
+ loss_dict = {}
+ x, _ = test_batch
+ x = x.to(self.device).float()
+ real_w, fake_w = self.sample_real_and_fake_latents(x)
+ real_pred = self.discriminator(real_w)
+ fake_pred = self.discriminator(fake_w)
+ loss = self.discriminator_loss(real_pred, fake_pred, loss_dict)
+ loss_dict['discriminator_loss'] = float(loss)
+ return loss_dict
+
+ def sample_real_and_fake_latents(self, x):
+ sample_z = torch.randn(self.opts.batch_size, 512, device=self.device)
+ real_w = self.net.decoder.get_latent(sample_z)
+ fake_w = self.net.encoder(x)
+ if self.is_progressive_training(): # When progressive training, feed only unique w's
+ dims_to_discriminate = self.get_dims_to_discriminate()
+ fake_w = fake_w[:, dims_to_discriminate, :]
+ if self.opts.use_w_pool:
+ real_w = self.real_w_pool.query(real_w)
+ fake_w = self.fake_w_pool.query(fake_w)
+ if fake_w.ndim == 3:
+ fake_w = fake_w[:, 0, :]
+ return real_w, fake_w
diff --git a/e4e/training/ranger.py b/e4e/training/ranger.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d63264dda6df0ee40cac143440f0b5f8977a9ad
--- /dev/null
+++ b/e4e/training/ranger.py
@@ -0,0 +1,164 @@
+# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
+
+# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
+# and/or
+# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
+
+# Ranger has now been used to capture 12 records on the FastAI leaderboard.
+
+# This version = 20.4.11
+
+# Credits:
+# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
+# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
+# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
+# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
+
+# summary of changes:
+# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
+# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
+# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
+# changes 8/31/19 - fix references to *self*.N_sma_threshold;
+# changed eps to 1e-5 as better default than 1e-8.
+
+import math
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class Ranger(Optimizer):
+
+ def __init__(self, params, lr=1e-3, # lr
+ alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
+ betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
+ use_gc=True, gc_conv_only=False
+ # Gradient centralization on or off, applied to conv layers only or conv + fc layers
+ ):
+
+ # parameter checks
+ if not 0.0 <= alpha <= 1.0:
+ raise ValueError(f'Invalid slow update rate: {alpha}')
+ if not 1 <= k:
+ raise ValueError(f'Invalid lookahead steps: {k}')
+ if not lr > 0:
+ raise ValueError(f'Invalid Learning Rate: {lr}')
+ if not eps > 0:
+ raise ValueError(f'Invalid eps: {eps}')
+
+ # parameter comments:
+ # beta1 (momentum) of .95 seems to work better than .90...
+ # N_sma_threshold of 5 seems better in testing than 4.
+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
+
+ # prep defaults and init torch.optim base
+ defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
+ eps=eps, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+
+ # adjustable threshold
+ self.N_sma_threshhold = N_sma_threshhold
+
+ # look ahead params
+
+ self.alpha = alpha
+ self.k = k
+
+ # radam buffer for state
+ self.radam_buffer = [[None, None, None] for ind in range(10)]
+
+ # gc on or off
+ self.use_gc = use_gc
+
+ # level of gradient centralization
+ self.gc_gradient_threshold = 3 if gc_conv_only else 1
+
+ def __setstate__(self, state):
+ super(Ranger, self).__setstate__(state)
+
+ def step(self, closure=None):
+ loss = None
+
+ # Evaluate averages and grad, update param tensors
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.float()
+
+ if grad.is_sparse:
+ raise RuntimeError('Ranger optimizer does not support sparse gradients')
+
+ p_data_fp32 = p.data.float()
+
+ state = self.state[p] # get state dict for this param
+
+ if len(state) == 0: # if first time to run...init dictionary with our desired entries
+ # if self.first_run_check==0:
+ # self.first_run_check=1
+ # print("Initializing slow buffer...should not see this at load from saved model!")
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+
+ # look ahead weight storage now in state dict
+ state['slow_buffer'] = torch.empty_like(p.data)
+ state['slow_buffer'].copy_(p.data)
+
+ else:
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+ # begin computations
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ # GC operation for Conv layers and FC layers
+ if grad.dim() > self.gc_gradient_threshold:
+ grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
+
+ state['step'] += 1
+
+ # compute variance mov avg
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ # compute mean moving avg
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+ buffered = self.radam_buffer[int(state['step'] % 10)]
+
+ if state['step'] == buffered[0]:
+ N_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ N_sma_max = 2 / (1 - beta2) - 1
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = N_sma
+ if N_sma > self.N_sma_threshhold:
+ step_size = math.sqrt(
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
+ else:
+ step_size = 1.0 / (1 - beta1 ** state['step'])
+ buffered[2] = step_size
+
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+
+ # apply lr
+ if N_sma > self.N_sma_threshhold:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+ p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
+ else:
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+
+ p.data.copy_(p_data_fp32)
+
+ # integrated look ahead...
+ # we do it at the param level instead of group level
+ if state['step'] % group['k'] == 0:
+ slow_p = state['slow_buffer'] # get access to slow param tensor
+ slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
+ p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
+
+ return loss
\ No newline at end of file
diff --git a/e4e/utils/__init__.py b/e4e/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/e4e/utils/alignment.py b/e4e/utils/alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..a02798f0f7c9fdcc319f7884a491b9e6580cc8aa
--- /dev/null
+++ b/e4e/utils/alignment.py
@@ -0,0 +1,115 @@
+import numpy as np
+import PIL
+import PIL.Image
+import scipy
+import scipy.ndimage
+import dlib
+
+
+def get_landmark(filepath, predictor):
+ """get landmark with dlib
+ :return: np.array shape=(68, 2)
+ """
+ detector = dlib.get_frontal_face_detector()
+
+ img = dlib.load_rgb_image(filepath)
+ dets = detector(img, 1)
+
+ for k, d in enumerate(dets):
+ shape = predictor(img, d)
+
+ t = list(shape.parts())
+ a = []
+ for tt in t:
+ a.append([tt.x, tt.y])
+ lm = np.array(a)
+ return lm
+
+
+def align_face(filepath, predictor):
+ """
+ :param filepath: str
+ :return: PIL Image
+ """
+
+ lm = get_landmark(filepath, predictor)
+
+ lm_chin = lm[0: 17] # left-right
+ lm_eyebrow_left = lm[17: 22] # left-right
+ lm_eyebrow_right = lm[22: 27] # left-right
+ lm_nose = lm[27: 31] # top-down
+ lm_nostrils = lm[31: 36] # top-down
+ lm_eye_left = lm[36: 42] # left-clockwise
+ lm_eye_right = lm[42: 48] # left-clockwise
+ lm_mouth_outer = lm[48: 60] # left-clockwise
+ lm_mouth_inner = lm[60: 68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ x /= np.hypot(*x)
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ qsize = np.hypot(*x) * 2
+
+ # read image
+ img = PIL.Image.open(filepath)
+
+ output_size = 256
+ transform_size = 256
+ enable_padding = True
+
+ # Shrink.
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+ min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+ max(pad[3] - img.size[1] + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w, _ = img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
+ blur = qsize * 0.02
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ quad += pad[:2]
+
+ # Transform.
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
+ if output_size < transform_size:
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
+
+ # Return aligned image.
+ return img
diff --git a/e4e/utils/common.py b/e4e/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..b19e18ddcb78b06678fa18e4a76da44fc511b789
--- /dev/null
+++ b/e4e/utils/common.py
@@ -0,0 +1,55 @@
+from PIL import Image
+import matplotlib.pyplot as plt
+
+
+# Log images
+def log_input_image(x, opts):
+ return tensor2im(x)
+
+
+def tensor2im(var):
+ # var shape: (3, H, W)
+ var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
+ var = ((var + 1) / 2)
+ var[var < 0] = 0
+ var[var > 1] = 1
+ var = var * 255
+ return Image.fromarray(var.astype('uint8'))
+
+
+def vis_faces(log_hooks):
+ display_count = len(log_hooks)
+ fig = plt.figure(figsize=(8, 4 * display_count))
+ gs = fig.add_gridspec(display_count, 3)
+ for i in range(display_count):
+ hooks_dict = log_hooks[i]
+ fig.add_subplot(gs[i, 0])
+ if 'diff_input' in hooks_dict:
+ vis_faces_with_id(hooks_dict, fig, gs, i)
+ else:
+ vis_faces_no_id(hooks_dict, fig, gs, i)
+ plt.tight_layout()
+ return fig
+
+
+def vis_faces_with_id(hooks_dict, fig, gs, i):
+ plt.imshow(hooks_dict['input_face'])
+ plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input'])))
+ fig.add_subplot(gs[i, 1])
+ plt.imshow(hooks_dict['target_face'])
+ plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']),
+ float(hooks_dict['diff_target'])))
+ fig.add_subplot(gs[i, 2])
+ plt.imshow(hooks_dict['output_face'])
+ plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target'])))
+
+
+def vis_faces_no_id(hooks_dict, fig, gs, i):
+ plt.imshow(hooks_dict['input_face'], cmap="gray")
+ plt.title('Input')
+ fig.add_subplot(gs[i, 1])
+ plt.imshow(hooks_dict['target_face'])
+ plt.title('Target')
+ fig.add_subplot(gs[i, 2])
+ plt.imshow(hooks_dict['output_face'])
+ plt.title('Output')
diff --git a/e4e/utils/data_utils.py b/e4e/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1ba79f4a2d5cc2b97dce76d87bf6e7cdebbc257
--- /dev/null
+++ b/e4e/utils/data_utils.py
@@ -0,0 +1,25 @@
+"""
+Code adopted from pix2pixHD:
+https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py
+"""
+import os
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir):
+ images = []
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
+ for root, _, fnames in sorted(os.walk(dir)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+ return images
diff --git a/e4e/utils/model_utils.py b/e4e/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e51e95578f72b3218d6d832e3b604193cb68c1d7
--- /dev/null
+++ b/e4e/utils/model_utils.py
@@ -0,0 +1,35 @@
+import torch
+import argparse
+from models.psp import pSp
+from models.encoders.psp_encoders import Encoder4Editing
+
+
+def setup_model(checkpoint_path, device='cuda'):
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
+ opts = ckpt['opts']
+
+ opts['checkpoint_path'] = checkpoint_path
+ opts['device'] = device
+ opts = argparse.Namespace(**opts)
+
+ net = pSp(opts)
+ net.eval()
+ net = net.to(device)
+ return net, opts
+
+
+def load_e4e_standalone(checkpoint_path, device='cuda'):
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
+ opts = argparse.Namespace(**ckpt['opts'])
+ e4e = Encoder4Editing(50, 'ir_se', opts)
+ e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')}
+ e4e.load_state_dict(e4e_dict)
+ e4e.eval()
+ e4e = e4e.to(device)
+ latent_avg = ckpt['latent_avg'].to(device)
+
+ def add_latent_avg(model, inputs, outputs):
+ return outputs + latent_avg.repeat(outputs.shape[0], 1, 1)
+
+ e4e.register_forward_hook(add_latent_avg)
+ return e4e
diff --git a/e4e/utils/train_utils.py b/e4e/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c55177f7442010bc1fcc64de3d142585c22adc0
--- /dev/null
+++ b/e4e/utils/train_utils.py
@@ -0,0 +1,13 @@
+
+def aggregate_loss_dict(agg_loss_dict):
+ mean_vals = {}
+ for output in agg_loss_dict:
+ for key in output:
+ mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]]
+ for key in mean_vals:
+ if len(mean_vals[key]) > 0:
+ mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key])
+ else:
+ print('{} has no value'.format(key))
+ mean_vals[key] = 0
+ return mean_vals
diff --git a/e4e_projection.py b/e4e_projection.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf5bc1b7301e068626460b3ac23fe44f49238d79
--- /dev/null
+++ b/e4e_projection.py
@@ -0,0 +1,38 @@
+import os
+import sys
+import numpy as np
+from PIL import Image
+import torch
+import torchvision.transforms as transforms
+from argparse import Namespace
+from e4e.models.psp import pSp
+from util import *
+
+
+
+@ torch.no_grad()
+def projection(img, name, device='cuda'):
+
+
+ model_path = 'e4e_ffhq_encode.pt'
+ ckpt = torch.load(model_path, map_location='cpu')
+ opts = ckpt['opts']
+ opts['checkpoint_path'] = model_path
+ opts= Namespace(**opts)
+ net = pSp(opts, device).eval().to(device)
+
+ transform = transforms.Compose(
+ [
+ transforms.Resize(256),
+ transforms.CenterCrop(256),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ]
+ )
+
+ img = transform(img).unsqueeze(0).to(device)
+ images, w_plus = net(img, randomize_noise=False, return_latents=True)
+ result_file = {}
+ result_file['latent'] = w_plus[0]
+ torch.save(result_file, name)
+ return w_plus[0]
diff --git a/elon.png b/elon.png
new file mode 100644
index 0000000000000000000000000000000000000000..272bbbceff04d64c6eabd3f99c25350095d3c33e
Binary files /dev/null and b/elon.png differ
diff --git a/iu.jpeg b/iu.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..d89550f47ff9e351739770bb1561b320a3232cef
Binary files /dev/null and b/iu.jpeg differ
diff --git a/model.py b/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..497bf78d57c54d58cd3b55f26c718be2470a04f1
--- /dev/null
+++ b/model.py
@@ -0,0 +1,688 @@
+import math
+import random
+import functools
+import operator
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.autograd import Function
+
+from op import conv2d_gradfix
+if torch.cuda.is_available():
+ from op.fused_act import FusedLeakyReLU, fused_leaky_relu
+ from op.upfirdn2d import upfirdn2d
+else:
+ from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
+ from op.upfirdn2d_cpu import upfirdn2d
+
+
+class PixelNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_kernel(k):
+ k = torch.tensor(k, dtype=torch.float32)
+
+ if k.ndim == 1:
+ k = k[None, :] * k[:, None]
+
+ k /= k.sum()
+
+ return k
+
+
+class Upsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel) * (factor ** 2)
+ self.register_buffer("kernel", kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
+
+ return out
+
+
+class Downsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel)
+ self.register_buffer("kernel", kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
+
+ return out
+
+
+class Blur(nn.Module):
+ def __init__(self, kernel, pad, upsample_factor=1):
+ super().__init__()
+
+ kernel = make_kernel(kernel)
+
+ if upsample_factor > 1:
+ kernel = kernel * (upsample_factor ** 2)
+
+ self.register_buffer("kernel", kernel)
+
+ self.pad = pad
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
+
+ return out
+
+
+class EqualConv2d(nn.Module):
+ def __init__(
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
+ )
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
+
+ self.stride = stride
+ self.padding = padding
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channel))
+
+ else:
+ self.bias = None
+
+ def forward(self, input):
+ out = conv2d_gradfix.conv2d(
+ input,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
+ )
+
+
+class EqualLinear(nn.Module):
+ def __init__(
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
+
+ else:
+ self.bias = None
+
+ self.activation = activation
+
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+ if self.activation:
+ out = F.linear(input, self.weight * self.scale)
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+
+ else:
+ out = F.linear(
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
+ )
+
+
+class ModulatedConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ demodulate=True,
+ upsample=False,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ fused=True,
+ ):
+ super().__init__()
+
+ self.eps = 1e-8
+ self.kernel_size = kernel_size
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.upsample = upsample
+ self.downsample = downsample
+
+ if upsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2 + 1
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
+
+ fan_in = in_channel * kernel_size ** 2
+ self.scale = 1 / math.sqrt(fan_in)
+ self.padding = kernel_size // 2
+
+ self.weight = nn.Parameter(
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
+ )
+
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
+
+ self.demodulate = demodulate
+ self.fused = fused
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
+ f"upsample={self.upsample}, downsample={self.downsample})"
+ )
+
+ def forward(self, input, style):
+ batch, in_channel, height, width = input.shape
+
+ if not self.fused:
+ weight = self.scale * self.weight.squeeze(0)
+ style = self.modulation(style)
+
+ if self.demodulate:
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
+
+ input = input * style.reshape(batch, in_channel, 1, 1)
+
+ if self.upsample:
+ weight = weight.transpose(0, 1)
+ out = conv2d_gradfix.conv_transpose2d(
+ input, weight, padding=0, stride=2
+ )
+ out = self.blur(out)
+
+ elif self.downsample:
+ input = self.blur(input)
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
+
+ else:
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
+
+ if self.demodulate:
+ out = out * dcoefs.view(batch, -1, 1, 1)
+
+ return out
+
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
+ weight = self.scale * self.weight * style
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
+
+ weight = weight.view(
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+
+ if self.upsample:
+ input = input.view(1, batch * in_channel, height, width)
+ weight = weight.view(
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+ weight = weight.transpose(1, 2).reshape(
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
+ )
+ out = conv2d_gradfix.conv_transpose2d(
+ input, weight, padding=0, stride=2, groups=batch
+ )
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+ out = self.blur(out)
+
+ elif self.downsample:
+ input = self.blur(input)
+ _, _, height, width = input.shape
+ input = input.view(1, batch * in_channel, height, width)
+ out = conv2d_gradfix.conv2d(
+ input, weight, padding=0, stride=2, groups=batch
+ )
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ else:
+ input = input.view(1, batch * in_channel, height, width)
+ out = conv2d_gradfix.conv2d(
+ input, weight, padding=self.padding, groups=batch
+ )
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ return out
+
+
+class NoiseInjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.zeros(1))
+
+ def forward(self, image, noise=None):
+ if noise is None:
+ batch, _, height, width = image.shape
+ noise = image.new_empty(batch, 1, height, width).normal_()
+
+ return image + self.weight * noise
+
+
+class ConstantInput(nn.Module):
+ def __init__(self, channel, size=4):
+ super().__init__()
+
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
+
+ def forward(self, input):
+ batch = input.shape[0]
+ out = self.input.repeat(batch, 1, 1, 1)
+
+ return out
+
+
+class StyledConv(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ demodulate=True,
+ ):
+ super().__init__()
+
+ self.conv = ModulatedConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=upsample,
+ blur_kernel=blur_kernel,
+ demodulate=demodulate,
+ )
+
+ self.noise = NoiseInjection()
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
+ # self.activate = ScaledLeakyReLU(0.2)
+ self.activate = FusedLeakyReLU(out_channel)
+
+ def forward(self, input, style, noise=None):
+ out = self.conv(input, style)
+ out = self.noise(out, noise=noise)
+ # out = out + self.bias
+ out = self.activate(out)
+
+ return out
+
+
+class ToRGB(nn.Module):
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ if upsample:
+ self.upsample = Upsample(blur_kernel)
+
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, input, style, skip=None):
+ out = self.conv(input, style)
+ out = out + self.bias
+
+ if skip is not None:
+ skip = self.upsample(skip)
+
+ out = out + skip
+
+ return out
+
+
+class Generator(nn.Module):
+ def __init__(
+ self,
+ size,
+ style_dim,
+ n_mlp,
+ channel_multiplier=2,
+ blur_kernel=[1, 3, 3, 1],
+ lr_mlp=0.01,
+ ):
+ super().__init__()
+
+ self.size = size
+
+ self.style_dim = style_dim
+
+ layers = [PixelNorm()]
+
+ for i in range(n_mlp):
+ layers.append(
+ EqualLinear(
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
+ )
+ )
+
+ self.style = nn.Sequential(*layers)
+
+ self.channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ self.input = ConstantInput(self.channels[4])
+ self.conv1 = StyledConv(
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
+ )
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
+
+ self.log_size = int(math.log(size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+
+ self.convs = nn.ModuleList()
+ self.upsamples = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channel = self.channels[4]
+
+ for layer_idx in range(self.num_layers):
+ res = (layer_idx + 5) // 2
+ shape = [1, 1, 2 ** res, 2 ** res]
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
+
+ for i in range(3, self.log_size + 1):
+ out_channel = self.channels[2 ** i]
+
+ self.convs.append(
+ StyledConv(
+ in_channel,
+ out_channel,
+ 3,
+ style_dim,
+ upsample=True,
+ blur_kernel=blur_kernel,
+ )
+ )
+
+ self.convs.append(
+ StyledConv(
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
+ )
+ )
+
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
+
+ in_channel = out_channel
+
+ self.n_latent = self.log_size * 2 - 2
+
+ def make_noise(self):
+ device = self.input.input.device
+
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
+
+ return noises
+
+ @torch.no_grad()
+ def mean_latent(self, n_latent):
+ latent_in = torch.randn(
+ n_latent, self.style_dim, device=self.input.input.device
+ )
+ latent = self.style(latent_in).mean(0, keepdim=True)
+
+ return latent
+
+ @torch.no_grad()
+ def get_latent(self, input):
+ return self.style(input)
+
+ def forward(
+ self,
+ styles,
+ return_latents=False,
+ inject_index=None,
+ truncation=1,
+ truncation_latent=None,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ ):
+
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers
+ else:
+ noise = [
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
+ ]
+
+ if not input_is_latent:
+ styles = [self.style(s) for s in styles]
+
+ if truncation < 1:
+ style_t = []
+
+ for style in styles:
+ style_t.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+
+ styles = style_t
+ latent = styles[0].unsqueeze(1).repeat(1, self.n_latent, 1)
+ else:
+ latent = styles
+
+ out = self.input(latent)
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
+
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+
+ i += 2
+
+ image = skip
+
+ return image
+
+
+class ConvLayer(nn.Sequential):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ bias=True,
+ activate=True,
+ ):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
+
+ stride = 2
+ self.padding = 0
+
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+
+ layers.append(
+ EqualConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=self.padding,
+ stride=stride,
+ bias=bias and not activate,
+ )
+ )
+
+ if activate:
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
+
+ self.skip = ConvLayer(
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
+ )
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ skip = self.skip(input)
+ out = (out + skip) / math.sqrt(2)
+
+ return out
+
+
+class Discriminator(nn.Module):
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ convs = [ConvLayer(3, channels[size], 1)]
+
+ log_size = int(math.log(size, 2))
+
+ in_channel = channels[size]
+
+ for i in range(log_size, 2, -1):
+ out_channel = channels[2 ** (i - 1)]
+
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
+
+ in_channel = out_channel
+
+ self.convs = nn.Sequential(*convs)
+
+ self.stddev_group = 4
+ self.stddev_feat = 1
+
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
+ self.final_linear = nn.Sequential(
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
+ EqualLinear(channels[4], 1),
+ )
+
+ def forward(self, input):
+ out = self.convs(input)
+
+ batch, channel, height, width = out.shape
+ group = min(batch, self.stddev_group)
+ stddev = out.view(
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
+ )
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
+ stddev = stddev.repeat(group, 1, height, width)
+ out = torch.cat([out, stddev], 1)
+
+ out = self.final_conv(out)
+
+ out = out.view(batch, -1)
+ out = self.final_linear(out)
+
+ return out
+
diff --git a/mona.png b/mona.png
new file mode 100644
index 0000000000000000000000000000000000000000..95c5c5a09c73b343cd2d1911816267db3ed79619
Binary files /dev/null and b/mona.png differ
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ace9fff330d5ce42b11d0e450d7cde99ccecfa77
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,4 @@
+ffmpeg
+libsm6
+libxext6
+cmake
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0379116faae6fe42b29fd9ea44800c6893fac9e1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+tqdm
+gdown
+scikit-learn==0.22
+scipy
+lpips
+opencv-python-headless
+torch
+torchvision
+imageio
+dlib
\ No newline at end of file
diff --git a/util.py b/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed7f1124405c955bdb46b0bd2cec68d47653d100
--- /dev/null
+++ b/util.py
@@ -0,0 +1,220 @@
+from matplotlib import pyplot as plt
+import torch
+import torch.nn.functional as F
+import os
+import cv2
+import dlib
+from PIL import Image
+import numpy as np
+import math
+import torchvision
+import scipy
+import scipy.ndimage
+import torchvision.transforms as transforms
+
+from huggingface_hub import hf_hub_download
+
+
+shape_predictor_path = hf_hub_download(repo_id="akhaliq/jojogan_dlib", filename="shape_predictor_68_face_landmarks.dat")
+
+
+google_drive_paths = {
+ "models/stylegan2-ffhq-config-f.pt": "https://drive.google.com/uc?id=1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK",
+ "models/dlibshape_predictor_68_face_landmarks.dat": "https://drive.google.com/uc?id=11BDmNKS1zxSZxkgsEvQoKgFd8J264jKp",
+ "models/e4e_ffhq_encode.pt": "https://drive.google.com/uc?id=1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7",
+ "models/restyle_psp_ffhq_encode.pt": "https://drive.google.com/uc?id=1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd",
+ "models/arcane_caitlyn.pt": "https://drive.google.com/uc?id=1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc",
+ "models/arcane_caitlyn_preserve_color.pt": "https://drive.google.com/uc?id=1cUTyjU-q98P75a8THCaO545RTwpVV-aH",
+ "models/arcane_jinx_preserve_color.pt": "https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney",
+ "models/arcane_jinx.pt": "https://drive.google.com/uc?id=1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_",
+ "models/disney.pt": "https://drive.google.com/uc?id=1zbE2upakFUAx8ximYnLofFwfT8MilqJA",
+ "models/disney_preserve_color.pt": "https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi",
+ "models/jojo.pt": "https://drive.google.com/uc?id=13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4",
+ "models/jojo_preserve_color.pt": "https://drive.google.com/uc?id=1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2",
+ "models/jojo_yasuho.pt": "https://drive.google.com/uc?id=1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_",
+ "models/jojo_yasuho_preserve_color.pt": "https://drive.google.com/uc?id=1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L",
+ "models/supergirl.pt": "https://drive.google.com/uc?id=1L0y9IYgzLNzB-33xTpXpecsKU-t9DpVC",
+ "models/supergirl_preserve_color.pt": "https://drive.google.com/uc?id=1VmKGuvThWHym7YuayXxjv0fSn32lfDpE",
+}
+
+@torch.no_grad()
+def load_model(generator, model_file_path):
+ ensure_checkpoint_exists(model_file_path)
+ ckpt = torch.load(model_file_path, map_location=lambda storage, loc: storage)
+ generator.load_state_dict(ckpt["g_ema"], strict=False)
+ return generator.mean_latent(50000)
+
+def ensure_checkpoint_exists(model_weights_filename):
+ if not os.path.isfile(model_weights_filename) and (
+ model_weights_filename in google_drive_paths
+ ):
+ gdrive_url = google_drive_paths[model_weights_filename]
+ try:
+ from gdown import download as drive_download
+
+ drive_download(gdrive_url, model_weights_filename, quiet=False)
+ except ModuleNotFoundError:
+ print(
+ "gdown module not found.",
+ "pip3 install gdown or, manually download the checkpoint file:",
+ gdrive_url
+ )
+
+ if not os.path.isfile(model_weights_filename) and (
+ model_weights_filename not in google_drive_paths
+ ):
+ print(
+ model_weights_filename,
+ " not found, you may need to manually download the model weights."
+ )
+
+# given a list of filenames, load the inverted style code
+@torch.no_grad()
+def load_source(files, generator, device='cuda'):
+ sources = []
+
+ for file in files:
+ source = torch.load(f'./inversion_codes/{file}.pt')['latent'].to(device)
+
+ if source.size(0) != 1:
+ source = source.unsqueeze(0)
+
+ if source.ndim == 3:
+ source = generator.get_latent(source, truncation=1, is_latent=True)
+ source = list2style(source)
+
+ sources.append(source)
+
+ sources = torch.cat(sources, 0)
+ if type(sources) is not list:
+ sources = style2list(sources)
+
+ return sources
+
+def display_image(image, size=None, mode='nearest', unnorm=False, title=''):
+ # image is [3,h,w] or [1,3,h,w] tensor [0,1]
+ if not isinstance(image, torch.Tensor):
+ image = transforms.ToTensor()(image).unsqueeze(0)
+ if image.is_cuda:
+ image = image.cpu()
+ if size is not None and image.size(-1) != size:
+ image = F.interpolate(image, size=(size,size), mode=mode)
+ if image.dim() == 4:
+ image = image[0]
+ image = image.permute(1, 2, 0).detach().numpy()
+ plt.figure()
+ plt.title(title)
+ plt.axis('off')
+ plt.imshow(image)
+
+def get_landmark(filepath, predictor):
+ """get landmark with dlib
+ :return: np.array shape=(68, 2)
+ """
+ detector = dlib.get_frontal_face_detector()
+
+ img = dlib.load_rgb_image(filepath)
+ dets = detector(img, 1)
+ assert len(dets) > 0, "Face not detected, try another face image"
+
+ for k, d in enumerate(dets):
+ shape = predictor(img, d)
+
+ t = list(shape.parts())
+ a = []
+ for tt in t:
+ a.append([tt.x, tt.y])
+ lm = np.array(a)
+ return lm
+
+
+def align_face(filepath, output_size=256, transform_size=1024, enable_padding=True):
+
+ """
+ :param filepath: str
+ :return: PIL Image
+ """
+ predictor = dlib.shape_predictor(shape_predictor_path)
+ lm = get_landmark(filepath, predictor)
+
+ lm_chin = lm[0: 17] # left-right
+ lm_eyebrow_left = lm[17: 22] # left-right
+ lm_eyebrow_right = lm[22: 27] # left-right
+ lm_nose = lm[27: 31] # top-down
+ lm_nostrils = lm[31: 36] # top-down
+ lm_eye_left = lm[36: 42] # left-clockwise
+ lm_eye_right = lm[42: 48] # left-clockwise
+ lm_mouth_outer = lm[48: 60] # left-clockwise
+ lm_mouth_inner = lm[60: 68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ x /= np.hypot(*x)
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ qsize = np.hypot(*x) * 2
+
+ # read image
+ img = Image.open(filepath)
+
+ transform_size = output_size
+ enable_padding = True
+
+ # Shrink.
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+ min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+ max(pad[3] - img.size[1] + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w, _ = img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
+ blur = qsize * 0.02
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ quad += pad[:2]
+
+ # Transform.
+ img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
+ if output_size < transform_size:
+ img = img.resize((output_size, output_size), Image.ANTIALIAS)
+
+ # Return aligned image.
+ return img
+
+def strip_path_extension(path):
+ return os.path.splitext(path)[0]