|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
|
from ..models.auto import AutoProcessor |
|
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel |
|
from ..utils import is_vision_available |
|
from .base import PipelineTool |
|
|
|
|
|
if is_vision_available(): |
|
from PIL import Image |
|
|
|
|
|
class DocumentQuestionAnsweringTool(PipelineTool): |
|
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa" |
|
description = ( |
|
"This is a tool that answers a question about an document (pdf). It takes an input named `document` which " |
|
"should be the document containing the information, as well as a `question` that is the question about the " |
|
"document. It returns a text that contains the answer to the question." |
|
) |
|
name = "document_qa" |
|
pre_processor_class = AutoProcessor |
|
model_class = VisionEncoderDecoderModel |
|
|
|
inputs = ["image", "text"] |
|
outputs = ["text"] |
|
|
|
def __init__(self, *args, **kwargs): |
|
if not is_vision_available(): |
|
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.") |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
def encode(self, document: "Image", question: str): |
|
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" |
|
prompt = task_prompt.replace("{user_input}", question) |
|
decoder_input_ids = self.pre_processor.tokenizer( |
|
prompt, add_special_tokens=False, return_tensors="pt" |
|
).input_ids |
|
pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values |
|
|
|
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values} |
|
|
|
def forward(self, inputs): |
|
return self.model.generate( |
|
inputs["pixel_values"].to(self.device), |
|
decoder_input_ids=inputs["decoder_input_ids"].to(self.device), |
|
max_length=self.model.decoder.config.max_position_embeddings, |
|
early_stopping=True, |
|
pad_token_id=self.pre_processor.tokenizer.pad_token_id, |
|
eos_token_id=self.pre_processor.tokenizer.eos_token_id, |
|
use_cache=True, |
|
num_beams=1, |
|
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]], |
|
return_dict_in_generate=True, |
|
).sequences |
|
|
|
def decode(self, outputs): |
|
sequence = self.pre_processor.batch_decode(outputs)[0] |
|
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "") |
|
sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "") |
|
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() |
|
sequence = self.pre_processor.token2json(sequence) |
|
|
|
return sequence["answer"] |
|
|