File size: 1,938 Bytes
c5bd7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
Contains functionality for creating PyTorch DataLoaders for 
image classification data.
"""
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

def train_test_dataloader(train_dir: str,
                          test_dir: str,
                          transform: transforms.Compose,
                          batch_size: int):
    """Creates training and testing DataLoaders.

  Takes in a training directory and testing directory path and turns
  them into PyTorch Datasets using ImageFolder and then into PyTorch DataLoaders.

  Args:
    train_dir: Path to training directory.
    test_dir: Path to testing directory.
    transform: torchvision transforms to perform on training and testing data.
    batch_size: Number of samples per batch in each of the DataLoaders.

  Returns:
    A tuple of (train_dataloader, test_dataloader, class_names).
    Where class_names is a list of the target classes.
    Example usage:
      train_dataloader, test_dataloader, class_names = \
        = create_dataloaders(train_dir=path/to/train_dir,
                             test_dir=path/to/test_dir,
                             transform=some_transform,
                             batch_size=32)
  """
    # use ImageFolder to create the datasets
    dataset_train = ImageFolder(root=train_dir, transform=transform)
    dataset_test = ImageFolder(root=test_dir, transform=transform)

    # Get the Class Names
    class_names = dataset_train.classes

    # Make the DataLoaders
    train_dataloader = DataLoader(dataset_train,
                                  batch_size=batch_size,
                                  shuffle=True)
    test_dataloader = DataLoader(dataset_test,
                                 batch_size=batch_size,
                                 shuffle=True)
    
    return train_dataloader, test_dataloader, class_names