File size: 5,099 Bytes
7165c71
 
 
 
 
 
f350aa0
7165c71
 
 
 
f350aa0
 
d29896d
f350aa0
7165c71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Beep profanity words in audio using one of the Hubert-compatible ASR models.
"""

import argparse
import re
import logging
import warnings

import transformers
import torch
import numpy as np
try:
    import soundfile
except (ImportError, OSError):
    warnings.warn("Cannot import soundfile. Reading/writing files will be unavailable")


log = logging.getLogger(__name__)


class HubertBeeper:

    PROFANITY = ["fuck", "shit", "piss"]

    def __init__(self, model_name="facebook/hubert-large-ls960-ft"):
        log.debug("Loading model: %s", model_name)
        self.model_name = model_name
        self.model = transformers.AutoModelForCTC.from_pretrained(model_name)
        self.model.eval()
        self.feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name)
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
        self.processor = transformers.Wav2Vec2Processor(
            feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)

    def asr(self, waveform, sample_rate):
        features = self.processor([waveform], sampling_rate=sample_rate)
        features = torch.tensor(features.input_values)
        output = self.model(features)
        return output

    def f_beep(self, sound_file_path: str) -> np.array:
        wav, sample_rate = soundfile.read(sound_file_path)
        text, result_wav = self.f_beep_waveform(wav, sample_rate)
        return result_wav

    def f_beep_waveform(self, wav: np.array, sample_rate: int) -> np.array:
        model_output = self.asr(wav, sample_rate)
        text, spans = find_words_in_audio(model_output, self.processor, self.PROFANITY)
        number_of_frames = model_output.logits.shape[1]
        frame_size = len(wav) / number_of_frames

        # Mask offsensive parts of the audio
        for frame_begin, frame_end in spans:
            begin = round(frame_begin * frame_size)
            end = round(frame_end * frame_size)
            self.generate_beep(wav, begin, end)

        return text, wav

    def generate_beep(self, wav, begin, end):
        """Generate a beep over the selected region in audio.

        Modifies waveform in place.
        """

        # Silence sounds better than beeps
        for i in range(begin, end):
            wav[i] = 0

def find_words_in_audio(model_output, processor, words):
    """Return all frame spans that matches any of the `words`.
    """

    result_spans = []

    token_ids = model_output.logits.argmax(dim=-1)[0]
    vocab = processor.tokenizer.get_vocab()
    text, offsets = decode_output_with_offsets(token_ids, vocab)
    text = text.lower()
    log.debug("ASR text: %s", text)

    for word in words:
        result_spans += find_spans(text, offsets, word)

    log.debug("Spans: %s", result_spans)

    return text, result_spans


def find_spans(text, offsets, word):
    """Return all frame indexes that correspond to the given `word`.
    """
    spans = []
    pattern = r"\b" + re.escape(word) + r"\b"
    for match in re.finditer(pattern, text):
        a = match.start() 
        b = match.end() + 1
        start_frame = offsets[a]
        end_frame = offsets[b] if b < len(offsets) else -1
        spans.append((start_frame, end_frame))
    return spans


def decode_output_with_offsets(decoded_token_ids, vocab):
    """Given list of decoded tokens, return text and
    time offsets that correspond to each character in the text.
    
    Args:
        decoded_token_ids (List[int]): list of token ids.
            The length of the list should be equal to the number
            of audio frames.
        vocab (Dict[str, int]): model's vocabulary.
        
    Returns:
        Tuple[str, List[int]], where
            `str` is a decoded text,
            `List[int]` is a starting frame indexes for
            every character in text.
    """
    token_by_index = {v: k for k, v in vocab.items()}
    prev_token = None
    result_string = []
    result_offset = []
    for i, token_id in enumerate(decoded_token_ids):
        token_id = token_id.item()
        if token_id == 0:
            continue
            
        token = token_by_index[token_id]
        if token == prev_token:
            continue
        
        result_string.append(token)
        result_offset.append(i)
        prev_token = token
        
    result_string = "".join(result_string).replace("|", " ")
    assert len(result_string) == len(result_offset)
    return result_string, result_offset



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input")
    parser.add_argument("-o", "--output")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--model", default="facebook/hubert-large-ls960-ft")
    args = parser.parse_args()
    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)

    beeper = HubertBeeper(args.model)
    result = beeper.f_beep(args.input)

    output = args.output or "result.wav"
    soundfile.write(output, result, 16000)
    print(f"Saved to {output}")


if __name__ == "__main__":
    main()