import torch import numpy as np import pandas as pd from PIL import Image from torch.utils.data import Dataset from augments import RandAug # Torch dataset class TextlineDataset(Dataset): def __init__(self, root_dir, df, processor, augment=False, max_target_length=128): self.root_dir = root_dir self.df = df self.processor = processor self.augment = augment self.augmentator = RandAug() self.max_target_length = max_target_length def __len__(self): return len(self.df) def __getitem__(self, idx): # get file name + text file_name = self.df['file_name'][idx] text = self.df['text'][idx] # prepare image (i.e. resize + normalize) image = Image.open(self.root_dir + file_name).convert("RGB") # Add image augmentations if self.augment: image = self.augmentator(image) # extract the pixel values pixel_values = self.processor(image, return_tensors="pt").pixel_values # add labels (input_ids) by encoding the text labels = self.processor.tokenizer(str(text), padding="max_length", truncation=True, max_length=self.max_target_length).input_ids # important: make sure that PAD tokens are ignored by the loss function labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} return encoding