Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from PyPDF2 import PdfReader | |
# Initialize the tokenizer and model from the saved checkpoint | |
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG") | |
model = AutoModelForCausalLM.from_pretrained( | |
"himmeow/vi-gemma-2b-RAG", | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
) | |
# Use GPU if available | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
# Function to extract text from PDF | |
def extract_text_from_pdf(pdf_path): | |
pdf_text = "" | |
with open(pdf_path, "rb") as file: | |
reader = PdfReader(file) | |
for page_num in range(len(reader.pages)): | |
page = reader.pages[page_num] | |
text = page.extract_text() | |
pdf_text += text + "\n" | |
return pdf_text | |
# Streamlit app | |
st.write("**Created by: Engr. Hamesh Raj** [LinkedIn](https://www.linkedin.com/in/datascientisthameshraj/)") | |
st.title("π PDF Question Answering") | |
# Sidebar for PDF upload | |
uploaded_file = st.sidebar.file_uploader("Upload a PDF file", type="pdf") | |
if uploaded_file is not None: | |
# Extract text from the uploaded PDF | |
pdf_text = extract_text_from_pdf(uploaded_file) | |
st.text_area("Extracted PDF Text", pdf_text, height=200) | |
# Input field for the user's question | |
user_query = st.text_input("Enter your question:") | |
# Display the submit button below the input field | |
if st.button("Submit") and user_query: | |
# Format the input text | |
input_text = f"{user_query}\n\n### Response:\n" | |
# Encode the input text into input ids | |
input_ids = tokenizer(input_text, return_tensors="pt") | |
# Use GPU for input ids if available | |
if torch.cuda.is_available(): | |
input_ids = input_ids.to("cuda") | |
# Generate text using the model | |
outputs = model.generate( | |
**input_ids, | |
max_new_tokens=150, # Limit the number of tokens generated | |
no_repeat_ngram_size=5, # Prevent repetition of 5-gram phrases | |
) | |
# Decode and print the results | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Display question and answer | |
st.write(f"**Q{len(st.session_state) + 1}: {user_query}**") | |
st.write(f"**A{len(st.session_state) + 1}: {answer.strip()}**") | |
# Store in session state for chat history | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
st.session_state.history.append({ | |
"question": user_query, | |
"answer": answer.strip() | |
}) | |
# Display chat history | |
if "history" in st.session_state: | |
for i, qa in enumerate(st.session_state.history): | |
st.write(f"**Q{i + 1}: {qa['question']}**") | |
st.write(f"**A{i + 1}: {qa['answer']}**") |