Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch.nn as nn | |
from models.model_blocks import ResBlock | |
class Discriminator(nn.Module): | |
def __init__(self, input_nc, ndf=64, n_layers=6): | |
super(Discriminator, self).__init__() | |
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1)] | |
for i in range(n_layers): | |
if i >= 3: | |
sequence += [ResBlock(512, 512, down_sample=True, norm=False)] | |
else: | |
mult = 2**i | |
sequence += [ResBlock(ndf * mult, ndf * mult * 2, down_sample=True, norm=False)] | |
sequence += [ | |
nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(512, 2, kernel_size=1, stride=1, padding=0), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
self.sequence = nn.Sequential(*sequence) | |
def forward(self, input): | |
return self.sequence(input) | |