|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Mostly copy-paste from timm library. |
|
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py |
|
""" |
|
from copy import deepcopy |
|
import math |
|
from functools import partial |
|
from sympy import flatten |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor, pixel_shuffle |
|
|
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange |
|
from torch.nn.modules import GELU, LayerNorm |
|
|
|
|
|
|
|
from .utils import trunc_normal_ |
|
|
|
from pdb import set_trace as st |
|
|
|
try: |
|
from xformers.ops import memory_efficient_attention, unbind, fmha |
|
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp |
|
|
|
XFORMERS_AVAILABLE = True |
|
except ImportError: |
|
|
|
XFORMERS_AVAILABLE = False |
|
|
|
|
|
class Attention(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads=8, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
attn_drop=0., |
|
proj_drop=0.): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
self.scale = qk_scale or head_dim**-0.5 |
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x): |
|
B, N, C = x.shape |
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, |
|
C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
|
|
return x |
|
|
|
|
|
class MemEffAttention(Attention): |
|
|
|
def forward(self, x: Tensor, attn_bias=None) -> Tensor: |
|
if not XFORMERS_AVAILABLE: |
|
assert attn_bias is None, "xFormers is required for nested tensors usage" |
|
return super().forward(x) |
|
|
|
B, N, C = x.shape |
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) |
|
|
|
q, k, v = unbind(qkv, 2) |
|
|
|
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) |
|
|
|
x = x.reshape([B, N, C]) |
|
|
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads=8, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
attn_drop=0., |
|
proj_drop=0.): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
|
|
self.scale = qk_scale or head_dim**-0.5 |
|
|
|
self.wq = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.wk = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.wv = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x): |
|
|
|
B, N, C = x.shape |
|
q = self.wq(x[:, |
|
0:1, ...]).reshape(B, 1, self.num_heads, |
|
C // self.num_heads).permute( |
|
0, 2, 1, |
|
3) |
|
k = self.wk(x).reshape(B, N, |
|
self.num_heads, C // self.num_heads).permute( |
|
0, 2, 1, 3) |
|
v = self.wv(x).reshape(B, N, |
|
self.num_heads, C // self.num_heads).permute( |
|
0, 2, 1, 3) |
|
|
|
attn = (q @ k.transpose( |
|
-2, -1)) * self.scale |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x = (attn @ v).transpose(1, 2).reshape( |
|
B, 1, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
|
|
class Conv3D_Aware_CrossAttention(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads=8, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
attn_drop=0., |
|
proj_drop=0.): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
|
|
self.scale = qk_scale or head_dim**-0.5 |
|
|
|
self.wq = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.wk = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.wv = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x): |
|
|
|
B, group_size, N, C = x.shape |
|
p = int(N**0.5) |
|
assert p**2 == N, 'check input dim, no [cls] needed here' |
|
assert group_size == 3, 'designed for triplane here' |
|
|
|
x = x.reshape(B, group_size, p, p, C) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q_x = torch.empty( |
|
B * group_size * N, |
|
1, |
|
|
|
|
|
C, |
|
device=x.device) |
|
k_x = torch.empty( |
|
B * group_size * N, |
|
2 * p, |
|
|
|
|
|
C, |
|
device=x.device) |
|
v_x = torch.empty_like(k_x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index_i, index_j = torch.meshgrid(torch.arange(0, p), |
|
torch.arange(0, p), |
|
indexing='ij') |
|
index_mesh_grid = torch.stack([index_i, index_j], 0).to( |
|
x.device).unsqueeze(0).repeat_interleave(B, |
|
0).reshape(B, 2, p, |
|
p) |
|
|
|
for i in range(group_size): |
|
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( |
|
0, 2, 3, 1, 4).reshape(B * N, 1, C) |
|
|
|
|
|
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + |
|
1] |
|
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] |
|
|
|
assert plane_yz.shape == plane_zx.shape == ( |
|
B, 1, p, p, C), 'check sub plane dimensions' |
|
|
|
pooling_plane_yz = torch.gather( |
|
plane_yz, |
|
dim=2, |
|
index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( |
|
-1, -1, -1, p, |
|
C)).permute(0, 2, 1, 3, 4) |
|
pooling_plane_zx = torch.gather( |
|
plane_zx, |
|
dim=3, |
|
index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( |
|
-1, -1, p, -1, |
|
C)).permute(0, 3, 1, 2, 4) |
|
|
|
k_x[B * i * N:B * (i + 1) * |
|
N] = v_x[B * i * N:B * (i + 1) * N] = torch.cat( |
|
[pooling_plane_yz, pooling_plane_zx], |
|
dim=2).reshape(B * N, 2 * p, |
|
C) |
|
|
|
|
|
|
|
|
|
|
|
q = self.wq(q_x).reshape(B * group_size * N, 1, |
|
self.num_heads, C // self.num_heads).permute( |
|
0, 2, 1, |
|
3) |
|
k = self.wk(k_x).reshape(B * group_size * N, 2 * p, self.num_heads, |
|
C // self.num_heads).permute(0, 2, 1, 3) |
|
v = self.wv(v_x).reshape(B * group_size * N, 2 * p, self.num_heads, |
|
C // self.num_heads).permute(0, 2, 1, 3) |
|
|
|
attn = (q @ k.transpose( |
|
-2, -1)) * self.scale |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x = (attn @ v).transpose(1, 2).reshape( |
|
B * 3 * N, 1, |
|
C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
|
|
|
|
x = x.reshape(B, 3, N, C) |
|
|
|
return x |
|
|
|
|
|
class xformer_Conv3D_Aware_CrossAttention(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads=8, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
attn_drop=0., |
|
proj_drop=0.): |
|
super().__init__() |
|
|
|
|
|
|
|
self.num_heads = num_heads |
|
self.wq = nn.Linear(dim, dim * 1, bias=qkv_bias) |
|
self.w_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
self.index_mesh_grid = None |
|
|
|
def forward(self, x, attn_bias=None): |
|
|
|
B, group_size, N, C = x.shape |
|
p = int(N**0.5) |
|
assert p**2 == N, 'check input dim, no [cls] needed here' |
|
assert group_size == 3, 'designed for triplane here' |
|
|
|
x = x.reshape(B, group_size, p, p, C) |
|
|
|
q_x = torch.empty(B * group_size * N, 1, C, device=x.device) |
|
context = torch.empty(B * group_size * N, 2 * p, C, |
|
device=x.device) |
|
|
|
if self.index_mesh_grid is None: |
|
index_i, index_j = torch.meshgrid(torch.arange(0, p), |
|
torch.arange(0, p), |
|
indexing='ij') |
|
index_mesh_grid = torch.stack([index_i, index_j], 0).to( |
|
x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( |
|
B, 2, p, p) |
|
self.index_mesh_grid = index_mesh_grid[0:1] |
|
else: |
|
index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( |
|
B, 0) |
|
assert index_mesh_grid.shape == ( |
|
B, 2, p, p), 'check index_mesh_grid dimension' |
|
|
|
for i in range(group_size): |
|
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( |
|
0, 2, 3, 1, 4).reshape(B * N, 1, C) |
|
|
|
|
|
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + |
|
1] |
|
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] |
|
|
|
assert plane_yz.shape == plane_zx.shape == ( |
|
B, 1, p, p, C), 'check sub plane dimensions' |
|
|
|
pooling_plane_yz = torch.gather( |
|
plane_yz, |
|
dim=2, |
|
index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( |
|
-1, -1, -1, p, |
|
C)).permute(0, 2, 1, 3, 4) |
|
pooling_plane_zx = torch.gather( |
|
plane_zx, |
|
dim=3, |
|
index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( |
|
-1, -1, p, -1, |
|
C)).permute(0, 3, 1, 2, 4) |
|
|
|
context[B * i * N:B * (i + 1) * N] = torch.cat( |
|
[pooling_plane_yz, pooling_plane_zx], |
|
dim=2).reshape(B * N, 2 * p, |
|
C) |
|
|
|
|
|
|
|
q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, |
|
C // self.num_heads) |
|
|
|
kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, |
|
self.num_heads, C // self.num_heads) |
|
k, v = unbind(kv, 2) |
|
|
|
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) |
|
|
|
x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) |
|
|
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
|
|
return x |
|
|
|
|
|
class xformer_Conv3D_Aware_CrossAttention_xygrid( |
|
xformer_Conv3D_Aware_CrossAttention): |
|
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention |
|
""" |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads=8, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
attn_drop=0.0, |
|
proj_drop=0.0): |
|
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, |
|
proj_drop) |
|
|
|
def forward(self, x, attn_bias=None): |
|
|
|
B, group_size, N, C = x.shape |
|
p = int(N**0.5) |
|
assert p**2 == N, 'check input dim, no [cls] needed here' |
|
assert group_size == 3, 'designed for triplane here' |
|
|
|
x = x.reshape(B, group_size, p, p, C) |
|
|
|
q_x = torch.empty(B * group_size * N, 1, C, device=x.device) |
|
context = torch.empty(B * group_size * N, 2 * p, C, |
|
device=x.device) |
|
|
|
if self.index_mesh_grid is None: |
|
index_u, index_v = torch.meshgrid( |
|
torch.arange(0, p), torch.arange(0, p), |
|
indexing='xy') |
|
index_mesh_grid = torch.stack([index_u, index_v], 0).to( |
|
x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( |
|
B, 2, p, p) |
|
self.index_mesh_grid = index_mesh_grid[0:1] |
|
else: |
|
index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( |
|
B, 0) |
|
assert index_mesh_grid.shape == ( |
|
B, 2, p, p), 'check index_mesh_grid dimension' |
|
|
|
for i in range(group_size): |
|
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( |
|
0, 2, 3, 1, 4).reshape(B * N, 1, C) |
|
|
|
|
|
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + |
|
1] |
|
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] |
|
|
|
assert plane_yz.shape == plane_zx.shape == ( |
|
B, 1, p, p, C), 'check sub plane dimensions' |
|
|
|
pooling_plane_yz = torch.gather( |
|
plane_yz, |
|
dim=2, |
|
index=index_mesh_grid[:, 1:2].reshape(B, 1, N, 1, 1).expand( |
|
-1, -1, -1, p, |
|
C)).permute(0, 2, 1, 3, 4) |
|
pooling_plane_zx = torch.gather( |
|
plane_zx, |
|
dim=3, |
|
index=index_mesh_grid[:, 0:1].reshape(B, 1, 1, N, 1).expand( |
|
-1, -1, p, -1, |
|
C)).permute(0, 3, 1, 2, 4) |
|
|
|
context[B * i * N:B * (i + 1) * N] = torch.cat( |
|
[pooling_plane_yz, pooling_plane_zx], |
|
dim=2).reshape(B * N, 2 * p, |
|
C) |
|
|
|
|
|
q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, |
|
C // self.num_heads) |
|
|
|
kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, |
|
self.num_heads, C // self.num_heads) |
|
k, v = unbind(kv, 2) |
|
|
|
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) |
|
|
|
x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) |
|
|
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
|
|
return x |
|
|
|
|
|
class xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( |
|
xformer_Conv3D_Aware_CrossAttention_xygrid): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads=8, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
attn_drop=0, |
|
proj_drop=0): |
|
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, |
|
proj_drop) |
|
|
|
def forward(self, x, attn_bias=None): |
|
|
|
B, N, C = x.shape |
|
x = x.reshape(B, N, C // 3, 3).permute(0, 3, 1, |
|
2) |
|
x_out = super().forward(x, attn_bias) |
|
x_out = x_out.permute(0, 2, 3, 1) |
|
x_out = x_out.reshape(*x_out.shape[:2], -1) |
|
return x_out.contiguous() |
|
|
|
class self_cross_attn(nn.Module): |
|
def __init__(self, dino_attn, cross_attn, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.dino_attn = dino_attn |
|
self.cross_attn = cross_attn |
|
|
|
def forward(self, x_norm): |
|
y = self.dino_attn(x_norm) + x_norm |
|
return self.cross_attn(y) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RodinRollOutConv3D(nn.Module): |
|
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention |
|
""" |
|
|
|
def __init__(self, in_chans, out_chans=None): |
|
super().__init__() |
|
if out_chans is None: |
|
out_chans = in_chans |
|
|
|
self.out_chans = out_chans // 3 |
|
|
|
self.roll_out_convs = nn.Conv2d(in_chans, |
|
self.out_chans, |
|
kernel_size=3, |
|
padding=1) |
|
|
|
def forward(self, x): |
|
|
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
group_size = C3 // C |
|
assert group_size == 3 |
|
|
|
x = x.reshape(B, 3, C, p, p) |
|
|
|
roll_out_x = torch.empty(B, group_size * C, p, 3 * p, |
|
device=x.device) |
|
|
|
for i in range(group_size): |
|
plane_xy = x[:, i] |
|
|
|
|
|
plane_yz_pooling = x[:, (i + 1) % group_size].mean( |
|
dim=-1, keepdim=True).repeat_interleave( |
|
p, dim=-1) |
|
plane_zx_pooling = x[:, (i + 2) % group_size].mean( |
|
dim=-2, keepdim=True).repeat_interleave( |
|
p, dim=-2) |
|
|
|
roll_out_x[..., i * p:(i + 1) * p] = torch.cat( |
|
[plane_xy, plane_yz_pooling, plane_zx_pooling], |
|
1) |
|
|
|
x = self.roll_out_convs(roll_out_x) |
|
|
|
x = x.reshape(B, self.out_chans, p, 3, p) |
|
x = x.permute(0, 3, 1, 2, 4).reshape(B, 3 * self.out_chans, p, |
|
p) |
|
|
|
return x |
|
|
|
|
|
class RodinRollOutConv3D_GroupConv(nn.Module): |
|
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention |
|
""" |
|
|
|
def __init__(self, |
|
in_chans, |
|
out_chans=None, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1): |
|
super().__init__() |
|
if out_chans is None: |
|
out_chans = in_chans |
|
|
|
self.roll_out_convs = nn.Conv2d( |
|
in_chans * 3, |
|
out_chans, |
|
kernel_size=kernel_size, |
|
groups=3, |
|
stride=stride, |
|
padding=padding) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
group_size = C3 // C |
|
assert group_size == 3 |
|
|
|
x = x.reshape(B, 3, C, p, p) |
|
|
|
roll_out_x = torch.empty(B, group_size * C * 3, p, p, |
|
device=x.device) |
|
|
|
for i in range(group_size): |
|
plane_xy = x[:, i] |
|
|
|
|
|
plane_yz_pooling = x[:, (i + 1) % group_size].mean( |
|
dim=-1, keepdim=True).repeat_interleave( |
|
p, dim=-1) |
|
plane_zx_pooling = x[:, (i + 2) % group_size].mean( |
|
dim=-2, keepdim=True).repeat_interleave( |
|
p, dim=-2) |
|
|
|
roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat( |
|
[plane_xy, plane_yz_pooling, plane_zx_pooling], |
|
1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.roll_out_convs(roll_out_x) |
|
|
|
return x |
|
|
|
|
|
class RodinRollOut_GroupConv_noConv3D(nn.Module): |
|
"""only roll out and do Conv on individual planes |
|
""" |
|
|
|
def __init__(self, |
|
in_chans, |
|
out_chans=None, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1): |
|
super().__init__() |
|
if out_chans is None: |
|
out_chans = in_chans |
|
|
|
self.roll_out_inplane_conv = nn.Conv2d( |
|
in_chans, |
|
out_chans, |
|
kernel_size=kernel_size, |
|
groups=3, |
|
stride=stride, |
|
padding=padding) |
|
|
|
def forward(self, x): |
|
x = self.roll_out_inplane_conv(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RodinConv3D_SynthesisLayer_mlp_unshuffle_as_residual(nn.Module): |
|
|
|
def __init__(self, in_chans, out_chans) -> None: |
|
super().__init__() |
|
|
|
self.act = nn.LeakyReLU(inplace=True) |
|
self.conv = nn.Sequential( |
|
RodinRollOutConv3D_GroupConv(in_chans, out_chans), |
|
nn.LeakyReLU(inplace=True), |
|
) |
|
|
|
self.out_chans = out_chans |
|
if in_chans != out_chans: |
|
|
|
self.short_cut = nn.Linear( |
|
in_chans // 3, |
|
out_chans // 3 * 4 * 4, |
|
bias=True) |
|
|
|
|
|
else: |
|
self.short_cut = None |
|
|
|
def shortcut_unpatchify_triplane(self, |
|
x, |
|
p=None, |
|
unpatchify_out_chans=None): |
|
"""separate triplane version; x shape: B (3*257) 768 |
|
""" |
|
|
|
assert self.short_cut is not None |
|
|
|
|
|
B, C3, h, w = x.shape |
|
assert h == w |
|
L = h * w |
|
x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, |
|
1) |
|
|
|
x = self.short_cut(x) |
|
|
|
p = h * 4 |
|
|
|
x = x.reshape(shape=(B, 3, h, w, p, p, unpatchify_out_chans)) |
|
x = torch.einsum('ndhwpqc->ndchpwq', |
|
x) |
|
x = x.reshape(shape=(B, 3 * self.out_chans, h * p, h * p)) |
|
return x |
|
|
|
def forward(self, feats): |
|
|
|
if self.short_cut is not None: |
|
res_feats = self.shortcut_unpatchify_triplane(feats) |
|
else: |
|
res_feats = feats |
|
|
|
|
|
|
|
feats = res_feats + self.conv(feats) |
|
return self.act(feats) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RodinConv3D_SynthesisLayer(nn.Module): |
|
|
|
def __init__(self, in_chans, out_chans) -> None: |
|
super().__init__() |
|
|
|
|
|
|
|
self.act = nn.LeakyReLU(inplace=True) |
|
self.conv = nn.Sequential( |
|
RodinRollOutConv3D_GroupConv(in_chans, out_chans), |
|
nn.LeakyReLU(inplace=True), |
|
) |
|
|
|
if in_chans != out_chans: |
|
self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) |
|
else: |
|
self.short_cut = None |
|
|
|
def forward(self, feats): |
|
feats_out = self.conv(feats) |
|
if self.short_cut is not None: |
|
|
|
feats_out = self.short_cut( |
|
feats |
|
) + feats_out |
|
|
|
else: |
|
feats_out = feats_out + feats |
|
return feats_out |
|
|
|
|
|
class RodinRollOutConv3DSR2X(nn.Module): |
|
|
|
def __init__(self, in_chans, **kwargs) -> None: |
|
super().__init__() |
|
self.conv3D = RodinRollOutConv3D_GroupConv(in_chans) |
|
|
|
self.act = nn.LeakyReLU(inplace=True) |
|
self.input_resolution = 224 |
|
|
|
def forward(self, x): |
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
group_size = C3 // C |
|
|
|
assert group_size == 3 |
|
|
|
|
|
assert group_size == 3, 'designed for triplane here' |
|
|
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
if x.shape[-1] != self.input_resolution: |
|
x = torch.nn.functional.interpolate(x, |
|
size=(self.input_resolution, |
|
self.input_resolution), |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True) |
|
|
|
x = x + self.conv3D(x) |
|
|
|
return x |
|
|
|
|
|
class RodinRollOutConv3DSR4X_lite(nn.Module): |
|
|
|
def __init__(self, in_chans, input_resolutiopn=256, **kwargs) -> None: |
|
super().__init__() |
|
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans) |
|
self.conv3D_1 = RodinRollOutConv3D_GroupConv(in_chans) |
|
|
|
self.act = nn.LeakyReLU(inplace=True) |
|
self.input_resolution = input_resolutiopn |
|
|
|
def forward(self, x): |
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
group_size = C3 // C |
|
|
|
assert group_size == 3 |
|
|
|
|
|
assert group_size == 3, 'designed for triplane here' |
|
|
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
if x.shape[-1] != self.input_resolution: |
|
x = torch.nn.functional.interpolate(x, |
|
size=(self.input_resolution, |
|
self.input_resolution), |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True) |
|
|
|
|
|
|
|
|
|
|
|
x = x + self.act(self.conv3D_0(x)) |
|
x = x + self.act(self.conv3D_1(x)) |
|
|
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RodinConv3D4X_lite_mlp_as_residual(nn.Module): |
|
"""lite 4X version, with MLP unshuffle to change the dimention |
|
""" |
|
|
|
def __init__(self, |
|
in_chans, |
|
out_chans, |
|
input_resolution=256, |
|
interp_mode='bilinear', |
|
bcg_triplane=False) -> None: |
|
super().__init__() |
|
|
|
self.interp_mode = interp_mode |
|
|
|
self.act = nn.LeakyReLU(inplace=True) |
|
|
|
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans) |
|
self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans) |
|
self.bcg_triplane = bcg_triplane |
|
if bcg_triplane: |
|
self.conv3D_1_bg = RodinRollOutConv3D_GroupConv( |
|
out_chans, out_chans) |
|
|
|
self.act = nn.LeakyReLU(inplace=True) |
|
self.input_resolution = input_resolution |
|
|
|
self.out_chans = out_chans |
|
if in_chans != out_chans: |
|
self.short_cut = nn.Linear( |
|
in_chans // 3, |
|
out_chans // 3, |
|
bias=True) |
|
else: |
|
self.short_cut = None |
|
|
|
def shortcut_unpatchify_triplane(self, x, p=None): |
|
"""separate triplane version; x shape: B (3*257) 768 |
|
""" |
|
|
|
assert self.short_cut is not None |
|
|
|
B, C3, h, w = x.shape |
|
assert h == w |
|
L = h * w |
|
x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, |
|
1) |
|
|
|
x = self.short_cut(x) |
|
|
|
x = x.permute(0, 1, 3, 2) |
|
x = x.reshape(shape=(B, self.out_chans, h, w)) |
|
|
|
|
|
if w != self.input_resolution: |
|
x = torch.nn.functional.interpolate( |
|
x, |
|
size=(self.input_resolution, self.input_resolution), |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True) |
|
|
|
return x |
|
|
|
def interpolate(self, feats): |
|
if self.interp_mode == 'bilinear': |
|
return torch.nn.functional.interpolate( |
|
feats, |
|
size=(self.input_resolution, self.input_resolution), |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True) |
|
else: |
|
return torch.nn.functional.interpolate( |
|
feats, |
|
size=(self.input_resolution, self.input_resolution), |
|
mode='nearest', |
|
) |
|
|
|
def forward(self, x): |
|
|
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
|
|
if self.short_cut is not None: |
|
res_feats = self.shortcut_unpatchify_triplane(x) |
|
else: |
|
res_feats = x |
|
if res_feats.shape[-1] != self.input_resolution: |
|
res_feats = self.interpolate(res_feats) |
|
"""following forward code copied from lite4x version |
|
""" |
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
if x.shape[-1] != self.input_resolution: |
|
x = self.interpolate(x) |
|
|
|
x0 = res_feats + self.act(self.conv3D_0(x)) |
|
x = x0 + self.act(self.conv3D_1(x0)) |
|
if self.bcg_triplane: |
|
x_bcg = x0 + self.act(self.conv3D_1_bg(x0)) |
|
return torch.cat([x, x_bcg], 1) |
|
else: |
|
return x |
|
|
|
|
|
class RodinConv3D4X_lite_mlp_as_residual_litev2( |
|
RodinConv3D4X_lite_mlp_as_residual): |
|
|
|
def __init__(self, |
|
in_chans, |
|
out_chans, |
|
num_feat=128, |
|
input_resolution=256, |
|
interp_mode='bilinear', |
|
bcg_triplane=False) -> None: |
|
super().__init__(in_chans, out_chans, input_resolution, interp_mode, |
|
bcg_triplane) |
|
|
|
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, in_chans) |
|
self.conv_before_upsample = RodinRollOut_GroupConv_noConv3D( |
|
in_chans, num_feat * 3) |
|
self.conv3D_1 = RodinRollOut_GroupConv_noConv3D( |
|
num_feat * 3, num_feat * 3) |
|
self.conv_last = RodinRollOut_GroupConv_noConv3D( |
|
num_feat * 3, out_chans) |
|
self.short_cut = None |
|
|
|
def forward(self, x): |
|
|
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""following forward code copied from lite4x version |
|
""" |
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
x = x + self.conv3D_0(x) |
|
x = self.act(self.conv_before_upsample(x)) |
|
|
|
|
|
x = self.conv_last(self.act(self.conv3D_1(self.interpolate(x)))) |
|
|
|
return x |
|
|
|
|
|
class RodinConv3D4X_lite_mlp_as_residual_lite( |
|
RodinConv3D4X_lite_mlp_as_residual): |
|
|
|
def __init__(self, |
|
in_chans, |
|
out_chans, |
|
input_resolution=256, |
|
interp_mode='bilinear') -> None: |
|
super().__init__(in_chans, out_chans, input_resolution, interp_mode) |
|
"""replace the first Rodin Conv 3D with ordinary rollout conv to save memory |
|
""" |
|
self.conv3D_0 = RodinRollOut_GroupConv_noConv3D(in_chans, out_chans) |
|
|
|
|
|
class SR3D(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
class RodinConv3D4X_lite_mlp_as_residual_improved(nn.Module): |
|
|
|
def __init__(self, |
|
in_chans, |
|
num_feat, |
|
out_chans, |
|
input_resolution=256) -> None: |
|
super().__init__() |
|
|
|
assert in_chans == 4 * out_chans |
|
assert num_feat == 2 * out_chans |
|
self.input_resolution = input_resolution |
|
|
|
|
|
self.upscale = 4 |
|
|
|
self.conv_after_body = RodinRollOutConv3D_GroupConv( |
|
in_chans, in_chans, 3, 1, 1) |
|
self.conv_before_upsample = nn.Sequential( |
|
RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), |
|
nn.LeakyReLU(inplace=True)) |
|
self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, |
|
1) |
|
if self.upscale == 4: |
|
self.conv_up2 = RodinRollOutConv3D_GroupConv( |
|
num_feat, num_feat, 3, 1, 1) |
|
self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, |
|
1) |
|
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, |
|
1, 1) |
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
def forward(self, x): |
|
|
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
"""following forward code copied from lite4x version |
|
""" |
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
|
|
x = self.conv_after_body(x) + x |
|
x = self.conv_before_upsample(x) |
|
x = self.lrelu( |
|
self.conv_up1( |
|
torch.nn.functional.interpolate( |
|
x, |
|
scale_factor=2, |
|
mode='nearest', |
|
|
|
|
|
))) |
|
if self.upscale == 4: |
|
x = self.lrelu( |
|
self.conv_up2( |
|
torch.nn.functional.interpolate( |
|
x, |
|
scale_factor=2, |
|
mode='nearest', |
|
|
|
|
|
))) |
|
x = self.conv_last(self.lrelu(self.conv_hr(x))) |
|
|
|
assert x.shape[-1] == self.input_resolution |
|
|
|
return x |
|
|
|
|
|
class RodinConv3D4X_lite_improved_lint_withresidual(nn.Module): |
|
|
|
def __init__(self, |
|
in_chans, |
|
num_feat, |
|
out_chans, |
|
input_resolution=256) -> None: |
|
super().__init__() |
|
|
|
assert in_chans == 4 * out_chans |
|
assert num_feat == 2 * out_chans |
|
self.input_resolution = input_resolution |
|
|
|
|
|
self.upscale = 4 |
|
|
|
self.conv_after_body = RodinRollOutConv3D_GroupConv( |
|
in_chans, in_chans, 3, 1, 1) |
|
self.conv_before_upsample = nn.Sequential( |
|
RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), |
|
nn.LeakyReLU(inplace=True)) |
|
self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, |
|
1) |
|
if self.upscale == 4: |
|
self.conv_up2 = RodinRollOutConv3D_GroupConv( |
|
num_feat, num_feat, 3, 1, 1) |
|
self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, |
|
1) |
|
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, |
|
1, 1) |
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
def forward(self, x): |
|
|
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
"""following forward code copied from lite4x version |
|
""" |
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
|
|
x = self.conv_after_body(x) + x |
|
x = self.conv_before_upsample(x) |
|
x = self.lrelu( |
|
self.conv_up1( |
|
torch.nn.functional.interpolate( |
|
x, |
|
scale_factor=2, |
|
mode='nearest', |
|
|
|
|
|
))) |
|
if self.upscale == 4: |
|
x = self.lrelu( |
|
self.conv_up2( |
|
torch.nn.functional.interpolate( |
|
x, |
|
scale_factor=2, |
|
mode='nearest', |
|
|
|
|
|
))) |
|
x = self.conv_last(self.lrelu(self.conv_hr(x) + x)) |
|
|
|
assert x.shape[-1] == self.input_resolution |
|
|
|
return x |
|
|
|
|
|
class RodinRollOutConv3DSR_FlexibleChannels(nn.Module): |
|
|
|
def __init__(self, |
|
in_chans, |
|
num_out_ch=96, |
|
input_resolution=256, |
|
**kwargs) -> None: |
|
super().__init__() |
|
|
|
self.block0 = RodinConv3D_SynthesisLayer(in_chans, |
|
num_out_ch) |
|
self.block1 = RodinConv3D_SynthesisLayer(num_out_ch, num_out_ch) |
|
|
|
self.input_resolution = input_resolution |
|
|
|
def forward(self, x): |
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
|
|
|
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
if x.shape[-1] != self.input_resolution: |
|
x = torch.nn.functional.interpolate(x, |
|
size=(self.input_resolution, |
|
self.input_resolution), |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True) |
|
|
|
x = self.block0(x) |
|
x = self.block1(x) |
|
|
|
return x |
|
|
|
|
|
|
|
class RodinRollOutConv3DSR4X(nn.Module): |
|
|
|
|
|
def __init__(self, in_chans, **kwargs) -> None: |
|
super().__init__() |
|
|
|
|
|
|
|
self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96) |
|
self.block1 = RodinConv3D_SynthesisLayer( |
|
96, 96) |
|
|
|
self.input_resolution = 64 |
|
|
|
def forward(self, x): |
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
|
|
|
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
if x.shape[-1] != self.input_resolution: |
|
x = torch.nn.functional.interpolate(x, |
|
size=(self.input_resolution, |
|
self.input_resolution), |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True) |
|
|
|
x = self.block0(x) |
|
x = self.block1(x) |
|
|
|
return x |
|
|
|
|
|
class Upsample3D(nn.Module): |
|
"""Upsample module. |
|
Args: |
|
scale (int): Scale factor. Supported scales: 2^n and 3. |
|
num_feat (int): Channel number of intermediate features. |
|
""" |
|
|
|
def __init__(self, scale, num_feat): |
|
super().__init__() |
|
|
|
m_convs = [] |
|
m_pixelshuffle = [] |
|
|
|
assert (scale & (scale - 1)) == 0, 'scale = 2^n' |
|
self.scale = scale |
|
|
|
for _ in range(int(math.log(scale, 2))): |
|
m_convs.append( |
|
RodinRollOutConv3D_GroupConv(num_feat, 4 * num_feat, 3, 1, 1)) |
|
m_pixelshuffle.append(nn.PixelShuffle(2)) |
|
|
|
self.m_convs = nn.ModuleList(m_convs) |
|
self.m_pixelshuffle = nn.ModuleList(m_pixelshuffle) |
|
|
|
|
|
def forward(self, x): |
|
for scale_idx in range(int(math.log(self.scale, 2))): |
|
x = self.m_convs[scale_idx](x) |
|
|
|
|
|
x = x.reshape(x.shape[0] * 3, x.shape[1] // 3, *x.shape[2:]) |
|
x = self.m_pixelshuffle[scale_idx](x) |
|
x = x.reshape(x.shape[0] // 3, x.shape[1] * 3, *x.shape[2:]) |
|
|
|
return x |
|
|
|
|
|
class RodinConv3DPixelUnshuffleUpsample(nn.Module): |
|
|
|
def __init__(self, |
|
output_dim, |
|
num_feat=32 * 6, |
|
num_out_ch=32 * 3, |
|
sr_ratio=4, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__() |
|
|
|
self.conv_after_body = RodinRollOutConv3D_GroupConv( |
|
output_dim, output_dim, 3, 1, 1) |
|
self.conv_before_upsample = nn.Sequential( |
|
RodinRollOutConv3D_GroupConv(output_dim, num_feat, 3, 1, 1), |
|
nn.LeakyReLU(inplace=True)) |
|
self.upsample = Upsample3D(sr_ratio, num_feat) |
|
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, num_out_ch, 3, |
|
1, 1) |
|
|
|
|
|
def forward(self, x, input_skip_connection=True, *args, **kwargs): |
|
|
|
if input_skip_connection: |
|
x = self.conv_after_body(x) + x |
|
else: |
|
x = self.conv_after_body(x) |
|
|
|
x = self.conv_before_upsample(x) |
|
x = self.upsample(x) |
|
x = self.conv_last(x) |
|
return x |
|
|
|
|
|
class RodinConv3DPixelUnshuffleUpsample_improvedVersion(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
output_dim, |
|
num_out_ch=32 * 3, |
|
sr_ratio=4, |
|
input_resolution=256, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.input_resolution = input_resolution |
|
|
|
|
|
|
|
self.upsample = Upsample3D(sr_ratio, output_dim) |
|
self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, |
|
3, 1, 1) |
|
|
|
def forward(self, x, bilinear_upsample=True): |
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
group_size = C3 // C |
|
|
|
assert group_size == 3, 'designed for triplane here' |
|
|
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
if bilinear_upsample and x.shape[-1] != self.input_resolution: |
|
x_bilinear_upsample = torch.nn.functional.interpolate( |
|
x, |
|
size=(self.input_resolution, self.input_resolution), |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True) |
|
x = self.upsample(x) + x_bilinear_upsample |
|
else: |
|
|
|
x = self.upsample(x) |
|
|
|
x = self.conv_last(x) |
|
|
|
return x |
|
|
|
|
|
class RodinConv3DPixelUnshuffleUpsample_improvedVersion2(nn.Module): |
|
"""removed nearest neighbour residual conenctions, add a conv layer residual conenction |
|
""" |
|
|
|
def __init__( |
|
self, |
|
output_dim, |
|
num_out_ch=32 * 3, |
|
sr_ratio=4, |
|
input_resolution=256, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.input_resolution = input_resolution |
|
|
|
self.conv_after_body = RodinRollOutConv3D_GroupConv( |
|
output_dim, num_out_ch, 3, 1, 1) |
|
self.upsample = Upsample3D(sr_ratio, output_dim) |
|
self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, |
|
3, 1, 1) |
|
|
|
def forward(self, x, input_skip_connection=True): |
|
|
|
B, C3, p, p = x.shape |
|
C = C3 // 3 |
|
group_size = C3 // C |
|
|
|
assert group_size == 3, 'designed for triplane here' |
|
|
|
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
|
p) |
|
|
|
if input_skip_connection: |
|
x = self.conv_after_body(x) + x |
|
else: |
|
x = self.conv_after_body(x) |
|
|
|
x = self.upsample(x) |
|
x = self.conv_last(x) |
|
|
|
return x |
|
|
|
|
|
class CLSCrossAttentionBlock(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads, |
|
mlp_ratio=4., |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop=0., |
|
attn_drop=0., |
|
drop_path=0., |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
has_mlp=False): |
|
super().__init__() |
|
self.norm1 = norm_layer(dim) |
|
self.attn = CrossAttention(dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop=attn_drop, |
|
proj_drop=drop) |
|
|
|
self.drop_path = DropPath( |
|
drop_path) if drop_path > 0. else nn.Identity() |
|
self.has_mlp = has_mlp |
|
if has_mlp: |
|
self.norm2 = norm_layer(dim) |
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
self.mlp = Mlp(in_features=dim, |
|
hidden_features=mlp_hidden_dim, |
|
act_layer=act_layer, |
|
drop=drop) |
|
|
|
def forward(self, x): |
|
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) |
|
if self.has_mlp: |
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
|
return x |
|
|
|
|
|
class Conv3DCrossAttentionBlock(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads, |
|
mlp_ratio=4., |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop=0., |
|
attn_drop=0., |
|
drop_path=0., |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
has_mlp=False): |
|
super().__init__() |
|
self.norm1 = norm_layer(dim) |
|
self.attn = Conv3D_Aware_CrossAttention(dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop=attn_drop, |
|
proj_drop=drop) |
|
|
|
self.drop_path = DropPath( |
|
drop_path) if drop_path > 0. else nn.Identity() |
|
self.has_mlp = has_mlp |
|
if has_mlp: |
|
self.norm2 = norm_layer(dim) |
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
self.mlp = Mlp(in_features=dim, |
|
hidden_features=mlp_hidden_dim, |
|
act_layer=act_layer, |
|
drop=drop) |
|
|
|
def forward(self, x): |
|
x = x + self.drop_path(self.attn(self.norm1(x))) |
|
if self.has_mlp: |
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
|
return x |
|
|
|
|
|
class Conv3DCrossAttentionBlockXformerMHA(Conv3DCrossAttentionBlock): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads, |
|
mlp_ratio=4, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop=0, |
|
attn_drop=0, |
|
drop_path=0, |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
has_mlp=False): |
|
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, |
|
attn_drop, drop_path, act_layer, norm_layer, has_mlp) |
|
|
|
self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid( |
|
dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop=attn_drop, |
|
proj_drop=drop) |
|
|
|
|
|
class Conv3DCrossAttentionBlockXformerMHANested( |
|
Conv3DCrossAttentionBlockXformerMHA): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads, |
|
mlp_ratio=4, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop=0., |
|
attn_drop=0., |
|
drop_path=0., |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
has_mlp=False): |
|
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, |
|
attn_drop, drop_path, act_layer, norm_layer, has_mlp) |
|
"""for in-place replaing the internal attn in Dino ViT. |
|
""" |
|
|
|
def forward(self, x): |
|
Bx3, N, C = x.shape |
|
B, group_size = Bx3 // 3, 3 |
|
x = x.reshape(B, group_size, N, C) |
|
x = super().forward(x) |
|
return x.reshape(B * group_size, N, |
|
C) |
|
|
|
|
|
class Conv3DCrossAttentionBlockXformerMHANested_withinC( |
|
Conv3DCrossAttentionBlockXformerMHANested): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads, |
|
mlp_ratio=4, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop=0, |
|
attn_drop=0, |
|
drop_path=0, |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
has_mlp=False): |
|
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, |
|
attn_drop, drop_path, act_layer, norm_layer, has_mlp) |
|
self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( |
|
dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop=attn_drop, |
|
proj_drop=drop) |
|
|
|
def forward(self, x): |
|
|
|
x = x + self.drop_path(self.attn(self.norm1(x))) |
|
if self.has_mlp: |
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
|
return x |
|
|
|
|
|
class TriplaneFusionBlock(nn.Module): |
|
"""4 ViT blocks + 1 CrossAttentionBlock |
|
""" |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
cross_attention_blk=CLSCrossAttentionBlock, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
self.num_branches = 3 |
|
self.vit_blks = vit_blks |
|
|
|
if use_fusion_blk: |
|
self.fusion = nn.ModuleList() |
|
|
|
|
|
nh = num_heads |
|
dim = embed_dim |
|
|
|
mlp_ratio = 4 |
|
qkv_bias = True |
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
drop_path_rate = 0.3 |
|
attn_drop = proj_drop = 0.0 |
|
qk_scale = None |
|
|
|
for d in range(self.num_branches): |
|
self.fusion.append( |
|
cross_attention_blk( |
|
dim=dim, |
|
num_heads=nh, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
|
|
drop=proj_drop, |
|
attn_drop=attn_drop, |
|
drop_path=drop_path_rate, |
|
norm_layer=norm_layer, |
|
has_mlp=False)) |
|
else: |
|
self.fusion = None |
|
|
|
def forward(self, x): |
|
|
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
|
|
|
|
|
|
B, group_size, N, C = x.shape |
|
assert group_size == 3, 'triplane' |
|
x = x.view(B * group_size, N, C) |
|
|
|
for blk in self.vit_blks: |
|
x = blk(x) |
|
|
|
if self.fusion is None: |
|
return x.view(B, group_size, N, C) |
|
|
|
|
|
|
|
|
|
|
|
outs_b = x.chunk(chunks=3, |
|
dim=0) |
|
|
|
|
|
proj_cls_token = [x[:, 0:1] for x in outs_b] |
|
|
|
outs = [] |
|
for i in range(self.num_branches): |
|
tmp = torch.cat( |
|
(proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, |
|
...]), |
|
dim=1) |
|
tmp = self.fusion[i](tmp) |
|
|
|
reverted_proj_cls_token = tmp[:, 0:1, ...] |
|
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), |
|
dim=1) |
|
outs.append(tmp) |
|
|
|
outs = torch.stack(outs, 1) |
|
return outs |
|
|
|
|
|
class TriplaneFusionBlockv2(nn.Module): |
|
"""4 ViT blocks + 1 CrossAttentionBlock |
|
""" |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=Conv3DCrossAttentionBlock, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
self.num_branches = 3 |
|
self.vit_blks = vit_blks |
|
|
|
if use_fusion_blk: |
|
|
|
|
|
|
|
nh = num_heads |
|
dim = embed_dim |
|
|
|
mlp_ratio = 4 |
|
qkv_bias = True |
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
drop_path_rate = 0.3 |
|
attn_drop = proj_drop = 0.0 |
|
qk_scale = None |
|
|
|
|
|
self.fusion = fusion_ca_blk( |
|
dim=dim, |
|
num_heads=nh, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
|
|
drop=proj_drop, |
|
attn_drop=attn_drop, |
|
drop_path=drop_path_rate, |
|
norm_layer=norm_layer, |
|
has_mlp=False) |
|
else: |
|
self.fusion = None |
|
|
|
def forward(self, x): |
|
|
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
|
|
|
|
|
|
B, group_size, N, C = x.shape |
|
assert group_size == 3, 'triplane' |
|
x = x.reshape(B * group_size, N, C) |
|
|
|
for blk in self.vit_blks: |
|
x = blk(x) |
|
|
|
if self.fusion is None: |
|
return x.reshape(B, group_size, N, C) |
|
|
|
x = x.reshape(B, group_size, N, C) |
|
|
|
return self.fusion(x) |
|
|
|
|
|
class TriplaneFusionBlockv3(TriplaneFusionBlockv2): |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, |
|
fusion_ca_blk, *args, **kwargs) |
|
|
|
|
|
class TriplaneFusionBlockv4(TriplaneFusionBlockv3): |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, |
|
fusion_ca_blk, *args, **kwargs) |
|
"""OOM? directly replace the atten here |
|
""" |
|
|
|
assert len(vit_blks) == 2 |
|
|
|
del self.vit_blks[1].attn, self.vit_blks[1].ls1, self.vit_blks[1].norm1 |
|
|
|
def ffn_residual_func(self, tx_blk, x: Tensor) -> Tensor: |
|
return tx_blk.ls2( |
|
tx_blk.mlp(tx_blk.norm2(x)) |
|
) |
|
|
|
def forward(self, x): |
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
assert self.fusion is not None |
|
|
|
B, group_size, N, C = x.shape |
|
x = x.reshape(B * group_size, N, C) |
|
|
|
|
|
x = self.vit_blks[0](x) |
|
|
|
|
|
x = x + self.fusion(x.reshape(B, group_size, N, C)).reshape( |
|
B * group_size, N, C) |
|
x = x + self.ffn_residual_func(self.vit_blks[1], x) |
|
return x.reshape(B, group_size, N, C) |
|
|
|
|
|
class TriplaneFusionBlockv4_nested(nn.Module): |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__() |
|
|
|
self.num_branches = 3 |
|
self.vit_blks = vit_blks |
|
|
|
assert use_fusion_blk |
|
|
|
assert len(vit_blks) == 2 |
|
|
|
|
|
del self.vit_blks[ |
|
1].attn |
|
|
|
|
|
nh = num_heads |
|
dim = embed_dim |
|
|
|
mlp_ratio = 4 |
|
qkv_bias = True |
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
drop_path_rate = 0.3 |
|
attn_drop = proj_drop = 0.0 |
|
qk_scale = None |
|
|
|
self.vit_blks[1].attn = fusion_ca_blk( |
|
dim=dim, |
|
num_heads=nh, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
|
|
drop=proj_drop, |
|
attn_drop=attn_drop, |
|
drop_path=drop_path_rate, |
|
norm_layer=norm_layer, |
|
has_mlp=False) |
|
|
|
def forward(self, x): |
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
|
|
|
|
|
|
B, group_size, N, C = x.shape |
|
assert group_size == 3, 'triplane' |
|
x = x.reshape(B * group_size, N, C) |
|
|
|
for blk in self.vit_blks: |
|
x = blk(x) |
|
|
|
|
|
return x.reshape(B, group_size, N, C) |
|
|
|
|
|
class TriplaneFusionBlockv4_nested_init_from_dino(nn.Module): |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
|
init_from_dino=True, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__() |
|
|
|
self.num_branches = 3 |
|
self.vit_blks = vit_blks |
|
|
|
assert use_fusion_blk |
|
|
|
assert len(vit_blks) == 2 |
|
|
|
|
|
nh = num_heads |
|
dim = embed_dim |
|
|
|
mlp_ratio = 4 |
|
qkv_bias = True |
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
drop_path_rate = 0.3 |
|
attn_drop = proj_drop = 0.0 |
|
qk_scale = None |
|
|
|
attn_3d = fusion_ca_blk( |
|
dim=dim, |
|
num_heads=nh, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
|
|
drop=proj_drop, |
|
attn_drop=attn_drop, |
|
drop_path=drop_path_rate, |
|
norm_layer=norm_layer, |
|
has_mlp=False) |
|
|
|
|
|
if init_from_dino: |
|
merged_qkv_linear = self.vit_blks[1].attn.qkv |
|
attn_3d.attn.proj.load_state_dict( |
|
self.vit_blks[1].attn.proj.state_dict()) |
|
|
|
|
|
attn_3d.attn.wq.weight.data = merged_qkv_linear.weight.data[: |
|
dim, :] |
|
attn_3d.attn.w_kv.weight.data = merged_qkv_linear.weight.data[ |
|
dim:, :] |
|
|
|
|
|
if qkv_bias: |
|
attn_3d.attn.wq.bias.data = merged_qkv_linear.bias.data[:dim] |
|
attn_3d.attn.w_kv.bias.data = merged_qkv_linear.bias.data[dim:] |
|
|
|
del self.vit_blks[1].attn |
|
|
|
self.vit_blks[1].attn = attn_3d |
|
|
|
def forward(self, x): |
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
|
|
|
|
|
|
B, group_size, N, C = x.shape |
|
assert group_size == 3, 'triplane' |
|
x = x.reshape(B * group_size, N, C) |
|
|
|
for blk in self.vit_blks: |
|
x = blk(x) |
|
|
|
|
|
return x.reshape(B, group_size, N, C) |
|
|
|
|
|
class TriplaneFusionBlockv4_nested_init_from_dino_lite(nn.Module): |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=None, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__() |
|
|
|
self.num_branches = 3 |
|
self.vit_blks = vit_blks |
|
|
|
assert use_fusion_blk |
|
|
|
assert len(vit_blks) == 2 |
|
|
|
|
|
nh = num_heads |
|
dim = embed_dim |
|
|
|
mlp_ratio = 4 |
|
qkv_bias = True |
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
drop_path_rate = 0.3 |
|
attn_drop = proj_drop = 0.0 |
|
qk_scale = None |
|
|
|
attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( |
|
dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop=attn_drop, |
|
proj_drop=proj_drop) |
|
|
|
del self.vit_blks[1].attn |
|
|
|
self.vit_blks[1].attn = attn_3d |
|
|
|
def forward(self, x): |
|
"""x: B N C, where N = H*W tokens. Just raw ViT forward pass |
|
""" |
|
|
|
|
|
B, N, C = x.shape |
|
|
|
for blk in self.vit_blks: |
|
x = blk(x) |
|
|
|
return x |
|
|
|
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge(nn.Module): |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=None, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__() |
|
|
|
self.vit_blks = vit_blks |
|
|
|
assert use_fusion_blk |
|
assert len(vit_blks) == 2 |
|
|
|
|
|
nh = num_heads |
|
dim = embed_dim |
|
qkv_bias = True |
|
attn_drop = proj_drop = 0.0 |
|
qk_scale = None |
|
|
|
if False: |
|
for blk in self.vit_blks: |
|
attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( |
|
dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop=attn_drop, |
|
proj_drop=proj_drop) |
|
blk.attn = self_cross_attn(blk.attn, attn_3d) |
|
|
|
def forward(self, x): |
|
"""x: B N C, where N = H*W tokens. Just raw ViT forward pass |
|
""" |
|
|
|
|
|
B, N, C = x.shape |
|
|
|
for blk in self.vit_blks: |
|
x = blk(x) |
|
|
|
return x |
|
|
|
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): |
|
|
|
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: |
|
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) |
|
|
|
|
|
def forward(self, x): |
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
|
|
|
|
|
|
B, group_size, N, C = x.shape |
|
x = x.reshape(B, group_size*N, C) |
|
|
|
for blk in self.vit_blks: |
|
x = blk(x) |
|
|
|
x = x.reshape(B, group_size, N, C) |
|
|
|
return x |
|
|
|
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): |
|
|
|
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: |
|
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) |
|
|
|
|
|
def forward(self, x): |
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
|
|
|
|
|
|
B, group_size, N, C = x.shape |
|
x = x.reshape(B*group_size, N, C) |
|
x = self.vit_blks[0](x) |
|
|
|
x = x.reshape(B,group_size*N, C) |
|
x = self.vit_blks[1](x) |
|
|
|
x = x.reshape(B, group_size, N, C) |
|
|
|
return x |
|
|
|
|
|
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_add3DAttn(TriplaneFusionBlockv4_nested_init_from_dino): |
|
|
|
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: |
|
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) |
|
|
|
|
|
def forward(self, x): |
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
B, group_size, N, C = x.shape |
|
x = x.reshape(B, group_size*N, C) |
|
x = self.vit_blks[0](x) |
|
|
|
|
|
x = x.reshape(B, group_size, N, C).reshape(B*group_size, N, C) |
|
x = self.vit_blks[1](x) |
|
return x.reshape(B, group_size, N, C) |
|
|
|
return x |
|
|
|
|
|
class TriplaneFusionBlockv5_ldm_addCA(nn.Module): |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__() |
|
|
|
self.num_branches = 3 |
|
self.vit_blks = vit_blks |
|
|
|
assert use_fusion_blk |
|
|
|
assert len(vit_blks) == 2 |
|
|
|
|
|
|
|
|
|
self.norm_for_atten_3d = deepcopy(self.vit_blks[1].norm1) |
|
|
|
|
|
nh = num_heads |
|
dim = embed_dim |
|
|
|
mlp_ratio = 4 |
|
qkv_bias = True |
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
drop_path_rate = 0.3 |
|
attn_drop = proj_drop = 0.0 |
|
qk_scale = None |
|
|
|
self.attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid( |
|
dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop=attn_drop, |
|
proj_drop=proj_drop) |
|
|
|
def forward(self, x): |
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
|
|
|
|
|
|
B, group_size, N, C = x.shape |
|
assert group_size == 3, 'triplane' |
|
|
|
flatten_token = lambda x: x.reshape(B * group_size, N, C) |
|
unflatten_token = lambda x: x.reshape(B, group_size, N, C) |
|
|
|
x = flatten_token(x) |
|
x = self.vit_blks[0](x) |
|
|
|
x = unflatten_token(x) |
|
x = self.attn_3d(self.norm_for_atten_3d(x)) + x |
|
|
|
x = flatten_token(x) |
|
x = self.vit_blks[1](x) |
|
|
|
return unflatten_token(x) |
|
|
|
|
|
class TriplaneFusionBlockv6_ldm_addCA_Init3DAttnfrom2D( |
|
TriplaneFusionBlockv5_ldm_addCA): |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, |
|
fusion_ca_blk, *args, **kwargs) |
|
|
|
def forward(self, x): |
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
|
|
|
|
|
|
B, group_size, N, C = x.shape |
|
assert group_size == 3, 'triplane' |
|
|
|
flatten_token = lambda x: x.reshape(B * group_size, N, C) |
|
unflatten_token = lambda x: x.reshape(B, group_size, N, C) |
|
|
|
x = flatten_token(x) |
|
x = self.vit_blks[0](x) |
|
|
|
x = unflatten_token(x) |
|
x = self.attn_3d(self.norm_for_atten_3d(x)) + x |
|
|
|
x = flatten_token(x) |
|
x = self.vit_blks[1](x) |
|
|
|
return unflatten_token(x) |
|
|
|
|
|
class TriplaneFusionBlockv5_ldm_add_dualCA(nn.Module): |
|
|
|
def __init__(self, |
|
vit_blks, |
|
num_heads, |
|
embed_dim, |
|
use_fusion_blk=True, |
|
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
|
*args, |
|
**kwargs) -> None: |
|
super().__init__() |
|
|
|
self.num_branches = 3 |
|
self.vit_blks = vit_blks |
|
|
|
assert use_fusion_blk |
|
|
|
assert len(vit_blks) == 2 |
|
|
|
|
|
|
|
|
|
self.norm_for_atten_3d_0 = deepcopy(self.vit_blks[0].norm1) |
|
self.norm_for_atten_3d_1 = deepcopy(self.vit_blks[1].norm1) |
|
|
|
|
|
nh = num_heads |
|
dim = embed_dim |
|
|
|
mlp_ratio = 4 |
|
qkv_bias = True |
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
drop_path_rate = 0.3 |
|
attn_drop = proj_drop = 0.0 |
|
qk_scale = None |
|
|
|
self.attn_3d_0 = xformer_Conv3D_Aware_CrossAttention_xygrid( |
|
dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop=attn_drop, |
|
proj_drop=proj_drop) |
|
|
|
self.attn_3d_1 = deepcopy(self.attn_3d_0) |
|
|
|
def forward(self, x): |
|
"""x: B 3 N C, where N = H*W tokens |
|
""" |
|
|
|
|
|
|
|
|
|
B, group_size, N, C = x.shape |
|
assert group_size == 3, 'triplane' |
|
|
|
flatten_token = lambda x: x.reshape(B * group_size, N, C) |
|
unflatten_token = lambda x: x.reshape(B, group_size, N, C) |
|
|
|
x = flatten_token(x) |
|
x = self.vit_blks[0](x) |
|
|
|
x = unflatten_token(x) |
|
x = self.attn_3d_0(self.norm_for_atten_3d_0(x)) + x |
|
|
|
x = flatten_token(x) |
|
x = self.vit_blks[1](x) |
|
|
|
x = unflatten_token(x) |
|
x = self.attn_3d_1(self.norm_for_atten_3d_1(x)) + x |
|
|
|
return unflatten_token(x) |
|
|
|
|
|
def drop_path(x, drop_prob: float = 0., training: bool = False): |
|
if drop_prob == 0. or not training: |
|
return x |
|
keep_prob = 1 - drop_prob |
|
shape = (x.shape[0], ) + (1, ) * ( |
|
x.ndim - 1) |
|
random_tensor = keep_prob + torch.rand( |
|
shape, dtype=x.dtype, device=x.device) |
|
random_tensor.floor_() |
|
output = x.div(keep_prob) * random_tensor |
|
return output |
|
|
|
|
|
class DropPath(nn.Module): |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
""" |
|
|
|
def __init__(self, drop_prob=None): |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
|
|
def forward(self, x): |
|
return drop_path(x, self.drop_prob, self.training) |
|
|
|
|
|
class Mlp(nn.Module): |
|
|
|
def __init__(self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
drop=0.): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
self.act = act_layer() |
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
self.drop = nn.Dropout(drop) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
x = self.drop(x) |
|
return x |
|
|
|
|
|
class Block(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
num_heads, |
|
mlp_ratio=4., |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop=0., |
|
attn_drop=0., |
|
drop_path=0., |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm): |
|
super().__init__() |
|
self.norm1 = norm_layer(dim) |
|
self.attn = Attention(dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop=attn_drop, |
|
proj_drop=drop) |
|
self.drop_path = DropPath( |
|
drop_path) if drop_path > 0. else nn.Identity() |
|
self.norm2 = norm_layer(dim) |
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
self.mlp = Mlp(in_features=dim, |
|
hidden_features=mlp_hidden_dim, |
|
act_layer=act_layer, |
|
drop=drop) |
|
|
|
def forward(self, x, return_attention=False): |
|
y, attn = self.attn(self.norm1(x)) |
|
if return_attention: |
|
return attn |
|
x = x + self.drop_path(y) |
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
return x |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
""" Image to Patch Embedding |
|
""" |
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): |
|
super().__init__() |
|
num_patches = (img_size // patch_size) * (img_size // patch_size) |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.num_patches = num_patches |
|
|
|
self.proj = nn.Conv2d(in_chans, |
|
embed_dim, |
|
kernel_size=patch_size, |
|
stride=patch_size) |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
x = self.proj(x).flatten(2).transpose(1, 2) |
|
return x |
|
|
|
|
|
class VisionTransformer(nn.Module): |
|
""" Vision Transformer """ |
|
|
|
def __init__(self, |
|
img_size=[224], |
|
patch_size=16, |
|
in_chans=3, |
|
num_classes=0, |
|
embed_dim=768, |
|
depth=12, |
|
num_heads=12, |
|
mlp_ratio=4., |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop_rate=0., |
|
attn_drop_rate=0., |
|
drop_path_rate=0., |
|
norm_layer='nn.LayerNorm', |
|
patch_embedding=True, |
|
cls_token=True, |
|
pixel_unshuffle=False, |
|
**kwargs): |
|
super().__init__() |
|
self.num_features = self.embed_dim = embed_dim |
|
self.patch_size = patch_size |
|
|
|
|
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
|
|
if patch_embedding: |
|
self.patch_embed = PatchEmbed(img_size=img_size[0], |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
embed_dim=embed_dim) |
|
num_patches = self.patch_embed.num_patches |
|
self.img_size = self.patch_embed.img_size |
|
else: |
|
self.patch_embed = None |
|
self.img_size = img_size[0] |
|
num_patches = (img_size[0] // patch_size) * (img_size[0] // |
|
patch_size) |
|
|
|
if cls_token: |
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
self.pos_embed = nn.Parameter( |
|
torch.zeros(1, num_patches + 1, embed_dim)) |
|
else: |
|
self.cls_token = None |
|
self.pos_embed = nn.Parameter( |
|
torch.zeros(1, num_patches, embed_dim)) |
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) |
|
] |
|
self.blocks = nn.ModuleList([ |
|
Block(dim=embed_dim, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
drop=drop_rate, |
|
attn_drop=attn_drop_rate, |
|
drop_path=dpr[i], |
|
norm_layer=norm_layer) for i in range(depth) |
|
]) |
|
self.norm = norm_layer(embed_dim) |
|
|
|
|
|
self.head = nn.Linear( |
|
embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
trunc_normal_(self.pos_embed, std=.02) |
|
if cls_token: |
|
trunc_normal_(self.cls_token, std=.02) |
|
self.apply(self._init_weights) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def interpolate_pos_encoding(self, x, w, h): |
|
npatch = x.shape[1] - 1 |
|
N = self.pos_embed.shape[1] - 1 |
|
if npatch == N and w == h: |
|
return self.pos_embed |
|
patch_pos_embed = self.pos_embed[:, 1:] |
|
dim = x.shape[-1] |
|
w0 = w // self.patch_size |
|
h0 = h // self.patch_size |
|
|
|
|
|
w0, h0 = w0 + 0.1, h0 + 0.1 |
|
|
|
patch_pos_embed = nn.functional.interpolate( |
|
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), |
|
dim).permute(0, 3, 1, 2), |
|
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), |
|
mode='bicubic', |
|
) |
|
assert int(w0) == patch_pos_embed.shape[-2] and int( |
|
h0) == patch_pos_embed.shape[-1] |
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(2, -1, dim) |
|
|
|
if self.cls_token is not None: |
|
class_pos_embed = self.pos_embed[:, 0] |
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), |
|
dim=1) |
|
return patch_pos_embed |
|
|
|
def prepare_tokens(self, x): |
|
B, nc, w, h = x.shape |
|
x = self.patch_embed(x) |
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
x = x + self.interpolate_pos_encoding(x, w, h) |
|
|
|
return self.pos_drop(x) |
|
|
|
def forward(self, x): |
|
x = self.prepare_tokens(x) |
|
for blk in self.blocks: |
|
x = blk(x) |
|
x = self.norm(x) |
|
return x[:, 1:] |
|
|
|
|
|
def get_last_selfattention(self, x): |
|
x = self.prepare_tokens(x) |
|
for i, blk in enumerate(self.blocks): |
|
if i < len(self.blocks) - 1: |
|
x = blk(x) |
|
else: |
|
|
|
return blk(x, return_attention=True) |
|
|
|
def get_intermediate_layers(self, x, n=1): |
|
x = self.prepare_tokens(x) |
|
|
|
output = [] |
|
for i, blk in enumerate(self.blocks): |
|
x = blk(x) |
|
if len(self.blocks) - i <= n: |
|
output.append(self.norm(x)) |
|
return output |
|
|
|
|
|
def vit_tiny(patch_size=16, **kwargs): |
|
model = VisionTransformer(patch_size=patch_size, |
|
embed_dim=192, |
|
depth=12, |
|
num_heads=3, |
|
mlp_ratio=4, |
|
qkv_bias=True, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
**kwargs) |
|
return model |
|
|
|
|
|
def vit_small(patch_size=16, **kwargs): |
|
model = VisionTransformer( |
|
patch_size=patch_size, |
|
embed_dim=384, |
|
depth=12, |
|
num_heads=6, |
|
mlp_ratio=4, |
|
qkv_bias=True, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
**kwargs) |
|
return model |
|
|
|
|
|
def vit_base(patch_size=16, **kwargs): |
|
model = VisionTransformer(patch_size=patch_size, |
|
embed_dim=768, |
|
depth=12, |
|
num_heads=12, |
|
mlp_ratio=4, |
|
qkv_bias=True, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
**kwargs) |
|
return model |
|
|
|
|
|
vits = vit_small |
|
vitb = vit_base |
|
|