Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PyPDF2 import PdfReader | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Cache the model and tokenizer loading | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG") | |
model = AutoModelForCausalLM.from_pretrained( | |
"himmeow/vi-gemma-2b-RAG", | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
) | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
return tokenizer, model | |
# Cache the text extraction from PDFs | |
def extract_text_from_pdfs(files): | |
pdf_text = "" | |
for file in files: | |
reader = PdfReader(file) | |
for page_num in range(len(reader.pages)): | |
page = reader.pages[page_num] | |
pdf_text += page.extract_text() + "\n" | |
return pdf_text | |
# Load the model and tokenizer | |
tokenizer, model = load_model() | |
# Sidebar for PDF file upload | |
st.sidebar.title("π Upload PDFs") | |
uploaded_files = st.sidebar.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) | |
# Initialize session state | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
# Extract text from PDFs and maintain session state | |
if uploaded_files: | |
if "pdf_text" not in st.session_state: | |
st.session_state.pdf_text = extract_text_from_pdfs(uploaded_files) | |
# Main interface | |
st.title("π¬ RAG PDF Chatbot") | |
st.markdown("Ask questions based on the uploaded PDF documents.") | |
# Input for user query | |
query = st.text_input("Enter your question:") | |
# Process and respond to user query | |
if st.button("Submit"): | |
if uploaded_files and query: | |
with st.spinner("Generating response..."): | |
# Prepare the input data | |
prompt = """ | |
### Instruction and Input: | |
Based on the following context/document: | |
{} | |
Please answer the question: {} | |
### Response: | |
""".format(st.session_state.pdf_text, query) | |
# Encode the input text | |
input_ids = tokenizer(prompt, return_tensors="pt") | |
# Use GPU for input ids if available | |
if torch.cuda.is_available(): | |
input_ids = input_ids.to("cuda") | |
# Generate text using the model | |
outputs = model.generate( | |
**input_ids, | |
max_new_tokens=500, # Limit tokens to speed up generation | |
no_repeat_ngram_size=3, # Avoid repetition | |
do_sample=True, # Sampling for variability | |
temperature=0.7 # Control randomness | |
) | |
# Decode and display the results | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
st.session_state.history.append({"question": query, "answer": response}) | |
# Display chat history | |
if st.session_state.history: | |
for i, qa in enumerate(reversed(st.session_state.history), 1): | |
st.markdown(f"**Q{i}:** {qa['question']}") | |
st.markdown(f"**A{i}:** {qa['answer']}") | |
# Footer with author information | |
st.sidebar.markdown("### Created by: [Engr. Hamesh Raj](https://www.linkedin.com/in/datascientisthameshraj/)") | |
st.sidebar.markdown("## ποΈ RAG PDF Chatbot") |