Spaces:
Running
Running
File size: 8,264 Bytes
52f1bcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from typing import List
import torch
from PIL import Image
from surya.postprocessing.math.latex import fix_math, contains_math
from surya.postprocessing.text import truncate_repetitions
from surya.settings import settings
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
def get_batch_size():
batch_size = settings.RECOGNITION_BATCH_SIZE
if batch_size is None:
batch_size = 32
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 64 # 12GB RAM max
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 512
return batch_size
def pad_to_batch_size(tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor
pad_size = batch_size - current_batch_size
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)
return F.pad(tensor, padding, mode='constant', value=0)
def batch_recognition(images: List, languages: List[List[str] | None], model, processor, batch_size=None):
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(languages)
if len(images) == 0:
return [], []
if batch_size is None:
batch_size = get_batch_size()
# Sort images by width, so similar length ones go together
sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False)
indices, images = zip(*sorted_pairs)
indices = list(indices)
images = list(images)
output_text = []
confidences = []
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"):
batch_images = images[i:i+batch_size]
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images
batch_langs = languages[i:i+batch_size]
has_math = [lang and "_math" in lang for lang in batch_langs]
processed_batch = processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs)
batch_pixel_values = processed_batch["pixel_values"]
batch_langs = processed_batch["langs"]
batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]
max_input_length = max([len(tokens) for tokens in batch_decoder_input])
# Pad decoder input to max length if needed, to ensure we can convert to a tensor
for token_idx in range(len(batch_decoder_input)):
lang_len = len(batch_decoder_input[token_idx])
if lang_len < max_input_length:
batch_decoder_input[token_idx] = [processor.tokenizer.pad_id] * (max_input_length - lang_len) + batch_decoder_input[token_idx]
current_batch_size = len(batch_pixel_values)
batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device)
batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)
token_count = 0
inference_token_count = batch_decoder_input.shape[-1]
batch_predictions = [[] for _ in range(current_batch_size)]
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
sequence_scores = None
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
encoder_hidden_states = None
with torch.no_grad(): # inference_mode doesn't work with torch.compile
encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR + 1
for z in range(0, batch_pixel_values.shape[0], encoder_batch_size):
encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])]
encoder_hidden_states_batch = model.encoder(pixel_values=encoder_pixel_values).last_hidden_state
if encoder_hidden_states is None:
encoder_hidden_states = encoder_hidden_states_batch
else:
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0)
text_encoder_input_ids = torch.arange(
model.text_encoder.config.query_token_count,
device=encoder_hidden_states.device,
dtype=torch.long
).unsqueeze(0).expand(encoder_hidden_states.size(0), -1)
encoder_text_hidden_states = model.text_encoder(
input_ids=text_encoder_input_ids,
cache_position=None,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
use_cache=False
).hidden_states
del encoder_hidden_states
if settings.RECOGNITION_STATIC_CACHE:
# Pad inputs to max batch size for static cache
encoder_text_hidden_states = pad_to_batch_size(encoder_text_hidden_states, batch_size)
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)
while token_count < settings.RECOGNITION_MAX_TOKENS - 1:
is_prefill = token_count == 0
#TODO: add attention mask
return_dict = model.decoder(
input_ids=batch_decoder_input,
encoder_hidden_states=encoder_text_hidden_states,
cache_position=decoder_position_ids,
use_cache=True,
prefill=is_prefill
)
decoder_position_ids = decoder_position_ids[-1:] + 1
logits = return_dict["logits"][:current_batch_size] # Ignore batch padding
aux_logits = return_dict.get("aux_logits", None)
preds = torch.argmax(logits[:, -1], dim=-1)
scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1)
done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id)
done = done
all_done = all_done | done
if is_prefill:
sequence_scores = scores
else:
scores = scores.masked_fill(all_done, 0)
sequence_scores = torch.cat([sequence_scores, scores], dim=1)
if all_done.all():
break
batch_decoder_input = preds.unsqueeze(1)
for j, (pred, status) in enumerate(zip(preds, all_done)):
if not status:
batch_predictions[j].append(int(pred))
token_count += inference_token_count
inference_token_count = batch_decoder_input.shape[-1]
max_position_id = torch.max(decoder_position_ids).item()
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + max_position_id
if settings.RECOGNITION_STATIC_CACHE:
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)
sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
detected_text = processor.tokenizer.batch_decode(batch_predictions)
detected_text = [truncate_repetitions(dt) for dt in detected_text]
# Postprocess to fix LaTeX output (add $$ signs, etc)
detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)]
output_text.extend(detected_text)
confidences.extend(sequence_scores.tolist())
del encoder_text_hidden_states
output_text = sorted(zip(indices, output_text), key=lambda x: x[0])
confidences = sorted(zip(indices, confidences), key=lambda x: x[0])
output_text = [text for _, text in output_text]
confidences = [conf for _, conf in confidences]
return output_text, confidences
|