teticio commited on
Commit
3e8b723
1 Parent(s): 58bc92a

log audio to tensorboard

Browse files
Files changed (5) hide show
  1. .gitignore +5 -2
  2. README.md +2 -2
  3. ldm_autoencoder_kl.yaml +32 -0
  4. train_unconditional.py +27 -7
  5. train_vae.py +214 -0
.gitignore CHANGED
@@ -3,6 +3,9 @@ __pycache__
3
  .ipynb_checkpoints
4
  data*
5
  ddpm-ema-audio-*
6
- flagged
7
- build
8
  audiodiffusion.egg-info
 
 
 
 
3
  .ipynb_checkpoints
4
  data*
5
  ddpm-ema-audio-*
6
+ flagged/
7
+ build/
8
  audiodiffusion.egg-info
9
+ lightning_logs/
10
+ taming/
11
+ checkpoints/
README.md CHANGED
@@ -89,7 +89,7 @@ accelerate launch --config_file accelerate_local.yaml \
89
  train_unconditional.py \
90
  --dataset_name teticio/audio-diffusion-256 \
91
  --resolution 256 \
92
- --output_dir ddpm-ema-audio-256 \
93
  --num_epochs 100 \
94
  --train_batch_size 2 \
95
  --eval_batch_size 2 \
@@ -98,7 +98,7 @@ accelerate launch --config_file accelerate_local.yaml \
98
  --lr_warmup_steps 500 \
99
  --mixed_precision no \
100
  --push_to_hub True \
101
- --hub_model_id audio-diffusion-256 \
102
  --hub_token $(cat $HOME/.huggingface/token)
103
  ```
104
 
 
89
  train_unconditional.py \
90
  --dataset_name teticio/audio-diffusion-256 \
91
  --resolution 256 \
92
+ --output_dir latent-audio-diffusion-256 \
93
  --num_epochs 100 \
94
  --train_batch_size 2 \
95
  --eval_batch_size 2 \
 
98
  --lr_warmup_steps 500 \
99
  --mixed_precision no \
100
  --push_to_hub True \
101
+ --hub_model_id latent-audio-diffusion-256 \
102
  --hub_token $(cat $HOME/.huggingface/token)
103
  ```
104
 
ldm_autoencoder_kl.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ base_learning_rate: 4.5e-6
4
+ target: ldm.models.autoencoder.AutoencoderKL
5
+ params:
6
+ monitor: "val/rec_loss"
7
+ embed_dim: 3
8
+ lossconfig:
9
+ target: ldm.modules.losses.LPIPSWithDiscriminator
10
+ params:
11
+ disc_start: 50001
12
+ kl_weight: 0.000001
13
+ disc_weight: 0.5
14
+
15
+ ddconfig:
16
+ double_z: True
17
+ z_channels: 3
18
+ resolution: 256
19
+ in_channels: 3
20
+ out_ch: 3
21
+ ch: 128
22
+ ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
23
+ num_res_blocks: 2
24
+ attn_resolutions: [ ]
25
+ dropout: 0.0
26
+
27
+ lightning:
28
+ trainer:
29
+ benchmark: True
30
+ accumulate_grad_batches: 24
31
+ accelerator: gpu
32
+ devices: 1
train_unconditional.py CHANGED
@@ -10,7 +10,8 @@ from PIL import Image
10
  from accelerate import Accelerator
11
  from accelerate.logging import get_logger
12
  from datasets import load_from_disk, load_dataset
13
- from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
 
14
  from diffusers.hub_utils import init_git_repo, push_to_hub
15
  from diffusers.optimization import get_scheduler
16
  from diffusers.training_utils import EMAModel
@@ -40,8 +41,16 @@ def main(args):
40
  )
41
 
42
  if args.from_pretrained is not None:
43
- model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
 
 
 
44
  else:
 
 
 
 
 
45
  model = UNet2DModel(
46
  sample_size=args.resolution,
47
  in_channels=1,
@@ -65,7 +74,10 @@ def main(args):
65
  "UpBlock2D",
66
  ),
67
  )
68
- noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
 
 
 
69
  tensor_format="pt")
70
  optimizer = torch.optim.AdamW(
71
  model.parameters(),
@@ -169,14 +181,16 @@ def main(args):
169
  device=clean_images.device,
170
  ).long()
171
 
 
172
  # Add noise to the clean images according to the noise magnitude at each timestep
173
  # (this is the forward diffusion process)
174
- noisy_images = noise_scheduler.add_noise(clean_images, noise,
175
- timesteps)
176
 
177
  with accelerator.accumulate(model):
178
  # Predict the noise residual
179
- noise_pred = model(noisy_images, timesteps)["sample"]
 
180
  loss = F.mse_loss(noise_pred, noise)
181
  accelerator.backward(loss)
182
 
@@ -205,9 +219,15 @@ def main(args):
205
  # Generate sample images for visual inspection
206
  if accelerator.is_main_process:
207
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
208
- pipeline = DDPMPipeline(
 
 
 
 
 
209
  unet=accelerator.unwrap_model(
210
  ema_model.averaged_model if args.use_ema else model),
 
211
  scheduler=noise_scheduler,
212
  )
213
 
 
10
  from accelerate import Accelerator
11
  from accelerate.logging import get_logger
12
  from datasets import load_from_disk, load_dataset
13
+ from diffusers import (DDPMPipeline, DDPMScheduler, UNet2DModel, LDMPipeline,
14
+ DDIMScheduler, VQModel)
15
  from diffusers.hub_utils import init_git_repo, push_to_hub
16
  from diffusers.optimization import get_scheduler
17
  from diffusers.training_utils import EMAModel
 
41
  )
42
 
43
  if args.from_pretrained is not None:
44
+ #model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
45
+ pretrained = LDMPipeline.from_pretrained(args.from_pretrained)
46
+ vqvae = pretrained.vqvae
47
+ model = pretrained.unet
48
  else:
49
+ vqvae = VQModel(sample_size=args.resolution,
50
+ in_channels=1,
51
+ out_channels=1,
52
+ latent_channels=1,
53
+ layers_per_block=2)
54
  model = UNet2DModel(
55
  sample_size=args.resolution,
56
  in_channels=1,
 
74
  "UpBlock2D",
75
  ),
76
  )
77
+
78
+ #noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
79
+ # tensor_format="pt")
80
+ noise_scheduler = DDIMScheduler(num_train_timesteps=1000,
81
  tensor_format="pt")
82
  optimizer = torch.optim.AdamW(
83
  model.parameters(),
 
181
  device=clean_images.device,
182
  ).long()
183
 
184
+ clean_latents = vqvae.encode(clean_images)["sample"]
185
  # Add noise to the clean images according to the noise magnitude at each timestep
186
  # (this is the forward diffusion process)
187
+ noisy_latents = noise_scheduler.add_noise(clean_latents, noise,
188
+ timesteps)
189
 
190
  with accelerator.accumulate(model):
191
  # Predict the noise residual
192
+ latents = model(noisy_latents, timesteps)["sample"]
193
+ noise_pred = vqvae.decode(latents)["sample"]
194
  loss = F.mse_loss(noise_pred, noise)
195
  accelerator.backward(loss)
196
 
 
219
  # Generate sample images for visual inspection
220
  if accelerator.is_main_process:
221
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
222
+ #pipeline = DDPMPipeline(
223
+ # unet=accelerator.unwrap_model(
224
+ # ema_model.averaged_model if args.use_ema else model),
225
+ # scheduler=noise_scheduler,
226
+ #)
227
+ pipeline = LDMPipeline(
228
  unet=accelerator.unwrap_model(
229
  ema_model.averaged_model if args.use_ema else model),
230
+ vqvae=vqvae,
231
  scheduler=noise_scheduler,
232
  )
233
 
train_vae.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -e git+https://github.com/CompVis/stable-diffusion.git@master
2
+ # pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
3
+ # convert_original_stable_diffusion_to_diffusers.py
4
+
5
+ # TODO
6
+ # grayscale
7
+ # log audio
8
+ # convert to huggingface / train huggingface
9
+
10
+ import os
11
+ import argparse
12
+
13
+ import torch
14
+ import torchvision
15
+ import numpy as np
16
+ from PIL import Image
17
+ import pytorch_lightning as pl
18
+ from omegaconf import OmegaConf
19
+ from datasets import load_dataset
20
+ from librosa.util import normalize
21
+ from ldm.util import instantiate_from_config
22
+ from pytorch_lightning.trainer import Trainer
23
+ from torch.utils.data import DataLoader, Dataset
24
+ from pytorch_lightning.callbacks import Callback, ModelCheckpoint
25
+
26
+ from audiodiffusion.mel import Mel
27
+
28
+
29
+ class AudioDiffusion(Dataset):
30
+
31
+ def __init__(self, model_id):
32
+ super().__init__()
33
+ self.hf_dataset = load_dataset(model_id)['train']
34
+
35
+ def __len__(self):
36
+ return len(self.hf_dataset)
37
+
38
+ def __getitem__(self, idx):
39
+ image = self.hf_dataset[idx]['image'].convert('RGB')
40
+ image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
41
+ (image.height, image.width, 3))
42
+ image = ((image / 255) * 2 - 1)
43
+ return {'image': image}
44
+
45
+
46
+ class AudioDiffusionDataModule(pl.LightningDataModule):
47
+
48
+ def __init__(self, model_id, batch_size):
49
+ super().__init__()
50
+ self.batch_size = batch_size
51
+ self.dataset = AudioDiffusion(model_id)
52
+ self.num_workers = 1
53
+
54
+ def train_dataloader(self):
55
+ return DataLoader(self.dataset,
56
+ batch_size=self.batch_size,
57
+ num_workers=self.num_workers)
58
+
59
+
60
+ # from https://github.com/CompVis/stable-diffusion/blob/main/main.py
61
+ class ImageLogger(Callback):
62
+
63
+ def __init__(self,
64
+ batch_frequency,
65
+ max_images,
66
+ clamp=True,
67
+ increase_log_steps=True,
68
+ rescale=True,
69
+ disabled=False,
70
+ log_on_batch_idx=False,
71
+ log_first_step=False,
72
+ log_images_kwargs=None,
73
+ resolution=256,
74
+ hop_length=512):
75
+ super().__init__()
76
+ self.mel = Mel(x_res=resolution,
77
+ y_res=resolution,
78
+ hop_length=hop_length)
79
+ self.rescale = rescale
80
+ self.batch_freq = batch_frequency
81
+ self.max_images = max_images
82
+ self.logger_log_images = {
83
+ pl.loggers.TensorBoardLogger: self._testtube,
84
+ }
85
+ self.log_steps = [
86
+ 2**n for n in range(int(np.log2(self.batch_freq)) + 1)
87
+ ]
88
+ if not increase_log_steps:
89
+ self.log_steps = [self.batch_freq]
90
+ self.clamp = clamp
91
+ self.disabled = disabled
92
+ self.log_on_batch_idx = log_on_batch_idx
93
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
94
+ self.log_first_step = log_first_step
95
+
96
+ #@rank_zero_only
97
+ def _testtube(self, pl_module, images, batch_idx, split):
98
+ for k in images:
99
+ images_ = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
100
+ grid = torchvision.utils.make_grid(images_)
101
+
102
+ tag = f"{split}/{k}"
103
+ pl_module.logger.experiment.add_image(
104
+ tag, grid, global_step=pl_module.global_step)
105
+
106
+ for _, image in enumerate(images_):
107
+ image = (images_.numpy() *
108
+ 255).round().astype("uint8").transpose(0, 2, 3, 1)
109
+ audio = self.mel.image_to_audio(
110
+ Image.fromarray(image[0], mode='RGB').convert('L'))
111
+ pl_module.logger.experiment.add_audio(
112
+ tag + f"/{_}",
113
+ normalize(audio),
114
+ global_step=pl_module.global_step,
115
+ sample_rate=self.mel.get_sample_rate())
116
+
117
+ #@rank_zero_only
118
+ def log_local(self, save_dir, split, images, global_step, current_epoch,
119
+ batch_idx):
120
+ root = os.path.join(save_dir, "images", split)
121
+ for k in images:
122
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
123
+ if self.rescale:
124
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
125
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
126
+ grid = grid.numpy()
127
+ grid = (grid * 255).astype(np.uint8)
128
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
129
+ k, global_step, current_epoch, batch_idx)
130
+ path = os.path.join(root, filename)
131
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
132
+ Image.fromarray(grid).save(path)
133
+
134
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
135
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
136
+ if (self.check_frequency(check_idx)
137
+ and # batch_idx % self.batch_freq == 0
138
+ hasattr(pl_module, "log_images") and
139
+ callable(pl_module.log_images) and self.max_images > 0):
140
+ logger = type(pl_module.logger)
141
+
142
+ is_train = pl_module.training
143
+ if is_train:
144
+ pl_module.eval()
145
+
146
+ with torch.no_grad():
147
+ images = pl_module.log_images(batch,
148
+ split=split,
149
+ **self.log_images_kwargs)
150
+
151
+ for k in images:
152
+ N = min(images[k].shape[0], self.max_images)
153
+ images[k] = images[k][:N]
154
+ if isinstance(images[k], torch.Tensor):
155
+ images[k] = images[k].detach().cpu()
156
+ if self.clamp:
157
+ images[k] = torch.clamp(images[k], -1., 1.)
158
+
159
+ #self.log_local(pl_module.logger.save_dir, split, images,
160
+ # pl_module.global_step, pl_module.current_epoch,
161
+ # batch_idx)
162
+
163
+ logger_log_images = self.logger_log_images.get(
164
+ logger, lambda *args, **kwargs: None)
165
+ logger_log_images(pl_module, images, pl_module.global_step, split)
166
+
167
+ if is_train:
168
+ pl_module.train()
169
+
170
+ def check_frequency(self, check_idx):
171
+ if ((check_idx % self.batch_freq) == 0 or
172
+ (check_idx in self.log_steps)) and (check_idx > 0
173
+ or self.log_first_step):
174
+ try:
175
+ self.log_steps.pop(0)
176
+ except IndexError as e:
177
+ #print(e)
178
+ pass
179
+ return True
180
+ return False
181
+
182
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch,
183
+ batch_idx):
184
+ if not self.disabled and (pl_module.global_step > 0
185
+ or self.log_first_step):
186
+ self.log_img(pl_module, batch, batch_idx, split="train")
187
+
188
+
189
+ if __name__ == "__main__":
190
+ parser = argparse.ArgumentParser(description="Train VAE using ldm.")
191
+ parser.add_argument("--batch_size", type=int, default=1)
192
+ args = parser.parse_args()
193
+
194
+ config = OmegaConf.load('ldm_autoencoder_kl.yaml')
195
+ lightning_config = config.pop("lightning", OmegaConf.create())
196
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
197
+ trainer_opt = argparse.Namespace(**trainer_config)
198
+ trainer = Trainer.from_argparse_args(
199
+ trainer_opt,
200
+ callbacks=[
201
+ ImageLogger(batch_frequency=1000,
202
+ max_images=8,
203
+ increase_log_steps=False,
204
+ log_on_batch_idx=True),
205
+ ModelCheckpoint(dirpath='checkpoints',
206
+ filename='{epoch:06}',
207
+ verbose=True,
208
+ save_last=True)
209
+ ])
210
+ model = instantiate_from_config(config.model)
211
+ model.learning_rate = config.model.base_learning_rate
212
+ data = AudioDiffusionDataModule('teticio/audio-diffusion-256',
213
+ batch_size=args.batch_size)
214
+ trainer.fit(model, data)