Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
import PyPDF2 | |
import markdown | |
import matplotlib.pyplot as plt | |
import io | |
import base64 | |
import torch | |
from fpdf import FPDF | |
import os | |
import tempfile | |
import glob | |
# Preload models | |
models = { | |
"distilbert-base-uncased-distilled-squad": "distilbert-base-uncased-distilled-squad", | |
"roberta-base-squad2": "deepset/roberta-base-squad2", | |
"bert-large-uncased-whole-word-masking-finetuned-squad": "bert-large-uncased-whole-word-masking-finetuned-squad", | |
"albert-base-v2": "twmkn9/albert-base-v2-squad2", | |
"xlm-roberta-large-squad2": "deepset/xlm-roberta-large-squad2" | |
} | |
loaded_models = {} | |
# Ensure we're using the CPU if GPU isn't available or necessary | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(model_name): | |
if model_name not in loaded_models: | |
loaded_models[model_name] = pipeline("question-answering", model=models[model_name], device=0 if torch.cuda.is_available() else -1) | |
return loaded_models[model_name] | |
def generate_score_chart(score): | |
plt.figure(figsize=(6, 4)) | |
plt.bar(["Confidence Score"], [score], color='skyblue') | |
plt.ylim(0, 1) | |
plt.ylabel("Score") | |
plt.title("Confidence Score") | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
return base64.b64encode(buf.getvalue()).decode() | |
def highlight_relevant_text(context, start, end): | |
highlighted_text = ( | |
context[:start] + | |
'<mark style="background-color: yellow;">' + | |
context[start:end] + | |
'</mark>' + | |
context[end:] | |
) | |
return highlighted_text | |
def find_system_font(): | |
# Adjust this function to find a suitable font | |
font_dirs = ["/usr/share/fonts", "/usr/local/share/fonts"] | |
for font_dir in font_dirs: | |
ttf_files = glob.glob(os.path.join(font_dir, "**/NotoSans*.ttf"), recursive=True) | |
if ttf_files: | |
return ttf_files[0] # Return the first found NotoSans font | |
raise FileNotFoundError("No suitable TTF font file found in system font directories.") | |
def generate_pdf_report(question, answer, score, score_explanation, score_chart, highlighted_context): | |
pdf = FPDF() | |
pdf.add_page() | |
# Find and use a comprehensive Unicode font like NotoSans | |
font_path = find_system_font() | |
pdf.add_font("NotoSans", "", font_path) | |
pdf.set_font("NotoSans", size=12) | |
pdf.multi_cell(0, 10, f"Question: {question}") | |
pdf.ln() | |
pdf.set_font("NotoSans", size=12) | |
pdf.multi_cell(0, 10, f"Answer: {answer}") | |
pdf.ln() | |
pdf.set_font("NotoSans", size=12) | |
pdf.multi_cell(0, 10, f"Confidence Score: {score}") | |
pdf.ln() | |
pdf.set_font("NotoSans", size=12) | |
pdf.multi_cell(0, 10, f"Score Explanation: {score_explanation}") | |
pdf.ln() | |
pdf.set_font("NotoSans", size=12) | |
pdf.multi_cell(0, 10, "Highlighted Context:") | |
pdf.ln() | |
pdf.set_font("NotoSans", size=10) | |
pdf.multi_cell(0, 10, highlighted_context) | |
pdf.ln() | |
# Handle the image as a temporary file | |
score_chart_image = io.BytesIO(base64.b64decode(score_chart)) | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmpfile: | |
tmpfile.write(score_chart_image.read()) | |
tmpfile.flush() | |
tmpfile.close() | |
pdf.image(tmpfile.name, x=10, y=pdf.get_y(), w=100) | |
# Save PDF to memory | |
pdf_output = io.BytesIO() | |
pdf.output(pdf_output) | |
pdf_output.seek(0) | |
# Clean up temporary file | |
os.remove(tmpfile.name) | |
return pdf_output | |
def answer_question(model_name, file, question, status): | |
status = "Loading model..." | |
model = load_model(model_name) | |
if file is not None: | |
file_name = file.name | |
if file_name.endswith(".pdf"): | |
pdf_reader = PyPDF2.PdfReader(file) | |
context = "" | |
for page_num in range(len(pdf_reader.pages)): | |
context += pdf_reader.pages[page_num].extract_text() | |
elif file_name.endswith(".md"): | |
context = file.read().decode('utf-8') | |
context = markdown.markdown(context) | |
else: | |
context = file.read().decode('utf-8') | |
else: | |
context = "" | |
result = model(question=question, context=context) | |
answer = result['answer'] | |
score = result['score'] | |
start = result['start'] | |
end = result['end'] | |
# Highlight relevant text | |
highlighted_context = highlight_relevant_text(context, start, end) | |
# Generate the score chart | |
score_chart = generate_score_chart(score) | |
# Explain score | |
score_explanation = f"The confidence score ranges from 0 to 1, where a higher score indicates higher confidence in the answer's correctness. In this case, the score is {score:.2f}. A score closer to 1 implies the model is very confident about the answer." | |
# Generate the PDF report | |
pdf_report = generate_pdf_report(question, answer, f"{score:.2f}", score_explanation, score_chart, highlighted_context) | |
status = "Model loaded" | |
return highlighted_context, f"{score:.2f}", score_explanation, score_chart, pdf_report, status | |
# Define the Gradio interface | |
with gr.Blocks() as interface: | |
gr.Markdown( | |
""" | |
# Question Answering System | |
Upload a document (text, PDF, or Markdown) and ask questions to get answers based on the context. | |
**Supported File Types**: `.txt`, `.pdf`, `.md` | |
""") | |
with gr.Row(): | |
model_dropdown = gr.Dropdown( | |
choices=list(models.keys()), | |
label="Select Model", | |
value="distilbert-base-uncased-distilled-squad" | |
) | |
with gr.Row(): | |
file_input = gr.File(label="Upload Document", file_types=["text", "pdf", "markdown"]) | |
question_input = gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question") | |
with gr.Row(): | |
answer_output = gr.HTML(label="Highlighted Answer") | |
score_output = gr.Textbox(label="Confidence Score") | |
explanation_output = gr.Textbox(label="Score Explanation") | |
chart_output = gr.Image(label="Score Chart") | |
pdf_output = gr.File(label="Download PDF Report") | |
with gr.Row(): | |
submit_button = gr.Button("Submit") | |
status_output = gr.Markdown(value="") | |
def on_submit(model_name, file, question): | |
return answer_question(model_name, file, question, status="Loading model...") | |
submit_button.click( | |
on_submit, | |
inputs=[model_dropdown, file_input, question_input], | |
outputs=[answer_output, score_output, explanation_output, chart_output, pdf_output, status_output] | |
) | |
if __name__ == "__main__": | |
interface.launch(share=True) | |