HienK64BKHN's picture
Upload 7 files
ec47fb5 verified
raw
history blame contribute delete
No virus
7.74 kB
import torch
from torch import nn
from torchvision.transforms import functional as f
class UNet(torch.nn.Module):
def __init__(self, device, in_channels: int = 3, num_classes: int = 3) -> None:
super().__init__()
self.block_1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), #-> Channels = 64
nn.BatchNorm2d(64, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64
nn.BatchNorm2d(64, device=device),
nn.ReLU()
)
self.max_pool_2x2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.block_2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128
nn.BatchNorm2d(128, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128
nn.BatchNorm2d(128, device=device),
nn.ReLU()
)
self.max_pool_2x2_2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.block_3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256
nn.BatchNorm2d(256, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256
nn.BatchNorm2d(256, device=device),
nn.ReLU()
)
self.max_pool_2x2_3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.block_4 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512
nn.BatchNorm2d(512, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512
nn.BatchNorm2d(512, device=device),
nn.ReLU()
)
self.drop_out_1 = nn.Dropout(p=0.5)
self.max_pool_2x2_4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.block_5 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024
nn.BatchNorm2d(1024, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024
nn.BatchNorm2d(1024, device=device),
nn.ReLU()
)
self.drop_out_2 = nn.Dropout(p=0.5)
self.up_conv_2x2_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 512
#after up_sampled, the tensor will be concatenate with the output of the block_4 which is a 512-channels tensor
# so that the tensor to put in the block 6 will be a (512 + 512)-channels = 1024-channels tensor
self.block_6 = nn.Sequential(
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 512
nn.BatchNorm2d(512, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512
nn.BatchNorm2d(512, device=device),
nn.ReLU()
)
self.drop_out_3 = nn.Dropout(p=0.5)
self.up_conv_2x2_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 256
#The same as up_conv_2x2_1
self.block_7 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 256
nn.BatchNorm2d(256, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256
nn.BatchNorm2d(256, device=device),
nn.ReLU()
)
self.up_conv_2x2_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 128
#The same as up_conv_2x2_1
self.block_8 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 128
nn.BatchNorm2d(128, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128
nn.BatchNorm2d(128, device=device),
nn.ReLU()
)
self.up_conv_2x2_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 64
#The same as up_conv_2x2_1
self.block_9 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 64
nn.BatchNorm2d(64, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64
nn.BatchNorm2d(64, device=device),
nn.ReLU()
)
self.last_conv_1x1 = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1, stride=1, padding=0, device=device) # -> channels = num_classes (default = 3 for [background, borders, objects])
def forward(self, x):
block_1_result = self.block_1(x)
block_2_result = self.block_2(self.max_pool_2x2_1(block_1_result))
block_3_result = self.block_3(self.max_pool_2x2_2(block_2_result))
block_4_result = self.block_4(self.max_pool_2x2_3(block_3_result))
block_4_result = self.drop_out_1(block_4_result)
block_5_result = self.block_5(self.max_pool_2x2_4(block_4_result))
block_5_result = self.drop_out_2(block_5_result)
up_conv_1_result = self.up_conv_2x2_1(block_5_result)
block_4_result = f.center_crop(block_4_result, [up_conv_1_result.shape[2], up_conv_1_result.shape[3]])
concat_1_result = torch.cat([block_4_result, up_conv_1_result], axis=1)
block_6_result = self.block_6(concat_1_result)
block_6_result = self.drop_out_3(block_6_result)
up_conv_2_result = self.up_conv_2x2_2(block_6_result)
block_3_result = f.center_crop(block_3_result, [up_conv_2_result.shape[2], up_conv_2_result.shape[3]])
concat_2_result = torch.cat([block_3_result, up_conv_2_result], axis=1)
block_7_result = self.block_7(concat_2_result)
up_conv_3_result = self.up_conv_2x2_3(block_7_result)
block_2_result = f.center_crop(block_2_result, [up_conv_3_result.shape[2], up_conv_3_result.shape[3]])
concat_3_result = torch.cat([block_2_result, up_conv_3_result], axis=1)
block_8_result = self.block_8(concat_3_result)
up_conv_4_result = self.up_conv_2x2_4(block_8_result)
block_1_result = f.center_crop(block_1_result, [up_conv_4_result.shape[2], up_conv_4_result.shape[3]])
concat_4_result = torch.cat([block_1_result, up_conv_4_result], axis=1)
block_9_result = self.block_9(concat_4_result)
last_block_result = self.last_conv_1x1(block_9_result)
return last_block_result