chainyo commited on
Commit
7289ec7
1 Parent(s): 54e9333

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +70 -0
README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Basic TinyCNN PyTorch model trained on Sklearn Digits dataset.
2
+
3
+ ```python
4
+ """
5
+ Credits to Zama.ai - https://github.com/zama-ai/concrete-ml/blob/main/docs/user/advanced_examples/ConvolutionalNeuralNetwork.ipynb
6
+ """
7
+ import numpy as np
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn.utils import prune
12
+
13
+ class TinyCNN(nn.Module):
14
+ """A very small CNN to classify the sklearn digits dataset.
15
+
16
+ This class also allows pruning to a maximum of 10 active neurons, which
17
+ should help keep the accumulator bit width low.
18
+ """
19
+
20
+ def __init__(self, n_classes) -> None:
21
+ """Construct the CNN with a configurable number of classes."""
22
+ super().__init__()
23
+
24
+ # This network has a total complexity of 1216 MAC
25
+ self.conv1 = nn.Conv2d(1, 2, 3, stride=1, padding=0)
26
+ self.conv2 = nn.Conv2d(2, 3, 3, stride=2, padding=0)
27
+ self.conv3 = nn.Conv2d(3, 16, 2, stride=1, padding=0)
28
+ self.fc1 = nn.Linear(16, n_classes)
29
+
30
+ # Enable pruning, prepared for training
31
+ self.toggle_pruning(True)
32
+
33
+ def toggle_pruning(self, enable):
34
+ """Enables or removes pruning."""
35
+
36
+ # Maximum number of active neurons (i.e. corresponding weight != 0)
37
+ n_active = 10
38
+
39
+ # Go through all the convolution layers
40
+ for layer in (self.conv1, self.conv2, self.conv3):
41
+ s = layer.weight.shape
42
+
43
+ # Compute fan-in (number of inputs to a neuron)
44
+ # and fan-out (number of neurons in the layer)
45
+ st = [s[0], np.prod(s[1:])]
46
+
47
+ # The number of input neurons (fan-in) is the product of
48
+ # the kernel width x height x inChannels.
49
+ if st[1] > n_active:
50
+ if enable:
51
+ # This will create a forward hook to create a mask tensor that is multiplied
52
+ # with the weights during forward. The mask will contain 0s or 1s
53
+ prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
54
+ else:
55
+ # When disabling pruning, the mask is multiplied with the weights
56
+ # and the result is stored in the weights member
57
+ prune.remove(layer, "weight")
58
+
59
+ def forward(self, x):
60
+ """Run inference on the tiny CNN, apply the decision layer on the reshaped conv output."""
61
+ x = self.conv1(x)
62
+ x = torch.relu(x)
63
+ x = self.conv2(x)
64
+ x = torch.relu(x)
65
+ x = self.conv3(x)
66
+ x = torch.relu(x)
67
+ x = x.view(-1, 16)
68
+ x = self.fc1(x)
69
+ return x
70
+ ```