Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (C) 2024 Apple Inc. All Rights Reserved. | |
# DepthProEncoder combining patch and image encoders. | |
from __future__ import annotations | |
import math | |
from typing import Iterable, Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class DepthProEncoder(nn.Module): | |
"""DepthPro Encoder. | |
An encoder aimed at creating multi-resolution encodings from Vision Transformers. | |
""" | |
def __init__( | |
self, | |
dims_encoder: Iterable[int], | |
patch_encoder: nn.Module, | |
image_encoder: nn.Module, | |
hook_block_ids: Iterable[int], | |
decoder_features: int, | |
): | |
"""Initialize DepthProEncoder. | |
The framework | |
1. creates an image pyramid, | |
2. generates overlapping patches with a sliding window at each pyramid level, | |
3. creates batched encodings via vision transformer backbones, | |
4. produces multi-resolution encodings. | |
Args: | |
---- | |
img_size: Backbone image resolution. | |
dims_encoder: Dimensions of the encoder at different layers. | |
patch_encoder: Backbone used for patches. | |
image_encoder: Backbone used for global image encoder. | |
hook_block_ids: Hooks to obtain intermediate features for the patch encoder model. | |
decoder_features: Number of feature output in the decoder. | |
""" | |
super().__init__() | |
self.dims_encoder = list(dims_encoder) | |
self.patch_encoder = patch_encoder | |
self.image_encoder = image_encoder | |
self.hook_block_ids = list(hook_block_ids) | |
patch_encoder_embed_dim = patch_encoder.embed_dim | |
image_encoder_embed_dim = image_encoder.embed_dim | |
self.out_size = int( | |
patch_encoder.patch_embed.img_size[0] // patch_encoder.patch_embed.patch_size[0] | |
) | |
def _create_project_upsample_block( | |
dim_in: int, | |
dim_out: int, | |
upsample_layers: int, | |
dim_int: Optional[int] = None, | |
) -> nn.Module: | |
if dim_int is None: | |
dim_int = dim_out | |
# Projection. | |
blocks = [ | |
nn.Conv2d( | |
in_channels=dim_in, | |
out_channels=dim_int, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
) | |
] | |
# Upsampling. | |
blocks += [ | |
nn.ConvTranspose2d( | |
in_channels=dim_int if i == 0 else dim_out, | |
out_channels=dim_out, | |
kernel_size=2, | |
stride=2, | |
padding=0, | |
bias=False, | |
) | |
for i in range(upsample_layers) | |
] | |
return nn.Sequential(*blocks) | |
self.upsample_latent0 = _create_project_upsample_block( | |
dim_in=patch_encoder_embed_dim, | |
dim_int=self.dims_encoder[0], | |
dim_out=decoder_features, | |
upsample_layers=3, | |
) | |
self.upsample_latent1 = _create_project_upsample_block( | |
dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[0], upsample_layers=2 | |
) | |
self.upsample0 = _create_project_upsample_block( | |
dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=1 | |
) | |
self.upsample1 = _create_project_upsample_block( | |
dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1 | |
) | |
self.upsample2 = _create_project_upsample_block( | |
dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1 | |
) | |
self.upsample_lowres = nn.ConvTranspose2d( | |
in_channels=image_encoder_embed_dim, | |
out_channels=self.dims_encoder[3], | |
kernel_size=2, | |
stride=2, | |
padding=0, | |
bias=True, | |
) | |
self.fuse_lowres = nn.Conv2d( | |
in_channels=(self.dims_encoder[3] + self.dims_encoder[3]), | |
out_channels=self.dims_encoder[3], | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=True, | |
) | |
# Obtain intermediate outputs of the blocks. | |
self.patch_encoder.blocks[self.hook_block_ids[0]].register_forward_hook( | |
self._hook0 | |
) | |
self.patch_encoder.blocks[self.hook_block_ids[1]].register_forward_hook( | |
self._hook1 | |
) | |
def _hook0(self, model, input, output): | |
self.backbone_highres_hook0 = output | |
def _hook1(self, model, input, output): | |
self.backbone_highres_hook1 = output | |
def img_size(self) -> int: | |
"""Return the full image size of the SPN network.""" | |
return self.patch_encoder.patch_embed.img_size[0] * 4 | |
def _create_pyramid( | |
self, x: torch.Tensor | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Create a 3-level image pyramid.""" | |
# Original resolution: 1536 by default. | |
x0 = x | |
# Middle resolution: 768 by default. | |
x1 = F.interpolate( | |
x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False | |
) | |
# Low resolution: 384 by default, corresponding to the backbone resolution. | |
x2 = F.interpolate( | |
x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False | |
) | |
return x0, x1, x2 | |
def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor: | |
"""Split the input into small patches with sliding window.""" | |
patch_size = 384 | |
patch_stride = int(patch_size * (1 - overlap_ratio)) | |
image_size = x.shape[-1] | |
steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1 | |
x_patch_list = [] | |
for j in range(steps): | |
j0 = j * patch_stride | |
j1 = j0 + patch_size | |
for i in range(steps): | |
i0 = i * patch_stride | |
i1 = i0 + patch_size | |
x_patch_list.append(x[..., j0:j1, i0:i1]) | |
return torch.cat(x_patch_list, dim=0) | |
def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor: | |
"""Merge the patched input into a image with sliding window.""" | |
steps = int(math.sqrt(x.shape[0] // batch_size)) | |
idx = 0 | |
output_list = [] | |
for j in range(steps): | |
output_row_list = [] | |
for i in range(steps): | |
output = x[batch_size * idx : batch_size * (idx + 1)] | |
if j != 0: | |
output = output[..., padding:, :] | |
if i != 0: | |
output = output[..., :, padding:] | |
if j != steps - 1: | |
output = output[..., :-padding, :] | |
if i != steps - 1: | |
output = output[..., :, :-padding] | |
output_row_list.append(output) | |
idx += 1 | |
output_row = torch.cat(output_row_list, dim=-1) | |
output_list.append(output_row) | |
output = torch.cat(output_list, dim=-2) | |
return output | |
def reshape_feature( | |
self, embeddings: torch.Tensor, width, height, cls_token_offset=1 | |
): | |
"""Discard class token and reshape 1D feature map to a 2D grid.""" | |
b, hw, c = embeddings.shape | |
# Remove class token. | |
if cls_token_offset > 0: | |
embeddings = embeddings[:, cls_token_offset:, :] | |
# Shape: (batch, height, width, dim) -> (batch, dim, height, width) | |
embeddings = embeddings.reshape(b, height, width, c).permute(0, 3, 1, 2) | |
return embeddings | |
def forward(self, x: torch.Tensor) -> list[torch.Tensor]: | |
"""Encode input at multiple resolutions. | |
Args: | |
---- | |
x (torch.Tensor): Input image. | |
Returns: | |
------- | |
Multi resolution encoded features. | |
""" | |
batch_size = x.shape[0] | |
# Step 0: create a 3-level image pyramid. | |
x0, x1, x2 = self._create_pyramid(x) | |
# Step 1: split to create batched overlapped mini-images at the backbone (BeiT/ViT/Dino) | |
# resolution. | |
# 5x5 @ 384x384 at the highest resolution (1536x1536). | |
x0_patches = self.split(x0, overlap_ratio=0.25) | |
# 3x3 @ 384x384 at the middle resolution (768x768). | |
x1_patches = self.split(x1, overlap_ratio=0.5) | |
# 1x1 # 384x384 at the lowest resolution (384x384). | |
x2_patches = x2 | |
# Concatenate all the sliding window patches and form a batch of size (35=5x5+3x3+1x1). | |
x_pyramid_patches = torch.cat( | |
(x0_patches, x1_patches, x2_patches), | |
dim=0, | |
) | |
# Step 2: Run the backbone (BeiT) model and get the result of large batch size. | |
x_pyramid_encodings = self.patch_encoder(x_pyramid_patches) | |
x_pyramid_encodings = self.reshape_feature( | |
x_pyramid_encodings, self.out_size, self.out_size | |
) | |
# Step 3: merging. | |
# Merge highres latent encoding. | |
x_latent0_encodings = self.reshape_feature( | |
self.backbone_highres_hook0, | |
self.out_size, | |
self.out_size, | |
) | |
x_latent0_features = self.merge( | |
x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 | |
) | |
x_latent1_encodings = self.reshape_feature( | |
self.backbone_highres_hook1, | |
self.out_size, | |
self.out_size, | |
) | |
x_latent1_features = self.merge( | |
x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 | |
) | |
# Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1. | |
x0_encodings, x1_encodings, x2_encodings = torch.split( | |
x_pyramid_encodings, | |
[len(x0_patches), len(x1_patches), len(x2_patches)], | |
dim=0, | |
) | |
# 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps. | |
x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3) | |
# 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps. | |
x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6) | |
# 24x24 feature maps. | |
x2_features = x2_encodings | |
# Apply the image encoder model. | |
x_global_features = self.image_encoder(x2_patches) | |
x_global_features = self.reshape_feature( | |
x_global_features, self.out_size, self.out_size | |
) | |
# Upsample feature maps. | |
x_latent0_features = self.upsample_latent0(x_latent0_features) | |
x_latent1_features = self.upsample_latent1(x_latent1_features) | |
x0_features = self.upsample0(x0_features) | |
x1_features = self.upsample1(x1_features) | |
x2_features = self.upsample2(x2_features) | |
x_global_features = self.upsample_lowres(x_global_features) | |
x_global_features = self.fuse_lowres( | |
torch.cat((x2_features, x_global_features), dim=1) | |
) | |
return [ | |
x_latent0_features, | |
x_latent1_features, | |
x0_features, | |
x1_features, | |
x_global_features, | |
] | |