Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use hardware accelerated Aes on x64
Browse files Browse the repository at this point in the history
benaadams committed Jan 22, 2023
1 parent 6fe2983 commit 7ebec69
Showing 4 changed files with 433 additions and 5 deletions.
422 changes: 422 additions & 0 deletions src/Nethermind/Nethermind.Crypto/AesEngineX86Intrinsic.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,422 @@
// SPDX-FileCopyrightText: 2023 Demerzel Solutions Limited
// SPDX-License-Identifier: LGPL-3.0-only
// Modified from BouncyCastle MIT

using System;
using System.Buffers.Binary;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;

using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Parameters;

using Aes = System.Runtime.Intrinsics.X86.Aes;
using Sse2 = System.Runtime.Intrinsics.X86.Sse2;

namespace Nethermind.Crypto;

public sealed class AesEngineX86Intrinsic : IBlockCipher
{
public static bool IsSupported => Aes.IsSupported;
public bool IsPartialBlockOkay => false;
public void Reset() { }

public AesEngineX86Intrinsic()
{
if (!IsSupported)
throw new PlatformNotSupportedException(nameof(AesEngineX86Intrinsic));
}

public string AlgorithmName => "AES";

public int GetBlockSize() => 16;

private AesEncoderDecoder _implementation;

public void Init(bool forEncryption, ICipherParameters parameters)
{
if (parameters is not KeyParameter keyParameter)
{
ArgumentNullException.ThrowIfNull(parameters, nameof(parameters));
throw new ArgumentException("invalid type: " + parameters.GetType(), nameof(parameters));
}

Vector128<byte>[] roundKeys = CreateRoundKeys(keyParameter.GetKey(), forEncryption);
_implementation = AesEncoderDecoder.Init(forEncryption, roundKeys);
}

public int ProcessBlock(byte[] inBuf, int inOff, byte[] outBuf, int outOff)
{
Check.DataLength(inBuf, inOff, 16);
Check.OutputLength(outBuf, outOff, 16);

Vector128<byte> state = Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(inBuf), inOff));

_implementation.ProcessRounds(ref state);

Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(outBuf), outOff)) = state;

return 16;
}

private static Vector128<byte>[] CreateRoundKeys(byte[] key, bool forEncryption)
{
Vector128<byte>[] K = key.Length switch
{
16 => KeyLength16(key),
24 => KeyLength24(key),
32 => KeyLength32(key),
_ => throw new ArgumentException("Key length not 128/192/256 bits.")
};

if (!forEncryption)
{
for (int i = 1, last = K.Length - 1; i < last; ++i)
{
K[i] = Aes.InverseMixColumns(K[i]);
}

Array.Reverse(K);
}

return K;

[SkipLocalsInit]
static Vector128<byte>[] KeyLength16(byte[] key)
{
ReadOnlySpan<byte> rcon = stackalloc byte[] { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 };

Vector128<byte> s = MemoryMarshal.Read<Vector128<byte>>(key.AsSpan(0, 16));
Vector128<byte>[] K = new Vector128<byte>[11];
K[0] = s;

for (int round = 0; round < 10;)
{
Vector128<byte> t = Aes.KeygenAssist(s, rcon[round++]);
t = Sse2.Shuffle(t.AsInt32(), 0xFF).AsByte();
s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8));
s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4));
s = Sse2.Xor(s, t);
K[round] = s;
}

return K;
}

static Vector128<byte>[] KeyLength24(byte[] key)
{
Vector128<byte> s1 = MemoryMarshal.Read<Vector128<byte>>(key.AsSpan(0, 16));
Vector128<byte> s2 = MemoryMarshal.Read<Vector64<byte>>(key.AsSpan(16, 8)).ToVector128();
Vector128<byte>[] K = new Vector128<byte>[13];
K[0] = s1;

byte rcon = 0x01;
for (int round = 0; ;)
{
Vector128<byte> t1 = Aes.KeygenAssist(s2, rcon); rcon <<= 1;
t1 = Sse2.Shuffle(t1.AsInt32(), 0x55).AsByte();

s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 8));
s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 4));
s1 = Sse2.Xor(s1, t1);

K[++round] = Sse2.Xor(s2, Sse2.ShiftLeftLogical128BitLane(s1, 8));

Vector128<byte> s3 = Sse2.Xor(s2, Sse2.ShiftRightLogical128BitLane(s1, 12));
s3 = Sse2.Xor(s3, Sse2.ShiftLeftLogical128BitLane(s3, 4));

K[++round] = Sse2.Xor(
Sse2.ShiftRightLogical128BitLane(s1, 8),
Sse2.ShiftLeftLogical128BitLane(s3, 8));

Vector128<byte> t2 = Aes.KeygenAssist(s3, rcon); rcon <<= 1;
t2 = Sse2.Shuffle(t2.AsInt32(), 0x55).AsByte();

s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 8));
s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 4));
s1 = Sse2.Xor(s1, t2);

K[++round] = s1;

if (round == 12)
break;

s2 = Sse2.Xor(s3, Sse2.ShiftRightLogical128BitLane(s1, 12));
s2 = Sse2.Xor(s2, Sse2.ShiftLeftLogical128BitLane(s2, 4));
s2 = s2.WithUpper(Vector64<byte>.Zero);
}

return K;
}

static Vector128<byte>[] KeyLength32(byte[] key)
{
Vector128<byte> s1 = MemoryMarshal.Read<Vector128<byte>>(key.AsSpan(0, 16));
Vector128<byte> s2 = MemoryMarshal.Read<Vector128<byte>>(key.AsSpan(16, 16));
Vector128<byte>[] K = new Vector128<byte>[15];
K[0] = s1;
K[1] = s2;

byte rcon = 0x01;
for (int round = 1; ;)
{
Vector128<byte> t1 = Aes.KeygenAssist(s2, rcon); rcon <<= 1;
t1 = Sse2.Shuffle(t1.AsInt32(), 0xFF).AsByte();
s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 8));
s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 4));
s1 = Sse2.Xor(s1, t1);
K[++round] = s1;

if (round == 14)
break;

Vector128<byte> t2 = Aes.KeygenAssist(s1, 0x00);
t2 = Sse2.Shuffle(t2.AsInt32(), 0xAA).AsByte();
s2 = Sse2.Xor(s2, Sse2.ShiftLeftLogical128BitLane(s2, 8));
s2 = Sse2.Xor(s2, Sse2.ShiftLeftLogical128BitLane(s2, 4));
s2 = Sse2.Xor(s2, t2);
K[++round] = s2;
}

return K;
}
}

private abstract class AesEncoderDecoder
{
protected readonly Vector128<byte>[] _roundKeys;

public AesEncoderDecoder(Vector128<byte>[] roundKeys)
{
_roundKeys = roundKeys;
}

public static AesEncoderDecoder Init(bool forEncryption, Vector128<byte>[] roundKeys)
{
if (roundKeys.Length == 11)
{
return forEncryption ? new Encode128(roundKeys) : new Decode128(roundKeys);
}
else if (roundKeys.Length == 13)
{
return forEncryption ? new Encode192(roundKeys) : new Decode192(roundKeys);
}
else
{
return forEncryption ? new Encode256(roundKeys) : new Decode256(roundKeys);
}
}

public abstract void ProcessRounds(ref Vector128<byte> state);

private sealed class Encode128 : AesEncoderDecoder
{
public Encode128(Vector128<byte>[] roundKeys) : base(roundKeys) { }

public override void ProcessRounds(ref Vector128<byte> state)
{
// Take local refence to array so Jit can reason length doesn't change in method
Vector128<byte>[] roundKeys = _roundKeys;
{
// Get the Jit to bounds check once rather than each increasing array access
Vector128<byte> temp = roundKeys[10];
}

// Operate on non-ref local so it remains in register rather than operating on memory
Vector128<byte> state2 = Sse2.Xor(state, roundKeys[0]);
state2 = Aes.Encrypt(state2, roundKeys[1]);
state2 = Aes.Encrypt(state2, roundKeys[2]);
state2 = Aes.Encrypt(state2, roundKeys[3]);
state2 = Aes.Encrypt(state2, roundKeys[4]);
state2 = Aes.Encrypt(state2, roundKeys[5]);
state2 = Aes.Encrypt(state2, roundKeys[6]);
state2 = Aes.Encrypt(state2, roundKeys[7]);
state2 = Aes.Encrypt(state2, roundKeys[8]);
state2 = Aes.Encrypt(state2, roundKeys[9]);
state2 = Aes.EncryptLast(state2, roundKeys[10]);
// Copy back to ref
state = state2;
}
}

private sealed class Decode128 : AesEncoderDecoder
{
public Decode128(Vector128<byte>[] roundKeys) : base(roundKeys) { }

public override void ProcessRounds(ref Vector128<byte> state)
{
// Take local refence to array so Jit can reason length doesn't change in method
Vector128<byte>[] roundKeys = _roundKeys;
{
// Get the Jit to bounds check once rather than each increasing array access
Vector128<byte> temp = roundKeys[10];
}

// Operate on non-ref local so it remains in register rather than operating on memory
Vector128<byte> state2 = Sse2.Xor(state, roundKeys[0]);
state2 = Aes.Decrypt(state2, roundKeys[1]);
state2 = Aes.Decrypt(state2, roundKeys[2]);
state2 = Aes.Decrypt(state2, roundKeys[3]);
state2 = Aes.Decrypt(state2, roundKeys[4]);
state2 = Aes.Decrypt(state2, roundKeys[5]);
state2 = Aes.Decrypt(state2, roundKeys[6]);
state2 = Aes.Decrypt(state2, roundKeys[7]);
state2 = Aes.Decrypt(state2, roundKeys[8]);
state2 = Aes.Decrypt(state2, roundKeys[9]);
state2 = Aes.DecryptLast(state2, roundKeys[10]);
// Copy back to ref
state = state2;
}
}

private sealed class Encode192 : AesEncoderDecoder
{
public Encode192(Vector128<byte>[] roundKeys) : base(roundKeys) { }

public override void ProcessRounds(ref Vector128<byte> state)
{
// Take local refence to array so Jit can reason length doesn't change in method
Vector128<byte>[] roundKeys = _roundKeys;
{
// Get the Jit to bounds check once rather than each increasing array access
Vector128<byte> temp = roundKeys[12];
}

// Operate on non-ref local so it remains in register rather than operating on memory
Vector128<byte> state2 = Sse2.Xor(state, roundKeys[0]);
state2 = Aes.Encrypt(state2, roundKeys[1]);
state2 = Aes.Encrypt(state2, roundKeys[2]);
state2 = Aes.Encrypt(state2, roundKeys[3]);
state2 = Aes.Encrypt(state2, roundKeys[4]);
state2 = Aes.Encrypt(state2, roundKeys[5]);
state2 = Aes.Encrypt(state2, roundKeys[6]);
state2 = Aes.Encrypt(state2, roundKeys[7]);
state2 = Aes.Encrypt(state2, roundKeys[8]);
state2 = Aes.Encrypt(state2, roundKeys[9]);
state2 = Aes.Encrypt(state2, roundKeys[10]);
state2 = Aes.Encrypt(state2, roundKeys[11]);
state2 = Aes.EncryptLast(state2, roundKeys[12]);
// Copy back to ref
state = state2;
}
}

private sealed class Decode192 : AesEncoderDecoder
{
public Decode192(Vector128<byte>[] roundKeys) : base(roundKeys) { }

public override void ProcessRounds(ref Vector128<byte> state)
{
// Take local refence to array so Jit can reason length doesn't change in method
Vector128<byte>[] roundKeys = _roundKeys;
{
// Get the Jit to bounds check once rather than each increasing array access
Vector128<byte> temp = roundKeys[12];
}

// Operate on non-ref local so it remains in register rather than operating on memory
Vector128<byte> state2 = Sse2.Xor(state, roundKeys[0]);
state2 = Aes.Decrypt(state2, roundKeys[1]);
state2 = Aes.Decrypt(state2, roundKeys[2]);
state2 = Aes.Decrypt(state2, roundKeys[3]);
state2 = Aes.Decrypt(state2, roundKeys[4]);
state2 = Aes.Decrypt(state2, roundKeys[5]);
state2 = Aes.Decrypt(state2, roundKeys[6]);
state2 = Aes.Decrypt(state2, roundKeys[7]);
state2 = Aes.Decrypt(state2, roundKeys[8]);
state2 = Aes.Decrypt(state2, roundKeys[9]);
state2 = Aes.Decrypt(state2, roundKeys[10]);
state2 = Aes.Decrypt(state2, roundKeys[11]);
state2 = Aes.DecryptLast(state2, roundKeys[12]);
// Copy back to ref
state = state2;
}
}

private sealed class Encode256 : AesEncoderDecoder
{
public Encode256(Vector128<byte>[] roundKeys) : base(roundKeys) { }

public override void ProcessRounds(ref Vector128<byte> state)
{
// Take local refence to array so Jit can reason length doesn't change in method
Vector128<byte>[] roundKeys = _roundKeys;
{
// Get the Jit to bounds check once rather than each increasing array access
Vector128<byte> temp = roundKeys[14];
}

// Operate on non-ref local so it remains in register rather than operating on memory
Vector128<byte> state2 = Sse2.Xor(state, roundKeys[0]);
state2 = Aes.Encrypt(state2, roundKeys[1]);
state2 = Aes.Encrypt(state2, roundKeys[2]);
state2 = Aes.Encrypt(state2, roundKeys[3]);
state2 = Aes.Encrypt(state2, roundKeys[4]);
state2 = Aes.Encrypt(state2, roundKeys[5]);
state2 = Aes.Encrypt(state2, roundKeys[6]);
state2 = Aes.Encrypt(state2, roundKeys[7]);
state2 = Aes.Encrypt(state2, roundKeys[8]);
state2 = Aes.Encrypt(state2, roundKeys[9]);
state2 = Aes.Encrypt(state2, roundKeys[10]);
state2 = Aes.Encrypt(state2, roundKeys[11]);
state2 = Aes.Encrypt(state2, roundKeys[12]);
state2 = Aes.Encrypt(state2, roundKeys[13]);
state2 = Aes.EncryptLast(state2, roundKeys[14]);
// Copy back to ref
state = state2;
}
}

private sealed class Decode256 : AesEncoderDecoder
{
public Decode256(Vector128<byte>[] roundKeys) : base(roundKeys) { }

public override void ProcessRounds(ref Vector128<byte> state)
{
// Take local refence to array so Jit can reason length doesn't change in method
Vector128<byte>[] roundKeys = _roundKeys;
{
// Get the Jit to bounds check once rather than each increasing array access
Vector128<byte> temp = roundKeys[14];
}

// Operate on non-ref local so it remains in register rather than operating on memory
Vector128<byte> state2 = Sse2.Xor(state, roundKeys[0]);
state2 = Aes.Decrypt(state2, roundKeys[1]);
state2 = Aes.Decrypt(state2, roundKeys[2]);
state2 = Aes.Decrypt(state2, roundKeys[3]);
state2 = Aes.Decrypt(state2, roundKeys[4]);
state2 = Aes.Decrypt(state2, roundKeys[5]);
state2 = Aes.Decrypt(state2, roundKeys[6]);
state2 = Aes.Decrypt(state2, roundKeys[7]);
state2 = Aes.Decrypt(state2, roundKeys[8]);
state2 = Aes.Decrypt(state2, roundKeys[9]);
state2 = Aes.Decrypt(state2, roundKeys[10]);
state2 = Aes.Decrypt(state2, roundKeys[11]);
state2 = Aes.Decrypt(state2, roundKeys[12]);
state2 = Aes.Decrypt(state2, roundKeys[13]);
state2 = Aes.DecryptLast(state2, roundKeys[14]);
// Copy back to ref
state = state2;
}
}
}

private static class Check
{
public static void DataLength(byte[] buf, int off, int len)
{
if (off > (buf.Length - len)) ThrowDataLengthException();

static void ThrowDataLengthException() => throw new DataLengthException("input buffer too short");
}

public static void OutputLength(byte[] buf, int off, int len)
{
if (off > (buf.Length - len)) ThrowOutputLengthException();

static void ThrowOutputLengthException() => throw new OutputLengthException("output buffer too short");
}
}
}
2 changes: 1 addition & 1 deletion src/Nethermind/Nethermind.Crypto/EciesCipher.cs
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ private byte[] Decrypt(PublicKey ephemeralPublicKey, PrivateKey privateKey, byte

private IIesEngine MakeIesEngine(bool isEncrypt, PublicKey publicKey, PrivateKey privateKey, byte[] iv)
{
AesEngine aesFastEngine = new();
IBlockCipher aesFastEngine = AesEngineX86Intrinsic.IsSupported ? new AesEngineX86Intrinsic() : new AesEngine();

EthereumIesEngine iesEngine = new(
new HMac(new Sha256Digest()),
5 changes: 4 additions & 1 deletion src/Nethermind/Nethermind.Network/Rlpx/FrameCipher.cs
Original file line number Diff line number Diff line change
@@ -2,6 +2,9 @@
// SPDX-License-Identifier: LGPL-3.0-only

using System.Diagnostics;

using Nethermind.Crypto;

using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Engines;
using Org.BouncyCastle.Crypto.Modes;
@@ -20,7 +23,7 @@ public class FrameCipher : IFrameCipher

public FrameCipher(byte[] aesKey)
{
AesEngine aes = new();
IBlockCipher aes = AesEngineX86Intrinsic.IsSupported ? new AesEngineX86Intrinsic() : new AesEngine();

Debug.Assert(aesKey.Length == KeySize, $"AES key expected to be {KeySize} bytes long");

9 changes: 6 additions & 3 deletions src/Nethermind/Nethermind.Network/Rlpx/FrameMacProcessor.cs
Original file line number Diff line number Diff line change
@@ -5,6 +5,9 @@
using System.IO;
using Nethermind.Core.Attributes;
using Nethermind.Core.Crypto;
using Nethermind.Crypto;

using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Digests;
using Org.BouncyCastle.Crypto.Engines;
using Org.BouncyCastle.Crypto.Parameters;
@@ -21,7 +24,7 @@ public class FrameMacProcessor : IFrameMacProcessor
private readonly KeccakDigest _ingressMac;
private readonly KeccakDigest _egressMacCopy;
private readonly KeccakDigest _ingressMacCopy;
private readonly AesEngine _aesEngine;
private readonly IBlockCipher _aesEngine;
private readonly byte[] _macSecret;

public FrameMacProcessor(PublicKey remoteNodeId, EncryptionSecrets secrets)
@@ -39,9 +42,9 @@ public FrameMacProcessor(PublicKey remoteNodeId, EncryptionSecrets secrets)
_egressAesBlockBuffer = new byte[_ingressMac.GetDigestSize()];
}

private AesEngine MakeMacCipher()
private IBlockCipher MakeMacCipher()
{
AesEngine aesFastEngine = new();
IBlockCipher aesFastEngine = AesEngineX86Intrinsic.IsSupported ? new AesEngineX86Intrinsic() : new AesEngine();
aesFastEngine.Init(true, new KeyParameter(_macSecret));
return aesFastEngine;
}

0 comments on commit 7ebec69

Please sign in to comment.