Phone-aid / Phonemize.py
Mostafa Shahin
First Commit
4c01711
raw
history blame
No virus
3.97 kB
from datasets import load_from_disk
from dp.phonemizer import Phonemizer
from speechbrain.pretrained import GraphemeToPhoneme
import cmudict
import re
import fire
import torch
from os.path import join
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
torch.multiprocessing.set_start_method('spawn')
class phonemization:
def __init__(self):
self.chars_to_ignore_regex = r'[,?.!-;:"]'
self.dp_phonemizer_model_path = join('models','d_phonemizer','en_us_cmudict_forward.pt')
self.sb_phonemizer_model_path = join('models','sb_phonemizer')
self.cmu_dict = cmudict.dict()
self.dp_phonemizer = Phonemizer.from_checkpoint(self.dp_phonemizer_model_path)
if torch.cuda.is_available():
self.sb_phonemizer = GraphemeToPhoneme.from_hparams(self.sb_phonemizer_model_path,run_opts={"device":"cuda"})
else:
self.sb_phonemizer = GraphemeToPhoneme.from_hparams(self.sb_phonemizer_model_path)
self.normalize = False
def dp_phonemize(self, text):
return self.dp_phonemizer(text, lang='en_us',expand_acronyms=False).replace('[',' ').replace(']',' ').split()
def cmu_phonemize(self,
text,
fallback_phonemizer=dp_phonemize):
phoneme_lst=[]
for word in text.split():
if word in self.cmu_dict:
phoneme_lst.extend(re.sub('[0-9]','',' '.join(self.cmu_dict.get(word)[0])).split())
else:
phoneme_lst.extend(fallback_phonemizer(self,word))
phoneme_lst = [p.lower() for p in phoneme_lst]
return(phoneme_lst)
def sb_phonemize(self,text):
return self.sb_phonemizer(text)
def remove_special_characters(self,text):
#print(text)
return re.sub(self.chars_to_ignore_regex, ' ', text).lower() + " "
def replace_multiple_spaces_with_single_space(self, input_string):
"""Replace multiple spaces with a single space."""
return re.sub(r'\s+', ' ', input_string)
def phonemize_batch(self,
batch,
phonamizer_fn=dp_phonemize,
suffix=''):
if self.normalize:
text = batch['text_norm'].lower()
else:
text = batch['text'].lower()
phoneme_str = ' '.join(phonamizer_fn(text))
phoneme_str = phoneme_str.lower()
phoneme_str = self.replace_multiple_spaces_with_single_space(phoneme_str)
batch[f'phoneme{suffix}'] = phoneme_str.strip()
return batch
def remove_special_characters_batch(self, batch):
batch["text_norm"] = self.remove_special_characters(batch["text"])
return batch
def run(self,
dataset_path,
output_path,
phonemizers='dp,sb,cmu',
normalize=True,
nproc=1):
data = load_from_disk(dataset_path)
if normalize:
data = data.map(self.remove_special_characters_batch, num_proc=nproc)
for phonemizer in phonemizers.split(','):
if phonemizer == 'cmu':
data = data.map(self.phonemize_batch, fn_kwargs={'phonamizer_fn':self.cmu_phonemize,'suffix':'_cmu'},num_proc=nproc)
if phonemizer == 'dp':
data = data.map(self.phonemize_batch, fn_kwargs={'phonamizer_fn':self.dp_phonemize,'suffix':'_dp'},num_proc=nproc)
if phonemizer == 'sb':
if torch.cuda.is_available():
nproc = torch.cuda.device_count()
data = data.map(self.phonemize_batch, fn_kwargs={'phonamizer_fn':self.sb_phonemize,'suffix':'_sb'},num_proc=nproc, cache_file_name='/g/data/iv96/mostafa/cache_sb', load_from_cache_file=True)
data.save_to_disk(output_path)
if __name__=='__main__':
fire.Fire(phonemization)