|
Basic TinyCNN PyTorch model trained on Sklearn Digits dataset. |
|
|
|
```python |
|
""" |
|
Credits to Zama.ai - https://github.com/zama-ai/concrete-ml/blob/main/docs/user/advanced_examples/ConvolutionalNeuralNetwork.ipynb |
|
""" |
|
import numpy as np |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn.utils import prune |
|
|
|
class TinyCNN(nn.Module): |
|
"""A very small CNN to classify the sklearn digits dataset. |
|
|
|
This class also allows pruning to a maximum of 10 active neurons, which |
|
should help keep the accumulator bit width low. |
|
""" |
|
|
|
def __init__(self, n_classes) -> None: |
|
"""Construct the CNN with a configurable number of classes.""" |
|
super().__init__() |
|
|
|
# This network has a total complexity of 1216 MAC |
|
self.conv1 = nn.Conv2d(1, 2, 3, stride=1, padding=0) |
|
self.conv2 = nn.Conv2d(2, 3, 3, stride=2, padding=0) |
|
self.conv3 = nn.Conv2d(3, 16, 2, stride=1, padding=0) |
|
self.fc1 = nn.Linear(16, n_classes) |
|
|
|
# Enable pruning, prepared for training |
|
self.toggle_pruning(True) |
|
|
|
def toggle_pruning(self, enable): |
|
"""Enables or removes pruning.""" |
|
|
|
# Maximum number of active neurons (i.e. corresponding weight != 0) |
|
n_active = 10 |
|
|
|
# Go through all the convolution layers |
|
for layer in (self.conv1, self.conv2, self.conv3): |
|
s = layer.weight.shape |
|
|
|
# Compute fan-in (number of inputs to a neuron) |
|
# and fan-out (number of neurons in the layer) |
|
st = [s[0], np.prod(s[1:])] |
|
|
|
# The number of input neurons (fan-in) is the product of |
|
# the kernel width x height x inChannels. |
|
if st[1] > n_active: |
|
if enable: |
|
# This will create a forward hook to create a mask tensor that is multiplied |
|
# with the weights during forward. The mask will contain 0s or 1s |
|
prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0]) |
|
else: |
|
# When disabling pruning, the mask is multiplied with the weights |
|
# and the result is stored in the weights member |
|
prune.remove(layer, "weight") |
|
|
|
def forward(self, x): |
|
"""Run inference on the tiny CNN, apply the decision layer on the reshaped conv output.""" |
|
x = self.conv1(x) |
|
x = torch.relu(x) |
|
x = self.conv2(x) |
|
x = torch.relu(x) |
|
x = self.conv3(x) |
|
x = torch.relu(x) |
|
x = x.view(-1, 16) |
|
x = self.fc1(x) |
|
return x |
|
``` |
|
|