metadata
tags:
- PyTorch
- CNN
datasets:
- sklearn-digits
Basic TinyCNN PyTorch model trained on Sklearn Digits dataset.
"""
Credits to Zama.ai - https://github.com/zama-ai/concrete-ml/blob/main/docs/user/advanced_examples/ConvolutionalNeuralNetwork.ipynb
"""
import numpy as np
import torch
from torch import nn
from torch.nn.utils import prune
class TinyCNN(nn.Module):
"""A very small CNN to classify the sklearn digits dataset.
This class also allows pruning to a maximum of 10 active neurons, which
should help keep the accumulator bit width low.
"""
def __init__(self, n_classes) -> None:
"""Construct the CNN with a configurable number of classes."""
super().__init__()
# This network has a total complexity of 1216 MAC
self.conv1 = nn.Conv2d(1, 2, 3, stride=1, padding=0)
self.conv2 = nn.Conv2d(2, 3, 3, stride=2, padding=0)
self.conv3 = nn.Conv2d(3, 16, 2, stride=1, padding=0)
self.fc1 = nn.Linear(16, n_classes)
# Enable pruning, prepared for training
self.toggle_pruning(True)
def toggle_pruning(self, enable):
"""Enables or removes pruning."""
# Maximum number of active neurons (i.e. corresponding weight != 0)
n_active = 10
# Go through all the convolution layers
for layer in (self.conv1, self.conv2, self.conv3):
s = layer.weight.shape
# Compute fan-in (number of inputs to a neuron)
# and fan-out (number of neurons in the layer)
st = [s[0], np.prod(s[1:])]
# The number of input neurons (fan-in) is the product of
# the kernel width x height x inChannels.
if st[1] > n_active:
if enable:
# This will create a forward hook to create a mask tensor that is multiplied
# with the weights during forward. The mask will contain 0s or 1s
prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
else:
# When disabling pruning, the mask is multiplied with the weights
# and the result is stored in the weights member
prune.remove(layer, "weight")
def forward(self, x):
"""Run inference on the tiny CNN, apply the decision layer on the reshaped conv output."""
x = self.conv1(x)
x = torch.relu(x)
x = self.conv2(x)
x = torch.relu(x)
x = self.conv3(x)
x = torch.relu(x)
x = x.view(-1, 16)
x = self.fc1(x)
return x