File size: 2,616 Bytes
f715a50 7ee5054 f715a50 7289ec7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
---
tags:
- PyTorch
- CNN
datasets:
- sklearn-digits
---
Basic TinyCNN PyTorch model trained on Sklearn Digits dataset.
```python
"""
Credits to Zama.ai - https://github.com/zama-ai/concrete-ml/blob/main/docs/user/advanced_examples/ConvolutionalNeuralNetwork.ipynb
"""
import numpy as np
import torch
from torch import nn
from torch.nn.utils import prune
class TinyCNN(nn.Module):
"""A very small CNN to classify the sklearn digits dataset.
This class also allows pruning to a maximum of 10 active neurons, which
should help keep the accumulator bit width low.
"""
def __init__(self, n_classes) -> None:
"""Construct the CNN with a configurable number of classes."""
super().__init__()
# This network has a total complexity of 1216 MAC
self.conv1 = nn.Conv2d(1, 2, 3, stride=1, padding=0)
self.conv2 = nn.Conv2d(2, 3, 3, stride=2, padding=0)
self.conv3 = nn.Conv2d(3, 16, 2, stride=1, padding=0)
self.fc1 = nn.Linear(16, n_classes)
# Enable pruning, prepared for training
self.toggle_pruning(True)
def toggle_pruning(self, enable):
"""Enables or removes pruning."""
# Maximum number of active neurons (i.e. corresponding weight != 0)
n_active = 10
# Go through all the convolution layers
for layer in (self.conv1, self.conv2, self.conv3):
s = layer.weight.shape
# Compute fan-in (number of inputs to a neuron)
# and fan-out (number of neurons in the layer)
st = [s[0], np.prod(s[1:])]
# The number of input neurons (fan-in) is the product of
# the kernel width x height x inChannels.
if st[1] > n_active:
if enable:
# This will create a forward hook to create a mask tensor that is multiplied
# with the weights during forward. The mask will contain 0s or 1s
prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
else:
# When disabling pruning, the mask is multiplied with the weights
# and the result is stored in the weights member
prune.remove(layer, "weight")
def forward(self, x):
"""Run inference on the tiny CNN, apply the decision layer on the reshaped conv output."""
x = self.conv1(x)
x = torch.relu(x)
x = self.conv2(x)
x = torch.relu(x)
x = self.conv3(x)
x = torch.relu(x)
x = x.view(-1, 16)
x = self.fc1(x)
return x
```
|