import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch import fitz # PyMuPDF for PDF handling # Load the model and tokenizer @st.cache_resource def load_model(): # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG") model = AutoModelForCausalLM.from_pretrained( "himmeow/vi-gemma-2b-RAG", device_map="auto", torch_dtype=torch.bfloat16 ) if torch.cuda.is_available(): model.to("cuda") return tokenizer, model # Function to extract text from PDF def extract_text_from_pdf(pdf_file): # Extract text from the uploaded PDF file using PyMuPDF doc = fitz.open(stream=pdf_file.read(), filetype="pdf") text = "" for page_num in range(doc.page_count): page = doc.load_page(page_num) text += page.get_text("text") # Ensure text extraction return text # Function to generate response from model def generate_response(input_text, query, tokenizer, model): # Format the input prompt for the model prompt = f""" ### Instruction and Input: Based on the following context/document: {input_text} Please answer the question: {query} ### Response: """ input_ids = tokenizer(prompt, return_tensors="pt").input_ids if torch.cuda.is_available(): input_ids = input_ids.to("cuda") # Generate a response from the model outputs = model.generate( input_ids=input_ids, max_new_tokens=500, no_repeat_ngram_size=5 ) # Decode the generated output into readable text return tokenizer.decode(outputs[0], skip_special_tokens=True) # Streamlit app main function def main(): st.title("PDF Question Answering with vi-gemma-2b-RAG") # File uploader widget for PDF files pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"]) if pdf_file is not None: with st.spinner("Reading the PDF..."): # Extract text from the uploaded PDF pdf_text = extract_text_from_pdf(pdf_file) st.text_area("Extracted Text", pdf_text, height=300) # Text input for the user's question query = st.text_input("Enter your question:") if st.button("Get Answer"): if query.strip() == "": st.warning("Please enter a question.") else: with st.spinner("Generating response..."): # Load the model and tokenizer tokenizer, model = load_model() # Generate the response using the model try: response = generate_response(pdf_text, query, tokenizer, model) st.text_area("Response", response, height=200) except Exception as e: st.error(f"Error generating response: {e}") if __name__ == "__main__": main()