File size: 8,264 Bytes
52f1bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from typing import List
import torch
from PIL import Image

from surya.postprocessing.math.latex import fix_math, contains_math
from surya.postprocessing.text import truncate_repetitions
from surya.settings import settings
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F


def get_batch_size():
    batch_size = settings.RECOGNITION_BATCH_SIZE
    if batch_size is None:
        batch_size = 32
        if settings.TORCH_DEVICE_MODEL == "mps":
            batch_size = 64 # 12GB RAM max
        if settings.TORCH_DEVICE_MODEL == "cuda":
            batch_size = 512
    return batch_size


def pad_to_batch_size(tensor, batch_size):
    current_batch_size = tensor.shape[0]
    if current_batch_size >= batch_size:
        return tensor

    pad_size = batch_size - current_batch_size
    padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)

    return F.pad(tensor, padding, mode='constant', value=0)


def batch_recognition(images: List, languages: List[List[str] | None], model, processor, batch_size=None):
    assert all([isinstance(image, Image.Image) for image in images])
    assert len(images) == len(languages)

    if len(images) == 0:
        return [], []

    if batch_size is None:
        batch_size = get_batch_size()

    # Sort images by width, so similar length ones go together
    sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False)
    indices, images = zip(*sorted_pairs)
    indices = list(indices)
    images = list(images)

    output_text = []
    confidences = []
    for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"):
        batch_images = images[i:i+batch_size]
        batch_images = [image.convert("RGB") for image in batch_images]  # also copies the images

        batch_langs = languages[i:i+batch_size]
        has_math = [lang and "_math" in lang for lang in batch_langs]

        processed_batch = processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs)

        batch_pixel_values = processed_batch["pixel_values"]
        batch_langs = processed_batch["langs"]
        batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]
        max_input_length = max([len(tokens) for tokens in batch_decoder_input])

        # Pad decoder input to max length if needed, to ensure we can convert to a tensor
        for token_idx in range(len(batch_decoder_input)):
            lang_len = len(batch_decoder_input[token_idx])
            if lang_len < max_input_length:
                batch_decoder_input[token_idx] = [processor.tokenizer.pad_id] * (max_input_length - lang_len) + batch_decoder_input[token_idx]

        current_batch_size = len(batch_pixel_values)

        batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device)
        batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)

        token_count = 0
        inference_token_count = batch_decoder_input.shape[-1]
        batch_predictions = [[] for _ in range(current_batch_size)]

        decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1
        model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
        model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)

        sequence_scores = None
        all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
        encoder_hidden_states = None

        with torch.no_grad(): # inference_mode doesn't work with torch.compile
            encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR + 1
            for z in range(0, batch_pixel_values.shape[0], encoder_batch_size):
                encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])]
                encoder_hidden_states_batch = model.encoder(pixel_values=encoder_pixel_values).last_hidden_state
                if encoder_hidden_states is None:
                    encoder_hidden_states = encoder_hidden_states_batch
                else:
                    encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0)

            text_encoder_input_ids = torch.arange(
                model.text_encoder.config.query_token_count,
                device=encoder_hidden_states.device,
                dtype=torch.long
            ).unsqueeze(0).expand(encoder_hidden_states.size(0), -1)

            encoder_text_hidden_states = model.text_encoder(
                input_ids=text_encoder_input_ids,
                cache_position=None,
                attention_mask=None,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=None,
                use_cache=False
            ).hidden_states
            del encoder_hidden_states

            if settings.RECOGNITION_STATIC_CACHE:
                # Pad inputs to max batch size for static cache
                encoder_text_hidden_states = pad_to_batch_size(encoder_text_hidden_states, batch_size)
                batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)

            while token_count < settings.RECOGNITION_MAX_TOKENS - 1:
                is_prefill = token_count == 0
                #TODO: add attention mask
                return_dict = model.decoder(
                    input_ids=batch_decoder_input,
                    encoder_hidden_states=encoder_text_hidden_states,
                    cache_position=decoder_position_ids,
                    use_cache=True,
                    prefill=is_prefill
                )

                decoder_position_ids = decoder_position_ids[-1:] + 1
                logits = return_dict["logits"][:current_batch_size] # Ignore batch padding
                aux_logits = return_dict.get("aux_logits", None)

                preds = torch.argmax(logits[:, -1], dim=-1)
                scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1)
                done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id)
                done = done
                all_done = all_done | done

                if is_prefill:
                    sequence_scores = scores
                else:
                    scores = scores.masked_fill(all_done, 0)
                    sequence_scores = torch.cat([sequence_scores, scores], dim=1)

                if all_done.all():
                    break

                batch_decoder_input = preds.unsqueeze(1)

                for j, (pred, status) in enumerate(zip(preds, all_done)):
                    if not status:
                        batch_predictions[j].append(int(pred))

                token_count += inference_token_count
                inference_token_count = batch_decoder_input.shape[-1]
                max_position_id = torch.max(decoder_position_ids).item()
                decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + max_position_id

                if settings.RECOGNITION_STATIC_CACHE:
                    batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)

        sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
        detected_text = processor.tokenizer.batch_decode(batch_predictions)
        detected_text = [truncate_repetitions(dt) for dt in detected_text]

        # Postprocess to fix LaTeX output (add $$ signs, etc)
        detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)]
        output_text.extend(detected_text)
        confidences.extend(sequence_scores.tolist())

        del encoder_text_hidden_states

    output_text = sorted(zip(indices, output_text), key=lambda x: x[0])
    confidences = sorted(zip(indices, confidences), key=lambda x: x[0])
    output_text = [text for _, text in output_text]
    confidences = [conf for _, conf in confidences]
    return output_text, confidences