import torch from torch import nn from torchvision.transforms import functional as f class UNet(torch.nn.Module): def __init__(self, device, in_channels: int = 3, num_classes: int = 3) -> None: super().__init__() self.block_1 = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), #-> Channels = 64 nn.BatchNorm2d(64, device=device), nn.ReLU(), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64 nn.BatchNorm2d(64, device=device), nn.ReLU() ) self.max_pool_2x2_1 = nn.MaxPool2d(kernel_size=2, stride=2) self.block_2 = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128 nn.BatchNorm2d(128, device=device), nn.ReLU(), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128 nn.BatchNorm2d(128, device=device), nn.ReLU() ) self.max_pool_2x2_2 = nn.MaxPool2d(kernel_size=2, stride=2) self.block_3 = nn.Sequential( nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256 nn.BatchNorm2d(256, device=device), nn.ReLU(), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256 nn.BatchNorm2d(256, device=device), nn.ReLU() ) self.max_pool_2x2_3 = nn.MaxPool2d(kernel_size=2, stride=2) self.block_4 = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512 nn.BatchNorm2d(512, device=device), nn.ReLU(), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512 nn.BatchNorm2d(512, device=device), nn.ReLU() ) self.drop_out_1 = nn.Dropout(p=0.5) self.max_pool_2x2_4 = nn.MaxPool2d(kernel_size=2, stride=2) self.block_5 = nn.Sequential( nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024 nn.BatchNorm2d(1024, device=device), nn.ReLU(), nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024 nn.BatchNorm2d(1024, device=device), nn.ReLU() ) self.drop_out_2 = nn.Dropout(p=0.5) self.up_conv_2x2_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 512 #after up_sampled, the tensor will be concatenate with the output of the block_4 which is a 512-channels tensor # so that the tensor to put in the block 6 will be a (512 + 512)-channels = 1024-channels tensor self.block_6 = nn.Sequential( nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 512 nn.BatchNorm2d(512, device=device), nn.ReLU(), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512 nn.BatchNorm2d(512, device=device), nn.ReLU() ) self.drop_out_3 = nn.Dropout(p=0.5) self.up_conv_2x2_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 256 #The same as up_conv_2x2_1 self.block_7 = nn.Sequential( nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 256 nn.BatchNorm2d(256, device=device), nn.ReLU(), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256 nn.BatchNorm2d(256, device=device), nn.ReLU() ) self.up_conv_2x2_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 128 #The same as up_conv_2x2_1 self.block_8 = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 128 nn.BatchNorm2d(128, device=device), nn.ReLU(), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128 nn.BatchNorm2d(128, device=device), nn.ReLU() ) self.up_conv_2x2_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 64 #The same as up_conv_2x2_1 self.block_9 = nn.Sequential( nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 64 nn.BatchNorm2d(64, device=device), nn.ReLU(), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64 nn.BatchNorm2d(64, device=device), nn.ReLU() ) self.last_conv_1x1 = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1, stride=1, padding=0, device=device) # -> channels = num_classes (default = 3 for [background, borders, objects]) def forward(self, x): block_1_result = self.block_1(x) block_2_result = self.block_2(self.max_pool_2x2_1(block_1_result)) block_3_result = self.block_3(self.max_pool_2x2_2(block_2_result)) block_4_result = self.block_4(self.max_pool_2x2_3(block_3_result)) block_4_result = self.drop_out_1(block_4_result) block_5_result = self.block_5(self.max_pool_2x2_4(block_4_result)) block_5_result = self.drop_out_2(block_5_result) up_conv_1_result = self.up_conv_2x2_1(block_5_result) block_4_result = f.center_crop(block_4_result, [up_conv_1_result.shape[2], up_conv_1_result.shape[3]]) concat_1_result = torch.cat([block_4_result, up_conv_1_result], axis=1) block_6_result = self.block_6(concat_1_result) block_6_result = self.drop_out_3(block_6_result) up_conv_2_result = self.up_conv_2x2_2(block_6_result) block_3_result = f.center_crop(block_3_result, [up_conv_2_result.shape[2], up_conv_2_result.shape[3]]) concat_2_result = torch.cat([block_3_result, up_conv_2_result], axis=1) block_7_result = self.block_7(concat_2_result) up_conv_3_result = self.up_conv_2x2_3(block_7_result) block_2_result = f.center_crop(block_2_result, [up_conv_3_result.shape[2], up_conv_3_result.shape[3]]) concat_3_result = torch.cat([block_2_result, up_conv_3_result], axis=1) block_8_result = self.block_8(concat_3_result) up_conv_4_result = self.up_conv_2x2_4(block_8_result) block_1_result = f.center_crop(block_1_result, [up_conv_4_result.shape[2], up_conv_4_result.shape[3]]) concat_4_result = torch.cat([block_1_result, up_conv_4_result], axis=1) block_9_result = self.block_9(concat_4_result) last_block_result = self.last_conv_1x1(block_9_result) return last_block_result