Update unet/conditional_unet_model.py
Browse files
unet/conditional_unet_model.py
CHANGED
@@ -25,12 +25,13 @@ class UNet2DOutput(BaseOutput):
|
|
25 |
|
26 |
sample: torch.FloatTensor
|
27 |
|
28 |
-
|
29 |
class UNet2DModel(ModelMixin, ConfigMixin):
|
30 |
r"""
|
31 |
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
|
|
32 |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
33 |
for all models (such as downloading or saving).
|
|
|
34 |
Parameters:
|
35 |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
36 |
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
@@ -105,7 +106,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
105 |
class_embed_type: Optional[str] = None,
|
106 |
num_class_embeds: Optional[int] = None,
|
107 |
num_train_timesteps: Optional[int] = None,
|
108 |
-
set_W_to_weight: Optional[bool] = True
|
109 |
):
|
110 |
super().__init__()
|
111 |
|
@@ -237,6 +238,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
237 |
) -> Union[UNet2DOutput, Tuple]:
|
238 |
r"""
|
239 |
The [`UNet2DModel`] forward method.
|
|
|
240 |
Args:
|
241 |
sample (`torch.FloatTensor`):
|
242 |
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
@@ -245,6 +247,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
245 |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
246 |
return_dict (`bool`, *optional*, defaults to `True`):
|
247 |
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
|
|
248 |
Returns:
|
249 |
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
250 |
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
|
|
25 |
|
26 |
sample: torch.FloatTensor
|
27 |
|
|
|
28 |
class UNet2DModel(ModelMixin, ConfigMixin):
|
29 |
r"""
|
30 |
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
31 |
+
|
32 |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
33 |
for all models (such as downloading or saving).
|
34 |
+
|
35 |
Parameters:
|
36 |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
37 |
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
|
|
106 |
class_embed_type: Optional[str] = None,
|
107 |
num_class_embeds: Optional[int] = None,
|
108 |
num_train_timesteps: Optional[int] = None,
|
109 |
+
set_W_to_weight: Optional[bool] = True
|
110 |
):
|
111 |
super().__init__()
|
112 |
|
|
|
238 |
) -> Union[UNet2DOutput, Tuple]:
|
239 |
r"""
|
240 |
The [`UNet2DModel`] forward method.
|
241 |
+
|
242 |
Args:
|
243 |
sample (`torch.FloatTensor`):
|
244 |
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
|
|
247 |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
248 |
return_dict (`bool`, *optional*, defaults to `True`):
|
249 |
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
250 |
+
|
251 |
Returns:
|
252 |
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
253 |
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|