File size: 12,027 Bytes
4c65bff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Collection of utils to be used by backbones and their components."""
import enum
import inspect
from typing import Iterable, List, Optional, Tuple, Union
class BackboneType(enum.Enum):
TIMM = "timm"
TRANSFORMERS = "transformers"
def verify_out_features_out_indices(
out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]]
):
"""
Verify that out_indices and out_features are valid for the given stage_names.
"""
if stage_names is None:
raise ValueError("Stage_names must be set for transformers backbones")
if out_features is not None:
if not isinstance(out_features, (list,)):
raise ValueError(f"out_features must be a list {type(out_features)}")
if any(feat not in stage_names for feat in out_features):
raise ValueError(f"out_features must be a subset of stage_names: {stage_names} got {out_features}")
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError(f"out_indices must be a list or tuple, got {type(out_indices)}")
if any(idx >= len(stage_names) for idx in out_indices):
raise ValueError("out_indices must be valid indices for stage_names {stage_names}, got {out_indices}")
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
if out_features != [stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
def _align_output_features_output_indices(
out_features: Optional[List[str]],
out_indices: Optional[Union[List[int], Tuple[int]]],
stage_names: List[str],
):
"""
Finds the corresponding `out_features` and `out_indices` for the given `stage_names`.
The logic is as follows:
- `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the
`out_indices`.
- `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the
`out_features`.
- `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage.
- `out_indices` and `out_features` set: input `out_indices` and `out_features` are returned.
Args:
out_features (`List[str]`): The names of the features for the backbone to output.
out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output.
stage_names (`List[str]`): The names of the stages of the backbone.
"""
if out_indices is None and out_features is None:
out_indices = [len(stage_names) - 1]
out_features = [stage_names[-1]]
elif out_indices is None and out_features is not None:
out_indices = [stage_names.index(layer) for layer in out_features]
elif out_features is None and out_indices is not None:
out_features = [stage_names[idx] for idx in out_indices]
return out_features, out_indices
def get_aligned_output_features_output_indices(
out_features: Optional[List[str]],
out_indices: Optional[Union[List[int], Tuple[int]]],
stage_names: List[str],
) -> Tuple[List[str], List[int]]:
"""
Get the `out_features` and `out_indices` so that they are aligned.
The logic is as follows:
- `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the
`out_indices`.
- `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the
`out_features`.
- `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage.
- `out_indices` and `out_features` set: they are verified to be aligned.
Args:
out_features (`List[str]`): The names of the features for the backbone to output.
out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output.
stage_names (`List[str]`): The names of the stages of the backbone.
"""
# First verify that the out_features and out_indices are valid
verify_out_features_out_indices(out_features=out_features, out_indices=out_indices, stage_names=stage_names)
output_features, output_indices = _align_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=stage_names
)
# Verify that the aligned out_features and out_indices are valid
verify_out_features_out_indices(out_features=output_features, out_indices=output_indices, stage_names=stage_names)
return output_features, output_indices
class BackboneMixin:
backbone_type: Optional[BackboneType] = None
def _init_timm_backbone(self, config) -> None:
"""
Initialize the backbone model from timm The backbone must already be loaded to self._backbone
"""
if getattr(self, "_backbone", None) is None:
raise ValueError("self._backbone must be set before calling _init_timm_backbone")
# These will diagree with the defaults for the transformers models e.g. for resnet50
# the transformer model has out_features = ['stem', 'stage1', 'stage2', 'stage3', 'stage4']
# the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4']
self.stage_names = [stage["module"] for stage in self._backbone.feature_info.info]
self.num_features = [stage["num_chs"] for stage in self._backbone.feature_info.info]
out_indices = self._backbone.feature_info.out_indices
out_features = self._backbone.feature_info.module_name()
# We verify the out indices and out features are valid
verify_out_features_out_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
self._out_features, self._out_indices = out_features, out_indices
def _init_transformers_backbone(self, config) -> None:
stage_names = getattr(config, "stage_names")
out_features = getattr(config, "out_features", None)
out_indices = getattr(config, "out_indices", None)
self.stage_names = stage_names
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=stage_names
)
# Number of channels for each stage. This is set in the transformer backbone model init
self.num_features = None
def _init_backbone(self, config) -> None:
"""
Method to initialize the backbone. This method is called by the constructor of the base class after the
pretrained model weights have been loaded.
"""
self.config = config
self.use_timm_backbone = getattr(config, "use_timm_backbone", False)
self.backbone_type = BackboneType.TIMM if self.use_timm_backbone else BackboneType.TRANSFORMERS
if self.backbone_type == BackboneType.TIMM:
self._init_timm_backbone(config)
elif self.backbone_type == BackboneType.TRANSFORMERS:
self._init_transformers_backbone(config)
else:
raise ValueError(f"backbone_type {self.backbone_type} not supported.")
@property
def out_features(self):
return self._out_features
@out_features.setter
def out_features(self, out_features: List[str]):
"""
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
"""
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=None, stage_names=self.stage_names
)
@property
def out_indices(self):
return self._out_indices
@out_indices.setter
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
"""
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
"""
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=None, out_indices=out_indices, stage_names=self.stage_names
)
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
def forward(
self,
pixel_values,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
include the `out_features` and `out_indices` attributes.
"""
output = super().to_dict()
output["out_features"] = output.pop("_out_features")
output["out_indices"] = output.pop("_out_indices")
return output
class BackboneConfigMixin:
"""
A Mixin to support handling the `out_features` and `out_indices` attributes for the backbone configurations.
"""
@property
def out_features(self):
return self._out_features
@out_features.setter
def out_features(self, out_features: List[str]):
"""
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
"""
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=None, stage_names=self.stage_names
)
@property
def out_indices(self):
return self._out_indices
@out_indices.setter
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
"""
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
"""
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=None, out_indices=out_indices, stage_names=self.stage_names
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
include the `out_features` and `out_indices` attributes.
"""
output = super().to_dict()
output["out_features"] = output.pop("_out_features")
output["out_indices"] = output.pop("_out_indices")
return output
|