#!/usr/bin/env python3 # Portions Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import os from functools import partial from types import SimpleNamespace import torch import torch.nn as nn # from pytorch_lightning.utilities import rank_zero_only from .helpers import (EinOpsRearrange, LearnableLogitScaling, Normalize, SelectElement, SelectEOSAndProject) from .multimodal_preprocessors import (AudioPreprocessor, IMUPreprocessor, PadIm2Video, PatchEmbedGeneric, RGBDTPreprocessor, SpatioTemporalPosEmbeddingHelper, TextPreprocessor, ThermalPreprocessor) from .transformer import MultiheadAttention, SimpleTransformer ModalityType = SimpleNamespace( VISION="vision", TEXT="text", AUDIO="audio", THERMAL="thermal", DEPTH="depth", IMU="imu", POINT="point", ) class ImageBindModel(nn.Module): def __init__( self, video_frames=2, kernel_size=(2, 14, 14), audio_kernel_size=16, audio_stride=10, out_embed_dim=768, vision_embed_dim=1024, vision_num_blocks=24, vision_num_heads=16, audio_embed_dim=768, audio_num_blocks=12, audio_num_heads=12, audio_num_mel_bins=128, audio_target_len=204, audio_drop_path=0.1, text_embed_dim=768, text_num_blocks=12, text_num_heads=12, depth_embed_dim=384, depth_kernel_size=16, depth_num_blocks=12, depth_num_heads=8, depth_drop_path=0.0, thermal_embed_dim=768, thermal_kernel_size=16, thermal_num_blocks=12, thermal_num_heads=12, thermal_drop_path=0.0, imu_embed_dim=512, imu_kernel_size=8, imu_num_blocks=6, imu_num_heads=8, imu_drop_path=0.7, layers = [7,15,23,31] ): super().__init__() self.out_layers = layers self.modality_preprocessors = self._create_modality_preprocessors( video_frames, vision_embed_dim, kernel_size, text_embed_dim, audio_embed_dim, audio_kernel_size, audio_stride, audio_num_mel_bins, audio_target_len, depth_embed_dim, depth_kernel_size, thermal_embed_dim, thermal_kernel_size, imu_embed_dim, ) self.modality_trunks = self._create_modality_trunks( vision_embed_dim, vision_num_blocks, vision_num_heads, text_embed_dim, text_num_blocks, text_num_heads, audio_embed_dim, audio_num_blocks, audio_num_heads, audio_drop_path, depth_embed_dim, depth_num_blocks, depth_num_heads, depth_drop_path, thermal_embed_dim, thermal_num_blocks, thermal_num_heads, thermal_drop_path, imu_embed_dim, imu_num_blocks, imu_num_heads, imu_drop_path, ) self.modality_heads = self._create_modality_heads( out_embed_dim, vision_embed_dim, text_embed_dim, audio_embed_dim, depth_embed_dim, thermal_embed_dim, imu_embed_dim, ) self.modality_postprocessors = self._create_modality_postprocessors( out_embed_dim ) def _create_modality_preprocessors( self, video_frames=2, vision_embed_dim=1024, kernel_size=(2, 14, 14), text_embed_dim=768, audio_embed_dim=768, audio_kernel_size=16, audio_stride=10, audio_num_mel_bins=128, audio_target_len=204, depth_embed_dim=768, depth_kernel_size=16, thermal_embed_dim=768, thermal_kernel_size=16, imu_embed_dim=512, ): rgbt_stem = PatchEmbedGeneric( proj_stem=[ PadIm2Video(pad_type="repeat", ntimes=2), nn.Conv3d( in_channels=3, kernel_size=kernel_size, out_channels=vision_embed_dim, stride=kernel_size, bias=False, ), ] ) rgbt_preprocessor = RGBDTPreprocessor( img_size=[3, video_frames, 224, 224], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), rgbt_stem=rgbt_stem, depth_stem=None, ) text_preprocessor = TextPreprocessor( context_length=77, vocab_size=49408, embed_dim=text_embed_dim, causal_masking=True, ) audio_stem = PatchEmbedGeneric( proj_stem=[ nn.Conv2d( in_channels=1, kernel_size=audio_kernel_size, stride=audio_stride, out_channels=audio_embed_dim, bias=False, ), ], norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), ) audio_preprocessor = AudioPreprocessor( img_size=[1, audio_num_mel_bins, audio_target_len], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), audio_stem=audio_stem, ) depth_stem = PatchEmbedGeneric( [ nn.Conv2d( kernel_size=depth_kernel_size, in_channels=1, out_channels=depth_embed_dim, stride=depth_kernel_size, bias=False, ), ], norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), ) depth_preprocessor = RGBDTPreprocessor( img_size=[1, 224, 224], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), rgbt_stem=None, depth_stem=depth_stem, ) thermal_stem = PatchEmbedGeneric( [ nn.Conv2d( kernel_size=thermal_kernel_size, in_channels=1, out_channels=thermal_embed_dim, stride=thermal_kernel_size, bias=False, ), ], norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), ) thermal_preprocessor = ThermalPreprocessor( img_size=[1, 224, 224], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), thermal_stem=thermal_stem, ) imu_stem = PatchEmbedGeneric( [ nn.Linear( in_features=48, out_features=imu_embed_dim, bias=False, ), ], norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), ) imu_preprocessor = IMUPreprocessor( img_size=[6, 2000], num_cls_tokens=1, kernel_size=8, embed_dim=imu_embed_dim, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), imu_stem=imu_stem, ) modality_preprocessors = { ModalityType.VISION: rgbt_preprocessor, ModalityType.TEXT: text_preprocessor, ModalityType.AUDIO: audio_preprocessor, ModalityType.DEPTH: depth_preprocessor, ModalityType.THERMAL: thermal_preprocessor, ModalityType.IMU: imu_preprocessor, } return nn.ModuleDict(modality_preprocessors) def _create_modality_trunks( self, vision_embed_dim=1024, vision_num_blocks=24, vision_num_heads=16, text_embed_dim=768, text_num_blocks=12, text_num_heads=12, audio_embed_dim=768, audio_num_blocks=12, audio_num_heads=12, audio_drop_path=0.0, depth_embed_dim=768, depth_num_blocks=12, depth_num_heads=12, depth_drop_path=0.0, thermal_embed_dim=768, thermal_num_blocks=12, thermal_num_heads=12, thermal_drop_path=0.0, imu_embed_dim=512, imu_num_blocks=6, imu_num_heads=8, imu_drop_path=0.7, ): def instantiate_trunk( embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path ): return SimpleTransformer( embed_dim=embed_dim, num_blocks=num_blocks, ffn_dropout_rate=0.0, drop_path_rate=drop_path, attn_target=partial( MultiheadAttention, embed_dim=embed_dim, num_heads=num_heads, bias=True, add_bias_kv=add_bias_kv, ), pre_transformer_layer=nn.Sequential( nn.LayerNorm(embed_dim, eps=1e-6) if pre_transformer_ln else nn.Identity(), EinOpsRearrange("b l d -> l b d"), ), post_transformer_layer=EinOpsRearrange("l b d -> b l d"), ) modality_trunks = {} modality_trunks[ModalityType.VISION] = instantiate_trunk( vision_embed_dim, vision_num_blocks, vision_num_heads, pre_transformer_ln=True, add_bias_kv=False, drop_path=0.0, ) modality_trunks[ModalityType.TEXT] = instantiate_trunk( text_embed_dim, text_num_blocks, text_num_heads, pre_transformer_ln=False, add_bias_kv=False, drop_path=0.0, ) modality_trunks[ModalityType.AUDIO] = instantiate_trunk( audio_embed_dim, audio_num_blocks, audio_num_heads, pre_transformer_ln=False, add_bias_kv=True, drop_path=audio_drop_path, ) modality_trunks[ModalityType.DEPTH] = instantiate_trunk( depth_embed_dim, depth_num_blocks, depth_num_heads, pre_transformer_ln=False, add_bias_kv=True, drop_path=depth_drop_path, ) modality_trunks[ModalityType.THERMAL] = instantiate_trunk( thermal_embed_dim, thermal_num_blocks, thermal_num_heads, pre_transformer_ln=False, add_bias_kv=True, drop_path=thermal_drop_path, ) modality_trunks[ModalityType.IMU] = instantiate_trunk( imu_embed_dim, imu_num_blocks, imu_num_heads, pre_transformer_ln=False, add_bias_kv=True, drop_path=imu_drop_path, ) return nn.ModuleDict(modality_trunks) def _create_modality_heads( self, out_embed_dim, vision_embed_dim, text_embed_dim, audio_embed_dim, depth_embed_dim, thermal_embed_dim, imu_embed_dim, ): modality_heads = {} modality_heads[ModalityType.VISION] = nn.Sequential( nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(vision_embed_dim, out_embed_dim, bias=False), ) modality_heads[ModalityType.TEXT] = SelectEOSAndProject( proj=nn.Sequential( nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), nn.Linear(text_embed_dim, out_embed_dim, bias=False), ) ) modality_heads[ModalityType.AUDIO] = nn.Sequential( nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(audio_embed_dim, out_embed_dim, bias=False), ) modality_heads[ModalityType.DEPTH] = nn.Sequential( nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(depth_embed_dim, out_embed_dim, bias=False), ) modality_heads[ModalityType.THERMAL] = nn.Sequential( nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), ) modality_heads[ModalityType.IMU] = nn.Sequential( nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), SelectElement(index=0), nn.Dropout(p=0.5), nn.Linear(imu_embed_dim, out_embed_dim, bias=False), ) return nn.ModuleDict(modality_heads) def _create_modality_postprocessors(self, out_embed_dim): modality_postprocessors = {} modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) modality_postprocessors[ModalityType.TEXT] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(learnable=True) ) modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(logit_scale_init=20.0, learnable=False), ) modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(logit_scale_init=5.0, learnable=False), ) modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(logit_scale_init=10.0, learnable=False), ) modality_postprocessors[ModalityType.IMU] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(logit_scale_init=5.0, learnable=False), ) return nn.ModuleDict(modality_postprocessors) def forward(self, inputs): outputs = {} for modality_key, modality_value in inputs.items(): reduce_list = ( modality_value.ndim >= 5 ) # Audio and Video inputs consist of multiple clips if reduce_list: B, S = modality_value.shape[:2] modality_value = modality_value.reshape( B * S, *modality_value.shape[2:] ) if modality_value is not None: modality_value = self.modality_preprocessors[modality_key]( **{modality_key: modality_value} ) trunk_inputs = modality_value["trunk"] head_inputs = modality_value["head"] modality_value, modality_full_value = self.modality_trunks[modality_key](**trunk_inputs, out_layers=self.out_layers) modality_value = self.modality_heads[modality_key]( modality_value, **head_inputs ) modality_value = self.modality_postprocessors[modality_key]( modality_value ) if reduce_list: modality_value = modality_value.reshape(B, S, -1) modality_value = modality_value.mean(dim=1) outputs[modality_key] = modality_value, modality_full_value return outputs def imagebind_huge(args): if 'layers' in args: layers = args['layers'] else: layers = [7,15,23,31] return ImageBindModel( vision_embed_dim=1280, vision_num_blocks=32, vision_num_heads=16, text_embed_dim=1024, text_num_blocks=24, text_num_heads=16, out_embed_dim=1024, audio_drop_path=0.1, imu_drop_path=0.7, layers = layers ), 1024 def save_module(module_dict: nn.ModuleDict, module_name: str = "", checkpoint_dir: str = "./.checkpoints/full", postfix: str = "_last", extension: str = "pth"): try: torch.save(module_dict.state_dict(), os.path.join(checkpoint_dir, f"imagebind-{module_name}{postfix}.{extension}")) logging.info(f"Saved parameters for module {module_name} to {checkpoint_dir}.") except FileNotFoundError: logging.warning(f"Could not save module parameters for {module_name} to {checkpoint_dir}.") def load_module(module_dict: nn.ModuleDict, module_name: str = "", checkpoint_dir: str = "./.checkpoints/full", postfix: str = "_last", extension: str = "pth"): try: module_dict.load_state_dict(torch.load( os.path.join(checkpoint_dir, f"imagebind-{module_name}{postfix}.{extension}")), strict=False) logging.info(f"Loaded parameters for module {module_name} from {checkpoint_dir}.") except FileNotFoundError: logging.warning(f"Could not load module parameters for {module_name} from {checkpoint_dir}.")