deploy-s2s-api / data /dataset.py
3v324v23's picture
Add application file
ad48e75
raw
history blame contribute delete
No virus
9.53 kB
# Copyright 2023 (authors: Feiteng Li)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
modified from lhoste.dataset.speech_synthesis.py
"""
import torch
import math
import h5py
from tokenizers import Tokenizer
from typing import Union, List
import numpy as np
from tqdm import tqdm
_pad = '_'
_punctuation = ',.!?-~…'
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
symbols = [_pad] + list(_punctuation) + list(_letters)
language_dict = {
'en': 0,
'zh': 1,
'ja': 2,
}
def seq2phone(tokens: Union[List, np.ndarray]):
"""
Convert tokenized phoneme ID sequence back to phoneme string
:param tokens: phoneme tokens
:return: recovered phoneme sequence
"""
phones = "".join([symbols[i] for i in tokens])
return phones
class DynamicBatchSampler(torch.utils.data.Sampler):
def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
max_tokens=None, max_sentences=None, drop_last=False):
"""
:param sampler:
:param num_tokens_fn: 根据idx返回样本的长度的函数
:param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
:param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
:param max_size: 最大长度的样本
:param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
"""
super(DynamicBatchSampler, self).__init__(sampler)
self.sampler = sampler
self.num_tokens_fn = num_tokens_fn
self.num_buckets = num_buckets
self.min_size = min_size
self.max_size = max_size
assert max_size <= max_tokens, "max_size should be smaller than max tokens"
assert max_tokens is not None or max_sentences is not None, \
"max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
self.drop_last = drop_last
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def is_batch_full(self, num_tokens, batch):
if len(batch) == 0:
return False
if len(batch) == self.max_sentences:
return True
if num_tokens > self.max_tokens:
return True
return False
def __iter__(self):
buckets = [[] for _ in range(self.num_buckets)]
sample_len = [0] * self.num_buckets
for idx in self.sampler:
idx_length = self.num_tokens_fn(idx)
if not (self.min_size <= idx_length <= self.max_size):
print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
continue
index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
* self.num_buckets)
sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
if self.is_batch_full(num_tokens, buckets[index_buckets]):
# yield this batch
yield buckets[index_buckets]
buckets[index_buckets] = []
sample_len[index_buckets] = 0
buckets[index_buckets].append(idx)
# process left-over
leftover_batch = []
leftover_sample_len = 0
leftover = [idx for bucket in buckets for idx in bucket]
for idx in leftover:
idx_length = self.num_tokens_fn(idx)
leftover_sample_len = max(leftover_sample_len, idx_length)
num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
if self.is_batch_full(num_tokens, leftover_batch):
yield leftover_batch
leftover_batch = []
leftover_sample_len = 0
leftover_batch.append(idx)
if len(leftover_batch) > 0 and not self.drop_last:
yield leftover_batch
def __len__(self):
# we do not know the exactly batch size, so do not call len(dataloader)
pass
class AudioDataset(torch.utils.data.Dataset):
def __init__(self, h5_path, ann_path, tokenizer_path):
self.h5_path = h5_path
with open(ann_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
ls = [l.split("|") for l in lines]
ls_T = list(zip(*ls))
del ls_T[-1]
self.h5_paths, self.durations, self.langs, self.texts = \
list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
self.durations = [float(dur) for dur in self.durations]
self.tokenizer = Tokenizer.from_file(tokenizer_path)
self._archive = None
def __len__(self):
return len(self.h5_paths)
def get_dur(self, idx):
return self.durations[idx]
@property
def archive(self):
if self._archive is None: # lazy loading here!
self._archive = h5py.File(self.h5_path, "r")
return self._archive
def __getitem__(self, idx):
archive = self.archive
h5_path = self.h5_paths[idx]
sub = archive[h5_path]
audio_tokens = sub['audio'][()]
phone_tokens = sub['text'][()]
dur = self.durations[idx]
lang = self.langs[idx]
text = self.texts[idx]
# tokenization should be done within dataloader
phones = seq2phone(phone_tokens)
phones = phones.replace(" ", "_")
if not len(phones):
cptpho_tokens = self.tokenizer.encode(text).ids
else:
cptpho_tokens = self.tokenizer.encode(phones).ids
assert len(cptpho_tokens)
return {
'utt_id': h5_path,
'text': text,
'audio': None,
'audio_lens': None,
'audio_features': audio_tokens,
'audio_features_lens': len(audio_tokens.T),
'text_tokens': np.array(cptpho_tokens),
'text_tokens_lens': len(cptpho_tokens),
'language': language_dict[lang],
}
def collate(batch):
utt_id_s = [b['utt_id'] for b in batch]
text_s = [b['text'] for b in batch]
audio_s = [b['audio'] for b in batch]
audio_lens_s = [b['audio_lens'] for b in batch]
audio_features_lens_s = [b['audio_features_lens'] for b in batch]
# create an empty tensor with maximum audio feature length
audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
# create an empty tensor with maximum text tokens length
text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
language_s = [b['language'] for b in batch]
for i, b in enumerate(batch):
audio_features = b['audio_features']
audio_features_lens = b['audio_features_lens']
audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T)
text_tokens = b['text_tokens']
text_tokens_lens = b['text_tokens_lens']
text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
batch = {
'utt_id': utt_id_s,
'text': text_s,
'audio': audio_s,
'audio_lens': audio_lens_s,
'audio_features': audio_features_s,
'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
'text_tokens': text_tokens_s,
'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
'languages': torch.LongTensor(np.array(language_s)),
}
return batch
def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
ann_path=f"{data_dir}/audio_ann_sum.txt",
tokenizer_path=f"{data_dir}/bpe_69.json")
ran_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=n_gpus,
rank=rank,
shuffle=True,
)
dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
max_tokens=max_duration)
train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
batch_sampler=dynamic_sampler)
return train_loader