HienK64BKHN commited on
Commit
ec47fb5
1 Parent(s): 8262048

Upload 7 files

Browse files
Unet.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision.transforms import functional as f
4
+
5
+ class UNet(torch.nn.Module):
6
+ def __init__(self, device, in_channels: int = 3, num_classes: int = 3) -> None:
7
+ super().__init__()
8
+ self.block_1 = nn.Sequential(
9
+ nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), #-> Channels = 64
10
+ nn.BatchNorm2d(64, device=device),
11
+ nn.ReLU(),
12
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64
13
+ nn.BatchNorm2d(64, device=device),
14
+ nn.ReLU()
15
+ )
16
+
17
+ self.max_pool_2x2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
18
+
19
+ self.block_2 = nn.Sequential(
20
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128
21
+ nn.BatchNorm2d(128, device=device),
22
+ nn.ReLU(),
23
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128
24
+ nn.BatchNorm2d(128, device=device),
25
+ nn.ReLU()
26
+ )
27
+
28
+ self.max_pool_2x2_2 = nn.MaxPool2d(kernel_size=2, stride=2)
29
+
30
+ self.block_3 = nn.Sequential(
31
+ nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256
32
+ nn.BatchNorm2d(256, device=device),
33
+ nn.ReLU(),
34
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256
35
+ nn.BatchNorm2d(256, device=device),
36
+ nn.ReLU()
37
+ )
38
+
39
+ self.max_pool_2x2_3 = nn.MaxPool2d(kernel_size=2, stride=2)
40
+
41
+ self.block_4 = nn.Sequential(
42
+ nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512
43
+ nn.BatchNorm2d(512, device=device),
44
+ nn.ReLU(),
45
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512
46
+ nn.BatchNorm2d(512, device=device),
47
+ nn.ReLU()
48
+ )
49
+
50
+ self.drop_out_1 = nn.Dropout(p=0.5)
51
+
52
+ self.max_pool_2x2_4 = nn.MaxPool2d(kernel_size=2, stride=2)
53
+
54
+ self.block_5 = nn.Sequential(
55
+ nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024
56
+ nn.BatchNorm2d(1024, device=device),
57
+ nn.ReLU(),
58
+ nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024
59
+ nn.BatchNorm2d(1024, device=device),
60
+ nn.ReLU()
61
+ )
62
+
63
+ self.drop_out_2 = nn.Dropout(p=0.5)
64
+
65
+ self.up_conv_2x2_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 512
66
+
67
+ #after up_sampled, the tensor will be concatenate with the output of the block_4 which is a 512-channels tensor
68
+ # so that the tensor to put in the block 6 will be a (512 + 512)-channels = 1024-channels tensor
69
+
70
+ self.block_6 = nn.Sequential(
71
+ nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 512
72
+ nn.BatchNorm2d(512, device=device),
73
+ nn.ReLU(),
74
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512
75
+ nn.BatchNorm2d(512, device=device),
76
+ nn.ReLU()
77
+ )
78
+
79
+ self.drop_out_3 = nn.Dropout(p=0.5)
80
+
81
+ self.up_conv_2x2_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 256
82
+ #The same as up_conv_2x2_1
83
+
84
+ self.block_7 = nn.Sequential(
85
+ nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 256
86
+ nn.BatchNorm2d(256, device=device),
87
+ nn.ReLU(),
88
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256
89
+ nn.BatchNorm2d(256, device=device),
90
+ nn.ReLU()
91
+ )
92
+
93
+ self.up_conv_2x2_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 128
94
+ #The same as up_conv_2x2_1
95
+
96
+ self.block_8 = nn.Sequential(
97
+ nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 128
98
+ nn.BatchNorm2d(128, device=device),
99
+ nn.ReLU(),
100
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128
101
+ nn.BatchNorm2d(128, device=device),
102
+ nn.ReLU()
103
+ )
104
+
105
+ self.up_conv_2x2_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 64
106
+ #The same as up_conv_2x2_1
107
+
108
+ self.block_9 = nn.Sequential(
109
+ nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 64
110
+ nn.BatchNorm2d(64, device=device),
111
+ nn.ReLU(),
112
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64
113
+ nn.BatchNorm2d(64, device=device),
114
+ nn.ReLU()
115
+ )
116
+
117
+ self.last_conv_1x1 = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1, stride=1, padding=0, device=device) # -> channels = num_classes (default = 3 for [background, borders, objects])
118
+
119
+ def forward(self, x):
120
+ block_1_result = self.block_1(x)
121
+ block_2_result = self.block_2(self.max_pool_2x2_1(block_1_result))
122
+ block_3_result = self.block_3(self.max_pool_2x2_2(block_2_result))
123
+ block_4_result = self.block_4(self.max_pool_2x2_3(block_3_result))
124
+ block_4_result = self.drop_out_1(block_4_result)
125
+ block_5_result = self.block_5(self.max_pool_2x2_4(block_4_result))
126
+ block_5_result = self.drop_out_2(block_5_result)
127
+
128
+ up_conv_1_result = self.up_conv_2x2_1(block_5_result)
129
+ block_4_result = f.center_crop(block_4_result, [up_conv_1_result.shape[2], up_conv_1_result.shape[3]])
130
+ concat_1_result = torch.cat([block_4_result, up_conv_1_result], axis=1)
131
+
132
+ block_6_result = self.block_6(concat_1_result)
133
+ block_6_result = self.drop_out_3(block_6_result)
134
+
135
+ up_conv_2_result = self.up_conv_2x2_2(block_6_result)
136
+ block_3_result = f.center_crop(block_3_result, [up_conv_2_result.shape[2], up_conv_2_result.shape[3]])
137
+ concat_2_result = torch.cat([block_3_result, up_conv_2_result], axis=1)
138
+
139
+ block_7_result = self.block_7(concat_2_result)
140
+
141
+ up_conv_3_result = self.up_conv_2x2_3(block_7_result)
142
+ block_2_result = f.center_crop(block_2_result, [up_conv_3_result.shape[2], up_conv_3_result.shape[3]])
143
+ concat_3_result = torch.cat([block_2_result, up_conv_3_result], axis=1)
144
+
145
+ block_8_result = self.block_8(concat_3_result)
146
+
147
+ up_conv_4_result = self.up_conv_2x2_4(block_8_result)
148
+ block_1_result = f.center_crop(block_1_result, [up_conv_4_result.shape[2], up_conv_4_result.shape[3]])
149
+ concat_4_result = torch.cat([block_1_result, up_conv_4_result], axis=1)
150
+
151
+ block_9_result = self.block_9(concat_4_result)
152
+
153
+ last_block_result = self.last_conv_1x1(block_9_result)
154
+
155
+ return last_block_result
156
+
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from Unet import UNet
4
+ import torchvision
5
+ from torchvision.transforms import functional as f
6
+ import os
7
+ from timeit import default_timer as timer
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+
11
+ model = UNet(device=device, in_channels=3, num_classes=3)
12
+ model.load_state_dict(torch.load("./data/models/Unet_v1.pth"))
13
+
14
+ image_transforms = torchvision.transforms.Compose([
15
+ torchvision.transforms.Resize(size=(128, 128)),
16
+ torchvision.transforms.ToTensor()
17
+ ])
18
+
19
+ def predict(img):
20
+ start_time = timer()
21
+
22
+ img_transformed = image_transforms(img).to(device)
23
+
24
+ model.eval()
25
+ with torch.inference_mode():
26
+ y_logits = model(img_transformed.unsqueeze(dim=0)).squeeze(dim=0)
27
+ predicted_label = torch.argmax(y_logits, dim=0).to('cpu')
28
+
29
+ for i in range(3):
30
+ for j in range(128):
31
+ for z in range(128):
32
+ img_transformed[i][j][z] = predicted_label[j][z]
33
+
34
+ img_transformed = f.to_pil_image(img_transformed)
35
+
36
+ return img_transformed, round((timer() - start_time), 3)
37
+
38
+ title = "Animal Segmentation"
39
+ description = "An UNet* feature extractor computer vision model to segment animal in an image.\nModel works more precisely on an image that only contains just one animal."
40
+ article = "U-Net: Convolutional Networks for Biomedical Image Segmentation (https://arxiv.org/abs/1505.04597)"
41
+
42
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
43
+
44
+ demo = gr.Interface(fn=predict, # mapping function from input to output
45
+ inputs=gr.Image(type="pil"), # what are the inputs?
46
+ outputs=[gr.Pil(label="Segmentation"), # what are the outputs?
47
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
48
+ # Create examples list from "examples/" directory
49
+ examples=example_list,
50
+ title=title,
51
+ description=description,
52
+ article=article)
53
+
54
+ demo.launch()
data/models/Unet_v1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:969e0f4e4c3e88c255842409ca8d40b856a49305cdc11d3bd86046f7c017a7ab
3
+ size 124265073
examples/Abyssinian_13.jpg ADDED
examples/Abyssinian_177.jpg ADDED
examples/american_bulldog_173.jpg ADDED
examples/american_pit_bull_terrier_5.jpg ADDED