Spaces:
Running
Running
from copy import deepcopy | |
from typing import List | |
import torch | |
from PIL import Image | |
from surya.input.processing import convert_if_not_rgb | |
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel | |
from surya.schema import OrderBox, OrderResult | |
from surya.settings import settings | |
from tqdm import tqdm | |
import numpy as np | |
def get_batch_size(): | |
batch_size = settings.ORDER_BATCH_SIZE | |
if batch_size is None: | |
batch_size = 8 | |
if settings.TORCH_DEVICE_MODEL == "mps": | |
batch_size = 8 | |
if settings.TORCH_DEVICE_MODEL == "cuda": | |
batch_size = 32 | |
return batch_size | |
def rank_elements(arr): | |
enumerated_and_sorted = sorted(enumerate(arr), key=lambda x: x[1]) | |
rank = [0] * len(arr) | |
for rank_value, (original_index, value) in enumerate(enumerated_and_sorted): | |
rank[original_index] = rank_value | |
return rank | |
def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[OrderResult]: | |
assert all([isinstance(image, Image.Image) for image in images]) | |
assert len(images) == len(bboxes) | |
if batch_size is None: | |
batch_size = get_batch_size() | |
output_order = [] | |
for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"): | |
batch_bboxes = deepcopy(bboxes[i:i+batch_size]) | |
batch_images = images[i:i+batch_size] | |
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images | |
orig_sizes = [image.size for image in batch_images] | |
model_inputs = processor(images=batch_images, boxes=batch_bboxes) | |
batch_pixel_values = model_inputs["pixel_values"] | |
batch_bboxes = model_inputs["input_boxes"] | |
batch_bbox_mask = model_inputs["input_boxes_mask"] | |
batch_bbox_counts = model_inputs["input_boxes_counts"] | |
batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device) | |
batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device) | |
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device) | |
batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device) | |
token_count = 0 | |
past_key_values = None | |
encoder_outputs = None | |
batch_predictions = [[] for _ in range(len(batch_images))] | |
done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device) | |
with torch.inference_mode(): | |
while token_count < settings.ORDER_MAX_BOXES: | |
return_dict = model( | |
pixel_values=batch_pixel_values, | |
decoder_input_boxes=batch_bboxes, | |
decoder_input_boxes_mask=batch_bbox_mask, | |
decoder_input_boxes_counts=batch_bbox_counts, | |
encoder_outputs=encoder_outputs, | |
past_key_values=past_key_values, | |
) | |
logits = return_dict["logits"].detach() | |
last_tokens = [] | |
last_token_mask = [] | |
min_val = torch.finfo(model.dtype).min | |
for j in range(logits.shape[0]): | |
label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token | |
new_logits = logits[j, -1] | |
new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once | |
new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes | |
pred = int(torch.argmax(new_logits, dim=-1).item()) | |
# Add one to avoid colliding with the 1000 height/width token for bboxes | |
last_tokens.append([[pred + processor.box_size["height"] + 1] * 4]) | |
if len(batch_predictions[j]) == label_count - 1: # Minus one since we're appending the final label | |
last_token_mask.append([0]) | |
batch_predictions[j].append(pred) | |
done[j] = True | |
elif len(batch_predictions[j]) < label_count - 1: | |
last_token_mask.append([1]) | |
batch_predictions[j].append(pred) # Get rank prediction for given position | |
else: | |
last_token_mask.append([0]) | |
if done.all(): | |
break | |
past_key_values = return_dict["past_key_values"] | |
encoder_outputs = (return_dict["encoder_last_hidden_state"],) | |
batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device) | |
token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device) | |
batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1) | |
token_count += 1 | |
for j, row_pred in enumerate(batch_predictions): | |
row_bboxes = bboxes[i+j] | |
assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}" | |
orig_size = orig_sizes[j] | |
ranks = [0] * len(row_bboxes) | |
for box_idx in range(len(row_bboxes)): | |
ranks[row_pred[box_idx]] = box_idx | |
order_boxes = [] | |
for row_bbox, rank in zip(row_bboxes, ranks): | |
order_box = OrderBox( | |
bbox=row_bbox, | |
position=rank, | |
) | |
order_boxes.append(order_box) | |
result = OrderResult( | |
bboxes=order_boxes, | |
image_bbox=[0, 0, orig_size[0], orig_size[1]], | |
) | |
output_order.append(result) | |
return output_order | |