datascientist22's picture
Update app.py
e93e1aa verified
raw
history blame
No virus
3.23 kB
import streamlit as st
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Cache the model and tokenizer loading
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
model = AutoModelForCausalLM.from_pretrained(
"himmeow/vi-gemma-2b-RAG",
device_map="auto",
torch_dtype=torch.bfloat16
)
if torch.cuda.is_available():
model.to("cuda")
return tokenizer, model
# Cache the text extraction from PDFs
@st.cache_data
def extract_text_from_pdfs(files):
pdf_text = ""
for file in files:
reader = PdfReader(file)
for page_num in range(len(reader.pages)):
page = reader.pages[page_num]
pdf_text += page.extract_text() + "\n"
return pdf_text
# Load the model and tokenizer
tokenizer, model = load_model()
# Sidebar for PDF file upload
st.sidebar.title("πŸ“‚ Upload PDFs")
uploaded_files = st.sidebar.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
# Initialize session state
if "history" not in st.session_state:
st.session_state.history = []
# Extract text from PDFs and maintain session state
if uploaded_files:
if "pdf_text" not in st.session_state:
st.session_state.pdf_text = extract_text_from_pdfs(uploaded_files)
# Main interface
st.title("πŸ’¬ RAG PDF Chatbot")
st.markdown("Ask questions based on the uploaded PDF documents.")
# Input for user query
query = st.text_input("Enter your question:")
# Process and respond to user query
if st.button("Submit"):
if uploaded_files and query:
with st.spinner("Generating response..."):
# Prepare the input data
prompt = """
### Instruction and Input:
Based on the following context/document:
{}
Please answer the question: {}
### Response:
""".format(st.session_state.pdf_text, query)
# Encode the input text
input_ids = tokenizer(prompt, return_tensors="pt")
# Use GPU for input ids if available
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
# Generate text using the model
outputs = model.generate(
**input_ids,
max_new_tokens=500, # Limit tokens to speed up generation
no_repeat_ngram_size=3, # Avoid repetition
do_sample=True, # Sampling for variability
temperature=0.7 # Control randomness
)
# Decode and display the results
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.session_state.history.append({"question": query, "answer": response})
# Display chat history
if st.session_state.history:
for i, qa in enumerate(reversed(st.session_state.history), 1):
st.markdown(f"**Q{i}:** {qa['question']}")
st.markdown(f"**A{i}:** {qa['answer']}")
# Footer with author information
st.sidebar.markdown("### Created by: [Engr. Hamesh Raj](https://www.linkedin.com/in/datascientisthameshraj/)")
st.sidebar.markdown("## πŸ—‚οΈ RAG PDF Chatbot")