from typing import List import torch from PIL import Image from surya.postprocessing.math.latex import fix_math, contains_math from surya.postprocessing.text import truncate_repetitions from surya.settings import settings from tqdm import tqdm import numpy as np import torch.nn.functional as F def get_batch_size(): batch_size = settings.RECOGNITION_BATCH_SIZE if batch_size is None: batch_size = 32 if settings.TORCH_DEVICE_MODEL == "mps": batch_size = 64 # 12GB RAM max if settings.TORCH_DEVICE_MODEL == "cuda": batch_size = 512 return batch_size def pad_to_batch_size(tensor, batch_size): current_batch_size = tensor.shape[0] if current_batch_size >= batch_size: return tensor pad_size = batch_size - current_batch_size padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) return F.pad(tensor, padding, mode='constant', value=0) def batch_recognition(images: List, languages: List[List[str] | None], model, processor, batch_size=None): assert all([isinstance(image, Image.Image) for image in images]) assert len(images) == len(languages) if len(images) == 0: return [], [] if batch_size is None: batch_size = get_batch_size() # Sort images by width, so similar length ones go together sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False) indices, images = zip(*sorted_pairs) indices = list(indices) images = list(images) output_text = [] confidences = [] for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"): batch_images = images[i:i+batch_size] batch_images = [image.convert("RGB") for image in batch_images] # also copies the images batch_langs = languages[i:i+batch_size] has_math = [lang and "_math" in lang for lang in batch_langs] processed_batch = processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs) batch_pixel_values = processed_batch["pixel_values"] batch_langs = processed_batch["langs"] batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs] max_input_length = max([len(tokens) for tokens in batch_decoder_input]) # Pad decoder input to max length if needed, to ensure we can convert to a tensor for token_idx in range(len(batch_decoder_input)): lang_len = len(batch_decoder_input[token_idx]) if lang_len < max_input_length: batch_decoder_input[token_idx] = [processor.tokenizer.pad_id] * (max_input_length - lang_len) + batch_decoder_input[token_idx] current_batch_size = len(batch_pixel_values) batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device) batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device) token_count = 0 inference_token_count = batch_decoder_input.shape[-1] batch_predictions = [[] for _ in range(current_batch_size)] decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) sequence_scores = None all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) encoder_hidden_states = None with torch.no_grad(): # inference_mode doesn't work with torch.compile encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR + 1 for z in range(0, batch_pixel_values.shape[0], encoder_batch_size): encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])] encoder_hidden_states_batch = model.encoder(pixel_values=encoder_pixel_values).last_hidden_state if encoder_hidden_states is None: encoder_hidden_states = encoder_hidden_states_batch else: encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0) text_encoder_input_ids = torch.arange( model.text_encoder.config.query_token_count, device=encoder_hidden_states.device, dtype=torch.long ).unsqueeze(0).expand(encoder_hidden_states.size(0), -1) encoder_text_hidden_states = model.text_encoder( input_ids=text_encoder_input_ids, cache_position=None, attention_mask=None, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=None, use_cache=False ).hidden_states del encoder_hidden_states if settings.RECOGNITION_STATIC_CACHE: # Pad inputs to max batch size for static cache encoder_text_hidden_states = pad_to_batch_size(encoder_text_hidden_states, batch_size) batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) while token_count < settings.RECOGNITION_MAX_TOKENS - 1: is_prefill = token_count == 0 #TODO: add attention mask return_dict = model.decoder( input_ids=batch_decoder_input, encoder_hidden_states=encoder_text_hidden_states, cache_position=decoder_position_ids, use_cache=True, prefill=is_prefill ) decoder_position_ids = decoder_position_ids[-1:] + 1 logits = return_dict["logits"][:current_batch_size] # Ignore batch padding aux_logits = return_dict.get("aux_logits", None) preds = torch.argmax(logits[:, -1], dim=-1) scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1) done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id) done = done all_done = all_done | done if is_prefill: sequence_scores = scores else: scores = scores.masked_fill(all_done, 0) sequence_scores = torch.cat([sequence_scores, scores], dim=1) if all_done.all(): break batch_decoder_input = preds.unsqueeze(1) for j, (pred, status) in enumerate(zip(preds, all_done)): if not status: batch_predictions[j].append(int(pred)) token_count += inference_token_count inference_token_count = batch_decoder_input.shape[-1] max_position_id = torch.max(decoder_position_ids).item() decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + max_position_id if settings.RECOGNITION_STATIC_CACHE: batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) detected_text = processor.tokenizer.batch_decode(batch_predictions) detected_text = [truncate_repetitions(dt) for dt in detected_text] # Postprocess to fix LaTeX output (add $$ signs, etc) detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)] output_text.extend(detected_text) confidences.extend(sequence_scores.tolist()) del encoder_text_hidden_states output_text = sorted(zip(indices, output_text), key=lambda x: x[0]) confidences = sorted(zip(indices, confidences), key=lambda x: x[0]) output_text = [text for _, text in output_text] confidences = [conf for _, conf in confidences] return output_text, confidences