""" Contains PyTorch model code to instantiate a TinyVGG model. """ import torch from torch import nn class TrashClassificationCNNModel(nn.Module): def __init__(self, input_shape: int, hidden_units: int, output_shape: int): super().__init__() self.block_1 = nn.Sequential( nn.Conv2d(input_shape, hidden_units, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(hidden_units, hidden_units, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ) self.block_2 = nn.Sequential( nn.Conv2d(hidden_units, hidden_units, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(hidden_units, hidden_units, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(in_features=hidden_units*28*28, out_features=output_shape) ) def forward(self, x): x = self.block_1(x) x = self.block_2(x) x = self.classifier(x) return x