RAG-CHAT-BOt / app.py
Waseem771's picture
Update app.py
3907dec verified
raw
history blame
No virus
2.21 kB
import gradio as gr
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
model = AutoModelForCausalLM.from_pretrained(
"himmeow/vi-gemma-2b-RAG",
device_map="auto",
torch_dtype=torch.bfloat16
)
if torch.cuda.is_available():
model.to("cuda")
# Define the prompt format for the model
prompt = """
### Instruction and Input:
Based on the following context/document:
{}
Please answer the question: {}
### Response:
{}
"""
def extract_text_from_pdf(pdf):
pdf_Text = ""
reader = PdfReader(pdf)
for page_num in range(len(reader.pages)):
page = reader.pages[page_num]
text = page.extract_text()
if text:
pdf_Text += text + "\n"
if not pdf_Text.strip():
pdf_Text = "The PDF contains no extractable text."
print("Extracted Text:\n", pdf_Text) # Debugging statement
return pdf_Text
def generate_response(pdf, query):
pdf_Text = extract_text_from_pdf(pdf)
if not pdf_Text.strip():
return "The PDF appears to be empty or unreadable."
input_text = prompt.format(pdf_Text, query, " ")
print("Input Text for Model:\n", input_text) # Debugging statement
input_ids = tokenizer(input_text, return_tensors="pt")
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
try:
outputs = model.generate(
**input_ids,
max_new_tokens=500,
no_repeat_ngram_size=5,
)
response = tokenizer.decode(outputs[0])
except Exception as e:
response = "An error occurred while generating the response."
print("Error:", e)
print("Generated Response:\n", response) # Debugging statement
return response
# Gradio interface
iface = gr.Interface(
fn=generate_response,
inputs=[gr.File(label="Upload PDF"), gr.Textbox(label="Ask a question")],
outputs="text",
title="PDF Question Answering with vi-gemma-2b-RAG",
description="Upload a PDF and ask a question based on its content. The model will generate a response."
)
iface.launch()