Spaces:
Sleeping
Sleeping
# Copyright (c) Tencent Inc. All rights reserved. | |
from typing import List | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule, Linear | |
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig | |
from mmengine.model import BaseModule | |
from mmyolo.registry import MODELS | |
from mmyolo.models.layers import CSPLayerWithTwoConv | |
class MaxSigmoidAttnBlock(BaseModule): | |
"""Max Sigmoid attention block.""" | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
guide_channels: int, | |
embed_channels: int, | |
kernel_size: int = 3, | |
padding: int = 1, | |
num_heads: int = 1, | |
use_depthwise: bool = False, | |
with_scale: bool = False, | |
conv_cfg: OptConfigType = None, | |
norm_cfg: ConfigType = dict(type='BN', | |
momentum=0.03, | |
eps=0.001), | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule | |
assert (out_channels % num_heads == 0 and | |
embed_channels % num_heads == 0), \ | |
'out_channels and embed_channels should be divisible by num_heads.' | |
self.num_heads = num_heads | |
self.head_channels = out_channels // num_heads | |
self.embed_conv = ConvModule( | |
in_channels, | |
embed_channels, | |
1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None) if embed_channels != in_channels else None | |
self.guide_fc = Linear(guide_channels, embed_channels) | |
self.bias = nn.Parameter(torch.zeros(num_heads)) | |
if with_scale: | |
self.scale = nn.Parameter(torch.ones(1, num_heads, 1, 1)) | |
else: | |
self.scale = 1.0 | |
self.project_conv = conv(in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
padding=padding, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
def forward(self, x: Tensor, guide: Tensor) -> Tensor: | |
"""Forward process.""" | |
B, _, H, W = x.shape | |
guide = self.guide_fc(guide) | |
guide = guide.reshape(B, -1, self.num_heads, self.head_channels) | |
embed = self.embed_conv(x) if self.embed_conv is not None else x | |
embed = embed.reshape(B, self.num_heads, self.head_channels, H, W) | |
attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide) | |
attn_weight = attn_weight.max(dim=-1)[0] | |
attn_weight = attn_weight / (self.head_channels**0.5) | |
attn_weight = attn_weight + self.bias[None, :, None, None] | |
attn_weight = attn_weight.sigmoid() * self.scale | |
x = self.project_conv(x) | |
x = x.reshape(B, self.num_heads, -1, H, W) | |
x = x * attn_weight.unsqueeze(2) | |
x = x.reshape(B, -1, H, W) | |
return x | |
class MaxSigmoidCSPLayerWithTwoConv(CSPLayerWithTwoConv): | |
"""Sigmoid-attention based CSP layer with two convolution layers.""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
guide_channels: int, | |
embed_channels: int, | |
num_heads: int = 1, | |
expand_ratio: float = 0.5, | |
num_blocks: int = 1, | |
with_scale: bool = False, | |
add_identity: bool = True, # shortcut | |
conv_cfg: OptConfigType = None, | |
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), | |
act_cfg: ConfigType = dict(type='SiLU', inplace=True), | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(in_channels=in_channels, | |
out_channels=out_channels, | |
expand_ratio=expand_ratio, | |
num_blocks=num_blocks, | |
add_identity=add_identity, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
init_cfg=init_cfg) | |
self.final_conv = ConvModule((3 + num_blocks) * self.mid_channels, | |
out_channels, | |
1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.attn_block = MaxSigmoidAttnBlock(self.mid_channels, | |
self.mid_channels, | |
guide_channels=guide_channels, | |
embed_channels=embed_channels, | |
num_heads=num_heads, | |
with_scale=with_scale, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg) | |
def forward(self, x: Tensor, guide: Tensor) -> Tensor: | |
"""Forward process.""" | |
x_main = self.main_conv(x) | |
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1)) | |
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks) | |
x_main.append(self.attn_block(x_main[-1], guide)) | |
return self.final_conv(torch.cat(x_main, 1)) | |
class ImagePoolingAttentionModule(nn.Module): | |
def __init__(self, | |
image_channels: List[int], | |
text_channels: int, | |
embed_channels: int, | |
with_scale: bool = False, | |
num_feats: int = 3, | |
num_heads: int = 8, | |
pool_size: int = 3): | |
super().__init__() | |
self.text_channels = text_channels | |
self.embed_channels = embed_channels | |
self.num_heads = num_heads | |
self.num_feats = num_feats | |
self.head_channels = embed_channels // num_heads | |
self.pool_size = pool_size | |
if with_scale: | |
self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True) | |
else: | |
self.scale = 1.0 | |
self.projections = nn.ModuleList([ | |
ConvModule(in_channels, embed_channels, 1, act_cfg=None) | |
for in_channels in image_channels | |
]) | |
self.query = nn.Sequential(nn.LayerNorm(text_channels), | |
Linear(text_channels, embed_channels)) | |
self.key = nn.Sequential(nn.LayerNorm(embed_channels), | |
Linear(embed_channels, embed_channels)) | |
self.value = nn.Sequential(nn.LayerNorm(embed_channels), | |
Linear(embed_channels, embed_channels)) | |
self.proj = Linear(embed_channels, text_channels) | |
self.image_pools = nn.ModuleList([ | |
nn.AdaptiveMaxPool2d((pool_size, pool_size)) | |
for _ in range(num_feats) | |
]) | |
def forward(self, text_features, image_features): | |
B = image_features[0].shape[0] | |
assert len(image_features) == self.num_feats | |
num_patches = self.pool_size**2 | |
mlvl_image_features = [ | |
pool(proj(x)).view(B, -1, num_patches) | |
for (x, proj, pool | |
) in zip(image_features, self.projections, self.image_pools) | |
] | |
mlvl_image_features = torch.cat(mlvl_image_features, | |
dim=-1).transpose(1, 2) | |
q = self.query(text_features) | |
k = self.key(mlvl_image_features) | |
v = self.value(mlvl_image_features) | |
q = q.reshape(B, -1, self.num_heads, self.head_channels) | |
k = k.reshape(B, -1, self.num_heads, self.head_channels) | |
v = v.reshape(B, -1, self.num_heads, self.head_channels) | |
attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k) | |
attn_weight = attn_weight / (self.head_channels**0.5) | |
attn_weight = F.softmax(attn_weight, dim=-1) | |
x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v) | |
x = self.proj(x.reshape(B, -1, self.embed_channels)) | |
return x * self.scale + text_features | |
class VanillaSigmoidBlock(BaseModule): | |
"""Sigmoid attention block.""" | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
guide_channels: int, | |
embed_channels: int, | |
kernel_size: int = 3, | |
padding: int = 1, | |
num_heads: int = 1, | |
use_depthwise: bool = False, | |
with_scale: bool = False, | |
conv_cfg: OptConfigType = None, | |
norm_cfg: ConfigType = dict(type='BN', | |
momentum=0.03, | |
eps=0.001), | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule | |
assert (out_channels % num_heads == 0 and | |
embed_channels % num_heads == 0), \ | |
'out_channels and embed_channels should be divisible by num_heads.' | |
self.num_heads = num_heads | |
self.head_channels = out_channels // num_heads | |
self.project_conv = conv(in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
padding=padding, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
def forward(self, x: Tensor, guide: Tensor) -> Tensor: | |
"""Forward process.""" | |
x = self.project_conv(x) | |
x = x * x.sigmoid() | |
return x | |
class EfficientCSPLayerWithTwoConv(CSPLayerWithTwoConv): | |
"""Sigmoid-attention based CSP layer with two convolution layers.""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
guide_channels: int, | |
embed_channels: int, | |
num_heads: int = 1, | |
expand_ratio: float = 0.5, | |
num_blocks: int = 1, | |
with_scale: bool = False, | |
add_identity: bool = True, # shortcut | |
conv_cfg: OptConfigType = None, | |
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), | |
act_cfg: ConfigType = dict(type='SiLU', inplace=True), | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(in_channels=in_channels, | |
out_channels=out_channels, | |
expand_ratio=expand_ratio, | |
num_blocks=num_blocks, | |
add_identity=add_identity, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
init_cfg=init_cfg) | |
self.final_conv = ConvModule((3 + num_blocks) * self.mid_channels, | |
out_channels, | |
1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.attn_block = VanillaSigmoidBlock(self.mid_channels, | |
self.mid_channels, | |
guide_channels=guide_channels, | |
embed_channels=embed_channels, | |
num_heads=num_heads, | |
with_scale=with_scale, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg) | |
def forward(self, x: Tensor, guide: Tensor) -> Tensor: | |
"""Forward process.""" | |
x_main = self.main_conv(x) | |
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1)) | |
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks) | |
x_main.append(self.attn_block(x_main[-1], guide)) | |
return self.final_conv(torch.cat(x_main, 1)) | |