gaunernst commited on
Commit
eb407e4
1 Parent(s): 1134889

update decode logic

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -7,8 +7,9 @@ from PIL import Image
7
  from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
8
  from transformers.pipelines.document_question_answering import apply_tesseract
9
 
10
- MODEL = LayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa").eval()
11
- TOKENIZER = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa")
 
12
  OCR = PaddleOCR(
13
  use_angle_cls=True,
14
  lang="en",
@@ -56,6 +57,7 @@ def predict(image: Image.Image, question: str, ocr_engine: str):
56
 
57
  token_ids = TOKENIZER(question)["input_ids"]
58
  token_boxes = [[0] * 4] * (len(token_ids) - 1) + [[1000] * 4]
 
59
 
60
  token_ids.append(TOKENIZER.sep_token_id)
61
  token_boxes.append([1000] * 4)
@@ -74,14 +76,19 @@ def predict(image: Image.Image, question: str, ocr_engine: str):
74
  bbox=torch.tensor(token_boxes).unsqueeze(0),
75
  )
76
 
77
- start_scores = outputs.start_logits.squeeze(0).softmax(-1)
78
- end_scores = outputs.end_logits.squeeze(0).softmax(-1)
 
 
 
 
 
 
 
79
 
80
- start_score, start_idx = start_scores.max(-1)
81
- end_score, end_idx = end_scores.max(-1)
82
  answer = TOKENIZER.decode(token_ids[start_idx : end_idx + 1])
83
 
84
- return answer, start_score, end_score, image_np
85
 
86
 
87
  gr.Interface(
@@ -93,8 +100,7 @@ gr.Interface(
93
  ],
94
  outputs=[
95
  gr.Textbox(label="Answer"),
96
- gr.Number(label="Start score"),
97
- gr.Number(label="End score"),
98
  gr.Image(label="OCR results"),
99
  ],
100
  examples=[
 
7
  from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
8
  from transformers.pipelines.document_question_answering import apply_tesseract
9
 
10
+ model_tag = "impira/layoutlm-document-qa"
11
+ MODEL = LayoutLMForQuestionAnswering.from_pretrained(model_tag).eval()
12
+ TOKENIZER = AutoTokenizer.from_pretrained(model_tag)
13
  OCR = PaddleOCR(
14
  use_angle_cls=True,
15
  lang="en",
 
57
 
58
  token_ids = TOKENIZER(question)["input_ids"]
59
  token_boxes = [[0] * 4] * (len(token_ids) - 1) + [[1000] * 4]
60
+ n_question_tokens = len(token_ids)
61
 
62
  token_ids.append(TOKENIZER.sep_token_id)
63
  token_boxes.append([1000] * 4)
 
76
  bbox=torch.tensor(token_boxes).unsqueeze(0),
77
  )
78
 
79
+ start_scores = outputs.start_logits.squeeze(0).softmax(-1)[n_question_tokens:]
80
+ end_scores = outputs.end_logits.squeeze(0).softmax(-1)[n_question_tokens:]
81
+
82
+ span_scores = start_scores.view(-1, 1) * end_scores.view(1, -1)
83
+ span_scores = torch.triu(span_scores) # don't allow start < end
84
+
85
+ score, indices = span_scores.flatten().max(-1)
86
+ start_idx = n_question_tokens + indices // span_scores.shape[1]
87
+ end_idx = n_question_tokens + indices % span_scores.shape[1]
88
 
 
 
89
  answer = TOKENIZER.decode(token_ids[start_idx : end_idx + 1])
90
 
91
+ return answer, score, image_np
92
 
93
 
94
  gr.Interface(
 
100
  ],
101
  outputs=[
102
  gr.Textbox(label="Answer"),
103
+ gr.Number(label="Score"),
 
104
  gr.Image(label="OCR results"),
105
  ],
106
  examples=[