Spaces:
Running
Running
from typing import List | |
import torch | |
from PIL import Image | |
from surya.postprocessing.math.latex import fix_math, contains_math | |
from surya.postprocessing.text import truncate_repetitions | |
from surya.settings import settings | |
from tqdm import tqdm | |
import numpy as np | |
import torch.nn.functional as F | |
def get_batch_size(): | |
batch_size = settings.RECOGNITION_BATCH_SIZE | |
if batch_size is None: | |
batch_size = 32 | |
if settings.TORCH_DEVICE_MODEL == "mps": | |
batch_size = 64 # 12GB RAM max | |
if settings.TORCH_DEVICE_MODEL == "cuda": | |
batch_size = 512 | |
return batch_size | |
def pad_to_batch_size(tensor, batch_size): | |
current_batch_size = tensor.shape[0] | |
if current_batch_size >= batch_size: | |
return tensor | |
pad_size = batch_size - current_batch_size | |
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) | |
return F.pad(tensor, padding, mode='constant', value=0) | |
def batch_recognition(images: List, languages: List[List[str] | None], model, processor, batch_size=None): | |
assert all([isinstance(image, Image.Image) for image in images]) | |
assert len(images) == len(languages) | |
if len(images) == 0: | |
return [], [] | |
if batch_size is None: | |
batch_size = get_batch_size() | |
# Sort images by width, so similar length ones go together | |
sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False) | |
indices, images = zip(*sorted_pairs) | |
indices = list(indices) | |
images = list(images) | |
output_text = [] | |
confidences = [] | |
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"): | |
batch_images = images[i:i+batch_size] | |
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images | |
batch_langs = languages[i:i+batch_size] | |
has_math = [lang and "_math" in lang for lang in batch_langs] | |
processed_batch = processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs) | |
batch_pixel_values = processed_batch["pixel_values"] | |
batch_langs = processed_batch["langs"] | |
batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs] | |
max_input_length = max([len(tokens) for tokens in batch_decoder_input]) | |
# Pad decoder input to max length if needed, to ensure we can convert to a tensor | |
for token_idx in range(len(batch_decoder_input)): | |
lang_len = len(batch_decoder_input[token_idx]) | |
if lang_len < max_input_length: | |
batch_decoder_input[token_idx] = [processor.tokenizer.pad_id] * (max_input_length - lang_len) + batch_decoder_input[token_idx] | |
current_batch_size = len(batch_pixel_values) | |
batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device) | |
batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device) | |
token_count = 0 | |
inference_token_count = batch_decoder_input.shape[-1] | |
batch_predictions = [[] for _ in range(current_batch_size)] | |
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 | |
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) | |
model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) | |
sequence_scores = None | |
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) | |
encoder_hidden_states = None | |
with torch.no_grad(): # inference_mode doesn't work with torch.compile | |
encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR + 1 | |
for z in range(0, batch_pixel_values.shape[0], encoder_batch_size): | |
encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])] | |
encoder_hidden_states_batch = model.encoder(pixel_values=encoder_pixel_values).last_hidden_state | |
if encoder_hidden_states is None: | |
encoder_hidden_states = encoder_hidden_states_batch | |
else: | |
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0) | |
text_encoder_input_ids = torch.arange( | |
model.text_encoder.config.query_token_count, | |
device=encoder_hidden_states.device, | |
dtype=torch.long | |
).unsqueeze(0).expand(encoder_hidden_states.size(0), -1) | |
encoder_text_hidden_states = model.text_encoder( | |
input_ids=text_encoder_input_ids, | |
cache_position=None, | |
attention_mask=None, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=None, | |
use_cache=False | |
).hidden_states | |
del encoder_hidden_states | |
if settings.RECOGNITION_STATIC_CACHE: | |
# Pad inputs to max batch size for static cache | |
encoder_text_hidden_states = pad_to_batch_size(encoder_text_hidden_states, batch_size) | |
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) | |
while token_count < settings.RECOGNITION_MAX_TOKENS - 1: | |
is_prefill = token_count == 0 | |
#TODO: add attention mask | |
return_dict = model.decoder( | |
input_ids=batch_decoder_input, | |
encoder_hidden_states=encoder_text_hidden_states, | |
cache_position=decoder_position_ids, | |
use_cache=True, | |
prefill=is_prefill | |
) | |
decoder_position_ids = decoder_position_ids[-1:] + 1 | |
logits = return_dict["logits"][:current_batch_size] # Ignore batch padding | |
aux_logits = return_dict.get("aux_logits", None) | |
preds = torch.argmax(logits[:, -1], dim=-1) | |
scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1) | |
done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id) | |
done = done | |
all_done = all_done | done | |
if is_prefill: | |
sequence_scores = scores | |
else: | |
scores = scores.masked_fill(all_done, 0) | |
sequence_scores = torch.cat([sequence_scores, scores], dim=1) | |
if all_done.all(): | |
break | |
batch_decoder_input = preds.unsqueeze(1) | |
for j, (pred, status) in enumerate(zip(preds, all_done)): | |
if not status: | |
batch_predictions[j].append(int(pred)) | |
token_count += inference_token_count | |
inference_token_count = batch_decoder_input.shape[-1] | |
max_position_id = torch.max(decoder_position_ids).item() | |
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + max_position_id | |
if settings.RECOGNITION_STATIC_CACHE: | |
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) | |
sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) | |
detected_text = processor.tokenizer.batch_decode(batch_predictions) | |
detected_text = [truncate_repetitions(dt) for dt in detected_text] | |
# Postprocess to fix LaTeX output (add $$ signs, etc) | |
detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)] | |
output_text.extend(detected_text) | |
confidences.extend(sequence_scores.tolist()) | |
del encoder_text_hidden_states | |
output_text = sorted(zip(indices, output_text), key=lambda x: x[0]) | |
confidences = sorted(zip(indices, confidences), key=lambda x: x[0]) | |
output_text = [text for _, text in output_text] | |
confidences = [conf for _, conf in confidences] | |
return output_text, confidences | |