ZiqianLiu commited on
Commit
55f6076
1 Parent(s): cb2f529

Upload 14 files

Browse files
scripts/__init__.py ADDED
File without changes
scripts/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (175 Bytes). View file
 
scripts/__pycache__/reactor_faceswap.cpython-310.pyc ADDED
Binary file (3.59 kB). View file
 
scripts/__pycache__/reactor_logger.cpython-310.pyc ADDED
Binary file (1.4 kB). View file
 
scripts/__pycache__/reactor_swapper.cpython-310.pyc ADDED
Binary file (7.24 kB). View file
 
scripts/__pycache__/reactor_version.cpython-310.pyc ADDED
Binary file (534 Bytes). View file
 
scripts/r_archs/__pycache__/codeformer_arch.cpython-310.pyc ADDED
Binary file (9.24 kB). View file
 
scripts/r_archs/__pycache__/vqgan_arch.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
scripts/r_archs/codeformer_arch.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn, Tensor
5
+ import torch.nn.functional as F
6
+ from typing import Optional, List
7
+
8
+ from scripts.r_archs.vqgan_arch import *
9
+ from r_basicsr.utils import get_root_logger
10
+ from r_basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+
13
+ def calc_mean_std(feat, eps=1e-5):
14
+ """Calculate mean and std for adaptive_instance_normalization.
15
+
16
+ Args:
17
+ feat (Tensor): 4D tensor.
18
+ eps (float): A small value added to the variance to avoid
19
+ divide-by-zero. Default: 1e-5.
20
+ """
21
+ size = feat.size()
22
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
23
+ b, c = size[:2]
24
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
25
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
26
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
27
+ return feat_mean, feat_std
28
+
29
+
30
+ def adaptive_instance_normalization(content_feat, style_feat):
31
+ """Adaptive instance normalization.
32
+
33
+ Adjust the reference features to have the similar color and illuminations
34
+ as those in the degradate features.
35
+
36
+ Args:
37
+ content_feat (Tensor): The reference feature.
38
+ style_feat (Tensor): The degradate features.
39
+ """
40
+ size = content_feat.size()
41
+ style_mean, style_std = calc_mean_std(style_feat)
42
+ content_mean, content_std = calc_mean_std(content_feat)
43
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
44
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
45
+
46
+
47
+ class PositionEmbeddingSine(nn.Module):
48
+ """
49
+ This is a more standard version of the position embedding, very similar to the one
50
+ used by the Attention is all you need paper, generalized to work on images.
51
+ """
52
+
53
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
54
+ super().__init__()
55
+ self.num_pos_feats = num_pos_feats
56
+ self.temperature = temperature
57
+ self.normalize = normalize
58
+ if scale is not None and normalize is False:
59
+ raise ValueError("normalize should be True if scale is passed")
60
+ if scale is None:
61
+ scale = 2 * math.pi
62
+ self.scale = scale
63
+
64
+ def forward(self, x, mask=None):
65
+ if mask is None:
66
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
67
+ not_mask = ~mask
68
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
69
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
70
+ if self.normalize:
71
+ eps = 1e-6
72
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
73
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
74
+
75
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
76
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
77
+
78
+ pos_x = x_embed[:, :, :, None] / dim_t
79
+ pos_y = y_embed[:, :, :, None] / dim_t
80
+ pos_x = torch.stack(
81
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
82
+ ).flatten(3)
83
+ pos_y = torch.stack(
84
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
85
+ ).flatten(3)
86
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
87
+ return pos
88
+
89
+ def _get_activation_fn(activation):
90
+ """Return an activation function given a string"""
91
+ if activation == "relu":
92
+ return F.relu
93
+ if activation == "gelu":
94
+ return F.gelu
95
+ if activation == "glu":
96
+ return F.glu
97
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
98
+
99
+
100
+ class TransformerSALayer(nn.Module):
101
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
102
+ super().__init__()
103
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
104
+ # Implementation of Feedforward model - MLP
105
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
106
+ self.dropout = nn.Dropout(dropout)
107
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
108
+
109
+ self.norm1 = nn.LayerNorm(embed_dim)
110
+ self.norm2 = nn.LayerNorm(embed_dim)
111
+ self.dropout1 = nn.Dropout(dropout)
112
+ self.dropout2 = nn.Dropout(dropout)
113
+
114
+ self.activation = _get_activation_fn(activation)
115
+
116
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
117
+ return tensor if pos is None else tensor + pos
118
+
119
+ def forward(self, tgt,
120
+ tgt_mask: Optional[Tensor] = None,
121
+ tgt_key_padding_mask: Optional[Tensor] = None,
122
+ query_pos: Optional[Tensor] = None):
123
+
124
+ # self attention
125
+ tgt2 = self.norm1(tgt)
126
+ q = k = self.with_pos_embed(tgt2, query_pos)
127
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
128
+ key_padding_mask=tgt_key_padding_mask)[0]
129
+ tgt = tgt + self.dropout1(tgt2)
130
+
131
+ # ffn
132
+ tgt2 = self.norm2(tgt)
133
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
134
+ tgt = tgt + self.dropout2(tgt2)
135
+ return tgt
136
+
137
+ class Fuse_sft_block(nn.Module):
138
+ def __init__(self, in_ch, out_ch):
139
+ super().__init__()
140
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
141
+
142
+ self.scale = nn.Sequential(
143
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
144
+ nn.LeakyReLU(0.2, True),
145
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
146
+
147
+ self.shift = nn.Sequential(
148
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
149
+ nn.LeakyReLU(0.2, True),
150
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
151
+
152
+ def forward(self, enc_feat, dec_feat, w=1):
153
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
154
+ scale = self.scale(enc_feat)
155
+ shift = self.shift(enc_feat)
156
+ residual = w * (dec_feat * scale + shift)
157
+ out = dec_feat + residual
158
+ return out
159
+
160
+
161
+ @ARCH_REGISTRY.register()
162
+ class CodeFormer(VQAutoEncoder):
163
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
164
+ codebook_size=1024, latent_size=256,
165
+ connect_list=['32', '64', '128', '256'],
166
+ fix_modules=['quantize','generator']):
167
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
168
+
169
+ if fix_modules is not None:
170
+ for module in fix_modules:
171
+ for param in getattr(self, module).parameters():
172
+ param.requires_grad = False
173
+
174
+ self.connect_list = connect_list
175
+ self.n_layers = n_layers
176
+ self.dim_embd = dim_embd
177
+ self.dim_mlp = dim_embd*2
178
+
179
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
180
+ self.feat_emb = nn.Linear(256, self.dim_embd)
181
+
182
+ # transformer
183
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
184
+ for _ in range(self.n_layers)])
185
+
186
+ # logits_predict head
187
+ self.idx_pred_layer = nn.Sequential(
188
+ nn.LayerNorm(dim_embd),
189
+ nn.Linear(dim_embd, codebook_size, bias=False))
190
+
191
+ self.channels = {
192
+ '16': 512,
193
+ '32': 256,
194
+ '64': 256,
195
+ '128': 128,
196
+ '256': 128,
197
+ '512': 64,
198
+ }
199
+
200
+ # after second residual block for > 16, before attn layer for ==16
201
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
202
+ # after first residual block for > 16, before attn layer for ==16
203
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
204
+
205
+ # fuse_convs_dict
206
+ self.fuse_convs_dict = nn.ModuleDict()
207
+ for f_size in self.connect_list:
208
+ in_ch = self.channels[f_size]
209
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
210
+
211
+ def _init_weights(self, module):
212
+ if isinstance(module, (nn.Linear, nn.Embedding)):
213
+ module.weight.data.normal_(mean=0.0, std=0.02)
214
+ if isinstance(module, nn.Linear) and module.bias is not None:
215
+ module.bias.data.zero_()
216
+ elif isinstance(module, nn.LayerNorm):
217
+ module.bias.data.zero_()
218
+ module.weight.data.fill_(1.0)
219
+
220
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
221
+ # ################### Encoder #####################
222
+ enc_feat_dict = {}
223
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
224
+ for i, block in enumerate(self.encoder.blocks):
225
+ x = block(x)
226
+ if i in out_list:
227
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
228
+
229
+ lq_feat = x
230
+ # ################# Transformer ###################
231
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
232
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
233
+ # BCHW -> BC(HW) -> (HW)BC
234
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
235
+ query_emb = feat_emb
236
+ # Transformer encoder
237
+ for layer in self.ft_layers:
238
+ query_emb = layer(query_emb, query_pos=pos_emb)
239
+
240
+ # output logits
241
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
242
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
243
+
244
+ if code_only: # for training stage II
245
+ # logits doesn't need softmax before cross_entropy loss
246
+ return logits, lq_feat
247
+
248
+ # ################# Quantization ###################
249
+ # if self.training:
250
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
251
+ # # b(hw)c -> bc(hw) -> bchw
252
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
253
+ # ------------
254
+ soft_one_hot = F.softmax(logits, dim=2)
255
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
256
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
257
+ # preserve gradients
258
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
259
+
260
+ if detach_16:
261
+ quant_feat = quant_feat.detach() # for training stage III
262
+ if adain:
263
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
264
+
265
+ # ################## Generator ####################
266
+ x = quant_feat
267
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
268
+
269
+ for i, block in enumerate(self.generator.blocks):
270
+ x = block(x)
271
+ if i in fuse_list: # fuse after i-th block
272
+ f_size = str(x.shape[-1])
273
+ if w>0:
274
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
275
+ out = x
276
+ # logits doesn't need softmax before cross_entropy loss
277
+ return out, logits, lq_feat
278
+
scripts/r_archs/vqgan_arch.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
3
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
4
+
5
+ '''
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import copy
11
+ from r_basicsr.utils import get_root_logger
12
+ from r_basicsr.utils.registry import ARCH_REGISTRY
13
+
14
+
15
+ def normalize(in_channels):
16
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
17
+
18
+
19
+ @torch.jit.script
20
+ def swish(x):
21
+ return x*torch.sigmoid(x)
22
+
23
+
24
+ # Define VQVAE classes
25
+ class VectorQuantizer(nn.Module):
26
+ def __init__(self, codebook_size, emb_dim, beta):
27
+ super(VectorQuantizer, self).__init__()
28
+ self.codebook_size = codebook_size # number of embeddings
29
+ self.emb_dim = emb_dim # dimension of embedding
30
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
31
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
32
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
33
+
34
+ def forward(self, z):
35
+ # reshape z -> (batch, height, width, channel) and flatten
36
+ z = z.permute(0, 2, 3, 1).contiguous()
37
+ z_flattened = z.view(-1, self.emb_dim)
38
+
39
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
40
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
41
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
42
+
43
+ mean_distance = torch.mean(d)
44
+ # find closest encodings
45
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
46
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
47
+ # [0-1], higher score, higher confidence
48
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
49
+
50
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
51
+ min_encodings.scatter_(1, min_encoding_indices, 1)
52
+
53
+ # get quantized latent vectors
54
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
55
+ # compute loss for embedding
56
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
57
+ # preserve gradients
58
+ z_q = z + (z_q - z).detach()
59
+
60
+ # perplexity
61
+ e_mean = torch.mean(min_encodings, dim=0)
62
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
63
+ # reshape back to match original input shape
64
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
65
+
66
+ return z_q, loss, {
67
+ "perplexity": perplexity,
68
+ "min_encodings": min_encodings,
69
+ "min_encoding_indices": min_encoding_indices,
70
+ "min_encoding_scores": min_encoding_scores,
71
+ "mean_distance": mean_distance
72
+ }
73
+
74
+ def get_codebook_feat(self, indices, shape):
75
+ # input indices: batch*token_num -> (batch*token_num)*1
76
+ # shape: batch, height, width, channel
77
+ indices = indices.view(-1,1)
78
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
79
+ min_encodings.scatter_(1, indices, 1)
80
+ # get quantized latent vectors
81
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
82
+
83
+ if shape is not None: # reshape back to match original input shape
84
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
85
+
86
+ return z_q
87
+
88
+
89
+ class GumbelQuantizer(nn.Module):
90
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
91
+ super().__init__()
92
+ self.codebook_size = codebook_size # number of embeddings
93
+ self.emb_dim = emb_dim # dimension of embedding
94
+ self.straight_through = straight_through
95
+ self.temperature = temp_init
96
+ self.kl_weight = kl_weight
97
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
98
+ self.embed = nn.Embedding(codebook_size, emb_dim)
99
+
100
+ def forward(self, z):
101
+ hard = self.straight_through if self.training else True
102
+
103
+ logits = self.proj(z)
104
+
105
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
106
+
107
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
108
+
109
+ # + kl divergence to the prior loss
110
+ qy = F.softmax(logits, dim=1)
111
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
112
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
113
+
114
+ return z_q, diff, {
115
+ "min_encoding_indices": min_encoding_indices
116
+ }
117
+
118
+
119
+ class Downsample(nn.Module):
120
+ def __init__(self, in_channels):
121
+ super().__init__()
122
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
123
+
124
+ def forward(self, x):
125
+ pad = (0, 1, 0, 1)
126
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
127
+ x = self.conv(x)
128
+ return x
129
+
130
+
131
+ class Upsample(nn.Module):
132
+ def __init__(self, in_channels):
133
+ super().__init__()
134
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
135
+
136
+ def forward(self, x):
137
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
138
+ x = self.conv(x)
139
+
140
+ return x
141
+
142
+
143
+ class ResBlock(nn.Module):
144
+ def __init__(self, in_channels, out_channels=None):
145
+ super(ResBlock, self).__init__()
146
+ self.in_channels = in_channels
147
+ self.out_channels = in_channels if out_channels is None else out_channels
148
+ self.norm1 = normalize(in_channels)
149
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
150
+ self.norm2 = normalize(out_channels)
151
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
152
+ if self.in_channels != self.out_channels:
153
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
154
+
155
+ def forward(self, x_in):
156
+ x = x_in
157
+ x = self.norm1(x)
158
+ x = swish(x)
159
+ x = self.conv1(x)
160
+ x = self.norm2(x)
161
+ x = swish(x)
162
+ x = self.conv2(x)
163
+ if self.in_channels != self.out_channels:
164
+ x_in = self.conv_out(x_in)
165
+
166
+ return x + x_in
167
+
168
+
169
+ class AttnBlock(nn.Module):
170
+ def __init__(self, in_channels):
171
+ super().__init__()
172
+ self.in_channels = in_channels
173
+
174
+ self.norm = normalize(in_channels)
175
+ self.q = torch.nn.Conv2d(
176
+ in_channels,
177
+ in_channels,
178
+ kernel_size=1,
179
+ stride=1,
180
+ padding=0
181
+ )
182
+ self.k = torch.nn.Conv2d(
183
+ in_channels,
184
+ in_channels,
185
+ kernel_size=1,
186
+ stride=1,
187
+ padding=0
188
+ )
189
+ self.v = torch.nn.Conv2d(
190
+ in_channels,
191
+ in_channels,
192
+ kernel_size=1,
193
+ stride=1,
194
+ padding=0
195
+ )
196
+ self.proj_out = torch.nn.Conv2d(
197
+ in_channels,
198
+ in_channels,
199
+ kernel_size=1,
200
+ stride=1,
201
+ padding=0
202
+ )
203
+
204
+ def forward(self, x):
205
+ h_ = x
206
+ h_ = self.norm(h_)
207
+ q = self.q(h_)
208
+ k = self.k(h_)
209
+ v = self.v(h_)
210
+
211
+ # compute attention
212
+ b, c, h, w = q.shape
213
+ q = q.reshape(b, c, h*w)
214
+ q = q.permute(0, 2, 1)
215
+ k = k.reshape(b, c, h*w)
216
+ w_ = torch.bmm(q, k)
217
+ w_ = w_ * (int(c)**(-0.5))
218
+ w_ = F.softmax(w_, dim=2)
219
+
220
+ # attend to values
221
+ v = v.reshape(b, c, h*w)
222
+ w_ = w_.permute(0, 2, 1)
223
+ h_ = torch.bmm(v, w_)
224
+ h_ = h_.reshape(b, c, h, w)
225
+
226
+ h_ = self.proj_out(h_)
227
+
228
+ return x+h_
229
+
230
+
231
+ class Encoder(nn.Module):
232
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
233
+ super().__init__()
234
+ self.nf = nf
235
+ self.num_resolutions = len(ch_mult)
236
+ self.num_res_blocks = num_res_blocks
237
+ self.resolution = resolution
238
+ self.attn_resolutions = attn_resolutions
239
+
240
+ curr_res = self.resolution
241
+ in_ch_mult = (1,)+tuple(ch_mult)
242
+
243
+ blocks = []
244
+ # initial convultion
245
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
246
+
247
+ # residual and downsampling blocks, with attention on smaller res (16x16)
248
+ for i in range(self.num_resolutions):
249
+ block_in_ch = nf * in_ch_mult[i]
250
+ block_out_ch = nf * ch_mult[i]
251
+ for _ in range(self.num_res_blocks):
252
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
253
+ block_in_ch = block_out_ch
254
+ if curr_res in attn_resolutions:
255
+ blocks.append(AttnBlock(block_in_ch))
256
+
257
+ if i != self.num_resolutions - 1:
258
+ blocks.append(Downsample(block_in_ch))
259
+ curr_res = curr_res // 2
260
+
261
+ # non-local attention block
262
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
263
+ blocks.append(AttnBlock(block_in_ch))
264
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
265
+
266
+ # normalise and convert to latent size
267
+ blocks.append(normalize(block_in_ch))
268
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
269
+ self.blocks = nn.ModuleList(blocks)
270
+
271
+ def forward(self, x):
272
+ for block in self.blocks:
273
+ x = block(x)
274
+
275
+ return x
276
+
277
+
278
+ class Generator(nn.Module):
279
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
280
+ super().__init__()
281
+ self.nf = nf
282
+ self.ch_mult = ch_mult
283
+ self.num_resolutions = len(self.ch_mult)
284
+ self.num_res_blocks = res_blocks
285
+ self.resolution = img_size
286
+ self.attn_resolutions = attn_resolutions
287
+ self.in_channels = emb_dim
288
+ self.out_channels = 3
289
+ block_in_ch = self.nf * self.ch_mult[-1]
290
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
291
+
292
+ blocks = []
293
+ # initial conv
294
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
295
+
296
+ # non-local attention block
297
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
298
+ blocks.append(AttnBlock(block_in_ch))
299
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
300
+
301
+ for i in reversed(range(self.num_resolutions)):
302
+ block_out_ch = self.nf * self.ch_mult[i]
303
+
304
+ for _ in range(self.num_res_blocks):
305
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
306
+ block_in_ch = block_out_ch
307
+
308
+ if curr_res in self.attn_resolutions:
309
+ blocks.append(AttnBlock(block_in_ch))
310
+
311
+ if i != 0:
312
+ blocks.append(Upsample(block_in_ch))
313
+ curr_res = curr_res * 2
314
+
315
+ blocks.append(normalize(block_in_ch))
316
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
317
+
318
+ self.blocks = nn.ModuleList(blocks)
319
+
320
+
321
+ def forward(self, x):
322
+ for block in self.blocks:
323
+ x = block(x)
324
+
325
+ return x
326
+
327
+
328
+ @ARCH_REGISTRY.register()
329
+ class VQAutoEncoder(nn.Module):
330
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
331
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
332
+ super().__init__()
333
+ logger = get_root_logger()
334
+ self.in_channels = 3
335
+ self.nf = nf
336
+ self.n_blocks = res_blocks
337
+ self.codebook_size = codebook_size
338
+ self.embed_dim = emb_dim
339
+ self.ch_mult = ch_mult
340
+ self.resolution = img_size
341
+ self.attn_resolutions = attn_resolutions
342
+ self.quantizer_type = quantizer
343
+ self.encoder = Encoder(
344
+ self.in_channels,
345
+ self.nf,
346
+ self.embed_dim,
347
+ self.ch_mult,
348
+ self.n_blocks,
349
+ self.resolution,
350
+ self.attn_resolutions
351
+ )
352
+ if self.quantizer_type == "nearest":
353
+ self.beta = beta #0.25
354
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
355
+ elif self.quantizer_type == "gumbel":
356
+ self.gumbel_num_hiddens = emb_dim
357
+ self.straight_through = gumbel_straight_through
358
+ self.kl_weight = gumbel_kl_weight
359
+ self.quantize = GumbelQuantizer(
360
+ self.codebook_size,
361
+ self.embed_dim,
362
+ self.gumbel_num_hiddens,
363
+ self.straight_through,
364
+ self.kl_weight
365
+ )
366
+ self.generator = Generator(
367
+ self.nf,
368
+ self.embed_dim,
369
+ self.ch_mult,
370
+ self.n_blocks,
371
+ self.resolution,
372
+ self.attn_resolutions
373
+ )
374
+
375
+ if model_path is not None:
376
+ chkpt = torch.load(model_path, map_location='cpu')
377
+ if 'params_ema' in chkpt:
378
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
379
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
380
+ elif 'params' in chkpt:
381
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
382
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
383
+ else:
384
+ raise ValueError(f'Wrong params!')
385
+
386
+
387
+ def forward(self, x):
388
+ x = self.encoder(x)
389
+ quant, codebook_loss, quant_stats = self.quantize(x)
390
+ x = self.generator(quant)
391
+ return x, codebook_loss, quant_stats
392
+
393
+
394
+
395
+ # patch based discriminator
396
+ @ARCH_REGISTRY.register()
397
+ class VQGANDiscriminator(nn.Module):
398
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
399
+ super().__init__()
400
+
401
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
402
+ ndf_mult = 1
403
+ ndf_mult_prev = 1
404
+ for n in range(1, n_layers): # gradually increase the number of filters
405
+ ndf_mult_prev = ndf_mult
406
+ ndf_mult = min(2 ** n, 8)
407
+ layers += [
408
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
409
+ nn.BatchNorm2d(ndf * ndf_mult),
410
+ nn.LeakyReLU(0.2, True)
411
+ ]
412
+
413
+ ndf_mult_prev = ndf_mult
414
+ ndf_mult = min(2 ** n_layers, 8)
415
+
416
+ layers += [
417
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
418
+ nn.BatchNorm2d(ndf * ndf_mult),
419
+ nn.LeakyReLU(0.2, True)
420
+ ]
421
+
422
+ layers += [
423
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
424
+ self.main = nn.Sequential(*layers)
425
+
426
+ if model_path is not None:
427
+ chkpt = torch.load(model_path, map_location='cpu')
428
+ if 'params_d' in chkpt:
429
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
430
+ elif 'params' in chkpt:
431
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
432
+ else:
433
+ raise ValueError(f'Wrong params!')
434
+
435
+ def forward(self, x):
436
+ return self.main(x)
437
+
scripts/reactor_faceswap.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+
3
+ from PIL import Image
4
+
5
+ import modules.scripts as scripts
6
+ # from modules.upscaler import Upscaler, UpscalerData
7
+ from modules import scripts, scripts_postprocessing
8
+ from modules.processing import (
9
+ StableDiffusionProcessing,
10
+ StableDiffusionProcessingImg2Img,
11
+ )
12
+ from modules.shared import state
13
+ from scripts.reactor_logger import logger
14
+ from scripts.reactor_swapper import swap_face, get_current_faces_model, analyze_faces, half_det_size
15
+ import folder_paths
16
+ import comfy.model_management as model_management
17
+
18
+
19
+ def get_models():
20
+ models_path = os.path.join(folder_paths.models_dir,"insightface/*")
21
+ models = glob.glob(models_path)
22
+ models = [x for x in models if x.endswith(".onnx") or x.endswith(".pth")]
23
+ return models
24
+
25
+
26
+ class FaceSwapScript(scripts.Script):
27
+
28
+ def process(
29
+ self,
30
+ p: StableDiffusionProcessing,
31
+ img,
32
+ enable,
33
+ source_faces_index,
34
+ faces_index,
35
+ model,
36
+ swap_in_source,
37
+ swap_in_generated,
38
+ gender_source,
39
+ gender_target,
40
+ face_model,
41
+ ):
42
+ self.enable = enable
43
+ if self.enable:
44
+
45
+ self.source = img
46
+ self.swap_in_generated = swap_in_generated
47
+ self.gender_source = gender_source
48
+ self.gender_target = gender_target
49
+ self.model = model
50
+ self.face_model = face_model
51
+ self.source_faces_index = [
52
+ int(x) for x in source_faces_index.strip(",").split(",") if x.isnumeric()
53
+ ]
54
+ self.faces_index = [
55
+ int(x) for x in faces_index.strip(",").split(",") if x.isnumeric()
56
+ ]
57
+ if len(self.source_faces_index) == 0:
58
+ self.source_faces_index = [0]
59
+ if len(self.faces_index) == 0:
60
+ self.faces_index = [0]
61
+
62
+ if self.gender_source is None or self.gender_source == "no":
63
+ self.gender_source = 0
64
+ elif self.gender_source == "female":
65
+ self.gender_source = 1
66
+ elif self.gender_source == "male":
67
+ self.gender_source = 2
68
+
69
+ if self.gender_target is None or self.gender_target == "no":
70
+ self.gender_target = 0
71
+ elif self.gender_target == "female":
72
+ self.gender_target = 1
73
+ elif self.gender_target == "male":
74
+ self.gender_target = 2
75
+
76
+ # if self.source is not None:
77
+ if isinstance(p, StableDiffusionProcessingImg2Img) and swap_in_source:
78
+ logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index)
79
+
80
+ for i in range(len(p.init_images)):
81
+ if state.interrupted or model_management.processing_interrupted():
82
+ logger.status("Interrupted by User")
83
+ break
84
+ if len(p.init_images) > 1:
85
+ logger.status(f"Swap in %s", i)
86
+ result = swap_face(
87
+ self.source,
88
+ p.init_images[i],
89
+ source_faces_index=self.source_faces_index,
90
+ faces_index=self.faces_index,
91
+ model=self.model,
92
+ gender_source=self.gender_source,
93
+ gender_target=self.gender_target,
94
+ face_model=self.face_model,
95
+ )
96
+ p.init_images[i] = result
97
+ logger.status("--Done!--")
98
+ # else:
99
+ # logger.error(f"Please provide a source face")
100
+
101
+ def postprocess_batch(self, p, *args, **kwargs):
102
+ if self.enable:
103
+ images = kwargs["images"]
104
+
105
+ def postprocess_image(self, p, script_pp: scripts.PostprocessImageArgs, *args):
106
+ if self.enable and self.swap_in_generated:
107
+ if self.source is not None:
108
+ logger.status(f"Working: source face index %s, target face index %s", self.source_faces_index, self.faces_index)
109
+ image: Image.Image = script_pp.image
110
+ result = swap_face(
111
+ self.source,
112
+ image,
113
+ source_faces_index=self.source_faces_index,
114
+ faces_index=self.faces_index,
115
+ model=self.model,
116
+ upscale_options=self.upscale_options,
117
+ gender_source=self.gender_source,
118
+ gender_target=self.gender_target,
119
+ )
120
+ try:
121
+ pp = scripts_postprocessing.PostprocessedImage(result)
122
+ pp.info = {}
123
+ p.extra_generation_params.update(pp.info)
124
+ script_pp.image = pp.image
125
+ except:
126
+ logger.error(f"Cannot create a result image")
scripts/reactor_logger.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import copy
3
+ import sys
4
+
5
+ from modules import shared
6
+ from reactor_utils import addLoggingLevel
7
+
8
+
9
+ class ColoredFormatter(logging.Formatter):
10
+ COLORS = {
11
+ "DEBUG": "\033[0;36m", # CYAN
12
+ "STATUS": "\033[38;5;173m", # Calm ORANGE
13
+ "INFO": "\033[0;32m", # GREEN
14
+ "WARNING": "\033[0;33m", # YELLOW
15
+ "ERROR": "\033[0;31m", # RED
16
+ "CRITICAL": "\033[0;37;41m", # WHITE ON RED
17
+ "RESET": "\033[0m", # RESET COLOR
18
+ }
19
+
20
+ def format(self, record):
21
+ colored_record = copy.copy(record)
22
+ levelname = colored_record.levelname
23
+ seq = self.COLORS.get(levelname, self.COLORS["RESET"])
24
+ colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
25
+ return super().format(colored_record)
26
+
27
+
28
+ # Create a new logger
29
+ logger = logging.getLogger("ReActor")
30
+ logger.propagate = False
31
+
32
+ # Add Custom Level
33
+ # logging.addLevelName(logging.INFO, "STATUS")
34
+ addLoggingLevel("STATUS", logging.INFO + 5)
35
+
36
+ # Add handler if we don't have one.
37
+ if not logger.handlers:
38
+ handler = logging.StreamHandler(sys.stdout)
39
+ handler.setFormatter(
40
+ ColoredFormatter("[%(name)s] %(asctime)s - %(levelname)s - %(message)s",datefmt="%H:%M:%S")
41
+ )
42
+ logger.addHandler(handler)
43
+
44
+ # Configure logger
45
+ loglevel_string = getattr(shared.cmd_opts, "reactor_loglevel", "INFO")
46
+ loglevel = getattr(logging, loglevel_string.upper(), "info")
47
+ logger.setLevel(loglevel)
scripts/reactor_swapper.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import shutil
4
+ from dataclasses import dataclass
5
+ from typing import List, Union
6
+
7
+ import cv2
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ import insightface
12
+ from insightface.app.common import Face
13
+ try:
14
+ import torch.cuda as cuda
15
+ except:
16
+ cuda = None
17
+
18
+ from scripts.reactor_logger import logger
19
+ from reactor_utils import move_path, get_image_md5hash
20
+ import folder_paths
21
+
22
+ import warnings
23
+
24
+ np.warnings = warnings
25
+ np.warnings.filterwarnings('ignore')
26
+
27
+ if cuda is not None:
28
+ if cuda.is_available():
29
+ providers = ["CUDAExecutionProvider"]
30
+ else:
31
+ providers = ["CPUExecutionProvider"]
32
+ else:
33
+ providers = ["CPUExecutionProvider"]
34
+
35
+ models_path_old = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
36
+ insightface_path_old = os.path.join(models_path_old, "insightface")
37
+ insightface_models_path_old = os.path.join(insightface_path_old, "models")
38
+
39
+ models_path = folder_paths.models_dir
40
+ insightface_path = os.path.join(models_path, "insightface")
41
+ insightface_models_path = os.path.join(insightface_path, "models")
42
+
43
+ if os.path.exists(models_path_old):
44
+ move_path(insightface_models_path_old, insightface_models_path)
45
+ move_path(insightface_path_old, insightface_path)
46
+ move_path(models_path_old, models_path)
47
+ if os.path.exists(insightface_path) and os.path.exists(insightface_path_old):
48
+ shutil.rmtree(insightface_path_old)
49
+ shutil.rmtree(models_path_old)
50
+
51
+
52
+ FS_MODEL = None
53
+ CURRENT_FS_MODEL_PATH = None
54
+
55
+ ANALYSIS_MODEL = None
56
+
57
+ SOURCE_FACES = None
58
+ SOURCE_IMAGE_HASH = None
59
+ TARGET_FACES = None
60
+ TARGET_IMAGE_HASH = None
61
+
62
+ def get_current_faces_model():
63
+ global SOURCE_FACES
64
+ return SOURCE_FACES
65
+
66
+ def getAnalysisModel():
67
+ global ANALYSIS_MODEL
68
+ if ANALYSIS_MODEL is None:
69
+ ANALYSIS_MODEL = insightface.app.FaceAnalysis(
70
+ name="buffalo_l", providers=providers, root=insightface_path
71
+ )
72
+ return ANALYSIS_MODEL
73
+
74
+
75
+ def getFaceSwapModel(model_path: str):
76
+ global FS_MODEL
77
+ global CURRENT_FS_MODEL_PATH
78
+ if CURRENT_FS_MODEL_PATH is None or CURRENT_FS_MODEL_PATH != model_path:
79
+ CURRENT_FS_MODEL_PATH = model_path
80
+ FS_MODEL = insightface.model_zoo.get_model(model_path, providers=providers)
81
+
82
+ return FS_MODEL
83
+
84
+
85
+ def get_face_gender(
86
+ face,
87
+ face_index,
88
+ gender_condition,
89
+ operated: str
90
+ ):
91
+ gender = [
92
+ x.sex
93
+ for x in face
94
+ ]
95
+ gender.reverse()
96
+ # If index is outside of bounds, return None, avoid exception
97
+ if face_index >= len(gender):
98
+ logger.status("Requested face index (%s) is out of bounds (max available index is %s)", face_index, len(gender))
99
+ return None, 0
100
+ face_gender = gender[face_index]
101
+ logger.status("%s Face %s: Detected Gender -%s-", operated, face_index, face_gender)
102
+ if (gender_condition == 1 and face_gender == "F") or (gender_condition == 2 and face_gender == "M"):
103
+ logger.status("OK - Detected Gender matches Condition")
104
+ try:
105
+ return sorted(face, key=lambda x: x.bbox[0])[face_index], 0
106
+ except IndexError:
107
+ return None, 0
108
+ else:
109
+ logger.status("WRONG - Detected Gender doesn't match Condition")
110
+ return sorted(face, key=lambda x: x.bbox[0])[face_index], 1
111
+
112
+
113
+ def half_det_size(det_size):
114
+ logger.status("Trying to halve 'det_size' parameter")
115
+ return (det_size[0] // 2, det_size[1] // 2)
116
+
117
+ def analyze_faces(img_data: np.ndarray, det_size=(640, 640)):
118
+ face_analyser = copy.deepcopy(getAnalysisModel())
119
+ face_analyser.prepare(ctx_id=0, det_size=det_size)
120
+ return face_analyser.get(img_data)
121
+
122
+ def get_face_single(img_data: np.ndarray, face, face_index=0, det_size=(640, 640), gender_source=0, gender_target=0):
123
+
124
+ buffalo_path = os.path.join(insightface_models_path, "buffalo_l.zip")
125
+ if os.path.exists(buffalo_path):
126
+ os.remove(buffalo_path)
127
+
128
+ if gender_source != 0:
129
+ if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
130
+ det_size_half = half_det_size(det_size)
131
+ return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target)
132
+ return get_face_gender(face,face_index,gender_source,"Source")
133
+
134
+ if gender_target != 0:
135
+ if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
136
+ det_size_half = half_det_size(det_size)
137
+ return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target)
138
+ return get_face_gender(face,face_index,gender_target,"Target")
139
+
140
+ if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
141
+ det_size_half = half_det_size(det_size)
142
+ return get_face_single(img_data, analyze_faces(img_data, det_size_half), face_index, det_size_half, gender_source, gender_target)
143
+
144
+ try:
145
+ return sorted(face, key=lambda x: x.bbox[0])[face_index], 0
146
+ except IndexError:
147
+ return None, 0
148
+
149
+
150
+ def swap_face(
151
+ source_img: Union[Image.Image, None],
152
+ target_img: Image.Image,
153
+ model: Union[str, None] = None,
154
+ source_faces_index: List[int] = [0],
155
+ faces_index: List[int] = [0],
156
+ gender_source: int = 0,
157
+ gender_target: int = 0,
158
+ face_model: Union[Face, None] = None,
159
+ ):
160
+ global SOURCE_FACES, SOURCE_IMAGE_HASH, TARGET_FACES, TARGET_IMAGE_HASH
161
+ result_image = target_img
162
+
163
+ if model is not None:
164
+
165
+ if isinstance(source_img, str): # source_img is a base64 string
166
+ import base64, io
167
+ if 'base64,' in source_img: # check if the base64 string has a data URL scheme
168
+ # split the base64 string to get the actual base64 encoded image data
169
+ base64_data = source_img.split('base64,')[-1]
170
+ # decode base64 string to bytes
171
+ img_bytes = base64.b64decode(base64_data)
172
+ else:
173
+ # if no data URL scheme, just decode
174
+ img_bytes = base64.b64decode(source_img)
175
+
176
+ source_img = Image.open(io.BytesIO(img_bytes))
177
+
178
+ target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
179
+
180
+ if source_img is not None:
181
+
182
+ source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
183
+
184
+ source_image_md5hash = get_image_md5hash(source_img)
185
+
186
+ if SOURCE_IMAGE_HASH is None:
187
+ SOURCE_IMAGE_HASH = source_image_md5hash
188
+ source_image_same = False
189
+ else:
190
+ source_image_same = True if SOURCE_IMAGE_HASH == source_image_md5hash else False
191
+ if not source_image_same:
192
+ SOURCE_IMAGE_HASH = source_image_md5hash
193
+
194
+ logger.info("Source Image MD5 Hash = %s", SOURCE_IMAGE_HASH)
195
+ logger.info("Source Image the Same? %s", source_image_same)
196
+
197
+ if SOURCE_FACES is None or not source_image_same:
198
+ logger.status("Analyzing Source Image...")
199
+ source_faces = analyze_faces(source_img)
200
+ SOURCE_FACES = source_faces
201
+ elif source_image_same:
202
+ logger.status("Using Hashed Source Face(s) Model...")
203
+ source_faces = SOURCE_FACES
204
+
205
+ elif face_model is not None:
206
+
207
+ source_faces_index = [0]
208
+ logger.status("Using Loaded Source Face Model...")
209
+ source_face_model = [face_model]
210
+ source_faces = source_face_model
211
+
212
+ else:
213
+ logger.error("Cannot detect any Source")
214
+
215
+ if source_faces is not None:
216
+
217
+ target_image_md5hash = get_image_md5hash(target_img)
218
+
219
+ if TARGET_IMAGE_HASH is None:
220
+ TARGET_IMAGE_HASH = target_image_md5hash
221
+ target_image_same = False
222
+ else:
223
+ target_image_same = True if TARGET_IMAGE_HASH == target_image_md5hash else False
224
+ if not target_image_same:
225
+ TARGET_IMAGE_HASH = target_image_md5hash
226
+
227
+ logger.info("Target Image MD5 Hash = %s", TARGET_IMAGE_HASH)
228
+ logger.info("Target Image the Same? %s", target_image_same)
229
+
230
+ if TARGET_FACES is None or not target_image_same:
231
+ logger.status("Analyzing Target Image...")
232
+ target_faces = analyze_faces(target_img)
233
+ TARGET_FACES = target_faces
234
+ elif target_image_same:
235
+ logger.status("Using Hashed Target Face(s) Model...")
236
+ target_faces = TARGET_FACES
237
+
238
+ # No use in trying to swap faces if no faces are found, enhancement
239
+ if len(target_faces) == 0:
240
+ logger.status("Cannot detect any Target, skipping swapping...")
241
+ return result_image
242
+
243
+ if source_img is not None:
244
+ # separated management of wrong_gender between source and target, enhancement
245
+ source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[0], gender_source=gender_source)
246
+ else:
247
+ source_face = sorted(source_faces, key=lambda x: x.bbox[0])[source_faces_index[0]]
248
+ src_wrong_gender = 0
249
+
250
+ if len(source_faces_index) != 0 and len(source_faces_index) != 1 and len(source_faces_index) != len(faces_index):
251
+ logger.status(f'Source Faces must have no entries (default=0), one entry, or same number of entries as target faces.')
252
+ elif source_face is not None:
253
+ result = target_img
254
+ model_path = model_path = os.path.join(insightface_path, model)
255
+ face_swapper = getFaceSwapModel(model_path)
256
+
257
+ source_face_idx = 0
258
+
259
+ for face_num in faces_index:
260
+ # No use in trying to swap faces if no further faces are found, enhancement
261
+ if face_num >= len(target_faces):
262
+ logger.status("Checked all existing target faces, skipping swapping...")
263
+ break
264
+
265
+ if len(source_faces_index) > 1 and source_face_idx > 0:
266
+ source_face, src_wrong_gender = get_face_single(source_img, source_faces, face_index=source_faces_index[source_face_idx], gender_source=gender_source)
267
+ source_face_idx += 1
268
+
269
+ if source_face is not None and src_wrong_gender == 0:
270
+ target_face, wrong_gender = get_face_single(target_img, target_faces, face_index=face_num, gender_target=gender_target)
271
+ if target_face is not None and wrong_gender == 0:
272
+ logger.status(f"Swapping...")
273
+ result = face_swapper.get(result, target_face, source_face)
274
+ elif wrong_gender == 1:
275
+ wrong_gender = 0
276
+ # Keep searching for other faces if wrong gender is detected, enhancement
277
+ #if source_face_idx == len(source_faces_index):
278
+ # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
279
+ # return result_image
280
+ logger.status("Wrong target gender detected")
281
+ continue
282
+ else:
283
+ logger.status(f"No target face found for {face_num}")
284
+ elif src_wrong_gender == 1:
285
+ src_wrong_gender = 0
286
+ # Keep searching for other faces if wrong gender is detected, enhancement
287
+ #if source_face_idx == len(source_faces_index):
288
+ # result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
289
+ # return result_image
290
+ logger.status("Wrong source gender detected")
291
+ continue
292
+ else:
293
+ logger.status(f"No source face found for face number {source_face_idx}.")
294
+
295
+ result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
296
+
297
+ else:
298
+ logger.status("No source face(s) in the provided Index")
299
+ else:
300
+ logger.status("No source face(s) found")
301
+ return result_image
scripts/reactor_version.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ app_title = "ReActor Node for ComfyUI"
2
+ version_flag = "v0.4.1-b12"
3
+
4
+ COLORS = {
5
+ "CYAN": "\033[0;36m", # CYAN
6
+ "ORANGE": "\033[38;5;173m", # Calm ORANGE
7
+ "GREEN": "\033[0;32m", # GREEN
8
+ "YELLOW": "\033[0;33m", # YELLOW
9
+ "RED": "\033[0;91m", # RED
10
+ "0": "\033[0m", # RESET COLOR
11
+ }
12
+
13
+ print(f"{COLORS['YELLOW']}[ReActor]{COLORS['0']} - {COLORS['ORANGE']}STATUS{COLORS['0']} - {COLORS['GREEN']}Running {version_flag} in ComfyUI{COLORS['0']}")