File size: 3,640 Bytes
90b9e1a
 
 
aa7e4bb
90b9e1a
 
aa7e4bb
90b9e1a
 
eb407e4
 
 
90b9e1a
 
 
 
 
 
 
 
 
 
 
 
e88f340
90b9e1a
 
90a9cee
90b9e1a
 
 
 
71a9e68
90b9e1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa7e4bb
 
eb407e4
aa7e4bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb407e4
 
 
 
 
 
 
 
 
aa7e4bb
 
 
eb407e4
90b9e1a
 
 
 
 
 
 
 
 
 
 
eb407e4
90b9e1a
 
1134889
 
 
 
b789e36
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import cv2
import gradio as gr
import numpy as np
import torch
from paddleocr import PaddleOCR
from PIL import Image
from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
from transformers.pipelines.document_question_answering import apply_tesseract

model_tag = "impira/layoutlm-document-qa"
MODEL = LayoutLMForQuestionAnswering.from_pretrained(model_tag).eval()
TOKENIZER = AutoTokenizer.from_pretrained(model_tag)
OCR = PaddleOCR(
    lang="en",
    det_limit_side_len=10_000,
    det_db_score_mode="slow",
)


PADDLE_OCR_LABEL = "PaddleOCR (en)"
TESSERACT_LABEL = "Tesseract (HF default)"


def predict(image: Image.Image, question: str, ocr_engine: str):
    image_np = np.array(image)

    if ocr_engine == PADDLE_OCR_LABEL:
        ocr_result = OCR.ocr(image_np, cls=False)[0]
        words = [x[1][0] for x in ocr_result]
        boxes = np.asarray([x[0] for x in ocr_result])  # (n_boxes, 4, 2)

        for box in boxes:
            cv2.polylines(image_np, [box.reshape(-1, 1, 2).astype(int)], True, (0, 255, 255), 3)

        x1 = boxes[:, :, 0].min(1) * 1000 / image.width
        y1 = boxes[:, :, 1].min(1) * 1000 / image.height
        x2 = boxes[:, :, 0].max(1) * 1000 / image.width
        y2 = boxes[:, :, 1].max(1) * 1000 / image.height

        # (n_boxes, 4) in xyxy format
        boxes = np.stack([x1, y1, x2, y2], axis=1).astype(int)

    elif ocr_engine == TESSERACT_LABEL:
        words, boxes = apply_tesseract(image, None, "")

        for x1, y1, x2, y2 in boxes:
            x1 = int(x1 * image.width / 1000)
            y1 = int(y1 * image.height / 1000)
            x2 = int(x2 * image.width / 1000)
            y2 = int(y2 * image.height / 1000)
            cv2.rectangle(image_np, (x1, y1), (x2, y2), (0, 255, 255), 3)

    else:
        raise ValueError(f"Unsupported ocr_engine={ocr_engine}")

    token_ids = TOKENIZER(question)["input_ids"]
    token_boxes = [[0] * 4] * (len(token_ids) - 1) + [[1000] * 4]
    n_question_tokens = len(token_ids)

    token_ids.append(TOKENIZER.sep_token_id)
    token_boxes.append([1000] * 4)

    for word, box in zip(words, boxes):
        new_ids = TOKENIZER(word, add_special_tokens=False)["input_ids"]
        token_ids.extend(new_ids)
        token_boxes.extend([box] * len(new_ids))

    token_ids.append(TOKENIZER.sep_token_id)
    token_boxes.append([1000] * 4)

    with torch.inference_mode():
        outputs = MODEL(
            input_ids=torch.tensor(token_ids).unsqueeze(0),
            bbox=torch.tensor(token_boxes).unsqueeze(0),
        )

    start_scores = outputs.start_logits.squeeze(0).softmax(-1)[n_question_tokens:]
    end_scores = outputs.end_logits.squeeze(0).softmax(-1)[n_question_tokens:]

    span_scores = start_scores.view(-1, 1) * end_scores.view(1, -1)
    span_scores = torch.triu(span_scores)  # don't allow start < end

    score, indices = span_scores.flatten().max(-1)
    start_idx = n_question_tokens + indices // span_scores.shape[1]
    end_idx = n_question_tokens + indices % span_scores.shape[1]

    answer = TOKENIZER.decode(token_ids[start_idx : end_idx + 1])

    return answer, score, image_np


gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil"),
        "text",
        gr.Radio([PADDLE_OCR_LABEL, TESSERACT_LABEL]),
    ],
    outputs=[
        gr.Textbox(label="Answer"),
        gr.Number(label="Score"),
        gr.Image(label="OCR results"),
    ],
    examples=[
        ["example_01.jpg", "When did the sample take place?", PADDLE_OCR_LABEL],
        ["example_02.jpg", "What is the ID number?", PADDLE_OCR_LABEL],
    ],
).launch(server_name="0.0.0.0", server_port=7860)