teticio commited on
Commit
c190f5b
1 Parent(s): 0a4662e
Files changed (1) hide show
  1. src/train_unconditional.py +7 -7
src/train_unconditional.py CHANGED
@@ -207,6 +207,13 @@ def main(args):
207
  # Generate sample images for visual inspection
208
  if accelerator.is_main_process:
209
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
 
 
 
 
 
 
 
210
  # save the model
211
  if args.push_to_hub:
212
  try:
@@ -222,13 +229,6 @@ def main(args):
222
  else:
223
  pipeline.save_pretrained(output_dir)
224
 
225
- pipeline = DDPMPipeline(
226
- unet=accelerator.unwrap_model(
227
- ema_model.averaged_model if args.use_ema else model
228
- ),
229
- scheduler=noise_scheduler,
230
- )
231
-
232
  generator = torch.manual_seed(0)
233
  # run pipeline in inference (sample random noise and denoise)
234
  images = pipeline(
 
207
  # Generate sample images for visual inspection
208
  if accelerator.is_main_process:
209
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
210
+ pipeline = DDPMPipeline(
211
+ unet=accelerator.unwrap_model(
212
+ ema_model.averaged_model if args.use_ema else model
213
+ ),
214
+ scheduler=noise_scheduler,
215
+ )
216
+
217
  # save the model
218
  if args.push_to_hub:
219
  try:
 
229
  else:
230
  pipeline.save_pretrained(output_dir)
231
 
 
 
 
 
 
 
 
232
  generator = torch.manual_seed(0)
233
  # run pipeline in inference (sample random noise and denoise)
234
  images = pipeline(