Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# Portions Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
from functools import partial | |
from types import SimpleNamespace | |
import torch | |
import torch.nn as nn | |
# from pytorch_lightning.utilities import rank_zero_only | |
from .helpers import (EinOpsRearrange, LearnableLogitScaling, Normalize, | |
SelectElement, SelectEOSAndProject) | |
from .multimodal_preprocessors import (AudioPreprocessor, | |
IMUPreprocessor, PadIm2Video, | |
PatchEmbedGeneric, | |
RGBDTPreprocessor, | |
SpatioTemporalPosEmbeddingHelper, | |
TextPreprocessor, | |
ThermalPreprocessor) | |
from .transformer import MultiheadAttention, SimpleTransformer | |
ModalityType = SimpleNamespace( | |
VISION="vision", | |
TEXT="text", | |
AUDIO="audio", | |
THERMAL="thermal", | |
DEPTH="depth", | |
IMU="imu", | |
POINT="point", | |
) | |
class ImageBindModel(nn.Module): | |
def __init__( | |
self, | |
video_frames=2, | |
kernel_size=(2, 14, 14), | |
audio_kernel_size=16, | |
audio_stride=10, | |
out_embed_dim=768, | |
vision_embed_dim=1024, | |
vision_num_blocks=24, | |
vision_num_heads=16, | |
audio_embed_dim=768, | |
audio_num_blocks=12, | |
audio_num_heads=12, | |
audio_num_mel_bins=128, | |
audio_target_len=204, | |
audio_drop_path=0.1, | |
text_embed_dim=768, | |
text_num_blocks=12, | |
text_num_heads=12, | |
depth_embed_dim=384, | |
depth_kernel_size=16, | |
depth_num_blocks=12, | |
depth_num_heads=8, | |
depth_drop_path=0.0, | |
thermal_embed_dim=768, | |
thermal_kernel_size=16, | |
thermal_num_blocks=12, | |
thermal_num_heads=12, | |
thermal_drop_path=0.0, | |
imu_embed_dim=512, | |
imu_kernel_size=8, | |
imu_num_blocks=6, | |
imu_num_heads=8, | |
imu_drop_path=0.7, | |
layers = [7,15,23,31] | |
): | |
super().__init__() | |
self.out_layers = layers | |
self.modality_preprocessors = self._create_modality_preprocessors( | |
video_frames, | |
vision_embed_dim, | |
kernel_size, | |
text_embed_dim, | |
audio_embed_dim, | |
audio_kernel_size, | |
audio_stride, | |
audio_num_mel_bins, | |
audio_target_len, | |
depth_embed_dim, | |
depth_kernel_size, | |
thermal_embed_dim, | |
thermal_kernel_size, | |
imu_embed_dim, | |
) | |
self.modality_trunks = self._create_modality_trunks( | |
vision_embed_dim, | |
vision_num_blocks, | |
vision_num_heads, | |
text_embed_dim, | |
text_num_blocks, | |
text_num_heads, | |
audio_embed_dim, | |
audio_num_blocks, | |
audio_num_heads, | |
audio_drop_path, | |
depth_embed_dim, | |
depth_num_blocks, | |
depth_num_heads, | |
depth_drop_path, | |
thermal_embed_dim, | |
thermal_num_blocks, | |
thermal_num_heads, | |
thermal_drop_path, | |
imu_embed_dim, | |
imu_num_blocks, | |
imu_num_heads, | |
imu_drop_path, | |
) | |
self.modality_heads = self._create_modality_heads( | |
out_embed_dim, | |
vision_embed_dim, | |
text_embed_dim, | |
audio_embed_dim, | |
depth_embed_dim, | |
thermal_embed_dim, | |
imu_embed_dim, | |
) | |
self.modality_postprocessors = self._create_modality_postprocessors( | |
out_embed_dim | |
) | |
def _create_modality_preprocessors( | |
self, | |
video_frames=2, | |
vision_embed_dim=1024, | |
kernel_size=(2, 14, 14), | |
text_embed_dim=768, | |
audio_embed_dim=768, | |
audio_kernel_size=16, | |
audio_stride=10, | |
audio_num_mel_bins=128, | |
audio_target_len=204, | |
depth_embed_dim=768, | |
depth_kernel_size=16, | |
thermal_embed_dim=768, | |
thermal_kernel_size=16, | |
imu_embed_dim=512, | |
): | |
rgbt_stem = PatchEmbedGeneric( | |
proj_stem=[ | |
PadIm2Video(pad_type="repeat", ntimes=2), | |
nn.Conv3d( | |
in_channels=3, | |
kernel_size=kernel_size, | |
out_channels=vision_embed_dim, | |
stride=kernel_size, | |
bias=False, | |
), | |
] | |
) | |
rgbt_preprocessor = RGBDTPreprocessor( | |
img_size=[3, video_frames, 224, 224], | |
num_cls_tokens=1, | |
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
rgbt_stem=rgbt_stem, | |
depth_stem=None, | |
) | |
text_preprocessor = TextPreprocessor( | |
context_length=77, | |
vocab_size=49408, | |
embed_dim=text_embed_dim, | |
causal_masking=True, | |
) | |
audio_stem = PatchEmbedGeneric( | |
proj_stem=[ | |
nn.Conv2d( | |
in_channels=1, | |
kernel_size=audio_kernel_size, | |
stride=audio_stride, | |
out_channels=audio_embed_dim, | |
bias=False, | |
), | |
], | |
norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), | |
) | |
audio_preprocessor = AudioPreprocessor( | |
img_size=[1, audio_num_mel_bins, audio_target_len], | |
num_cls_tokens=1, | |
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
audio_stem=audio_stem, | |
) | |
depth_stem = PatchEmbedGeneric( | |
[ | |
nn.Conv2d( | |
kernel_size=depth_kernel_size, | |
in_channels=1, | |
out_channels=depth_embed_dim, | |
stride=depth_kernel_size, | |
bias=False, | |
), | |
], | |
norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), | |
) | |
depth_preprocessor = RGBDTPreprocessor( | |
img_size=[1, 224, 224], | |
num_cls_tokens=1, | |
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
rgbt_stem=None, | |
depth_stem=depth_stem, | |
) | |
thermal_stem = PatchEmbedGeneric( | |
[ | |
nn.Conv2d( | |
kernel_size=thermal_kernel_size, | |
in_channels=1, | |
out_channels=thermal_embed_dim, | |
stride=thermal_kernel_size, | |
bias=False, | |
), | |
], | |
norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), | |
) | |
thermal_preprocessor = ThermalPreprocessor( | |
img_size=[1, 224, 224], | |
num_cls_tokens=1, | |
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
thermal_stem=thermal_stem, | |
) | |
imu_stem = PatchEmbedGeneric( | |
[ | |
nn.Linear( | |
in_features=48, | |
out_features=imu_embed_dim, | |
bias=False, | |
), | |
], | |
norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), | |
) | |
imu_preprocessor = IMUPreprocessor( | |
img_size=[6, 2000], | |
num_cls_tokens=1, | |
kernel_size=8, | |
embed_dim=imu_embed_dim, | |
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
imu_stem=imu_stem, | |
) | |
modality_preprocessors = { | |
ModalityType.VISION: rgbt_preprocessor, | |
ModalityType.TEXT: text_preprocessor, | |
ModalityType.AUDIO: audio_preprocessor, | |
ModalityType.DEPTH: depth_preprocessor, | |
ModalityType.THERMAL: thermal_preprocessor, | |
ModalityType.IMU: imu_preprocessor, | |
} | |
return nn.ModuleDict(modality_preprocessors) | |
def _create_modality_trunks( | |
self, | |
vision_embed_dim=1024, | |
vision_num_blocks=24, | |
vision_num_heads=16, | |
text_embed_dim=768, | |
text_num_blocks=12, | |
text_num_heads=12, | |
audio_embed_dim=768, | |
audio_num_blocks=12, | |
audio_num_heads=12, | |
audio_drop_path=0.0, | |
depth_embed_dim=768, | |
depth_num_blocks=12, | |
depth_num_heads=12, | |
depth_drop_path=0.0, | |
thermal_embed_dim=768, | |
thermal_num_blocks=12, | |
thermal_num_heads=12, | |
thermal_drop_path=0.0, | |
imu_embed_dim=512, | |
imu_num_blocks=6, | |
imu_num_heads=8, | |
imu_drop_path=0.7, | |
): | |
def instantiate_trunk( | |
embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path | |
): | |
return SimpleTransformer( | |
embed_dim=embed_dim, | |
num_blocks=num_blocks, | |
ffn_dropout_rate=0.0, | |
drop_path_rate=drop_path, | |
attn_target=partial( | |
MultiheadAttention, | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
bias=True, | |
add_bias_kv=add_bias_kv, | |
), | |
pre_transformer_layer=nn.Sequential( | |
nn.LayerNorm(embed_dim, eps=1e-6) | |
if pre_transformer_ln | |
else nn.Identity(), | |
EinOpsRearrange("b l d -> l b d"), | |
), | |
post_transformer_layer=EinOpsRearrange("l b d -> b l d"), | |
) | |
modality_trunks = {} | |
modality_trunks[ModalityType.VISION] = instantiate_trunk( | |
vision_embed_dim, | |
vision_num_blocks, | |
vision_num_heads, | |
pre_transformer_ln=True, | |
add_bias_kv=False, | |
drop_path=0.0, | |
) | |
modality_trunks[ModalityType.TEXT] = instantiate_trunk( | |
text_embed_dim, | |
text_num_blocks, | |
text_num_heads, | |
pre_transformer_ln=False, | |
add_bias_kv=False, | |
drop_path=0.0, | |
) | |
modality_trunks[ModalityType.AUDIO] = instantiate_trunk( | |
audio_embed_dim, | |
audio_num_blocks, | |
audio_num_heads, | |
pre_transformer_ln=False, | |
add_bias_kv=True, | |
drop_path=audio_drop_path, | |
) | |
modality_trunks[ModalityType.DEPTH] = instantiate_trunk( | |
depth_embed_dim, | |
depth_num_blocks, | |
depth_num_heads, | |
pre_transformer_ln=False, | |
add_bias_kv=True, | |
drop_path=depth_drop_path, | |
) | |
modality_trunks[ModalityType.THERMAL] = instantiate_trunk( | |
thermal_embed_dim, | |
thermal_num_blocks, | |
thermal_num_heads, | |
pre_transformer_ln=False, | |
add_bias_kv=True, | |
drop_path=thermal_drop_path, | |
) | |
modality_trunks[ModalityType.IMU] = instantiate_trunk( | |
imu_embed_dim, | |
imu_num_blocks, | |
imu_num_heads, | |
pre_transformer_ln=False, | |
add_bias_kv=True, | |
drop_path=imu_drop_path, | |
) | |
return nn.ModuleDict(modality_trunks) | |
def _create_modality_heads( | |
self, | |
out_embed_dim, | |
vision_embed_dim, | |
text_embed_dim, | |
audio_embed_dim, | |
depth_embed_dim, | |
thermal_embed_dim, | |
imu_embed_dim, | |
): | |
modality_heads = {} | |
modality_heads[ModalityType.VISION] = nn.Sequential( | |
nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), | |
SelectElement(index=0), | |
nn.Linear(vision_embed_dim, out_embed_dim, bias=False), | |
) | |
modality_heads[ModalityType.TEXT] = SelectEOSAndProject( | |
proj=nn.Sequential( | |
nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), | |
nn.Linear(text_embed_dim, out_embed_dim, bias=False), | |
) | |
) | |
modality_heads[ModalityType.AUDIO] = nn.Sequential( | |
nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), | |
SelectElement(index=0), | |
nn.Linear(audio_embed_dim, out_embed_dim, bias=False), | |
) | |
modality_heads[ModalityType.DEPTH] = nn.Sequential( | |
nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), | |
SelectElement(index=0), | |
nn.Linear(depth_embed_dim, out_embed_dim, bias=False), | |
) | |
modality_heads[ModalityType.THERMAL] = nn.Sequential( | |
nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), | |
SelectElement(index=0), | |
nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), | |
) | |
modality_heads[ModalityType.IMU] = nn.Sequential( | |
nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), | |
SelectElement(index=0), | |
nn.Dropout(p=0.5), | |
nn.Linear(imu_embed_dim, out_embed_dim, bias=False), | |
) | |
return nn.ModuleDict(modality_heads) | |
def _create_modality_postprocessors(self, out_embed_dim): | |
modality_postprocessors = {} | |
modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) | |
modality_postprocessors[ModalityType.TEXT] = nn.Sequential( | |
Normalize(dim=-1), LearnableLogitScaling(learnable=True) | |
) | |
modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( | |
Normalize(dim=-1), | |
LearnableLogitScaling(logit_scale_init=20.0, learnable=False), | |
) | |
modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( | |
Normalize(dim=-1), | |
LearnableLogitScaling(logit_scale_init=5.0, learnable=False), | |
) | |
modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( | |
Normalize(dim=-1), | |
LearnableLogitScaling(logit_scale_init=10.0, learnable=False), | |
) | |
modality_postprocessors[ModalityType.IMU] = nn.Sequential( | |
Normalize(dim=-1), | |
LearnableLogitScaling(logit_scale_init=5.0, learnable=False), | |
) | |
return nn.ModuleDict(modality_postprocessors) | |
def forward(self, inputs): | |
outputs = {} | |
for modality_key, modality_value in inputs.items(): | |
reduce_list = ( | |
modality_value.ndim >= 5 | |
) # Audio and Video inputs consist of multiple clips | |
if reduce_list: | |
B, S = modality_value.shape[:2] | |
modality_value = modality_value.reshape( | |
B * S, *modality_value.shape[2:] | |
) | |
if modality_value is not None: | |
modality_value = self.modality_preprocessors[modality_key]( | |
**{modality_key: modality_value} | |
) | |
trunk_inputs = modality_value["trunk"] | |
head_inputs = modality_value["head"] | |
modality_value, modality_full_value = self.modality_trunks[modality_key](**trunk_inputs, out_layers=self.out_layers) | |
modality_value = self.modality_heads[modality_key]( | |
modality_value, **head_inputs | |
) | |
modality_value = self.modality_postprocessors[modality_key]( | |
modality_value | |
) | |
if reduce_list: | |
modality_value = modality_value.reshape(B, S, -1) | |
modality_value = modality_value.mean(dim=1) | |
outputs[modality_key] = modality_value, modality_full_value | |
return outputs | |
def imagebind_huge(args): | |
if 'layers' in args: | |
layers = args['layers'] | |
else: | |
layers = [7,15,23,31] | |
return ImageBindModel( | |
vision_embed_dim=1280, | |
vision_num_blocks=32, | |
vision_num_heads=16, | |
text_embed_dim=1024, | |
text_num_blocks=24, | |
text_num_heads=16, | |
out_embed_dim=1024, | |
audio_drop_path=0.1, | |
imu_drop_path=0.7, | |
layers = layers | |
), 1024 | |
def save_module(module_dict: nn.ModuleDict, module_name: str = "", | |
checkpoint_dir: str = "./.checkpoints/full", postfix: str = "_last", | |
extension: str = "pth"): | |
try: | |
torch.save(module_dict.state_dict(), | |
os.path.join(checkpoint_dir, f"imagebind-{module_name}{postfix}.{extension}")) | |
logging.info(f"Saved parameters for module {module_name} to {checkpoint_dir}.") | |
except FileNotFoundError: | |
logging.warning(f"Could not save module parameters for {module_name} to {checkpoint_dir}.") | |
def load_module(module_dict: nn.ModuleDict, module_name: str = "", | |
checkpoint_dir: str = "./.checkpoints/full", postfix: str = "_last", | |
extension: str = "pth"): | |
try: | |
module_dict.load_state_dict(torch.load( | |
os.path.join(checkpoint_dir, f"imagebind-{module_name}{postfix}.{extension}")), strict=False) | |
logging.info(f"Loaded parameters for module {module_name} from {checkpoint_dir}.") | |
except FileNotFoundError: | |
logging.warning(f"Could not load module parameters for {module_name} from {checkpoint_dir}.") |