Skip to content

NN Inference

Alexander Viand edited this page Aug 13, 2020 · 49 revisions

The NN inference benchmark is an image classification task, using a relatively simple neural network.

Task

The input is an encrypted image, and the application must classify that image into one of ten categories. The network will be pre-trained and weights/etc are available as plaintext (i.e. the FHE task is inference only).

We use the MNIST dataset, where the categories are handwritten digits 0-9 and the images are 28x28px large. We also considered using CIFAR-10, where the classes are objects like dogs, cars or planes and the images have a resolution of 32x32px. However, we were unable to reproduce results from previous works that trained FHE-compatible networks for CIFAR-10.

FHE Considerations

Non-Polynomial Functions

Non-polynomial functions like the ReLU activation function cannot be expressed directly in the BFV, BGV and CKKS schemes. Therefore, we will follow a similar route to the authors of CHET:

[..] we modified the activation functions to a second-degree polynomial. The key difference with prior work is that our activation functions are f(x)=ax^2+bx with learnable parameters a and b. During the training phase, the CNN adjusts these parameters to implement an appropriate activation function. We also replaced max-pooling with average-pooling.

Binary Networks

When using a binary plaintext space (which is always the case for TFHE), we can easily implement complex functions with a Look-Up-Table (LUT). However, operations like multiplications and additions, which are abundant in CNNs, are very expensive since we need to emulate binary addition and multiplication circuits. As a result, it might be more suitable to explore heavily quantized neural networks that use binary weights.

Implementations

Plaintext Reference Training

In order to establish a baseline for performance, and to generate the weights used by some of the tools that do not train the model themselves, we first implemented the networks for both MNIST and CIFAR-10 in standard Keras/Tensorflow. The EVA & CHET referenced the code that their models were inspired by. We therefore started with a version of the models that closely matched these original implementations and then made modifications as described in the EVA & CHET papers.

We replaced "max pooling" with "average pooling" and replaced the activation function (ReLU) with a degree-2 polynomial (ax^2 + bx), where a and b are parameters that are learned during training. In addition, we trained all networks using an ADAM optimiser as recommended by the authors of CHET.

MNIST

For MNIST we used a LeNet-5-like network as seen here:

conv [5x5x1x32,'SAME']  + bias + ReLU # replaced with ax^2 + bx
max_pool [2x2] # replaced with avg_pool
conv [5,5,32,64,'SAME'] + bias + ReLU # replaced with ax^2 + bx
max_pool [2x2] # replaced with avg_pool
flatten
fc [...] + bias + ReLU # replaced with ax^2 + bx
droput 50% # only during training
fc [512] + bias
softmax

On MNIST, both the original and the modified network achieve 99.3% accuracy.

CIFAR-10

For CIFAR-10 we used a model based on SqueezeNet-CIFAR:

conv [3x3x3x64, 'SAME'] + bias + ReLU
max_pool [3x3] # replaced with average pool
fire 16/64/64
fire 16/64/64
pool x2
fire 32/128/128
fire 32/128/128
conv [1x1x256x10, 'SAME'] + bias + ReLU
average pool
softmax

SEALion

SEALion curently only supports fully connected (or "dense") layers. In addition, SEALion hardcodes its activation function to f(x) = x^2. Therefore the model has to be significantly simplified to work in SEALion. MNIST is already one of the examples for SEALion, and a simple MLP is used:

flatten
fc[30] + activation (x^2)
fc[10]
softmax
---------------------
Total params: 23,862

This architecture achieves around 95% accuracy.

We tried to increase the number of fully connected layers and number of nodes per layer in order to achieve comparable performance to the LeNet-5 network:

flatten
fc[784] + activation (either ReLU, a*x^2+b*x or just x^2)
fc[16] + activation (either ReLU, a*x^2+b*x or just x^2)
fc[10]
softmax
---------------------
Total params: 628,174

However, we noticed that accuracy was at most equal, if not worse, than the simpler model. This is most likely due to quantization errors and the simpler activation function. For reference, we also trained an equivalent model both with standard ReLU and with EVA/CHET-style parameterised activation and observed that both achieved around 97% accuracy for MNIST.

Note that at these sizes, inference on the MNIST test set of 10'000 images already took around 25 minutes. Therefore, we decided to evaluate SEALion using the simple MLP architecture that it included as an example.

Clone this wiki locally