Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer, | |
xavier_init) | |
from .builder import TRANSFORMER | |
class MultiheadAttention(nn.Module): | |
"""A warpper for torch.nn.MultiheadAttention. | |
This module implements MultiheadAttention with residual connection, | |
and positional encoding used in DETR is also passed as input. | |
Args: | |
embed_dims (int): The embedding dimension. | |
num_heads (int): Parallel attention heads. Same as | |
`nn.MultiheadAttention`. | |
dropout (float): A Dropout layer on attn_output_weights. Default 0.0. | |
""" | |
def __init__(self, embed_dims, num_heads, dropout=0.0): | |
super(MultiheadAttention, self).__init__() | |
assert embed_dims % num_heads == 0, 'embed_dims must be ' \ | |
f'divisible by num_heads. got {embed_dims} and {num_heads}.' | |
self.embed_dims = embed_dims | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, | |
x, | |
key=None, | |
value=None, | |
residual=None, | |
query_pos=None, | |
key_pos=None, | |
attn_mask=None, | |
key_padding_mask=None): | |
"""Forward function for `MultiheadAttention`. | |
Args: | |
x (Tensor): The input query with shape [num_query, bs, | |
embed_dims]. Same in `nn.MultiheadAttention.forward`. | |
key (Tensor): The key tensor with shape [num_key, bs, | |
embed_dims]. Same in `nn.MultiheadAttention.forward`. | |
Default None. If None, the `query` will be used. | |
value (Tensor): The value tensor with same shape as `key`. | |
Same in `nn.MultiheadAttention.forward`. Default None. | |
If None, the `key` will be used. | |
residual (Tensor): The tensor used for addition, with the | |
same shape as `x`. Default None. If None, `x` will be used. | |
query_pos (Tensor): The positional encoding for query, with | |
the same shape as `x`. Default None. If not None, it will | |
be added to `x` before forward function. | |
key_pos (Tensor): The positional encoding for `key`, with the | |
same shape as `key`. Default None. If not None, it will | |
be added to `key` before forward function. If None, and | |
`query_pos` has the same shape as `key`, then `query_pos` | |
will be used for `key_pos`. | |
attn_mask (Tensor): ByteTensor mask with shape [num_query, | |
num_key]. Same in `nn.MultiheadAttention.forward`. | |
Default None. | |
key_padding_mask (Tensor): ByteTensor with shape [bs, num_key]. | |
Same in `nn.MultiheadAttention.forward`. Default None. | |
Returns: | |
Tensor: forwarded results with shape [num_query, bs, embed_dims]. | |
""" | |
query = x | |
if key is None: | |
key = query | |
if value is None: | |
value = key | |
if residual is None: | |
residual = x | |
if key_pos is None: | |
if query_pos is not None and key is not None: | |
if query_pos.shape == key.shape: | |
key_pos = query_pos | |
if query_pos is not None: | |
query = query + query_pos | |
if key_pos is not None: | |
key = key + key_pos | |
out = self.attn( | |
query, | |
key, | |
value=value, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask)[0] | |
return residual + self.dropout(out) | |
def __repr__(self): | |
"""str: a string that describes the module""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(embed_dims={self.embed_dims}, ' | |
repr_str += f'num_heads={self.num_heads}, ' | |
repr_str += f'dropout={self.dropout})' | |
return repr_str | |
class FFN(nn.Module): | |
"""Implements feed-forward networks (FFNs) with residual connection. | |
Args: | |
embed_dims (int): The feature dimension. Same as | |
`MultiheadAttention`. | |
feedforward_channels (int): The hidden dimension of FFNs. | |
num_fcs (int, optional): The number of fully-connected layers in | |
FFNs. Defaults to 2. | |
act_cfg (dict, optional): The activation config for FFNs. | |
dropout (float, optional): Probability of an element to be | |
zeroed. Default 0.0. | |
add_residual (bool, optional): Add resudual connection. | |
Defaults to True. | |
""" | |
def __init__(self, | |
embed_dims, | |
feedforward_channels, | |
num_fcs=2, | |
act_cfg=dict(type='ReLU', inplace=True), | |
dropout=0.0, | |
add_residual=True): | |
super(FFN, self).__init__() | |
assert num_fcs >= 2, 'num_fcs should be no less ' \ | |
f'than 2. got {num_fcs}.' | |
self.embed_dims = embed_dims | |
self.feedforward_channels = feedforward_channels | |
self.num_fcs = num_fcs | |
self.act_cfg = act_cfg | |
self.dropout = dropout | |
self.activate = build_activation_layer(act_cfg) | |
layers = nn.ModuleList() | |
in_channels = embed_dims | |
for _ in range(num_fcs - 1): | |
layers.append( | |
nn.Sequential( | |
Linear(in_channels, feedforward_channels), self.activate, | |
nn.Dropout(dropout))) | |
in_channels = feedforward_channels | |
layers.append(Linear(feedforward_channels, embed_dims)) | |
self.layers = nn.Sequential(*layers) | |
self.dropout = nn.Dropout(dropout) | |
self.add_residual = add_residual | |
def forward(self, x, residual=None): | |
"""Forward function for `FFN`.""" | |
out = self.layers(x) | |
if not self.add_residual: | |
return out | |
if residual is None: | |
residual = x | |
return residual + self.dropout(out) | |
def __repr__(self): | |
"""str: a string that describes the module""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(embed_dims={self.embed_dims}, ' | |
repr_str += f'feedforward_channels={self.feedforward_channels}, ' | |
repr_str += f'num_fcs={self.num_fcs}, ' | |
repr_str += f'act_cfg={self.act_cfg}, ' | |
repr_str += f'dropout={self.dropout}, ' | |
repr_str += f'add_residual={self.add_residual})' | |
return repr_str | |
class TransformerEncoderLayer(nn.Module): | |
"""Implements one encoder layer in DETR transformer. | |
Args: | |
embed_dims (int): The feature dimension. Same as `FFN`. | |
num_heads (int): Parallel attention heads. | |
feedforward_channels (int): The hidden dimension for FFNs. | |
dropout (float): Probability of an element to be zeroed. Default 0.0. | |
order (tuple[str]): The order for encoder layer. Valid examples are | |
('selfattn', 'norm', 'ffn', 'norm') and ('norm', 'selfattn', | |
'norm', 'ffn'). Default ('selfattn', 'norm', 'ffn', 'norm'). | |
act_cfg (dict): The activation config for FFNs. Default ReLU. | |
norm_cfg (dict): Config dict for normalization layer. Default | |
layer normalization. | |
num_fcs (int): The number of fully-connected layers for FFNs. | |
Default 2. | |
""" | |
def __init__(self, | |
embed_dims, | |
num_heads, | |
feedforward_channels, | |
dropout=0.0, | |
order=('selfattn', 'norm', 'ffn', 'norm'), | |
act_cfg=dict(type='ReLU', inplace=True), | |
norm_cfg=dict(type='LN'), | |
num_fcs=2): | |
super(TransformerEncoderLayer, self).__init__() | |
assert isinstance(order, tuple) and len(order) == 4 | |
assert set(order) == set(['selfattn', 'norm', 'ffn']) | |
self.embed_dims = embed_dims | |
self.num_heads = num_heads | |
self.feedforward_channels = feedforward_channels | |
self.dropout = dropout | |
self.order = order | |
self.act_cfg = act_cfg | |
self.norm_cfg = norm_cfg | |
self.num_fcs = num_fcs | |
self.pre_norm = order[0] == 'norm' | |
self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout) | |
self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg, | |
dropout) | |
self.norms = nn.ModuleList() | |
self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) | |
self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) | |
def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None): | |
"""Forward function for `TransformerEncoderLayer`. | |
Args: | |
x (Tensor): The input query with shape [num_key, bs, | |
embed_dims]. Same in `MultiheadAttention.forward`. | |
pos (Tensor): The positional encoding for query. Default None. | |
Same as `query_pos` in `MultiheadAttention.forward`. | |
attn_mask (Tensor): ByteTensor mask with shape [num_key, | |
num_key]. Same in `MultiheadAttention.forward`. Default None. | |
key_padding_mask (Tensor): ByteTensor with shape [bs, num_key]. | |
Same in `MultiheadAttention.forward`. Default None. | |
Returns: | |
Tensor: forwarded results with shape [num_key, bs, embed_dims]. | |
""" | |
norm_cnt = 0 | |
inp_residual = x | |
for layer in self.order: | |
if layer == 'selfattn': | |
# self attention | |
query = key = value = x | |
x = self.self_attn( | |
query, | |
key, | |
value, | |
inp_residual if self.pre_norm else None, | |
query_pos=pos, | |
key_pos=pos, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask) | |
inp_residual = x | |
elif layer == 'norm': | |
x = self.norms[norm_cnt](x) | |
norm_cnt += 1 | |
elif layer == 'ffn': | |
x = self.ffn(x, inp_residual if self.pre_norm else None) | |
return x | |
def __repr__(self): | |
"""str: a string that describes the module""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(embed_dims={self.embed_dims}, ' | |
repr_str += f'num_heads={self.num_heads}, ' | |
repr_str += f'feedforward_channels={self.feedforward_channels}, ' | |
repr_str += f'dropout={self.dropout}, ' | |
repr_str += f'order={self.order}, ' | |
repr_str += f'act_cfg={self.act_cfg}, ' | |
repr_str += f'norm_cfg={self.norm_cfg}, ' | |
repr_str += f'num_fcs={self.num_fcs})' | |
return repr_str | |
class TransformerDecoderLayer(nn.Module): | |
"""Implements one decoder layer in DETR transformer. | |
Args: | |
embed_dims (int): The feature dimension. Same as | |
`TransformerEncoderLayer`. | |
num_heads (int): Parallel attention heads. | |
feedforward_channels (int): Same as `TransformerEncoderLayer`. | |
dropout (float): Same as `TransformerEncoderLayer`. Default 0.0. | |
order (tuple[str]): The order for decoder layer. Valid examples are | |
('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', 'norm') and | |
('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn'). | |
Default the former. | |
act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU. | |
norm_cfg (dict): Config dict for normalization layer. Default | |
layer normalization. | |
num_fcs (int): The number of fully-connected layers in FFNs. | |
""" | |
def __init__(self, | |
embed_dims, | |
num_heads, | |
feedforward_channels, | |
dropout=0.0, | |
order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', | |
'norm'), | |
act_cfg=dict(type='ReLU', inplace=True), | |
norm_cfg=dict(type='LN'), | |
num_fcs=2): | |
super(TransformerDecoderLayer, self).__init__() | |
assert isinstance(order, tuple) and len(order) == 6 | |
assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn']) | |
self.embed_dims = embed_dims | |
self.num_heads = num_heads | |
self.feedforward_channels = feedforward_channels | |
self.dropout = dropout | |
self.order = order | |
self.act_cfg = act_cfg | |
self.norm_cfg = norm_cfg | |
self.num_fcs = num_fcs | |
self.pre_norm = order[0] == 'norm' | |
self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout) | |
self.multihead_attn = MultiheadAttention(embed_dims, num_heads, | |
dropout) | |
self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg, | |
dropout) | |
self.norms = nn.ModuleList() | |
# 3 norm layers in official DETR's TransformerDecoderLayer | |
for _ in range(3): | |
self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) | |
def forward(self, | |
x, | |
memory, | |
memory_pos=None, | |
query_pos=None, | |
memory_attn_mask=None, | |
target_attn_mask=None, | |
memory_key_padding_mask=None, | |
target_key_padding_mask=None): | |
"""Forward function for `TransformerDecoderLayer`. | |
Args: | |
x (Tensor): Input query with shape [num_query, bs, embed_dims]. | |
memory (Tensor): Tensor got from `TransformerEncoder`, with shape | |
[num_key, bs, embed_dims]. | |
memory_pos (Tensor): The positional encoding for `memory`. Default | |
None. Same as `key_pos` in `MultiheadAttention.forward`. | |
query_pos (Tensor): The positional encoding for `query`. Default | |
None. Same as `query_pos` in `MultiheadAttention.forward`. | |
memory_attn_mask (Tensor): ByteTensor mask for `memory`, with | |
shape [num_key, num_key]. Same as `attn_mask` in | |
`MultiheadAttention.forward`. Default None. | |
target_attn_mask (Tensor): ByteTensor mask for `x`, with shape | |
[num_query, num_query]. Same as `attn_mask` in | |
`MultiheadAttention.forward`. Default None. | |
memory_key_padding_mask (Tensor): ByteTensor for `memory`, with | |
shape [bs, num_key]. Same as `key_padding_mask` in | |
`MultiheadAttention.forward`. Default None. | |
target_key_padding_mask (Tensor): ByteTensor for `x`, with shape | |
[bs, num_query]. Same as `key_padding_mask` in | |
`MultiheadAttention.forward`. Default None. | |
Returns: | |
Tensor: forwarded results with shape [num_query, bs, embed_dims]. | |
""" | |
norm_cnt = 0 | |
inp_residual = x | |
for layer in self.order: | |
if layer == 'selfattn': | |
query = key = value = x | |
x = self.self_attn( | |
query, | |
key, | |
value, | |
inp_residual if self.pre_norm else None, | |
query_pos, | |
key_pos=query_pos, | |
attn_mask=target_attn_mask, | |
key_padding_mask=target_key_padding_mask) | |
inp_residual = x | |
elif layer == 'norm': | |
x = self.norms[norm_cnt](x) | |
norm_cnt += 1 | |
elif layer == 'multiheadattn': | |
query = x | |
key = value = memory | |
x = self.multihead_attn( | |
query, | |
key, | |
value, | |
inp_residual if self.pre_norm else None, | |
query_pos, | |
key_pos=memory_pos, | |
attn_mask=memory_attn_mask, | |
key_padding_mask=memory_key_padding_mask) | |
inp_residual = x | |
elif layer == 'ffn': | |
x = self.ffn(x, inp_residual if self.pre_norm else None) | |
return x | |
def __repr__(self): | |
"""str: a string that describes the module""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(embed_dims={self.embed_dims}, ' | |
repr_str += f'num_heads={self.num_heads}, ' | |
repr_str += f'feedforward_channels={self.feedforward_channels}, ' | |
repr_str += f'dropout={self.dropout}, ' | |
repr_str += f'order={self.order}, ' | |
repr_str += f'act_cfg={self.act_cfg}, ' | |
repr_str += f'norm_cfg={self.norm_cfg}, ' | |
repr_str += f'num_fcs={self.num_fcs})' | |
return repr_str | |
class TransformerEncoder(nn.Module): | |
"""Implements the encoder in DETR transformer. | |
Args: | |
num_layers (int): The number of `TransformerEncoderLayer`. | |
embed_dims (int): Same as `TransformerEncoderLayer`. | |
num_heads (int): Same as `TransformerEncoderLayer`. | |
feedforward_channels (int): Same as `TransformerEncoderLayer`. | |
dropout (float): Same as `TransformerEncoderLayer`. Default 0.0. | |
order (tuple[str]): Same as `TransformerEncoderLayer`. | |
act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU. | |
norm_cfg (dict): Same as `TransformerEncoderLayer`. Default | |
layer normalization. | |
num_fcs (int): Same as `TransformerEncoderLayer`. Default 2. | |
""" | |
def __init__(self, | |
num_layers, | |
embed_dims, | |
num_heads, | |
feedforward_channels, | |
dropout=0.0, | |
order=('selfattn', 'norm', 'ffn', 'norm'), | |
act_cfg=dict(type='ReLU', inplace=True), | |
norm_cfg=dict(type='LN'), | |
num_fcs=2): | |
super(TransformerEncoder, self).__init__() | |
assert isinstance(order, tuple) and len(order) == 4 | |
assert set(order) == set(['selfattn', 'norm', 'ffn']) | |
self.num_layers = num_layers | |
self.embed_dims = embed_dims | |
self.num_heads = num_heads | |
self.feedforward_channels = feedforward_channels | |
self.dropout = dropout | |
self.order = order | |
self.act_cfg = act_cfg | |
self.norm_cfg = norm_cfg | |
self.num_fcs = num_fcs | |
self.pre_norm = order[0] == 'norm' | |
self.layers = nn.ModuleList() | |
for _ in range(num_layers): | |
self.layers.append( | |
TransformerEncoderLayer(embed_dims, num_heads, | |
feedforward_channels, dropout, order, | |
act_cfg, norm_cfg, num_fcs)) | |
self.norm = build_norm_layer(norm_cfg, | |
embed_dims)[1] if self.pre_norm else None | |
def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None): | |
"""Forward function for `TransformerEncoder`. | |
Args: | |
x (Tensor): Input query. Same in `TransformerEncoderLayer.forward`. | |
pos (Tensor): Positional encoding for query. Default None. | |
Same in `TransformerEncoderLayer.forward`. | |
attn_mask (Tensor): ByteTensor attention mask. Default None. | |
Same in `TransformerEncoderLayer.forward`. | |
key_padding_mask (Tensor): Same in | |
`TransformerEncoderLayer.forward`. Default None. | |
Returns: | |
Tensor: Results with shape [num_key, bs, embed_dims]. | |
""" | |
for layer in self.layers: | |
x = layer(x, pos, attn_mask, key_padding_mask) | |
if self.norm is not None: | |
x = self.norm(x) | |
return x | |
def __repr__(self): | |
"""str: a string that describes the module""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(num_layers={self.num_layers}, ' | |
repr_str += f'embed_dims={self.embed_dims}, ' | |
repr_str += f'num_heads={self.num_heads}, ' | |
repr_str += f'feedforward_channels={self.feedforward_channels}, ' | |
repr_str += f'dropout={self.dropout}, ' | |
repr_str += f'order={self.order}, ' | |
repr_str += f'act_cfg={self.act_cfg}, ' | |
repr_str += f'norm_cfg={self.norm_cfg}, ' | |
repr_str += f'num_fcs={self.num_fcs})' | |
return repr_str | |
class TransformerDecoder(nn.Module): | |
"""Implements the decoder in DETR transformer. | |
Args: | |
num_layers (int): The number of `TransformerDecoderLayer`. | |
embed_dims (int): Same as `TransformerDecoderLayer`. | |
num_heads (int): Same as `TransformerDecoderLayer`. | |
feedforward_channels (int): Same as `TransformerDecoderLayer`. | |
dropout (float): Same as `TransformerDecoderLayer`. Default 0.0. | |
order (tuple[str]): Same as `TransformerDecoderLayer`. | |
act_cfg (dict): Same as `TransformerDecoderLayer`. Default ReLU. | |
norm_cfg (dict): Same as `TransformerDecoderLayer`. Default | |
layer normalization. | |
num_fcs (int): Same as `TransformerDecoderLayer`. Default 2. | |
""" | |
def __init__(self, | |
num_layers, | |
embed_dims, | |
num_heads, | |
feedforward_channels, | |
dropout=0.0, | |
order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', | |
'norm'), | |
act_cfg=dict(type='ReLU', inplace=True), | |
norm_cfg=dict(type='LN'), | |
num_fcs=2, | |
return_intermediate=False): | |
super(TransformerDecoder, self).__init__() | |
assert isinstance(order, tuple) and len(order) == 6 | |
assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn']) | |
self.num_layers = num_layers | |
self.embed_dims = embed_dims | |
self.num_heads = num_heads | |
self.feedforward_channels = feedforward_channels | |
self.dropout = dropout | |
self.order = order | |
self.act_cfg = act_cfg | |
self.norm_cfg = norm_cfg | |
self.num_fcs = num_fcs | |
self.return_intermediate = return_intermediate | |
self.layers = nn.ModuleList() | |
for _ in range(num_layers): | |
self.layers.append( | |
TransformerDecoderLayer(embed_dims, num_heads, | |
feedforward_channels, dropout, order, | |
act_cfg, norm_cfg, num_fcs)) | |
self.norm = build_norm_layer(norm_cfg, embed_dims)[1] | |
def forward(self, | |
x, | |
memory, | |
memory_pos=None, | |
query_pos=None, | |
memory_attn_mask=None, | |
target_attn_mask=None, | |
memory_key_padding_mask=None, | |
target_key_padding_mask=None): | |
"""Forward function for `TransformerDecoder`. | |
Args: | |
x (Tensor): Input query. Same in `TransformerDecoderLayer.forward`. | |
memory (Tensor): Same in `TransformerDecoderLayer.forward`. | |
memory_pos (Tensor): Same in `TransformerDecoderLayer.forward`. | |
Default None. | |
query_pos (Tensor): Same in `TransformerDecoderLayer.forward`. | |
Default None. | |
memory_attn_mask (Tensor): Same in | |
`TransformerDecoderLayer.forward`. Default None. | |
target_attn_mask (Tensor): Same in | |
`TransformerDecoderLayer.forward`. Default None. | |
memory_key_padding_mask (Tensor): Same in | |
`TransformerDecoderLayer.forward`. Default None. | |
target_key_padding_mask (Tensor): Same in | |
`TransformerDecoderLayer.forward`. Default None. | |
Returns: | |
Tensor: Results with shape [num_query, bs, embed_dims]. | |
""" | |
intermediate = [] | |
for layer in self.layers: | |
x = layer(x, memory, memory_pos, query_pos, memory_attn_mask, | |
target_attn_mask, memory_key_padding_mask, | |
target_key_padding_mask) | |
if self.return_intermediate: | |
intermediate.append(self.norm(x)) | |
if self.norm is not None: | |
x = self.norm(x) | |
if self.return_intermediate: | |
intermediate.pop() | |
intermediate.append(x) | |
if self.return_intermediate: | |
return torch.stack(intermediate) | |
return x.unsqueeze(0) | |
def __repr__(self): | |
"""str: a string that describes the module""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(num_layers={self.num_layers}, ' | |
repr_str += f'embed_dims={self.embed_dims}, ' | |
repr_str += f'num_heads={self.num_heads}, ' | |
repr_str += f'feedforward_channels={self.feedforward_channels}, ' | |
repr_str += f'dropout={self.dropout}, ' | |
repr_str += f'order={self.order}, ' | |
repr_str += f'act_cfg={self.act_cfg}, ' | |
repr_str += f'norm_cfg={self.norm_cfg}, ' | |
repr_str += f'num_fcs={self.num_fcs}, ' | |
repr_str += f'return_intermediate={self.return_intermediate})' | |
return repr_str | |
class Transformer(nn.Module): | |
"""Implements the DETR transformer. | |
Following the official DETR implementation, this module copy-paste | |
from torch.nn.Transformer with modifications: | |
* positional encodings are passed in MultiheadAttention | |
* extra LN at the end of encoder is removed | |
* decoder returns a stack of activations from all decoding layers | |
See `paper: End-to-End Object Detection with Transformers | |
<https://arxiv.org/pdf/2005.12872>`_ for details. | |
Args: | |
embed_dims (int): The feature dimension. | |
num_heads (int): Parallel attention heads. Same as | |
`nn.MultiheadAttention`. | |
num_encoder_layers (int): Number of `TransformerEncoderLayer`. | |
num_decoder_layers (int): Number of `TransformerDecoderLayer`. | |
feedforward_channels (int): The hidden dimension for FFNs used in both | |
encoder and decoder. | |
dropout (float): Probability of an element to be zeroed. Default 0.0. | |
act_cfg (dict): Activation config for FFNs used in both encoder | |
and decoder. Default ReLU. | |
norm_cfg (dict): Config dict for normalization used in both encoder | |
and decoder. Default layer normalization. | |
num_fcs (int): The number of fully-connected layers in FFNs, which is | |
used for both encoder and decoder. | |
pre_norm (bool): Whether the normalization layer is ordered | |
first in the encoder and decoder. Default False. | |
return_intermediate_dec (bool): Whether to return the intermediate | |
output from each TransformerDecoderLayer or only the last | |
TransformerDecoderLayer. Default False. If False, the returned | |
`hs` has shape [num_decoder_layers, bs, num_query, embed_dims]. | |
If True, the returned `hs` will have shape [1, bs, num_query, | |
embed_dims]. | |
""" | |
def __init__(self, | |
embed_dims=512, | |
num_heads=8, | |
num_encoder_layers=6, | |
num_decoder_layers=6, | |
feedforward_channels=2048, | |
dropout=0.0, | |
act_cfg=dict(type='ReLU', inplace=True), | |
norm_cfg=dict(type='LN'), | |
num_fcs=2, | |
pre_norm=False, | |
return_intermediate_dec=False): | |
super(Transformer, self).__init__() | |
self.embed_dims = embed_dims | |
self.num_heads = num_heads | |
self.num_encoder_layers = num_encoder_layers | |
self.num_decoder_layers = num_decoder_layers | |
self.feedforward_channels = feedforward_channels | |
self.dropout = dropout | |
self.act_cfg = act_cfg | |
self.norm_cfg = norm_cfg | |
self.num_fcs = num_fcs | |
self.pre_norm = pre_norm | |
self.return_intermediate_dec = return_intermediate_dec | |
if self.pre_norm: | |
encoder_order = ('norm', 'selfattn', 'norm', 'ffn') | |
decoder_order = ('norm', 'selfattn', 'norm', 'multiheadattn', | |
'norm', 'ffn') | |
else: | |
encoder_order = ('selfattn', 'norm', 'ffn', 'norm') | |
decoder_order = ('selfattn', 'norm', 'multiheadattn', 'norm', | |
'ffn', 'norm') | |
self.encoder = TransformerEncoder(num_encoder_layers, embed_dims, | |
num_heads, feedforward_channels, | |
dropout, encoder_order, act_cfg, | |
norm_cfg, num_fcs) | |
self.decoder = TransformerDecoder(num_decoder_layers, embed_dims, | |
num_heads, feedforward_channels, | |
dropout, decoder_order, act_cfg, | |
norm_cfg, num_fcs, | |
return_intermediate_dec) | |
def init_weights(self, distribution='uniform'): | |
"""Initialize the transformer weights.""" | |
# follow the official DETR to init parameters | |
for m in self.modules(): | |
if hasattr(m, 'weight') and m.weight.dim() > 1: | |
xavier_init(m, distribution=distribution) | |
def forward(self, x, mask, query_embed, pos_embed): | |
"""Forward function for `Transformer`. | |
Args: | |
x (Tensor): Input query with shape [bs, c, h, w] where | |
c = embed_dims. | |
mask (Tensor): The key_padding_mask used for encoder and decoder, | |
with shape [bs, h, w]. | |
query_embed (Tensor): The query embedding for decoder, with shape | |
[num_query, c]. | |
pos_embed (Tensor): The positional encoding for encoder and | |
decoder, with the same shape as `x`. | |
Returns: | |
tuple[Tensor]: results of decoder containing the following tensor. | |
- out_dec: Output from decoder. If return_intermediate_dec \ | |
is True output has shape [num_dec_layers, bs, | |
num_query, embed_dims], else has shape [1, bs, \ | |
num_query, embed_dims]. | |
- memory: Output results from encoder, with shape \ | |
[bs, embed_dims, h, w]. | |
""" | |
bs, c, h, w = x.shape | |
x = x.flatten(2).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] | |
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) | |
query_embed = query_embed.unsqueeze(1).repeat( | |
1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] | |
mask = mask.flatten(1) # [bs, h, w] -> [bs, h*w] | |
memory = self.encoder( | |
x, pos=pos_embed, attn_mask=None, key_padding_mask=mask) | |
target = torch.zeros_like(query_embed) | |
# out_dec: [num_layers, num_query, bs, dim] | |
out_dec = self.decoder( | |
target, | |
memory, | |
memory_pos=pos_embed, | |
query_pos=query_embed, | |
memory_attn_mask=None, | |
target_attn_mask=None, | |
memory_key_padding_mask=mask, | |
target_key_padding_mask=None) | |
out_dec = out_dec.transpose(1, 2) | |
memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) | |
return out_dec, memory | |
def __repr__(self): | |
"""str: a string that describes the module""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(embed_dims={self.embed_dims}, ' | |
repr_str += f'num_heads={self.num_heads}, ' | |
repr_str += f'num_encoder_layers={self.num_encoder_layers}, ' | |
repr_str += f'num_decoder_layers={self.num_decoder_layers}, ' | |
repr_str += f'feedforward_channels={self.feedforward_channels}, ' | |
repr_str += f'dropout={self.dropout}, ' | |
repr_str += f'act_cfg={self.act_cfg}, ' | |
repr_str += f'norm_cfg={self.norm_cfg}, ' | |
repr_str += f'num_fcs={self.num_fcs}, ' | |
repr_str += f'pre_norm={self.pre_norm}, ' | |
repr_str += f'return_intermediate_dec={self.return_intermediate_dec})' | |
return repr_str | |
class DynamicConv(nn.Module): | |
"""Implements Dynamic Convolution. | |
This module generate parameters for each sample and | |
use bmm to implement 1*1 convolution. Code is modified | |
from the `official github repo <https://github.com/PeizeSun/ | |
SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ . | |
Args: | |
in_channels (int): The input feature channel. | |
Defaults to 256. | |
feat_channels (int): The inner feature channel. | |
Defaults to 64. | |
out_channels (int, optional): The output feature channel. | |
When not specified, it will be set to `in_channels` | |
by default | |
input_feat_shape (int): The shape of input feature. | |
Defaults to 7. | |
act_cfg (dict): The activation config for DynamicConv. | |
norm_cfg (dict): Config dict for normalization layer. Default | |
layer normalization. | |
""" | |
def __init__(self, | |
in_channels=256, | |
feat_channels=64, | |
out_channels=None, | |
input_feat_shape=7, | |
act_cfg=dict(type='ReLU', inplace=True), | |
norm_cfg=dict(type='LN')): | |
super(DynamicConv, self).__init__() | |
self.in_channels = in_channels | |
self.feat_channels = feat_channels | |
self.out_channels_raw = out_channels | |
self.input_feat_shape = input_feat_shape | |
self.act_cfg = act_cfg | |
self.norm_cfg = norm_cfg | |
self.out_channels = out_channels if out_channels else in_channels | |
self.num_params_in = self.in_channels * self.feat_channels | |
self.num_params_out = self.out_channels * self.feat_channels | |
self.dynamic_layer = nn.Linear( | |
self.in_channels, self.num_params_in + self.num_params_out) | |
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] | |
self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] | |
self.activation = build_activation_layer(act_cfg) | |
num_output = self.out_channels * input_feat_shape**2 | |
self.fc_layer = nn.Linear(num_output, self.out_channels) | |
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] | |
def forward(self, param_feature, input_feature): | |
"""Forward function for `DynamicConv`. | |
Args: | |
param_feature (Tensor): The feature can be used | |
to generate the parameter, has shape | |
(num_all_proposals, in_channels). | |
input_feature (Tensor): Feature that | |
interact with parameters, has shape | |
(num_all_proposals, in_channels, H, W). | |
Returns: | |
Tensor: The output feature has shape | |
(num_all_proposals, out_channels). | |
""" | |
num_proposals = param_feature.size(0) | |
input_feature = input_feature.view(num_proposals, self.in_channels, | |
-1).permute(2, 0, 1) | |
input_feature = input_feature.permute(1, 0, 2) | |
parameters = self.dynamic_layer(param_feature) | |
param_in = parameters[:, :self.num_params_in].view( | |
-1, self.in_channels, self.feat_channels) | |
param_out = parameters[:, -self.num_params_out:].view( | |
-1, self.feat_channels, self.out_channels) | |
# input_feature has shape (num_all_proposals, H*W, in_channels) | |
# param_in has shape (num_all_proposals, in_channels, feat_channels) | |
# feature has shape (num_all_proposals, H*W, feat_channels) | |
features = torch.bmm(input_feature, param_in) | |
features = self.norm_in(features) | |
features = self.activation(features) | |
# param_out has shape (batch_size, feat_channels, out_channels) | |
features = torch.bmm(features, param_out) | |
features = self.norm_out(features) | |
features = self.activation(features) | |
features = features.flatten(1) | |
features = self.fc_layer(features) | |
features = self.fc_norm(features) | |
features = self.activation(features) | |
return features | |
def __repr__(self): | |
"""str: a string that describes the module""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(in_channels={self.in_channels}, ' | |
repr_str += f'feat_channels={self.feat_channels}, ' | |
repr_str += f'out_channels={self.out_channels_raw}, ' | |
repr_str += f'input_feat_shape={self.input_feat_shape}, ' | |
repr_str += f'act_cfg={self.act_cfg}, ' | |
repr_str += f'norm_cfg={self.norm_cfg})' | |
return repr_str | |