Paul Engstler
Initial commit
92f0e98
raw
history blame contribute delete
No virus
3.54 kB
import os
import torch
from monai.networks.nets import DenseNet121, DenseNet169, DenseNet201, DenseNet264
from backbones.unet3d import UNet3D
import utils.config
def _freeze_layers_if_any(model, hparams):
if len(hparams.frozen_layers) == 0:
return model
for (name, param) in model.named_parameters():
if any([name.startswith(to_freeze_name) for to_freeze_name in hparams.frozen_layers]):
param.requires_grad = False
return model
def _replace_inplace_operations(model):
# Grad-CAM compatibility
for module in model.modules():
if hasattr(module, "inplace"):
setattr(module, "inplace", False)
return model
def get_backbone(hparams):
backbone = None
in_channels = 1 + (hparams.mask == 'channel') + hparams.input_dim * hparams.coordinates
if hparams.model_name.startswith('DenseNet'):
if hparams.model_name == "DenseNet121":
net_selection = DenseNet121
elif hparams.model_name == "DenseNet169":
net_selection = DenseNet169
elif hparams.model_name == "DenseNet201":
net_selection = DenseNet201
elif hparams.model_name == "DenseNet264":
net_selection = DenseNet264
else:
raise ValueError(f"Unknown DenseNet: {hparams.model_name}")
backbone = net_selection(
spatial_dims = hparams.input_dim,
in_channels = in_channels,
out_channels = hparams.num_classes - (hparams.loss == 'ordinal_regression'),
dropout_prob = hparams.dropout,
act = ("relu", {"inplace": False}) # inplace has to be set to False to enable use of Grad-CAM
)
# ensure activation maps are not shrunk too much
backbone.features.transition2.pool = torch.nn.Identity()
backbone.features.transition3.pool = torch.nn.Identity()
elif hparams.model_name.lower().startswith("resne"):
# if you use pre-trained models, please add "pretrained_resnet" to the transforms hyperparameter
backbone = torch.hub.load('pytorch/vision:v0.10.0', hparams.model_name, pretrained=hparams.model_name.lower().endswith('-pretrained'))
# reset final fully connected layer to expected number of classes
backbone.fc.out_features = hparams.num_classes - (hparams.loss == 'ordinal_regression')
elif hparams.model_name == 'ModelsGenesis':
backbone = UNet3D(
in_channels=in_channels,
input_size=hparams.input_size,
n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression')
)
weight_dir = os.path.join('data_sl', utils.config.globals["MODELS_GENESIS_PATH"])
checkpoint = torch.load(weight_dir,map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
unparalled_state_dict = {}
for key in state_dict.keys():
unparalled_state_dict[key.replace("module.", "")] = state_dict[key]
backbone.load_state_dict(unparalled_state_dict, strict=False)
elif hparams.model_name == 'UNet3D':
# this is the architecture of Models Genesis minus the pretraining
backbone = UNet3D(
in_channels=in_channels,
input_size=hparams.input_size,
n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression')
)
else:
raise NotImplementedError
backbone = _replace_inplace_operations(backbone)
backbone = _freeze_layers_if_any(backbone, hparams)
return backbone