from itertools import chain from transformers import GitProcessor class GIAProcessor(GitProcessor): def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) self._block_size = 1024 def _group_texts(self, examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. total_length = (total_length // self._block_size) * self._block_size # Split by chunks of max_len. result = { k: [t[i: i + self._block_size] for i in range(0, total_length, self._block_size)] for k, t in concatenated_examples.items() } return result def __call__(self, text=None, images=None, return_tensors=None, **kwargs): if text is not None and images is None: encoded_text = self.tokenizer(text, return_tensors=return_tensors) encoding = self._group_texts(encoded_text) elif text is not None and images is not None: encoding = super().__call__(text, images, return_tensors, **kwargs) return encoding def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): return ["input_ids", "attention_mask", "pixel_values"] GIAProcessor.register_for_auto_class("AutoProcessor")