Skip to content

NN Inference

Alexander Viand edited this page Aug 26, 2020 · 49 revisions

The NN inference benchmark is an image classification task, using a relatively simple neural network.

Task

The input is an encrypted image, and the application must classify that image into one of ten categories. The network will be pre-trained and weights/etc are available as plaintext (i.e. the FHE task is inference only).

We use the MNIST dataset, where the categories are handwritten digits 0-9 and the images are 28x28px large. We also considered using CIFAR-10, where the classes are objects like dogs, cars or planes and the images have a resolution of 32x32px. However, we were unable to reproduce results from previous works that trained FHE-compatible networks for CIFAR-10.

FHE Considerations

Non-Polynomial Functions

Non-polynomial functions like the ReLU activation function cannot be expressed directly in the BFV, BGV and CKKS schemes. Therefore, we will follow a similar route to the authors of CHET:

[..] we modified the activation functions to a second-degree polynomial. The key difference with prior work is that our activation functions are f(x)=ax^2+bx with learnable parameters a and b. During the training phase, the CNN adjusts these parameters to implement an appropriate activation function. We also replaced max-pooling with average-pooling.

Binary Networks

When using a binary plaintext space (which is always the case for TFHE), we can easily implement complex functions with a Look-Up-Table (LUT). However, operations like multiplications and additions, which are abundant in CNNs, are very expensive since we need to emulate binary addition and multiplication circuits. As a result, it might be more suitable to explore heavily quantized neural networks that use binary weights.

Implementations

Plaintext Reference Training

In order to establish a baseline for performance, and to generate the weights used by some of the tools that do not train the model themselves, we first implemented the networks for both MNIST and CIFAR-10 in standard Keras/Tensorflow. The EVA & CHET referenced the code that their models were inspired by. We therefore started with a version of the models that closely matched these original implementations and then made modifications as described in the EVA & CHET papers.

We replaced "max pooling" with "average pooling" and replaced the activation function (ReLU) with a degree-2 polynomial (ax^2 + bx), where a and b are parameters that are learned during training. In addition, we trained all networks using an ADAM optimiser as recommended by the authors of CHET.

MNIST

For MNIST we used a LeNet-5-like network as seen here:

conv [5x5x1x32,'SAME']  + bias + ReLU # replaced with ax^2 + bx
max_pool [2x2] # replaced with avg_pool
conv [5,5,32,64,'SAME'] + bias + ReLU # replaced with ax^2 + bx
max_pool [2x2] # replaced with avg_pool
flatten
fc [...] + bias + ReLU # replaced with ax^2 + bx
droput 50% # only during training
fc [512] + bias
softmax

On MNIST, both the original and the modified network achieve 99.3% accuracy.

CIFAR-10

For CIFAR-10 we used a model based on SqueezeNet-CIFAR:

conv [3x3x3x64, 'SAME'] + bias + ReLU
max_pool [3x3] # replaced with average pool
fire 16/64/64
fire 16/64/64
pool x2
fire 32/128/128
fire 32/128/128
conv [1x1x256x10, 'SAME'] + bias + ReLU
average pool
softmax

While this achieves around 87% accuracy for the original model, the modified network did not learn at all (around 10% accuracy, i.e. random guess).

SEALion

SEALion curently only supports fully connected (or "dense") layers. In addition, SEALion hardcodes its activation function to f(x) = x^2. Therefore the model has to be significantly simplified to work in SEALion. MNIST is already one of the examples for SEALion, and a simple MLP is used:

flatten
fc[30] + activation (x^2)
fc[10]
softmax
---------------------
Total params: 23,862

This architecture achieves around 95% accuracy.

We tried to increase the number of fully connected layers and number of nodes per layer in order to achieve comparable performance to the LeNet-5 network:

flatten
fc[784] + activation (either ReLU, a*x^2+b*x or just x^2)
fc[16] + activation (either ReLU, a*x^2+b*x or just x^2)
fc[10]
softmax
---------------------
Total params: 628,174

However, we noticed that accuracy was at most equal, if not worse, than the simpler model. This is most likely due to quantization errors and the simpler activation function. For reference, we also trained an equivalent model both with standard ReLU and with EVA/CHET-style parameterised activation and observed that both achieved around 97% accuracy for MNIST.

Note that at these sizes, inference on the MNIST test set of 10'000 images already took around 25 minutes. Therefore, we decided to evaluate SEALion using the simple MLP architecture that it included as an example.

nGraph-HE

nGraph-HE includes several networks for MNIST in its examples. nGraph seems to have issues with the Keras API functions Flatten() or Reshape(), and the examples recommend using tf.reshape() instead, which is a minor inconvenience.

SEAL

For SEAL, we skipped non-SIMD implementations since they were unlikely to have feasible performance for benchmarking.

CKKS-based

For simplicity and direct comparability we decided to implement the simple MLP architecture we used in SEALion. Implementing the CNN-based network manually remains an open challenge.

For CKKS, the major consideration was dealing with the scale of the fixed-point representation used. Because multiplying two n-bit fixed-point numbers with s-bit scale gives a 2n-bit result with 2s-bit scale, traditional fixed point systems are very limited in multiplicative depth. Note that this is independent of whether the multiplication is between ciphertexts or ciphertext-plaintext. In FHE, the plaintext number stored in a ciphertext will quickly grow large enough to cause overflow issues that ruin the computation. CKKS offers an operation to drop down to a lower scale while essentially rounding the number to a shorter bit-length. This significantly increases the multiplicative depth that can be realized.

Clone this wiki locally