Spaces:
Runtime error
Runtime error
"""Beep profanity words in audio using one of the Hubert-compatible ASR models. | |
""" | |
import argparse | |
import re | |
import logging | |
import warnings | |
import transformers | |
import torch | |
import numpy as np | |
try: | |
import soundfile | |
except (ImportError, OSError): | |
warnings.warn("Cannot import soundfile. Reading/writing files will be unavailable") | |
log = logging.getLogger(__name__) | |
class HubertBeeper: | |
PROFANITY = ["fuck", "shit", "piss"] | |
def __init__(self, model_name="facebook/hubert-large-ls960-ft"): | |
log.debug("Loading model: %s", model_name) | |
self.model_name = model_name | |
self.model = transformers.AutoModelForCTC.from_pretrained(model_name) | |
self.model.eval() | |
self.feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
self.processor = transformers.Wav2Vec2Processor( | |
feature_extractor=self.feature_extractor, tokenizer=self.tokenizer) | |
def asr(self, waveform, sample_rate): | |
features = self.processor([waveform], sampling_rate=sample_rate) | |
features = torch.tensor(features.input_values) | |
output = self.model(features) | |
return output | |
def f_beep(self, sound_file_path: str) -> np.array: | |
wav, sample_rate = soundfile.read(sound_file_path) | |
text, result_wav = self.f_beep_waveform(wav, sample_rate) | |
return result_wav | |
def f_beep_waveform(self, wav: np.array, sample_rate: int) -> np.array: | |
model_output = self.asr(wav, sample_rate) | |
text, spans = find_words_in_audio(model_output, self.processor, self.PROFANITY) | |
number_of_frames = model_output.logits.shape[1] | |
frame_size = len(wav) / number_of_frames | |
# Mask offsensive parts of the audio | |
for frame_begin, frame_end in spans: | |
begin = round(frame_begin * frame_size) | |
end = round(frame_end * frame_size) | |
self.generate_beep(wav, begin, end) | |
return text, wav | |
def generate_beep(self, wav, begin, end): | |
"""Generate a beep over the selected region in audio. | |
Modifies waveform in place. | |
""" | |
# Silence sounds better than beeps | |
for i in range(begin, end): | |
wav[i] = 0 | |
def find_words_in_audio(model_output, processor, words): | |
"""Return all frame spans that matches any of the `words`. | |
""" | |
result_spans = [] | |
token_ids = model_output.logits.argmax(dim=-1)[0] | |
vocab = processor.tokenizer.get_vocab() | |
text, offsets = decode_output_with_offsets(token_ids, vocab) | |
text = text.lower() | |
log.debug("ASR text: %s", text) | |
for word in words: | |
result_spans += find_spans(text, offsets, word) | |
log.debug("Spans: %s", result_spans) | |
return text, result_spans | |
def find_spans(text, offsets, word): | |
"""Return all frame indexes that correspond to the given `word`. | |
""" | |
spans = [] | |
pattern = r"\b" + re.escape(word) + r"\b" | |
for match in re.finditer(pattern, text): | |
a = match.start() | |
b = match.end() + 1 | |
start_frame = offsets[a] | |
end_frame = offsets[b] if b < len(offsets) else -1 | |
spans.append((start_frame, end_frame)) | |
return spans | |
def decode_output_with_offsets(decoded_token_ids, vocab): | |
"""Given list of decoded tokens, return text and | |
time offsets that correspond to each character in the text. | |
Args: | |
decoded_token_ids (List[int]): list of token ids. | |
The length of the list should be equal to the number | |
of audio frames. | |
vocab (Dict[str, int]): model's vocabulary. | |
Returns: | |
Tuple[str, List[int]], where | |
`str` is a decoded text, | |
`List[int]` is a starting frame indexes for | |
every character in text. | |
""" | |
token_by_index = {v: k for k, v in vocab.items()} | |
prev_token = None | |
result_string = [] | |
result_offset = [] | |
for i, token_id in enumerate(decoded_token_ids): | |
token_id = token_id.item() | |
if token_id == 0: | |
continue | |
token = token_by_index[token_id] | |
if token == prev_token: | |
continue | |
result_string.append(token) | |
result_offset.append(i) | |
prev_token = token | |
result_string = "".join(result_string).replace("|", " ") | |
assert len(result_string) == len(result_offset) | |
return result_string, result_offset | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("input") | |
parser.add_argument("-o", "--output") | |
parser.add_argument("-v", "--verbose", action="store_true") | |
parser.add_argument("--model", default="facebook/hubert-large-ls960-ft") | |
args = parser.parse_args() | |
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) | |
beeper = HubertBeeper(args.model) | |
result = beeper.f_beep(args.input) | |
output = args.output or "result.wav" | |
soundfile.write(output, result, 16000) | |
print(f"Saved to {output}") | |
if __name__ == "__main__": | |
main() | |