Spaces:
Sleeping
Sleeping
File size: 7,735 Bytes
ec47fb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import torch
from torch import nn
from torchvision.transforms import functional as f
class UNet(torch.nn.Module):
def __init__(self, device, in_channels: int = 3, num_classes: int = 3) -> None:
super().__init__()
self.block_1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), #-> Channels = 64
nn.BatchNorm2d(64, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64
nn.BatchNorm2d(64, device=device),
nn.ReLU()
)
self.max_pool_2x2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.block_2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128
nn.BatchNorm2d(128, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128
nn.BatchNorm2d(128, device=device),
nn.ReLU()
)
self.max_pool_2x2_2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.block_3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256
nn.BatchNorm2d(256, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256
nn.BatchNorm2d(256, device=device),
nn.ReLU()
)
self.max_pool_2x2_3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.block_4 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512
nn.BatchNorm2d(512, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512
nn.BatchNorm2d(512, device=device),
nn.ReLU()
)
self.drop_out_1 = nn.Dropout(p=0.5)
self.max_pool_2x2_4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.block_5 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024
nn.BatchNorm2d(1024, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024
nn.BatchNorm2d(1024, device=device),
nn.ReLU()
)
self.drop_out_2 = nn.Dropout(p=0.5)
self.up_conv_2x2_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 512
#after up_sampled, the tensor will be concatenate with the output of the block_4 which is a 512-channels tensor
# so that the tensor to put in the block 6 will be a (512 + 512)-channels = 1024-channels tensor
self.block_6 = nn.Sequential(
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 512
nn.BatchNorm2d(512, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512
nn.BatchNorm2d(512, device=device),
nn.ReLU()
)
self.drop_out_3 = nn.Dropout(p=0.5)
self.up_conv_2x2_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 256
#The same as up_conv_2x2_1
self.block_7 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 256
nn.BatchNorm2d(256, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256
nn.BatchNorm2d(256, device=device),
nn.ReLU()
)
self.up_conv_2x2_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 128
#The same as up_conv_2x2_1
self.block_8 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 128
nn.BatchNorm2d(128, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128
nn.BatchNorm2d(128, device=device),
nn.ReLU()
)
self.up_conv_2x2_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 64
#The same as up_conv_2x2_1
self.block_9 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 64
nn.BatchNorm2d(64, device=device),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64
nn.BatchNorm2d(64, device=device),
nn.ReLU()
)
self.last_conv_1x1 = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1, stride=1, padding=0, device=device) # -> channels = num_classes (default = 3 for [background, borders, objects])
def forward(self, x):
block_1_result = self.block_1(x)
block_2_result = self.block_2(self.max_pool_2x2_1(block_1_result))
block_3_result = self.block_3(self.max_pool_2x2_2(block_2_result))
block_4_result = self.block_4(self.max_pool_2x2_3(block_3_result))
block_4_result = self.drop_out_1(block_4_result)
block_5_result = self.block_5(self.max_pool_2x2_4(block_4_result))
block_5_result = self.drop_out_2(block_5_result)
up_conv_1_result = self.up_conv_2x2_1(block_5_result)
block_4_result = f.center_crop(block_4_result, [up_conv_1_result.shape[2], up_conv_1_result.shape[3]])
concat_1_result = torch.cat([block_4_result, up_conv_1_result], axis=1)
block_6_result = self.block_6(concat_1_result)
block_6_result = self.drop_out_3(block_6_result)
up_conv_2_result = self.up_conv_2x2_2(block_6_result)
block_3_result = f.center_crop(block_3_result, [up_conv_2_result.shape[2], up_conv_2_result.shape[3]])
concat_2_result = torch.cat([block_3_result, up_conv_2_result], axis=1)
block_7_result = self.block_7(concat_2_result)
up_conv_3_result = self.up_conv_2x2_3(block_7_result)
block_2_result = f.center_crop(block_2_result, [up_conv_3_result.shape[2], up_conv_3_result.shape[3]])
concat_3_result = torch.cat([block_2_result, up_conv_3_result], axis=1)
block_8_result = self.block_8(concat_3_result)
up_conv_4_result = self.up_conv_2x2_4(block_8_result)
block_1_result = f.center_crop(block_1_result, [up_conv_4_result.shape[2], up_conv_4_result.shape[3]])
concat_4_result = torch.cat([block_1_result, up_conv_4_result], axis=1)
block_9_result = self.block_9(concat_4_result)
last_block_result = self.last_conv_1x1(block_9_result)
return last_block_result
|