Update unet/mv_unet.py
Browse files- unet/mv_unet.py +0 -84
unet/mv_unet.py
CHANGED
@@ -39,55 +39,6 @@ def get_camera(
|
|
39 |
return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
|
40 |
|
41 |
|
42 |
-
def checkpoint(func, inputs, params, flag):
|
43 |
-
"""
|
44 |
-
Evaluate a function without caching intermediate activations, allowing for
|
45 |
-
reduced memory at the expense of extra compute in the backward pass.
|
46 |
-
:param func: the function to evaluate.
|
47 |
-
:param inputs: the argument sequence to pass to `func`.
|
48 |
-
:param params: a sequence of parameters `func` depends on but does not
|
49 |
-
explicitly take as arguments.
|
50 |
-
:param flag: if False, disable gradient checkpointing.
|
51 |
-
"""
|
52 |
-
if flag:
|
53 |
-
args = tuple(inputs) + tuple(params)
|
54 |
-
return CheckpointFunction.apply(func, len(inputs), *args)
|
55 |
-
else:
|
56 |
-
return func(*inputs)
|
57 |
-
|
58 |
-
|
59 |
-
class CheckpointFunction(torch.autograd.Function):
|
60 |
-
@staticmethod
|
61 |
-
def forward(ctx, run_function, length, *args):
|
62 |
-
ctx.run_function = run_function
|
63 |
-
ctx.input_tensors = list(args[:length])
|
64 |
-
ctx.input_params = list(args[length:])
|
65 |
-
|
66 |
-
with torch.no_grad():
|
67 |
-
output_tensors = ctx.run_function(*ctx.input_tensors)
|
68 |
-
return output_tensors
|
69 |
-
|
70 |
-
@staticmethod
|
71 |
-
def backward(ctx, *output_grads):
|
72 |
-
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
73 |
-
with torch.enable_grad():
|
74 |
-
# Fixes a bug where the first op in run_function modifies the
|
75 |
-
# Tensor storage in place, which is not allowed for detach()'d
|
76 |
-
# Tensors.
|
77 |
-
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
78 |
-
output_tensors = ctx.run_function(*shallow_copies)
|
79 |
-
input_grads = torch.autograd.grad(
|
80 |
-
output_tensors,
|
81 |
-
ctx.input_tensors + ctx.input_params,
|
82 |
-
output_grads,
|
83 |
-
allow_unused=True,
|
84 |
-
)
|
85 |
-
del ctx.input_tensors
|
86 |
-
del ctx.input_params
|
87 |
-
del output_tensors
|
88 |
-
return (None, None) + input_grads
|
89 |
-
|
90 |
-
|
91 |
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
92 |
"""
|
93 |
Create sinusoidal timestep embeddings.
|
@@ -286,7 +237,6 @@ class BasicTransformerBlock3D(nn.Module):
|
|
286 |
context_dim,
|
287 |
dropout=0.0,
|
288 |
gated_ff=True,
|
289 |
-
checkpoint=True,
|
290 |
ip_dim=0,
|
291 |
ip_weight=1,
|
292 |
):
|
@@ -313,14 +263,8 @@ class BasicTransformerBlock3D(nn.Module):
|
|
313 |
self.norm1 = nn.LayerNorm(dim)
|
314 |
self.norm2 = nn.LayerNorm(dim)
|
315 |
self.norm3 = nn.LayerNorm(dim)
|
316 |
-
self.checkpoint = checkpoint
|
317 |
|
318 |
def forward(self, x, context=None, num_frames=1):
|
319 |
-
return checkpoint(
|
320 |
-
self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
|
321 |
-
)
|
322 |
-
|
323 |
-
def _forward(self, x, context=None, num_frames=1):
|
324 |
x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
|
325 |
x = self.attn1(self.norm1(x), context=None) + x
|
326 |
x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
|
@@ -341,7 +285,6 @@ class SpatialTransformer3D(nn.Module):
|
|
341 |
dropout=0.0,
|
342 |
ip_dim=0,
|
343 |
ip_weight=1,
|
344 |
-
use_checkpoint=True,
|
345 |
):
|
346 |
super().__init__()
|
347 |
|
@@ -362,7 +305,6 @@ class SpatialTransformer3D(nn.Module):
|
|
362 |
d_head,
|
363 |
context_dim=context_dim[d],
|
364 |
dropout=dropout,
|
365 |
-
checkpoint=use_checkpoint,
|
366 |
ip_dim=ip_dim,
|
367 |
ip_weight=ip_weight,
|
368 |
)
|
@@ -581,7 +523,6 @@ class ResBlock(nn.Module):
|
|
581 |
convolution instead of a smaller 1x1 convolution to change the
|
582 |
channels in the skip connection.
|
583 |
:param dims: determines if the signal is 1D, 2D, or 3D.
|
584 |
-
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
585 |
:param up: if True, use this block for upsampling.
|
586 |
:param down: if True, use this block for downsampling.
|
587 |
"""
|
@@ -595,7 +536,6 @@ class ResBlock(nn.Module):
|
|
595 |
use_conv=False,
|
596 |
use_scale_shift_norm=False,
|
597 |
dims=2,
|
598 |
-
use_checkpoint=False,
|
599 |
up=False,
|
600 |
down=False,
|
601 |
):
|
@@ -605,7 +545,6 @@ class ResBlock(nn.Module):
|
|
605 |
self.dropout = dropout
|
606 |
self.out_channels = out_channels or channels
|
607 |
self.use_conv = use_conv
|
608 |
-
self.use_checkpoint = use_checkpoint
|
609 |
self.use_scale_shift_norm = use_scale_shift_norm
|
610 |
|
611 |
self.in_layers = nn.Sequential(
|
@@ -651,17 +590,6 @@ class ResBlock(nn.Module):
|
|
651 |
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
652 |
|
653 |
def forward(self, x, emb):
|
654 |
-
"""
|
655 |
-
Apply the block to a Tensor, conditioned on a timestep embedding.
|
656 |
-
:param x: an [N x C x ...] Tensor of features.
|
657 |
-
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
658 |
-
:return: an [N x C x ...] Tensor of outputs.
|
659 |
-
"""
|
660 |
-
return checkpoint(
|
661 |
-
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
662 |
-
)
|
663 |
-
|
664 |
-
def _forward(self, x, emb):
|
665 |
if self.updown:
|
666 |
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
667 |
h = in_rest(x)
|
@@ -702,7 +630,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
702 |
:param dims: determines if the signal is 1D, 2D, or 3D.
|
703 |
:param num_classes: if specified (as an int), then this model will be
|
704 |
class-conditional with `num_classes` classes.
|
705 |
-
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
706 |
:param num_heads: the number of attention heads in each attention layer.
|
707 |
:param num_heads_channels: if specified, ignore num_heads and instead use
|
708 |
a fixed channel width per attention head.
|
@@ -728,7 +655,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
728 |
conv_resample=True,
|
729 |
dims=2,
|
730 |
num_classes=None,
|
731 |
-
use_checkpoint=False,
|
732 |
num_heads=-1,
|
733 |
num_head_channels=-1,
|
734 |
num_heads_upsample=-1,
|
@@ -794,7 +720,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
794 |
self.channel_mult = channel_mult
|
795 |
self.conv_resample = conv_resample
|
796 |
self.num_classes = num_classes
|
797 |
-
self.use_checkpoint = use_checkpoint
|
798 |
self.num_heads = num_heads
|
799 |
self.num_head_channels = num_head_channels
|
800 |
self.num_heads_upsample = num_heads_upsample
|
@@ -868,7 +793,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
868 |
dropout,
|
869 |
out_channels=mult * model_channels,
|
870 |
dims=dims,
|
871 |
-
use_checkpoint=use_checkpoint,
|
872 |
use_scale_shift_norm=use_scale_shift_norm,
|
873 |
)
|
874 |
]
|
@@ -888,7 +812,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
888 |
dim_head,
|
889 |
context_dim=context_dim,
|
890 |
depth=transformer_depth,
|
891 |
-
use_checkpoint=use_checkpoint,
|
892 |
ip_dim=self.ip_dim,
|
893 |
ip_weight=self.ip_weight,
|
894 |
)
|
@@ -906,7 +829,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
906 |
dropout,
|
907 |
out_channels=out_ch,
|
908 |
dims=dims,
|
909 |
-
use_checkpoint=use_checkpoint,
|
910 |
use_scale_shift_norm=use_scale_shift_norm,
|
911 |
down=True,
|
912 |
)
|
@@ -933,7 +855,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
933 |
time_embed_dim,
|
934 |
dropout,
|
935 |
dims=dims,
|
936 |
-
use_checkpoint=use_checkpoint,
|
937 |
use_scale_shift_norm=use_scale_shift_norm,
|
938 |
),
|
939 |
SpatialTransformer3D(
|
@@ -942,7 +863,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
942 |
dim_head,
|
943 |
context_dim=context_dim,
|
944 |
depth=transformer_depth,
|
945 |
-
use_checkpoint=use_checkpoint,
|
946 |
ip_dim=self.ip_dim,
|
947 |
ip_weight=self.ip_weight,
|
948 |
),
|
@@ -951,7 +871,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
951 |
time_embed_dim,
|
952 |
dropout,
|
953 |
dims=dims,
|
954 |
-
use_checkpoint=use_checkpoint,
|
955 |
use_scale_shift_norm=use_scale_shift_norm,
|
956 |
),
|
957 |
)
|
@@ -968,7 +887,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
968 |
dropout,
|
969 |
out_channels=model_channels * mult,
|
970 |
dims=dims,
|
971 |
-
use_checkpoint=use_checkpoint,
|
972 |
use_scale_shift_norm=use_scale_shift_norm,
|
973 |
)
|
974 |
]
|
@@ -988,7 +906,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
988 |
dim_head,
|
989 |
context_dim=context_dim,
|
990 |
depth=transformer_depth,
|
991 |
-
use_checkpoint=use_checkpoint,
|
992 |
ip_dim=self.ip_dim,
|
993 |
ip_weight=self.ip_weight,
|
994 |
)
|
@@ -1002,7 +919,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
1002 |
dropout,
|
1003 |
out_channels=out_ch,
|
1004 |
dims=dims,
|
1005 |
-
use_checkpoint=use_checkpoint,
|
1006 |
use_scale_shift_norm=use_scale_shift_norm,
|
1007 |
up=True,
|
1008 |
)
|
|
|
39 |
return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
43 |
"""
|
44 |
Create sinusoidal timestep embeddings.
|
|
|
237 |
context_dim,
|
238 |
dropout=0.0,
|
239 |
gated_ff=True,
|
|
|
240 |
ip_dim=0,
|
241 |
ip_weight=1,
|
242 |
):
|
|
|
263 |
self.norm1 = nn.LayerNorm(dim)
|
264 |
self.norm2 = nn.LayerNorm(dim)
|
265 |
self.norm3 = nn.LayerNorm(dim)
|
|
|
266 |
|
267 |
def forward(self, x, context=None, num_frames=1):
|
|
|
|
|
|
|
|
|
|
|
268 |
x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
|
269 |
x = self.attn1(self.norm1(x), context=None) + x
|
270 |
x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
|
|
|
285 |
dropout=0.0,
|
286 |
ip_dim=0,
|
287 |
ip_weight=1,
|
|
|
288 |
):
|
289 |
super().__init__()
|
290 |
|
|
|
305 |
d_head,
|
306 |
context_dim=context_dim[d],
|
307 |
dropout=dropout,
|
|
|
308 |
ip_dim=ip_dim,
|
309 |
ip_weight=ip_weight,
|
310 |
)
|
|
|
523 |
convolution instead of a smaller 1x1 convolution to change the
|
524 |
channels in the skip connection.
|
525 |
:param dims: determines if the signal is 1D, 2D, or 3D.
|
|
|
526 |
:param up: if True, use this block for upsampling.
|
527 |
:param down: if True, use this block for downsampling.
|
528 |
"""
|
|
|
536 |
use_conv=False,
|
537 |
use_scale_shift_norm=False,
|
538 |
dims=2,
|
|
|
539 |
up=False,
|
540 |
down=False,
|
541 |
):
|
|
|
545 |
self.dropout = dropout
|
546 |
self.out_channels = out_channels or channels
|
547 |
self.use_conv = use_conv
|
|
|
548 |
self.use_scale_shift_norm = use_scale_shift_norm
|
549 |
|
550 |
self.in_layers = nn.Sequential(
|
|
|
590 |
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
591 |
|
592 |
def forward(self, x, emb):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
if self.updown:
|
594 |
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
595 |
h = in_rest(x)
|
|
|
630 |
:param dims: determines if the signal is 1D, 2D, or 3D.
|
631 |
:param num_classes: if specified (as an int), then this model will be
|
632 |
class-conditional with `num_classes` classes.
|
|
|
633 |
:param num_heads: the number of attention heads in each attention layer.
|
634 |
:param num_heads_channels: if specified, ignore num_heads and instead use
|
635 |
a fixed channel width per attention head.
|
|
|
655 |
conv_resample=True,
|
656 |
dims=2,
|
657 |
num_classes=None,
|
|
|
658 |
num_heads=-1,
|
659 |
num_head_channels=-1,
|
660 |
num_heads_upsample=-1,
|
|
|
720 |
self.channel_mult = channel_mult
|
721 |
self.conv_resample = conv_resample
|
722 |
self.num_classes = num_classes
|
|
|
723 |
self.num_heads = num_heads
|
724 |
self.num_head_channels = num_head_channels
|
725 |
self.num_heads_upsample = num_heads_upsample
|
|
|
793 |
dropout,
|
794 |
out_channels=mult * model_channels,
|
795 |
dims=dims,
|
|
|
796 |
use_scale_shift_norm=use_scale_shift_norm,
|
797 |
)
|
798 |
]
|
|
|
812 |
dim_head,
|
813 |
context_dim=context_dim,
|
814 |
depth=transformer_depth,
|
|
|
815 |
ip_dim=self.ip_dim,
|
816 |
ip_weight=self.ip_weight,
|
817 |
)
|
|
|
829 |
dropout,
|
830 |
out_channels=out_ch,
|
831 |
dims=dims,
|
|
|
832 |
use_scale_shift_norm=use_scale_shift_norm,
|
833 |
down=True,
|
834 |
)
|
|
|
855 |
time_embed_dim,
|
856 |
dropout,
|
857 |
dims=dims,
|
|
|
858 |
use_scale_shift_norm=use_scale_shift_norm,
|
859 |
),
|
860 |
SpatialTransformer3D(
|
|
|
863 |
dim_head,
|
864 |
context_dim=context_dim,
|
865 |
depth=transformer_depth,
|
|
|
866 |
ip_dim=self.ip_dim,
|
867 |
ip_weight=self.ip_weight,
|
868 |
),
|
|
|
871 |
time_embed_dim,
|
872 |
dropout,
|
873 |
dims=dims,
|
|
|
874 |
use_scale_shift_norm=use_scale_shift_norm,
|
875 |
),
|
876 |
)
|
|
|
887 |
dropout,
|
888 |
out_channels=model_channels * mult,
|
889 |
dims=dims,
|
|
|
890 |
use_scale_shift_norm=use_scale_shift_norm,
|
891 |
)
|
892 |
]
|
|
|
906 |
dim_head,
|
907 |
context_dim=context_dim,
|
908 |
depth=transformer_depth,
|
|
|
909 |
ip_dim=self.ip_dim,
|
910 |
ip_weight=self.ip_weight,
|
911 |
)
|
|
|
919 |
dropout,
|
920 |
out_channels=out_ch,
|
921 |
dims=dims,
|
|
|
922 |
use_scale_shift_norm=use_scale_shift_norm,
|
923 |
up=True,
|
924 |
)
|