import pytorch_lightning as L from torch.utils.data import DataLoader, random_split import torch import time class ImageDataModule(L.LightningDataModule): def __init__( self, train_dataset, val_dataset, test_dataset, global_batch_size, num_workers, num_nodes=1, num_devices=1, val_proportion=0.1, ): super().__init__() self._builders = { "train": train_dataset, "val": val_dataset, "test": test_dataset, } self.num_workers = num_workers self.batch_size = global_batch_size // (num_nodes * num_devices) print(f"Each GPU will receive {self.batch_size} images") self.val_proportion = val_proportion @property def num_classes(self): if hasattr(self, "train_dataset"): return self.train_dataset.num_classes else: return self._builders["train"]().num_classes def setup(self, stage=None): """Setup the datamodule. Args: stage (str): stage of the datamodule Is be one of "fit" or "test" or None """ print("Stage", stage) start_time = time.time() if stage == "fit" or stage is None: self.train_dataset = self._builders["train"]() self.val_dataset = self._builders["val"]() print(f"Train dataset size: {len(self.train_dataset)}") print(f"Val dataset size: {len(self.val_dataset)}") else: self.test_dataset = self._builders["test"]() print(f"Test dataset size: {len(self.test_dataset)}") end_time = time.time() print(f"Setup took {(end_time - start_time):.2f} seconds") def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=False, drop_last=True, num_workers=self.num_workers, collate_fn=self.train_dataset.collate_fn_density, ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, pin_memory=False, num_workers=self.num_workers, collate_fn=self.val_dataset.collate_fn, ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, pin_memory=False, num_workers=self.num_workers, collate_fn=self.test_dataset.collate_fn, )