dawn17 commited on
Commit
bcc0f94
1 Parent(s): e0678c3

Upload 35 files

Browse files
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ from src.run.unet.inference import ResUnetInfer
6
+
7
+
8
+ infer = ResUnetInfer(
9
+ model_path="./checkpoint/resunet/decoder.pt",
10
+ config_path="./src/models/unet/config/resnet_config.yml",
11
+ )
12
+
13
+ demo = gr.Interface(
14
+ fn=infer.infer,
15
+ inputs=[
16
+ gr.Image(
17
+ shape=(224, 224),
18
+ label="Input Image",
19
+ value="./sample/bird_plane.jpeg",
20
+ )
21
+ ],
22
+ outputs=[
23
+ gr.Image(),
24
+ ],
25
+ examples=[[os.path.join("./sample/", f)] for f in os.listdir("./sample/")],
26
+ )
27
+
28
+
29
+ demo.launch()
checkpoint/resunet/decoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df2780f1ec58f0a9653c951b341102097ef20a8bbd9cd9aba2ea8e789876b9ae
3
+ size 189285667
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchinfo
3
+ easydict
4
+ gradio
5
+ torchvision
6
+ numpy
7
+ grad - cam
8
+ Pillow
9
+ albumentations
10
+ tqdm
11
+ opencv - python
12
+ matplotlib
sample/bird_plane.jpeg ADDED
sample/dog.jpeg ADDED
sample/group.webp ADDED
sample/horse_person_cycle.jpeg ADDED
sample/mask.jpeg ADDED
sample/people.jpeg ADDED
sample/titanic.jpeg ADDED
src/datasets/__init__.py ADDED
File without changes
src/datasets/coco/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Coco Dataset Sample
2
+
3
+ ![Image1](samples/people.png)
4
+ ![Image2](samples/giraffe.png)
5
+ ![Image3](samples/airplane.png)
6
+ ![Image4](samples/zebra.png)
src/datasets/coco/dataset.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/datasets/coco/dataset.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Any, Callable, List, Optional, Tuple
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from PIL import Image
7
+ from torchvision.datasets import VisionDataset
8
+
9
+
10
+ class CocoDetection(VisionDataset):
11
+ def __init__(
12
+ self,
13
+ root: str,
14
+ annFile: str,
15
+ class_names: Optional[List] = None,
16
+ transform: Optional[Callable] = None,
17
+ target_transform: Optional[Callable] = None,
18
+ transforms: Optional[Callable] = None,
19
+ ) -> None:
20
+ super().__init__(root, transforms, transform, target_transform)
21
+ from pycocotools.coco import COCO
22
+
23
+ self.coco = COCO(annFile)
24
+
25
+ if class_names is not None:
26
+ cat_ids = self._get_category_ids_from_name(category_names=class_names)
27
+ self.ids = list(
28
+ sorted((self._get_img_ids_for_category_ids(category_ids=cat_ids)))
29
+ )
30
+
31
+ else:
32
+ cat_ids = self.coco.getCatIds()
33
+ self.ids = list(sorted(self.coco.imgs.keys()))
34
+
35
+ self.cat2idx = {cat_id: idx + 1 for idx, cat_id in enumerate(cat_ids)}
36
+ self.cat2idx[0] = 0
37
+
38
+ def _load_image(self, id: int) -> Image.Image:
39
+ path = self.coco.loadImgs(id)[0]["file_name"]
40
+ return Image.open(os.path.join(self.root, path)).convert("RGB")
41
+
42
+ def _load_target(self, id: int) -> List[Any]:
43
+ return self.coco.loadAnns(self.coco.getAnnIds(id))
44
+
45
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
46
+ id = self.ids[index]
47
+ image = self._load_image(id)
48
+ mask = self._load_target(id)
49
+
50
+ mask = self._get_mask_in_channels(image, mask)
51
+
52
+ if self.transform is not None:
53
+ image = self.transform(image=np.array(image))["image"]
54
+
55
+ if self.target_transform is not None:
56
+ mask = self.target_transform(image=mask)["image"]
57
+
58
+ return image, (mask != 0).int()
59
+
60
+ def __len__(self) -> int:
61
+ return len(self.ids)
62
+
63
+ def _get_all_classes(self):
64
+ catIDs = self.coco.getCatIds()
65
+ return self.coco.loadCats(catIDs)
66
+
67
+ def _get_category_info_from_ids(self, ids: list):
68
+ all_cat = self._get_all_classes()
69
+ return [category for category in all_cat if category["id"] in ids]
70
+
71
+ def _get_category_ids_from_name(self, category_names: list):
72
+ return self.coco.getCatIds(catNms=category_names)
73
+
74
+ def _get_img_ids_for_category_ids(self, category_ids: list):
75
+ img_ids = []
76
+
77
+ for catIds in category_ids:
78
+ img_ids.extend(self.coco.getImgIds(catIds=catIds))
79
+
80
+ return img_ids
81
+
82
+ def _get_img_ids_for_category_names(self, category_names: list):
83
+ img_ids = []
84
+ category_ids = self._get_category_ids_from_name(category_names=class_names)
85
+
86
+ for catIds in category_ids:
87
+ img_ids.extend(self.coco.getImgIds(catIds=catIds))
88
+
89
+ return img_ids
90
+
91
+ def _get_all_category_ids_in_img_id(self, img_id: int) -> List:
92
+ target = self._load_target(img_id)
93
+ return list({annotation["category_id"] for annotation in target})
94
+
95
+ def _get_mask_aggregated(self, image: Image, annotations: List) -> np.array:
96
+ w, h = image.size
97
+ mask = np.zeros((h, w))
98
+
99
+ for annotation in annotations:
100
+ category_id = annotation["category_id"]
101
+
102
+ if category_id in self.cat2idx:
103
+ pixel_value = self.cat2idx[category_id]
104
+ mask = np.maximum(self.coco.annToMask(annotation) * pixel_value, mask)
105
+
106
+ return mask
107
+
108
+ def _get_mask_in_channels(self, image: Image, annotations: List) -> np.array:
109
+ w, h = image.size
110
+ mask = np.zeros((len(self.cat2idx), h, w))
111
+
112
+ for annotation in annotations:
113
+ category_id = annotation["category_id"]
114
+
115
+ if category_id in self.cat2idx:
116
+ pixel_value = self.cat2idx[category_id]
117
+ mask[pixel_value] = np.maximum(
118
+ self.coco.annToMask(annotation), mask[pixel_value]
119
+ )
120
+
121
+ # [h, w, channels]
122
+ mask = np.transpose(mask, (1, 2, 0))
123
+ return mask
124
+
125
+ def _plot_image_and_mask(self, index):
126
+ image, mask = self.__getitem__(index)
127
+
128
+ # Create a figure with two subplots side by side
129
+ fig, axs = plt.subplots(1, 2, figsize=(7, 3))
130
+
131
+ axs[0].imshow(image.permute(1, 2, 0))
132
+ axs[0].set_title("Image")
133
+
134
+ axs[1].imshow(mask.sum(0, keepdim=True).permute(1, 2, 0))
135
+ axs[1].set_title("Mask")
136
+
137
+ plt.show()
src/datasets/coco/samples/airplane.png ADDED
src/datasets/coco/samples/giraffe.png ADDED
src/datasets/coco/samples/people.png ADDED
src/datasets/coco/samples/zebra.png ADDED
src/models/unet/__init__.py ADDED
File without changes
src/models/unet/config/carvana_config.yml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Input (1, 512, 512)
2
+ # Output (64, 512, 512)
3
+ decoder_config:
4
+ block5: # (1024, 32, 32)
5
+ in_channels: 1024
6
+ kernel_size: 3
7
+ out_channels: 1024
8
+ padding:
9
+ - 1
10
+ - 1
11
+ stride: 1 # (1024, 32, 32)
12
+ block4: # (1024, 32, 32)
13
+ in_channels: 1024
14
+ kernel_size: 2
15
+ out_channels: 512
16
+ padding:
17
+ - 0
18
+ - 1
19
+ stride: 2 # (512, 64, 64)
20
+ block3: # (512, 64, 64)
21
+ in_channels: 512
22
+ kernel_size: 2
23
+ out_channels: 256
24
+ padding:
25
+ - 0
26
+ - 1
27
+ stride: 2 # (256, 128, 128)
28
+ block2: # (256, 128, 128)
29
+ in_channels: 256
30
+ kernel_size: 2
31
+ out_channels: 128
32
+ padding:
33
+ - 0
34
+ - 1
35
+ stride: 2 # (128, 256, 256)
36
+ block1: # (128, 256, 256)
37
+ in_channels: 128
38
+ kernel_size: 2
39
+ out_channels: 64
40
+ padding:
41
+ - 0
42
+ - 1
43
+ stride: 2 # (64, 512, 512)
44
+ encoder_config:
45
+ block1: # (1, 512, 512)
46
+ all_padding: true
47
+ in_channels: 1
48
+ maxpool: true
49
+ n_layers: 2
50
+ out_channels: 64 # (64, 256, 256)
51
+ block2: # (64, 256, 256)
52
+ all_padding: true
53
+ in_channels: 64
54
+ maxpool: true
55
+ n_layers: 2
56
+ out_channels: 128 # (128, 128, 128)
57
+ block3: # (128, 128, 128)
58
+ all_padding: true
59
+ in_channels: 128
60
+ maxpool: true
61
+ n_layers: 2
62
+ out_channels: 256 # (256, 64, 64)
63
+ block4: # (256, 64, 64)
64
+ all_padding: true
65
+ in_channels: 256
66
+ maxpool: true
67
+ n_layers: 2
68
+ out_channels: 512 # (512, 32, 32)
69
+ block5: # (512, 32, 32)
70
+ all_padding: true
71
+ in_channels: 512
72
+ maxpool: false
73
+ n_layers: 2
74
+ out_channels: 512 # (512, 32, 32)
75
+ block6: # (512, 32, 32)
76
+ all_padding: true
77
+ in_channels: 512
78
+ maxpool: false
79
+ n_layers: 2
80
+ out_channels: 1024 # (1024, 32, 32)
81
+ nclasses: 2
src/models/unet/config/paper_config.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original UNet Paper Configuration
2
+ # Input shape [1, 572, 572]
3
+ # Output shape [64, 388, 388]
4
+ decoder_config:
5
+ block4: # [1024, 28, 28]
6
+ in_channels: 1024
7
+ kernel_size: 2
8
+ out_channels: 512
9
+ padding: [0, 0]
10
+ stride: 2 # [512, 52, 52]
11
+ block3: # [512, 52, 52]
12
+ in_channels: 512
13
+ kernel_size: 2
14
+ out_channels: 256
15
+ padding: [0, 0]
16
+ stride: 2 # [256, 100, 100]
17
+ block2: # [256, 100, 100]
18
+ in_channels: 256
19
+ kernel_size: 2
20
+ out_channels: 128
21
+ padding: [0, 0]
22
+ stride: 2 # [128, 196, 196]
23
+ block1: # [128, 196, 196]
24
+ in_channels: 128
25
+ kernel_size: 2
26
+ out_channels: 64
27
+ padding: [0, 0]
28
+ stride: 2 # [64, 388, 388]
29
+ encoder_config:
30
+ block1: # [1, 572, 572]
31
+ all_padding: false
32
+ in_channels: 1
33
+ maxpool: true
34
+ n_layers: 2
35
+ out_channels: 64 # [64, 568/2, 568/2] = [64, 284, 284]
36
+ block2: # [64, 568/2, 568/2] = [64, 284, 284]
37
+ all_padding: false
38
+ in_channels: 64
39
+ maxpool: true
40
+ n_layers: 2
41
+ out_channels: 128 # [128, 280/2, 280/2] = [128, 140, 140]
42
+ block3: # [128, 280/2, 280/2] = [128, 140, 140]
43
+ all_padding: false
44
+ in_channels: 128
45
+ maxpool: true
46
+ n_layers: 2
47
+ out_channels: 256 # [256, 136/2, 136/2] = [256, 68, 68]
48
+ block4: # [256, 136/2, 136/2] = [256, 68, 68]
49
+ all_padding: false
50
+ in_channels: 256
51
+ maxpool: true
52
+ n_layers: 2
53
+ out_channels: 512 # [512, 64/2, 64/2] = [512, 32, 32]
54
+ block5: # [512, 64/2, 64/2] = [512, 32, 32]
55
+ all_padding: false
56
+ in_channels: 512
57
+ maxpool: false
58
+ n_layers: 2
59
+ out_channels: 1024 # [1024, 28, 28]
60
+ nclasses: 2
src/models/unet/config/resnet_config.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original UNet Paper Configuration
2
+ # Input shape [1, 572, 572]
3
+ # Output shape [64, 388, 388]
4
+ decoder_config:
5
+ block4: # [2048, 16, 16]
6
+ in_channels: 2048
7
+ kernel_size: 2
8
+ out_channels: 1024
9
+ padding: [0, 0]
10
+ stride: 2 # [1024, 28, 28]
11
+ block3: # [1024, 28, 28]
12
+ in_channels: 1024
13
+ kernel_size: 2
14
+ out_channels: 512
15
+ padding: [0, 0]
16
+ stride: 2 # [512, 52, 52]
17
+ block2: # [512, 52, 52]
18
+ in_channels: 512
19
+ kernel_size: 2
20
+ out_channels: 128
21
+ padding: [0, 0]
22
+ stride: 2 # [256, 100, 100]
23
+ block1: # [256, 100, 100]
24
+ in_channels: 128
25
+ kernel_size: 2
26
+ out_channels: 64
27
+ padding: [0, 0]
28
+ stride: 2 # [128, 196, 196]
29
+ nclasses: 1
30
+ input_size: [224, 224]
31
+ mean: [0.485, 0.456, 0.406]
32
+ std: [0.229, 0.224, 0.225]
src/models/unet/decoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .decoder import Decoder as CustomDecoder
src/models/unet/decoder/decoder.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class DecoderLayer(nn.Module):
6
+ def __init__(
7
+ self, in_channels, out_channels, kernel_size=2, stride=2, padding=[0, 0]
8
+ ):
9
+ super(DecoderLayer, self).__init__()
10
+ self.up_conv = nn.ConvTranspose2d(
11
+ in_channels=in_channels,
12
+ out_channels=in_channels // 2,
13
+ kernel_size=kernel_size,
14
+ stride=stride,
15
+ padding=padding[0],
16
+ )
17
+
18
+ self.bn1 = nn.BatchNorm2d(in_channels)
19
+
20
+ self.conv = nn.Sequential(
21
+ *[
22
+ self._conv_relu_layer(
23
+ in_channels=in_channels if i == 0 else out_channels,
24
+ out_channels=out_channels,
25
+ padding=padding[1],
26
+ )
27
+ for i in range(2)
28
+ ]
29
+ )
30
+
31
+ def _conv_relu_layer(self, in_channels, out_channels, padding=0):
32
+ return nn.Sequential(
33
+ nn.Conv2d(
34
+ in_channels=in_channels,
35
+ out_channels=out_channels,
36
+ kernel_size=3,
37
+ padding=padding,
38
+ ),
39
+ nn.ReLU(),
40
+ nn.BatchNorm2d(out_channels),
41
+ )
42
+
43
+ @staticmethod
44
+ def crop_cat(x, encoder_output):
45
+ delta = (encoder_output.shape[-1] - x.shape[-1]) // 2
46
+ encoder_output = encoder_output[
47
+ :, :, delta : delta + x.shape[-1], delta : delta + x.shape[-1]
48
+ ]
49
+ return torch.cat((encoder_output, x), dim=1)
50
+
51
+ def forward(self, x, encoder_output):
52
+ x = self.crop_cat(self.up_conv(x), encoder_output)
53
+ x = self.bn1(x)
54
+ return self.conv(x)
55
+
56
+
57
+ class Decoder(nn.Module):
58
+ def __init__(self, config):
59
+ super(Decoder, self).__init__()
60
+ self.decoder = nn.ModuleDict(
61
+ {
62
+ name: DecoderLayer(
63
+ in_channels=block["in_channels"],
64
+ out_channels=block["out_channels"],
65
+ kernel_size=block["kernel_size"],
66
+ stride=block["stride"],
67
+ padding=block["padding"],
68
+ )
69
+ for name, block in config.items()
70
+ }
71
+ )
72
+
73
+ def forward(self, x, encoder_output):
74
+ for name, block in self.decoder.items():
75
+ x = block(x, encoder_output[name])
76
+ return x
src/models/unet/encoder/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .encoder import Encoder as CustomEncoder
2
+ from .resnet import Encoder as ResnetEncoder
src/models/unet/encoder/encoder.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ """
5
+ downsampling blocks
6
+ (first half of the 'U' in UNet)
7
+ [ENCODER]
8
+ """
9
+
10
+
11
+ class EncoderLayer(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_channels=1,
15
+ out_channels=64,
16
+ n_layers=2,
17
+ all_padding=False,
18
+ maxpool=True,
19
+ ):
20
+ super(EncoderLayer, self).__init__()
21
+
22
+ f_in_channel = lambda layer: in_channels if layer == 0 else out_channels
23
+ f_padding = lambda layer: 1 if layer >= 2 or all_padding else 0
24
+
25
+ self.layer = nn.Sequential(
26
+ *[
27
+ self._conv_relu_layer(
28
+ in_channels=f_in_channel(i),
29
+ out_channels=out_channels,
30
+ padding=f_padding(i),
31
+ )
32
+ for i in range(n_layers)
33
+ ]
34
+ )
35
+ self.maxpool = maxpool
36
+
37
+ def _conv_relu_layer(self, in_channels, out_channels, padding=0):
38
+ return nn.Sequential(
39
+ nn.Conv2d(
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ kernel_size=3,
43
+ padding=padding,
44
+ ),
45
+ nn.ReLU(),
46
+ nn.BatchNorm2d(out_channels),
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.layer(x)
51
+
52
+
53
+ class Encoder(nn.Module):
54
+ def __init__(self, config):
55
+ super(Encoder, self).__init__()
56
+ self.encoder = nn.ModuleDict(
57
+ {
58
+ name: EncoderLayer(
59
+ in_channels=block["in_channels"],
60
+ out_channels=block["out_channels"],
61
+ n_layers=block["n_layers"],
62
+ all_padding=block["all_padding"],
63
+ maxpool=block["maxpool"],
64
+ )
65
+ for name, block in config.items()
66
+ }
67
+ )
68
+ self.maxpool = nn.MaxPool2d(2)
69
+
70
+ def forward(self, x):
71
+ output = dict()
72
+
73
+ for i, (block_name, block) in enumerate(self.encoder.items()):
74
+ x = block(x)
75
+ output[block_name] = x
76
+
77
+ if block.maxpool:
78
+ x = self.maxpool(x)
79
+
80
+ return x, output
src/models/unet/encoder/resnet.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.models import resnet50, ResNet50_Weights
2
+ import torch.nn as nn
3
+
4
+
5
+ class Encoder(nn.Module):
6
+ def __init__(self):
7
+ super(Encoder, self).__init__()
8
+ resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
9
+
10
+ for param in resnet.parameters():
11
+ param.requires_grad_(False)
12
+
13
+ self.stages = nn.ModuleDict(
14
+ {
15
+ "block1": nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu),
16
+ "block2": nn.Sequential(resnet.maxpool, resnet.layer1),
17
+ "block3": resnet.layer2,
18
+ "block4": resnet.layer3,
19
+ "block5": resnet.layer4,
20
+ }
21
+ )
22
+
23
+ def forward(self, x):
24
+ stages = {}
25
+
26
+ for name, stage in self.stages.items():
27
+ x = stage(x)
28
+ stages[name] = x
29
+
30
+ return x, stages
src/models/unet/example/model_sample.ipynb ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "310eb987-37b7-4620-b533-089644fbb440",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.functional as F\n",
12
+ "import torch.nn as nn\n",
13
+ "import yaml\n",
14
+ "from easydict import EasyDict\n",
15
+ "from torchinfo import summary"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 2,
21
+ "id": "f8cff897-df8f-4e6d-893b-321805699e1b",
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "config_path = \"./config/paper_config.yml\"\n",
26
+ "\n",
27
+ "with open(config_path, \"r\") as file:\n",
28
+ " yaml_data = yaml.safe_load(file)\n",
29
+ "\n",
30
+ "config = EasyDict(yaml_data)"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "id": "ca66846e-d2b4-4dd2-83eb-eee746c26c74",
36
+ "metadata": {},
37
+ "source": [
38
+ "# Encoder "
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 3,
44
+ "id": "975a6f86-68ff-4fda-b7d8-acf453addade",
45
+ "metadata": {},
46
+ "outputs": [
47
+ {
48
+ "data": {
49
+ "text/plain": [
50
+ "==========================================================================================\n",
51
+ "Layer (type:depth-idx) Output Shape Param #\n",
52
+ "==========================================================================================\n",
53
+ "EncoderLayer [64, 568, 568] --\n",
54
+ "├─Sequential: 1-1 [64, 568, 568] --\n",
55
+ "│ └─Sequential: 2-1 [64, 570, 570] --\n",
56
+ "│ │ └─Conv2d: 3-1 [64, 570, 570] 640\n",
57
+ "│ │ └─ReLU: 3-2 [64, 570, 570] --\n",
58
+ "│ └─Sequential: 2-2 [64, 568, 568] --\n",
59
+ "│ │ └─Conv2d: 3-3 [64, 568, 568] 36,928\n",
60
+ "│ │ └─ReLU: 3-4 [64, 568, 568] --\n",
61
+ "==========================================================================================\n",
62
+ "Total params: 37,568\n",
63
+ "Trainable params: 37,568\n",
64
+ "Non-trainable params: 0\n",
65
+ "Total mult-adds (G): 1.37\n",
66
+ "==========================================================================================\n",
67
+ "Input size (MB): 1.31\n",
68
+ "Forward/backward pass size (MB): 331.53\n",
69
+ "Params size (MB): 0.15\n",
70
+ "Estimated Total Size (MB): 332.99\n",
71
+ "=========================================================================================="
72
+ ]
73
+ },
74
+ "execution_count": 3,
75
+ "metadata": {},
76
+ "output_type": "execute_result"
77
+ }
78
+ ],
79
+ "source": [
80
+ "\"\"\"\n",
81
+ "downsampling blocks \n",
82
+ "(first half of the 'U' in UNet) \n",
83
+ "[ENCODER]\n",
84
+ "\"\"\"\n",
85
+ "\n",
86
+ "\n",
87
+ "class EncoderLayer(nn.Module):\n",
88
+ " def __init__(\n",
89
+ " self,\n",
90
+ " in_channels=1,\n",
91
+ " out_channels=64,\n",
92
+ " n_layers=2,\n",
93
+ " all_padding=False,\n",
94
+ " maxpool=True,\n",
95
+ " ):\n",
96
+ " super(EncoderLayer, self).__init__()\n",
97
+ "\n",
98
+ " f_in_channel = lambda layer: in_channels if layer == 0 else out_channels\n",
99
+ " f_padding = lambda layer: 1 if layer >= 2 or all_padding else 0\n",
100
+ "\n",
101
+ " self.layer = nn.Sequential(\n",
102
+ " *[\n",
103
+ " self._conv_relu_layer(\n",
104
+ " in_channels=f_in_channel(i),\n",
105
+ " out_channels=out_channels,\n",
106
+ " padding=f_padding(i),\n",
107
+ " )\n",
108
+ " for i in range(n_layers)\n",
109
+ " ]\n",
110
+ " )\n",
111
+ " self.maxpool = maxpool\n",
112
+ "\n",
113
+ " def _conv_relu_layer(self, in_channels, out_channels, padding=0):\n",
114
+ " return nn.Sequential(\n",
115
+ " nn.Conv2d(\n",
116
+ " in_channels=in_channels,\n",
117
+ " out_channels=out_channels,\n",
118
+ " kernel_size=3,\n",
119
+ " padding=padding,\n",
120
+ " ),\n",
121
+ " nn.ReLU(),\n",
122
+ " )\n",
123
+ "\n",
124
+ " def forward(self, x):\n",
125
+ " return self.layer(x)\n",
126
+ "\n",
127
+ "\n",
128
+ "summary(\n",
129
+ " EncoderLayer(in_channels=1, out_channels=64, n_layers=2, all_padding=False).cuda(),\n",
130
+ " input_size=(1, 572, 572),\n",
131
+ ")"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 4,
137
+ "id": "4eb7eedd-6530-44e2-9486-fbd8f39fd0ad",
138
+ "metadata": {},
139
+ "outputs": [
140
+ {
141
+ "data": {
142
+ "text/plain": [
143
+ "==========================================================================================\n",
144
+ "Layer (type:depth-idx) Output Shape Param #\n",
145
+ "==========================================================================================\n",
146
+ "Encoder [1024, 28, 28] --\n",
147
+ "├─ModuleDict: 1-9 -- (recursive)\n",
148
+ "│ └─EncoderLayer: 2-1 [64, 568, 568] --\n",
149
+ "│ │ └─Sequential: 3-1 [64, 568, 568] 37,568\n",
150
+ "├─MaxPool2d: 1-2 [64, 284, 284] --\n",
151
+ "├─ModuleDict: 1-9 -- (recursive)\n",
152
+ "│ └─EncoderLayer: 2-2 [128, 280, 280] --\n",
153
+ "│ │ └─Sequential: 3-2 [128, 280, 280] 221,440\n",
154
+ "├─MaxPool2d: 1-4 [128, 140, 140] --\n",
155
+ "├─ModuleDict: 1-9 -- (recursive)\n",
156
+ "│ └─EncoderLayer: 2-3 [256, 136, 136] --\n",
157
+ "│ │ └─Sequential: 3-3 [256, 136, 136] 885,248\n",
158
+ "├─MaxPool2d: 1-6 [256, 68, 68] --\n",
159
+ "├─ModuleDict: 1-9 -- (recursive)\n",
160
+ "│ └─EncoderLayer: 2-4 [512, 64, 64] --\n",
161
+ "│ │ └─Sequential: 3-4 [512, 64, 64] 3,539,968\n",
162
+ "├─MaxPool2d: 1-8 [512, 32, 32] --\n",
163
+ "├─ModuleDict: 1-9 -- (recursive)\n",
164
+ "│ └─EncoderLayer: 2-5 [512, 28, 28] --\n",
165
+ "│ │ └─Sequential: 3-5 [512, 28, 28] 4,719,616\n",
166
+ "│ └─EncoderLayer: 2-6 [1024, 28, 28] --\n",
167
+ "│ │ └─Sequential: 3-6 [1024, 28, 28] 14,157,824\n",
168
+ "==========================================================================================\n",
169
+ "Total params: 23,561,664\n",
170
+ "Trainable params: 23,561,664\n",
171
+ "Non-trainable params: 0\n",
172
+ "Total mult-adds (G): 633.51\n",
173
+ "==========================================================================================\n",
174
+ "Input size (MB): 1.31\n",
175
+ "Forward/backward pass size (MB): 624.49\n",
176
+ "Params size (MB): 94.25\n",
177
+ "Estimated Total Size (MB): 720.05\n",
178
+ "=========================================================================================="
179
+ ]
180
+ },
181
+ "execution_count": 4,
182
+ "metadata": {},
183
+ "output_type": "execute_result"
184
+ }
185
+ ],
186
+ "source": [
187
+ "class Encoder(nn.Module):\n",
188
+ " def __init__(self, config):\n",
189
+ " super(Encoder, self).__init__()\n",
190
+ " self.encoder = nn.ModuleDict(\n",
191
+ " {\n",
192
+ " name: EncoderLayer(\n",
193
+ " in_channels=block[\"in_channels\"],\n",
194
+ " out_channels=block[\"out_channels\"],\n",
195
+ " n_layers=block[\"n_layers\"],\n",
196
+ " all_padding=block[\"all_padding\"],\n",
197
+ " maxpool=block[\"maxpool\"],\n",
198
+ " )\n",
199
+ " for name, block in config.items()\n",
200
+ " }\n",
201
+ " )\n",
202
+ " self.maxpool = nn.MaxPool2d(2)\n",
203
+ "\n",
204
+ " def forward(self, x):\n",
205
+ " output = dict()\n",
206
+ "\n",
207
+ " for i, (block_name, block) in enumerate(self.encoder.items()):\n",
208
+ " x = block(x)\n",
209
+ " output[block_name] = x\n",
210
+ "\n",
211
+ " if block.maxpool:\n",
212
+ " x = self.maxpool(x)\n",
213
+ "\n",
214
+ " return x, output\n",
215
+ "\n",
216
+ "\n",
217
+ "summary(\n",
218
+ " Encoder(config.encoder_config).cuda(),\n",
219
+ " input_size=(1, 572, 572),\n",
220
+ ")"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "markdown",
225
+ "id": "a7ad06cb-61a2-4a66-ba58-f29d402a81f2",
226
+ "metadata": {},
227
+ "source": [
228
+ "# Decoder"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 5,
234
+ "id": "735322d0-0dc3-4137-b906-ac7e54c43a79",
235
+ "metadata": {},
236
+ "outputs": [
237
+ {
238
+ "data": {
239
+ "text/plain": [
240
+ "==========================================================================================\n",
241
+ "Layer (type:depth-idx) Output Shape Param #\n",
242
+ "==========================================================================================\n",
243
+ "DecoderLayer [1, 512, 52, 52] --\n",
244
+ "├─ConvTranspose2d: 1-1 [1, 512, 56, 56] 2,097,664\n",
245
+ "├─Sequential: 1-2 [1, 512, 52, 52] --\n",
246
+ "│ ���─Sequential: 2-1 [1, 512, 54, 54] --\n",
247
+ "│ │ └─Conv2d: 3-1 [1, 512, 54, 54] 4,719,104\n",
248
+ "│ │ └─ReLU: 3-2 [1, 512, 54, 54] --\n",
249
+ "│ └─Sequential: 2-2 [1, 512, 52, 52] --\n",
250
+ "│ │ └─Conv2d: 3-3 [1, 512, 52, 52] 2,359,808\n",
251
+ "│ │ └─ReLU: 3-4 [1, 512, 52, 52] --\n",
252
+ "==========================================================================================\n",
253
+ "Total params: 9,176,576\n",
254
+ "Trainable params: 9,176,576\n",
255
+ "Non-trainable params: 0\n",
256
+ "Total mult-adds (G): 26.72\n",
257
+ "==========================================================================================\n",
258
+ "Input size (MB): 11.60\n",
259
+ "Forward/backward pass size (MB): 35.86\n",
260
+ "Params size (MB): 36.71\n",
261
+ "Estimated Total Size (MB): 84.17\n",
262
+ "=========================================================================================="
263
+ ]
264
+ },
265
+ "execution_count": 5,
266
+ "metadata": {},
267
+ "output_type": "execute_result"
268
+ }
269
+ ],
270
+ "source": [
271
+ "class DecoderLayer(nn.Module):\n",
272
+ " def __init__(\n",
273
+ " self, in_channels, out_channels, kernel_size=2, stride=2, padding=[0, 0]\n",
274
+ " ):\n",
275
+ " super(DecoderLayer, self).__init__()\n",
276
+ " self.up_conv = nn.ConvTranspose2d(\n",
277
+ " in_channels=in_channels,\n",
278
+ " out_channels=in_channels // 2,\n",
279
+ " kernel_size=kernel_size,\n",
280
+ " stride=stride,\n",
281
+ " padding=padding[0],\n",
282
+ " )\n",
283
+ "\n",
284
+ " self.conv = nn.Sequential(\n",
285
+ " *[\n",
286
+ " self._conv_relu_layer(\n",
287
+ " in_channels=in_channels if i == 0 else out_channels,\n",
288
+ " out_channels=out_channels,\n",
289
+ " padding=padding[1],\n",
290
+ " )\n",
291
+ " for i in range(2)\n",
292
+ " ]\n",
293
+ " )\n",
294
+ "\n",
295
+ " def _conv_relu_layer(self, in_channels, out_channels, padding=0):\n",
296
+ " return nn.Sequential(\n",
297
+ " nn.Conv2d(\n",
298
+ " in_channels=in_channels,\n",
299
+ " out_channels=out_channels,\n",
300
+ " kernel_size=3,\n",
301
+ " padding=padding,\n",
302
+ " ),\n",
303
+ " nn.ReLU(),\n",
304
+ " )\n",
305
+ "\n",
306
+ " @staticmethod\n",
307
+ " def crop_cat(x, encoder_output):\n",
308
+ " delta = (encoder_output.shape[-1] - x.shape[-1]) // 2\n",
309
+ " encoder_output = encoder_output[\n",
310
+ " :, :, delta : delta + x.shape[-1], delta : delta + x.shape[-1]\n",
311
+ " ]\n",
312
+ " return torch.cat((encoder_output, x), dim=1)\n",
313
+ "\n",
314
+ " def forward(self, x, encoder_output):\n",
315
+ " x = self.crop_cat(self.up_conv(x), encoder_output)\n",
316
+ " return self.conv(x)\n",
317
+ "\n",
318
+ "\n",
319
+ "# summary\n",
320
+ "input_data = [torch.rand((1, 1024, 28, 28)), torch.rand((1, 512, 64, 64))]\n",
321
+ "summary(\n",
322
+ " DecoderLayer(in_channels=1024, out_channels=512),\n",
323
+ " input_data=input_data,\n",
324
+ ")"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": 6,
330
+ "id": "3795e85d-ff83-457c-9c12-af6cc6e2830c",
331
+ "metadata": {},
332
+ "outputs": [
333
+ {
334
+ "data": {
335
+ "text/plain": [
336
+ "==========================================================================================\n",
337
+ "Layer (type:depth-idx) Output Shape Param #\n",
338
+ "==========================================================================================\n",
339
+ "Decoder [1, 64, 388, 388] --\n",
340
+ "├─ModuleDict: 1-1 -- --\n",
341
+ "│ └─DecoderLayer: 2-1 [1, 1024, 28, 28] --\n",
342
+ "│ │ └─ConvTranspose2d: 3-1 [1, 512, 28, 28] 4,719,104\n",
343
+ "│ │ └─Sequential: 3-2 [1, 1024, 28, 28] 18,876,416\n",
344
+ "│ └─DecoderLayer: 2-2 [1, 512, 52, 52] --\n",
345
+ "│ │ └─ConvTranspose2d: 3-3 [1, 512, 56, 56] 2,097,664\n",
346
+ "│ │ └─Sequential: 3-4 [1, 512, 52, 52] 7,078,912\n",
347
+ "│ └─DecoderLayer: 2-3 [1, 256, 100, 100] --\n",
348
+ "│ │ └─ConvTranspose2d: 3-5 [1, 256, 104, 104] 524,544\n",
349
+ "│ │ └─Sequential: 3-6 [1, 256, 100, 100] 1,769,984\n",
350
+ "│ └─DecoderLayer: 2-4 [1, 128, 196, 196] --\n",
351
+ "│ │ └─ConvTranspose2d: 3-7 [1, 128, 200, 200] 131,200\n",
352
+ "│ │ └─Sequential: 3-8 [1, 128, 196, 196] 442,624\n",
353
+ "│ └─DecoderLayer: 2-5 [1, 64, 388, 388] --\n",
354
+ "│ │ └─ConvTranspose2d: 3-9 [1, 64, 392, 392] 32,832\n",
355
+ "│ │ └─Sequential: 3-10 [1, 64, 388, 388] 110,720\n",
356
+ "==========================================================================================\n",
357
+ "Total params: 35,784,000\n",
358
+ "Trainable params: 35,784,000\n",
359
+ "Non-trainable params: 0\n",
360
+ "Total mult-adds (G): 113.38\n",
361
+ "==========================================================================================\n",
362
+ "Input size (MB): 158.09\n",
363
+ "Forward/backward pass size (MB): 469.93\n",
364
+ "Params size (MB): 143.14\n",
365
+ "Estimated Total Size (MB): 771.16\n",
366
+ "=========================================================================================="
367
+ ]
368
+ },
369
+ "execution_count": 6,
370
+ "metadata": {},
371
+ "output_type": "execute_result"
372
+ }
373
+ ],
374
+ "source": [
375
+ "class Decoder(nn.Module):\n",
376
+ " def __init__(self, config):\n",
377
+ " super(Decoder, self).__init__()\n",
378
+ " self.decoder = nn.ModuleDict(\n",
379
+ " {\n",
380
+ " name: DecoderLayer(\n",
381
+ " in_channels=block[\"in_channels\"],\n",
382
+ " out_channels=block[\"out_channels\"],\n",
383
+ " kernel_size=block[\"kernel_size\"],\n",
384
+ " stride=block[\"stride\"],\n",
385
+ " padding=block[\"padding\"],\n",
386
+ " )\n",
387
+ " for name, block in config.items()\n",
388
+ " }\n",
389
+ " )\n",
390
+ "\n",
391
+ " def forward(self, x, encoder_output):\n",
392
+ " for name, block in self.decoder.items():\n",
393
+ " x = block(x, encoder_output[name])\n",
394
+ " return x\n",
395
+ "\n",
396
+ "\n",
397
+ "# summary\n",
398
+ "encoder_input = torch.rand((1, 1, 572, 572), device=\"cuda\")\n",
399
+ "x, encoder_output = Encoder(config.encoder_config).cuda()(encoder_input)\n",
400
+ "\n",
401
+ "input_data = [x, encoder_output]\n",
402
+ "summary(\n",
403
+ " Decoder(config.decoder_config).cuda(),\n",
404
+ " input_data=input_data,\n",
405
+ ")"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "markdown",
410
+ "id": "6cd06e02-abd4-4537-8bce-5a15c4ad4f85",
411
+ "metadata": {},
412
+ "source": [
413
+ "# UNet"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": 7,
419
+ "id": "24fd0355-3603-4a55-b827-068eda70b78a",
420
+ "metadata": {},
421
+ "outputs": [
422
+ {
423
+ "data": {
424
+ "text/plain": [
425
+ "===============================================================================================\n",
426
+ "Layer (type:depth-idx) Output Shape Param #\n",
427
+ "===============================================================================================\n",
428
+ "UNet [1, 2, 388, 388] --\n",
429
+ "├─Encoder: 1-1 [1, 1024, 28, 28] --\n",
430
+ "│ └─ModuleDict: 2-9 -- (recursive)\n",
431
+ "│ │ └─EncoderLayer: 3-1 [1, 64, 568, 568] 37,568\n",
432
+ "│ └─MaxPool2d: 2-2 [1, 64, 284, 284] --\n",
433
+ "│ └─ModuleDict: 2-9 -- (recursive)\n",
434
+ "│ │ └─EncoderLayer: 3-2 [1, 128, 280, 280] 221,440\n",
435
+ "│ └─MaxPool2d: 2-4 [1, 128, 140, 140] --\n",
436
+ "│ └─ModuleDict: 2-9 -- (recursive)\n",
437
+ "│ │ └─EncoderLayer: 3-3 [1, 256, 136, 136] 885,248\n",
438
+ "│ └─MaxPool2d: 2-6 [1, 256, 68, 68] --\n",
439
+ "│ └─ModuleDict: 2-9 -- (recursive)\n",
440
+ "│ │ └─EncoderLayer: 3-4 [1, 512, 64, 64] 3,539,968\n",
441
+ "│ └─MaxPool2d: 2-8 [1, 512, 32, 32] --\n",
442
+ "│ └─ModuleDict: 2-9 -- (recursive)\n",
443
+ "│ │ └─EncoderLayer: 3-5 [1, 512, 28, 28] 4,719,616\n",
444
+ "│ │ └─EncoderLayer: 3-6 [1, 1024, 28, 28] 14,157,824\n",
445
+ "├─Decoder: 1-2 [1, 64, 388, 388] --\n",
446
+ "│ └─ModuleDict: 2-10 -- --\n",
447
+ "│ │ └─DecoderLayer: 3-7 [1, 1024, 28, 28] 23,595,520\n",
448
+ "│ │ └─DecoderLayer: 3-8 [1, 512, 52, 52] 9,176,576\n",
449
+ "│ │ └─DecoderLayer: 3-9 [1, 256, 100, 100] 2,294,528\n",
450
+ "│ │ └─DecoderLayer: 3-10 [1, 128, 196, 196] 573,824\n",
451
+ "│ │ └─DecoderLayer: 3-11 [1, 64, 388, 388] 143,552\n",
452
+ "├─Conv2d: 1-3 [1, 2, 388, 388] 130\n",
453
+ "===============================================================================================\n",
454
+ "Total params: 59,345,794\n",
455
+ "Trainable params: 59,345,794\n",
456
+ "Non-trainable params: 0\n",
457
+ "Total mult-adds (G): 189.38\n",
458
+ "===============================================================================================\n",
459
+ "Input size (MB): 1.31\n",
460
+ "Forward/backward pass size (MB): 1096.83\n",
461
+ "Params size (MB): 237.38\n",
462
+ "Estimated Total Size (MB): 1335.52\n",
463
+ "==============================================================================================="
464
+ ]
465
+ },
466
+ "execution_count": 7,
467
+ "metadata": {},
468
+ "output_type": "execute_result"
469
+ }
470
+ ],
471
+ "source": [
472
+ "class UNet(nn.Module):\n",
473
+ " def __init__(self, encoder_config, decoder_config, nclasses):\n",
474
+ " super(UNet, self).__init__()\n",
475
+ " self.encoder = Encoder(config=encoder_config)\n",
476
+ " self.decoder = Decoder(config=decoder_config)\n",
477
+ "\n",
478
+ " self.output = nn.Conv2d(\n",
479
+ " in_channels=decoder_config[\"block1\"][\"out_channels\"],\n",
480
+ " out_channels=nclasses,\n",
481
+ " kernel_size=1,\n",
482
+ " )\n",
483
+ "\n",
484
+ " def forward(self, x):\n",
485
+ " x, encoder_step_output = self.encoder(x)\n",
486
+ " x = self.decoder(x, encoder_step_output)\n",
487
+ " return self.output(x)\n",
488
+ "\n",
489
+ "\n",
490
+ "summary(\n",
491
+ " UNet(\n",
492
+ " config[\"encoder_config\"], config[\"decoder_config\"], nclasses=config[\"nclasses\"]\n",
493
+ " ),\n",
494
+ " input_data=torch.rand((1, 1, 572, 572)),\n",
495
+ ")"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": 13,
501
+ "id": "550824e4-2151-4c0b-8a12-383fa092b4ac",
502
+ "metadata": {},
503
+ "outputs": [],
504
+ "source": [
505
+ "# # if config is a dict\n",
506
+ "# with open('custom_config.yml', 'w') as outfile:\n",
507
+ "# yaml.dump(config, outfile, sort_keys=False)"
508
+ ]
509
+ }
510
+ ],
511
+ "metadata": {
512
+ "kernelspec": {
513
+ "display_name": "Python 3 (ipykernel)",
514
+ "language": "python",
515
+ "name": "python3"
516
+ },
517
+ "language_info": {
518
+ "codemirror_mode": {
519
+ "name": "ipython",
520
+ "version": 3
521
+ },
522
+ "file_extension": ".py",
523
+ "mimetype": "text/x-python",
524
+ "name": "python",
525
+ "nbconvert_exporter": "python",
526
+ "pygments_lexer": "ipython3",
527
+ "version": "3.8.10"
528
+ }
529
+ },
530
+ "nbformat": 4,
531
+ "nbformat_minor": 5
532
+ }
src/models/unet/resunet.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .encoder import ResnetEncoder as Encoder
3
+ from .decoder import CustomDecoder as Decoder
4
+
5
+
6
+ class UNet(nn.Module):
7
+ def __init__(self, decoder_config, nclasses, input_shape=(224, 224)):
8
+ super(UNet, self).__init__()
9
+ self.encoder = Encoder()
10
+ self.decoder = Decoder(config=decoder_config)
11
+
12
+ self.output = nn.Sequential(
13
+ nn.Conv2d(
14
+ in_channels=decoder_config["block1"]["out_channels"],
15
+ out_channels=nclasses,
16
+ kernel_size=1,
17
+ ),
18
+ nn.UpsamplingBilinear2d(size=input_shape),
19
+ )
20
+
21
+ def forward(self, x):
22
+ x, encoder_step_output = self.encoder(x)
23
+ x = self.decoder(x, encoder_step_output)
24
+ x = self.output(x)
25
+ return x
26
+
27
+
28
+ if __name__ == "__main__":
29
+ import torch
30
+ import yaml
31
+ from easydict import EasyDict
32
+ from torchinfo import summary
33
+
34
+ # load config
35
+ config_path = "./config/resnet_config.yml"
36
+
37
+ with open(config_path, "r") as file:
38
+ yaml_data = yaml.safe_load(file)
39
+
40
+ config = EasyDict(yaml_data)
41
+
42
+ # input shape
43
+ input_shape = (224, 224)
44
+
45
+ # device
46
+ use_cuda = torch.cuda.is_available()
47
+ device = torch.device("cuda" if use_cuda else "cpu")
48
+
49
+ # model definition
50
+ model = UNet(
51
+ decoder_config=config["decoder_config"], nclasses=1, input_shape=input_shape
52
+ ).to(device)
53
+
54
+ summary(
55
+ model,
56
+ input_data=torch.rand((1, 3, input_shape[0], input_shape[1])),
57
+ device=device,
58
+ )
59
+
60
+ # load weights (if any)
61
+ model_path = None
62
+
63
+ if model_path is not None:
64
+ checkpoint = torch.load(model_path, map_location=device)
65
+ model.decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=False)
66
+ model.output.load_state_dict(checkpoint["output_state_dict"], strict=False)
src/run/unet/example/binary_segmentation_resunet.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/run/unet/inference.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import albumentations as A
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import yaml
9
+ from albumentations.pytorch import ToTensorV2
10
+ from easydict import EasyDict
11
+ from PIL import Image
12
+
13
+ from src.models.unet.resunet import UNet as Model
14
+
15
+
16
+ class ResUnetInfer:
17
+ def __init__(self, model_path, config_path):
18
+ use_cuda = torch.cuda.is_available()
19
+ self.device = torch.device("cuda" if use_cuda else "cpu")
20
+
21
+ self.config = self.load_config(config_path=config_path)
22
+ self.model = self.load_model(model_path=model_path)
23
+
24
+ self.transform = A.Compose(
25
+ [
26
+ A.Resize(self.config.input_size[0], self.config.input_size[1]),
27
+ A.Normalize(
28
+ mean=self.config.mean,
29
+ std=self.config.std,
30
+ max_pixel_value=255,
31
+ ),
32
+ ToTensorV2(),
33
+ ]
34
+ )
35
+
36
+ def load_model(self, model_path):
37
+ model = Model(
38
+ decoder_config=self.config.decoder_config, nclasses=self.config.nclasses
39
+ ).to(self.device)
40
+
41
+ if os.path.isfile(model_path):
42
+ checkpoint = torch.load(model_path, map_location=self.device)
43
+ model.decoder.load_state_dict(
44
+ checkpoint["decoder_state_dict"], strict=False
45
+ )
46
+ model.output.load_state_dict(checkpoint["output_state_dict"], strict=False)
47
+
48
+ return model
49
+
50
+ def load_config(self, config_path):
51
+ with open(config_path, "r") as file:
52
+ yaml_data = yaml.safe_load(file)
53
+
54
+ return EasyDict(yaml_data)
55
+
56
+ def infer(self, image, image_weight=0.01):
57
+ self.model.eval()
58
+ input_tensor = self.transform(image=image)["image"].unsqueeze(0)
59
+
60
+ # get mask
61
+ with torch.no_grad():
62
+ """
63
+ output_tensor = [batch, 1, 224, 224]
64
+ batch = 1
65
+ """
66
+ output_tensor = self.model(input_tensor.to(self.device))
67
+
68
+ mask = torch.sigmoid(output_tensor)
69
+ mask = nn.UpsamplingBilinear2d(size=(image.shape[0], image.shape[1]))(mask)
70
+ mask = mask.squeeze(0)
71
+
72
+ # add zeros for green and blue channels
73
+ # our mask will be red in colour
74
+ zero_channels = torch.zeros((2, image.shape[0], image.shape[1]), device=self.device)
75
+ mask = torch.cat([mask, zero_channels], dim=0)
76
+ mask = mask.permute(1,2,0).cpu().numpy()
77
+ mask = np.uint8(255 * mask)
78
+
79
+ # overlap image and mask
80
+ mask = (1 - image_weight) * mask + image_weight * image
81
+ mask = mask / np.max(mask)
82
+ return np.uint8(255 * mask)
83
+
84
+ @staticmethod
85
+ def load_image_as_array(image_path):
86
+ # Load a PIL image
87
+ pil_image = Image.open(image_path)
88
+
89
+ # Convert PIL image to NumPy array
90
+ return np.array(pil_image.convert("RGB"))
91
+
92
+ @staticmethod
93
+ def plot_array(array: np.array, figsize=(10, 10)):
94
+ plt.figure(figsize=figsize)
95
+ plt.imshow(array)
96
+ plt.show()
97
+
98
+ @staticmethod
99
+ def save_numpy_as_image(numpy_array, image_path):
100
+ """
101
+ Saves a NumPy array as an image.
102
+ Args:
103
+ numpy_array (numpy.ndarray): The NumPy array to be saved as an image.
104
+ image_path (str): The path where the image will be saved.
105
+ """
106
+ # Convert the NumPy array to a PIL image
107
+ image = Image.fromarray(numpy_array)
108
+
109
+ # Save the PIL image to the specified path
110
+ image.save(image_path)
111
+
src/unet/__init__.py ADDED
File without changes
src/unet/config/carvana_config.yml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Input (1, 512, 512)
2
+ # Output (64, 512, 512)
3
+ decoder_config:
4
+ block5: # (1024, 32, 32)
5
+ in_channels: 1024
6
+ kernel_size: 3
7
+ out_channels: 1024
8
+ padding:
9
+ - 1
10
+ - 1
11
+ stride: 1 # (1024, 32, 32)
12
+ block4: # (1024, 32, 32)
13
+ in_channels: 1024
14
+ kernel_size: 2
15
+ out_channels: 512
16
+ padding:
17
+ - 0
18
+ - 1
19
+ stride: 2 # (512, 64, 64)
20
+ block3: # (512, 64, 64)
21
+ in_channels: 512
22
+ kernel_size: 2
23
+ out_channels: 256
24
+ padding:
25
+ - 0
26
+ - 1
27
+ stride: 2 # (256, 128, 128)
28
+ block2: # (256, 128, 128)
29
+ in_channels: 256
30
+ kernel_size: 2
31
+ out_channels: 128
32
+ padding:
33
+ - 0
34
+ - 1
35
+ stride: 2 # (128, 256, 256)
36
+ block1: # (128, 256, 256)
37
+ in_channels: 128
38
+ kernel_size: 2
39
+ out_channels: 64
40
+ padding:
41
+ - 0
42
+ - 1
43
+ stride: 2 # (64, 512, 512)
44
+ encoder_config:
45
+ block1: # (1, 512, 512)
46
+ all_padding: true
47
+ in_channels: 1
48
+ maxpool: true
49
+ n_layers: 2
50
+ out_channels: 64 # (64, 256, 256)
51
+ block2: # (64, 256, 256)
52
+ all_padding: true
53
+ in_channels: 64
54
+ maxpool: true
55
+ n_layers: 2
56
+ out_channels: 128 # (128, 128, 128)
57
+ block3: # (128, 128, 128)
58
+ all_padding: true
59
+ in_channels: 128
60
+ maxpool: true
61
+ n_layers: 2
62
+ out_channels: 256 # (256, 64, 64)
63
+ block4: # (256, 64, 64)
64
+ all_padding: true
65
+ in_channels: 256
66
+ maxpool: true
67
+ n_layers: 2
68
+ out_channels: 512 # (512, 32, 32)
69
+ block5: # (512, 32, 32)
70
+ all_padding: true
71
+ in_channels: 512
72
+ maxpool: false
73
+ n_layers: 2
74
+ out_channels: 512 # (512, 32, 32)
75
+ block6: # (512, 32, 32)
76
+ all_padding: true
77
+ in_channels: 512
78
+ maxpool: false
79
+ n_layers: 2
80
+ out_channels: 1024 # (1024, 32, 32)
81
+ nclasses: 2
src/unet/config/paper_config.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original UNet Paper Configuration
2
+ # Input shape [1, 572, 572]
3
+ # Output shape [64, 388, 388]
4
+ decoder_config:
5
+ block4: # [1024, 28, 28]
6
+ in_channels: 1024
7
+ kernel_size: 2
8
+ out_channels: 512
9
+ padding: [0, 0]
10
+ stride: 2 # [512, 52, 52]
11
+ block3: # [512, 52, 52]
12
+ in_channels: 512
13
+ kernel_size: 2
14
+ out_channels: 256
15
+ padding: [0, 0]
16
+ stride: 2 # [256, 100, 100]
17
+ block2: # [256, 100, 100]
18
+ in_channels: 256
19
+ kernel_size: 2
20
+ out_channels: 128
21
+ padding: [0, 0]
22
+ stride: 2 # [128, 196, 196]
23
+ block1: # [128, 196, 196]
24
+ in_channels: 128
25
+ kernel_size: 2
26
+ out_channels: 64
27
+ padding: [0, 0]
28
+ stride: 2 # [64, 388, 388]
29
+ encoder_config:
30
+ block1: # [1, 572, 572]
31
+ all_padding: false
32
+ in_channels: 1
33
+ maxpool: true
34
+ n_layers: 2
35
+ out_channels: 64 # [64, 568/2, 568/2] = [64, 284, 284]
36
+ block2: # [64, 568/2, 568/2] = [64, 284, 284]
37
+ all_padding: false
38
+ in_channels: 64
39
+ maxpool: true
40
+ n_layers: 2
41
+ out_channels: 128 # [128, 280/2, 280/2] = [128, 140, 140]
42
+ block3: # [128, 280/2, 280/2] = [128, 140, 140]
43
+ all_padding: false
44
+ in_channels: 128
45
+ maxpool: true
46
+ n_layers: 2
47
+ out_channels: 256 # [256, 136/2, 136/2] = [256, 68, 68]
48
+ block4: # [256, 136/2, 136/2] = [256, 68, 68]
49
+ all_padding: false
50
+ in_channels: 256
51
+ maxpool: true
52
+ n_layers: 2
53
+ out_channels: 512 # [512, 64/2, 64/2] = [512, 32, 32]
54
+ block5: # [512, 64/2, 64/2] = [512, 32, 32]
55
+ all_padding: false
56
+ in_channels: 512
57
+ maxpool: false
58
+ n_layers: 2
59
+ out_channels: 1024 # [1024, 28, 28]
60
+ nclasses: 2
src/unet/model.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ """
6
+ downsampling blocks
7
+ (first half of the 'U' in UNet)
8
+ [ENCODER]
9
+ """
10
+
11
+
12
+ class EncoderLayer(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels=1,
16
+ out_channels=64,
17
+ n_layers=2,
18
+ all_padding=False,
19
+ maxpool=True,
20
+ ):
21
+ super(EncoderLayer, self).__init__()
22
+
23
+ f_in_channel = lambda layer: in_channels if layer == 0 else out_channels
24
+ f_padding = lambda layer: 1 if layer >= 2 or all_padding else 0
25
+
26
+ self.layer = nn.Sequential(
27
+ *[
28
+ self._conv_relu_layer(
29
+ in_channels=f_in_channel(i),
30
+ out_channels=out_channels,
31
+ padding=f_padding(i),
32
+ )
33
+ for i in range(n_layers)
34
+ ]
35
+ )
36
+ self.maxpool = maxpool
37
+
38
+ def _conv_relu_layer(self, in_channels, out_channels, padding=0):
39
+ return nn.Sequential(
40
+ nn.Conv2d(
41
+ in_channels=in_channels,
42
+ out_channels=out_channels,
43
+ kernel_size=3,
44
+ padding=padding,
45
+ ),
46
+ nn.ReLU(),
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.layer(x)
51
+
52
+
53
+ class Encoder(nn.Module):
54
+ def __init__(self, config):
55
+ super(Encoder, self).__init__()
56
+ self.encoder = nn.ModuleDict(
57
+ {
58
+ name: EncoderLayer(
59
+ in_channels=block["in_channels"],
60
+ out_channels=block["out_channels"],
61
+ n_layers=block["n_layers"],
62
+ all_padding=block["all_padding"],
63
+ maxpool=block["maxpool"],
64
+ )
65
+ for name, block in config.items()
66
+ }
67
+ )
68
+ self.maxpool = nn.MaxPool2d(2)
69
+
70
+ def forward(self, x):
71
+ output = dict()
72
+
73
+ for i, (block_name, block) in enumerate(self.encoder.items()):
74
+ x = block(x)
75
+ output[block_name] = x
76
+
77
+ if block.maxpool:
78
+ x = self.maxpool(x)
79
+
80
+ return x, output
81
+
82
+
83
+ """
84
+ upsampling blocks
85
+ (second half of the 'U' in UNet)
86
+ [DECODER]
87
+ """
88
+
89
+
90
+ class DecoderLayer(nn.Module):
91
+ def __init__(
92
+ self, in_channels, out_channels, kernel_size=2, stride=2, padding=[0, 0]
93
+ ):
94
+ super(DecoderLayer, self).__init__()
95
+ self.up_conv = nn.ConvTranspose2d(
96
+ in_channels=in_channels,
97
+ out_channels=in_channels // 2,
98
+ kernel_size=kernel_size,
99
+ stride=stride,
100
+ padding=padding[0],
101
+ )
102
+
103
+ self.conv = nn.Sequential(
104
+ *[
105
+ self._conv_relu_layer(
106
+ in_channels=in_channels if i == 0 else out_channels,
107
+ out_channels=out_channels,
108
+ padding=padding[1],
109
+ )
110
+ for i in range(2)
111
+ ]
112
+ )
113
+
114
+ def _conv_relu_layer(self, in_channels, out_channels, padding=0):
115
+ return nn.Sequential(
116
+ nn.Conv2d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=3,
120
+ padding=padding,
121
+ ),
122
+ nn.ReLU(),
123
+ )
124
+
125
+ @staticmethod
126
+ def crop_cat(x, encoder_output):
127
+ delta = (encoder_output.shape[-1] - x.shape[-1]) // 2
128
+ encoder_output = encoder_output[
129
+ :, :, delta : delta + x.shape[-1], delta : delta + x.shape[-1]
130
+ ]
131
+ return torch.cat((encoder_output, x), dim=1)
132
+
133
+ def forward(self, x, encoder_output):
134
+ x = self.crop_cat(self.up_conv(x), encoder_output)
135
+ return self.conv(x)
136
+
137
+
138
+ class Decoder(nn.Module):
139
+ def __init__(self, config):
140
+ super(Decoder, self).__init__()
141
+ self.decoder = nn.ModuleDict(
142
+ {
143
+ name: DecoderLayer(
144
+ in_channels=block["in_channels"],
145
+ out_channels=block["out_channels"],
146
+ kernel_size=block["kernel_size"],
147
+ stride=block["stride"],
148
+ padding=block["padding"],
149
+ )
150
+ for name, block in config.items()
151
+ }
152
+ )
153
+
154
+ def forward(self, x, encoder_output):
155
+ for name, block in self.decoder.items():
156
+ x = block(x, encoder_output[name])
157
+ return x
158
+
159
+
160
+ class UNet(nn.Module):
161
+ def __init__(self, encoder_config, decoder_config, nclasses):
162
+ super(UNet, self).__init__()
163
+ self.encoder = Encoder(config=encoder_config)
164
+ self.decoder = Decoder(config=decoder_config)
165
+
166
+ self.output = nn.Conv2d(
167
+ in_channels=decoder_config["block1"]["out_channels"],
168
+ out_channels=nclasses,
169
+ kernel_size=1,
170
+ )
171
+
172
+ def forward(self, x):
173
+ x, encoder_step_output = self.encoder(x)
174
+ x = self.decoder(x, encoder_step_output)
175
+ return self.output(x)