File size: 35,566 Bytes
2cd560a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer, trunc_normal_init
from mmcv.cnn.bricks.transformer import build_dropout
try:
from torch.cuda.amp import autocast
WITH_AUTOCAST = True
except ImportError:
WITH_AUTOCAST = False
def get_grid_index(init_grid_size, map_size, device):
"""For every initial grid, get its index in the feature map.
Note:
[H_init, W_init]: shape of initial grid
[H, W]: shape of feature map
N_init: numbers of initial token
Args:
init_grid_size (list[int] or tuple[int]): initial grid resolution in
format [H_init, W_init].
map_size (list[int] or tuple[int]): feature map resolution in format
[H, W].
device: the device of output
Returns:
idx (torch.LongTensor[B, N_init]): index in flattened feature map.
"""
H_init, W_init = init_grid_size
H, W = map_size
idx = torch.arange(H * W, device=device).reshape(1, 1, H, W)
idx = F.interpolate(idx.float(), [H_init, W_init], mode='nearest').long()
return idx.flatten()
def index_points(points, idx):
"""Sample features following the index.
Note:
B: batch size
N: point number
C: channel number of each point
Ns: sampled point number
Args:
points (torch.Tensor[B, N, C]): input points data
idx (torch.LongTensor[B, S]): sample index
Returns:
new_points (torch.Tensor[B, Ns, C]):, indexed points data
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(
B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def token2map(token_dict):
"""Transform vision tokens to feature map. This function only works when
the resolution of the feature map is not higher than the initial grid
structure.
Note:
B: batch size
C: channel number of each token
[H, W]: shape of feature map
N_init: numbers of initial token
Args:
token_dict (dict): dict for token information.
Returns:
x_out (Tensor[B, C, H, W]): feature map.
"""
x = token_dict['x']
H, W = token_dict['map_size']
H_init, W_init = token_dict['init_grid_size']
idx_token = token_dict['idx_token']
B, N, C = x.shape
N_init = H_init * W_init
device = x.device
if N_init == N and N == H * W:
# for the initial tokens with grid structure, just reshape
return x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
# for each initial grid, get the corresponding index in
# the flattened feature map.
idx_hw = get_grid_index([H_init, W_init], [H, W],
device=device)[None, :].expand(B, -1)
idx_batch = torch.arange(B, device=device)[:, None].expand(B, N_init)
value = x.new_ones(B * N_init)
# choose the way with fewer flops.
if N_init < N * H * W:
# use sparse matrix multiplication
# Flops: B * N_init * (C+2)
idx_hw = idx_hw + idx_batch * H * W
idx_tokens = idx_token + idx_batch * N
coor = torch.stack([idx_hw, idx_tokens], dim=0).reshape(2, B * N_init)
# torch.sparse do not support gradient for
# sparse tensor, so we detach it
value = value.detach().to(torch.float32)
# build a sparse matrix with the shape [B * H * W, B * N]
A = torch.sparse.FloatTensor(coor, value,
torch.Size([B * H * W, B * N]))
# normalize the weight for each row
if WITH_AUTOCAST:
with autocast(enabled=False):
all_weight = A @ x.new_ones(B * N, 1).type(
torch.float32) + 1e-6
else:
all_weight = A @ x.new_ones(B * N, 1).type(torch.float32) + 1e-6
value = value / all_weight[idx_hw.reshape(-1), 0]
# update the matrix with normalize weight
A = torch.sparse.FloatTensor(coor, value,
torch.Size([B * H * W, B * N]))
# sparse matrix multiplication
if WITH_AUTOCAST:
with autocast(enabled=False):
x_out = A @ x.reshape(B * N, C).to(torch.float32) # [B*H*W, C]
else:
x_out = A @ x.reshape(B * N, C).to(torch.float32) # [B*H*W, C]
else:
# use dense matrix multiplication
# Flops: B * N * H * W * (C+2)
coor = torch.stack([idx_batch, idx_hw, idx_token],
dim=0).reshape(3, B * N_init)
# build a matrix with shape [B, H*W, N]
A = torch.sparse.FloatTensor(coor, value, torch.Size([B, H * W,
N])).to_dense()
# normalize the weight
A = A / (A.sum(dim=-1, keepdim=True) + 1e-6)
x_out = A @ x # [B, H*W, C]
x_out = x_out.type(x.dtype)
x_out = x_out.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
return x_out
def map2token(feature_map, token_dict):
"""Transform feature map to vision tokens. This function only works when
the resolution of the feature map is not higher than the initial grid
structure.
Note:
B: batch size
C: channel number
[H, W]: shape of feature map
N_init: numbers of initial token
Args:
feature_map (Tensor[B, C, H, W]): feature map.
token_dict (dict): dict for token information.
Returns:
out (Tensor[B, N, C]): token features.
"""
idx_token = token_dict['idx_token']
N = token_dict['token_num']
H_init, W_init = token_dict['init_grid_size']
N_init = H_init * W_init
B, C, H, W = feature_map.shape
device = feature_map.device
if N_init == N and N == H * W:
# for the initial tokens with grid structure, just reshape
return feature_map.flatten(2).permute(0, 2, 1).contiguous()
idx_hw = get_grid_index([H_init, W_init], [H, W],
device=device)[None, :].expand(B, -1)
idx_batch = torch.arange(B, device=device)[:, None].expand(B, N_init)
value = feature_map.new_ones(B * N_init)
# choose the way with fewer flops.
if N_init < N * H * W:
# use sparse matrix multiplication
# Flops: B * N_init * (C+2)
idx_token = idx_token + idx_batch * N
idx_hw = idx_hw + idx_batch * H * W
indices = torch.stack([idx_token, idx_hw], dim=0).reshape(2, -1)
# sparse mm do not support gradient for sparse matrix
value = value.detach().to(torch.float32)
# build a sparse matrix with shape [B*N, B*H*W]
A = torch.sparse_coo_tensor(indices, value, (B * N, B * H * W))
# normalize the matrix
if WITH_AUTOCAST:
with autocast(enabled=False):
all_weight = A @ torch.ones(
[B * H * W, 1], device=device, dtype=torch.float32) + 1e-6
else:
all_weight = A @ torch.ones(
[B * H * W, 1], device=device, dtype=torch.float32) + 1e-6
value = value / all_weight[idx_token.reshape(-1), 0]
A = torch.sparse_coo_tensor(indices, value, (B * N, B * H * W))
# out: [B*N, C]
if WITH_AUTOCAST:
with autocast(enabled=False):
out = A @ feature_map.permute(0, 2, 3, 1).contiguous().reshape(
B * H * W, C).float()
else:
out = A @ feature_map.permute(0, 2, 3, 1).contiguous().reshape(
B * H * W, C).float()
else:
# use dense matrix multiplication
# Flops: B * N * H * W * (C+2)
indices = torch.stack([idx_batch, idx_token, idx_hw],
dim=0).reshape(3, -1)
value = value.detach() # To reduce the training time, we detach here.
A = torch.sparse_coo_tensor(indices, value, (B, N, H * W)).to_dense()
# normalize the matrix
A = A / (A.sum(dim=-1, keepdim=True) + 1e-6)
out = A @ feature_map.permute(0, 2, 3, 1).reshape(B, H * W,
C).contiguous()
out = out.type(feature_map.dtype)
out = out.reshape(B, N, C)
return out
def token_interp(target_dict, source_dict):
"""Transform token features between different distribution.
Note:
B: batch size
N: token number
C: channel number
Args:
target_dict (dict): dict for target token information
source_dict (dict): dict for source token information.
Returns:
x_out (Tensor[B, N, C]): token features.
"""
x_s = source_dict['x']
idx_token_s = source_dict['idx_token']
idx_token_t = target_dict['idx_token']
T = target_dict['token_num']
B, S, C = x_s.shape
N_init = idx_token_s.shape[1]
weight = target_dict['agg_weight'] if 'agg_weight' in target_dict.keys(
) else None
if weight is None:
weight = x_s.new_ones(B, N_init, 1)
weight = weight.reshape(-1)
# choose the way with fewer flops.
if N_init < T * S:
# use sparse matrix multiplication
# Flops: B * N_init * (C+2)
idx_token_t = idx_token_t + torch.arange(
B, device=x_s.device)[:, None] * T
idx_token_s = idx_token_s + torch.arange(
B, device=x_s.device)[:, None] * S
coor = torch.stack([idx_token_t, idx_token_s],
dim=0).reshape(2, B * N_init)
# torch.sparse does not support grad for sparse matrix
weight = weight.float().detach().to(torch.float32)
# build a matrix with shape [B*T, B*S]
A = torch.sparse.FloatTensor(coor, weight, torch.Size([B * T, B * S]))
# normalize the matrix
if WITH_AUTOCAST:
with autocast(enabled=False):
all_weight = A.type(torch.float32) @ x_s.new_ones(
B * S, 1).type(torch.float32) + 1e-6
else:
all_weight = A.type(torch.float32) @ x_s.new_ones(B * S, 1).type(
torch.float32) + 1e-6
weight = weight / all_weight[idx_token_t.reshape(-1), 0]
A = torch.sparse.FloatTensor(coor, weight, torch.Size([B * T, B * S]))
# sparse matmul
if WITH_AUTOCAST:
with autocast(enabled=False):
x_out = A.type(torch.float32) @ x_s.reshape(B * S, C).type(
torch.float32)
else:
x_out = A.type(torch.float32) @ x_s.reshape(B * S, C).type(
torch.float32)
else:
# use dense matrix multiplication
# Flops: B * T * S * (C+2)
idx_batch = torch.arange(
B, device=x_s.device)[:, None].expand(B, N_init)
coor = torch.stack([idx_batch, idx_token_t, idx_token_s],
dim=0).reshape(3, B * N_init)
weight = weight.detach() # detach to reduce training time
# build a matrix with shape [B, T, S]
A = torch.sparse.FloatTensor(coor, weight, torch.Size([B, T,
S])).to_dense()
# normalize the matrix
A = A / (A.sum(dim=-1, keepdim=True) + 1e-6)
# dense matmul
x_out = A @ x_s
x_out = x_out.reshape(B, T, C).type(x_s.dtype)
return x_out
def cluster_dpc_knn(token_dict, cluster_num, k=5, token_mask=None):
"""Cluster tokens with DPC-KNN algorithm.
Note:
B: batch size
N: token number
C: channel number
Args:
token_dict (dict): dict for token information
cluster_num (int): cluster number
k (int): number of the nearest neighbor used for local density.
token_mask (Tensor[B, N]): mask indicating which token is the
padded empty token. Non-zero value means the token is meaningful,
zero value means the token is an empty token. If set to None, all
tokens are regarded as meaningful.
Return:
idx_cluster (Tensor[B, N]): cluster index of each token.
cluster_num (int): actual cluster number. In this function, it equals
to the input cluster number.
"""
with torch.no_grad():
x = token_dict['x']
B, N, C = x.shape
dist_matrix = torch.cdist(x, x) / (C**0.5)
if token_mask is not None:
token_mask = token_mask > 0
# in order to not affect the local density, the
# distance between empty tokens and any other
# tokens should be the maximal distance.
dist_matrix = \
dist_matrix * token_mask[:, None, :] +\
(dist_matrix.max() + 1) * (~token_mask[:, None, :])
# get local density
dist_nearest, index_nearest = torch.topk(
dist_matrix, k=k, dim=-1, largest=False)
density = (-(dist_nearest**2).mean(dim=-1)).exp()
# add a little noise to ensure no tokens have the same density.
density = density + torch.rand(
density.shape, device=density.device, dtype=density.dtype) * 1e-6
if token_mask is not None:
# the density of empty token should be 0
density = density * token_mask
# get distance indicator
mask = density[:, None, :] > density[:, :, None]
mask = mask.type(x.dtype)
dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None]
dist, index_parent = (dist_matrix * mask + dist_max *
(1 - mask)).min(dim=-1)
# select clustering center according to score
score = dist * density
_, index_down = torch.topk(score, k=cluster_num, dim=-1)
# assign tokens to the nearest center
dist_matrix = index_points(dist_matrix, index_down)
idx_cluster = dist_matrix.argmin(dim=1)
# make sure cluster center merge to itself
idx_batch = torch.arange(
B, device=x.device)[:, None].expand(B, cluster_num)
idx_tmp = torch.arange(
cluster_num, device=x.device)[None, :].expand(B, cluster_num)
idx_cluster[idx_batch.reshape(-1),
index_down.reshape(-1)] = idx_tmp.reshape(-1)
return idx_cluster, cluster_num
def merge_tokens(token_dict, idx_cluster, cluster_num, token_weight=None):
"""Merge tokens in the same cluster to a single cluster. Implemented by
torch.index_add(). Flops: B*N*(C+2)
Note:
B: batch size
N: token number
C: channel number
Args:
token_dict (dict): dict for input token information
idx_cluster (Tensor[B, N]): cluster index of each token.
cluster_num (int): cluster number
token_weight (Tensor[B, N, 1]): weight for each token.
Return:
out_dict (dict): dict for output token information
"""
x = token_dict['x']
idx_token = token_dict['idx_token']
agg_weight = token_dict['agg_weight']
B, N, C = x.shape
if token_weight is None:
token_weight = x.new_ones(B, N, 1)
idx_batch = torch.arange(B, device=x.device)[:, None]
idx = idx_cluster + idx_batch * cluster_num
all_weight = token_weight.new_zeros(B * cluster_num, 1)
all_weight.index_add_(
dim=0, index=idx.reshape(B * N), source=token_weight.reshape(B * N, 1))
all_weight = all_weight + 1e-6
norm_weight = token_weight / all_weight[idx]
# average token features
x_merged = x.new_zeros(B * cluster_num, C)
source = x * norm_weight
x_merged.index_add_(
dim=0,
index=idx.reshape(B * N),
source=source.reshape(B * N, C).type(x.dtype))
x_merged = x_merged.reshape(B, cluster_num, C)
idx_token_new = index_points(idx_cluster[..., None], idx_token).squeeze(-1)
weight_t = index_points(norm_weight, idx_token)
agg_weight_new = agg_weight * weight_t
agg_weight_new / agg_weight_new.max(dim=1, keepdim=True)[0]
out_dict = {}
out_dict['x'] = x_merged
out_dict['token_num'] = cluster_num
out_dict['map_size'] = token_dict['map_size']
out_dict['init_grid_size'] = token_dict['init_grid_size']
out_dict['idx_token'] = idx_token_new
out_dict['agg_weight'] = agg_weight_new
return out_dict
class MLP(nn.Module):
"""FFN with Depthwise Conv of TCFormer.
Args:
in_features (int): The feature dimension.
hidden_features (int, optional): The hidden dimension of FFNs.
Defaults: The same as in_features.
out_features (int, optional): The output feature dimension.
Defaults: The same as in_features.
act_layer (nn.Module, optional): The activation config for FFNs.
Default: nn.GELU.
drop (float, optional): drop out rate. Default: 0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def init_weights(self):
"""init weights."""
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DWConv(nn.Module):
"""Depthwise Conv for regular grid-based tokens.
Args:
dim (int): The feature dimension.
"""
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class TCFormerRegularAttention(nn.Module):
"""Spatial Reduction Attention for regular grid-based tokens.
Args:
dim (int): The feature dimension of tokens,
num_heads (int): Parallel attention heads.
qkv_bias (bool): enable bias for qkv if True. Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after attention process.
Default: 0.0.
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
Attention. Default: 1.
use_sr_conv (bool): If True, use a conv layer for spatial reduction.
If False, use a pooling process for spatial reduction. Defaults:
True.
"""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1,
use_sr_conv=True,
):
super().__init__()
assert dim % num_heads == 0, \
f'dim {dim} should be divided by num_heads {num_heads}.'
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
self.use_sr_conv = use_sr_conv
if sr_ratio > 1 and self.use_sr_conv:
self.sr = nn.Conv2d(
dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
kv = x.permute(0, 2, 1).reshape(B, C, H, W)
if self.use_sr_conv:
kv = self.sr(kv).reshape(B, C, -1).permute(0, 2,
1).contiguous()
kv = self.norm(kv)
else:
kv = F.avg_pool2d(
kv, kernel_size=self.sr_ratio, stride=self.sr_ratio)
kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous()
else:
kv = x
kv = self.kv(kv).reshape(B, -1, 2, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1,
4).contiguous()
k, v = kv[0], kv[1]
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TCFormerRegularBlock(nn.Module):
"""Transformer block for regular grid-based tokens.
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (int): The expansion ratio for the FFNs.
qkv_bias (bool): enable bias for qkv if True. Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop (float): Dropout layers after attention process and in FFN.
Default: 0.0.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
drop_path (int, optional): The drop path rate of transformer block.
Default: 0.0
act_layer (nn.Module, optional): The activation config for FFNs.
Default: nn.GELU.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
Attention. Default: 1.
use_sr_conv (bool): If True, use a conv layer for spatial reduction.
If False, use a pooling process for spatial reduction. Defaults:
True.
"""
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_cfg=dict(type='LN'),
sr_ratio=1,
use_sr_conv=True):
super().__init__()
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.attn = TCFormerRegularAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
sr_ratio=sr_ratio,
use_sr_conv=use_sr_conv)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path))
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class TokenConv(nn.Conv2d):
"""Conv layer for dynamic tokens.
A skip link is added between the input and output tokens to reserve detail
tokens.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
groups = kwargs['groups'] if 'groups' in kwargs.keys() else 1
self.skip = nn.Conv1d(
in_channels=kwargs['in_channels'],
out_channels=kwargs['out_channels'],
kernel_size=1,
bias=False,
groups=groups)
def forward(self, token_dict):
x = token_dict['x']
x = self.skip(x.permute(0, 2, 1)).permute(0, 2, 1)
x_map = token2map(token_dict)
x_map = super().forward(x_map)
x = x + map2token(x_map, token_dict)
return x
class TCMLP(nn.Module):
"""FFN with Depthwise Conv for dynamic tokens.
Args:
in_features (int): The feature dimension.
hidden_features (int, optional): The hidden dimension of FFNs.
Defaults: The same as in_features.
out_features (int, optional): The output feature dimension.
Defaults: The same as in_features.
act_layer (nn.Module, optional): The activation config for FFNs.
Default: nn.GELU.
drop (float, optional): drop out rate. Default: 0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = TokenConv(
in_channels=hidden_features,
out_channels=hidden_features,
kernel_size=3,
padding=1,
stride=1,
bias=True,
groups=hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def init_weights(self):
"""init weights."""
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, token_dict):
token_dict['x'] = self.fc1(token_dict['x'])
x = self.dwconv(token_dict)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class TCFormerDynamicAttention(TCFormerRegularAttention):
"""Spatial Reduction Attention for dynamic tokens."""
def forward(self, q_dict, kv_dict):
"""Attention process for dynamic tokens.
Dynamic tokens are represented by a dict with the following keys:
x (torch.Tensor[B, N, C]): token features.
token_num(int): token number.
map_size(list[int] or tuple[int]): feature map resolution in
format [H, W].
init_grid_size(list[int] or tuple[int]): initial grid resolution
in format [H_init, W_init].
idx_token(torch.LongTensor[B, N_init]): indicates which token
the initial grid belongs to.
agg_weight(torch.LongTensor[B, N_init] or None): weight for
aggregation. Indicates the weight of each token in its
cluster. If set to None, uniform weight is used.
Note:
B: batch size
N: token number
C: channel number
Ns: sampled point number
[H_init, W_init]: shape of initial grid
[H, W]: shape of feature map
N_init: numbers of initial token
Args:
q_dict (dict): dict for query token information
kv_dict (dict): dict for key and value token information
Return:
x (torch.Tensor[B, N, C]): output token features.
"""
q = q_dict['x']
kv = kv_dict['x']
B, Nq, C = q.shape
Nkv = kv.shape[1]
conf_kv = kv_dict['token_score'] if 'token_score' in kv_dict.keys(
) else kv.new_zeros(B, Nkv, 1)
q = self.q(q).reshape(B, Nq, self.num_heads,
C // self.num_heads).permute(0, 2, 1,
3).contiguous()
if self.sr_ratio > 1:
tmp = torch.cat([kv, conf_kv], dim=-1)
tmp_dict = kv_dict.copy()
tmp_dict['x'] = tmp
tmp_dict['map_size'] = q_dict['map_size']
tmp = token2map(tmp_dict)
kv = tmp[:, :C]
conf_kv = tmp[:, C:]
if self.use_sr_conv:
kv = self.sr(kv)
_, _, h, w = kv.shape
kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous()
kv = self.norm(kv)
else:
kv = F.avg_pool2d(
kv, kernel_size=self.sr_ratio, stride=self.sr_ratio)
kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous()
conf_kv = F.avg_pool2d(
conf_kv, kernel_size=self.sr_ratio, stride=self.sr_ratio)
conf_kv = conf_kv.reshape(B, 1, -1).permute(0, 2, 1).contiguous()
kv = self.kv(kv).reshape(B, -1, 2, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1,
4).contiguous()
k, v = kv[0], kv[1]
attn = (q * self.scale) @ k.transpose(-2, -1)
conf_kv = conf_kv.squeeze(-1)[:, None, None, :]
attn = attn + conf_kv
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# Transformer block for dynamic tokens
class TCFormerDynamicBlock(TCFormerRegularBlock):
"""Transformer block for dynamic tokens.
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (int): The expansion ratio for the FFNs.
qkv_bias (bool): enable bias for qkv if True. Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop (float): Dropout layers after attention process and in FFN.
Default: 0.0.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
drop_path (int, optional): The drop path rate of transformer block.
Default: 0.0
act_layer (nn.Module, optional): The activation config for FFNs.
Default: nn.GELU.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
Attention. Default: 1.
use_sr_conv (bool): If True, use a conv layer for spatial reduction.
If False, use a pooling process for spatial reduction. Defaults:
True.
"""
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_cfg=dict(type='LN'),
sr_ratio=1,
use_sr_conv=True):
super(TCFormerRegularBlock, self).__init__()
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.attn = TCFormerDynamicAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
sr_ratio=sr_ratio,
use_sr_conv=use_sr_conv)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path))
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = TCMLP(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, inputs):
"""Forward function.
Args:
inputs (dict or tuple[dict] or list[dict]): input dynamic
token information. If a single dict is provided, it's
regraded as query and key, value. If a tuple or list
of dict is provided, the first one is regarded as key
and the second one is regarded as key, value.
Return:
q_dict (dict): dict for output token information
"""
if isinstance(inputs, tuple) or isinstance(inputs, list):
q_dict, kv_dict = inputs
else:
q_dict, kv_dict = inputs, None
x = q_dict['x']
# norm1
q_dict['x'] = self.norm1(q_dict['x'])
if kv_dict is None:
kv_dict = q_dict
else:
kv_dict['x'] = self.norm1(kv_dict['x'])
# attn
x = x + self.drop_path(self.attn(q_dict, kv_dict))
# mlp
q_dict['x'] = self.norm2(x)
x = x + self.drop_path(self.mlp(q_dict))
q_dict['x'] = x
return q_dict
|