Spaces:
Runtime error
Runtime error
log audio to tensorboard
Browse files- .gitignore +5 -2
- README.md +2 -2
- ldm_autoencoder_kl.yaml +32 -0
- train_unconditional.py +27 -7
- train_vae.py +214 -0
.gitignore
CHANGED
@@ -3,6 +3,9 @@ __pycache__
|
|
3 |
.ipynb_checkpoints
|
4 |
data*
|
5 |
ddpm-ema-audio-*
|
6 |
-
flagged
|
7 |
-
build
|
8 |
audiodiffusion.egg-info
|
|
|
|
|
|
|
|
3 |
.ipynb_checkpoints
|
4 |
data*
|
5 |
ddpm-ema-audio-*
|
6 |
+
flagged/
|
7 |
+
build/
|
8 |
audiodiffusion.egg-info
|
9 |
+
lightning_logs/
|
10 |
+
taming/
|
11 |
+
checkpoints/
|
README.md
CHANGED
@@ -89,7 +89,7 @@ accelerate launch --config_file accelerate_local.yaml \
|
|
89 |
train_unconditional.py \
|
90 |
--dataset_name teticio/audio-diffusion-256 \
|
91 |
--resolution 256 \
|
92 |
-
--output_dir
|
93 |
--num_epochs 100 \
|
94 |
--train_batch_size 2 \
|
95 |
--eval_batch_size 2 \
|
@@ -98,7 +98,7 @@ accelerate launch --config_file accelerate_local.yaml \
|
|
98 |
--lr_warmup_steps 500 \
|
99 |
--mixed_precision no \
|
100 |
--push_to_hub True \
|
101 |
-
--hub_model_id audio-diffusion-256 \
|
102 |
--hub_token $(cat $HOME/.huggingface/token)
|
103 |
```
|
104 |
|
|
|
89 |
train_unconditional.py \
|
90 |
--dataset_name teticio/audio-diffusion-256 \
|
91 |
--resolution 256 \
|
92 |
+
--output_dir latent-audio-diffusion-256 \
|
93 |
--num_epochs 100 \
|
94 |
--train_batch_size 2 \
|
95 |
--eval_batch_size 2 \
|
|
|
98 |
--lr_warmup_steps 500 \
|
99 |
--mixed_precision no \
|
100 |
--push_to_hub True \
|
101 |
+
--hub_model_id latent-audio-diffusion-256 \
|
102 |
--hub_token $(cat $HOME/.huggingface/token)
|
103 |
```
|
104 |
|
ldm_autoencoder_kl.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
model:
|
3 |
+
base_learning_rate: 4.5e-6
|
4 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
5 |
+
params:
|
6 |
+
monitor: "val/rec_loss"
|
7 |
+
embed_dim: 3
|
8 |
+
lossconfig:
|
9 |
+
target: ldm.modules.losses.LPIPSWithDiscriminator
|
10 |
+
params:
|
11 |
+
disc_start: 50001
|
12 |
+
kl_weight: 0.000001
|
13 |
+
disc_weight: 0.5
|
14 |
+
|
15 |
+
ddconfig:
|
16 |
+
double_z: True
|
17 |
+
z_channels: 3
|
18 |
+
resolution: 256
|
19 |
+
in_channels: 3
|
20 |
+
out_ch: 3
|
21 |
+
ch: 128
|
22 |
+
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
|
23 |
+
num_res_blocks: 2
|
24 |
+
attn_resolutions: [ ]
|
25 |
+
dropout: 0.0
|
26 |
+
|
27 |
+
lightning:
|
28 |
+
trainer:
|
29 |
+
benchmark: True
|
30 |
+
accumulate_grad_batches: 24
|
31 |
+
accelerator: gpu
|
32 |
+
devices: 1
|
train_unconditional.py
CHANGED
@@ -10,7 +10,8 @@ from PIL import Image
|
|
10 |
from accelerate import Accelerator
|
11 |
from accelerate.logging import get_logger
|
12 |
from datasets import load_from_disk, load_dataset
|
13 |
-
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
|
|
14 |
from diffusers.hub_utils import init_git_repo, push_to_hub
|
15 |
from diffusers.optimization import get_scheduler
|
16 |
from diffusers.training_utils import EMAModel
|
@@ -40,8 +41,16 @@ def main(args):
|
|
40 |
)
|
41 |
|
42 |
if args.from_pretrained is not None:
|
43 |
-
model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
|
|
|
|
|
|
|
44 |
else:
|
|
|
|
|
|
|
|
|
|
|
45 |
model = UNet2DModel(
|
46 |
sample_size=args.resolution,
|
47 |
in_channels=1,
|
@@ -65,7 +74,10 @@ def main(args):
|
|
65 |
"UpBlock2D",
|
66 |
),
|
67 |
)
|
68 |
-
|
|
|
|
|
|
|
69 |
tensor_format="pt")
|
70 |
optimizer = torch.optim.AdamW(
|
71 |
model.parameters(),
|
@@ -169,14 +181,16 @@ def main(args):
|
|
169 |
device=clean_images.device,
|
170 |
).long()
|
171 |
|
|
|
172 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
173 |
# (this is the forward diffusion process)
|
174 |
-
|
175 |
-
|
176 |
|
177 |
with accelerator.accumulate(model):
|
178 |
# Predict the noise residual
|
179 |
-
|
|
|
180 |
loss = F.mse_loss(noise_pred, noise)
|
181 |
accelerator.backward(loss)
|
182 |
|
@@ -205,9 +219,15 @@ def main(args):
|
|
205 |
# Generate sample images for visual inspection
|
206 |
if accelerator.is_main_process:
|
207 |
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
208 |
-
pipeline = DDPMPipeline(
|
|
|
|
|
|
|
|
|
|
|
209 |
unet=accelerator.unwrap_model(
|
210 |
ema_model.averaged_model if args.use_ema else model),
|
|
|
211 |
scheduler=noise_scheduler,
|
212 |
)
|
213 |
|
|
|
10 |
from accelerate import Accelerator
|
11 |
from accelerate.logging import get_logger
|
12 |
from datasets import load_from_disk, load_dataset
|
13 |
+
from diffusers import (DDPMPipeline, DDPMScheduler, UNet2DModel, LDMPipeline,
|
14 |
+
DDIMScheduler, VQModel)
|
15 |
from diffusers.hub_utils import init_git_repo, push_to_hub
|
16 |
from diffusers.optimization import get_scheduler
|
17 |
from diffusers.training_utils import EMAModel
|
|
|
41 |
)
|
42 |
|
43 |
if args.from_pretrained is not None:
|
44 |
+
#model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
|
45 |
+
pretrained = LDMPipeline.from_pretrained(args.from_pretrained)
|
46 |
+
vqvae = pretrained.vqvae
|
47 |
+
model = pretrained.unet
|
48 |
else:
|
49 |
+
vqvae = VQModel(sample_size=args.resolution,
|
50 |
+
in_channels=1,
|
51 |
+
out_channels=1,
|
52 |
+
latent_channels=1,
|
53 |
+
layers_per_block=2)
|
54 |
model = UNet2DModel(
|
55 |
sample_size=args.resolution,
|
56 |
in_channels=1,
|
|
|
74 |
"UpBlock2D",
|
75 |
),
|
76 |
)
|
77 |
+
|
78 |
+
#noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
|
79 |
+
# tensor_format="pt")
|
80 |
+
noise_scheduler = DDIMScheduler(num_train_timesteps=1000,
|
81 |
tensor_format="pt")
|
82 |
optimizer = torch.optim.AdamW(
|
83 |
model.parameters(),
|
|
|
181 |
device=clean_images.device,
|
182 |
).long()
|
183 |
|
184 |
+
clean_latents = vqvae.encode(clean_images)["sample"]
|
185 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
186 |
# (this is the forward diffusion process)
|
187 |
+
noisy_latents = noise_scheduler.add_noise(clean_latents, noise,
|
188 |
+
timesteps)
|
189 |
|
190 |
with accelerator.accumulate(model):
|
191 |
# Predict the noise residual
|
192 |
+
latents = model(noisy_latents, timesteps)["sample"]
|
193 |
+
noise_pred = vqvae.decode(latents)["sample"]
|
194 |
loss = F.mse_loss(noise_pred, noise)
|
195 |
accelerator.backward(loss)
|
196 |
|
|
|
219 |
# Generate sample images for visual inspection
|
220 |
if accelerator.is_main_process:
|
221 |
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
222 |
+
#pipeline = DDPMPipeline(
|
223 |
+
# unet=accelerator.unwrap_model(
|
224 |
+
# ema_model.averaged_model if args.use_ema else model),
|
225 |
+
# scheduler=noise_scheduler,
|
226 |
+
#)
|
227 |
+
pipeline = LDMPipeline(
|
228 |
unet=accelerator.unwrap_model(
|
229 |
ema_model.averaged_model if args.use_ema else model),
|
230 |
+
vqvae=vqvae,
|
231 |
scheduler=noise_scheduler,
|
232 |
)
|
233 |
|
train_vae.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install -e git+https://github.com/CompVis/stable-diffusion.git@master
|
2 |
+
# pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
3 |
+
# convert_original_stable_diffusion_to_diffusers.py
|
4 |
+
|
5 |
+
# TODO
|
6 |
+
# grayscale
|
7 |
+
# log audio
|
8 |
+
# convert to huggingface / train huggingface
|
9 |
+
|
10 |
+
import os
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torchvision
|
15 |
+
import numpy as np
|
16 |
+
from PIL import Image
|
17 |
+
import pytorch_lightning as pl
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
from datasets import load_dataset
|
20 |
+
from librosa.util import normalize
|
21 |
+
from ldm.util import instantiate_from_config
|
22 |
+
from pytorch_lightning.trainer import Trainer
|
23 |
+
from torch.utils.data import DataLoader, Dataset
|
24 |
+
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
25 |
+
|
26 |
+
from audiodiffusion.mel import Mel
|
27 |
+
|
28 |
+
|
29 |
+
class AudioDiffusion(Dataset):
|
30 |
+
|
31 |
+
def __init__(self, model_id):
|
32 |
+
super().__init__()
|
33 |
+
self.hf_dataset = load_dataset(model_id)['train']
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return len(self.hf_dataset)
|
37 |
+
|
38 |
+
def __getitem__(self, idx):
|
39 |
+
image = self.hf_dataset[idx]['image'].convert('RGB')
|
40 |
+
image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
|
41 |
+
(image.height, image.width, 3))
|
42 |
+
image = ((image / 255) * 2 - 1)
|
43 |
+
return {'image': image}
|
44 |
+
|
45 |
+
|
46 |
+
class AudioDiffusionDataModule(pl.LightningDataModule):
|
47 |
+
|
48 |
+
def __init__(self, model_id, batch_size):
|
49 |
+
super().__init__()
|
50 |
+
self.batch_size = batch_size
|
51 |
+
self.dataset = AudioDiffusion(model_id)
|
52 |
+
self.num_workers = 1
|
53 |
+
|
54 |
+
def train_dataloader(self):
|
55 |
+
return DataLoader(self.dataset,
|
56 |
+
batch_size=self.batch_size,
|
57 |
+
num_workers=self.num_workers)
|
58 |
+
|
59 |
+
|
60 |
+
# from https://github.com/CompVis/stable-diffusion/blob/main/main.py
|
61 |
+
class ImageLogger(Callback):
|
62 |
+
|
63 |
+
def __init__(self,
|
64 |
+
batch_frequency,
|
65 |
+
max_images,
|
66 |
+
clamp=True,
|
67 |
+
increase_log_steps=True,
|
68 |
+
rescale=True,
|
69 |
+
disabled=False,
|
70 |
+
log_on_batch_idx=False,
|
71 |
+
log_first_step=False,
|
72 |
+
log_images_kwargs=None,
|
73 |
+
resolution=256,
|
74 |
+
hop_length=512):
|
75 |
+
super().__init__()
|
76 |
+
self.mel = Mel(x_res=resolution,
|
77 |
+
y_res=resolution,
|
78 |
+
hop_length=hop_length)
|
79 |
+
self.rescale = rescale
|
80 |
+
self.batch_freq = batch_frequency
|
81 |
+
self.max_images = max_images
|
82 |
+
self.logger_log_images = {
|
83 |
+
pl.loggers.TensorBoardLogger: self._testtube,
|
84 |
+
}
|
85 |
+
self.log_steps = [
|
86 |
+
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
|
87 |
+
]
|
88 |
+
if not increase_log_steps:
|
89 |
+
self.log_steps = [self.batch_freq]
|
90 |
+
self.clamp = clamp
|
91 |
+
self.disabled = disabled
|
92 |
+
self.log_on_batch_idx = log_on_batch_idx
|
93 |
+
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
94 |
+
self.log_first_step = log_first_step
|
95 |
+
|
96 |
+
#@rank_zero_only
|
97 |
+
def _testtube(self, pl_module, images, batch_idx, split):
|
98 |
+
for k in images:
|
99 |
+
images_ = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
100 |
+
grid = torchvision.utils.make_grid(images_)
|
101 |
+
|
102 |
+
tag = f"{split}/{k}"
|
103 |
+
pl_module.logger.experiment.add_image(
|
104 |
+
tag, grid, global_step=pl_module.global_step)
|
105 |
+
|
106 |
+
for _, image in enumerate(images_):
|
107 |
+
image = (images_.numpy() *
|
108 |
+
255).round().astype("uint8").transpose(0, 2, 3, 1)
|
109 |
+
audio = self.mel.image_to_audio(
|
110 |
+
Image.fromarray(image[0], mode='RGB').convert('L'))
|
111 |
+
pl_module.logger.experiment.add_audio(
|
112 |
+
tag + f"/{_}",
|
113 |
+
normalize(audio),
|
114 |
+
global_step=pl_module.global_step,
|
115 |
+
sample_rate=self.mel.get_sample_rate())
|
116 |
+
|
117 |
+
#@rank_zero_only
|
118 |
+
def log_local(self, save_dir, split, images, global_step, current_epoch,
|
119 |
+
batch_idx):
|
120 |
+
root = os.path.join(save_dir, "images", split)
|
121 |
+
for k in images:
|
122 |
+
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
123 |
+
if self.rescale:
|
124 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
125 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
126 |
+
grid = grid.numpy()
|
127 |
+
grid = (grid * 255).astype(np.uint8)
|
128 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
129 |
+
k, global_step, current_epoch, batch_idx)
|
130 |
+
path = os.path.join(root, filename)
|
131 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
132 |
+
Image.fromarray(grid).save(path)
|
133 |
+
|
134 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
135 |
+
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
136 |
+
if (self.check_frequency(check_idx)
|
137 |
+
and # batch_idx % self.batch_freq == 0
|
138 |
+
hasattr(pl_module, "log_images") and
|
139 |
+
callable(pl_module.log_images) and self.max_images > 0):
|
140 |
+
logger = type(pl_module.logger)
|
141 |
+
|
142 |
+
is_train = pl_module.training
|
143 |
+
if is_train:
|
144 |
+
pl_module.eval()
|
145 |
+
|
146 |
+
with torch.no_grad():
|
147 |
+
images = pl_module.log_images(batch,
|
148 |
+
split=split,
|
149 |
+
**self.log_images_kwargs)
|
150 |
+
|
151 |
+
for k in images:
|
152 |
+
N = min(images[k].shape[0], self.max_images)
|
153 |
+
images[k] = images[k][:N]
|
154 |
+
if isinstance(images[k], torch.Tensor):
|
155 |
+
images[k] = images[k].detach().cpu()
|
156 |
+
if self.clamp:
|
157 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
158 |
+
|
159 |
+
#self.log_local(pl_module.logger.save_dir, split, images,
|
160 |
+
# pl_module.global_step, pl_module.current_epoch,
|
161 |
+
# batch_idx)
|
162 |
+
|
163 |
+
logger_log_images = self.logger_log_images.get(
|
164 |
+
logger, lambda *args, **kwargs: None)
|
165 |
+
logger_log_images(pl_module, images, pl_module.global_step, split)
|
166 |
+
|
167 |
+
if is_train:
|
168 |
+
pl_module.train()
|
169 |
+
|
170 |
+
def check_frequency(self, check_idx):
|
171 |
+
if ((check_idx % self.batch_freq) == 0 or
|
172 |
+
(check_idx in self.log_steps)) and (check_idx > 0
|
173 |
+
or self.log_first_step):
|
174 |
+
try:
|
175 |
+
self.log_steps.pop(0)
|
176 |
+
except IndexError as e:
|
177 |
+
#print(e)
|
178 |
+
pass
|
179 |
+
return True
|
180 |
+
return False
|
181 |
+
|
182 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch,
|
183 |
+
batch_idx):
|
184 |
+
if not self.disabled and (pl_module.global_step > 0
|
185 |
+
or self.log_first_step):
|
186 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
187 |
+
|
188 |
+
|
189 |
+
if __name__ == "__main__":
|
190 |
+
parser = argparse.ArgumentParser(description="Train VAE using ldm.")
|
191 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
192 |
+
args = parser.parse_args()
|
193 |
+
|
194 |
+
config = OmegaConf.load('ldm_autoencoder_kl.yaml')
|
195 |
+
lightning_config = config.pop("lightning", OmegaConf.create())
|
196 |
+
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
197 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
198 |
+
trainer = Trainer.from_argparse_args(
|
199 |
+
trainer_opt,
|
200 |
+
callbacks=[
|
201 |
+
ImageLogger(batch_frequency=1000,
|
202 |
+
max_images=8,
|
203 |
+
increase_log_steps=False,
|
204 |
+
log_on_batch_idx=True),
|
205 |
+
ModelCheckpoint(dirpath='checkpoints',
|
206 |
+
filename='{epoch:06}',
|
207 |
+
verbose=True,
|
208 |
+
save_last=True)
|
209 |
+
])
|
210 |
+
model = instantiate_from_config(config.model)
|
211 |
+
model.learning_rate = config.model.base_learning_rate
|
212 |
+
data = AudioDiffusionDataModule('teticio/audio-diffusion-256',
|
213 |
+
batch_size=args.batch_size)
|
214 |
+
trainer.fit(model, data)
|