RAG_Chat_Bot / app.py
Waseem7711's picture
Update app.py
d39f2ea verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import fitz # PyMuPDF for PDF handling
# Load the model and tokenizer
@st.cache_resource
def load_model():
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
model = AutoModelForCausalLM.from_pretrained(
"himmeow/vi-gemma-2b-RAG",
device_map="auto",
torch_dtype=torch.bfloat16
)
if torch.cuda.is_available():
model.to("cuda")
return tokenizer, model
# Function to extract text from PDF
def extract_text_from_pdf(pdf_file):
# Extract text from the uploaded PDF file using PyMuPDF
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
text = ""
for page_num in range(doc.page_count):
page = doc.load_page(page_num)
text += page.get_text("text") # Ensure text extraction
return text
# Function to generate response from model
def generate_response(input_text, query, tokenizer, model):
# Format the input prompt for the model
prompt = f"""
### Instruction and Input:
Based on the following context/document:
{input_text}
Please answer the question: {query}
### Response:
"""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
# Generate a response from the model
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=500,
no_repeat_ngram_size=5
)
# Decode the generated output into readable text
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Streamlit app main function
def main():
st.title("PDF Question Answering with vi-gemma-2b-RAG")
# File uploader widget for PDF files
pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
if pdf_file is not None:
with st.spinner("Reading the PDF..."):
# Extract text from the uploaded PDF
pdf_text = extract_text_from_pdf(pdf_file)
st.text_area("Extracted Text", pdf_text, height=300)
# Text input for the user's question
query = st.text_input("Enter your question:")
if st.button("Get Answer"):
if query.strip() == "":
st.warning("Please enter a question.")
else:
with st.spinner("Generating response..."):
# Load the model and tokenizer
tokenizer, model = load_model()
# Generate the response using the model
try:
response = generate_response(pdf_text, query, tokenizer, model)
st.text_area("Response", response, height=200)
except Exception as e:
st.error(f"Error generating response: {e}")
if __name__ == "__main__":
main()