Spaces:
Runtime error
Runtime error
File size: 5,099 Bytes
7165c71 f350aa0 7165c71 f350aa0 d29896d f350aa0 7165c71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
"""Beep profanity words in audio using one of the Hubert-compatible ASR models.
"""
import argparse
import re
import logging
import warnings
import transformers
import torch
import numpy as np
try:
import soundfile
except (ImportError, OSError):
warnings.warn("Cannot import soundfile. Reading/writing files will be unavailable")
log = logging.getLogger(__name__)
class HubertBeeper:
PROFANITY = ["fuck", "shit", "piss"]
def __init__(self, model_name="facebook/hubert-large-ls960-ft"):
log.debug("Loading model: %s", model_name)
self.model_name = model_name
self.model = transformers.AutoModelForCTC.from_pretrained(model_name)
self.model.eval()
self.feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
self.processor = transformers.Wav2Vec2Processor(
feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
def asr(self, waveform, sample_rate):
features = self.processor([waveform], sampling_rate=sample_rate)
features = torch.tensor(features.input_values)
output = self.model(features)
return output
def f_beep(self, sound_file_path: str) -> np.array:
wav, sample_rate = soundfile.read(sound_file_path)
text, result_wav = self.f_beep_waveform(wav, sample_rate)
return result_wav
def f_beep_waveform(self, wav: np.array, sample_rate: int) -> np.array:
model_output = self.asr(wav, sample_rate)
text, spans = find_words_in_audio(model_output, self.processor, self.PROFANITY)
number_of_frames = model_output.logits.shape[1]
frame_size = len(wav) / number_of_frames
# Mask offsensive parts of the audio
for frame_begin, frame_end in spans:
begin = round(frame_begin * frame_size)
end = round(frame_end * frame_size)
self.generate_beep(wav, begin, end)
return text, wav
def generate_beep(self, wav, begin, end):
"""Generate a beep over the selected region in audio.
Modifies waveform in place.
"""
# Silence sounds better than beeps
for i in range(begin, end):
wav[i] = 0
def find_words_in_audio(model_output, processor, words):
"""Return all frame spans that matches any of the `words`.
"""
result_spans = []
token_ids = model_output.logits.argmax(dim=-1)[0]
vocab = processor.tokenizer.get_vocab()
text, offsets = decode_output_with_offsets(token_ids, vocab)
text = text.lower()
log.debug("ASR text: %s", text)
for word in words:
result_spans += find_spans(text, offsets, word)
log.debug("Spans: %s", result_spans)
return text, result_spans
def find_spans(text, offsets, word):
"""Return all frame indexes that correspond to the given `word`.
"""
spans = []
pattern = r"\b" + re.escape(word) + r"\b"
for match in re.finditer(pattern, text):
a = match.start()
b = match.end() + 1
start_frame = offsets[a]
end_frame = offsets[b] if b < len(offsets) else -1
spans.append((start_frame, end_frame))
return spans
def decode_output_with_offsets(decoded_token_ids, vocab):
"""Given list of decoded tokens, return text and
time offsets that correspond to each character in the text.
Args:
decoded_token_ids (List[int]): list of token ids.
The length of the list should be equal to the number
of audio frames.
vocab (Dict[str, int]): model's vocabulary.
Returns:
Tuple[str, List[int]], where
`str` is a decoded text,
`List[int]` is a starting frame indexes for
every character in text.
"""
token_by_index = {v: k for k, v in vocab.items()}
prev_token = None
result_string = []
result_offset = []
for i, token_id in enumerate(decoded_token_ids):
token_id = token_id.item()
if token_id == 0:
continue
token = token_by_index[token_id]
if token == prev_token:
continue
result_string.append(token)
result_offset.append(i)
prev_token = token
result_string = "".join(result_string).replace("|", " ")
assert len(result_string) == len(result_offset)
return result_string, result_offset
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input")
parser.add_argument("-o", "--output")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("--model", default="facebook/hubert-large-ls960-ft")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
beeper = HubertBeeper(args.model)
result = beeper.f_beep(args.input)
output = args.output or "result.wav"
soundfile.write(output, result, 16000)
print(f"Saved to {output}")
if __name__ == "__main__":
main()
|