Shanulhaq's picture
Update app.py
2414c16 verified
# Install necessary libraries
#!pip install PyPDF2 transformers torch accelerate streamlit
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import PyPDF2
import streamlit as st
# Function to extract text from PDF
def extract_text_from_pdf(uploaded_file):
pdf_text = ""
reader = PyPDF2.PdfReader(uploaded_file)
for page_num in range(len(reader.pages)):
page = reader.pages[page_num]
pdf_text += page.extract_text()
return pdf_text
# Initialize the tokenizer and model on CPU first
tokenizer = AutoTokenizer.from_pretrained("ricepaper/vi-gemma-2b-RAG")
model = AutoModelForCausalLM.from_pretrained(
"ricepaper/vi-gemma-2b-RAG",
torch_dtype=torch.bfloat16
)
# Move model to GPU if available
if torch.cuda.is_available():
model.to("cuda")
# Define the prompt format for the model
prompt = """
### Instruction and Input:
Based on the following context/document:
{}
Please answer the question: {}
### Response:
{}
"""
# Function to generate answer based on query and context
def generate_answer(context, query):
input_text = prompt.format(context, query, "")
input_ids = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
# Use GPU for input ids if available
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
# Generate text using the model
outputs = model.generate(
**input_ids,
max_new_tokens=500,
no_repeat_ngram_size=5,
)
# Decode and print the results
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Streamlit App
st.title("RAG-Based PDF Question Answering Application")
# Upload PDF
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
if uploaded_file is not None:
# Extract text from the uploaded PDF
pdf_text = extract_text_from_pdf(uploaded_file)
st.write("Extracted text from PDF:")
st.text_area("PDF Content", pdf_text, height=200)
# User inputs their question
query = st.text_input("Enter your question about the PDF content:")
if st.button("Get Answer"):
if query.strip() != "":
# Generate answer based on extracted PDF text and the query
answer = generate_answer(pdf_text, query)
st.write("Answer:", answer)
else:
st.warning("Please enter a question.")
else:
st.info("Please upload a PDF file to get started.")