Spaces:
Running
Running
from collections import defaultdict | |
from copy import deepcopy | |
from typing import List, Dict | |
import torch | |
from PIL import Image | |
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel | |
from surya.schema import TableResult, TableCell, Bbox | |
from surya.settings import settings | |
from tqdm import tqdm | |
import numpy as np | |
from surya.model.table_rec.config import SPECIAL_TOKENS | |
def get_batch_size(): | |
batch_size = settings.TABLE_REC_BATCH_SIZE | |
if batch_size is None: | |
batch_size = 8 | |
if settings.TORCH_DEVICE_MODEL == "mps": | |
batch_size = 8 | |
if settings.TORCH_DEVICE_MODEL == "cuda": | |
batch_size = 64 | |
return batch_size | |
def sort_bboxes(bboxes, tolerance=1): | |
vertical_groups = {} | |
for block in bboxes: | |
group_key = round(block["bbox"][1] / tolerance) * tolerance | |
if group_key not in vertical_groups: | |
vertical_groups[group_key] = [] | |
vertical_groups[group_key].append(block) | |
# Sort each group horizontally and flatten the groups into a single list | |
sorted_page_blocks = [] | |
for _, group in sorted(vertical_groups.items()): | |
sorted_group = sorted(group, key=lambda x: x["bbox"][0]) | |
sorted_page_blocks.extend(sorted_group) | |
return sorted_page_blocks | |
def is_rotated(rows, cols): | |
# Determine if the table is rotated by looking at row and column width / height ratios | |
# Rows should have a >1 ratio, cols <1 | |
widths = sum([r.width for r in rows]) | |
heights = sum([c.height for c in rows]) + 1 | |
r_ratio = widths / heights | |
widths = sum([c.width for c in cols]) | |
heights = sum([r.height for r in cols]) + 1 | |
c_ratio = widths / heights | |
return r_ratio * 2 < c_ratio | |
def batch_table_recognition(images: List, table_cells: List[List[Dict]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]: | |
assert all([isinstance(image, Image.Image) for image in images]) | |
assert len(images) == len(table_cells) | |
if batch_size is None: | |
batch_size = get_batch_size() | |
output_order = [] | |
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables"): | |
batch_table_cells = deepcopy(table_cells[i:i+batch_size]) | |
batch_table_cells = [sort_bboxes(page_bboxes) for page_bboxes in batch_table_cells] # Sort bboxes before passing in | |
batch_list_bboxes = [[block["bbox"] for block in page] for page in batch_table_cells] | |
batch_images = images[i:i+batch_size] | |
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images | |
current_batch_size = len(batch_images) | |
orig_sizes = [image.size for image in batch_images] | |
model_inputs = processor(images=batch_images, boxes=deepcopy(batch_list_bboxes)) | |
batch_pixel_values = model_inputs["pixel_values"] | |
batch_bboxes = model_inputs["input_boxes"] | |
batch_bbox_mask = model_inputs["input_boxes_mask"] | |
batch_bbox_counts = model_inputs["input_boxes_counts"] | |
batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device) | |
batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device) | |
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device) | |
batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device) | |
# Setup inputs for the decoder | |
batch_decoder_input = [[[model.config.decoder.bos_token_id] * 5] for _ in range(current_batch_size)] | |
batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device) | |
inference_token_count = batch_decoder_input.shape[1] | |
max_tokens = min(batch_bbox_counts[:, 1].max().item(), settings.TABLE_REC_MAX_BOXES) | |
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :, 0], dtype=torch.int64, device=model.device).cumsum(0) - 1 | |
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) | |
model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) | |
batch_predictions = [[] for _ in range(current_batch_size)] | |
with torch.inference_mode(): | |
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state | |
text_encoder_hidden_states = model.text_encoder( | |
input_boxes=batch_bboxes, | |
input_boxes_counts=batch_bbox_counts, | |
cache_position=None, | |
attention_mask=batch_bbox_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=None, | |
use_cache=False | |
).hidden_states | |
token_count = 0 | |
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) | |
while token_count < max_tokens: | |
is_prefill = token_count == 0 | |
return_dict = model.decoder( | |
input_ids=batch_decoder_input, | |
encoder_hidden_states=text_encoder_hidden_states, | |
cache_position=decoder_position_ids, | |
use_cache=True, | |
prefill=is_prefill | |
) | |
decoder_position_ids = decoder_position_ids[-1:] + 1 | |
box_logits = return_dict["bbox_logits"][:, -1, :].detach() | |
rowcol_logits = return_dict["class_logits"][:, -1, :].detach() | |
rowcol_preds = torch.argmax(rowcol_logits, dim=-1) | |
box_preds = torch.argmax(box_logits, dim=-1) | |
done = (rowcol_preds == processor.tokenizer.eos_id) | (rowcol_preds == processor.tokenizer.pad_id) | |
done = done | |
all_done = all_done | done | |
if all_done.all(): | |
break | |
batch_decoder_input = torch.cat([box_preds.unsqueeze(1), rowcol_preds.unsqueeze(1).unsqueeze(1)], dim=-1) | |
for j, (pred, status) in enumerate(zip(batch_decoder_input, all_done)): | |
if not status: | |
batch_predictions[j].append(pred[0].tolist()) | |
token_count += inference_token_count | |
inference_token_count = batch_decoder_input.shape[1] | |
for j, (preds, input_cells, orig_size) in enumerate(zip(batch_predictions, batch_table_cells, orig_sizes)): | |
img_w, img_h = orig_size | |
width_scaler = img_w / model.config.decoder.out_box_size | |
height_scaler = img_h / model.config.decoder.out_box_size | |
# cx, cy to corners | |
for i, pred in enumerate(preds): | |
w = pred[2] / 2 | |
h = pred[3] / 2 | |
x1 = pred[0] - w | |
y1 = pred[1] - h | |
x2 = pred[0] + w | |
y2 = pred[1] + h | |
class_ = int(pred[4] - SPECIAL_TOKENS) | |
preds[i] = [x1 * width_scaler, y1 * height_scaler, x2 * width_scaler, y2 * height_scaler, class_] | |
# Get rows and columns | |
bb_rows = [p[:4] for p in preds if p[4] == 0] | |
bb_cols = [p[:4] for p in preds if p[4] == 1] | |
rows = [] | |
cols = [] | |
for row_idx, row in enumerate(bb_rows): | |
cell = TableCell( | |
bbox=row, | |
row_id=row_idx | |
) | |
rows.append(cell) | |
for col_idx, col in enumerate(bb_cols): | |
cell = TableCell( | |
bbox=col, | |
col_id=col_idx, | |
) | |
cols.append(cell) | |
# Assign cells to rows/columns | |
cells = [] | |
for cell in input_cells: | |
max_intersection = 0 | |
row_pred = None | |
for row_idx, row in enumerate(rows): | |
intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(row) | |
if intersection_pct > max_intersection: | |
max_intersection = intersection_pct | |
row_pred = row_idx | |
max_intersection = 0 | |
col_pred = None | |
for col_idx, col in enumerate(cols): | |
intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(col) | |
if intersection_pct > max_intersection: | |
max_intersection = intersection_pct | |
col_pred = col_idx | |
cells.append( | |
TableCell( | |
bbox=cell["bbox"], | |
text=cell.get("text"), | |
row_id=row_pred, | |
col_id=col_pred | |
) | |
) | |
rotated = is_rotated(rows, cols) | |
for cell in cells: | |
if cell.row_id is None: | |
closest_row = None | |
closest_row_dist = None | |
for cell2 in cells: | |
if cell2.row_id is None: | |
continue | |
if rotated: | |
cell_y_center = cell.center[0] | |
cell2_y_center = cell2.center[0] | |
else: | |
cell_y_center = cell.center[1] | |
cell2_y_center = cell2.center[1] | |
y_dist = abs(cell_y_center - cell2_y_center) | |
if closest_row_dist is None or y_dist < closest_row_dist: | |
closest_row = cell2.row_id | |
closest_row_dist = y_dist | |
cell.row_id = closest_row | |
if cell.col_id is None: | |
closest_col = None | |
closest_col_dist = None | |
for cell2 in cells: | |
if cell2.col_id is None: | |
continue | |
if rotated: | |
cell_x_center = cell.center[1] | |
cell2_x_center = cell2.center[1] | |
else: | |
cell_x_center = cell.center[0] | |
cell2_x_center = cell2.center[0] | |
x_dist = abs(cell2_x_center - cell_x_center) | |
if closest_col_dist is None or x_dist < closest_col_dist: | |
closest_col = cell2.col_id | |
closest_col_dist = x_dist | |
cell.col_id = closest_col | |
result = TableResult( | |
cells=cells, | |
rows=rows, | |
cols=cols, | |
image_bbox=[0, 0, img_w, img_h], | |
) | |
output_order.append(result) | |
return output_order |