File size: 3,535 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import torch

from monai.networks.nets import DenseNet121, DenseNet169, DenseNet201, DenseNet264
from backbones.unet3d import UNet3D

import utils.config

def _freeze_layers_if_any(model, hparams):
    if len(hparams.frozen_layers) == 0:
        return model

    for (name, param) in model.named_parameters():
        if any([name.startswith(to_freeze_name) for to_freeze_name in hparams.frozen_layers]):
            param.requires_grad = False

    return model

def _replace_inplace_operations(model):
    # Grad-CAM compatibility
    for module in model.modules():
        if hasattr(module, "inplace"):
            setattr(module, "inplace", False)
    return model

def get_backbone(hparams):
    backbone = None

    in_channels = 1 + (hparams.mask == 'channel') + hparams.input_dim * hparams.coordinates

    if hparams.model_name.startswith('DenseNet'):
        if hparams.model_name == "DenseNet121":
            net_selection = DenseNet121
        elif hparams.model_name == "DenseNet169":
            net_selection = DenseNet169
        elif hparams.model_name == "DenseNet201":
            net_selection = DenseNet201
        elif hparams.model_name == "DenseNet264":
            net_selection = DenseNet264
        else:
            raise ValueError(f"Unknown DenseNet: {hparams.model_name}")

        backbone = net_selection(
            spatial_dims = hparams.input_dim, 
            in_channels = in_channels, 
            out_channels = hparams.num_classes - (hparams.loss == 'ordinal_regression'), 
            dropout_prob = hparams.dropout,
            act = ("relu", {"inplace": False}) # inplace has to be set to False to enable use of Grad-CAM
        )

        # ensure activation maps are not shrunk too much
        backbone.features.transition2.pool = torch.nn.Identity()
        backbone.features.transition3.pool = torch.nn.Identity()

    elif hparams.model_name.lower().startswith("resne"):
        # if you use pre-trained models, please add "pretrained_resnet" to the transforms hyperparameter
        backbone = torch.hub.load('pytorch/vision:v0.10.0', hparams.model_name, pretrained=hparams.model_name.lower().endswith('-pretrained'))

        # reset final fully connected layer to expected number of classes
        backbone.fc.out_features = hparams.num_classes - (hparams.loss == 'ordinal_regression')

    elif hparams.model_name == 'ModelsGenesis':
        backbone = UNet3D(
            in_channels=in_channels,
            input_size=hparams.input_size,
            n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression')
        )

        weight_dir = os.path.join('data_sl', utils.config.globals["MODELS_GENESIS_PATH"])

        checkpoint = torch.load(weight_dir,map_location=torch.device('cpu')) 
        state_dict = checkpoint['state_dict']
        unparalled_state_dict = {} 

        for key in state_dict.keys():
            unparalled_state_dict[key.replace("module.", "")] = state_dict[key]

        backbone.load_state_dict(unparalled_state_dict, strict=False) 

    elif hparams.model_name == 'UNet3D':
        # this is the architecture of Models Genesis minus the pretraining
        backbone = UNet3D(
            in_channels=in_channels,
            input_size=hparams.input_size,
            n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression')
        )
    else:
          raise NotImplementedError

    backbone = _replace_inplace_operations(backbone)
    backbone = _freeze_layers_if_any(backbone, hparams)

    return backbone