|
import os |
|
import cv2 |
|
import torch |
|
import albumentations as A |
|
|
|
import config as CFG |
|
|
|
|
|
class CLIPDataset(torch.utils.data.Dataset): |
|
def __init__(self, image_filenames, captions, tokenizer, transforms): |
|
""" |
|
image_filenames and cpations must have the same length; so, if there are |
|
multiple captions for each image, the image_filenames must have repetitive |
|
file names |
|
""" |
|
|
|
self.image_filenames = image_filenames |
|
self.captions = list(captions) |
|
self.encoded_captions = tokenizer( |
|
list(captions), padding=True, truncation=True, max_length=CFG.max_length |
|
) |
|
self.transforms = transforms |
|
|
|
def __getitem__(self, idx): |
|
item = { |
|
key: torch.tensor(values[idx]) |
|
for key, values in self.encoded_captions.items() |
|
} |
|
|
|
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}") |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
image = self.transforms(image=image)['image'] |
|
item['image'] = torch.tensor(image).permute(2, 0, 1).float() |
|
item['caption'] = self.captions[idx] |
|
|
|
return item |
|
|
|
|
|
def __len__(self): |
|
return len(self.captions) |
|
|
|
|
|
|
|
def get_transforms(mode="train"): |
|
if mode == "train": |
|
return A.Compose( |
|
[ |
|
A.Resize(CFG.size, CFG.size, always_apply=True), |
|
A.Normalize(max_pixel_value=255.0, always_apply=True), |
|
] |
|
) |
|
else: |
|
return A.Compose( |
|
[ |
|
A.Resize(CFG.size, CFG.size, always_apply=True), |
|
A.Normalize(max_pixel_value=255.0, always_apply=True), |
|
] |
|
) |
|
|
|
|