teticio commited on
Commit
29c19d2
1 Parent(s): d7ba4b7

remove unnecessary check

Browse files
Files changed (1) hide show
  1. scripts/train_unconditional.py +32 -32
scripts/train_unconditional.py CHANGED
@@ -1,11 +1,8 @@
1
  # based on https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py
2
 
3
- # TODO
4
- # Migrate to diffusers
5
- # from diffusers.hub_utils import Repository
6
-
7
  import argparse
8
  import os
 
9
 
10
  import torch
11
  import torch.nn.functional as F
@@ -14,13 +11,14 @@ from accelerate import Accelerator
14
  from accelerate.logging import get_logger
15
  from datasets import load_from_disk, load_dataset
16
  from diffusers import (
17
- #AudioDiffusionPipeline,
 
18
  DDPMScheduler,
19
  UNet2DModel,
20
  DDIMScheduler,
21
  AutoencoderKL,
22
  )
23
- from diffusers.hub_utils import init_git_repo, push_to_hub
24
  from diffusers.optimization import get_scheduler
25
  from diffusers.training_utils import EMAModel
26
  from torchvision.transforms import (
@@ -32,12 +30,21 @@ import numpy as np
32
  from tqdm.auto import tqdm
33
  from librosa.util import normalize
34
 
35
- #from diffusers import Mel, AudioDiffusionPipeline
36
- from audiodiffusion import Mel, AudioDiffusionPipeline
37
-
38
  logger = get_logger(__name__)
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def main(args):
42
  output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
43
  logging_dir = os.path.join(output_dir, args.logging_dir)
@@ -94,8 +101,7 @@ def main(args):
94
  try:
95
  vqvae = AutoencoderKL.from_pretrained(args.vae)
96
  except EnvironmentError:
97
- vqvae = AudioDiffusionPipeline.from_pretrained(
98
- args.vae).vqvae
99
  # Determine latent resolution
100
  with torch.no_grad():
101
  latent_resolution = (vqvae.encode(
@@ -169,7 +175,12 @@ def main(args):
169
  )
170
 
171
  if args.push_to_hub:
172
- repo = init_git_repo(args, at_init=True)
 
 
 
 
 
173
 
174
  if accelerator.is_main_process:
175
  run = os.path.split(__file__)[-1].split(".")[0]
@@ -265,24 +276,17 @@ def main(args):
265
  pipeline = AudioDiffusionPipeline(
266
  vqvae=vqvae,
267
  unet=accelerator.unwrap_model(
268
- ema_model.averaged_model if args.use_ema else model
269
- ),
270
  mel=mel,
271
  scheduler=noise_scheduler,
272
  )
 
273
 
274
  # save the model
275
  if args.push_to_hub:
276
- try:
277
- push_to_hub(
278
- args,
279
- pipeline,
280
- repo,
281
- commit_message=f"Epoch {epoch}",
282
- blocking=False,
283
- )
284
- except NameError: # current version of diffusers has a little bug
285
- pass
286
  else:
287
  pipeline.save_pretrained(output_dir)
288
 
@@ -290,11 +294,10 @@ def main(args):
290
  generator = torch.Generator(
291
  device=clean_images.device).manual_seed(42)
292
  # run pipeline in inference (sample random noise and denoise)
293
- images, (sample_rate, audios) = pipeline(
294
- generator=generator,
295
- batch_size=args.eval_batch_size,
296
- return_dict=False
297
- )
298
 
299
  # denormalize the images and save to tensorboard
300
  images = np.array([
@@ -390,8 +393,5 @@ if __name__ == "__main__":
390
  raise ValueError(
391
  "You must specify either a dataset name from the hub or a train data directory."
392
  )
393
- if args.dataset_name is not None and args.dataset_name == args.hub_model_id:
394
- raise ValueError(
395
- "The local dataset name must be different from the hub model id.")
396
 
397
  main(args)
 
1
  # based on https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py
2
 
 
 
 
 
3
  import argparse
4
  import os
5
+ from typing import Optional
6
 
7
  import torch
8
  import torch.nn.functional as F
 
11
  from accelerate.logging import get_logger
12
  from datasets import load_from_disk, load_dataset
13
  from diffusers import (
14
+ AudioDiffusionPipeline,
15
+ Mel,
16
  DDPMScheduler,
17
  UNet2DModel,
18
  DDIMScheduler,
19
  AutoencoderKL,
20
  )
21
+ from huggingface_hub import HfFolder, Repository, whoami
22
  from diffusers.optimization import get_scheduler
23
  from diffusers.training_utils import EMAModel
24
  from torchvision.transforms import (
 
30
  from tqdm.auto import tqdm
31
  from librosa.util import normalize
32
 
 
 
 
33
  logger = get_logger(__name__)
34
 
35
 
36
+ def get_full_repo_name(model_id: str,
37
+ organization: Optional[str] = None,
38
+ token: Optional[str] = None):
39
+ if token is None:
40
+ token = HfFolder.get_token()
41
+ if organization is None:
42
+ username = whoami(token)["name"]
43
+ return f"{username}/{model_id}"
44
+ else:
45
+ return f"{organization}/{model_id}"
46
+
47
+
48
  def main(args):
49
  output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
50
  logging_dir = os.path.join(output_dir, args.logging_dir)
 
101
  try:
102
  vqvae = AutoencoderKL.from_pretrained(args.vae)
103
  except EnvironmentError:
104
+ vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae
 
105
  # Determine latent resolution
106
  with torch.no_grad():
107
  latent_resolution = (vqvae.encode(
 
175
  )
176
 
177
  if args.push_to_hub:
178
+ if args.hub_model_id is None:
179
+ repo_name = get_full_repo_name(Path(args.output_dir).name,
180
+ token=args.hub_token)
181
+ else:
182
+ repo_name = args.hub_model_id
183
+ repo = Repository(args.output_dir, clone_from=repo_name)
184
 
185
  if accelerator.is_main_process:
186
  run = os.path.split(__file__)[-1].split(".")[0]
 
276
  pipeline = AudioDiffusionPipeline(
277
  vqvae=vqvae,
278
  unet=accelerator.unwrap_model(
279
+ ema_model.averaged_model if args.use_ema else model),
 
280
  mel=mel,
281
  scheduler=noise_scheduler,
282
  )
283
+ pipeline.save_pretrained(args.output_dir)
284
 
285
  # save the model
286
  if args.push_to_hub:
287
+ repo.push_to_hub(commit_message=f"Epoch {epoch}",
288
+ blocking=False,
289
+ auto_lfs_prune=True)
 
 
 
 
 
 
 
290
  else:
291
  pipeline.save_pretrained(output_dir)
292
 
 
294
  generator = torch.Generator(
295
  device=clean_images.device).manual_seed(42)
296
  # run pipeline in inference (sample random noise and denoise)
297
+ images, (sample_rate,
298
+ audios) = pipeline(generator=generator,
299
+ batch_size=args.eval_batch_size,
300
+ return_dict=False)
 
301
 
302
  # denormalize the images and save to tensorboard
303
  images = np.array([
 
393
  raise ValueError(
394
  "You must specify either a dataset name from the hub or a train data directory."
395
  )
 
 
 
396
 
397
  main(args)