Spaces:
Runtime error
Runtime error
remove unnecessary check
Browse files- scripts/train_unconditional.py +32 -32
scripts/train_unconditional.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
# based on https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py
|
2 |
|
3 |
-
# TODO
|
4 |
-
# Migrate to diffusers
|
5 |
-
# from diffusers.hub_utils import Repository
|
6 |
-
|
7 |
import argparse
|
8 |
import os
|
|
|
9 |
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
@@ -14,13 +11,14 @@ from accelerate import Accelerator
|
|
14 |
from accelerate.logging import get_logger
|
15 |
from datasets import load_from_disk, load_dataset
|
16 |
from diffusers import (
|
17 |
-
|
|
|
18 |
DDPMScheduler,
|
19 |
UNet2DModel,
|
20 |
DDIMScheduler,
|
21 |
AutoencoderKL,
|
22 |
)
|
23 |
-
from
|
24 |
from diffusers.optimization import get_scheduler
|
25 |
from diffusers.training_utils import EMAModel
|
26 |
from torchvision.transforms import (
|
@@ -32,12 +30,21 @@ import numpy as np
|
|
32 |
from tqdm.auto import tqdm
|
33 |
from librosa.util import normalize
|
34 |
|
35 |
-
#from diffusers import Mel, AudioDiffusionPipeline
|
36 |
-
from audiodiffusion import Mel, AudioDiffusionPipeline
|
37 |
-
|
38 |
logger = get_logger(__name__)
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def main(args):
|
42 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
43 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
@@ -94,8 +101,7 @@ def main(args):
|
|
94 |
try:
|
95 |
vqvae = AutoencoderKL.from_pretrained(args.vae)
|
96 |
except EnvironmentError:
|
97 |
-
vqvae = AudioDiffusionPipeline.from_pretrained(
|
98 |
-
args.vae).vqvae
|
99 |
# Determine latent resolution
|
100 |
with torch.no_grad():
|
101 |
latent_resolution = (vqvae.encode(
|
@@ -169,7 +175,12 @@ def main(args):
|
|
169 |
)
|
170 |
|
171 |
if args.push_to_hub:
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
if accelerator.is_main_process:
|
175 |
run = os.path.split(__file__)[-1].split(".")[0]
|
@@ -265,24 +276,17 @@ def main(args):
|
|
265 |
pipeline = AudioDiffusionPipeline(
|
266 |
vqvae=vqvae,
|
267 |
unet=accelerator.unwrap_model(
|
268 |
-
ema_model.averaged_model if args.use_ema else model
|
269 |
-
),
|
270 |
mel=mel,
|
271 |
scheduler=noise_scheduler,
|
272 |
)
|
|
|
273 |
|
274 |
# save the model
|
275 |
if args.push_to_hub:
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
pipeline,
|
280 |
-
repo,
|
281 |
-
commit_message=f"Epoch {epoch}",
|
282 |
-
blocking=False,
|
283 |
-
)
|
284 |
-
except NameError: # current version of diffusers has a little bug
|
285 |
-
pass
|
286 |
else:
|
287 |
pipeline.save_pretrained(output_dir)
|
288 |
|
@@ -290,11 +294,10 @@ def main(args):
|
|
290 |
generator = torch.Generator(
|
291 |
device=clean_images.device).manual_seed(42)
|
292 |
# run pipeline in inference (sample random noise and denoise)
|
293 |
-
images, (sample_rate,
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
)
|
298 |
|
299 |
# denormalize the images and save to tensorboard
|
300 |
images = np.array([
|
@@ -390,8 +393,5 @@ if __name__ == "__main__":
|
|
390 |
raise ValueError(
|
391 |
"You must specify either a dataset name from the hub or a train data directory."
|
392 |
)
|
393 |
-
if args.dataset_name is not None and args.dataset_name == args.hub_model_id:
|
394 |
-
raise ValueError(
|
395 |
-
"The local dataset name must be different from the hub model id.")
|
396 |
|
397 |
main(args)
|
|
|
1 |
# based on https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py
|
2 |
|
|
|
|
|
|
|
|
|
3 |
import argparse
|
4 |
import os
|
5 |
+
from typing import Optional
|
6 |
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
|
|
11 |
from accelerate.logging import get_logger
|
12 |
from datasets import load_from_disk, load_dataset
|
13 |
from diffusers import (
|
14 |
+
AudioDiffusionPipeline,
|
15 |
+
Mel,
|
16 |
DDPMScheduler,
|
17 |
UNet2DModel,
|
18 |
DDIMScheduler,
|
19 |
AutoencoderKL,
|
20 |
)
|
21 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
22 |
from diffusers.optimization import get_scheduler
|
23 |
from diffusers.training_utils import EMAModel
|
24 |
from torchvision.transforms import (
|
|
|
30 |
from tqdm.auto import tqdm
|
31 |
from librosa.util import normalize
|
32 |
|
|
|
|
|
|
|
33 |
logger = get_logger(__name__)
|
34 |
|
35 |
|
36 |
+
def get_full_repo_name(model_id: str,
|
37 |
+
organization: Optional[str] = None,
|
38 |
+
token: Optional[str] = None):
|
39 |
+
if token is None:
|
40 |
+
token = HfFolder.get_token()
|
41 |
+
if organization is None:
|
42 |
+
username = whoami(token)["name"]
|
43 |
+
return f"{username}/{model_id}"
|
44 |
+
else:
|
45 |
+
return f"{organization}/{model_id}"
|
46 |
+
|
47 |
+
|
48 |
def main(args):
|
49 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
50 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
|
|
101 |
try:
|
102 |
vqvae = AutoencoderKL.from_pretrained(args.vae)
|
103 |
except EnvironmentError:
|
104 |
+
vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae
|
|
|
105 |
# Determine latent resolution
|
106 |
with torch.no_grad():
|
107 |
latent_resolution = (vqvae.encode(
|
|
|
175 |
)
|
176 |
|
177 |
if args.push_to_hub:
|
178 |
+
if args.hub_model_id is None:
|
179 |
+
repo_name = get_full_repo_name(Path(args.output_dir).name,
|
180 |
+
token=args.hub_token)
|
181 |
+
else:
|
182 |
+
repo_name = args.hub_model_id
|
183 |
+
repo = Repository(args.output_dir, clone_from=repo_name)
|
184 |
|
185 |
if accelerator.is_main_process:
|
186 |
run = os.path.split(__file__)[-1].split(".")[0]
|
|
|
276 |
pipeline = AudioDiffusionPipeline(
|
277 |
vqvae=vqvae,
|
278 |
unet=accelerator.unwrap_model(
|
279 |
+
ema_model.averaged_model if args.use_ema else model),
|
|
|
280 |
mel=mel,
|
281 |
scheduler=noise_scheduler,
|
282 |
)
|
283 |
+
pipeline.save_pretrained(args.output_dir)
|
284 |
|
285 |
# save the model
|
286 |
if args.push_to_hub:
|
287 |
+
repo.push_to_hub(commit_message=f"Epoch {epoch}",
|
288 |
+
blocking=False,
|
289 |
+
auto_lfs_prune=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
else:
|
291 |
pipeline.save_pretrained(output_dir)
|
292 |
|
|
|
294 |
generator = torch.Generator(
|
295 |
device=clean_images.device).manual_seed(42)
|
296 |
# run pipeline in inference (sample random noise and denoise)
|
297 |
+
images, (sample_rate,
|
298 |
+
audios) = pipeline(generator=generator,
|
299 |
+
batch_size=args.eval_batch_size,
|
300 |
+
return_dict=False)
|
|
|
301 |
|
302 |
# denormalize the images and save to tensorboard
|
303 |
images = np.array([
|
|
|
393 |
raise ValueError(
|
394 |
"You must specify either a dataset name from the hub or a train data directory."
|
395 |
)
|
|
|
|
|
|
|
396 |
|
397 |
main(args)
|