Spaces:
Sleeping
Sleeping
# Install necessary libraries | |
#!pip install PyPDF2 transformers torch accelerate streamlit | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import PyPDF2 | |
import streamlit as st | |
# Function to extract text from PDF | |
def extract_text_from_pdf(uploaded_file): | |
pdf_text = "" | |
reader = PyPDF2.PdfReader(uploaded_file) | |
for page_num in range(len(reader.pages)): | |
page = reader.pages[page_num] | |
pdf_text += page.extract_text() | |
return pdf_text | |
# Initialize the tokenizer and model on CPU first | |
tokenizer = AutoTokenizer.from_pretrained("ricepaper/vi-gemma-2b-RAG") | |
model = AutoModelForCausalLM.from_pretrained( | |
"ricepaper/vi-gemma-2b-RAG", | |
torch_dtype=torch.bfloat16 | |
) | |
# Move model to GPU if available | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
# Define the prompt format for the model | |
prompt = """ | |
### Instruction and Input: | |
Based on the following context/document: | |
{} | |
Please answer the question: {} | |
### Response: | |
{} | |
""" | |
# Function to generate answer based on query and context | |
def generate_answer(context, query): | |
input_text = prompt.format(context, query, "") | |
input_ids = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024) | |
# Use GPU for input ids if available | |
if torch.cuda.is_available(): | |
input_ids = input_ids.to("cuda") | |
# Generate text using the model | |
outputs = model.generate( | |
**input_ids, | |
max_new_tokens=500, | |
no_repeat_ngram_size=5, | |
) | |
# Decode and print the results | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return answer | |
# Streamlit App | |
st.title("RAG-Based PDF Question Answering Application") | |
# Upload PDF | |
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf") | |
if uploaded_file is not None: | |
# Extract text from the uploaded PDF | |
pdf_text = extract_text_from_pdf(uploaded_file) | |
st.write("Extracted text from PDF:") | |
st.text_area("PDF Content", pdf_text, height=200) | |
# User inputs their question | |
query = st.text_input("Enter your question about the PDF content:") | |
if st.button("Get Answer"): | |
if query.strip() != "": | |
# Generate answer based on extracted PDF text and the query | |
answer = generate_answer(pdf_text, query) | |
st.write("Answer:", answer) | |
else: | |
st.warning("Please enter a question.") | |
else: | |
st.info("Please upload a PDF file to get started.") | |