-
Notifications
You must be signed in to change notification settings - Fork 16
NN Inference
The NN inference benchmark is an image classification task, using a relatively simple neural network.
The input is an encrypted image, and the application must classify that image into one of ten categories. The network will be pre-trained and weights/etc are available as plaintext (i.e. the FHE task is inference only).
We use the MNIST dataset, where the categories are handwritten digits 0-9 and the images are 28x28px large. We also considered using CIFAR-10, where the classes are objects like dogs, cars or planes and the images have a resolution of 32x32px. However, we were unable to reproduce results from previous works that trained FHE-compatible networks for CIFAR-10.
Non-polynomial functions like the ReLU activation function cannot be expressed directly in the BFV, BGV and CKKS schemes. Therefore, we will follow a similar route to the authors of CHET:
[..] we modified the activation functions to a second-degree polynomial. The key difference with prior work is that our activation functions are f(x)=ax^2+bx with learnable parameters a and b. During the training phase, the CNN adjusts these parameters to implement an appropriate activation function. We also replaced max-pooling with average-pooling.
When using a binary plaintext space (which is always the case for TFHE), we can easily implement complex functions with a Look-Up-Table (LUT). However, operations like multiplications and additions, which are abundant in CNNs, are very expensive since we need to emulate binary addition and multiplication circuits. As a result, it might be more suitable to explore heavily quantized neural networks that use binary weights.
In order to establish a baseline for performance, and to generate the weights used by some of the tools that do not train the model themselves, we first implemented the networks for both MNIST and CIFAR-10 in standard Keras/Tensorflow. The EVA & CHET referenced the code that their models were inspired by. We therefore started with a version of the models that closely matched these original implementations and then made modifications as described in the EVA & CHET papers.
We replaced "max pooling" with "average pooling" and replaced the activation function (ReLU) with a degree-2 polynomial (ax^2 + bx), where a and b are parameters that are learned during training. In addition, we trained all networks using an ADAM optimiser as recommended by the authors of CHET.
For MNIST we used a LeNet-5-like network as seen here:
conv [5x5x1x32,'SAME'] + bias + ReLU # replaced with ax^2 + bx
max_pool [2x2] # replaced with avg_pool
conv [5,5,32,64,'SAME'] + bias + ReLU # replaced with ax^2 + bx
max_pool [2x2] # replaced with avg_pool
flatten
fc [...] + bias + ReLU # replaced with ax^2 + bx
droput 50% # only during training
fc [512] + bias
softmax
On MNIST, both the original and the modified network achieve 99.3% accuracy.
For CIFAR-10 we used a model based on SqueezeNet-CIFAR:
conv [3x3x3x64, 'SAME'] + bias + ReLU
max_pool [3x3] # replaced with average pool
fire 16/64/64
fire 16/64/64
pool x2
fire 32/128/128
fire 32/128/128
conv [1x1x256x10, 'SAME'] + bias + ReLU
average pool
softmax
While this achieves around 87% accuracy for the original model, the modified network did not learn at all (around 10% accuracy, i.e. random guess).
SEALion curently only supports fully connected (or "dense") layers. In addition, SEALion hardcodes its activation function to f(x) = x^2. Therefore the model has to be significantly simplified to work in SEALion. MNIST is already one of the examples for SEALion, and a simple MLP is used:
flatten
fc[30] + activation (x^2)
fc[10]
softmax
---------------------
Total params: 23,862
This architecture achieves around 95% accuracy.
We tried to increase the number of fully connected layers and number of nodes per layer in order to achieve comparable performance to the LeNet-5 network:
flatten
fc[784] + activation (either ReLU, a*x^2+b*x or just x^2)
fc[16] + activation (either ReLU, a*x^2+b*x or just x^2)
fc[10]
softmax
---------------------
Total params: 628,174
However, we noticed that accuracy was at most equal, if not worse, than the simpler model. This is most likely due to quantization errors and the simpler activation function. For reference, we also trained an equivalent model both with standard ReLU and with EVA/CHET-style parameterised activation and observed that both achieved around 97% accuracy for MNIST.
Note that at these sizes, inference on the MNIST test set of 10'000 images already took around 25 minutes. Therefore, we decided to evaluate SEALion using the simple MLP architecture that it included as an example.
SEALion includes estimator.py
to automatically select suitable parameters from a list of pre-defined candidate parameters.
SEALion first tries to estimate the necessary plaintext modulus, however this is apparently generally overly cautious and therefore divided by 16 before proceeding. This estimate is then used to narrow down the candidates to be tested.
Then, starting with the smallest remaining candidate, it evaluates the network on a single batch of the network, enforcing stricter noise budgets (as measured by SEAL::Decryptor::invariant_noise_budget
than during normal computatios.
This is meant to account for variations in the input data, by allowing some "extra breathing room" for larger values.
Since machine learning in practice only provides meaningful results on input that are relatively close to the training data, this approach is well suited for neural network inference.
However, it would be incorrect for e.g. financial or statistical calculations.
nGraph-HE includes several networks for MNIST in its examples. nGraph seems to have issues with the Keras API functions Flatten() or Reshape(), and the examples recommend using tf.reshape() instead, which is a minor inconvenience.
####Plaintext To test a network using the HE_SEAL backend using unencrypted data (for debugging only), call
python test.py --batch_size=100 \
--backend=HE_SEAL \
--model_file=models/cryptonets.pb \
--encrypt_server_data=false
To test a network using the HE_SEAL backend using encrypted data, call
python test.py --batch_size=100 \
--backend=HE_SEAL \
--model_file=models/cryptonets.pb \
--encrypt_server_data=true \
--encryption_parameters=$HE_TRANSFORMER/configs/he_seal_ckks_config_N13_L8.json
This setting stores the secret key and public key on the same object, and should only be used for debugging, and estimating the runtime and memory overhead.
To test the client-server model, in one terminal call
python test.py --backend=HE_SEAL \
--model_file=models/cryptonets.pb \
--enable_client=true \
--encryption_parameters=$HE_TRANSFORMER/configs/he_seal_ckks_config_N13_L8.json
In another terminal (with the python environment active), call
python pyclient_mnist.py --batch_size=1024 \
--encrypt_data_str=encrypt
It appears like the batch size is how many samples are packed into a single ciphertext, as setting it to a number larger than the number of slots causes a packing error and setting it to just shy of the available ciphertext vector slots works fine.
For SEAL, we skipped non-SIMD implementations since they were unlikely to have feasible performance for benchmarking.
For simplicity and direct comparability we decided to implement the simple MLP architecture we used in SEALion. Implementing the CNN-based network manually remains an open challenge.
For CKKS, the major consideration was dealing with the scale of the fixed-point representation used. Because multiplying two n-bit fixed-point numbers with s-bit scale gives a 2n-bit result with 2s-bit scale, traditional fixed point systems are very limited in multiplicative depth. Note that this is independent of whether the multiplication is between ciphertexts or ciphertext-plaintext. In FHE, the plaintext number stored in a ciphertext will quickly grow large enough to cause overflow issues that ruin the computation. CKKS offers an operation to drop down to a lower scale while essentially rounding the number to a shorter bit-length. This significantly increases the multiplicative depth that can be realized.
- Home
- Compilers & Optimizations
- Libraries
- Benchmark Programs
- Implementation Docs