File size: 13,160 Bytes
4c01711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import fire
import logging
import sys, os
import yaml
import json
import torch
import librosa
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC
import transformers
import pandas as pd

logger = logging.getLogger(__name__)
# Setup logging
logger.setLevel(logging.ERROR)
console_handler = logging.StreamHandler()
formater = logging.Formatter(fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",)
console_handler.setFormatter(formater)
console_handler.setLevel(logging.ERROR)

logger.addHandler(console_handler)


class transcribe_SA():
    def __init__(self, model_path, verbose=0):
        if verbose == 0:
            logger.setLevel(logging.ERROR)
            transformers.logging.set_verbosity_error()
            #console_handler.setLevel(logging.ERROR)
        elif verbose == 1:
            logger.setLevel(logging.WARNING)
            transformers.logging.set_verbosity_warning()
            #console_handler.setLevel(logging.WARNING)
        else:
            logger.setLevel(logging.INFO)
            transformers.logging.set_verbosity_info()
            #console_handler.setLevel(logging.INFO)
        # Read YAML file
        logger.info('Init Object')
        if torch.cuda.is_available():
            self.accelerate = True
            self.device = torch.device('cuda')
            self.n_devices = torch.cuda.device_count()
            assert self.n_devices == 1, 'Support only single GPU. Please use CUDA_VISIBLE_DEVICES=gpu_index if you have multiple gpus' #Currently support only single gpu
        else:
            self.device = torch.device('cpu')
            self.n_devices = 1
        self.model_path = model_path
        self.load_model()
        self.get_available_attributes()
        self.get_att_binary_group_indexs()

    def load_model(self):
        if not os.path.exists(self.model_path):
            logger.error(f'Model file {self.model_path} is not exist')
            raise FileNotFoundError

        self.processor = Wav2Vec2Processor.from_pretrained(self.model_path)
        self.model = Wav2Vec2ForCTC.from_pretrained(self.model_path)
        self.pad_token_id = self.processor.tokenizer.pad_token_id
        self.sampling_rate = self.processor.feature_extractor.sampling_rate

    def get_available_attributes(self):
        if not hasattr(self, 'model'):
            logger.error('model not loaded, call load_model first!')
            raise AttributeError("model not defined")
        att_list = set(self.processor.tokenizer.get_vocab().keys()) - set(self.processor.tokenizer.all_special_tokens)
        att_list = [p.replace('p_','') for p in att_list if p[0]=='p']
        self.att_list = att_list

    def print_availabel_attributes(self):
        print(self.att_list)

    
    def get_att_binary_group_indexs(self):
        self.group_ids = [] #Each group contains the token_ids of [<PAD>, n_att, p_att] sorted by their token ids
        for i, att in enumerate(self.att_list):
            n_indx = self.processor.tokenizer.convert_tokens_to_ids(f'n_{att}')
            p_indx = self.processor.tokenizer.convert_tokens_to_ids(f'p_{att}')
            self.group_ids.append(sorted([self.pad_token_id, n_indx, p_indx]))

    def decode_att(self, logits, att): #Need to lowercase when first read from the user
        mask = torch.zeros(logits.size()[2], dtype = torch.bool)
        try:
            i = self.att_list.index(att)
        except ValueError:
            logger.error(f'The given attribute {att} not supported in the given model {self.model_path}')
            raise
        mask[self.group_ids[i]] = True
        logits_g = logits[:,:,mask]
        pred_ids = torch.argmax(logits_g,dim=-1)
        pred_ids = pred_ids.cpu().apply_(lambda x: self.group_ids[i][x])
        pred = self.processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0].split()
        return list(map(lambda x:{f'p_{att}':'+',f'n_{att}':'-'}[x], pred))

    def read_audio_file(self, audio_file):
        if not os.path.exists(audio_file):
            logger.error(f'Audio file {audio_file} is not exist')
            raise FileNotFoundError
        y, _ = librosa.load(audio_file, sr=self.sampling_rate)

        return y


    def get_logits(self, y):
        
        input_values = self.processor(audio=y, sampling_rate=self.sampling_rate, return_tensors="pt").input_values
        
        with torch.no_grad():
            logits = self.model(input_values).logits

        return logits


    def check_identical_phonemes(self, df_p2att):        
        identical_phonemes = []
        for index,row in df_p2att.iterrows():
            mask = df_p2att.eq(row).all(axis=1)    
            indexes = df_p2att[mask].index.values
            if len(indexes) > 1:
                identical_phonemes.append(tuple(indexes))
        if identical_phonemes:
            logger.warning('The following phonemes has identical phonological features given the phonological features used in the model. If using fixed weight layer, these phonemes will be confused with each other')
            identical_phonemes = set(identical_phonemes)
            for x in identical_phonemes:
                logger.warning(f"{','.join(x)}")

    def read_phoneme2att(self,p2att_file):

        if not os.path.exists(p2att_file):
            logger.error(f'Phonological matrix file {p2att_file} is not exist')
            raise FileNotFoundError(f'{p2att_file}')
        
        df_p2att = pd.read_csv(p2att_file, index_col=0)
        
        self.check_identical_phonemes(df_p2att)
        not_supported = set(df_p2att.columns) - set(self.att_list)
        if not_supported:
            logger.warning(f"Attribute/s {','.join(not_supported)} is not supported by the model {self.model_path} and will be ignored. To get available attributes of the selected model run transcribe --model_path=/path/to/model print_availabel_attributes")
            df_p2att = df_p2att.drop(columns=not_supported)
        
        self.phoneme_list = df_p2att.index.values
        self.p2att_map = {}
        for i, r in df_p2att.iterrows():
            phoneme = i
            self.p2att_map[phoneme] = []
            for att in r.index.values:
                if f'p_{att}' not in self.processor.tokenizer.vocab:
                    logger.warn(f'Attribute {att} is not supported by the model {self.model_path} and will be ignored. To get available attributes of the selected model run transcribe --model_path=/path/to/model print_availabel_attributes')
                    continue
                value = r[att]
                if value == 0:
                    self.p2att_map[phoneme].append(f'n_{att}')
                elif value == 1:
                    self.p2att_map[phoneme].append(f'p_{att}')
                else:
                    logger.error(f'Invalid value of {value} for attribute {att} of phoneme {phoneme}. Values in the phoneme to attribute map should be either 0 or 1')
                    raise ValueError(f'{value} should be 0 or 1')


    def create_phoneme_tokenizer(self):
        vocab_list = self.phoneme_list
        vocab_dict = {v: k+1 for k, v in enumerate(vocab_list)}
        vocab_dict['<pad>'] = 0
        vocab_dict = dict(sorted(vocab_dict.items(), key= lambda x: x[1]))
        vocab_file = 'phoneme_vocab.json'
        with open(vocab_file, 'w') as f:
            json.dump(vocab_dict, f)
        #Build processor
        self.phoneme_tokenizer = Wav2Vec2CTCTokenizer(vocab_file, pad_token="<pad>", word_delimiter_token="")
        
    def create_phonological_matrix(self):
        self.phonological_matrix = torch.zeros((self.phoneme_tokenizer.vocab_size, self.processor.tokenizer.vocab_size)).type(torch.FloatTensor)
        self.phonological_matrix[self.phoneme_tokenizer.pad_token_id, self.processor.tokenizer.pad_token_id] = 1
        for p in self.phoneme_list:
            for att in self.p2att_map[p]:
                self.phonological_matrix[self.phoneme_tokenizer.convert_tokens_to_ids(p), self.processor.tokenizer.convert_tokens_to_ids(att)] = 1
            

    #This function gets the attribute logits from the output layer and convert to phonemes
    #Input is a sequence of logits (one vector per frame) and output phoneme sequence
    #Note that this is CTC so number of output phonemes is not equal to number of input frames
    def decode_phoneme(self,logits):
        def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
            if mask is not None:
                mask = mask.float()
                while mask.dim() < vector.dim():
                    mask = mask.unsqueeze(1)
                # vector + mask.log() is an easy way to zero out masked elements in logspace, but it
                # results in nans when the whole vector is masked.  We need a very small value instead of a
                # zero in the mask for these cases.  log(1 + 1e-45) is still basically 0, so we can safely
                # just add 1e-45 before calling mask.log().  We use 1e-45 because 1e-46 is so small it
                # becomes 0 - this is just the smallest value we can actually use.
                vector = vector + (mask + 1e-45).log()
            return torch.nn.functional.log_softmax(vector, dim=dim)
        
        log_props_all_masked = []
        for i in range(len(self.att_list)):
            mask = torch.zeros(logits.size()[2], dtype = torch.bool)
            mask[self.group_ids[i]] = True
            mask.unsqueeze_(0).unsqueeze_(0)
            log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0)
            log_props_all_masked.append(log_probs)
        log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0)
        log_probs_phoneme = torch.matmul(self.phonological_matrix,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor)
        pred_ids = torch.argmax(log_probs_phoneme,dim=-1)
        pred = self.phoneme_tokenizer.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
        return pred

    
    def print_human_readable(self, output, with_phoneme = False):
            column_widths = []
            rows = []
            if with_phoneme:
                column_widths.append(max([len(att['Name']) for att in output['Attributes']]+[len('Phoneme')]))
                column_widths.extend([5]*max([len(att['Pattern']) for att in output['Attributes']]+[len(output['Phoneme']['symbols'])]))
                rows.append(('Phoneme'.center(column_widths[0]), *[s.center(column_widths[j+1]) for j,s in enumerate(output['Phoneme']['symbols'])]))
            else:
                column_widths.append(max([len(att['Name']) for att in output['Attributes']]))
                column_widths.extend([5]*max([len(att['Pattern']) for att in output['Attributes']]))
            for i in range(len(output['Attributes'])):
                att = output['Attributes'][i]
                rows.append((att['Name'].center(column_widths[0]), *[s.center(column_widths[j+1]) for j,s in enumerate(att['Pattern'])]))
            out_string = ''
            for row in rows:
                out_string += '|'.join(row)
                out_string += '\n'
            return out_string

    def transcribe(self, audio_file, 
                   attributes='all', 
                   phonological_matrix_file = None, 
                   human_readable = True):

        
        output = {}
        output['wav_file_path'] = audio_file
        output['Attributes'] = []
        output['Phoneme'] = {}
        
        #Initiate the model
        #self.load_model()
        #self.get_available_attributes()
        #self.get_att_binary_group_indexs()

        if attributes == 'all':
            target_attributes = self.att_list
        else:
            attributes = attributes if isinstance(attributes,tuple) else (attributes,)
            target_attributes = [att.lower() for att in attributes if att.lower() in self.att_list]
        
        if not target_attributes:
            logger.error(f'None of the given attributes is supported by model {self.model_path}. To get available attributes of the selected model run transcribe --model_path=/path/to/model get_available_attributes')
            raise ValueError("Invalid attributes")

        #Process audio
        y = self.read_audio_file(audio_file)
        self.logits = self.get_logits(y)
        
        for att in target_attributes:
            output['Attributes'].append({'Name':att, 'Pattern' : self.decode_att(self.logits, att)})

        if phonological_matrix_file:
            self.read_phoneme2att(phonological_matrix_file)
            self.create_phoneme_tokenizer()
            self.create_phonological_matrix()
            output['Phoneme']['symbols'] = self.decode_phoneme(self.logits).split()
            


        json_string = json.dumps(output, indent=4)
        if human_readable:
            return self.print_human_readable(output, phonological_matrix_file!=None)
        else:
            return json_string
        #return json_string


def main():
    fire.Fire(transcribe_SA)

if __name__ == '__main__':
    main()