sklearn-digits / README.md
chainyo's picture
Create README.md
7289ec7
|
raw
history blame
2.56 kB

Basic TinyCNN PyTorch model trained on Sklearn Digits dataset.

"""
Credits to Zama.ai - https://github.com/zama-ai/concrete-ml/blob/main/docs/user/advanced_examples/ConvolutionalNeuralNetwork.ipynb
"""
import numpy as np

import torch
from torch import nn
from torch.nn.utils import prune

class TinyCNN(nn.Module):
    """A very small CNN to classify the sklearn digits dataset.

    This class also allows pruning to a maximum of 10 active neurons, which
    should help keep the accumulator bit width low.
    """

    def __init__(self, n_classes) -> None:
        """Construct the CNN with a configurable number of classes."""
        super().__init__()

        # This network has a total complexity of 1216 MAC
        self.conv1 = nn.Conv2d(1, 2, 3, stride=1, padding=0)
        self.conv2 = nn.Conv2d(2, 3, 3, stride=2, padding=0)
        self.conv3 = nn.Conv2d(3, 16, 2, stride=1, padding=0)
        self.fc1 = nn.Linear(16, n_classes)

        # Enable pruning, prepared for training
        self.toggle_pruning(True)

    def toggle_pruning(self, enable):
        """Enables or removes pruning."""

        # Maximum number of active neurons (i.e. corresponding weight != 0)
        n_active = 10

        # Go through all the convolution layers
        for layer in (self.conv1, self.conv2, self.conv3):
            s = layer.weight.shape

            # Compute fan-in (number of inputs to a neuron)
            # and fan-out (number of neurons in the layer)
            st = [s[0], np.prod(s[1:])]

            # The number of input neurons (fan-in) is the product of
            # the kernel width x height x inChannels.
            if st[1] > n_active:
                if enable:
                    # This will create a forward hook to create a mask tensor that is multiplied
                    # with the weights during forward. The mask will contain 0s or 1s
                    prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
                else:
                    # When disabling pruning, the mask is multiplied with the weights
                    # and the result is stored in the weights member
                    prune.remove(layer, "weight")

    def forward(self, x):
        """Run inference on the tiny CNN, apply the decision layer on the reshaped conv output."""
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.conv3(x)
        x = torch.relu(x)
        x = x.view(-1, 16)
        x = self.fc1(x)
        return x