audio-diffusion / scripts /train_vae.py
teticio's picture
update readme
01c4a98
raw
history blame
6.75 kB
import os
import argparse
import torch
import torchvision
import numpy as np
from PIL import Image
import pytorch_lightning as pl
from omegaconf import OmegaConf
from librosa.util import normalize
from ldm.util import instantiate_from_config
from pytorch_lightning.trainer import Trainer
from torch.utils.data import DataLoader, Dataset
from datasets import load_from_disk, load_dataset
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.utilities.distributed import rank_zero_only
from audiodiffusion.mel import Mel
from audiodiffusion.utils import convert_ldm_to_hf_vae
class AudioDiffusion(Dataset):
def __init__(self, model_id, channels=3):
super().__init__()
self.channels = channels
if os.path.exists(model_id):
self.hf_dataset = load_from_disk(model_id)['train']
else:
self.hf_dataset = load_dataset(model_id)['train']
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
image = self.hf_dataset[idx]['image']
if self.channels == 3:
image = image.convert('RGB')
image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
(image.height, image.width, self.channels))
image = ((image / 255) * 2 - 1)
return {'image': image}
class AudioDiffusionDataModule(pl.LightningDataModule):
def __init__(self, model_id, batch_size, channels):
super().__init__()
self.batch_size = batch_size
self.dataset = AudioDiffusion(model_id=model_id, channels=channels)
self.num_workers = 1
def train_dataloader(self):
return DataLoader(self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers)
class ImageLogger(Callback):
def __init__(self, every=1000, channels=3, resolution=256, hop_length=512):
super().__init__()
self.mel = Mel(x_res=resolution,
y_res=resolution,
hop_length=hop_length)
self.every = every
self.channels = channels
@rank_zero_only
def log_images_and_audios(self, pl_module, batch):
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split='train')
pl_module.train()
for k in images:
images[k] = images[k].detach().cpu()
images[k] = torch.clamp(images[k], -1., 1.)
images[k] = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = torchvision.utils.make_grid(images[k])
tag = f"train/{k}"
pl_module.logger.experiment.add_image(
tag, grid, global_step=pl_module.global_step)
images[k] = (images[k].numpy() *
255).round().astype("uint8").transpose(0, 2, 3, 1)
for _, image in enumerate(images[k]):
audio = self.mel.image_to_audio(
Image.fromarray(image, mode='RGB').convert('L') if self.
channels == 3 else Image.fromarray(image[0]))
pl_module.logger.experiment.add_audio(
tag + f"/{_}",
normalize(audio),
global_step=pl_module.global_step,
sample_rate=self.mel.get_sample_rate())
def on_train_batch_end(self, trainer, pl_module, outputs, batch,
batch_idx):
if (batch_idx + 1) % self.every != 0:
return
self.log_images_and_audios(pl_module, batch)
class HFModelCheckpoint(ModelCheckpoint):
def __init__(self, ldm_config, hf_checkpoint, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ldm_config = ldm_config
self.hf_checkpoint = hf_checkpoint
def on_train_epoch_end(self, trainer, pl_module):
ldm_checkpoint = self._get_metric_interpolated_filepath_name(
{'epoch': trainer.current_epoch}, trainer)
super().on_train_epoch_end(trainer, pl_module)
convert_ldm_to_hf_vae(ldm_checkpoint, self.ldm_config,
self.hf_checkpoint)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train VAE using ldm.")
parser.add_argument("-d", "--dataset_name", type=str, default=None)
parser.add_argument("-b", "--batch_size", type=int, default=1)
parser.add_argument("-c",
"--ldm_config_file",
type=str,
default="config/ldm_autoencoder_kl.yaml")
parser.add_argument("--ldm_checkpoint_dir",
type=str,
default="models/ldm-autoencoder-kl")
parser.add_argument("--hf_checkpoint_dir",
type=str,
default="models/autoencoder-kl")
parser.add_argument("-r",
"--resume_from_checkpoint",
type=str,
default=None)
parser.add_argument("-g",
"--gradient_accumulation_steps",
type=int,
default=1)
parser.add_argument("--resolution", type=int, default=256)
parser.add_argument("--hop_length", type=int, default=512)
parser.add_argument("--save_images_batches", type=int, default=1000)
args = parser.parse_args()
config = OmegaConf.load(args.ldm_config_file)
model = instantiate_from_config(config.model)
model.learning_rate = config.model.base_learning_rate
data = AudioDiffusionDataModule(
model_id=args.dataset_name,
batch_size=args.batch_size,
channels=config.model.params.ddconfig.in_channels)
lightning_config = config.pop("lightning", OmegaConf.create())
trainer_config = lightning_config.get("trainer", OmegaConf.create())
trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
trainer_opt = argparse.Namespace(**trainer_config)
trainer = Trainer.from_argparse_args(
trainer_opt,
resume_from_checkpoint=args.resume_from_checkpoint,
callbacks=[
ImageLogger(every=args.save_images_batches,
channels=config.model.params.ddconfig.out_ch,
resolution=args.resolution,
hop_length=args.hop_length),
HFModelCheckpoint(ldm_config=config,
hf_checkpoint=args.hf_checkpoint_dir,
dirpath=args.ldm_checkpoint_dir,
filename='{epoch:06}',
verbose=True,
save_last=True)
])
trainer.fit(model, data)