hubert-fbeeper / fbeeper_hubert.py
osyvokon's picture
Catch more exceptions
d29896d
raw
history blame
5.1 kB
"""Beep profanity words in audio using one of the Hubert-compatible ASR models.
"""
import argparse
import re
import logging
import warnings
import transformers
import torch
import numpy as np
try:
import soundfile
except (ImportError, OSError):
warnings.warn("Cannot import soundfile. Reading/writing files will be unavailable")
log = logging.getLogger(__name__)
class HubertBeeper:
PROFANITY = ["fuck", "shit", "piss"]
def __init__(self, model_name="facebook/hubert-large-ls960-ft"):
log.debug("Loading model: %s", model_name)
self.model_name = model_name
self.model = transformers.AutoModelForCTC.from_pretrained(model_name)
self.model.eval()
self.feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
self.processor = transformers.Wav2Vec2Processor(
feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
def asr(self, waveform, sample_rate):
features = self.processor([waveform], sampling_rate=sample_rate)
features = torch.tensor(features.input_values)
output = self.model(features)
return output
def f_beep(self, sound_file_path: str) -> np.array:
wav, sample_rate = soundfile.read(sound_file_path)
text, result_wav = self.f_beep_waveform(wav, sample_rate)
return result_wav
def f_beep_waveform(self, wav: np.array, sample_rate: int) -> np.array:
model_output = self.asr(wav, sample_rate)
text, spans = find_words_in_audio(model_output, self.processor, self.PROFANITY)
number_of_frames = model_output.logits.shape[1]
frame_size = len(wav) / number_of_frames
# Mask offsensive parts of the audio
for frame_begin, frame_end in spans:
begin = round(frame_begin * frame_size)
end = round(frame_end * frame_size)
self.generate_beep(wav, begin, end)
return text, wav
def generate_beep(self, wav, begin, end):
"""Generate a beep over the selected region in audio.
Modifies waveform in place.
"""
# Silence sounds better than beeps
for i in range(begin, end):
wav[i] = 0
def find_words_in_audio(model_output, processor, words):
"""Return all frame spans that matches any of the `words`.
"""
result_spans = []
token_ids = model_output.logits.argmax(dim=-1)[0]
vocab = processor.tokenizer.get_vocab()
text, offsets = decode_output_with_offsets(token_ids, vocab)
text = text.lower()
log.debug("ASR text: %s", text)
for word in words:
result_spans += find_spans(text, offsets, word)
log.debug("Spans: %s", result_spans)
return text, result_spans
def find_spans(text, offsets, word):
"""Return all frame indexes that correspond to the given `word`.
"""
spans = []
pattern = r"\b" + re.escape(word) + r"\b"
for match in re.finditer(pattern, text):
a = match.start()
b = match.end() + 1
start_frame = offsets[a]
end_frame = offsets[b] if b < len(offsets) else -1
spans.append((start_frame, end_frame))
return spans
def decode_output_with_offsets(decoded_token_ids, vocab):
"""Given list of decoded tokens, return text and
time offsets that correspond to each character in the text.
Args:
decoded_token_ids (List[int]): list of token ids.
The length of the list should be equal to the number
of audio frames.
vocab (Dict[str, int]): model's vocabulary.
Returns:
Tuple[str, List[int]], where
`str` is a decoded text,
`List[int]` is a starting frame indexes for
every character in text.
"""
token_by_index = {v: k for k, v in vocab.items()}
prev_token = None
result_string = []
result_offset = []
for i, token_id in enumerate(decoded_token_ids):
token_id = token_id.item()
if token_id == 0:
continue
token = token_by_index[token_id]
if token == prev_token:
continue
result_string.append(token)
result_offset.append(i)
prev_token = token
result_string = "".join(result_string).replace("|", " ")
assert len(result_string) == len(result_offset)
return result_string, result_offset
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input")
parser.add_argument("-o", "--output")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("--model", default="facebook/hubert-large-ls960-ft")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
beeper = HubertBeeper(args.model)
result = beeper.f_beep(args.input)
output = args.output or "result.wav"
soundfile.write(output, result, 16000)
print(f"Saved to {output}")
if __name__ == "__main__":
main()