SuryaOCR / surya /recognition.py
Jiangxz01's picture
Upload 56 files
52f1bcb verified
raw
history blame
8.26 kB
from typing import List
import torch
from PIL import Image
from surya.postprocessing.math.latex import fix_math, contains_math
from surya.postprocessing.text import truncate_repetitions
from surya.settings import settings
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
def get_batch_size():
batch_size = settings.RECOGNITION_BATCH_SIZE
if batch_size is None:
batch_size = 32
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 64 # 12GB RAM max
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 512
return batch_size
def pad_to_batch_size(tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor
pad_size = batch_size - current_batch_size
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)
return F.pad(tensor, padding, mode='constant', value=0)
def batch_recognition(images: List, languages: List[List[str] | None], model, processor, batch_size=None):
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(languages)
if len(images) == 0:
return [], []
if batch_size is None:
batch_size = get_batch_size()
# Sort images by width, so similar length ones go together
sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False)
indices, images = zip(*sorted_pairs)
indices = list(indices)
images = list(images)
output_text = []
confidences = []
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"):
batch_images = images[i:i+batch_size]
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images
batch_langs = languages[i:i+batch_size]
has_math = [lang and "_math" in lang for lang in batch_langs]
processed_batch = processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs)
batch_pixel_values = processed_batch["pixel_values"]
batch_langs = processed_batch["langs"]
batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]
max_input_length = max([len(tokens) for tokens in batch_decoder_input])
# Pad decoder input to max length if needed, to ensure we can convert to a tensor
for token_idx in range(len(batch_decoder_input)):
lang_len = len(batch_decoder_input[token_idx])
if lang_len < max_input_length:
batch_decoder_input[token_idx] = [processor.tokenizer.pad_id] * (max_input_length - lang_len) + batch_decoder_input[token_idx]
current_batch_size = len(batch_pixel_values)
batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device)
batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)
token_count = 0
inference_token_count = batch_decoder_input.shape[-1]
batch_predictions = [[] for _ in range(current_batch_size)]
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
sequence_scores = None
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
encoder_hidden_states = None
with torch.no_grad(): # inference_mode doesn't work with torch.compile
encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR + 1
for z in range(0, batch_pixel_values.shape[0], encoder_batch_size):
encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])]
encoder_hidden_states_batch = model.encoder(pixel_values=encoder_pixel_values).last_hidden_state
if encoder_hidden_states is None:
encoder_hidden_states = encoder_hidden_states_batch
else:
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0)
text_encoder_input_ids = torch.arange(
model.text_encoder.config.query_token_count,
device=encoder_hidden_states.device,
dtype=torch.long
).unsqueeze(0).expand(encoder_hidden_states.size(0), -1)
encoder_text_hidden_states = model.text_encoder(
input_ids=text_encoder_input_ids,
cache_position=None,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
use_cache=False
).hidden_states
del encoder_hidden_states
if settings.RECOGNITION_STATIC_CACHE:
# Pad inputs to max batch size for static cache
encoder_text_hidden_states = pad_to_batch_size(encoder_text_hidden_states, batch_size)
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)
while token_count < settings.RECOGNITION_MAX_TOKENS - 1:
is_prefill = token_count == 0
#TODO: add attention mask
return_dict = model.decoder(
input_ids=batch_decoder_input,
encoder_hidden_states=encoder_text_hidden_states,
cache_position=decoder_position_ids,
use_cache=True,
prefill=is_prefill
)
decoder_position_ids = decoder_position_ids[-1:] + 1
logits = return_dict["logits"][:current_batch_size] # Ignore batch padding
aux_logits = return_dict.get("aux_logits", None)
preds = torch.argmax(logits[:, -1], dim=-1)
scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1)
done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id)
done = done
all_done = all_done | done
if is_prefill:
sequence_scores = scores
else:
scores = scores.masked_fill(all_done, 0)
sequence_scores = torch.cat([sequence_scores, scores], dim=1)
if all_done.all():
break
batch_decoder_input = preds.unsqueeze(1)
for j, (pred, status) in enumerate(zip(preds, all_done)):
if not status:
batch_predictions[j].append(int(pred))
token_count += inference_token_count
inference_token_count = batch_decoder_input.shape[-1]
max_position_id = torch.max(decoder_position_ids).item()
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + max_position_id
if settings.RECOGNITION_STATIC_CACHE:
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)
sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
detected_text = processor.tokenizer.batch_decode(batch_predictions)
detected_text = [truncate_repetitions(dt) for dt in detected_text]
# Postprocess to fix LaTeX output (add $$ signs, etc)
detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)]
output_text.extend(detected_text)
confidences.extend(sequence_scores.tolist())
del encoder_text_hidden_states
output_text = sorted(zip(indices, output_text), key=lambda x: x[0])
confidences = sorted(zip(indices, confidences), key=lambda x: x[0])
output_text = [text for _, text in output_text]
confidences = [conf for _, conf in confidences]
return output_text, confidences