import streamlit as st
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Initialize the tokenizer and model from the saved checkpoint
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
model = AutoModelForCausalLM.from_pretrained(
"himmeow/vi-gemma-2b-RAG",
device_map="auto",
torch_dtype=torch.bfloat16
)
# Use GPU if available
if torch.cuda.is_available():
model.to("cuda")
# Set up the Streamlit app layout
st.set_page_config(page_title="RAG PDF Chatbot", layout="wide")
# Sidebar with file upload and app title with creator details
st.sidebar.title("📁 PDF Upload")
uploaded_files = st.sidebar.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
# Multicolor sidebar background
st.sidebar.markdown("""
""", unsafe_allow_html=True)
st.sidebar.markdown("""
### Created by: [Engr. Hamesh Raj](https://www.linkedin.com/in/datascientisthameshraj/)
""")
# Main title
st.markdown("""
📜 RAG PDF Chatbot
""", unsafe_allow_html=True)
# Multicolor background for the main content
st.markdown("""
""", unsafe_allow_html=True)
# Input field for user queries
query = st.text_input("Enter your query here:")
submit_button = st.button("Submit")
# Initialize chat history
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Function to extract text from PDF files
def extract_text_from_pdfs(files):
text = ""
for uploaded_file in files:
reader = PdfReader(uploaded_file)
for page in reader.pages:
text += page.extract_text() + "\n"
return text
# Handle the query submission
if submit_button and query:
# Extract text from uploaded PDFs
if uploaded_files:
pdf_text = extract_text_from_pdfs(uploaded_files)
# Prepare the input prompt
prompt = f"""
### Instruction and Input:
Based on the following context/document:
{pdf_text}
Please answer the question: {query}
### Response:
"""
# Encode the input text
input_ids = tokenizer(prompt, return_tensors="pt")
# Use GPU for input ids if available
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
# Generate the response
outputs = model.generate(
**input_ids,
max_new_tokens=500,
no_repeat_ngram_size=5,
)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Update chat history
st.session_state.chat_history.append((query, response))
# Display chat history
if st.session_state.chat_history:
for i, (q, a) in enumerate(st.session_state.chat_history):
st.markdown(f"**Question {i + 1}:** {q}")
st.markdown(f"**Answer:** {a}")
st.write("---")