File size: 952 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import torch.nn as nn

from models.model_blocks import ResBlock


class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=6):
        super(Discriminator, self).__init__()
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1)]
        for i in range(n_layers):
            if i >= 3:
                sequence += [ResBlock(512, 512, down_sample=True, norm=False)]
            else:
                mult = 2**i
                sequence += [ResBlock(ndf * mult, ndf * mult * 2, down_sample=True, norm=False)]
        sequence += [
            nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 2, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, inplace=True),
        ]
        self.sequence = nn.Sequential(*sequence)

    def forward(self, input):
        return self.sequence(input)