Spaces:
Running
on
Zero
Running
on
Zero
"""Copyright (C) 2024 Apple Inc. All Rights Reserved. | |
Dense Prediction Transformer Decoder architecture. | |
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413 | |
""" | |
from __future__ import annotations | |
from typing import Iterable | |
import torch | |
from torch import nn | |
class MultiresConvDecoder(nn.Module): | |
"""Decoder for multi-resolution encodings.""" | |
def __init__( | |
self, | |
dims_encoder: Iterable[int], | |
dim_decoder: int, | |
): | |
"""Initialize multiresolution convolutional decoder. | |
Args: | |
---- | |
dims_encoder: Expected dims at each level from the encoder. | |
dim_decoder: Dim of decoder features. | |
""" | |
super().__init__() | |
self.dims_encoder = list(dims_encoder) | |
self.dim_decoder = dim_decoder | |
self.dim_out = dim_decoder | |
num_encoders = len(self.dims_encoder) | |
# At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution | |
# when the dimensions mismatch. Otherwise we do not do anything, which is | |
# the default behavior of monodepth. | |
conv0 = ( | |
nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False) | |
if self.dims_encoder[0] != dim_decoder | |
else nn.Identity() | |
) | |
convs = [conv0] | |
for i in range(1, num_encoders): | |
convs.append( | |
nn.Conv2d( | |
self.dims_encoder[i], | |
dim_decoder, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
) | |
) | |
self.convs = nn.ModuleList(convs) | |
fusions = [] | |
for i in range(num_encoders): | |
fusions.append( | |
FeatureFusionBlock2d( | |
num_features=dim_decoder, | |
deconv=(i != 0), | |
batch_norm=False, | |
) | |
) | |
self.fusions = nn.ModuleList(fusions) | |
def forward(self, encodings: torch.Tensor) -> torch.Tensor: | |
"""Decode the multi-resolution encodings.""" | |
num_levels = len(encodings) | |
num_encoders = len(self.dims_encoder) | |
if num_levels != num_encoders: | |
raise ValueError( | |
f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}." | |
) | |
# Project features of different encoder dims to the same decoder dim. | |
# Fuse features from the lowest resolution (num_levels-1) | |
# to the highest (0). | |
features = self.convs[-1](encodings[-1]) | |
lowres_features = features | |
features = self.fusions[-1](features) | |
for i in range(num_levels - 2, -1, -1): | |
features_i = self.convs[i](encodings[i]) | |
features = self.fusions[i](features, features_i) | |
return features, lowres_features | |
class ResidualBlock(nn.Module): | |
"""Generic implementation of residual blocks. | |
This implements a generic residual block from | |
He et al. - Identity Mappings in Deep Residual Networks (2016), | |
https://arxiv.org/abs/1603.05027 | |
which can be further customized via factory functions. | |
""" | |
def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None: | |
"""Initialize ResidualBlock.""" | |
super().__init__() | |
self.residual = residual | |
self.shortcut = shortcut | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Apply residual block.""" | |
delta_x = self.residual(x) | |
if self.shortcut is not None: | |
x = self.shortcut(x) | |
return x + delta_x | |
class FeatureFusionBlock2d(nn.Module): | |
"""Feature fusion for DPT.""" | |
def __init__( | |
self, | |
num_features: int, | |
deconv: bool = False, | |
batch_norm: bool = False, | |
): | |
"""Initialize feature fusion block. | |
Args: | |
---- | |
num_features: Input and output dimensions. | |
deconv: Whether to use deconv before the final output conv. | |
batch_norm: Whether to use batch normalization in resnet blocks. | |
""" | |
super().__init__() | |
self.resnet1 = self._residual_block(num_features, batch_norm) | |
self.resnet2 = self._residual_block(num_features, batch_norm) | |
self.use_deconv = deconv | |
if deconv: | |
self.deconv = nn.ConvTranspose2d( | |
in_channels=num_features, | |
out_channels=num_features, | |
kernel_size=2, | |
stride=2, | |
padding=0, | |
bias=False, | |
) | |
self.out_conv = nn.Conv2d( | |
num_features, | |
num_features, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=True, | |
) | |
self.skip_add = nn.quantized.FloatFunctional() | |
def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor: | |
"""Process and fuse input features.""" | |
x = x0 | |
if x1 is not None: | |
res = self.resnet1(x1) | |
x = self.skip_add.add(x, res) | |
x = self.resnet2(x) | |
if self.use_deconv: | |
x = self.deconv(x) | |
x = self.out_conv(x) | |
return x | |
def _residual_block(num_features: int, batch_norm: bool): | |
"""Create a residual block.""" | |
def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]: | |
layers = [ | |
nn.ReLU(False), | |
nn.Conv2d( | |
num_features, | |
num_features, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=not batch_norm, | |
), | |
] | |
if batch_norm: | |
layers.append(nn.BatchNorm2d(dim)) | |
return layers | |
residual = nn.Sequential( | |
*_create_block(dim=num_features, batch_norm=batch_norm), | |
*_create_block(dim=num_features, batch_norm=batch_norm), | |
) | |
return ResidualBlock(residual) | |