thecho7's picture
LFS dat
c426e13
from functools import partial
import torch
from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns
from torch import nn
from torch.nn import Dropout2d, Conv2d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.pooling import AdaptiveAvgPool2d
from torch.nn.modules.upsampling import UpsamplingBilinear2d
encoder_params = {
"tf_efficientnet_b3_ns": {
"features": 1536,
"filters": [40, 32, 48, 136, 1536],
"decoder_filters": [64, 128, 256, 256],
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
},
"tf_efficientnet_b5_ns": {
"features": 2048,
"filters": [48, 40, 64, 176, 2048],
"decoder_filters": [64, 128, 256, 256],
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
},
}
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.layer = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.layer(x)
class ConcatBottleneck(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, dec, enc):
x = torch.cat([dec, enc], dim=1)
return self.seq(x)
class Decoder(nn.Module):
def __init__(self, decoder_filters, filters, upsample_filters=None,
decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0):
super().__init__()
self.decoder_filters = decoder_filters
self.filters = filters
self.decoder_block = decoder_block
self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))])
self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f)
for i, f in enumerate(reversed(decoder_filters))])
self.dropout = Dropout2d(dropout) if dropout > 0 else None
self.last_block = None
if upsample_filters:
self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters)
else:
self.last_block = UpsamplingBilinear2d(scale_factor=2)
def forward(self, encoder_results: list):
x = encoder_results[0]
bottlenecks = self.bottlenecks
for idx, bottleneck in enumerate(bottlenecks):
rev_idx = - (idx + 1)
x = self.decoder_stages[rev_idx](x)
x = bottleneck(x, encoder_results[-rev_idx])
if self.last_block:
x = self.last_block(x)
if self.dropout:
x = self.dropout(x)
return x
def _get_decoder(self, layer):
idx = layer + 1
if idx == len(self.decoder_filters):
in_channels = self.filters[idx]
else:
in_channels = self.decoder_filters[idx]
return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)])
def _initialize_weights(module):
for m in module.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
m.weight.data = nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class EfficientUnetClassifier(nn.Module):
def __init__(self, encoder, dropout_rate=0.5) -> None:
super().__init__()
self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"],
filters=encoder_params[encoder]["filters"])
self.avg_pool = AdaptiveAvgPool2d((1, 1))
self.dropout = Dropout(dropout_rate)
self.fc = Linear(encoder_params[encoder]["features"], 1)
self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False)
_initialize_weights(self)
self.encoder = encoder_params[encoder]["init_op"]()
def get_encoder_features(self, x):
encoder_results = []
x = self.encoder.conv_stem(x)
x = self.encoder.bn1(x)
x = self.encoder.act1(x)
encoder_results.append(x)
x = self.encoder.blocks[:2](x)
encoder_results.append(x)
x = self.encoder.blocks[2:3](x)
encoder_results.append(x)
x = self.encoder.blocks[3:5](x)
encoder_results.append(x)
x = self.encoder.blocks[5:](x)
x = self.encoder.conv_head(x)
x = self.encoder.bn2(x)
x = self.encoder.act2(x)
encoder_results.append(x)
encoder_results = list(reversed(encoder_results))
return encoder_results
def forward(self, x):
encoder_results = self.get_encoder_features(x)
seg = self.final(self.decoder(encoder_results))
x = encoder_results[0]
x = self.avg_pool(x).flatten(1)
x = self.dropout(x)
x = self.fc(x)
return x, seg
if __name__ == '__main__':
model = EfficientUnetClassifier("tf_efficientnet_b5_ns")
model.eval()
with torch.no_grad():
input = torch.rand(4, 3, 224, 224)
print(model(input))