import streamlit as st from PyPDF2 import PdfReader from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Cache the model and tokenizer loading @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG") model = AutoModelForCausalLM.from_pretrained( "himmeow/vi-gemma-2b-RAG", device_map="auto", torch_dtype=torch.bfloat16 ) if torch.cuda.is_available(): model.to("cuda") return tokenizer, model # Cache the text extraction from PDFs @st.cache_data def extract_text_from_pdfs(files): pdf_text = "" for file in files: reader = PdfReader(file) for page_num in range(len(reader.pages)): page = reader.pages[page_num] pdf_text += page.extract_text() + "\n" return pdf_text # Load the model and tokenizer tokenizer, model = load_model() # Sidebar for PDF file upload st.sidebar.title("📂 Upload PDFs") uploaded_files = st.sidebar.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) # Initialize session state if "history" not in st.session_state: st.session_state.history = [] # Extract text from PDFs and maintain session state if uploaded_files: if "pdf_text" not in st.session_state: st.session_state.pdf_text = extract_text_from_pdfs(uploaded_files) # Main interface st.title("💬 RAG PDF Chatbot") st.markdown("Ask questions based on the uploaded PDF documents.") # Input for user query query = st.text_input("Enter your question:") # Process and respond to user query if st.button("Submit"): if uploaded_files and query: with st.spinner("Generating response..."): # Prepare the input data prompt = """ ### Instruction and Input: Based on the following context/document: {} Please answer the question: {} ### Response: """.format(st.session_state.pdf_text, query) # Encode the input text input_ids = tokenizer(prompt, return_tensors="pt") # Use GPU for input ids if available if torch.cuda.is_available(): input_ids = input_ids.to("cuda") # Generate text using the model outputs = model.generate( **input_ids, max_new_tokens=500, # Limit tokens to speed up generation no_repeat_ngram_size=3, # Avoid repetition do_sample=True, # Sampling for variability temperature=0.7 # Control randomness ) # Decode and display the results response = tokenizer.decode(outputs[0], skip_special_tokens=True) st.session_state.history.append({"question": query, "answer": response}) # Display chat history if st.session_state.history: for i, qa in enumerate(reversed(st.session_state.history), 1): st.markdown(f"**Q{i}:** {qa['question']}") st.markdown(f"**A{i}:** {qa['answer']}") # Footer with author information st.sidebar.markdown("### Created by: [Engr. Hamesh Raj](https://www.linkedin.com/in/datascientisthameshraj/)") st.sidebar.markdown("## 🗂️ RAG PDF Chatbot")