import streamlit as st
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# App configuration
st.set_page_config(page_title="PDF Chatbot", layout="wide")
st.markdown(
"""
""",
unsafe_allow_html=True
)
# Title and "Created by" section
st.markdown("
📄 PDF RAG Chatbot
", unsafe_allow_html=True)
st.markdown(
"",
unsafe_allow_html=True
)
# Sidebar for PDF file upload
uploaded_files = st.sidebar.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
# Query input
query = st.text_input("Ask a question about the uploaded PDFs:")
# Initialize session state to store conversation
if "conversation" not in st.session_state:
st.session_state.conversation = []
# Function to extract text from PDFs
def extract_text_from_pdfs(files):
pdf_text = ""
for file in files:
reader = PdfReader(file)
for page_num in range(len(reader.pages)):
page = reader.pages[page_num]
pdf_text += page.extract_text() + "\n"
return pdf_text
# Load model and tokenizer
@st.cache_resource(allow_output_mutation=True)
def load_model():
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
model = AutoModelForCausalLM.from_pretrained(
"himmeow/vi-gemma-2b-RAG",
device_map="auto",
torch_dtype=torch.bfloat16
)
if torch.cuda.is_available():
model.to("cuda")
return tokenizer, model
# Process and respond to user query
if st.button("Submit"):
if uploaded_files and query:
pdf_text = extract_text_from_pdfs(uploaded_files)
tokenizer, model = load_model()
prompt = """
### Instruction and Input:
Based on the following context/document:
{}
Please answer the question: {}
### Response:
{}
"""
input_text = prompt.format(pdf_text, query, " ")
input_ids = tokenizer(input_text, return_tensors="pt")
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
outputs = model.generate(
**input_ids,
max_new_tokens=500,
no_repeat_ngram_size=5,
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Store the conversation
st.session_state.conversation.insert(0, {"question": query, "answer": answer})
# Display conversation
if st.session_state.conversation:
st.markdown("## Previous Conversations")
for qa in st.session_state.conversation:
st.markdown(f"**Q: {qa['question']}**")
st.markdown(f"**A: {qa['answer']}**")
st.markdown("---")