Upload 14 files
Browse files- scripts/__init__.py +0 -0
- scripts/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/__pycache__/reactor_faceswap.cpython-310.pyc +0 -0
- scripts/__pycache__/reactor_logger.cpython-310.pyc +0 -0
- scripts/__pycache__/reactor_swapper.cpython-310.pyc +0 -0
- scripts/__pycache__/reactor_version.cpython-310.pyc +0 -0
- scripts/r_archs/__pycache__/codeformer_arch.cpython-310.pyc +0 -0
- scripts/r_archs/__pycache__/vqgan_arch.cpython-310.pyc +0 -0
- scripts/r_archs/codeformer_arch.py +278 -0
- scripts/r_archs/vqgan_arch.py +437 -0
- scripts/reactor_faceswap.py +126 -0
- scripts/reactor_logger.py +47 -0
- scripts/reactor_swapper.py +301 -0
- scripts/reactor_version.py +13 -0
scripts/__init__.py
ADDED
File without changes
|
scripts/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (175 Bytes). View file
|
|
scripts/__pycache__/reactor_faceswap.cpython-310.pyc
ADDED
Binary file (3.59 kB). View file
|
|
scripts/__pycache__/reactor_logger.cpython-310.pyc
ADDED
Binary file (1.4 kB). View file
|
|
scripts/__pycache__/reactor_swapper.cpython-310.pyc
ADDED
Binary file (7.24 kB). View file
|
|
scripts/__pycache__/reactor_version.cpython-310.pyc
ADDED
Binary file (534 Bytes). View file
|
|
scripts/r_archs/__pycache__/codeformer_arch.cpython-310.pyc
ADDED
Binary file (9.24 kB). View file
|
|
scripts/r_archs/__pycache__/vqgan_arch.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
scripts/r_archs/codeformer_arch.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn, Tensor
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Optional, List
|
7 |
+
|
8 |
+
from scripts.r_archs.vqgan_arch import *
|
9 |
+
from r_basicsr.utils import get_root_logger
|
10 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
11 |
+
|
12 |
+
|
13 |
+
def calc_mean_std(feat, eps=1e-5):
|
14 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
feat (Tensor): 4D tensor.
|
18 |
+
eps (float): A small value added to the variance to avoid
|
19 |
+
divide-by-zero. Default: 1e-5.
|
20 |
+
"""
|
21 |
+
size = feat.size()
|
22 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
23 |
+
b, c = size[:2]
|
24 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
25 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
26 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
27 |
+
return feat_mean, feat_std
|
28 |
+
|
29 |
+
|
30 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
31 |
+
"""Adaptive instance normalization.
|
32 |
+
|
33 |
+
Adjust the reference features to have the similar color and illuminations
|
34 |
+
as those in the degradate features.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
content_feat (Tensor): The reference feature.
|
38 |
+
style_feat (Tensor): The degradate features.
|
39 |
+
"""
|
40 |
+
size = content_feat.size()
|
41 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
42 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
43 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
44 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
45 |
+
|
46 |
+
|
47 |
+
class PositionEmbeddingSine(nn.Module):
|
48 |
+
"""
|
49 |
+
This is a more standard version of the position embedding, very similar to the one
|
50 |
+
used by the Attention is all you need paper, generalized to work on images.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
54 |
+
super().__init__()
|
55 |
+
self.num_pos_feats = num_pos_feats
|
56 |
+
self.temperature = temperature
|
57 |
+
self.normalize = normalize
|
58 |
+
if scale is not None and normalize is False:
|
59 |
+
raise ValueError("normalize should be True if scale is passed")
|
60 |
+
if scale is None:
|
61 |
+
scale = 2 * math.pi
|
62 |
+
self.scale = scale
|
63 |
+
|
64 |
+
def forward(self, x, mask=None):
|
65 |
+
if mask is None:
|
66 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
67 |
+
not_mask = ~mask
|
68 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
69 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
70 |
+
if self.normalize:
|
71 |
+
eps = 1e-6
|
72 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
73 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
74 |
+
|
75 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
76 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
77 |
+
|
78 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
79 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
80 |
+
pos_x = torch.stack(
|
81 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
82 |
+
).flatten(3)
|
83 |
+
pos_y = torch.stack(
|
84 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
85 |
+
).flatten(3)
|
86 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
87 |
+
return pos
|
88 |
+
|
89 |
+
def _get_activation_fn(activation):
|
90 |
+
"""Return an activation function given a string"""
|
91 |
+
if activation == "relu":
|
92 |
+
return F.relu
|
93 |
+
if activation == "gelu":
|
94 |
+
return F.gelu
|
95 |
+
if activation == "glu":
|
96 |
+
return F.glu
|
97 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
98 |
+
|
99 |
+
|
100 |
+
class TransformerSALayer(nn.Module):
|
101 |
+
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
102 |
+
super().__init__()
|
103 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
104 |
+
# Implementation of Feedforward model - MLP
|
105 |
+
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
106 |
+
self.dropout = nn.Dropout(dropout)
|
107 |
+
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
108 |
+
|
109 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
110 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
111 |
+
self.dropout1 = nn.Dropout(dropout)
|
112 |
+
self.dropout2 = nn.Dropout(dropout)
|
113 |
+
|
114 |
+
self.activation = _get_activation_fn(activation)
|
115 |
+
|
116 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
117 |
+
return tensor if pos is None else tensor + pos
|
118 |
+
|
119 |
+
def forward(self, tgt,
|
120 |
+
tgt_mask: Optional[Tensor] = None,
|
121 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
122 |
+
query_pos: Optional[Tensor] = None):
|
123 |
+
|
124 |
+
# self attention
|
125 |
+
tgt2 = self.norm1(tgt)
|
126 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
127 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
128 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
129 |
+
tgt = tgt + self.dropout1(tgt2)
|
130 |
+
|
131 |
+
# ffn
|
132 |
+
tgt2 = self.norm2(tgt)
|
133 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
134 |
+
tgt = tgt + self.dropout2(tgt2)
|
135 |
+
return tgt
|
136 |
+
|
137 |
+
class Fuse_sft_block(nn.Module):
|
138 |
+
def __init__(self, in_ch, out_ch):
|
139 |
+
super().__init__()
|
140 |
+
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
141 |
+
|
142 |
+
self.scale = nn.Sequential(
|
143 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
144 |
+
nn.LeakyReLU(0.2, True),
|
145 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
146 |
+
|
147 |
+
self.shift = nn.Sequential(
|
148 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
149 |
+
nn.LeakyReLU(0.2, True),
|
150 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
151 |
+
|
152 |
+
def forward(self, enc_feat, dec_feat, w=1):
|
153 |
+
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
154 |
+
scale = self.scale(enc_feat)
|
155 |
+
shift = self.shift(enc_feat)
|
156 |
+
residual = w * (dec_feat * scale + shift)
|
157 |
+
out = dec_feat + residual
|
158 |
+
return out
|
159 |
+
|
160 |
+
|
161 |
+
@ARCH_REGISTRY.register()
|
162 |
+
class CodeFormer(VQAutoEncoder):
|
163 |
+
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
164 |
+
codebook_size=1024, latent_size=256,
|
165 |
+
connect_list=['32', '64', '128', '256'],
|
166 |
+
fix_modules=['quantize','generator']):
|
167 |
+
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
168 |
+
|
169 |
+
if fix_modules is not None:
|
170 |
+
for module in fix_modules:
|
171 |
+
for param in getattr(self, module).parameters():
|
172 |
+
param.requires_grad = False
|
173 |
+
|
174 |
+
self.connect_list = connect_list
|
175 |
+
self.n_layers = n_layers
|
176 |
+
self.dim_embd = dim_embd
|
177 |
+
self.dim_mlp = dim_embd*2
|
178 |
+
|
179 |
+
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
180 |
+
self.feat_emb = nn.Linear(256, self.dim_embd)
|
181 |
+
|
182 |
+
# transformer
|
183 |
+
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
184 |
+
for _ in range(self.n_layers)])
|
185 |
+
|
186 |
+
# logits_predict head
|
187 |
+
self.idx_pred_layer = nn.Sequential(
|
188 |
+
nn.LayerNorm(dim_embd),
|
189 |
+
nn.Linear(dim_embd, codebook_size, bias=False))
|
190 |
+
|
191 |
+
self.channels = {
|
192 |
+
'16': 512,
|
193 |
+
'32': 256,
|
194 |
+
'64': 256,
|
195 |
+
'128': 128,
|
196 |
+
'256': 128,
|
197 |
+
'512': 64,
|
198 |
+
}
|
199 |
+
|
200 |
+
# after second residual block for > 16, before attn layer for ==16
|
201 |
+
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
202 |
+
# after first residual block for > 16, before attn layer for ==16
|
203 |
+
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
204 |
+
|
205 |
+
# fuse_convs_dict
|
206 |
+
self.fuse_convs_dict = nn.ModuleDict()
|
207 |
+
for f_size in self.connect_list:
|
208 |
+
in_ch = self.channels[f_size]
|
209 |
+
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
210 |
+
|
211 |
+
def _init_weights(self, module):
|
212 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
213 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
214 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
215 |
+
module.bias.data.zero_()
|
216 |
+
elif isinstance(module, nn.LayerNorm):
|
217 |
+
module.bias.data.zero_()
|
218 |
+
module.weight.data.fill_(1.0)
|
219 |
+
|
220 |
+
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
221 |
+
# ################### Encoder #####################
|
222 |
+
enc_feat_dict = {}
|
223 |
+
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
224 |
+
for i, block in enumerate(self.encoder.blocks):
|
225 |
+
x = block(x)
|
226 |
+
if i in out_list:
|
227 |
+
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
228 |
+
|
229 |
+
lq_feat = x
|
230 |
+
# ################# Transformer ###################
|
231 |
+
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
232 |
+
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
233 |
+
# BCHW -> BC(HW) -> (HW)BC
|
234 |
+
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
235 |
+
query_emb = feat_emb
|
236 |
+
# Transformer encoder
|
237 |
+
for layer in self.ft_layers:
|
238 |
+
query_emb = layer(query_emb, query_pos=pos_emb)
|
239 |
+
|
240 |
+
# output logits
|
241 |
+
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
242 |
+
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
243 |
+
|
244 |
+
if code_only: # for training stage II
|
245 |
+
# logits doesn't need softmax before cross_entropy loss
|
246 |
+
return logits, lq_feat
|
247 |
+
|
248 |
+
# ################# Quantization ###################
|
249 |
+
# if self.training:
|
250 |
+
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
251 |
+
# # b(hw)c -> bc(hw) -> bchw
|
252 |
+
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
253 |
+
# ------------
|
254 |
+
soft_one_hot = F.softmax(logits, dim=2)
|
255 |
+
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
256 |
+
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
257 |
+
# preserve gradients
|
258 |
+
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
259 |
+
|
260 |
+
if detach_16:
|
261 |
+
quant_feat = quant_feat.detach() # for training stage III
|
262 |
+
if adain:
|
263 |
+
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
264 |
+
|
265 |
+
# ################## Generator ####################
|
266 |
+
x = quant_feat
|
267 |
+
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
268 |
+
|
269 |
+
for i, block in enumerate(self.generator.blocks):
|
270 |
+
x = block(x)
|
271 |
+
if i in fuse_list: # fuse after i-th block
|
272 |
+
f_size = str(x.shape[-1])
|
273 |
+
if w>0:
|
274 |
+
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
275 |
+
out = x
|
276 |
+
# logits doesn't need softmax before cross_entropy loss
|
277 |
+
return out, logits, lq_feat
|
278 |
+
|
scripts/r_archs/vqgan_arch.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
3 |
+
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
4 |
+
|
5 |
+
'''
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import copy
|
11 |
+
from r_basicsr.utils import get_root_logger
|
12 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
13 |
+
|
14 |
+
|
15 |
+
def normalize(in_channels):
|
16 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
17 |
+
|
18 |
+
|
19 |
+
@torch.jit.script
|
20 |
+
def swish(x):
|
21 |
+
return x*torch.sigmoid(x)
|
22 |
+
|
23 |
+
|
24 |
+
# Define VQVAE classes
|
25 |
+
class VectorQuantizer(nn.Module):
|
26 |
+
def __init__(self, codebook_size, emb_dim, beta):
|
27 |
+
super(VectorQuantizer, self).__init__()
|
28 |
+
self.codebook_size = codebook_size # number of embeddings
|
29 |
+
self.emb_dim = emb_dim # dimension of embedding
|
30 |
+
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
31 |
+
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
32 |
+
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
33 |
+
|
34 |
+
def forward(self, z):
|
35 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
36 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
37 |
+
z_flattened = z.view(-1, self.emb_dim)
|
38 |
+
|
39 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
40 |
+
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
41 |
+
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
42 |
+
|
43 |
+
mean_distance = torch.mean(d)
|
44 |
+
# find closest encodings
|
45 |
+
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
46 |
+
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
47 |
+
# [0-1], higher score, higher confidence
|
48 |
+
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
49 |
+
|
50 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
51 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
52 |
+
|
53 |
+
# get quantized latent vectors
|
54 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
55 |
+
# compute loss for embedding
|
56 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
57 |
+
# preserve gradients
|
58 |
+
z_q = z + (z_q - z).detach()
|
59 |
+
|
60 |
+
# perplexity
|
61 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
62 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
63 |
+
# reshape back to match original input shape
|
64 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
65 |
+
|
66 |
+
return z_q, loss, {
|
67 |
+
"perplexity": perplexity,
|
68 |
+
"min_encodings": min_encodings,
|
69 |
+
"min_encoding_indices": min_encoding_indices,
|
70 |
+
"min_encoding_scores": min_encoding_scores,
|
71 |
+
"mean_distance": mean_distance
|
72 |
+
}
|
73 |
+
|
74 |
+
def get_codebook_feat(self, indices, shape):
|
75 |
+
# input indices: batch*token_num -> (batch*token_num)*1
|
76 |
+
# shape: batch, height, width, channel
|
77 |
+
indices = indices.view(-1,1)
|
78 |
+
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
79 |
+
min_encodings.scatter_(1, indices, 1)
|
80 |
+
# get quantized latent vectors
|
81 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
82 |
+
|
83 |
+
if shape is not None: # reshape back to match original input shape
|
84 |
+
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
85 |
+
|
86 |
+
return z_q
|
87 |
+
|
88 |
+
|
89 |
+
class GumbelQuantizer(nn.Module):
|
90 |
+
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
91 |
+
super().__init__()
|
92 |
+
self.codebook_size = codebook_size # number of embeddings
|
93 |
+
self.emb_dim = emb_dim # dimension of embedding
|
94 |
+
self.straight_through = straight_through
|
95 |
+
self.temperature = temp_init
|
96 |
+
self.kl_weight = kl_weight
|
97 |
+
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
98 |
+
self.embed = nn.Embedding(codebook_size, emb_dim)
|
99 |
+
|
100 |
+
def forward(self, z):
|
101 |
+
hard = self.straight_through if self.training else True
|
102 |
+
|
103 |
+
logits = self.proj(z)
|
104 |
+
|
105 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
106 |
+
|
107 |
+
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
108 |
+
|
109 |
+
# + kl divergence to the prior loss
|
110 |
+
qy = F.softmax(logits, dim=1)
|
111 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
112 |
+
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
113 |
+
|
114 |
+
return z_q, diff, {
|
115 |
+
"min_encoding_indices": min_encoding_indices
|
116 |
+
}
|
117 |
+
|
118 |
+
|
119 |
+
class Downsample(nn.Module):
|
120 |
+
def __init__(self, in_channels):
|
121 |
+
super().__init__()
|
122 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
pad = (0, 1, 0, 1)
|
126 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
127 |
+
x = self.conv(x)
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
class Upsample(nn.Module):
|
132 |
+
def __init__(self, in_channels):
|
133 |
+
super().__init__()
|
134 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
138 |
+
x = self.conv(x)
|
139 |
+
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
class ResBlock(nn.Module):
|
144 |
+
def __init__(self, in_channels, out_channels=None):
|
145 |
+
super(ResBlock, self).__init__()
|
146 |
+
self.in_channels = in_channels
|
147 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
148 |
+
self.norm1 = normalize(in_channels)
|
149 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
150 |
+
self.norm2 = normalize(out_channels)
|
151 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
152 |
+
if self.in_channels != self.out_channels:
|
153 |
+
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
154 |
+
|
155 |
+
def forward(self, x_in):
|
156 |
+
x = x_in
|
157 |
+
x = self.norm1(x)
|
158 |
+
x = swish(x)
|
159 |
+
x = self.conv1(x)
|
160 |
+
x = self.norm2(x)
|
161 |
+
x = swish(x)
|
162 |
+
x = self.conv2(x)
|
163 |
+
if self.in_channels != self.out_channels:
|
164 |
+
x_in = self.conv_out(x_in)
|
165 |
+
|
166 |
+
return x + x_in
|
167 |
+
|
168 |
+
|
169 |
+
class AttnBlock(nn.Module):
|
170 |
+
def __init__(self, in_channels):
|
171 |
+
super().__init__()
|
172 |
+
self.in_channels = in_channels
|
173 |
+
|
174 |
+
self.norm = normalize(in_channels)
|
175 |
+
self.q = torch.nn.Conv2d(
|
176 |
+
in_channels,
|
177 |
+
in_channels,
|
178 |
+
kernel_size=1,
|
179 |
+
stride=1,
|
180 |
+
padding=0
|
181 |
+
)
|
182 |
+
self.k = torch.nn.Conv2d(
|
183 |
+
in_channels,
|
184 |
+
in_channels,
|
185 |
+
kernel_size=1,
|
186 |
+
stride=1,
|
187 |
+
padding=0
|
188 |
+
)
|
189 |
+
self.v = torch.nn.Conv2d(
|
190 |
+
in_channels,
|
191 |
+
in_channels,
|
192 |
+
kernel_size=1,
|
193 |
+
stride=1,
|
194 |
+
padding=0
|
195 |
+
)
|
196 |
+
self.proj_out = torch.nn.Conv2d(
|
197 |
+
in_channels,
|
198 |
+
in_channels,
|
199 |
+
kernel_size=1,
|
200 |
+
stride=1,
|
201 |
+
padding=0
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
h_ = x
|
206 |
+
h_ = self.norm(h_)
|
207 |
+
q = self.q(h_)
|
208 |
+
k = self.k(h_)
|
209 |
+
v = self.v(h_)
|
210 |
+
|
211 |
+
# compute attention
|
212 |
+
b, c, h, w = q.shape
|
213 |
+
q = q.reshape(b, c, h*w)
|
214 |
+
q = q.permute(0, 2, 1)
|
215 |
+
k = k.reshape(b, c, h*w)
|
216 |
+
w_ = torch.bmm(q, k)
|
217 |
+
w_ = w_ * (int(c)**(-0.5))
|
218 |
+
w_ = F.softmax(w_, dim=2)
|
219 |
+
|
220 |
+
# attend to values
|
221 |
+
v = v.reshape(b, c, h*w)
|
222 |
+
w_ = w_.permute(0, 2, 1)
|
223 |
+
h_ = torch.bmm(v, w_)
|
224 |
+
h_ = h_.reshape(b, c, h, w)
|
225 |
+
|
226 |
+
h_ = self.proj_out(h_)
|
227 |
+
|
228 |
+
return x+h_
|
229 |
+
|
230 |
+
|
231 |
+
class Encoder(nn.Module):
|
232 |
+
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
233 |
+
super().__init__()
|
234 |
+
self.nf = nf
|
235 |
+
self.num_resolutions = len(ch_mult)
|
236 |
+
self.num_res_blocks = num_res_blocks
|
237 |
+
self.resolution = resolution
|
238 |
+
self.attn_resolutions = attn_resolutions
|
239 |
+
|
240 |
+
curr_res = self.resolution
|
241 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
242 |
+
|
243 |
+
blocks = []
|
244 |
+
# initial convultion
|
245 |
+
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
246 |
+
|
247 |
+
# residual and downsampling blocks, with attention on smaller res (16x16)
|
248 |
+
for i in range(self.num_resolutions):
|
249 |
+
block_in_ch = nf * in_ch_mult[i]
|
250 |
+
block_out_ch = nf * ch_mult[i]
|
251 |
+
for _ in range(self.num_res_blocks):
|
252 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
253 |
+
block_in_ch = block_out_ch
|
254 |
+
if curr_res in attn_resolutions:
|
255 |
+
blocks.append(AttnBlock(block_in_ch))
|
256 |
+
|
257 |
+
if i != self.num_resolutions - 1:
|
258 |
+
blocks.append(Downsample(block_in_ch))
|
259 |
+
curr_res = curr_res // 2
|
260 |
+
|
261 |
+
# non-local attention block
|
262 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
263 |
+
blocks.append(AttnBlock(block_in_ch))
|
264 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
265 |
+
|
266 |
+
# normalise and convert to latent size
|
267 |
+
blocks.append(normalize(block_in_ch))
|
268 |
+
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
269 |
+
self.blocks = nn.ModuleList(blocks)
|
270 |
+
|
271 |
+
def forward(self, x):
|
272 |
+
for block in self.blocks:
|
273 |
+
x = block(x)
|
274 |
+
|
275 |
+
return x
|
276 |
+
|
277 |
+
|
278 |
+
class Generator(nn.Module):
|
279 |
+
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
280 |
+
super().__init__()
|
281 |
+
self.nf = nf
|
282 |
+
self.ch_mult = ch_mult
|
283 |
+
self.num_resolutions = len(self.ch_mult)
|
284 |
+
self.num_res_blocks = res_blocks
|
285 |
+
self.resolution = img_size
|
286 |
+
self.attn_resolutions = attn_resolutions
|
287 |
+
self.in_channels = emb_dim
|
288 |
+
self.out_channels = 3
|
289 |
+
block_in_ch = self.nf * self.ch_mult[-1]
|
290 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
291 |
+
|
292 |
+
blocks = []
|
293 |
+
# initial conv
|
294 |
+
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
295 |
+
|
296 |
+
# non-local attention block
|
297 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
298 |
+
blocks.append(AttnBlock(block_in_ch))
|
299 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
300 |
+
|
301 |
+
for i in reversed(range(self.num_resolutions)):
|
302 |
+
block_out_ch = self.nf * self.ch_mult[i]
|
303 |
+
|
304 |
+
for _ in range(self.num_res_blocks):
|
305 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
306 |
+
block_in_ch = block_out_ch
|
307 |
+
|
308 |
+
if curr_res in self.attn_resolutions:
|
309 |
+
blocks.append(AttnBlock(block_in_ch))
|
310 |
+
|
311 |
+
if i != 0:
|
312 |
+
blocks.append(Upsample(block_in_ch))
|
313 |
+
curr_res = curr_res * 2
|
314 |
+
|
315 |
+
blocks.append(normalize(block_in_ch))
|
316 |
+
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
317 |
+
|
318 |
+
self.blocks = nn.ModuleList(blocks)
|
319 |
+
|
320 |
+
|
321 |
+
def forward(self, x):
|
322 |
+
for block in self.blocks:
|
323 |
+
x = block(x)
|
324 |
+
|
325 |
+
return x
|
326 |
+
|
327 |
+
|
328 |
+
@ARCH_REGISTRY.register()
|
329 |
+
class VQAutoEncoder(nn.Module):
|
330 |
+
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
331 |
+
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
332 |
+
super().__init__()
|
333 |
+
logger = get_root_logger()
|
334 |
+
self.in_channels = 3
|
335 |
+
self.nf = nf
|
336 |
+
self.n_blocks = res_blocks
|
337 |
+
self.codebook_size = codebook_size
|
338 |
+
self.embed_dim = emb_dim
|
339 |
+
self.ch_mult = ch_mult
|
340 |
+
self.resolution = img_size
|
341 |
+
self.attn_resolutions = attn_resolutions
|
342 |
+
self.quantizer_type = quantizer
|
343 |
+
self.encoder = Encoder(
|
344 |
+
self.in_channels,
|
345 |
+
self.nf,
|
346 |
+
self.embed_dim,
|
347 |
+
self.ch_mult,
|
348 |
+
self.n_blocks,
|
349 |
+
self.resolution,
|
350 |
+
self.attn_resolutions
|
351 |
+
)
|
352 |
+
if self.quantizer_type == "nearest":
|
353 |
+
self.beta = beta #0.25
|
354 |
+
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
355 |
+
elif self.quantizer_type == "gumbel":
|
356 |
+
self.gumbel_num_hiddens = emb_dim
|
357 |
+
self.straight_through = gumbel_straight_through
|
358 |
+
self.kl_weight = gumbel_kl_weight
|
359 |
+
self.quantize = GumbelQuantizer(
|
360 |
+
self.codebook_size,
|
361 |
+
self.embed_dim,
|
362 |
+
self.gumbel_num_hiddens,
|
363 |
+
self.straight_through,
|
364 |
+
self.kl_weight
|
365 |
+
)
|
366 |
+
self.generator = Generator(
|
367 |
+
self.nf,
|
368 |
+
self.embed_dim,
|
369 |
+
self.ch_mult,
|
370 |
+
self.n_blocks,
|
371 |
+
self.resolution,
|
372 |
+
self.attn_resolutions
|
373 |
+
)
|
374 |
+
|
375 |
+
if model_path is not None:
|
376 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
377 |
+
if 'params_ema' in chkpt:
|
378 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
379 |
+
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
380 |
+
elif 'params' in chkpt:
|
381 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
382 |
+
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
383 |
+
else:
|
384 |
+
raise ValueError(f'Wrong params!')
|
385 |
+
|
386 |
+
|
387 |
+
def forward(self, x):
|
388 |
+
x = self.encoder(x)
|
389 |
+
quant, codebook_loss, quant_stats = self.quantize(x)
|
390 |
+
x = self.generator(quant)
|
391 |
+
return x, codebook_loss, quant_stats
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
# patch based discriminator
|
396 |
+
@ARCH_REGISTRY.register()
|
397 |
+
class VQGANDiscriminator(nn.Module):
|
398 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
399 |
+
super().__init__()
|
400 |
+
|
401 |
+
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
402 |
+
ndf_mult = 1
|
403 |
+
ndf_mult_prev = 1
|
404 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
405 |
+
ndf_mult_prev = ndf_mult
|
406 |
+
ndf_mult = min(2 ** n, 8)
|
407 |
+
layers += [
|
408 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
409 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
410 |
+
nn.LeakyReLU(0.2, True)
|
411 |
+
]
|
412 |
+
|
413 |
+
ndf_mult_prev = ndf_mult
|
414 |
+
ndf_mult = min(2 ** n_layers, 8)
|
415 |
+
|
416 |
+
layers += [
|
417 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
418 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
419 |
+
nn.LeakyReLU(0.2, True)
|
420 |
+
]
|
421 |
+
|
422 |
+
layers += [
|
423 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
424 |
+
self.main = nn.Sequential(*layers)
|
425 |
+
|
426 |
+
if model_path is not None:
|
427 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
428 |
+
if 'params_d' in chkpt:
|
429 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
430 |
+
elif 'params' in chkpt:
|
431 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
432 |
+
else:
|
433 |
+
raise ValueError(f'Wrong params!')
|
434 |
+
|
435 |
+
def forward(self, x):
|
436 |
+
return self.main(x)
|
437 |
+
|
scripts/reactor_faceswap.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
import modules.scripts as scripts
|
6 |
+
# from modules.upscaler import Upscaler, UpscalerData
|
7 |
+
from modules import scripts, scripts_postprocessing
|
8 |
+
from modules.processing import (
|
9 |
+
StableDiffusionProcessing,
|
10 |
+
StableDiffusionProcessingImg2Img,
|
11 |
+
)
|
12 |
+
from modules.shared import state
|
13 |
+
from scripts.reactor_logger import logger
|
14 |
+
from scripts.reactor_swapper import swap_face, get_current_faces_model, analyze_faces, half_det_size
|
15 |
+
import folder_paths
|
16 |
+
import comfy.model_management as model_management
|
17 |
+
|
18 |
+
|
19 |
+
def get_models():
|
20 |
+
models_path = os.path.join(folder_paths.models_dir,"insightface/*")
|
21 |
+
models = glob.glob(models_path)
|
22 |
+
models = [x for x in models if x.endswith(".onnx") or x.endswith(".pth")]
|
23 |
+
return models
|
24 |
+
|
25 |
+
|
26 |
+
class FaceSwapScript(scripts.Script):
|
27 |
+
|
28 |
+
def process(
|
29 |
+
self,
|
30 |
+
p: StableDiffusionProcessing,
|
31 |
+
img,
|
32 |
+
enable,
|
33 |
+
source_faces_index,
|
34 |
+
faces_index,
|
35 |
+
model,
|
36 |
+
swap_in_source,
|
37 |
+
swap_in_generated,
|
38 |
+
gender_source,
|
39 |
+
gender_target,
|
40 |
+
face_model,
|
41 |
+
):
|
42 |
+
self.enable = enable
|
43 |
+
if self.enable:
|
44 |
+
|
45 |
+
self.source = img
|
46 |
+
self.swap_in_generated = swap_in_generated
|
47 |
+
self.gender_source = gender_source
|
48 |
+
self.gender_target = gender_target
|
49 |
+
self.model = model
|
50 |
+
self.face_model = face_model
|
51 |
+
self.source_faces_index = [
|
52 |
+
int(x) for x in source_faces_index.strip(",").split(",") if x.isnumeric()
|
53 |
+
]
|
54 |
+
self.faces_index = [
|
55 |
+
int(x) for x in faces_index.strip(",").split(",") if x.isnumeric()
|
56 |
+
]
|
57 |
+
if len(self.source_faces_index) == 0:
|
58 |
+
self.source_faces_index = [0]
|
59 |
+
if len(self.faces_index) == 0:
|
60 |
+
self.faces_index = [0]
|
61 |
+
|
62 |
+
if self.gender_source is None or self.gender_source == "no":
|
63 |
+
self.gender_source = 0
|
64 |
+
elif self.gender_source == "female":
|
65 |
+
self.gender_source = 1
|
66 |
+
elif self.gender_source == "male":
|
67 |
+
self.gender_source = 2
|
68 |
+
|
69 |
+
if self.gender_target is None or self.gender_target == "no":
|
70 |
+
self.gender_target = 0
|
71 |
+
elif self.gender_target == "female":
|
72 |
+
self.gender_target = 1
|
73 |
+
elif self.gender_target == "male":
|
74 |
+
self.gender_target = 2
|
75 |
+
|
76 |
+
# if self.source is not None:
|
77 |
+
if isinstance(p, StableDiffusionProcessingImg2Img) and swap_in_source:
|
78 |
+
logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index)
|
79 |
+
|
80 |
+
for i in range(len(p.init_images)):
|
81 |
+
if state.interrupted or model_management.processing_interrupted():
|
82 |
+
logger.status("Interrupted by User")
|
83 |
+
break
|
84 |
+
if len(p.init_images) > 1:
|
85 |
+
logger.status(f"Swap in %s", i)
|
86 |
+
result = swap_face(
|
87 |
+
self.source,
|
88 |
+
p.init_images[i],
|
89 |
+
source_faces_index=self.source_faces_index,
|
90 |
+
faces_index=self.faces_index,
|
91 |
+
model=self.model,
|
92 |
+
gender_source=self.gender_source,
|
93 |
+
gender_target=self.gender_target,
|
94 |
+
face_model=self.face_model,
|
95 |
+
)
|
96 |
+
p.init_images[i] = result
|
97 |
+
logger.status("--Done!--")
|
98 |
+
# else:
|
99 |
+
# logger.error(f"Please provide a source face")
|
100 |
+
|
101 |
+
def postprocess_batch(self, p, *args, **kwargs):
|
102 |
+
if self.enable:
|
103 |
+
images = kwargs["images"]
|
104 |
+
|
105 |
+
def postprocess_image(self, p, script_pp: scripts.PostprocessImageArgs, *args):
|
106 |
+
if self.enable and self.swap_in_generated:
|
107 |
+
if self.source is not None:
|
108 |
+
logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index)
|
109 |
+
image: Image.Image = script_pp.image
|
110 |
+
result = swap_face(
|
111 |
+
self.source,
|
112 |
+
image,
|
113 |
+
source_faces_index=self.source_faces_index,
|
114 |
+
faces_index=self.faces_index,
|
115 |
+
model=self.model,
|
116 |
+
upscale_options=self.upscale_options,
|
117 |
+
gender_source=self.gender_source,
|
118 |
+
gender_target=self.gender_target,
|
119 |
+
)
|
120 |
+
try:
|
121 |
+
pp = scripts_postprocessing.PostprocessedImage(result)
|
122 |
+
pp.info = {}
|
123 |
+
p.extra_generation_params.update(pp.info)
|
124 |
+
script_pp.image = pp.image
|
125 |
+
except:
|
126 |
+
logger.error(f"Cannot create a result image")
|
scripts/reactor_logger.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import copy
|
3 |
+
import sys
|
4 |
+
|
5 |
+
from modules import shared
|
6 |
+
from reactor_utils import addLoggingLevel
|
7 |
+
|
8 |
+
|
9 |
+
class ColoredFormatter(logging.Formatter):
|
10 |
+
COLORS = {
|
11 |
+
"DEBUG": "\033[0;36m", # CYAN
|
12 |
+
"STATUS": "\033[38;5;173m", # Calm ORANGE
|
13 |
+
"INFO": "\033[0;32m", # GREEN
|
14 |
+
"WARNING": "\033[0;33m", # YELLOW
|
15 |
+
"ERROR": "\033[0;31m", # RED
|
16 |
+
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
|
17 |
+
"RESET": "\033[0m", # RESET COLOR
|
18 |
+
}
|
19 |
+
|
20 |
+
def format(self, record):
|
21 |
+
colored_record = copy.copy(record)
|
22 |
+
levelname = colored_record.levelname
|
23 |
+
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
|
24 |
+
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
|
25 |
+
return super().format(colored_record)
|
26 |
+
|
27 |
+
|
28 |
+
# Create a new logger
|
29 |
+
logger = logging.getLogger("ReActor")
|
30 |
+
logger.propagate = False
|
31 |
+
|
32 |
+
# Add Custom Level
|
33 |
+
# logging.addLevelName(logging.INFO, "STATUS")
|
34 |
+
addLoggingLevel("STATUS", logging.INFO + 5)
|
35 |
+
|
36 |
+
# Add handler if we don't have one.
|
37 |
+
if not logger.handlers:
|
38 |
+
handler = logging.StreamHandler(sys.stdout)
|
39 |
+
handler.setFormatter(
|
40 |
+
ColoredFormatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s",datefmt="%H:%M:%S")
|
41 |
+
)
|
42 |
+
logger.addHandler(handler)
|
43 |
+
|
44 |
+
# Configure logger
|
45 |
+
loglevel_string = getattr(shared.cmd_opts, "reactor_loglevel", "INFO")
|
46 |
+
loglevel = getattr(logging, loglevel_string.upper(), "info")
|
47 |
+
logger.setLevel(loglevel)
|
scripts/reactor_swapper.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import List, Union
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import insightface
|
12 |
+
from insightface.app.common import Face
|
13 |
+
try:
|
14 |
+
import torch.cuda as cuda
|
15 |
+
except:
|
16 |
+
cuda = None
|
17 |
+
|
18 |
+
from scripts.reactor_logger import logger
|
19 |
+
from reactor_utils import move_path, get_image_md5hash
|
20 |
+
import folder_paths
|
21 |
+
|
22 |
+
import warnings
|
23 |
+
|
24 |
+
np.warnings = warnings
|
25 |
+
np.warnings.filterwarnings('ignore')
|
26 |
+
|
27 |
+
if cuda is not None:
|
28 |
+
if cuda.is_available():
|
29 |
+
providers = ["CUDAExecutionProvider"]
|
30 |
+
else:
|
31 |
+
providers = ["CPUExecutionProvider"]
|
32 |
+
else:
|
33 |
+
providers = ["CPUExecutionProvider"]
|
34 |
+
|
35 |
+
models_path_old = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
|
36 |
+
insightface_path_old = os.path.join(models_path_old, "insightface")
|
37 |
+
insightface_models_path_old = os.path.join(insightface_path_old, "models")
|
38 |
+
|
39 |
+
models_path = folder_paths.models_dir
|
40 |
+
insightface_path = os.path.join(models_path, "insightface")
|
41 |
+
insightface_models_path = os.path.join(insightface_path, "models")
|
42 |
+
|
43 |
+
if os.path.exists(models_path_old):
|
44 |
+
move_path(insightface_models_path_old, insightface_models_path)
|
45 |
+
move_path(insightface_path_old, insightface_path)
|
46 |
+
move_path(models_path_old, models_path)
|
47 |
+
if os.path.exists(insightface_path) and os.path.exists(insightface_path_old):
|
48 |
+
shutil.rmtree(insightface_path_old)
|
49 |
+
shutil.rmtree(models_path_old)
|
50 |
+
|
51 |
+
|
52 |
+
FS_MODEL = None
|
53 |
+
CURRENT_FS_MODEL_PATH = None
|
54 |
+
|
55 |
+
ANALYSIS_MODEL = None
|
56 |
+
|
57 |
+
SOURCE_FACES = None
|
58 |
+
SOURCE_IMAGE_HASH = None
|
59 |
+
TARGET_FACES = None
|
60 |
+
TARGET_IMAGE_HASH = None
|
61 |
+
|
62 |
+
def get_current_faces_model():
|
63 |
+
global SOURCE_FACES
|
64 |
+
return SOURCE_FACES
|
65 |
+
|
66 |
+
def getAnalysisModel():
|
67 |
+
global ANALYSIS_MODEL
|
68 |
+
if ANALYSIS_MODEL is None:
|
69 |
+
ANALYSIS_MODEL = insightface.app.FaceAnalysis(
|
70 |
+
name="buffalo_l", providers=providers, root=insightface_path
|
71 |
+
)
|
72 |
+
return ANALYSIS_MODEL
|
73 |
+
|
74 |
+
|
75 |
+
def getFaceSwapModel(model_path: str):
|
76 |
+
global FS_MODEL
|
77 |
+
global CURRENT_FS_MODEL_PATH
|
78 |
+
if CURRENT_FS_MODEL_PATH is None or CURRENT_FS_MODEL_PATH != model_path:
|
79 |
+
CURRENT_FS_MODEL_PATH = model_path
|
80 |
+
FS_MODEL = insightface.model_zoo.get_model(model_path, providers=providers)
|
81 |
+
|
82 |
+
return FS_MODEL
|
83 |
+
|
84 |
+
|
85 |
+
def get_face_gender(
|
86 |
+
face,
|
87 |
+
face_index,
|
88 |
+
gender_condition,
|
89 |
+
operated: str
|
90 |
+
):
|
91 |
+
gender = [
|
92 |
+
x.sex
|
93 |
+
for x in face
|
94 |
+
]
|
95 |
+
gender.reverse()
|
96 |
+
# If index is outside of bounds, return None, avoid exception
|
97 |
+
if face_index >= len(gender):
|
98 |
+
logger.status("Requested face index (%s) is out of bounds (max available index is %s)", face_index, len(gender))
|
99 |
+
return None, 0
|
100 |
+
face_gender = gender[face_index]
|
101 |
+
logger.status("%s Face %s: Detected Gender -%s-", operated, face_index, face_gender)
|
102 |
+
if (gender_condition == 1 and face_gender == "F") or (gender_condition == 2 and face_gender == "M"):
|
103 |
+
logger.status("OK - Detected Gender matches Condition")
|
104 |
+
try:
|
105 |
+
return sorted(face, key=lambda x: x.bbox[0])[face_index], 0
|
106 |
+
except IndexError:
|
107 |
+
return None, 0
|
108 |
+
else:
|
109 |
+
logger.status("WRONG - Detected Gender doesn't match Condition")
|
110 |
+
return sorted(face, key=lambda x: x.bbox[0])[face_index], 1
|
111 |
+
|
112 |
+
|
113 |
+
def half_det_size(det_size):
|
114 |
+
logger.status("Trying to halve 'det_size' parameter")
|
115 |
+
return (det_size[0] // 2, det_size[1] // 2)
|
116 |
+
|
117 |
+
def analyze_faces(img_data: np.ndarray, det_size=(640, 640)):
|
118 |
+
face_analyser = copy.deepcopy(getAnalysisModel())
|
119 |
+
face_analyser.prepare(ctx_id=0, det_size=det_size)
|
120 |
+
return face_analyser.get(img_data)
|
121 |
+
|
122 |
+
def get_face_single(img_data: np.ndarray, face, face_index=0, det_size=(640, 640), gender_source=0, gender_target=0):
|
123 |
+
|
124 |
+
buffalo_path = os.path.join(insightface_models_path, "buffalo_l.zip")
|
125 |
+
if os.path.exists(buffalo_path):
|
126 |
+
os.remove(buffalo_path)
|
127 |
+
|
128 |
+
if gender_source != 0:
|
129 |
+
if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
|
130 |
+
det_size_half = half_det_size(det_size)
|
131 |
+
return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target)
|
132 |
+
return get_face_gender(face,face_index,gender_source,"Source")
|
133 |
+
|
134 |
+
if gender_target != 0:
|
135 |
+
if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
|
136 |
+
det_size_half = half_det_size(det_size)
|
137 |
+
return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target)
|
138 |
+
return get_face_gender(face,face_index,gender_target,"Target")
|
139 |
+
|
140 |
+
if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
|
141 |
+
det_size_half = half_det_size(det_size)
|
142 |
+
return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target)
|
143 |
+
|
144 |
+
try:
|
145 |
+
return sorted(face, key=lambda x: x.bbox[0])[face_index], 0
|
146 |
+
except IndexError:
|
147 |
+
return None, 0
|
148 |
+
|
149 |
+
|
150 |
+
def swap_face(
|
151 |
+
source_img: Union[Image.Image, None],
|
152 |
+
target_img: Image.Image,
|
153 |
+
model: Union[str, None] = None,
|
154 |
+
source_faces_index: List[int] = [0],
|
155 |
+
faces_index: List[int] = [0],
|
156 |
+
gender_source: int = 0,
|
157 |
+
gender_target: int = 0,
|
158 |
+
face_model: Union[Face, None] = None,
|
159 |
+
):
|
160 |
+
global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH
|
161 |
+
result_image = target_img
|
162 |
+
|
163 |
+
if model is not None:
|
164 |
+
|
165 |
+
if isinstance(source_img, str): # source_img is a base64 string
|
166 |
+
import base64, io
|
167 |
+
if 'base64,' in source_img: # check if the base64 string has a data URL scheme
|
168 |
+
# split the base64 string to get the actual base64 encoded image data
|
169 |
+
base64_data = source_img.split('base64,')[-1]
|
170 |
+
# decode base64 string to bytes
|
171 |
+
img_bytes = base64.b64decode(base64_data)
|
172 |
+
else:
|
173 |
+
# if no data URL scheme, just decode
|
174 |
+
img_bytes = base64.b64decode(source_img)
|
175 |
+
|
176 |
+
source_img = Image.open(io.BytesIO(img_bytes))
|
177 |
+
|
178 |
+
target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
|
179 |
+
|
180 |
+
if source_img is not None:
|
181 |
+
|
182 |
+
source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
|
183 |
+
|
184 |
+
source_image_md5hash = get_image_md5hash(source_img)
|
185 |
+
|
186 |
+
if SOURCE_IMAGE_HASH is None:
|
187 |
+
SOURCE_IMAGE_HASH = source_image_md5hash
|
188 |
+
source_image_same = False
|
189 |
+
else:
|
190 |
+
source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False
|
191 |
+
if not source_image_same:
|
192 |
+
SOURCE_IMAGE_HASH = source_image_md5hash
|
193 |
+
|
194 |
+
logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH)
|
195 |
+
logger.info("Source Image the Same? %s", source_image_same)
|
196 |
+
|
197 |
+
if SOURCE_FACES is None or not source_image_same:
|
198 |
+
logger.status("Analyzing Source Image...")
|
199 |
+
source_faces = analyze_faces(source_img)
|
200 |
+
SOURCE_FACES = source_faces
|
201 |
+
elif source_image_same:
|
202 |
+
logger.status("Using Hashed Source Face(s) Model...")
|
203 |
+
source_faces = SOURCE_FACES
|
204 |
+
|
205 |
+
elif face_model is not None:
|
206 |
+
|
207 |
+
source_faces_index = [0]
|
208 |
+
logger.status("Using Loaded Source Face Model...")
|
209 |
+
source_face_model = [face_model]
|
210 |
+
source_faces = source_face_model
|
211 |
+
|
212 |
+
else:
|
213 |
+
logger.error("Cannot detect any Source")
|
214 |
+
|
215 |
+
if source_faces is not None:
|
216 |
+
|
217 |
+
target_image_md5hash = get_image_md5hash(target_img)
|
218 |
+
|
219 |
+
if TARGET_IMAGE_HASH is None:
|
220 |
+
TARGET_IMAGE_HASH = target_image_md5hash
|
221 |
+
target_image_same = False
|
222 |
+
else:
|
223 |
+
target_image_same = True if TARGET_IMAGE_HASH == target_image_md5hash else False
|
224 |
+
if not target_image_same:
|
225 |
+
TARGET_IMAGE_HASH = target_image_md5hash
|
226 |
+
|
227 |
+
logger.info("Target Image MD5 Hash = %s", TARGET_IMAGE_HASH)
|
228 |
+
logger.info("Target Image the Same? %s", target_image_same)
|
229 |
+
|
230 |
+
if TARGET_FACES is None or not target_image_same:
|
231 |
+
logger.status("Analyzing Target Image...")
|
232 |
+
target_faces = analyze_faces(target_img)
|
233 |
+
TARGET_FACES = target_faces
|
234 |
+
elif target_image_same:
|
235 |
+
logger.status("Using Hashed Target Face(s) Model...")
|
236 |
+
target_faces = TARGET_FACES
|
237 |
+
|
238 |
+
# No use in trying to swap faces if no faces are found, enhancement
|
239 |
+
if len(target_faces) == 0:
|
240 |
+
logger.status("Cannot detect any Target, skipping swapping...")
|
241 |
+
return result_image
|
242 |
+
|
243 |
+
if source_img is not None:
|
244 |
+
# separated management of wrong_gender between source and target, enhancement
|
245 |
+
source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source)
|
246 |
+
else:
|
247 |
+
source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]]
|
248 |
+
src_wrong_gender = 0
|
249 |
+
|
250 |
+
if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index):
|
251 |
+
logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.')
|
252 |
+
elif source_face is not None:
|
253 |
+
result = target_img
|
254 |
+
model_path = model_path = os.path.join(insightface_path, model)
|
255 |
+
face_swapper = getFaceSwapModel(model_path)
|
256 |
+
|
257 |
+
source_face_idx = 0
|
258 |
+
|
259 |
+
for face_num in faces_index:
|
260 |
+
# No use in trying to swap faces if no further faces are found, enhancement
|
261 |
+
if face_num >= len(target_faces):
|
262 |
+
logger.status("Checked all existing target faces, skipping swapping...")
|
263 |
+
break
|
264 |
+
|
265 |
+
if len(source_faces_index) > 1 and source_face_idx > 0:
|
266 |
+
source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source)
|
267 |
+
source_face_idx += 1
|
268 |
+
|
269 |
+
if source_face is not None and src_wrong_gender == 0:
|
270 |
+
target_face, wrong_gender = get_face_single(target_img, target_faces, face_index=face_num, gender_target=gender_target)
|
271 |
+
if target_face is not None and wrong_gender == 0:
|
272 |
+
logger.status(f"Swapping...")
|
273 |
+
result = face_swapper.get(result, target_face, source_face)
|
274 |
+
elif wrong_gender == 1:
|
275 |
+
wrong_gender = 0
|
276 |
+
# Keep searching for other faces if wrong gender is detected, enhancement
|
277 |
+
#if source_face_idx == len(source_faces_index):
|
278 |
+
# result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
|
279 |
+
# return result_image
|
280 |
+
logger.status("Wrong target gender detected")
|
281 |
+
continue
|
282 |
+
else:
|
283 |
+
logger.status(f"No target face found for {face_num}")
|
284 |
+
elif src_wrong_gender == 1:
|
285 |
+
src_wrong_gender = 0
|
286 |
+
# Keep searching for other faces if wrong gender is detected, enhancement
|
287 |
+
#if source_face_idx == len(source_faces_index):
|
288 |
+
# result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
|
289 |
+
# return result_image
|
290 |
+
logger.status("Wrong source gender detected")
|
291 |
+
continue
|
292 |
+
else:
|
293 |
+
logger.status(f"No source face found for face number {source_face_idx}.")
|
294 |
+
|
295 |
+
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
|
296 |
+
|
297 |
+
else:
|
298 |
+
logger.status("No source face(s) in the provided Index")
|
299 |
+
else:
|
300 |
+
logger.status("No source face(s) found")
|
301 |
+
return result_image
|
scripts/reactor_version.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
app_title = "ReActor Node for ComfyUI"
|
2 |
+
version_flag = "v0.4.1-b12"
|
3 |
+
|
4 |
+
COLORS = {
|
5 |
+
"CYAN": "\033[0;36m", # CYAN
|
6 |
+
"ORANGE": "\033[38;5;173m", # Calm ORANGE
|
7 |
+
"GREEN": "\033[0;32m", # GREEN
|
8 |
+
"YELLOW": "\033[0;33m", # YELLOW
|
9 |
+
"RED": "\033[0;91m", # RED
|
10 |
+
"0": "\033[0m", # RESET COLOR
|
11 |
+
}
|
12 |
+
|
13 |
+
print(f"{COLORS['YELLOW']}[ReActor]{COLORS['0']} - {COLORS['ORANGE']}STATUS{COLORS['0']} - {COLORS['GREEN']}Running {version_flag} in ComfyUI{COLORS['0']}")
|