junjuice0 commited on
Commit
1388c9c
1 Parent(s): d58cf13

Upload testvae.py

Browse files
Files changed (1) hide show
  1. testvae.py +99 -0
testvae.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torchvision
4
+ from torchvision import transforms
5
+ from torchvision.transforms.functional import to_pil_image, to_tensor
6
+ import glob
7
+ from PIL import Image
8
+ import tqdm
9
+ import gc
10
+
11
+ class TestModel(torch.nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.conv1 = torch.nn.Conv2d(3, 16, 5, 1, 2, bias=False)
15
+ self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
16
+ self.conv3 = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=True)
17
+ self.bn1 = torch.nn.BatchNorm2d(16)
18
+ self.bn2 = torch.nn.BatchNorm2d(16)
19
+
20
+ def forward(self, x):
21
+ x = self.conv1(x)
22
+ x = self.bn1(x)
23
+ x = self.conv2(x)
24
+ x = self.bn2(x)
25
+ x = self.conv3(x)
26
+ x = torch.clamp(x, -1, 1)
27
+ return x
28
+
29
+ class DS(Dataset):
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.g = glob.glob("./15k/*")
33
+ self.trans = transforms.Compose([
34
+ transforms.RandomCrop((256, 256)),
35
+ transforms.ToTensor()
36
+ ])
37
+
38
+ def __len__(self):
39
+ return len(self.g)
40
+
41
+ def __getitem__(self, idx):
42
+ x = self.g[idx]
43
+ x = Image.open(x)
44
+ x = x.convert("RGB")
45
+ x = self.trans(x)
46
+ x = x / 127.5 - 1
47
+ return x
48
+
49
+ def gettest(self):
50
+ x = self.g[0]
51
+ x = Image.open(x)
52
+ x = x.convert("RGB")
53
+ x = to_tensor(x)
54
+ x = x / 127.5 - 1
55
+ return x
56
+
57
+ def main():
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ bacth_size = 64
60
+ epoch = 10
61
+
62
+ model = TestModel()
63
+ dataset = DS()
64
+ datalaoder = DataLoader(dataset, batch_size=bacth_size, shuffle=True)
65
+ criterion = torch.nn.MSELoss()
66
+ kl = torch.nn.KLDivLoss(size_average=False)
67
+ optim = torch.optim.Adam(model.parameters(recurse=True), lr=1e-4)
68
+ criterion = criterion.to(device)
69
+ model = model.to(device)
70
+ model.train()
71
+
72
+ def log(l):
73
+ model.eval()
74
+ x = dataset.gettest().to(device)
75
+ x = x.unsqueeze(0)
76
+ out = model(x)
77
+ to_pil_image((out[0] + 1)/2).save("./test/" + str(l) + ".png")
78
+ model.train()
79
+
80
+ log("test")
81
+
82
+ for i in range(epoch):
83
+ for j, k in enumerate(tqdm.tqdm(datalaoder)):
84
+ k = k.to(device)
85
+ model.zero_grad()
86
+ out = model(k)
87
+ loss = criterion(out, k)# + kl(((out + 1)/2).log(), (k + 1)/2)
88
+ loss.backward()
89
+ optim.step()
90
+ if j % 100 == 0:
91
+ gc.collect()
92
+ torch.cuda.empty_cache()
93
+ print("EPOCH", i)
94
+ print("LAST LOSS", loss)
95
+ log(i)
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()