Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,398 Bytes
5d52c32 6c226f9 6fb7bb2 deca1bc 6c226f9 2362603 9d6fa91 b39f388 deca1bc 6c226f9 deca1bc b39f388 deca1bc 3de905c 550ced0 6c226f9 550ced0 5d52c32 3c0cd8e 550ced0 6c226f9 deca1bc 9cf2ed0 deca1bc 6c226f9 6fb7bb2 47407ef 6c226f9 6fb7bb2 5320fa6 6fb7bb2 5320fa6 6fb7bb2 5320fa6 6fb7bb2 6c226f9 47407ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import spaces
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from typing import Iterator
import os
MODEL_NAME = "openai/whisper-large-v3-turbo"
BATCH_SIZE = 8
FILE_LIMIT_MB = 5000
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = 0 if torch.cuda.is_available() else "cpu"
# Initialize the LLM
if torch.cuda.is_available():
llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored"
llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
tokenizer.use_default_system_prompt = False
# Initialize the transcription pipeline
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30,
device=device,
)
# Function to transcribe audio inputs
@spaces.GPU
def transcribe(inputs, task):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
return text
# Function to generate SOAP notes using LLM
@spaces.GPU
def generate_soap(
transcribed_text: str,
system_prompt: str = "You are a world class clinical assistant.",
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
task_prompt = """
Convert the following transcribed conversation into a clinical SOAP note.
The text includes dialogue between a physician and a patient. Please clearly distinguish between the physician's and the patient's statements.
Extract and organize the information into the relevant sections of a SOAP note:
- Subjective (symptoms and patient statements),
- Objective (clinical findings and observations),
- Assessment (diagnosis or potential diagnoses),
- Plan (treatment and follow-up).
Ensure the note is concise, clear, and accurately reflects the conversation.
"""
conversation = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"{task_prompt}\n\nTranscribed conversation:\n{transcribed_text}"}
]
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(llm.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=llm.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio Interface combining transcription and SOAP note generation
demo = gr.Blocks(theme=gr.themes.Ocean())
with demo:
with gr.Tab("Clinical SOAP Note from Audio"):
# Transcription Interface
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input")
task_input = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
transcription_output = gr.Textbox(label="Transcription Output")
# Transcription button
transcribe_button = gr.Button("Transcribe")
transcribe_button.click(fn=transcribe, inputs=[audio_input, task_input], outputs=transcription_output)
# SOAP Generation Interface
transcribed_text_input = gr.Textbox(label="Edit Transcription before SOAP Generation", lines=5)
system_prompt_input = gr.Textbox(label="System Prompt", lines=2, value="You are a world class clinical assistant.")
max_new_tokens_input = gr.Slider(label="Max new tokens", minimum=1, maximum=2048, value=1024, step=1)
temperature_input = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, value=0.6, step=0.1)
top_p_input = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, value=0.9, step=0.05)
top_k_input = gr.Slider(label="Top-k", minimum=1, maximum=1000, value=50, step=1)
repetition_penalty_input = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.05)
soap_output = gr.Textbox(label="Generated SOAP Note Output")
# SOAP generation button
generate_soap_button = gr.Button("Generate SOAP Note")
generate_soap_button.click(
fn=generate_soap,
inputs=[
transcribed_text_input,
system_prompt_input,
max_new_tokens_input,
temperature_input,
top_p_input,
top_k_input,
repetition_penalty_input
],
outputs=soap_output
)
demo.queue().launch(ssr_mode=False)
|