File size: 2,898 Bytes
1388c9c
 
 
 
 
 
 
 
 
 
 
 
 
96a87c7
 
1388c9c
96a87c7
 
1388c9c
 
 
 
96a87c7
1388c9c
96a87c7
 
 
1388c9c
96a87c7
1388c9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image, to_tensor
import glob
from PIL import Image
import tqdm
import gc

class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.start = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.conv1 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
        self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
        self.conv3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
        self.final = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(16)
        self.bn2 = torch.nn.BatchNorm2d(16)

    def forward(self, x):
        x = self.start(x)
        x = self.bn1(x)
        x = self.conv1(x) + x
        x = self.conv2(x) + x
        x = self.conv3(x) + x
        x = self.bn2(x)
        x = self.final(x)
        x = torch.clamp(x, -1, 1)
        return x
    
class DS(Dataset):
    def __init__(self):
        super().__init__()
        self.g = glob.glob("./15k/*")
        self.trans = transforms.Compose([
            transforms.RandomCrop((256, 256)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.g)
    
    def __getitem__(self, idx):
        x = self.g[idx]
        x = Image.open(x)
        x = x.convert("RGB")
        x = self.trans(x)
        x = x / 127.5 - 1
        return x
    
    def gettest(self):
        x = self.g[0]
        x = Image.open(x)
        x = x.convert("RGB")
        x = to_tensor(x)
        x = x / 127.5 - 1
        return x
    
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    bacth_size = 64
    epoch = 10

    model = TestModel()
    dataset = DS()
    datalaoder = DataLoader(dataset, batch_size=bacth_size, shuffle=True)
    criterion = torch.nn.MSELoss()
    kl = torch.nn.KLDivLoss(size_average=False)
    optim = torch.optim.Adam(model.parameters(recurse=True), lr=1e-4)
    criterion = criterion.to(device)
    model = model.to(device)
    model.train()

    def log(l):
        model.eval()
        x = dataset.gettest().to(device)
        x = x.unsqueeze(0)
        out = model(x)
        to_pil_image((out[0] + 1)/2).save("./test/" + str(l) + ".png")
        model.train()

    log("test")

    for i in range(epoch):
        for j, k in enumerate(tqdm.tqdm(datalaoder)):
            k = k.to(device)
            model.zero_grad()
            out = model(k)
            loss = criterion(out, k)# + kl(((out + 1)/2).log(), (k + 1)/2)
            loss.backward()
            optim.step()
            if j % 100 == 0:
                gc.collect()
                torch.cuda.empty_cache()
        print("EPOCH", i)
        print("LAST LOSS", loss)
        log(i)
        
        
if __name__ == "__main__":
    main()