# # Copyright (c) OpenMMLab. All rights reserved. | |
# import logging | |
# from abc import ABCMeta, abstractmethod | |
# | |
# import torch.nn as nn | |
# | |
# from .utils import load_checkpoint | |
# | |
# | |
# class BaseBackbone(nn.Module, metaclass=ABCMeta): | |
# """Base backbone. | |
# | |
# This class defines the basic functions of a backbone. Any backbone that | |
# inherits this class should at least define its own `forward` function. | |
# """ | |
# | |
# def init_weights(self, pretrained=None): | |
# """Init backbone weights. | |
# | |
# Args: | |
# pretrained (str | None): If pretrained is a string, then it | |
# initializes backbone weights by loading the pretrained | |
# checkpoint. If pretrained is None, then it follows default | |
# initializer or customized initializer in subclasses. | |
# """ | |
# if isinstance(pretrained, str): | |
# logger = logging.getLogger() | |
# load_checkpoint(self, pretrained, strict=False, logger=logger) | |
# elif pretrained is None: | |
# # use default initializer or customized initializer in subclasses | |
# pass | |
# else: | |
# raise TypeError('pretrained must be a str or None.' | |
# f' But received {type(pretrained)}.') | |
# | |
# @abstractmethod | |
# def forward(self, x): | |
# """Forward function. | |
# | |
# Args: | |
# x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of | |
# torch.Tensor, containing input data for forward computation. | |
# """ | |
# Copyright (c) OpenMMLab. All rights reserved. | |
import logging | |
from abc import ABCMeta, abstractmethod | |
import torch.nn as nn | |
from .utils import load_checkpoint | |
# from mmcv_custom.checkpoint import load_checkpoint | |
class BaseBackbone(nn.Module, metaclass=ABCMeta): | |
"""Base backbone. | |
This class defines the basic functions of a backbone. Any backbone that | |
inherits this class should at least define its own `forward` function. | |
""" | |
def init_weights(self, pretrained=None, patch_padding='pad'): | |
"""Init backbone weights. | |
Args: | |
pretrained (str | None): If pretrained is a string, then it | |
initializes backbone weights by loading the pretrained | |
checkpoint. If pretrained is None, then it follows default | |
initializer or customized initializer in subclasses. | |
""" | |
if isinstance(pretrained, str): | |
logger = logging.getLogger() | |
load_checkpoint(self, pretrained, strict=False, logger=logger, patch_padding=patch_padding) | |
elif pretrained is None: | |
# use default initializer or customized initializer in subclasses | |
pass | |
else: | |
raise TypeError('pretrained must be a str or None.' | |
f' But received {type(pretrained)}.') | |
def forward(self, x): | |
"""Forward function. | |
Args: | |
x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of | |
torch.Tensor, containing input data for forward computation. | |
""" |