Spaces:
Runtime error
Runtime error
fix when save_model_epochs != save_images_epochs
Browse files
scripts/train_unconditional.py
CHANGED
@@ -176,11 +176,11 @@ def main(args):
|
|
176 |
|
177 |
if args.push_to_hub:
|
178 |
if args.hub_model_id is None:
|
179 |
-
repo_name = get_full_repo_name(Path(
|
180 |
token=args.hub_token)
|
181 |
else:
|
182 |
repo_name = args.hub_model_id
|
183 |
-
repo = Repository(
|
184 |
|
185 |
if accelerator.is_main_process:
|
186 |
run = os.path.split(__file__)[-1].split(".")[0]
|
@@ -270,9 +270,9 @@ def main(args):
|
|
270 |
|
271 |
# Generate sample images for visual inspection
|
272 |
if accelerator.is_main_process:
|
273 |
-
if (
|
274 |
epoch + 1
|
275 |
-
) % args.
|
276 |
pipeline = AudioDiffusionPipeline(
|
277 |
vqvae=vqvae,
|
278 |
unet=accelerator.unwrap_model(
|
@@ -280,15 +280,17 @@ def main(args):
|
|
280 |
mel=mel,
|
281 |
scheduler=noise_scheduler,
|
282 |
)
|
283 |
-
|
|
|
|
|
|
|
|
|
284 |
|
285 |
# save the model
|
286 |
if args.push_to_hub:
|
287 |
repo.push_to_hub(commit_message=f"Epoch {epoch}",
|
288 |
blocking=False,
|
289 |
auto_lfs_prune=True)
|
290 |
-
else:
|
291 |
-
pipeline.save_pretrained(output_dir)
|
292 |
|
293 |
if (epoch + 1) % args.save_images_epochs == 0:
|
294 |
generator = torch.Generator(
|
|
|
176 |
|
177 |
if args.push_to_hub:
|
178 |
if args.hub_model_id is None:
|
179 |
+
repo_name = get_full_repo_name(Path(output_dir).name,
|
180 |
token=args.hub_token)
|
181 |
else:
|
182 |
repo_name = args.hub_model_id
|
183 |
+
repo = Repository(output_dir, clone_from=repo_name)
|
184 |
|
185 |
if accelerator.is_main_process:
|
186 |
run = os.path.split(__file__)[-1].split(".")[0]
|
|
|
270 |
|
271 |
# Generate sample images for visual inspection
|
272 |
if accelerator.is_main_process:
|
273 |
+
if (epoch + 1) % args.save_model_epochs == 0 or (
|
274 |
epoch + 1
|
275 |
+
) % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
276 |
pipeline = AudioDiffusionPipeline(
|
277 |
vqvae=vqvae,
|
278 |
unet=accelerator.unwrap_model(
|
|
|
280 |
mel=mel,
|
281 |
scheduler=noise_scheduler,
|
282 |
)
|
283 |
+
|
284 |
+
if (
|
285 |
+
epoch + 1
|
286 |
+
) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
287 |
+
pipeline.save_pretrained(output_dir)
|
288 |
|
289 |
# save the model
|
290 |
if args.push_to_hub:
|
291 |
repo.push_to_hub(commit_message=f"Epoch {epoch}",
|
292 |
blocking=False,
|
293 |
auto_lfs_prune=True)
|
|
|
|
|
294 |
|
295 |
if (epoch + 1) % args.save_images_epochs == 0:
|
296 |
generator = torch.Generator(
|