terry-li-hm
Update
128a0e2
import datetime
import math
import os
import numpy as np
import torch
import torchaudio
from funasr import AutoModel
from pyannote.audio import Audio, Pipeline
from pyannote.core import Segment
# Load models
model = AutoModel(
model="FunAudioLLM/SenseVoiceSmall",
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_kwargs={"max_single_segment_time": 30000},
hub="hf",
device="cuda" if torch.cuda.is_available() else "cpu",
)
pyannote_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_TOKEN")
)
if torch.cuda.is_available():
pyannote_pipeline.to(torch.device("cuda"))
# Emoji dictionaries and formatting functions
emo_dict = {
"<|HAPPY|>": "๐Ÿ˜Š",
"<|SAD|>": "๐Ÿ˜”",
"<|ANGRY|>": "๐Ÿ˜ก",
"<|NEUTRAL|>": "",
"<|FEARFUL|>": "๐Ÿ˜ฐ",
"<|DISGUSTED|>": "๐Ÿคข",
"<|SURPRISED|>": "๐Ÿ˜ฎ",
}
event_dict = {
"<|BGM|>": "๐ŸŽผ",
"<|Speech|>": "",
"<|Applause|>": "๐Ÿ‘",
"<|Laughter|>": "๐Ÿ˜€",
"<|Cry|>": "๐Ÿ˜ญ",
"<|Sneeze|>": "๐Ÿคง",
"<|Breath|>": "",
"<|Cough|>": "๐Ÿคง",
}
emoji_dict = {
"<|nospeech|><|Event_UNK|>": "โ“",
"<|zh|>": "",
"<|en|>": "",
"<|yue|>": "",
"<|ja|>": "",
"<|ko|>": "",
"<|nospeech|>": "",
"<|HAPPY|>": "๐Ÿ˜Š",
"<|SAD|>": "๐Ÿ˜”",
"<|ANGRY|>": "๐Ÿ˜ก",
"<|NEUTRAL|>": "",
"<|BGM|>": "๐ŸŽผ",
"<|Speech|>": "",
"<|Applause|>": "๐Ÿ‘",
"<|Laughter|>": "๐Ÿ˜€",
"<|FEARFUL|>": "๐Ÿ˜ฐ",
"<|DISGUSTED|>": "๐Ÿคข",
"<|SURPRISED|>": "๐Ÿ˜ฎ",
"<|Cry|>": "๐Ÿ˜ญ",
"<|EMO_UNKNOWN|>": "",
"<|Sneeze|>": "๐Ÿคง",
"<|Breath|>": "",
"<|Cough|>": "๐Ÿ˜ท",
"<|Sing|>": "",
"<|Speech_Noise|>": "",
"<|withitn|>": "",
"<|woitn|>": "",
"<|GBG|>": "",
"<|Event_UNK|>": "",
}
lang_dict = {
"<|zh|>": "<|lang|>",
"<|en|>": "<|lang|>",
"<|yue|>": "<|lang|>",
"<|ja|>": "<|lang|>",
"<|ko|>": "<|lang|>",
"<|nospeech|>": "<|lang|>",
}
emo_set = {"๐Ÿ˜Š", "๐Ÿ˜”", "๐Ÿ˜ก", "๐Ÿ˜ฐ", "๐Ÿคข", "๐Ÿ˜ฎ"}
event_set = {"๐ŸŽผ", "๐Ÿ‘", "๐Ÿ˜€", "๐Ÿ˜ญ", "๐Ÿคง", "๐Ÿ˜ท"}
def clean_and_emoji_annotate_speech(text):
# Helper function to get the first emoji from a string that belongs to a given set
def get_emoji(s, emoji_set):
return next((char for char in s if char in emoji_set), None)
# Helper function to format text with emojis based on special tokens
def format_text_with_emojis(s):
# Count occurrences of special tokens
sptk_dict = {sptk: s.count(sptk) for sptk in emoji_dict}
# Remove all special tokens from the text
for sptk in emoji_dict:
s = s.replace(sptk, "")
# Determine the dominant emotion
emo = "<|NEUTRAL|>"
for e in emo_dict:
if sptk_dict.get(e, 0) > sptk_dict.get(emo, 0):
emo = e
# Add event emojis at the beginning and emotion emoji at the end
s = (
"".join(event_dict[e] for e in event_dict if sptk_dict.get(e, 0) > 0)
+ s
+ emo_dict[emo]
)
# Remove spaces around emojis
for emoji in emo_set.union(event_set):
s = s.replace(f" {emoji}", emoji).replace(f"{emoji} ", emoji)
return s.strip()
# Replace special tags and language markers
text = text.replace("<|nospeech|><|Event_UNK|>", "โ“")
for lang, replacement in lang_dict.items():
text = text.replace(lang, replacement)
# Process each language segment
segments = [
format_text_with_emojis(segment.strip()) for segment in text.split("<|lang|>")
]
formatted_segments = []
prev_event = prev_emotion = None
# Combine segments, avoiding duplicate emojis
for segment in segments:
if not segment:
continue
current_event = get_emoji(segment, event_set)
current_emotion = get_emoji(segment, emo_set)
# Remove leading event emoji if it's the same as the previous one
if current_event is not None:
segment = segment[1:] if segment.startswith(current_event) else segment
# Move emotion emoji to the end if it's different from the previous one
if current_emotion is not None and current_emotion != prev_emotion:
segment = segment.replace(current_emotion, "") + current_emotion
formatted_segments.append(segment.strip())
prev_event, prev_emotion = current_event, current_emotion
# Join segments and remove unnecessary "The." at the end
result = " ".join(formatted_segments).replace("The.", "").strip()
return result
def time_to_seconds(time_str):
h, m, s = time_str.split(":")
return round(int(h) * 3600 + int(m) * 60 + float(s), 9)
def parse_time(time_str):
# Remove 's' if present at the end of the string
time_str = time_str.rstrip("s")
# Split the time string into hours, minutes, and seconds
parts = time_str.split(":")
if len(parts) == 3:
h, m, s = parts
elif len(parts) == 2:
h = "0"
m, s = parts
else:
h = m = "0"
s = parts[0]
return int(h) * 3600 + int(m) * 60 + float(s)
def format_time(seconds, use_short_format=True, always_use_seconds=False):
if isinstance(seconds, datetime.timedelta):
seconds = seconds.total_seconds()
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(int(minutes), 60)
if always_use_seconds or (use_short_format and hours == 0 and minutes == 0):
return f"{seconds:06.3f}s"
elif use_short_format and hours == 0:
return f"{minutes:02d}:{seconds:06.3f}"
else:
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}"
def generate_diarization(audio_path):
# Get the Hugging Face token from the environment variable
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise ValueError(
"HF_TOKEN environment variable is not set. Please set it with your Hugging Face token."
)
# Initialize the audio processor
audio = Audio(sample_rate=16000, mono=True)
# Load the pretrained pipeline
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=hf_token
)
# Send pipeline to GPU if available
if torch.cuda.is_available():
pipeline.to(torch.device("cuda"))
# Use only the provided audio_path
file_path = audio_path
if not os.path.exists(file_path):
raise FileNotFoundError(f"Could not find the audio file at: {file_path}")
print(f"Using audio file: {file_path}")
# Process the audio file
waveform, sample_rate = audio(file_path)
# Create a dictionary with the audio information
file = {"waveform": waveform, "sample_rate": sample_rate, "uri": "mtr"}
# Run the diarization
output = pipeline(file)
# Save results in human-readable format
diarization_segments = []
txt_file = "mtr_dn.txt"
with open(txt_file, "w") as f:
current_speaker = None
current_start = None
current_end = None
for turn, _, speaker in output.itertracks(yield_label=True):
if speaker != current_speaker:
if current_speaker is not None:
start_time = format_time(current_start)
end_time = format_time(current_end)
duration = format_time(current_end - current_start)
line = (
f"{start_time} - {end_time} ({duration}): {current_speaker}\n"
)
f.write(line)
print(line.strip())
diarization_segments.append(
(
parse_time(start_time),
parse_time(end_time),
parse_time(duration),
current_speaker,
)
)
current_speaker = speaker
current_start = turn.start
current_end = turn.end
else:
current_end = turn.end
# Write the last segment
if current_speaker is not None:
start_time = format_time(current_start)
end_time = format_time(current_end)
duration = format_time(current_end - current_start)
line = f"{start_time} - {end_time} ({duration}): {current_speaker}\n"
f.write(line)
print(line.strip())
diarization_segments.append(
(
parse_time(start_time),
parse_time(end_time),
parse_time(duration),
current_speaker,
)
)
print(f"\nHuman-readable diarization results saved to {txt_file}")
return diarization_segments
def process_audio(audio_path, language="yue", fs=16000):
# Generate diarization segments
diarization_segments = generate_diarization(audio_path)
# Load and preprocess audio
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != fs:
resampler = torchaudio.transforms.Resample(sample_rate, fs)
waveform = resampler(waveform)
input_wav = waveform.mean(0).numpy()
# Determine if the audio is less than one minute
total_duration = sum(duration for _, _, duration, _ in diarization_segments)
use_long_format = total_duration >= 60
# Process the audio in chunks based on diarization segments
results = []
for start_time, end_time, duration, speaker in diarization_segments:
start_seconds = start_time
end_seconds = end_time
# Convert time to sample indices
start_sample = int(start_seconds * fs)
end_sample = int(end_seconds * fs)
chunk = input_wav[start_sample:end_sample]
try:
text = model.generate(
input=chunk,
cache={},
language=language,
use_itn=True,
batch_size_s=500,
merge_vad=True,
)
text = text[0]["text"]
# Print the text before clean_and_emoji_annotate_speech
print(f"Text before clean_and_emoji_annotate_speech: {text}")
text = clean_and_emoji_annotate_speech(text)
# Handle empty transcriptions
if not text.strip():
text = "[inaudible]"
results.append((speaker, start_time, end_time, duration, text))
except AssertionError as e:
if "choose a window size" in str(e):
print(
f"Warning: Audio segment too short to process. Skipping. Error: {e}"
)
results.append((speaker, start_time, end_time, duration, "[too short]"))
else:
raise
# Format the results
formatted_text = ""
for speaker, start, end, duration, text in results:
start_str = (
format_time(start, use_short_format=False)
if use_long_format
else format_time(start, use_short_format=True)
)
end_str = (
format_time(end, use_short_format=False)
if use_long_format
else format_time(end, use_short_format=True)
)
duration_str = format_time(
duration, use_short_format=True
) # Always use short format for duration
speaker_num = "1" if speaker == "SPEAKER_00" else "2"
line = f"{start_str} - {end_str} ({duration_str}) Speaker {speaker_num}: {text}"
formatted_text += line + "\n"
print(f"Debug: Formatted line: {line}")
print("Debug: Full formatted text:")
print(formatted_text)
return formatted_text.strip()
if __name__ == "__main__":
audio_path = "example/mtr.mp3" # Replace with your audio file path
language = "yue" # Set language to Cantonese
# Option to run only diarization
diarization_only = False # Set this to True if you want only diarization
if diarization_only:
diarization_segments = generate_diarization(audio_path)
# You can add code here to save or process the diarization results as needed
else:
result = process_audio(audio_path, language)
# Save the result to mtr.txt
output_path = "mtr.txt"
with open(output_path, "w", encoding="utf-8") as f:
f.write(result)
print(f"Diarization and transcription result has been saved to {output_path}")