|
from functools import partial |
|
|
|
import torch |
|
from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns |
|
from torch import nn |
|
from torch.nn import Dropout2d, Conv2d |
|
from torch.nn.modules.dropout import Dropout |
|
from torch.nn.modules.linear import Linear |
|
from torch.nn.modules.pooling import AdaptiveAvgPool2d |
|
from torch.nn.modules.upsampling import UpsamplingBilinear2d |
|
|
|
encoder_params = { |
|
"tf_efficientnet_b3_ns": { |
|
"features": 1536, |
|
"filters": [40, 32, 48, 136, 1536], |
|
"decoder_filters": [64, 128, 256, 256], |
|
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2) |
|
}, |
|
"tf_efficientnet_b5_ns": { |
|
"features": 2048, |
|
"filters": [48, 40, 64, 176, 2048], |
|
"decoder_filters": [64, 128, 256, 256], |
|
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2) |
|
}, |
|
} |
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.layer = nn.Sequential( |
|
nn.Upsample(scale_factor=2), |
|
nn.Conv2d(in_channels, out_channels, 3, padding=1), |
|
nn.ReLU(inplace=True) |
|
) |
|
|
|
def forward(self, x): |
|
return self.layer(x) |
|
|
|
|
|
class ConcatBottleneck(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.seq = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, 3, padding=1), |
|
nn.ReLU(inplace=True) |
|
) |
|
|
|
def forward(self, dec, enc): |
|
x = torch.cat([dec, enc], dim=1) |
|
return self.seq(x) |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, decoder_filters, filters, upsample_filters=None, |
|
decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0): |
|
super().__init__() |
|
self.decoder_filters = decoder_filters |
|
self.filters = filters |
|
self.decoder_block = decoder_block |
|
self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))]) |
|
self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f) |
|
for i, f in enumerate(reversed(decoder_filters))]) |
|
self.dropout = Dropout2d(dropout) if dropout > 0 else None |
|
self.last_block = None |
|
if upsample_filters: |
|
self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters) |
|
else: |
|
self.last_block = UpsamplingBilinear2d(scale_factor=2) |
|
|
|
def forward(self, encoder_results: list): |
|
x = encoder_results[0] |
|
bottlenecks = self.bottlenecks |
|
for idx, bottleneck in enumerate(bottlenecks): |
|
rev_idx = - (idx + 1) |
|
x = self.decoder_stages[rev_idx](x) |
|
x = bottleneck(x, encoder_results[-rev_idx]) |
|
if self.last_block: |
|
x = self.last_block(x) |
|
if self.dropout: |
|
x = self.dropout(x) |
|
return x |
|
|
|
def _get_decoder(self, layer): |
|
idx = layer + 1 |
|
if idx == len(self.decoder_filters): |
|
in_channels = self.filters[idx] |
|
else: |
|
in_channels = self.decoder_filters[idx] |
|
return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)]) |
|
|
|
|
|
def _initialize_weights(module): |
|
for m in module.modules(): |
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): |
|
m.weight.data = nn.init.kaiming_normal_(m.weight.data) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.BatchNorm2d): |
|
m.weight.data.fill_(1) |
|
m.bias.data.zero_() |
|
|
|
|
|
class EfficientUnetClassifier(nn.Module): |
|
def __init__(self, encoder, dropout_rate=0.5) -> None: |
|
super().__init__() |
|
self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"], |
|
filters=encoder_params[encoder]["filters"]) |
|
self.avg_pool = AdaptiveAvgPool2d((1, 1)) |
|
self.dropout = Dropout(dropout_rate) |
|
self.fc = Linear(encoder_params[encoder]["features"], 1) |
|
self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False) |
|
_initialize_weights(self) |
|
self.encoder = encoder_params[encoder]["init_op"]() |
|
|
|
def get_encoder_features(self, x): |
|
encoder_results = [] |
|
x = self.encoder.conv_stem(x) |
|
x = self.encoder.bn1(x) |
|
x = self.encoder.act1(x) |
|
encoder_results.append(x) |
|
x = self.encoder.blocks[:2](x) |
|
encoder_results.append(x) |
|
x = self.encoder.blocks[2:3](x) |
|
encoder_results.append(x) |
|
x = self.encoder.blocks[3:5](x) |
|
encoder_results.append(x) |
|
x = self.encoder.blocks[5:](x) |
|
x = self.encoder.conv_head(x) |
|
x = self.encoder.bn2(x) |
|
x = self.encoder.act2(x) |
|
encoder_results.append(x) |
|
encoder_results = list(reversed(encoder_results)) |
|
return encoder_results |
|
|
|
def forward(self, x): |
|
encoder_results = self.get_encoder_features(x) |
|
seg = self.final(self.decoder(encoder_results)) |
|
x = encoder_results[0] |
|
x = self.avg_pool(x).flatten(1) |
|
x = self.dropout(x) |
|
x = self.fc(x) |
|
return x, seg |
|
|
|
|
|
if __name__ == '__main__': |
|
model = EfficientUnetClassifier("tf_efficientnet_b5_ns") |
|
model.eval() |
|
with torch.no_grad(): |
|
input = torch.rand(4, 3, 224, 224) |
|
print(model(input)) |
|
|