import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch from PyPDF2 import PdfReader # Initialize the tokenizer and model from the saved checkpoint tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG") model = AutoModelForCausalLM.from_pretrained( "himmeow/vi-gemma-2b-RAG", device_map="auto", torch_dtype=torch.bfloat16 ) # Use GPU if available if torch.cuda.is_available(): model.to("cuda") # Function to extract text from PDF def extract_text_from_pdf(pdf_path): pdf_text = "" with open(pdf_path, "rb") as file: reader = PdfReader(file) for page_num in range(len(reader.pages)): page = reader.pages[page_num] text = page.extract_text() pdf_text += text + "\n" return pdf_text # Streamlit app st.write("**Created by: Engr. Hamesh Raj** [LinkedIn](https://www.linkedin.com/in/datascientisthameshraj/)") st.title("📄 PDF Question Answering") # Sidebar for PDF upload uploaded_file = st.sidebar.file_uploader("Upload a PDF file", type="pdf") if uploaded_file is not None: # Extract text from the uploaded PDF pdf_text = extract_text_from_pdf(uploaded_file) st.text_area("Extracted PDF Text", pdf_text, height=200) # Input field for the user's question user_query = st.text_input("Enter your question:") # Display the submit button below the input field if st.button("Submit") and user_query: # Format the input text input_text = f"{user_query}\n\n### Response:\n" # Encode the input text into input ids input_ids = tokenizer(input_text, return_tensors="pt") # Use GPU for input ids if available if torch.cuda.is_available(): input_ids = input_ids.to("cuda") # Generate text using the model outputs = model.generate( **input_ids, max_new_tokens=150, # Limit the number of tokens generated no_repeat_ngram_size=5, # Prevent repetition of 5-gram phrases ) # Decode and print the results answer = tokenizer.decode(outputs[0], skip_special_tokens=True) # Display question and answer st.write(f"**Q{len(st.session_state) + 1}: {user_query}**") st.write(f"**A{len(st.session_state) + 1}: {answer.strip()}**") # Store in session state for chat history if "history" not in st.session_state: st.session_state.history = [] st.session_state.history.append({ "question": user_query, "answer": answer.strip() }) # Display chat history if "history" in st.session_state: for i, qa in enumerate(st.session_state.history): st.write(f"**Q{i + 1}: {qa['question']}**") st.write(f"**A{i + 1}: {qa['answer']}**")