giulio98 commited on
Commit
19527dd
1 Parent(s): e835170

Update unet/conditional_unet_model.py

Browse files
Files changed (1) hide show
  1. unet/conditional_unet_model.py +9 -0
unet/conditional_unet_model.py CHANGED
@@ -17,6 +17,7 @@ from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block
17
  class UNet2DOutput(BaseOutput):
18
  """
19
  The output of [`UNet2DModel`].
 
20
  Args:
21
  sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
22
  The hidden states output from the last layer of the model.
@@ -28,8 +29,10 @@ class UNet2DOutput(BaseOutput):
28
  class UNet2DModel(ModelMixin, ConfigMixin):
29
  r"""
30
  A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
 
31
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
32
  for all models (such as downloading or saving).
 
33
  Parameters:
34
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
35
  Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
@@ -236,6 +239,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
236
  ) -> Union[UNet2DOutput, Tuple]:
237
  r"""
238
  The [`UNet2DModel`] forward method.
 
239
  Args:
240
  sample (`torch.FloatTensor`):
241
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
@@ -244,6 +248,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
244
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
245
  return_dict (`bool`, *optional*, defaults to `True`):
246
  Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
 
247
  Returns:
248
  [`~models.unet_2d.UNet2DOutput`] or `tuple`:
249
  If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
@@ -433,8 +438,10 @@ class ClassConditionedUnetForShapes3D(ModelMixin, ConfigMixin):
433
  class MultiLabelConditionalUNet2DModelForShapes3D(ModelMixin, ConfigMixin):
434
  r"""
435
  A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
 
436
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
437
  for all models (such as downloading or saving).
 
438
  Parameters:
439
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
440
  Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
@@ -661,6 +668,7 @@ class MultiLabelConditionalUNet2DModelForShapes3D(ModelMixin, ConfigMixin):
661
  ) -> Union[UNet2DOutput, Tuple]:
662
  r"""
663
  The [`UNet2DModel`] forward method.
 
664
  Args:
665
  sample (`torch.FloatTensor`):
666
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
@@ -669,6 +677,7 @@ class MultiLabelConditionalUNet2DModelForShapes3D(ModelMixin, ConfigMixin):
669
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
670
  return_dict (`bool`, *optional*, defaults to `True`):
671
  Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
 
672
  Returns:
673
  [`~models.unet_2d.UNet2DOutput`] or `tuple`:
674
  If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
 
17
  class UNet2DOutput(BaseOutput):
18
  """
19
  The output of [`UNet2DModel`].
20
+
21
  Args:
22
  sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
23
  The hidden states output from the last layer of the model.
 
29
  class UNet2DModel(ModelMixin, ConfigMixin):
30
  r"""
31
  A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
32
+
33
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
34
  for all models (such as downloading or saving).
35
+
36
  Parameters:
37
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
38
  Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
 
239
  ) -> Union[UNet2DOutput, Tuple]:
240
  r"""
241
  The [`UNet2DModel`] forward method.
242
+
243
  Args:
244
  sample (`torch.FloatTensor`):
245
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
 
248
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
249
  return_dict (`bool`, *optional*, defaults to `True`):
250
  Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
251
+
252
  Returns:
253
  [`~models.unet_2d.UNet2DOutput`] or `tuple`:
254
  If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
 
438
  class MultiLabelConditionalUNet2DModelForShapes3D(ModelMixin, ConfigMixin):
439
  r"""
440
  A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
441
+
442
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
443
  for all models (such as downloading or saving).
444
+
445
  Parameters:
446
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
447
  Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
 
668
  ) -> Union[UNet2DOutput, Tuple]:
669
  r"""
670
  The [`UNet2DModel`] forward method.
671
+
672
  Args:
673
  sample (`torch.FloatTensor`):
674
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
 
677
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
678
  return_dict (`bool`, *optional*, defaults to `True`):
679
  Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
680
+
681
  Returns:
682
  [`~models.unet_2d.UNet2DOutput`] or `tuple`:
683
  If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is