datascientist22's picture
Update app.py
6feb14e verified
raw
history blame
2.83 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from PyPDF2 import PdfReader
# Initialize the tokenizer and model from the saved checkpoint
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
model = AutoModelForCausalLM.from_pretrained(
"himmeow/vi-gemma-2b-RAG",
device_map="auto",
torch_dtype=torch.bfloat16
)
# Use GPU if available
if torch.cuda.is_available():
model.to("cuda")
# Function to extract text from PDF
def extract_text_from_pdf(pdf_path):
pdf_text = ""
with open(pdf_path, "rb") as file:
reader = PdfReader(file)
for page_num in range(len(reader.pages)):
page = reader.pages[page_num]
text = page.extract_text()
pdf_text += text + "\n"
return pdf_text
# Streamlit app
st.write("**Created by: Engr. Hamesh Raj** [LinkedIn](https://www.linkedin.com/in/datascientisthameshraj/)")
st.title("πŸ“„ PDF Question Answering")
# Sidebar for PDF upload
uploaded_file = st.sidebar.file_uploader("Upload a PDF file", type="pdf")
if uploaded_file is not None:
# Extract text from the uploaded PDF
pdf_text = extract_text_from_pdf(uploaded_file)
st.text_area("Extracted PDF Text", pdf_text, height=200)
# Input field for the user's question
user_query = st.text_input("Enter your question:")
# Display the submit button below the input field
if st.button("Submit") and user_query:
# Format the input text
input_text = f"{user_query}\n\n### Response:\n"
# Encode the input text into input ids
input_ids = tokenizer(input_text, return_tensors="pt")
# Use GPU for input ids if available
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
# Generate text using the model
outputs = model.generate(
**input_ids,
max_new_tokens=150, # Limit the number of tokens generated
no_repeat_ngram_size=5, # Prevent repetition of 5-gram phrases
)
# Decode and print the results
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Display question and answer
st.write(f"**Q{len(st.session_state) + 1}: {user_query}**")
st.write(f"**A{len(st.session_state) + 1}: {answer.strip()}**")
# Store in session state for chat history
if "history" not in st.session_state:
st.session_state.history = []
st.session_state.history.append({
"question": user_query,
"answer": answer.strip()
})
# Display chat history
if "history" in st.session_state:
for i, qa in enumerate(st.session_state.history):
st.write(f"**Q{i + 1}: {qa['question']}**")
st.write(f"**A{i + 1}: {qa['answer']}**")