JoJoGan-powerhow2 / e4e /models /discriminator.py
Sanket
.
3d37b6e
raw
history blame
496 Bytes
from torch import nn
class LatentCodesDiscriminator(nn.Module):
def __init__(self, style_dim, n_mlp):
super().__init__()
self.style_dim = style_dim
layers = []
for i in range(n_mlp-1):
layers.append(
nn.Linear(style_dim, style_dim)
)
layers.append(nn.LeakyReLU(0.2))
layers.append(nn.Linear(512, 1))
self.mlp = nn.Sequential(*layers)
def forward(self, w):
return self.mlp(w)