Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import fitz # PyMuPDF for PDF handling | |
# Load the model and tokenizer | |
def load_model(): | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG") | |
model = AutoModelForCausalLM.from_pretrained( | |
"himmeow/vi-gemma-2b-RAG", | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
) | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
return tokenizer, model | |
# Function to extract text from PDF | |
def extract_text_from_pdf(pdf_file): | |
# Extract text from the uploaded PDF file using PyMuPDF | |
doc = fitz.open(stream=pdf_file.read(), filetype="pdf") | |
text = "" | |
for page_num in range(doc.page_count): | |
page = doc.load_page(page_num) | |
text += page.get_text("text") # Ensure text extraction | |
return text | |
# Function to generate response from model | |
def generate_response(input_text, query, tokenizer, model): | |
# Format the input prompt for the model | |
prompt = f""" | |
### Instruction and Input: | |
Based on the following context/document: | |
{input_text} | |
Please answer the question: {query} | |
### Response: | |
""" | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
if torch.cuda.is_available(): | |
input_ids = input_ids.to("cuda") | |
# Generate a response from the model | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=500, | |
no_repeat_ngram_size=5 | |
) | |
# Decode the generated output into readable text | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Streamlit app main function | |
def main(): | |
st.title("PDF Question Answering with vi-gemma-2b-RAG") | |
# File uploader widget for PDF files | |
pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"]) | |
if pdf_file is not None: | |
with st.spinner("Reading the PDF..."): | |
# Extract text from the uploaded PDF | |
pdf_text = extract_text_from_pdf(pdf_file) | |
st.text_area("Extracted Text", pdf_text, height=300) | |
# Text input for the user's question | |
query = st.text_input("Enter your question:") | |
if st.button("Get Answer"): | |
if query.strip() == "": | |
st.warning("Please enter a question.") | |
else: | |
with st.spinner("Generating response..."): | |
# Load the model and tokenizer | |
tokenizer, model = load_model() | |
# Generate the response using the model | |
try: | |
response = generate_response(pdf_text, query, tokenizer, model) | |
st.text_area("Response", response, height=200) | |
except Exception as e: | |
st.error(f"Error generating response: {e}") | |
if __name__ == "__main__": | |
main() | |