Spaces:
Runtime error
Runtime error
"""Data module. | |
Copyright PolyAI Limited. | |
""" | |
import typing | |
from pathlib import Path | |
from typing import List | |
import lightning.pytorch as pl | |
from torch.utils import data | |
from data.collation import GlobalCollater | |
from data.sampler import RandomBucketSampler | |
from data.single_speaker_dataset import QuantizeDataset | |
from utils import breakpoint_on_error | |
class ConcatDataset(data.ConcatDataset): | |
def __init__(self, datasets) -> None: | |
super().__init__(datasets) | |
self.lengths = [] | |
for dataset in datasets: | |
self.lengths.extend(dataset.lengths) | |
class DataModule(pl.LightningDataModule): | |
def __init__( | |
self, hp, metapath: List[str], val_metapath: List[str], | |
world_size, local_rank | |
): | |
super().__init__() | |
self.hp = hp | |
self.metapath = metapath | |
self.val_metapath = val_metapath | |
self.world_size = world_size | |
self.local_rank = local_rank | |
self.collater = GlobalCollater( | |
self.hp.n_codes, self.hp.n_semantic_codes) | |
def setup(self, stage: str) -> None: | |
if stage == "fit": | |
self.train_data = self.concatenate_datasets( | |
self.metapath, dataset_class=QuantizeDataset | |
) | |
if stage == "valid": | |
self.val_data = [] | |
self.val_data_keys = [] | |
self.prepare_val_datasets() | |
assert len(self.val_data) > 0 | |
assert len(self.val_data_keys) > 0 | |
def concatenate_datasets( | |
self, metapaths, dataset_class: typing.Type[QuantizeDataset]): | |
data = [] | |
for _, metapath in enumerate(metapaths): | |
metapath = Path(metapath) | |
# assumption that audios and audios-embeddings | |
# are in the same folder as metapath | |
datadir = metapath.with_name("audios") | |
assert datadir.exists() | |
data.append( | |
dataset_class( | |
self.hp, | |
metapath, | |
datadir=datadir, | |
speaker_embedding_dir=None, | |
) | |
) | |
return ConcatDataset(data) | |
def prepare_val_datasets(self): | |
for manifest in self.val_metapath: | |
self.val_data.append( | |
self.concatenate_datasets( | |
[manifest], dataset_class=QuantizeDataset) | |
) | |
name = Path(manifest).parent.name | |
self.val_data_keys.append(name) | |
assert len(self.val_data) == len(self.val_data_keys) | |
def train_dataloader(self): | |
length = self.train_data.lengths | |
sampler = RandomBucketSampler( | |
self.hp.train_bucket_size, | |
length, | |
self.hp.batch_size, | |
drop_last=True, | |
distributed=self.hp.distributed, | |
world_size=self.world_size, | |
rank=self.local_rank, | |
) | |
dataloader = data.DataLoader( | |
self.train_data, | |
num_workers=self.hp.nworkers, | |
batch_sampler=sampler, | |
collate_fn=self.collater.collate, | |
pin_memory=True | |
) | |
return dataloader | |
def val_dataloader(self): | |
val_loaders = [] | |
for dataset in self.val_data: | |
val_loaders.append( | |
data.DataLoader( | |
dataset, | |
num_workers=self.hp.nworkers, | |
batch_size=int(self.hp.batch_size), | |
collate_fn=self.collater.collate, | |
shuffle=False, | |
pin_memory=True | |
) | |
) | |
return val_loaders | |