Update unet/conditional_unet_model.py
Browse files
unet/conditional_unet_model.py
CHANGED
@@ -17,6 +17,7 @@ from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block
|
|
17 |
class UNet2DOutput(BaseOutput):
|
18 |
"""
|
19 |
The output of [`UNet2DModel`].
|
|
|
20 |
Args:
|
21 |
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
22 |
The hidden states output from the last layer of the model.
|
@@ -28,8 +29,10 @@ class UNet2DOutput(BaseOutput):
|
|
28 |
class UNet2DModel(ModelMixin, ConfigMixin):
|
29 |
r"""
|
30 |
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
|
|
31 |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
32 |
for all models (such as downloading or saving).
|
|
|
33 |
Parameters:
|
34 |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
35 |
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
@@ -236,6 +239,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
236 |
) -> Union[UNet2DOutput, Tuple]:
|
237 |
r"""
|
238 |
The [`UNet2DModel`] forward method.
|
|
|
239 |
Args:
|
240 |
sample (`torch.FloatTensor`):
|
241 |
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
@@ -244,6 +248,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
244 |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
245 |
return_dict (`bool`, *optional*, defaults to `True`):
|
246 |
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
|
|
247 |
Returns:
|
248 |
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
249 |
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
@@ -433,8 +438,10 @@ class ClassConditionedUnetForShapes3D(ModelMixin, ConfigMixin):
|
|
433 |
class MultiLabelConditionalUNet2DModelForShapes3D(ModelMixin, ConfigMixin):
|
434 |
r"""
|
435 |
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
|
|
436 |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
437 |
for all models (such as downloading or saving).
|
|
|
438 |
Parameters:
|
439 |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
440 |
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
@@ -661,6 +668,7 @@ class MultiLabelConditionalUNet2DModelForShapes3D(ModelMixin, ConfigMixin):
|
|
661 |
) -> Union[UNet2DOutput, Tuple]:
|
662 |
r"""
|
663 |
The [`UNet2DModel`] forward method.
|
|
|
664 |
Args:
|
665 |
sample (`torch.FloatTensor`):
|
666 |
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
@@ -669,6 +677,7 @@ class MultiLabelConditionalUNet2DModelForShapes3D(ModelMixin, ConfigMixin):
|
|
669 |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
670 |
return_dict (`bool`, *optional*, defaults to `True`):
|
671 |
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
|
|
672 |
Returns:
|
673 |
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
674 |
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
|
|
17 |
class UNet2DOutput(BaseOutput):
|
18 |
"""
|
19 |
The output of [`UNet2DModel`].
|
20 |
+
|
21 |
Args:
|
22 |
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
23 |
The hidden states output from the last layer of the model.
|
|
|
29 |
class UNet2DModel(ModelMixin, ConfigMixin):
|
30 |
r"""
|
31 |
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
32 |
+
|
33 |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
34 |
for all models (such as downloading or saving).
|
35 |
+
|
36 |
Parameters:
|
37 |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
38 |
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
|
|
239 |
) -> Union[UNet2DOutput, Tuple]:
|
240 |
r"""
|
241 |
The [`UNet2DModel`] forward method.
|
242 |
+
|
243 |
Args:
|
244 |
sample (`torch.FloatTensor`):
|
245 |
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
|
|
248 |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
249 |
return_dict (`bool`, *optional*, defaults to `True`):
|
250 |
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
251 |
+
|
252 |
Returns:
|
253 |
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
254 |
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
|
|
438 |
class MultiLabelConditionalUNet2DModelForShapes3D(ModelMixin, ConfigMixin):
|
439 |
r"""
|
440 |
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
441 |
+
|
442 |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
443 |
for all models (such as downloading or saving).
|
444 |
+
|
445 |
Parameters:
|
446 |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
447 |
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
|
|
668 |
) -> Union[UNet2DOutput, Tuple]:
|
669 |
r"""
|
670 |
The [`UNet2DModel`] forward method.
|
671 |
+
|
672 |
Args:
|
673 |
sample (`torch.FloatTensor`):
|
674 |
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
|
|
677 |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
678 |
return_dict (`bool`, *optional*, defaults to `True`):
|
679 |
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
680 |
+
|
681 |
Returns:
|
682 |
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
683 |
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|