Spaces:
Runtime error
Runtime error
import torch.utils.data | |
class DataProvider(): | |
def __init__(self, cfg, dataset, batch_size=None, shuffle=True): | |
super().__init__() | |
self.dataset = dataset | |
if batch_size is None: | |
batch_size = cfg.BATCH_SIZE | |
self.dataloader = torch.utils.data.DataLoader( | |
self.dataset, | |
batch_size=batch_size, | |
shuffle=shuffle, | |
num_workers=int(cfg.WORKERS), | |
drop_last=False) | |
def __len__(self): | |
return len(self.dataset) | |
def __iter__(self): | |
for i, data in enumerate(self.dataloader): | |
yield data |