Spaces:
Runtime error
Runtime error
import fvcore.nn.weight_init as weight_init | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .msdeformattn import PositionEmbeddingSine, _get_clones, _get_activation_fn | |
from lib.model_zoo.common.get_model import get_model, register | |
########## | |
# helper # | |
########## | |
def with_pos_embed(x, pos): | |
return x if pos is None else x + pos | |
############## | |
# One Former # | |
############## | |
class Transformer(nn.Module): | |
def __init__(self, | |
d_model=512, | |
nhead=8, | |
num_encoder_layers=6, | |
num_decoder_layers=6, | |
dim_feedforward=2048, | |
dropout=0.1, | |
activation="relu", | |
normalize_before=False, | |
return_intermediate_dec=False,): | |
super().__init__() | |
encoder_layer = TransformerEncoderLayer( | |
d_model, nhead, dim_feedforward, dropout, activation, normalize_before) | |
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None | |
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) | |
decoder_layer = TransformerDecoderLayer( | |
d_model, nhead, dim_feedforward, dropout, activation, normalize_before) | |
decoder_norm = nn.LayerNorm(d_model) | |
self.decoder = TransformerDecoder( | |
decoder_layer, | |
num_decoder_layers, | |
decoder_norm, | |
return_intermediate=return_intermediate_dec,) | |
self._reset_parameters() | |
self.d_model = d_model | |
self.nhead = nhead | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def forward(self, src, mask, query_embed, pos_embed, task_token=None): | |
# flatten NxCxHxW to HWxNxC | |
bs, c, h, w = src.shape | |
src = src.flatten(2).permute(2, 0, 1) | |
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) | |
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) | |
if mask is not None: | |
mask = mask.flatten(1) | |
if task_token is None: | |
tgt = torch.zeros_like(query_embed) | |
else: | |
tgt = task_token.repeat(query_embed.shape[0], 1, 1) | |
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # src = memory | |
hs = self.decoder( | |
tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed | |
) | |
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) | |
class TransformerEncoder(nn.Module): | |
def __init__(self, encoder_layer, num_layers, norm=None): | |
super().__init__() | |
self.layers = _get_clones(encoder_layer, num_layers) | |
self.num_layers = num_layers | |
self.norm = norm | |
def forward(self, src, mask=None, src_key_padding_mask=None, pos=None,): | |
output = src | |
for layer in self.layers: | |
output = layer( | |
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos | |
) | |
if self.norm is not None: | |
output = self.norm(output) | |
return output | |
class TransformerDecoder(nn.Module): | |
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): | |
super().__init__() | |
self.layers = _get_clones(decoder_layer, num_layers) | |
self.num_layers = num_layers | |
self.norm = norm | |
self.return_intermediate = return_intermediate | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask=None, | |
memory_mask=None, | |
tgt_key_padding_mask=None, | |
memory_key_padding_mask=None, | |
pos=None, | |
query_pos=None,): | |
output = tgt | |
intermediate = [] | |
for layer in self.layers: | |
output = layer( | |
output, | |
memory, | |
tgt_mask=tgt_mask, | |
memory_mask=memory_mask, | |
tgt_key_padding_mask=tgt_key_padding_mask, | |
memory_key_padding_mask=memory_key_padding_mask, | |
pos=pos, | |
query_pos=query_pos, | |
) | |
if self.return_intermediate: | |
intermediate.append(self.norm(output)) | |
if self.norm is not None: | |
output = self.norm(output) | |
if self.return_intermediate: | |
intermediate.pop() | |
intermediate.append(output) | |
if self.return_intermediate: | |
return torch.stack(intermediate) | |
return output.unsqueeze(0) | |
class TransformerEncoderLayer(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
nhead, | |
dim_feedforward=2048, | |
dropout=0.1, | |
activation="relu", | |
normalize_before=False, ): | |
super().__init__() | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
# Implementation of Feedforward model | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
def with_pos_embed(self, x, pos): | |
return x if pos is None else x + pos | |
def forward_post( | |
self, | |
src, | |
src_mask = None, | |
src_key_padding_mask = None, | |
pos = None,): | |
q = k = self.with_pos_embed(src, pos) | |
src2 = self.self_attn( | |
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask | |
)[0] | |
src = src + self.dropout1(src2) | |
src = self.norm1(src) | |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) | |
src = src + self.dropout2(src2) | |
src = self.norm2(src) | |
return src | |
def forward_pre( | |
self, | |
src, | |
src_mask = None, | |
src_key_padding_mask = None, | |
pos = None,): | |
src2 = self.norm1(src) | |
q = k = self.with_pos_embed(src2, pos) | |
src2 = self.self_attn( | |
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask | |
)[0] | |
src = src + self.dropout1(src2) | |
src2 = self.norm2(src) | |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) | |
src = src + self.dropout2(src2) | |
return src | |
def forward( | |
self, | |
src, | |
src_mask = None, | |
src_key_padding_mask = None, | |
pos = None,): | |
if self.normalize_before: | |
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) | |
return self.forward_post(src, src_mask, src_key_padding_mask, pos) | |
class TransformerDecoderLayer(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
nhead, | |
dim_feedforward=2048, | |
dropout=0.1, | |
activation="relu", | |
normalize_before=False,): | |
super().__init__() | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
# Implementation of Feedforward model | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.norm3 = nn.LayerNorm(d_model) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.dropout3 = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
def with_pos_embed(self, x, pos): | |
return x if pos is None else x + pos | |
def forward_post( | |
self, | |
tgt, | |
memory, | |
tgt_mask = None, | |
memory_mask = None, | |
tgt_key_padding_mask = None, | |
memory_key_padding_mask = None, | |
pos = None, | |
query_pos = None,): | |
q = k = self.with_pos_embed(tgt, query_pos) | |
tgt2 = self.self_attn( | |
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] | |
tgt = tgt + self.dropout1(tgt2) | |
tgt = self.norm1(tgt) | |
tgt2 = self.multihead_attn( | |
query=self.with_pos_embed(tgt, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, | |
attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask,)[0] | |
tgt = tgt + self.dropout2(tgt2) | |
tgt = self.norm2(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
tgt = tgt + self.dropout3(tgt2) | |
tgt = self.norm3(tgt) | |
return tgt | |
def forward_pre( | |
self, | |
tgt, | |
memory, | |
tgt_mask = None, | |
memory_mask = None, | |
tgt_key_padding_mask = None, | |
memory_key_padding_mask = None, | |
pos = None, | |
query_pos = None,): | |
tgt2 = self.norm1(tgt) | |
q = k = self.with_pos_embed(tgt2, query_pos) | |
tgt2 = self.self_attn( | |
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask | |
)[0] | |
tgt = tgt + self.dropout1(tgt2) | |
tgt2 = self.norm2(tgt) | |
tgt2 = self.multihead_attn( | |
query=self.with_pos_embed(tgt2, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, | |
attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask, | |
)[0] | |
tgt = tgt + self.dropout2(tgt2) | |
tgt2 = self.norm3(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |
tgt = tgt + self.dropout3(tgt2) | |
return tgt | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask = None, | |
memory_mask = None, | |
tgt_key_padding_mask = None, | |
memory_key_padding_mask = None, | |
pos = None, | |
query_pos = None, ): | |
if self.normalize_before: | |
return self.forward_pre( | |
tgt, | |
memory, | |
tgt_mask, | |
memory_mask, | |
tgt_key_padding_mask, | |
memory_key_padding_mask, | |
pos, | |
query_pos,) | |
return self.forward_post( | |
tgt, | |
memory, | |
tgt_mask, | |
memory_mask, | |
tgt_key_padding_mask, | |
memory_key_padding_mask, | |
pos, | |
query_pos,) | |
class SelfAttentionLayer(nn.Module): | |
def __init__(self, d_model, nhead, dropout=0.0, | |
activation="relu", normalize_before=False): | |
super().__init__() | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.norm = nn.LayerNorm(d_model) | |
self.dropout = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
self._reset_parameters() | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def with_pos_embed(self, tensor, pos): | |
return tensor if pos is None else tensor + pos | |
def forward_post(self, tgt, | |
tgt_mask = None, | |
tgt_key_padding_mask = None, | |
query_pos = None): | |
q = k = self.with_pos_embed(tgt, query_pos).transpose(0 ,1) | |
tgt2 = self.self_attn(q, k, value=tgt.transpose(0 ,1), attn_mask=tgt_mask, | |
key_padding_mask=tgt_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2.transpose(0 ,1)) | |
tgt = self.norm(tgt) | |
return tgt | |
def forward_pre(self, tgt, | |
tgt_mask = None, | |
tgt_key_padding_mask = None, | |
query_pos = None): | |
tgt2 = self.norm(tgt) | |
q = k = self.with_pos_embed(tgt2, query_pos) | |
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, | |
key_padding_mask=tgt_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2) | |
return tgt | |
def forward(self, tgt, | |
tgt_mask = None, | |
tgt_key_padding_mask = None, | |
query_pos = None): | |
if self.normalize_before: | |
return self.forward_pre(tgt, tgt_mask, | |
tgt_key_padding_mask, query_pos) | |
return self.forward_post(tgt, tgt_mask, | |
tgt_key_padding_mask, query_pos) | |
class CrossAttentionLayer(nn.Module): | |
def __init__(self, d_model, nhead, dropout=0.0, | |
activation="relu", normalize_before=False): | |
super().__init__() | |
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.norm = nn.LayerNorm(d_model) | |
self.dropout = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
self._reset_parameters() | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def with_pos_embed(self, tensor, pos): | |
return tensor if pos is None else tensor + pos | |
def forward_post(self, tgt, memory, | |
memory_mask = None, | |
memory_key_padding_mask = None, | |
pos = None, | |
query_pos = None): | |
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos).transpose(0, 1), | |
key=self.with_pos_embed(memory, pos).transpose(0, 1), | |
value=memory.transpose(0, 1), attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2.transpose(0, 1)) | |
tgt = self.norm(tgt) | |
return tgt | |
def forward_pre(self, tgt, memory, | |
memory_mask = None, | |
memory_key_padding_mask = None, | |
pos = None, | |
query_pos = None): | |
tgt2 = self.norm(tgt) | |
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2) | |
return tgt | |
def forward(self, tgt, memory, | |
memory_mask = None, | |
memory_key_padding_mask = None, | |
pos = None, | |
query_pos = None): | |
if self.normalize_before: | |
return self.forward_pre(tgt, memory, memory_mask, | |
memory_key_padding_mask, pos, query_pos) | |
return self.forward_post(tgt, memory, memory_mask, | |
memory_key_padding_mask, pos, query_pos) | |
class FFNLayer(nn.Module): | |
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, | |
activation="relu", normalize_before=False): | |
super().__init__() | |
# Implementation of Feedforward model | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm = nn.LayerNorm(d_model) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
self._reset_parameters() | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def with_pos_embed(self, tensor, pos): | |
return tensor if pos is None else tensor + pos | |
def forward_post(self, tgt): | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
tgt = tgt + self.dropout(tgt2) | |
tgt = self.norm(tgt) | |
return tgt | |
def forward_pre(self, tgt): | |
tgt2 = self.norm(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |
tgt = tgt + self.dropout(tgt2) | |
return tgt | |
def forward(self, tgt): | |
if self.normalize_before: | |
return self.forward_pre(tgt) | |
return self.forward_post(tgt) | |
class MLP(nn.Module): | |
""" Very simple multi-layer perceptron (also called FFN)""" | |
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
def forward(self, x): | |
for i, layer in enumerate(self.layers): | |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
return x | |
class Seet_OneFormer_TDecoder(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
mask_classification, | |
num_classes, | |
hidden_dim, | |
num_queries, | |
nheads, | |
dropout, | |
dim_feedforward, | |
enc_layers, | |
is_train, | |
dec_layers, | |
class_dec_layers, | |
pre_norm, | |
mask_dim, | |
enforce_input_project, | |
use_task_norm,): | |
super().__init__() | |
assert mask_classification, "Only support mask classification model" | |
self.mask_classification = mask_classification | |
self.is_train = is_train | |
self.use_task_norm = use_task_norm | |
# positional encoding | |
N_steps = hidden_dim // 2 | |
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) | |
self.class_transformer = Transformer( | |
d_model=hidden_dim, | |
dropout=dropout, | |
nhead=nheads, | |
dim_feedforward=dim_feedforward, | |
num_encoder_layers=enc_layers, | |
num_decoder_layers=class_dec_layers, | |
normalize_before=pre_norm, | |
return_intermediate_dec=False, | |
) | |
# define Transformer decoder here | |
self.num_heads = nheads | |
self.num_layers = dec_layers | |
self.transformer_self_attention_layers = nn.ModuleList() | |
self.transformer_cross_attention_layers = nn.ModuleList() | |
self.transformer_ffn_layers = nn.ModuleList() | |
for _ in range(self.num_layers): | |
self.transformer_self_attention_layers.append( | |
SelfAttentionLayer( | |
d_model=hidden_dim, | |
nhead=nheads, | |
dropout=0.0, | |
normalize_before=pre_norm, | |
) | |
) | |
self.transformer_cross_attention_layers.append( | |
CrossAttentionLayer( | |
d_model=hidden_dim, | |
nhead=nheads, | |
dropout=0.0, | |
normalize_before=pre_norm, | |
) | |
) | |
self.transformer_ffn_layers.append( | |
FFNLayer( | |
d_model=hidden_dim, | |
dim_feedforward=dim_feedforward, | |
dropout=0.0, | |
normalize_before=pre_norm, | |
) | |
) | |
self.decoder_norm = nn.LayerNorm(hidden_dim) | |
self.num_queries = num_queries | |
# learnable query p.e. | |
self.query_embed = nn.Embedding(num_queries, hidden_dim) | |
# level embedding (we always use 3 scales) | |
self.num_feature_levels = 3 | |
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) | |
self.input_proj = nn.ModuleList() | |
for _ in range(self.num_feature_levels): | |
if in_channels != hidden_dim or enforce_input_project: | |
self.input_proj.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1)) | |
weight_init.c2_xavier_fill(self.input_proj[-1]) | |
else: | |
self.input_proj.append(nn.Sequential()) | |
self.class_input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1) | |
weight_init.c2_xavier_fill(self.class_input_proj) | |
# output FFNs | |
if self.mask_classification: | |
self.class_embed = nn.Linear(hidden_dim, num_classes + 1) | |
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) | |
def forward(self, x, mask_features, tasks): | |
# x is a list of multi-scale feature | |
assert len(x) == self.num_feature_levels | |
src = [] | |
pos = [] | |
size_list = [] | |
for i in range(self.num_feature_levels): | |
size_list.append(x[i].shape[-2:]) | |
pos.append(self.pe_layer(x[i], None).flatten(2)) | |
src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) | |
pos[-1] = pos[-1].transpose(1, 2) | |
src[-1] = src[-1].transpose(1, 2) | |
bs, _, _ = src[0].shape | |
query_embed = self.query_embed.weight.unsqueeze(0).repeat(bs, 1, 1) | |
tasks = tasks.unsqueeze(0) | |
if self.use_task_norm: | |
tasks = self.decoder_norm(tasks) | |
feats = self.pe_layer(mask_features, None) | |
out_t, _ = self.class_transformer( | |
feats, None, | |
self.query_embed.weight[:-1], | |
self.class_input_proj(mask_features), | |
tasks if self.use_task_norm else None) | |
out_t = out_t[0] | |
out = torch.cat([out_t, tasks], dim=1) | |
output = out.clone() | |
predictions_class = [] | |
predictions_mask = [] | |
# prediction heads on learnable query features | |
outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( | |
output, mask_features, attn_mask_target_size=size_list[0]) | |
predictions_class.append(outputs_class) | |
predictions_mask.append(outputs_mask) | |
for i in range(self.num_layers): | |
level_index = i % self.num_feature_levels | |
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False | |
output = self.transformer_cross_attention_layers[i]( | |
output, src[level_index], | |
memory_mask=attn_mask, | |
memory_key_padding_mask=None, | |
pos=pos[level_index], query_pos=query_embed, ) | |
output = self.transformer_self_attention_layers[i]( | |
output, tgt_mask=None, | |
tgt_key_padding_mask=None, | |
query_pos=query_embed, ) | |
# FFN | |
output = self.transformer_ffn_layers[i](output) | |
outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( | |
output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels]) | |
predictions_class.append(outputs_class) | |
predictions_mask.append(outputs_mask) | |
assert len(predictions_class) == self.num_layers + 1 | |
out = { | |
'pred_logits': predictions_class[-1], | |
'pred_masks': predictions_mask[-1],} | |
return out | |
def forward_prediction_heads(self, output, mask_features, attn_mask_target_size): | |
decoder_output = self.decoder_norm(output) | |
outputs_class = self.class_embed(decoder_output) | |
mask_embed = self.mask_embed(decoder_output) | |
outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) | |
attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) | |
attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() | |
attn_mask = attn_mask.detach() | |
return outputs_class, outputs_mask, attn_mask | |