SuryaOCR / surya /ordering.py
Jiangxz01's picture
Upload 56 files
52f1bcb verified
raw
history blame
5.94 kB
from copy import deepcopy
from typing import List
import torch
from PIL import Image
from surya.input.processing import convert_if_not_rgb
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
from surya.schema import OrderBox, OrderResult
from surya.settings import settings
from tqdm import tqdm
import numpy as np
def get_batch_size():
batch_size = settings.ORDER_BATCH_SIZE
if batch_size is None:
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 32
return batch_size
def rank_elements(arr):
enumerated_and_sorted = sorted(enumerate(arr), key=lambda x: x[1])
rank = [0] * len(arr)
for rank_value, (original_index, value) in enumerate(enumerated_and_sorted):
rank[original_index] = rank_value
return rank
def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[OrderResult]:
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(bboxes)
if batch_size is None:
batch_size = get_batch_size()
output_order = []
for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"):
batch_bboxes = deepcopy(bboxes[i:i+batch_size])
batch_images = images[i:i+batch_size]
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images
orig_sizes = [image.size for image in batch_images]
model_inputs = processor(images=batch_images, boxes=batch_bboxes)
batch_pixel_values = model_inputs["pixel_values"]
batch_bboxes = model_inputs["input_boxes"]
batch_bbox_mask = model_inputs["input_boxes_mask"]
batch_bbox_counts = model_inputs["input_boxes_counts"]
batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device)
batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device)
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device)
token_count = 0
past_key_values = None
encoder_outputs = None
batch_predictions = [[] for _ in range(len(batch_images))]
done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device)
with torch.inference_mode():
while token_count < settings.ORDER_MAX_BOXES:
return_dict = model(
pixel_values=batch_pixel_values,
decoder_input_boxes=batch_bboxes,
decoder_input_boxes_mask=batch_bbox_mask,
decoder_input_boxes_counts=batch_bbox_counts,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
)
logits = return_dict["logits"].detach()
last_tokens = []
last_token_mask = []
min_val = torch.finfo(model.dtype).min
for j in range(logits.shape[0]):
label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token
new_logits = logits[j, -1]
new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once
new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes
pred = int(torch.argmax(new_logits, dim=-1).item())
# Add one to avoid colliding with the 1000 height/width token for bboxes
last_tokens.append([[pred + processor.box_size["height"] + 1] * 4])
if len(batch_predictions[j]) == label_count - 1: # Minus one since we're appending the final label
last_token_mask.append([0])
batch_predictions[j].append(pred)
done[j] = True
elif len(batch_predictions[j]) < label_count - 1:
last_token_mask.append([1])
batch_predictions[j].append(pred) # Get rank prediction for given position
else:
last_token_mask.append([0])
if done.all():
break
past_key_values = return_dict["past_key_values"]
encoder_outputs = (return_dict["encoder_last_hidden_state"],)
batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device)
token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device)
batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1)
token_count += 1
for j, row_pred in enumerate(batch_predictions):
row_bboxes = bboxes[i+j]
assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}"
orig_size = orig_sizes[j]
ranks = [0] * len(row_bboxes)
for box_idx in range(len(row_bboxes)):
ranks[row_pred[box_idx]] = box_idx
order_boxes = []
for row_bbox, rank in zip(row_bboxes, ranks):
order_box = OrderBox(
bbox=row_bbox,
position=rank,
)
order_boxes.append(order_box)
result = OrderResult(
bboxes=order_boxes,
image_bbox=[0, 0, orig_size[0], orig_size[1]],
)
output_order.append(result)
return output_order