Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import warnings | |
import torch | |
import torch.nn as nn | |
from annotator.uniformer.mmcv import ConfigDict, deprecated_api_warning | |
from annotator.uniformer.mmcv.cnn import Linear, build_activation_layer, build_norm_layer | |
from annotator.uniformer.mmcv.runner.base_module import BaseModule, ModuleList, Sequential | |
from annotator.uniformer.mmcv.utils import build_from_cfg | |
from .drop import build_dropout | |
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, | |
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) | |
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file | |
try: | |
from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401 | |
warnings.warn( | |
ImportWarning( | |
'``MultiScaleDeformableAttention`` has been moved to ' | |
'``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501 | |
'``from annotator.uniformer.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501 | |
'to ``from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501 | |
)) | |
except ImportError: | |
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from ' | |
'``mmcv.ops.multi_scale_deform_attn``, ' | |
'You should install ``mmcv-full`` if you need this module. ') | |
def build_positional_encoding(cfg, default_args=None): | |
"""Builder for Position Encoding.""" | |
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args) | |
def build_attention(cfg, default_args=None): | |
"""Builder for attention.""" | |
return build_from_cfg(cfg, ATTENTION, default_args) | |
def build_feedforward_network(cfg, default_args=None): | |
"""Builder for feed-forward network (FFN).""" | |
return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args) | |
def build_transformer_layer(cfg, default_args=None): | |
"""Builder for transformer layer.""" | |
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args) | |
def build_transformer_layer_sequence(cfg, default_args=None): | |
"""Builder for transformer encoder and transformer decoder.""" | |
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args) | |
class MultiheadAttention(BaseModule): | |
"""A wrapper for ``torch.nn.MultiheadAttention``. | |
This module implements MultiheadAttention with identity connection, | |
and positional encoding is also passed as input. | |
Args: | |
embed_dims (int): The embedding dimension. | |
num_heads (int): Parallel attention heads. | |
attn_drop (float): A Dropout layer on attn_output_weights. | |
Default: 0.0. | |
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. | |
Default: 0.0. | |
dropout_layer (obj:`ConfigDict`): The dropout_layer used | |
when adding the shortcut. | |
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. | |
Default: None. | |
batch_first (bool): When it is True, Key, Query and Value are shape of | |
(batch, n, embed_dim), otherwise (n, batch, embed_dim). | |
Default to False. | |
""" | |
def __init__(self, | |
embed_dims, | |
num_heads, | |
attn_drop=0., | |
proj_drop=0., | |
dropout_layer=dict(type='Dropout', drop_prob=0.), | |
init_cfg=None, | |
batch_first=False, | |
**kwargs): | |
super(MultiheadAttention, self).__init__(init_cfg) | |
if 'dropout' in kwargs: | |
warnings.warn('The arguments `dropout` in MultiheadAttention ' | |
'has been deprecated, now you can separately ' | |
'set `attn_drop`(float), proj_drop(float), ' | |
'and `dropout_layer`(dict) ') | |
attn_drop = kwargs['dropout'] | |
dropout_layer['drop_prob'] = kwargs.pop('dropout') | |
self.embed_dims = embed_dims | |
self.num_heads = num_heads | |
self.batch_first = batch_first | |
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, | |
**kwargs) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.dropout_layer = build_dropout( | |
dropout_layer) if dropout_layer else nn.Identity() | |
def forward(self, | |
query, | |
key=None, | |
value=None, | |
identity=None, | |
query_pos=None, | |
key_pos=None, | |
attn_mask=None, | |
key_padding_mask=None, | |
**kwargs): | |
"""Forward function for `MultiheadAttention`. | |
**kwargs allow passing a more general data flow when combining | |
with other operations in `transformerlayer`. | |
Args: | |
query (Tensor): The input query with shape [num_queries, bs, | |
embed_dims] if self.batch_first is False, else | |
[bs, num_queries embed_dims]. | |
key (Tensor): The key tensor with shape [num_keys, bs, | |
embed_dims] if self.batch_first is False, else | |
[bs, num_keys, embed_dims] . | |
If None, the ``query`` will be used. Defaults to None. | |
value (Tensor): The value tensor with same shape as `key`. | |
Same in `nn.MultiheadAttention.forward`. Defaults to None. | |
If None, the `key` will be used. | |
identity (Tensor): This tensor, with the same shape as x, | |
will be used for the identity link. | |
If None, `x` will be used. Defaults to None. | |
query_pos (Tensor): The positional encoding for query, with | |
the same shape as `x`. If not None, it will | |
be added to `x` before forward function. Defaults to None. | |
key_pos (Tensor): The positional encoding for `key`, with the | |
same shape as `key`. Defaults to None. If not None, it will | |
be added to `key` before forward function. If None, and | |
`query_pos` has the same shape as `key`, then `query_pos` | |
will be used for `key_pos`. Defaults to None. | |
attn_mask (Tensor): ByteTensor mask with shape [num_queries, | |
num_keys]. Same in `nn.MultiheadAttention.forward`. | |
Defaults to None. | |
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. | |
Defaults to None. | |
Returns: | |
Tensor: forwarded results with shape | |
[num_queries, bs, embed_dims] | |
if self.batch_first is False, else | |
[bs, num_queries embed_dims]. | |
""" | |
if key is None: | |
key = query | |
if value is None: | |
value = key | |
if identity is None: | |
identity = query | |
if key_pos is None: | |
if query_pos is not None: | |
# use query_pos if key_pos is not available | |
if query_pos.shape == key.shape: | |
key_pos = query_pos | |
else: | |
warnings.warn(f'position encoding of key is' | |
f'missing in {self.__class__.__name__}.') | |
if query_pos is not None: | |
query = query + query_pos | |
if key_pos is not None: | |
key = key + key_pos | |
# Because the dataflow('key', 'query', 'value') of | |
# ``torch.nn.MultiheadAttention`` is (num_query, batch, | |
# embed_dims), We should adjust the shape of dataflow from | |
# batch_first (batch, num_query, embed_dims) to num_query_first | |
# (num_query ,batch, embed_dims), and recover ``attn_output`` | |
# from num_query_first to batch_first. | |
if self.batch_first: | |
query = query.transpose(0, 1) | |
key = key.transpose(0, 1) | |
value = value.transpose(0, 1) | |
out = self.attn( | |
query=query, | |
key=key, | |
value=value, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask)[0] | |
if self.batch_first: | |
out = out.transpose(0, 1) | |
return identity + self.dropout_layer(self.proj_drop(out)) | |
class FFN(BaseModule): | |
"""Implements feed-forward networks (FFNs) with identity connection. | |
Args: | |
embed_dims (int): The feature dimension. Same as | |
`MultiheadAttention`. Defaults: 256. | |
feedforward_channels (int): The hidden dimension of FFNs. | |
Defaults: 1024. | |
num_fcs (int, optional): The number of fully-connected layers in | |
FFNs. Default: 2. | |
act_cfg (dict, optional): The activation config for FFNs. | |
Default: dict(type='ReLU') | |
ffn_drop (float, optional): Probability of an element to be | |
zeroed in FFN. Default 0.0. | |
add_identity (bool, optional): Whether to add the | |
identity connection. Default: `True`. | |
dropout_layer (obj:`ConfigDict`): The dropout_layer used | |
when adding the shortcut. | |
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. | |
Default: None. | |
""" | |
def __init__(self, | |
embed_dims=256, | |
feedforward_channels=1024, | |
num_fcs=2, | |
act_cfg=dict(type='ReLU', inplace=True), | |
ffn_drop=0., | |
dropout_layer=None, | |
add_identity=True, | |
init_cfg=None, | |
**kwargs): | |
super(FFN, self).__init__(init_cfg) | |
assert num_fcs >= 2, 'num_fcs should be no less ' \ | |
f'than 2. got {num_fcs}.' | |
self.embed_dims = embed_dims | |
self.feedforward_channels = feedforward_channels | |
self.num_fcs = num_fcs | |
self.act_cfg = act_cfg | |
self.activate = build_activation_layer(act_cfg) | |
layers = [] | |
in_channels = embed_dims | |
for _ in range(num_fcs - 1): | |
layers.append( | |
Sequential( | |
Linear(in_channels, feedforward_channels), self.activate, | |
nn.Dropout(ffn_drop))) | |
in_channels = feedforward_channels | |
layers.append(Linear(feedforward_channels, embed_dims)) | |
layers.append(nn.Dropout(ffn_drop)) | |
self.layers = Sequential(*layers) | |
self.dropout_layer = build_dropout( | |
dropout_layer) if dropout_layer else torch.nn.Identity() | |
self.add_identity = add_identity | |
def forward(self, x, identity=None): | |
"""Forward function for `FFN`. | |
The function would add x to the output tensor if residue is None. | |
""" | |
out = self.layers(x) | |
if not self.add_identity: | |
return self.dropout_layer(out) | |
if identity is None: | |
identity = x | |
return identity + self.dropout_layer(out) | |
class BaseTransformerLayer(BaseModule): | |
"""Base `TransformerLayer` for vision transformer. | |
It can be built from `mmcv.ConfigDict` and support more flexible | |
customization, for example, using any number of `FFN or LN ` and | |
use different kinds of `attention` by specifying a list of `ConfigDict` | |
named `attn_cfgs`. It is worth mentioning that it supports `prenorm` | |
when you specifying `norm` as the first element of `operation_order`. | |
More details about the `prenorm`: `On Layer Normalization in the | |
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ . | |
Args: | |
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): | |
Configs for `self_attention` or `cross_attention` modules, | |
The order of the configs in the list should be consistent with | |
corresponding attentions in operation_order. | |
If it is a dict, all of the attention modules in operation_order | |
will be built with this config. Default: None. | |
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): | |
Configs for FFN, The order of the configs in the list should be | |
consistent with corresponding ffn in operation_order. | |
If it is a dict, all of the attention modules in operation_order | |
will be built with this config. | |
operation_order (tuple[str]): The execution order of operation | |
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). | |
Support `prenorm` when you specifying first element as `norm`. | |
Default:None. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='LN'). | |
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. | |
Default: None. | |
batch_first (bool): Key, Query and Value are shape | |
of (batch, n, embed_dim) | |
or (n, batch, embed_dim). Default to False. | |
""" | |
def __init__(self, | |
attn_cfgs=None, | |
ffn_cfgs=dict( | |
type='FFN', | |
embed_dims=256, | |
feedforward_channels=1024, | |
num_fcs=2, | |
ffn_drop=0., | |
act_cfg=dict(type='ReLU', inplace=True), | |
), | |
operation_order=None, | |
norm_cfg=dict(type='LN'), | |
init_cfg=None, | |
batch_first=False, | |
**kwargs): | |
deprecated_args = dict( | |
feedforward_channels='feedforward_channels', | |
ffn_dropout='ffn_drop', | |
ffn_num_fcs='num_fcs') | |
for ori_name, new_name in deprecated_args.items(): | |
if ori_name in kwargs: | |
warnings.warn( | |
f'The arguments `{ori_name}` in BaseTransformerLayer ' | |
f'has been deprecated, now you should set `{new_name}` ' | |
f'and other FFN related arguments ' | |
f'to a dict named `ffn_cfgs`. ') | |
ffn_cfgs[new_name] = kwargs[ori_name] | |
super(BaseTransformerLayer, self).__init__(init_cfg) | |
self.batch_first = batch_first | |
assert set(operation_order) & set( | |
['self_attn', 'norm', 'ffn', 'cross_attn']) == \ | |
set(operation_order), f'The operation_order of' \ | |
f' {self.__class__.__name__} should ' \ | |
f'contains all four operation type ' \ | |
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}" | |
num_attn = operation_order.count('self_attn') + operation_order.count( | |
'cross_attn') | |
if isinstance(attn_cfgs, dict): | |
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] | |
else: | |
assert num_attn == len(attn_cfgs), f'The length ' \ | |
f'of attn_cfg {num_attn} is ' \ | |
f'not consistent with the number of attention' \ | |
f'in operation_order {operation_order}.' | |
self.num_attn = num_attn | |
self.operation_order = operation_order | |
self.norm_cfg = norm_cfg | |
self.pre_norm = operation_order[0] == 'norm' | |
self.attentions = ModuleList() | |
index = 0 | |
for operation_name in operation_order: | |
if operation_name in ['self_attn', 'cross_attn']: | |
if 'batch_first' in attn_cfgs[index]: | |
assert self.batch_first == attn_cfgs[index]['batch_first'] | |
else: | |
attn_cfgs[index]['batch_first'] = self.batch_first | |
attention = build_attention(attn_cfgs[index]) | |
# Some custom attentions used as `self_attn` | |
# or `cross_attn` can have different behavior. | |
attention.operation_name = operation_name | |
self.attentions.append(attention) | |
index += 1 | |
self.embed_dims = self.attentions[0].embed_dims | |
self.ffns = ModuleList() | |
num_ffns = operation_order.count('ffn') | |
if isinstance(ffn_cfgs, dict): | |
ffn_cfgs = ConfigDict(ffn_cfgs) | |
if isinstance(ffn_cfgs, dict): | |
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] | |
assert len(ffn_cfgs) == num_ffns | |
for ffn_index in range(num_ffns): | |
if 'embed_dims' not in ffn_cfgs[ffn_index]: | |
ffn_cfgs['embed_dims'] = self.embed_dims | |
else: | |
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims | |
self.ffns.append( | |
build_feedforward_network(ffn_cfgs[ffn_index], | |
dict(type='FFN'))) | |
self.norms = ModuleList() | |
num_norms = operation_order.count('norm') | |
for _ in range(num_norms): | |
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) | |
def forward(self, | |
query, | |
key=None, | |
value=None, | |
query_pos=None, | |
key_pos=None, | |
attn_masks=None, | |
query_key_padding_mask=None, | |
key_padding_mask=None, | |
**kwargs): | |
"""Forward function for `TransformerDecoderLayer`. | |
**kwargs contains some specific arguments of attentions. | |
Args: | |
query (Tensor): The input query with shape | |
[num_queries, bs, embed_dims] if | |
self.batch_first is False, else | |
[bs, num_queries embed_dims]. | |
key (Tensor): The key tensor with shape [num_keys, bs, | |
embed_dims] if self.batch_first is False, else | |
[bs, num_keys, embed_dims] . | |
value (Tensor): The value tensor with same shape as `key`. | |
query_pos (Tensor): The positional encoding for `query`. | |
Default: None. | |
key_pos (Tensor): The positional encoding for `key`. | |
Default: None. | |
attn_masks (List[Tensor] | None): 2D Tensor used in | |
calculation of corresponding attention. The length of | |
it should equal to the number of `attention` in | |
`operation_order`. Default: None. | |
query_key_padding_mask (Tensor): ByteTensor for `query`, with | |
shape [bs, num_queries]. Only used in `self_attn` layer. | |
Defaults to None. | |
key_padding_mask (Tensor): ByteTensor for `query`, with | |
shape [bs, num_keys]. Default: None. | |
Returns: | |
Tensor: forwarded results with shape [num_queries, bs, embed_dims]. | |
""" | |
norm_index = 0 | |
attn_index = 0 | |
ffn_index = 0 | |
identity = query | |
if attn_masks is None: | |
attn_masks = [None for _ in range(self.num_attn)] | |
elif isinstance(attn_masks, torch.Tensor): | |
attn_masks = [ | |
copy.deepcopy(attn_masks) for _ in range(self.num_attn) | |
] | |
warnings.warn(f'Use same attn_mask in all attentions in ' | |
f'{self.__class__.__name__} ') | |
else: | |
assert len(attn_masks) == self.num_attn, f'The length of ' \ | |
f'attn_masks {len(attn_masks)} must be equal ' \ | |
f'to the number of attention in ' \ | |
f'operation_order {self.num_attn}' | |
for layer in self.operation_order: | |
if layer == 'self_attn': | |
temp_key = temp_value = query | |
query = self.attentions[attn_index]( | |
query, | |
temp_key, | |
temp_value, | |
identity if self.pre_norm else None, | |
query_pos=query_pos, | |
key_pos=query_pos, | |
attn_mask=attn_masks[attn_index], | |
key_padding_mask=query_key_padding_mask, | |
**kwargs) | |
attn_index += 1 | |
identity = query | |
elif layer == 'norm': | |
query = self.norms[norm_index](query) | |
norm_index += 1 | |
elif layer == 'cross_attn': | |
query = self.attentions[attn_index]( | |
query, | |
key, | |
value, | |
identity if self.pre_norm else None, | |
query_pos=query_pos, | |
key_pos=key_pos, | |
attn_mask=attn_masks[attn_index], | |
key_padding_mask=key_padding_mask, | |
**kwargs) | |
attn_index += 1 | |
identity = query | |
elif layer == 'ffn': | |
query = self.ffns[ffn_index]( | |
query, identity if self.pre_norm else None) | |
ffn_index += 1 | |
return query | |
class TransformerLayerSequence(BaseModule): | |
"""Base class for TransformerEncoder and TransformerDecoder in vision | |
transformer. | |
As base-class of Encoder and Decoder in vision transformer. | |
Support customization such as specifying different kind | |
of `transformer_layer` in `transformer_coder`. | |
Args: | |
transformerlayer (list[obj:`mmcv.ConfigDict`] | | |
obj:`mmcv.ConfigDict`): Config of transformerlayer | |
in TransformerCoder. If it is obj:`mmcv.ConfigDict`, | |
it would be repeated `num_layer` times to a | |
list[`mmcv.ConfigDict`]. Default: None. | |
num_layers (int): The number of `TransformerLayer`. Default: None. | |
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. | |
Default: None. | |
""" | |
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): | |
super(TransformerLayerSequence, self).__init__(init_cfg) | |
if isinstance(transformerlayers, dict): | |
transformerlayers = [ | |
copy.deepcopy(transformerlayers) for _ in range(num_layers) | |
] | |
else: | |
assert isinstance(transformerlayers, list) and \ | |
len(transformerlayers) == num_layers | |
self.num_layers = num_layers | |
self.layers = ModuleList() | |
for i in range(num_layers): | |
self.layers.append(build_transformer_layer(transformerlayers[i])) | |
self.embed_dims = self.layers[0].embed_dims | |
self.pre_norm = self.layers[0].pre_norm | |
def forward(self, | |
query, | |
key, | |
value, | |
query_pos=None, | |
key_pos=None, | |
attn_masks=None, | |
query_key_padding_mask=None, | |
key_padding_mask=None, | |
**kwargs): | |
"""Forward function for `TransformerCoder`. | |
Args: | |
query (Tensor): Input query with shape | |
`(num_queries, bs, embed_dims)`. | |
key (Tensor): The key tensor with shape | |
`(num_keys, bs, embed_dims)`. | |
value (Tensor): The value tensor with shape | |
`(num_keys, bs, embed_dims)`. | |
query_pos (Tensor): The positional encoding for `query`. | |
Default: None. | |
key_pos (Tensor): The positional encoding for `key`. | |
Default: None. | |
attn_masks (List[Tensor], optional): Each element is 2D Tensor | |
which is used in calculation of corresponding attention in | |
operation_order. Default: None. | |
query_key_padding_mask (Tensor): ByteTensor for `query`, with | |
shape [bs, num_queries]. Only used in self-attention | |
Default: None. | |
key_padding_mask (Tensor): ByteTensor for `query`, with | |
shape [bs, num_keys]. Default: None. | |
Returns: | |
Tensor: results with shape [num_queries, bs, embed_dims]. | |
""" | |
for layer in self.layers: | |
query = layer( | |
query, | |
key, | |
value, | |
query_pos=query_pos, | |
key_pos=key_pos, | |
attn_masks=attn_masks, | |
query_key_padding_mask=query_key_padding_mask, | |
key_padding_mask=key_padding_mask, | |
**kwargs) | |
return query | |