import spaces import torch import gradio as gr from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread from typing import Iterator import os MODEL_NAME = "openai/whisper-large-v3-turbo" BATCH_SIZE = 8 FILE_LIMIT_MB = 5000 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = 0 if torch.cuda.is_available() else "cpu" # Initialize the LLM if torch.cuda.is_available(): llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored" llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(llm_model_id) tokenizer.use_default_system_prompt = False # Initialize the transcription pipeline pipe = pipeline( task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=30, device=device, ) # Function to transcribe audio inputs @spaces.GPU def transcribe(inputs, task): if inputs is None: raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"] return text # Function to generate SOAP notes using LLM @spaces.GPU def generate_soap( transcribed_text: str, system_prompt: str = "You are a world class clinical assistant.", max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: task_prompt = """ Convert the following transcribed conversation into a clinical SOAP note. The text includes dialogue between a physician and a patient. Please clearly distinguish between the physician's and the patient's statements. Extract and organize the information into the relevant sections of a SOAP note: - Subjective (symptoms and patient statements), - Objective (clinical findings and observations), - Assessment (diagnosis or potential diagnoses), - Plan (treatment and follow-up). Ensure the note is concise, clear, and accurately reflects the conversation. """ conversation = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"{task_prompt}\n\nTranscribed conversation:\n{transcribed_text}"} ] input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(llm.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=llm.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) # Gradio Interface combining transcription and SOAP note generation demo = gr.Blocks(theme=gr.themes.Ocean()) with demo: with gr.Tab("Clinical SOAP Note from Audio"): # Transcription Interface audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input") task_input = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe") transcription_output = gr.Textbox(label="Transcription Output") # Transcription button transcribe_button = gr.Button("Transcribe") transcribe_button.click(fn=transcribe, inputs=[audio_input, task_input], outputs=transcription_output) # SOAP Generation Interface transcribed_text_input = gr.Textbox(label="Edit Transcription before SOAP Generation", lines=5) system_prompt_input = gr.Textbox(label="System Prompt", lines=2, value="You are a world class clinical assistant.") max_new_tokens_input = gr.Slider(label="Max new tokens", minimum=1, maximum=2048, value=1024, step=1) temperature_input = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, value=0.6, step=0.1) top_p_input = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, value=0.9, step=0.05) top_k_input = gr.Slider(label="Top-k", minimum=1, maximum=1000, value=50, step=1) repetition_penalty_input = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.05) soap_output = gr.Textbox(label="Generated SOAP Note Output") # SOAP generation button generate_soap_button = gr.Button("Generate SOAP Note") generate_soap_button.click( fn=generate_soap, inputs=[ transcribed_text_input, system_prompt_input, max_new_tokens_input, temperature_input, top_p_input, top_k_input, repetition_penalty_input ], outputs=soap_output ) demo.queue().launch(ssr_mode=False)