teticio commited on
Commit
8f292f9
1 Parent(s): 9a9737e

update train_unconditional for latent diffusion

Browse files
README.md CHANGED
@@ -89,7 +89,7 @@ accelerate launch --config_file config/accelerate_local.yaml \
89
  scripts/train_unconditional.py \
90
  --dataset_name teticio/audio-diffusion-256 \
91
  --resolution 256 \
92
- --output_dir latent-audio-diffusion-256 \
93
  --num_epochs 100 \
94
  --train_batch_size 2 \
95
  --eval_batch_size 2 \
@@ -98,7 +98,7 @@ accelerate launch --config_file config/accelerate_local.yaml \
98
  --lr_warmup_steps 500 \
99
  --mixed_precision no \
100
  --push_to_hub True \
101
- --hub_model_id latent-audio-diffusion-256 \
102
  --hub_token $(cat $HOME/.huggingface/token)
103
  ```
104
  #### Run training on SageMaker.
 
89
  scripts/train_unconditional.py \
90
  --dataset_name teticio/audio-diffusion-256 \
91
  --resolution 256 \
92
+ --output_dir audio-diffusion-256 \
93
  --num_epochs 100 \
94
  --train_batch_size 2 \
95
  --eval_batch_size 2 \
 
98
  --lr_warmup_steps 500 \
99
  --mixed_precision no \
100
  --push_to_hub True \
101
+ --hub_model_id audio-diffusion-256 \
102
  --hub_token $(cat $HOME/.huggingface/token)
103
  ```
104
  #### Run training on SageMaker.
scripts/train_unconditional.py CHANGED
@@ -48,8 +48,9 @@ def main(args):
48
  model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
49
  else:
50
  model = UNet2DModel(
51
- in_channels=1,
52
- out_channels=1,
 
53
  layers_per_block=2,
54
  block_out_channels=(128, 128, 256, 256, 512, 512),
55
  down_block_types=(
@@ -114,7 +115,7 @@ def main(args):
114
  def transforms(examples):
115
  if args.vae is not None:
116
  images = [
117
- augmentations(image).convert("RGB")
118
  for image in examples["image"]
119
  ]
120
  else:
@@ -173,6 +174,13 @@ def main(args):
173
  model.train()
174
  for step, batch in enumerate(train_dataloader):
175
  clean_images = batch["input"]
 
 
 
 
 
 
 
176
  # Sample noise that we'll add to the images
177
  noise = torch.randn(clean_images.shape).to(clean_images.device)
178
  bsz = clean_images.shape[0]
@@ -184,11 +192,6 @@ def main(args):
184
  device=clean_images.device,
185
  ).long()
186
 
187
- if args.vae is not None:
188
- with torch.no_grad():
189
- clean_images = vqvae.encode(
190
- clean_images).latent_dist.sample()
191
-
192
  # Add noise to the clean images according to the noise magnitude at each timestep
193
  # (this is the forward diffusion process)
194
  noisy_images = noise_scheduler.add_noise(clean_images, noise,
@@ -196,8 +199,7 @@ def main(args):
196
 
197
  with accelerator.accumulate(model):
198
  # Predict the noise residual
199
- images = model(noisy_images, timesteps)["sample"]
200
- noise_pred = vqvae.decode(images)["sample"]
201
  loss = F.mse_loss(noise_pred, noise)
202
  accelerator.backward(loss)
203
 
@@ -209,13 +211,6 @@ def main(args):
209
  ema_model.step(model)
210
  optimizer.zero_grad()
211
 
212
- if args.vae is not None:
213
- with torch.no_grad():
214
- images = [
215
- image.convert('L')
216
- for image in vqvae.decode(images)["sample"]
217
- ]
218
-
219
  if accelerator.sync_gradients:
220
  progress_bar.update(1)
221
  global_step += 1
@@ -239,14 +234,16 @@ def main(args):
239
  if args.vae is not None:
240
  pipeline = LDMPipeline(
241
  unet=accelerator.unwrap_model(
242
- ema_model.averaged_model if args.use_ema else model),
 
243
  vqvae=vqvae,
244
  scheduler=noise_scheduler,
245
  )
246
  else:
247
  pipeline = DDPMPipeline(
248
  unet=accelerator.unwrap_model(
249
- ema_model.averaged_model if args.use_ema else model),
 
250
  scheduler=noise_scheduler,
251
  )
252
 
 
48
  model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
49
  else:
50
  model = UNet2DModel(
51
+ sample_size=args.resolution if args.vae is None else 64,
52
+ in_channels=1 if args.vae is None else 3,
53
+ out_channels=1 if args.vae is None else 3,
54
  layers_per_block=2,
55
  block_out_channels=(128, 128, 256, 256, 512, 512),
56
  down_block_types=(
 
115
  def transforms(examples):
116
  if args.vae is not None:
117
  images = [
118
+ augmentations(image.convert("RGB"))
119
  for image in examples["image"]
120
  ]
121
  else:
 
174
  model.train()
175
  for step, batch in enumerate(train_dataloader):
176
  clean_images = batch["input"]
177
+
178
+ if args.vae is not None:
179
+ vqvae.to(clean_images.device)
180
+ with torch.no_grad():
181
+ clean_images = vqvae.encode(
182
+ clean_images).latent_dist.sample()
183
+
184
  # Sample noise that we'll add to the images
185
  noise = torch.randn(clean_images.shape).to(clean_images.device)
186
  bsz = clean_images.shape[0]
 
192
  device=clean_images.device,
193
  ).long()
194
 
 
 
 
 
 
195
  # Add noise to the clean images according to the noise magnitude at each timestep
196
  # (this is the forward diffusion process)
197
  noisy_images = noise_scheduler.add_noise(clean_images, noise,
 
199
 
200
  with accelerator.accumulate(model):
201
  # Predict the noise residual
202
+ noise_pred = model(noisy_images, timesteps)["sample"]
 
203
  loss = F.mse_loss(noise_pred, noise)
204
  accelerator.backward(loss)
205
 
 
211
  ema_model.step(model)
212
  optimizer.zero_grad()
213
 
 
 
 
 
 
 
 
214
  if accelerator.sync_gradients:
215
  progress_bar.update(1)
216
  global_step += 1
 
234
  if args.vae is not None:
235
  pipeline = LDMPipeline(
236
  unet=accelerator.unwrap_model(
237
+ ema_model.averaged_model if args.use_ema else model
238
+ ),
239
  vqvae=vqvae,
240
  scheduler=noise_scheduler,
241
  )
242
  else:
243
  pipeline = DDPMPipeline(
244
  unet=accelerator.unwrap_model(
245
+ ema_model.averaged_model if args.use_ema else model
246
+ ),
247
  scheduler=noise_scheduler,
248
  )
249
 
scripts/train_vae.py CHANGED
@@ -4,9 +4,7 @@
4
 
5
  # TODO
6
  # grayscale
7
- # add vae to train_uncond (no_grad)
8
  # update README
9
- # merge in changes to train_unconditional
10
 
11
  import os
12
  import argparse
 
4
 
5
  # TODO
6
  # grayscale
 
7
  # update README
 
8
 
9
  import os
10
  import argparse