Skip to content

NN Inference

Alexander Viand edited this page Aug 29, 2020 · 49 revisions

The NN inference benchmark is an image classification task, using a relatively simple neural network.

Task

The input is an encrypted image, and the application must classify that image into one of ten categories. The network will be pre-trained and weights/etc are available as plaintext (i.e. the FHE task is inference only).

We use the MNIST dataset, where the categories are handwritten digits 0-9 and the images are 28x28px large. We also considered using CIFAR-10, where the classes are objects like dogs, cars or planes and the images have a resolution of 32x32px. However, we were unable to reproduce results from previous works that trained FHE-compatible networks for CIFAR-10.

FHE Considerations

Non-Polynomial Functions

Non-polynomial functions like the ReLU activation function cannot be expressed directly in the BFV, BGV and CKKS schemes. Therefore, we will follow a similar route to the authors of CHET:

[..] we modified the activation functions to a second-degree polynomial. The key difference with prior work is that our activation functions are f(x)=ax^2+bx with learnable parameters a and b. During the training phase, the CNN adjusts these parameters to implement an appropriate activation function. We also replaced max-pooling with average-pooling.

Binary Networks

When using a binary plaintext space (which is always the case for TFHE), we can easily implement complex functions with a Look-Up-Table (LUT). However, operations like multiplications and additions, which are abundant in CNNs, are very expensive since we need to emulate binary addition and multiplication circuits. As a result, it might be more suitable to explore heavily quantized neural networks that use binary weights.

Implementations

Plaintext Reference Training

In order to establish a baseline for performance, and to generate the weights used by some of the tools that do not train the model themselves, we first implemented the networks for both MNIST and CIFAR-10 in standard Keras/Tensorflow. The EVA & CHET referenced the code that their models were inspired by. We therefore started with a version of the models that closely matched these original implementations and then made modifications as described in the EVA & CHET papers.

We replaced "max pooling" with "average pooling" and replaced the activation function (ReLU) with a degree-2 polynomial (ax^2 + bx), where a and b are parameters that are learned during training. In addition, we trained all networks using an ADAM optimiser as recommended by the authors of CHET.

MNIST

For MNIST we used a LeNet-5-like network as seen here:

conv [5x5x1x32,'SAME']  + bias + ReLU # replaced with ax^2 + bx
max_pool [2x2] # replaced with avg_pool
conv [5,5,32,64,'SAME'] + bias + ReLU # replaced with ax^2 + bx
max_pool [2x2] # replaced with avg_pool
flatten
fc [...] + bias + ReLU # replaced with ax^2 + bx
droput 50% # only during training
fc [512] + bias
softmax

On MNIST, both the original and the modified network achieve 99.3% accuracy.

CIFAR-10

For CIFAR-10 we used a model based on SqueezeNet-CIFAR:

conv [3x3x3x64, 'SAME'] + bias + ReLU
max_pool [3x3] # replaced with average pool
fire 16/64/64
fire 16/64/64
pool x2
fire 32/128/128
fire 32/128/128
conv [1x1x256x10, 'SAME'] + bias + ReLU
average pool
softmax

While this achieves around 87% accuracy for the original model, the modified network did not learn at all (around 10% accuracy, i.e. random guess).

SEALion

SEALion curently only supports fully connected (or "dense") layers. In addition, SEALion hardcodes its activation function to f(x) = x^2. Therefore the model has to be significantly simplified to work in SEALion. MNIST is already one of the examples for SEALion, and a simple MLP is used:

flatten
fc[30] + activation (x^2)
fc[10]
softmax
---------------------
Total params: 23,862

This architecture achieves around 95% accuracy.

We tried to increase the number of fully connected layers and number of nodes per layer in order to achieve comparable performance to the LeNet-5 network:

flatten
fc[784] + activation (either ReLU, a*x^2+b*x or just x^2)
fc[16] + activation (either ReLU, a*x^2+b*x or just x^2)
fc[10]
softmax
---------------------
Total params: 628,174

However, we noticed that accuracy was at most equal, if not worse, than the simpler model. This is most likely due to quantization errors and the simpler activation function. For reference, we also trained an equivalent model both with standard ReLU and with EVA/CHET-style parameterised activation and observed that both achieved around 97% accuracy for MNIST.

Note that at these sizes, inference on the MNIST test set of 10'000 images already took around 25 minutes. Therefore, we decided to evaluate SEALion using the simple MLP architecture that it included as an example.

Parameters

ParamsCandidate degree=4096 primes=[18210817] cost=4096 after the first batch
ParamsCandidate degree=4096 primes=[27353089] cost=4096 after training completes

nGraph-HE

nGraph-HE includes several networks for MNIST in its examples. However, we additionally implemented a simple MLP following the same architecture as the one presented in SEALion, for an apples-to-apples comparison. nGraph seems to have issues with the Keras API functions Flatten() or Reshape(), and the examples recommend using tf.reshape() instead, which is a minor inconvenience.

MLP

Architecture as described above. Using both standard squaring and the learned ax^2+bx activation, achieves around 95-96% accuracy and can be run using the N13_L7 config:

{
    "scheme_name": "HE_SEAL",
    "poly_modulus_degree": 8192,
    "security_level": 128,
    "coeff_modulus": [
        30,
        24,
        24,
        24,
        24,
        24,
        30
    ],
    "complex_packing": false,
    "scale": 16777216
}

Cryptonets

See the corresponding paper for architecture. We use the "squashed" version from nGraph-HE,

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input (InputLayer)           [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 5)         130       
_________________________________________________________________
activation (Activation)      (None, 14, 14, 5)         0         
_________________________________________________________________
average_pooling2d (AveragePo (None, 14, 14, 5)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 50)          6300      
_________________________________________________________________
average_pooling2d_1 (Average (None, 7, 7, 50)          0         
_________________________________________________________________
flatten (Flatten)            (None, 2450)              0         
_________________________________________________________________
fc_1 (Dense)                 (None, 100)               245100    
_________________________________________________________________
activation_1 (Activation)    (None, 100)               0         
_________________________________________________________________
fc_2 (Dense)                 (None, 10)                1010      
=================================================================
Total params: 252,540
Trainable params: 252,540
Non-trainable params: 0

using N13_l8 configs

We used he-transformer/configs/he_seal_ckks_config_N13_L8.json

{
    "scheme_name": "HE_SEAL",
    "poly_modulus_degree": 8192,
    "security_level": 128,
    "coeff_modulus": [
        30,
        24,
        24,
        24,
        24,
        24,
        24,
        30
    ],
    "complex_packing": false,
    "scale": 16777216
}

General Remarks

It appears like the batch size is how many samples are packed into a single ciphertext, as setting it to a number larger than the number of slots causes a packing error and setting it to just shy of the available ciphertext vector slots works fine. I.e. other than the optimization proposed in EVA/CHET, nGraph-HE seems to use batching - at least in the settings we were able to explore - only to improve latency, not to improve throughput.

nGraph-HE will print out warnings about the encryption not enforcing security parmaters, but in fact all experiments do in fact have the claimed 128-bit security.

SEAL

For SEAL, we skipped non-SIMD implementations since they were unlikely to have feasible performance for benchmarking.

CKKS-based

For simplicity and direct comparability we decided to implement the simple MLP architecture we used in SEALion. Implementing the CNN-based network manually remains an open challenge.

For CKKS, the major consideration was dealing with the scale of the fixed-point representation used. Because multiplying two n-bit fixed-point numbers with s-bit scale gives a 2n-bit result with 2s-bit scale, traditional fixed point systems are very limited in multiplicative depth. Note that this is independent of whether the multiplication is between ciphertexts or ciphertext-plaintext. In FHE, the plaintext number stored in a ciphertext will quickly grow large enough to cause overflow issues that ruin the computation. CKKS offers an operation to drop down to a lower scale while essentially rounding the number to a shorter bit-length. This significantly increases the multiplicative depth that can be realized.

Fast Matrix-Vector Products

Dense (or Fully Connected) Layers are essentially matrix-vector products between the weights matrix of the layer and the input vector. Therefore, the most important performance aspect is to implement fast matrix-vector products (mvp).

We started our Implementation by recycling a lot of the matrix-vector multiplication logic from a PrivateAI Bootcamp Project. However, the implementation supports only square matrices. We therefore used the HYBRID method as proposed in Gazelle. However, the technique works best for multiplying vectors of length n with matrices of size n x m where m = n*2^k, i.e. m divides n and the result is a power-of-two. However, MNIST images are 28x28px, i.e. 784x1 after flattening. This does not divide nicely into either the 30 units of the first layer nor the final 10 classes required. Therefore we simply zero-pad the images to 32x32px and bump up the number of units in the first layer to 32 and the second to 16. While this introduces some dummy classes, the client can simply ignore these after decrypting.

Clone this wiki locally