File size: 2,205 Bytes
3907dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c837340
 
 
 
 
 
3907dec
 
 
 
 
c837340
3907dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
model = AutoModelForCausalLM.from_pretrained(
    "himmeow/vi-gemma-2b-RAG",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

if torch.cuda.is_available():
    model.to("cuda")

# Define the prompt format for the model
prompt = """
### Instruction and Input:
Based on the following context/document:
{}
Please answer the question: {}

### Response:
{}
"""

def extract_text_from_pdf(pdf):
    pdf_Text = ""
    reader = PdfReader(pdf)
    for page_num in range(len(reader.pages)):
        page = reader.pages[page_num]
        text = page.extract_text()
        if text:
            pdf_Text += text + "\n"
    if not pdf_Text.strip():
        pdf_Text = "The PDF contains no extractable text."
    print("Extracted Text:\n", pdf_Text)  # Debugging statement
    return pdf_Text

def generate_response(pdf, query):
    pdf_Text = extract_text_from_pdf(pdf)
    if not pdf_Text.strip():
        return "The PDF appears to be empty or unreadable."
    
    input_text = prompt.format(pdf_Text, query, " ")
    print("Input Text for Model:\n", input_text)  # Debugging statement

    input_ids = tokenizer(input_text, return_tensors="pt")

    if torch.cuda.is_available():
        input_ids = input_ids.to("cuda")

    try:
        outputs = model.generate(
            **input_ids,
            max_new_tokens=500,
            no_repeat_ngram_size=5,
        )
        response = tokenizer.decode(outputs[0])
    except Exception as e:
        response = "An error occurred while generating the response."
        print("Error:", e)
    
    print("Generated Response:\n", response)  # Debugging statement
    return response

# Gradio interface
iface = gr.Interface(
    fn=generate_response,
    inputs=[gr.File(label="Upload PDF"), gr.Textbox(label="Ask a question")],
    outputs="text",
    title="PDF Question Answering with vi-gemma-2b-RAG",
    description="Upload a PDF and ask a question based on its content. The model will generate a response."
)

iface.launch()