import re import requests import pyarrow as pa import librosa import torch from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer from fastapi import FastAPI, File, UploadFile import warnings from starlette.formparsers import MultiPartParser import io MultiPartParser.max_file_size = 200 * 1024 * 1024 # Initialize FastAPI app app = FastAPI() # Load Wav2Vec2 tokenizer and model tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") # Function to download English word list def download_word_list(): print("Downloading English word list...") url = "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt" response = requests.get(url) words = set(response.text.split()) print("Word list downloaded.") return words english_words = download_word_list() # Function to count correctly spelled words in text def count_spelled_words(text, word_list): print("Counting spelled words...") # Split the text into words words = re.findall(r'\b\w+\b', text.lower()) correct = sum(1 for word in words if word in word_list) incorrect = len(words) - correct print("Spelling check complete.") return incorrect, correct # Function to apply spell check to an item (assuming it's a dictionary) def apply_spell_check(item, word_list): print("Applying spell check...") if isinstance(item, dict): # This is a single item text = item['transcription'] incorrect, correct = count_spelled_words(text, word_list) item['incorrect_words'] = incorrect item['correct_words'] = correct print("Spell check applied to single item.") return item else: # This is likely a batch texts = item['transcription'] results = [count_spelled_words(text, word_list) for text in texts] incorrect_counts, correct_counts = zip(*results) item = item.append_column('incorrect_words', pa.array(incorrect_counts)) item = item.append_column('correct_words', pa.array(correct_counts)) print("Spell check applied to batch of items.") return item # FastAPI routes @app.get('/') async def root(): return "Welcome to the pronunciation scoring API!" @app.post('/pronunciation_scoring') async def unscripted_root(audio_file: UploadFile): print("Pronunciation Scoring") # Read the UploadFile into memory contents = await audio_file.read() print("Contents:" , contents) # Create a BytesIO object from the contents audio_bytes = io.BytesIO(contents) print("audio_bytes:" , audio_bytes) # Load the audio file using librosa audio, sr = librosa.load(audio_bytes) # Tokenize audio print("Tokenizing audio...") input_values = tokenizer(audio, return_tensors="pt").input_values # Perform inference print("Performing inference with Wav2Vec2 model...") logits = model(input_values).logits # Get predictions print("Getting predictions...") prediction = torch.argmax(logits, dim=-1) # Decode predictions print("Decoding predictions...") transcription = tokenizer.batch_decode(prediction)[0] # Convert transcription to lowercase transcription = transcription.lower() # Print transcription and word counts print("Decoded transcription:", transcription) incorrect, correct = count_spelled_words(transcription, english_words) print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct) # Calculate pronunciation score fraction = correct / (incorrect + correct) score = round(fraction * 10, 2) print("Pronunciation score for", transcription, ":", score) print("Pronunciation scoring process complete.") return { "transcription": transcription, "pronunciation_score": score }