giulio98 commited on
Commit
b62e9f9
1 Parent(s): c1b0a9e

Update unet/conditional_unet_model.py

Browse files
Files changed (1) hide show
  1. unet/conditional_unet_model.py +5 -2
unet/conditional_unet_model.py CHANGED
@@ -25,12 +25,13 @@ class UNet2DOutput(BaseOutput):
25
 
26
  sample: torch.FloatTensor
27
 
28
-
29
  class UNet2DModel(ModelMixin, ConfigMixin):
30
  r"""
31
  A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
 
32
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
33
  for all models (such as downloading or saving).
 
34
  Parameters:
35
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
36
  Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
@@ -105,7 +106,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
105
  class_embed_type: Optional[str] = None,
106
  num_class_embeds: Optional[int] = None,
107
  num_train_timesteps: Optional[int] = None,
108
- set_W_to_weight: Optional[bool] = True,
109
  ):
110
  super().__init__()
111
 
@@ -237,6 +238,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
237
  ) -> Union[UNet2DOutput, Tuple]:
238
  r"""
239
  The [`UNet2DModel`] forward method.
 
240
  Args:
241
  sample (`torch.FloatTensor`):
242
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
@@ -245,6 +247,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
245
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
246
  return_dict (`bool`, *optional*, defaults to `True`):
247
  Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
 
248
  Returns:
249
  [`~models.unet_2d.UNet2DOutput`] or `tuple`:
250
  If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
 
25
 
26
  sample: torch.FloatTensor
27
 
 
28
  class UNet2DModel(ModelMixin, ConfigMixin):
29
  r"""
30
  A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
31
+
32
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
33
  for all models (such as downloading or saving).
34
+
35
  Parameters:
36
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
37
  Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
 
106
  class_embed_type: Optional[str] = None,
107
  num_class_embeds: Optional[int] = None,
108
  num_train_timesteps: Optional[int] = None,
109
+ set_W_to_weight: Optional[bool] = True
110
  ):
111
  super().__init__()
112
 
 
238
  ) -> Union[UNet2DOutput, Tuple]:
239
  r"""
240
  The [`UNet2DModel`] forward method.
241
+
242
  Args:
243
  sample (`torch.FloatTensor`):
244
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
 
247
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
248
  return_dict (`bool`, *optional*, defaults to `True`):
249
  Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
250
+
251
  Returns:
252
  [`~models.unet_2d.UNet2DOutput`] or `tuple`:
253
  If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is