|
import gradio as gr |
|
from utils import reset_patient, set_patient, ask_another_question |
|
from regular_rag import qa_tool_regular_rag |
|
from graph_rag import qa_tool_graph_rag |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
def qa_tool(user_question, method): |
|
logger.info(f"Method selected: {method}") |
|
if method == "Regular RAG": |
|
answer, images, ask_another_visible, change_patient_visible = qa_tool_regular_rag(user_question) |
|
logger.info("Regular RAG answer generated") |
|
logger.info(f"Regular RAG images: {images}") |
|
return answer, images, ask_another_visible, change_patient_visible, gr.update(visible=False) |
|
elif method == "Graph-RAG": |
|
answer, images, ask_another_visible, change_patient_visible = qa_tool_graph_rag(user_question) |
|
logger.info("Graph-RAG answer generated") |
|
logger.info(f"Graph-RAG images: {images}") |
|
return answer, images, ask_another_visible, change_patient_visible, gr.update(visible=False) |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Clinical Diagram QA Tool") |
|
|
|
with gr.Group() as patient_input: |
|
patient_desc = gr.Textbox(label="Patient Description") |
|
set_patient_btn = gr.Button("Set Patient") |
|
|
|
with gr.Group() as qa_interface: |
|
qa_desc = gr.Markdown() |
|
question_input = gr.Textbox(label="Enter your question") |
|
method_choice = gr.Radio(["Regular RAG", "Graph-RAG"], label="Select Method") |
|
get_answer_btn = gr.Button("Get Answer") |
|
answer_output = gr.Textbox(label="Answer") |
|
image_output = gr.Gallery(label="Relevant Images", show_label=True) |
|
graph_output = gr.Plot(label="Knowledge Graph Visualization") |
|
ask_another_question_btn = gr.Button("Ask Another Question") |
|
change_patient_btn = gr.Button("Set Another Patient") |
|
|
|
qa_interface.visible = False |
|
ask_another_question_btn.visible = False |
|
change_patient_btn.visible = False |
|
|
|
set_patient_btn.click( |
|
set_patient, |
|
inputs=[patient_desc], |
|
outputs=[ |
|
qa_desc, |
|
qa_interface, |
|
patient_input, |
|
question_input, |
|
answer_output, |
|
image_output, |
|
graph_output, |
|
ask_another_question_btn, |
|
change_patient_btn |
|
] |
|
) |
|
|
|
get_answer_btn.click( |
|
qa_tool, |
|
inputs=[question_input, method_choice], |
|
outputs=[answer_output, image_output, ask_another_question_btn, change_patient_btn, graph_output] |
|
) |
|
|
|
ask_another_question_btn.click( |
|
ask_another_question, |
|
outputs=[question_input, answer_output, image_output, graph_output, get_answer_btn, ask_another_question_btn] |
|
) |
|
|
|
change_patient_btn.click( |
|
reset_patient, |
|
outputs=[ |
|
patient_desc, |
|
qa_desc, |
|
patient_input, |
|
qa_interface, |
|
change_patient_btn, |
|
question_input, |
|
answer_output, |
|
image_output, |
|
graph_output |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting the application") |
|
app.launch() |
|
|