import datetime import math import os import numpy as np import torch import torchaudio from funasr import AutoModel from pyannote.audio import Audio, Pipeline from pyannote.core import Segment # Load models model = AutoModel( model="FunAudioLLM/SenseVoiceSmall", # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", # vad_kwargs={"max_single_segment_time": 30000}, hub="hf", device="cuda" if torch.cuda.is_available() else "cpu", ) pyannote_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_TOKEN") ) if torch.cuda.is_available(): pyannote_pipeline.to(torch.device("cuda")) # Emoji dictionaries and formatting functions emo_dict = { "<|HAPPY|>": "😊", "<|SAD|>": "😔", "<|ANGRY|>": "😡", "<|NEUTRAL|>": "", "<|FEARFUL|>": "😰", "<|DISGUSTED|>": "🤢", "<|SURPRISED|>": "😮", } event_dict = { "<|BGM|>": "🎼", "<|Speech|>": "", "<|Applause|>": "👏", "<|Laughter|>": "😀", "<|Cry|>": "😭", "<|Sneeze|>": "🤧", "<|Breath|>": "", "<|Cough|>": "🤧", } emoji_dict = { "<|nospeech|><|Event_UNK|>": "❓", "<|zh|>": "", "<|en|>": "", "<|yue|>": "", "<|ja|>": "", "<|ko|>": "", "<|nospeech|>": "", "<|HAPPY|>": "😊", "<|SAD|>": "😔", "<|ANGRY|>": "😡", "<|NEUTRAL|>": "", "<|BGM|>": "🎼", "<|Speech|>": "", "<|Applause|>": "👏", "<|Laughter|>": "😀", "<|FEARFUL|>": "😰", "<|DISGUSTED|>": "🤢", "<|SURPRISED|>": "😮", "<|Cry|>": "😭", "<|EMO_UNKNOWN|>": "", "<|Sneeze|>": "🤧", "<|Breath|>": "", "<|Cough|>": "😷", "<|Sing|>": "", "<|Speech_Noise|>": "", "<|withitn|>": "", "<|woitn|>": "", "<|GBG|>": "", "<|Event_UNK|>": "", } lang_dict = { "<|zh|>": "<|lang|>", "<|en|>": "<|lang|>", "<|yue|>": "<|lang|>", "<|ja|>": "<|lang|>", "<|ko|>": "<|lang|>", "<|nospeech|>": "<|lang|>", } emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"} event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷"} def clean_and_emoji_annotate_speech(text): # Helper function to get the first emoji from a string that belongs to a given set def get_emoji(s, emoji_set): return next((char for char in s if char in emoji_set), None) # Helper function to format text with emojis based on special tokens def format_text_with_emojis(s): # Count occurrences of special tokens sptk_dict = {sptk: s.count(sptk) for sptk in emoji_dict} # Remove all special tokens from the text for sptk in emoji_dict: s = s.replace(sptk, "") # Determine the dominant emotion emo = "<|NEUTRAL|>" for e in emo_dict: if sptk_dict.get(e, 0) > sptk_dict.get(emo, 0): emo = e # Add event emojis at the beginning and emotion emoji at the end s = ( "".join(event_dict[e] for e in event_dict if sptk_dict.get(e, 0) > 0) + s + emo_dict[emo] ) # Remove spaces around emojis for emoji in emo_set.union(event_set): s = s.replace(f" {emoji}", emoji).replace(f"{emoji} ", emoji) return s.strip() # Replace special tags and language markers text = text.replace("<|nospeech|><|Event_UNK|>", "❓") for lang, replacement in lang_dict.items(): text = text.replace(lang, replacement) # Process each language segment segments = [ format_text_with_emojis(segment.strip()) for segment in text.split("<|lang|>") ] formatted_segments = [] prev_event = prev_emotion = None # Combine segments, avoiding duplicate emojis for segment in segments: if not segment: continue current_event = get_emoji(segment, event_set) current_emotion = get_emoji(segment, emo_set) # Remove leading event emoji if it's the same as the previous one if current_event is not None: segment = segment[1:] if segment.startswith(current_event) else segment # Move emotion emoji to the end if it's different from the previous one if current_emotion is not None and current_emotion != prev_emotion: segment = segment.replace(current_emotion, "") + current_emotion formatted_segments.append(segment.strip()) prev_event, prev_emotion = current_event, current_emotion # Join segments and remove unnecessary "The." at the end result = " ".join(formatted_segments).replace("The.", "").strip() return result def time_to_seconds(time_str): h, m, s = time_str.split(":") return round(int(h) * 3600 + int(m) * 60 + float(s), 9) def parse_time(time_str): # Remove 's' if present at the end of the string time_str = time_str.rstrip("s") # Split the time string into hours, minutes, and seconds parts = time_str.split(":") if len(parts) == 3: h, m, s = parts elif len(parts) == 2: h = "0" m, s = parts else: h = m = "0" s = parts[0] return int(h) * 3600 + int(m) * 60 + float(s) def format_time(seconds, use_short_format=True, always_use_seconds=False): if isinstance(seconds, datetime.timedelta): seconds = seconds.total_seconds() minutes, seconds = divmod(seconds, 60) hours, minutes = divmod(int(minutes), 60) if always_use_seconds or (use_short_format and hours == 0 and minutes == 0): return f"{seconds:06.3f}s" elif use_short_format and hours == 0: return f"{minutes:02d}:{seconds:06.3f}" else: return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}" def generate_diarization(audio_path): # Get the Hugging Face token from the environment variable hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise ValueError( "HF_TOKEN environment variable is not set. Please set it with your Hugging Face token." ) # Initialize the audio processor audio = Audio(sample_rate=16000, mono=True) # Load the pretrained pipeline pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=hf_token ) # Send pipeline to GPU if available if torch.cuda.is_available(): pipeline.to(torch.device("cuda")) # Use only the provided audio_path file_path = audio_path if not os.path.exists(file_path): raise FileNotFoundError(f"Could not find the audio file at: {file_path}") print(f"Using audio file: {file_path}") # Process the audio file waveform, sample_rate = audio(file_path) # Create a dictionary with the audio information file = {"waveform": waveform, "sample_rate": sample_rate, "uri": "mtr"} # Run the diarization output = pipeline(file) # Save results in human-readable format diarization_segments = [] txt_file = "mtr_dn.txt" with open(txt_file, "w") as f: current_speaker = None current_start = None current_end = None for turn, _, speaker in output.itertracks(yield_label=True): if speaker != current_speaker: if current_speaker is not None: start_time = format_time(current_start) end_time = format_time(current_end) duration = format_time(current_end - current_start) line = ( f"{start_time} - {end_time} ({duration}): {current_speaker}\n" ) f.write(line) print(line.strip()) diarization_segments.append( ( parse_time(start_time), parse_time(end_time), parse_time(duration), current_speaker, ) ) current_speaker = speaker current_start = turn.start current_end = turn.end else: current_end = turn.end # Write the last segment if current_speaker is not None: start_time = format_time(current_start) end_time = format_time(current_end) duration = format_time(current_end - current_start) line = f"{start_time} - {end_time} ({duration}): {current_speaker}\n" f.write(line) print(line.strip()) diarization_segments.append( ( parse_time(start_time), parse_time(end_time), parse_time(duration), current_speaker, ) ) print(f"\nHuman-readable diarization results saved to {txt_file}") return diarization_segments def process_audio(audio_path, language="yue", fs=16000): # Generate diarization segments diarization_segments = generate_diarization(audio_path) # Load and preprocess audio waveform, sample_rate = torchaudio.load(audio_path) if sample_rate != fs: resampler = torchaudio.transforms.Resample(sample_rate, fs) waveform = resampler(waveform) input_wav = waveform.mean(0).numpy() # Determine if the audio is less than one minute total_duration = sum(duration for _, _, duration, _ in diarization_segments) use_long_format = total_duration >= 60 # Process the audio in chunks based on diarization segments results = [] for start_time, end_time, duration, speaker in diarization_segments: start_seconds = start_time end_seconds = end_time # Convert time to sample indices start_sample = int(start_seconds * fs) end_sample = int(end_seconds * fs) chunk = input_wav[start_sample:end_sample] try: text = model.generate( input=chunk, cache={}, language=language, use_itn=True, batch_size_s=500, merge_vad=True, ) text = text[0]["text"] # Print the text before clean_and_emoji_annotate_speech print(f"Text before clean_and_emoji_annotate_speech: {text}") text = clean_and_emoji_annotate_speech(text) # Handle empty transcriptions if not text.strip(): text = "[inaudible]" results.append((speaker, start_time, end_time, duration, text)) except AssertionError as e: if "choose a window size" in str(e): print( f"Warning: Audio segment too short to process. Skipping. Error: {e}" ) results.append((speaker, start_time, end_time, duration, "[too short]")) else: raise # Format the results formatted_text = "" for speaker, start, end, duration, text in results: start_str = ( format_time(start, use_short_format=False) if use_long_format else format_time(start, use_short_format=True) ) end_str = ( format_time(end, use_short_format=False) if use_long_format else format_time(end, use_short_format=True) ) duration_str = format_time( duration, use_short_format=True ) # Always use short format for duration speaker_num = "1" if speaker == "SPEAKER_00" else "2" line = f"{start_str} - {end_str} ({duration_str}) Speaker {speaker_num}: {text}" formatted_text += line + "\n" print(f"Debug: Formatted line: {line}") print("Debug: Full formatted text:") print(formatted_text) return formatted_text.strip() if __name__ == "__main__": audio_path = "example/mtr.mp3" # Replace with your audio file path language = "yue" # Set language to Cantonese # Option to run only diarization diarization_only = False # Set this to True if you want only diarization if diarization_only: diarization_segments = generate_diarization(audio_path) # You can add code here to save or process the diarization results as needed else: result = process_audio(audio_path, language) # Save the result to mtr.txt output_path = "mtr.txt" with open(output_path, "w", encoding="utf-8") as f: f.write(result) print(f"Diarization and transcription result has been saved to {output_path}")