import fire import logging import sys, os import yaml import json import torch import librosa from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC import transformers import pandas as pd logger = logging.getLogger(__name__) # Setup logging logger.setLevel(logging.ERROR) console_handler = logging.StreamHandler() formater = logging.Formatter(fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S",) console_handler.setFormatter(formater) console_handler.setLevel(logging.ERROR) logger.addHandler(console_handler) class transcribe_SA(): def __init__(self, model_path, verbose=0): if verbose == 0: logger.setLevel(logging.ERROR) transformers.logging.set_verbosity_error() #console_handler.setLevel(logging.ERROR) elif verbose == 1: logger.setLevel(logging.WARNING) transformers.logging.set_verbosity_warning() #console_handler.setLevel(logging.WARNING) else: logger.setLevel(logging.INFO) transformers.logging.set_verbosity_info() #console_handler.setLevel(logging.INFO) # Read YAML file logger.info('Init Object') if torch.cuda.is_available(): self.accelerate = True self.device = torch.device('cuda') self.n_devices = torch.cuda.device_count() assert self.n_devices == 1, 'Support only single GPU. Please use CUDA_VISIBLE_DEVICES=gpu_index if you have multiple gpus' #Currently support only single gpu else: self.device = torch.device('cpu') self.n_devices = 1 self.model_path = model_path self.load_model() self.get_available_attributes() self.get_att_binary_group_indexs() def load_model(self): if not os.path.exists(self.model_path): logger.error(f'Model file {self.model_path} is not exist') raise FileNotFoundError self.processor = Wav2Vec2Processor.from_pretrained(self.model_path) self.model = Wav2Vec2ForCTC.from_pretrained(self.model_path) self.pad_token_id = self.processor.tokenizer.pad_token_id self.sampling_rate = self.processor.feature_extractor.sampling_rate def get_available_attributes(self): if not hasattr(self, 'model'): logger.error('model not loaded, call load_model first!') raise AttributeError("model not defined") att_list = set(self.processor.tokenizer.get_vocab().keys()) - set(self.processor.tokenizer.all_special_tokens) att_list = [p.replace('p_','') for p in att_list if p[0]=='p'] self.att_list = att_list def print_availabel_attributes(self): print(self.att_list) def get_att_binary_group_indexs(self): self.group_ids = [] #Each group contains the token_ids of [, n_att, p_att] sorted by their token ids for i, att in enumerate(self.att_list): n_indx = self.processor.tokenizer.convert_tokens_to_ids(f'n_{att}') p_indx = self.processor.tokenizer.convert_tokens_to_ids(f'p_{att}') self.group_ids.append(sorted([self.pad_token_id, n_indx, p_indx])) def decode_att(self, logits, att): #Need to lowercase when first read from the user mask = torch.zeros(logits.size()[2], dtype = torch.bool) try: i = self.att_list.index(att) except ValueError: logger.error(f'The given attribute {att} not supported in the given model {self.model_path}') raise mask[self.group_ids[i]] = True logits_g = logits[:,:,mask] pred_ids = torch.argmax(logits_g,dim=-1) pred_ids = pred_ids.cpu().apply_(lambda x: self.group_ids[i][x]) pred = self.processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0].split() return list(map(lambda x:{f'p_{att}':'+',f'n_{att}':'-'}[x], pred)) def read_audio_file(self, audio_file): if not os.path.exists(audio_file): logger.error(f'Audio file {audio_file} is not exist') raise FileNotFoundError y, _ = librosa.load(audio_file, sr=self.sampling_rate) return y def get_logits(self, y): input_values = self.processor(audio=y, sampling_rate=self.sampling_rate, return_tensors="pt").input_values with torch.no_grad(): logits = self.model(input_values).logits return logits def check_identical_phonemes(self, df_p2att): identical_phonemes = [] for index,row in df_p2att.iterrows(): mask = df_p2att.eq(row).all(axis=1) indexes = df_p2att[mask].index.values if len(indexes) > 1: identical_phonemes.append(tuple(indexes)) if identical_phonemes: logger.warning('The following phonemes has identical phonological features given the phonological features used in the model. If using fixed weight layer, these phonemes will be confused with each other') identical_phonemes = set(identical_phonemes) for x in identical_phonemes: logger.warning(f"{','.join(x)}") def read_phoneme2att(self,p2att_file): if not os.path.exists(p2att_file): logger.error(f'Phonological matrix file {p2att_file} is not exist') raise FileNotFoundError(f'{p2att_file}') df_p2att = pd.read_csv(p2att_file, index_col=0) self.check_identical_phonemes(df_p2att) not_supported = set(df_p2att.columns) - set(self.att_list) if not_supported: logger.warning(f"Attribute/s {','.join(not_supported)} is not supported by the model {self.model_path} and will be ignored. To get available attributes of the selected model run transcribe --model_path=/path/to/model print_availabel_attributes") df_p2att = df_p2att.drop(columns=not_supported) self.phoneme_list = df_p2att.index.values self.p2att_map = {} for i, r in df_p2att.iterrows(): phoneme = i self.p2att_map[phoneme] = [] for att in r.index.values: if f'p_{att}' not in self.processor.tokenizer.vocab: logger.warn(f'Attribute {att} is not supported by the model {self.model_path} and will be ignored. To get available attributes of the selected model run transcribe --model_path=/path/to/model print_availabel_attributes') continue value = r[att] if value == 0: self.p2att_map[phoneme].append(f'n_{att}') elif value == 1: self.p2att_map[phoneme].append(f'p_{att}') else: logger.error(f'Invalid value of {value} for attribute {att} of phoneme {phoneme}. Values in the phoneme to attribute map should be either 0 or 1') raise ValueError(f'{value} should be 0 or 1') def create_phoneme_tokenizer(self): vocab_list = self.phoneme_list vocab_dict = {v: k+1 for k, v in enumerate(vocab_list)} vocab_dict[''] = 0 vocab_dict = dict(sorted(vocab_dict.items(), key= lambda x: x[1])) vocab_file = 'phoneme_vocab.json' with open(vocab_file, 'w') as f: json.dump(vocab_dict, f) #Build processor self.phoneme_tokenizer = Wav2Vec2CTCTokenizer(vocab_file, pad_token="", word_delimiter_token="") def create_phonological_matrix(self): self.phonological_matrix = torch.zeros((self.phoneme_tokenizer.vocab_size, self.processor.tokenizer.vocab_size)).type(torch.FloatTensor) self.phonological_matrix[self.phoneme_tokenizer.pad_token_id, self.processor.tokenizer.pad_token_id] = 1 for p in self.phoneme_list: for att in self.p2att_map[p]: self.phonological_matrix[self.phoneme_tokenizer.convert_tokens_to_ids(p), self.processor.tokenizer.convert_tokens_to_ids(att)] = 1 #This function gets the attribute logits from the output layer and convert to phonemes #Input is a sequence of logits (one vector per frame) and output phoneme sequence #Note that this is CTC so number of output phonemes is not equal to number of input frames def decode_phoneme(self,logits): def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: if mask is not None: mask = mask.float() while mask.dim() < vector.dim(): mask = mask.unsqueeze(1) # vector + mask.log() is an easy way to zero out masked elements in logspace, but it # results in nans when the whole vector is masked. We need a very small value instead of a # zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely # just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it # becomes 0 - this is just the smallest value we can actually use. vector = vector + (mask + 1e-45).log() return torch.nn.functional.log_softmax(vector, dim=dim) log_props_all_masked = [] for i in range(len(self.att_list)): mask = torch.zeros(logits.size()[2], dtype = torch.bool) mask[self.group_ids[i]] = True mask.unsqueeze_(0).unsqueeze_(0) log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0) log_props_all_masked.append(log_probs) log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0) log_probs_phoneme = torch.matmul(self.phonological_matrix,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor) pred_ids = torch.argmax(log_probs_phoneme,dim=-1) pred = self.phoneme_tokenizer.batch_decode(pred_ids,spaces_between_special_tokens=True)[0] return pred def print_human_readable(self, output, with_phoneme = False): column_widths = [] rows = [] if with_phoneme: column_widths.append(max([len(att['Name']) for att in output['Attributes']]+[len('Phoneme')])) column_widths.extend([5]*max([len(att['Pattern']) for att in output['Attributes']]+[len(output['Phoneme']['symbols'])])) rows.append(('Phoneme'.center(column_widths[0]), *[s.center(column_widths[j+1]) for j,s in enumerate(output['Phoneme']['symbols'])])) else: column_widths.append(max([len(att['Name']) for att in output['Attributes']])) column_widths.extend([5]*max([len(att['Pattern']) for att in output['Attributes']])) for i in range(len(output['Attributes'])): att = output['Attributes'][i] rows.append((att['Name'].center(column_widths[0]), *[s.center(column_widths[j+1]) for j,s in enumerate(att['Pattern'])])) out_string = '' for row in rows: out_string += '|'.join(row) out_string += '\n' return out_string def transcribe(self, audio_file, attributes='all', phonological_matrix_file = None, human_readable = True): output = {} output['wav_file_path'] = audio_file output['Attributes'] = [] output['Phoneme'] = {} #Initiate the model #self.load_model() #self.get_available_attributes() #self.get_att_binary_group_indexs() if attributes == 'all': target_attributes = self.att_list else: attributes = attributes if isinstance(attributes,tuple) else (attributes,) target_attributes = [att.lower() for att in attributes if att.lower() in self.att_list] if not target_attributes: logger.error(f'None of the given attributes is supported by model {self.model_path}. To get available attributes of the selected model run transcribe --model_path=/path/to/model get_available_attributes') raise ValueError("Invalid attributes") #Process audio y = self.read_audio_file(audio_file) self.logits = self.get_logits(y) for att in target_attributes: output['Attributes'].append({'Name':att, 'Pattern' : self.decode_att(self.logits, att)}) if phonological_matrix_file: self.read_phoneme2att(phonological_matrix_file) self.create_phoneme_tokenizer() self.create_phonological_matrix() output['Phoneme']['symbols'] = self.decode_phoneme(self.logits).split() json_string = json.dumps(output, indent=4) if human_readable: return self.print_human_readable(output, phonological_matrix_file!=None) else: return json_string #return json_string def main(): fire.Fire(transcribe_SA) if __name__ == '__main__': main()