diff --git a/esp-mbedtls-sys/headers/esp32s3/config.h b/esp-mbedtls-sys/headers/esp32s3/config.h index 72e4bba..c06f13b 100644 --- a/esp-mbedtls-sys/headers/esp32s3/config.h +++ b/esp-mbedtls-sys/headers/esp32s3/config.h @@ -327,6 +327,7 @@ //#define MBEDTLS_GCM_ALT //#define MBEDTLS_NIST_KW_ALT #define MBEDTLS_MPI_EXP_MOD_ALT_FALLBACK +#define MBEDTLS_MPI_MUL_MPI_ALT //#define MBEDTLS_MD5_ALT //#define MBEDTLS_POLY1305_ALT //#define MBEDTLS_RIPEMD160_ALT diff --git a/esp-mbedtls/src/bignum.rs b/esp-mbedtls/src/bignum.rs index 84923ad..16150c4 100644 --- a/esp-mbedtls/src/bignum.rs +++ b/esp-mbedtls/src/bignum.rs @@ -1,6 +1,10 @@ #![allow(non_snake_case)] use crate::hal::prelude::nb; +#[cfg(feature = "esp32s3")] +use crate::hal::rsa::RsaModularMultiplication; +#[cfg(feature = "esp32s3")] +use crate::hal::rsa::RsaMultiplication; use crate::hal::rsa::{operand_sizes, RsaModularExponentiation}; use crypto_bigint::*; @@ -352,3 +356,412 @@ pub unsafe extern "C" fn mbedtls_mpi_exp_mod( } } } + +#[cfg(feature = "esp32s3")] +#[inline] +const fn bits_to_words(bits: usize) -> usize { + (bits + 31) / 32 +} + +/// Deal with the case when X & Y are too long for the hardware unit, by splitting one operand +/// into two halves. +/// +/// Y must be the longer operand +/// +/// Slice Y into Yp, Ypp such that: +/// Yp = lower 'b' bits of Y +/// Ypp = upper 'b' bits of Y (right shifted) +/// +/// Such that +/// Z = X * Y +/// Z = X * (Yp + Ypp< c_int { + let mut ret = 0; + + // Rather than slicing in two on bits we slice on limbs (32 bit words) + let words_slice: usize = y_words / 2; + + // Holds the lower bits of Y (declared to reuse Y's array contents to save on copying) + let yp: mbedtls_mpi = mbedtls_mpi { + private_p: (*Y).private_p, + private_n: words_slice, + private_s: (*Y).private_s, + }; + + // Holds the upper bits of Y, right shifted (also reuse Y's array contents) + let ypp: mbedtls_mpi = mbedtls_mpi { + private_p: unsafe { Y.private_p.add(words_slice) }, + private_n: y_words - words_slice, + private_s: (*Y).private_s, + }; + + let mut x_temp = mbedtls_mpi { + private_s: 0, + private_n: 0, + private_p: core::ptr::null_mut(), + }; + + unsafe { + mbedtls_mpi_init(&mut x_temp); + + error_checked!(mbedtls_mpi_mul_mpi(&mut x_temp, X, &yp)); + + // Z = b_upper * B + error_checked!(mbedtls_mpi_mul_mpi(Z, X, &ypp)); + + // X = X << b + error_checked!(mbedtls_mpi_shift_l(Z, words_slice * 32)); + + // X += Xtemp + error_checked!(mbedtls_mpi_add_mpi(Z, Z, &x_temp)); + + mbedtls_mpi_free(&mut x_temp); + } + + ret +} + +#[cfg(feature = "esp32s3")] +unsafe fn mbedtls_mpi_mult_mpi_failover_mod_mult( + Z: &mut mbedtls_mpi, + X: &mbedtls_mpi, + Y: &mbedtls_mpi, + z_words: usize, +) -> c_int { + match crate::RSA_REF { + None => unimplemented!("mbedtls_mpi_mult_mpi_failover_mod_mult"), + Some(ref mut rsa) => { + let mut ret = 0; + + let x_bits = unsafe { mbedtls_mpi_bitlen(X) }; + let y_bits = unsafe { mbedtls_mpi_bitlen(Y) }; + // TODO: We can have the words value from the mpi + let x_words = bits_to_words(x_bits); + let y_words = bits_to_words(y_bits); + let hw_words = calculate_hw_words(z_words); + + nb::block!(rsa.ready()).unwrap(); + match hw_words { + U2112::LIMBS => { + const OP_SIZE: usize = U2112::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE]; + // RINV + let mut rinv = [0u32; OP_SIZE]; + rinv[0] = 1; + // Modulus + let mut modulus = [0u32; OP_SIZE]; + for i in 0..hw_words { + modulus[i] = u32::MAX; + } + + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + + let mut calc = RsaModularMultiplication::::new( + rsa, &operand_x, // operand_a (X) X_MEM + &operand_y, // operand_b (Y) Y_MEM + &modulus, // modulus (M) M_MEM + 1, // mprime + ); + calc.start_modular_multiplication(&rinv); // r Z_MEM + + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, hw_words); + } + U2560::LIMBS => { + const OP_SIZE: usize = U2560::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE]; + // RINV + let mut rinv = [0u32; OP_SIZE]; + rinv[0] = 1; + // Modulus + let mut modulus = [0u32; OP_SIZE]; + for i in 0..hw_words { + modulus[i] = u32::MAX; + } + + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + + let mut calc = RsaModularMultiplication::::new( + rsa, &operand_x, // operand_a (X) X_MEM + &operand_y, // operand_b (Y) Y_MEM + &modulus, // modulus (M) M_MEM + 1, // mprime + ); + calc.start_modular_multiplication(&rinv); // r Z_MEM + + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, hw_words); + } + op => { + todo!("implement mod multi op {}", op); + } + } + + // Grow X to result size early, avoid interim allocations + unsafe { + error_checked!(mbedtls_mpi_grow(Z, hw_words)); + } + + Z.private_s = X.private_s * Y.private_s; + + // Relevant: https://github.com/espressif/esp-idf/issues/11850 + // + // If z_words < mpi_words(Z) (the actual words taken by the MPI result), + // the assert fails due to unsigned arithmetic - most likely hardware + // peripheral has produced an incorrect result for MPI operation. + // This can happen if data fed to the peripheral register was incorrect. + // + // z_words is calculated as the worst-case possible size of the result + // MPI Z. The difference between z_words and the actual words taken by + // the MPI result (mpi_words(Z)) can be a maximum of 1 word. + // The value z_bits (actual bits taken by the MPI result) is calculated + // as x_bits + y_bits bits, however, in some cases, z_bits can be + // x_bits + y_bits - 1 bits (see example below). + // 0b1111 * 0b1111 = 0b11100001 -> 8 bits + // 0b1000 * 0b1000 = 0b01000000 -> 7 bits. + // The code rounds up to the nearest word size, so the maximum difference + // could be of only 1 word. The assert handles this. + assert!(z_words - mpi_words(Z) <= 1); + + ret + } + } +} + +// Baseline multiplication: Z = X * Y (HAC 14.12) +#[cfg(feature = "esp32s3")] +#[no_mangle] +pub unsafe extern "C" fn mbedtls_mpi_mul_mpi( + Z: &mut mbedtls_mpi, + X: &mbedtls_mpi, + Y: &mbedtls_mpi, +) -> c_int { + match crate::RSA_REF { + None => unimplemented!("mbedtls_mpi_mul_mpi"), + Some(ref mut rsa) => { + let mut ret = 0; + + let x_bits = unsafe { mbedtls_mpi_bitlen(X) }; + let y_bits = unsafe { mbedtls_mpi_bitlen(Y) }; + // TODO: We can have the words value from the mpi + let x_words = bits_to_words(x_bits); + let y_words = bits_to_words(y_bits); + let z_words = bits_to_words(x_bits + y_bits); + let hw_words = calculate_hw_words(core::cmp::max(x_words, y_words)); + + // Short-circuit eval if either argument is 0 or 1. + // + // This is needed as the mpi modular division + // argument will sometimes call in here when one + // argument is too large for the hardware unit, but other + // argument is zero or one. + if x_bits == 0 || y_bits == 0 { + unsafe { mbedtls_mpi_lset(Z, 0) }; + return 0; + } + if x_bits == 1 { + ret = unsafe { mbedtls_mpi_copy(Z, Y) }; + (*Z).private_s *= (*X).private_s; + return ret; + } + if y_bits == 1 { + ret = unsafe { mbedtls_mpi_copy(Z, X) }; + (*Z).private_s *= (*Y).private_s; + return ret; + } + + // Grow Z to result size early, avoid interim allocations + unsafe { + error_checked!(mbedtls_mpi_grow(Z, z_words)); + } + + // If either factor is over 2048 bits, we can't use the standard hardware multiplier + // (it assumes result is double longest factor, and result is max 4096 bits.) + // + // However, we can fail over to mod_mult for up to 4096 bits of result (modulo + // multiplication doesn't have the same restriction, so result is simply the + // number of bits in X plus number of bits in in Y.) + + if hw_words * 32 > SOC_RSA_MAX_BIT_LEN / 2 { + if z_words * 32 <= SOC_RSA_MAX_BIT_LEN { + // Note: It's possible to use mpi_mult_mpi_overlong + // for this case as well, but it's very slightly + // slower and requires a memory allocation. + return mbedtls_mpi_mult_mpi_failover_mod_mult(Z, X, Y, z_words); + } else { + // Still too long for the hardware unit... + if y_words > x_words { + return mpi_mult_mpi_overlong(Z, X, Y, y_words); + } else { + return mpi_mult_mpi_overlong(Z, Y, X, x_words); + } + } + } + + // Otherwise, we can use the (faster) multiply hardware unit + nb::block!(rsa.ready()).unwrap(); + match hw_words * 4 { + U64::BYTES => { + const OP_SIZE: usize = U64::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, z_words); + } + U128::BYTES => { + const OP_SIZE: usize = U128::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, z_words); + } + U256::BYTES => { + const OP_SIZE: usize = U256::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, z_words); + } + U384::BYTES => { + const OP_SIZE: usize = U384::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, z_words); + } + U512::BYTES => { + const OP_SIZE: usize = U512::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, z_words); + } + // TODO: Is it normal to have hw_words * 4 not being a multiple of 32? + 68 => { + const OP_SIZE: usize = U576::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, z_words); + } + U1024::BYTES => { + const OP_SIZE: usize = U1024::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, z_words); + } + // TODO: Is it normal to have hw_words * 4 not being a multiple of 32? + 132 | U1088::BYTES => { + const OP_SIZE: usize = U1088::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, z_words); + } + U1152::BYTES => { + const OP_SIZE: usize = U1152::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), Z.private_p, z_words); + } + U2048::BYTES => { + const OP_SIZE: usize = U2048::LIMBS; + let mut operand_x = [0u32; OP_SIZE]; + let mut operand_y = [0u32; OP_SIZE]; + let mut out = [0u32; OP_SIZE * 2]; + copy_bytes(X.private_p, operand_x.as_mut_ptr(), x_words); + copy_bytes(Y.private_p, operand_y.as_mut_ptr(), y_words); + let mut calc = RsaMultiplication::::new(rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr(), (*Z).private_p, z_words); + } + op => { + todo!("Implement operand: {}", op); + } + } + Z.private_s = X.private_s * Y.private_s; + + ret + } + } +} + +#[cfg(feature = "esp32s3")] +#[no_mangle] +pub extern "C" fn mbedtls_mpi_mul_int( + X: &mut mbedtls_mpi, + A: &mbedtls_mpi, + mut b: mbedtls_mpi_uint, +) -> c_int { + let B: mbedtls_mpi = mbedtls_mpi { + private_s: 1, + private_n: 1, + private_p: &mut b, + }; + + unsafe { mbedtls_mpi_mul_mpi(X, A, &B) } +} diff --git a/esp-mbedtls/src/lib.rs b/esp-mbedtls/src/lib.rs index 2be2a37..97a9524 100644 --- a/esp-mbedtls/src/lib.rs +++ b/esp-mbedtls/src/lib.rs @@ -16,7 +16,7 @@ pub use esp32s2_hal as hal; #[cfg(feature = "esp32s3")] pub use esp32s3_hal as hal; -use crate::hal::rsa::Rsa; +pub use crate::hal::rsa::Rsa; mod compat; diff --git a/libs/xtensa-esp32s3-none-elf/libmbedcrypto.a b/libs/xtensa-esp32s3-none-elf/libmbedcrypto.a index ff878b4..43caf17 100644 Binary files a/libs/xtensa-esp32s3-none-elf/libmbedcrypto.a and b/libs/xtensa-esp32s3-none-elf/libmbedcrypto.a differ diff --git a/libs/xtensa-esp32s3-none-elf/libmbedtls.a b/libs/xtensa-esp32s3-none-elf/libmbedtls.a index e442ec1..e50f93e 100644 Binary files a/libs/xtensa-esp32s3-none-elf/libmbedtls.a and b/libs/xtensa-esp32s3-none-elf/libmbedtls.a differ diff --git a/libs/xtensa-esp32s3-none-elf/libmbedx509.a b/libs/xtensa-esp32s3-none-elf/libmbedx509.a index 55f258b..807a4cc 100644 Binary files a/libs/xtensa-esp32s3-none-elf/libmbedx509.a and b/libs/xtensa-esp32s3-none-elf/libmbedx509.a differ