File size: 5,398 Bytes
5d52c32
6c226f9
 
6fb7bb2
deca1bc
 
 
6c226f9
2362603
9d6fa91
b39f388
deca1bc
6c226f9
 
 
deca1bc
 
b39f388
deca1bc
 
 
3de905c
550ced0
6c226f9
 
 
 
 
 
 
550ced0
5d52c32
3c0cd8e
 
 
 
550ced0
6c226f9
deca1bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf2ed0
deca1bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c226f9
6fb7bb2
47407ef
6c226f9
 
6fb7bb2
5320fa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fb7bb2
5320fa6
 
 
 
 
 
 
6fb7bb2
5320fa6
6fb7bb2
6c226f9
47407ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import spaces
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from typing import Iterator
import os

MODEL_NAME = "openai/whisper-large-v3-turbo"
BATCH_SIZE = 8
FILE_LIMIT_MB = 5000
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

device = 0 if torch.cuda.is_available() else "cpu"

# Initialize the LLM
if torch.cuda.is_available():
    llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored"
    llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
    tokenizer.use_default_system_prompt = False

# Initialize the transcription pipeline
pipe = pipeline(
    task="automatic-speech-recognition",
    model=MODEL_NAME,
    chunk_length_s=30,
    device=device,
)

# Function to transcribe audio inputs
@spaces.GPU
def transcribe(inputs, task):
    if inputs is None:
        raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
    text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
    return text

# Function to generate SOAP notes using LLM
@spaces.GPU
def generate_soap(
    transcribed_text: str,
    system_prompt: str = "You are a world class clinical assistant.",
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    task_prompt = """
    Convert the following transcribed conversation into a clinical SOAP note.
    The text includes dialogue between a physician and a patient. Please clearly distinguish between the physician's and the patient's statements.
    Extract and organize the information into the relevant sections of a SOAP note:
    - Subjective (symptoms and patient statements),
    - Objective (clinical findings and observations),
    - Assessment (diagnosis or potential diagnoses),
    - Plan (treatment and follow-up).
    Ensure the note is concise, clear, and accurately reflects the conversation.
    """
    
    conversation = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"{task_prompt}\n\nTranscribed conversation:\n{transcribed_text}"}
    ]

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(llm.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=llm.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Gradio Interface combining transcription and SOAP note generation
demo = gr.Blocks(theme=gr.themes.Ocean())

with demo:
    with gr.Tab("Clinical SOAP Note from Audio"):
        # Transcription Interface
        audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input")
        task_input = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
        transcription_output = gr.Textbox(label="Transcription Output")
        
        # Transcription button
        transcribe_button = gr.Button("Transcribe")
        transcribe_button.click(fn=transcribe, inputs=[audio_input, task_input], outputs=transcription_output)

        # SOAP Generation Interface
        transcribed_text_input = gr.Textbox(label="Edit Transcription before SOAP Generation", lines=5)
        system_prompt_input = gr.Textbox(label="System Prompt", lines=2, value="You are a world class clinical assistant.")
        max_new_tokens_input = gr.Slider(label="Max new tokens", minimum=1, maximum=2048, value=1024, step=1)
        temperature_input = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, value=0.6, step=0.1)
        top_p_input = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, value=0.9, step=0.05)
        top_k_input = gr.Slider(label="Top-k", minimum=1, maximum=1000, value=50, step=1)
        repetition_penalty_input = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.05)

        soap_output = gr.Textbox(label="Generated SOAP Note Output")

        # SOAP generation button
        generate_soap_button = gr.Button("Generate SOAP Note")
        generate_soap_button.click(
            fn=generate_soap,
            inputs=[
                transcribed_text_input,
                system_prompt_input,
                max_new_tokens_input,
                temperature_input,
                top_p_input,
                top_k_input,
                repetition_penalty_input
            ],
            outputs=soap_output
        )

demo.queue().launch(ssr_mode=False)