From 25729d3b999d8d6eb5c63077e1e6b086d4b940cb Mon Sep 17 00:00:00 2001 From: Kirill Mikhailov Date: Tue, 16 Apr 2024 13:59:23 +0200 Subject: [PATCH] Not `panic`ing on wrong key length --- esp-hal/src/aes/mod.rs | 58 ++++++++++++++++++++++++++++++++++------- examples/src/bin/aes.rs | 4 +-- hil-test/tests/aes.rs | 12 ++++----- 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/esp-hal/src/aes/mod.rs b/esp-hal/src/aes/mod.rs index 50cb4c5e8a8..bc56cf9257d 100644 --- a/esp-hal/src/aes/mod.rs +++ b/esp-hal/src/aes/mod.rs @@ -37,10 +37,10 @@ //! //! ```no_run //! let mut block = block_buf.clone(); -//! aes.process(&mut block, Mode::Encryption128, &keybuf); +//! aes.process(&mut block, Mode::Encryption128, keybuf.into()); //! let hw_encrypted = block.clone(); //! -//! aes.process(&mut block, Mode::Decryption128, &keybuf); +//! aes.process(&mut block, Mode::Decryption128, keybuf.into()); //! let hw_decrypted = block; //! ``` //! @@ -122,6 +122,49 @@ mod aes_spec_impl; const ALIGN_SIZE: usize = core::mem::size_of::(); +/// Represents the various key sizes allowed for AES encryption and decryption. +pub enum Key { + /// 128-bit AES key + Key16([u8; 16]), + #[cfg(any(feature = "esp32", feature = "esp32s2"))] + /// 192-bit AES key + Key24([u8; 24]), + /// 256-bit AES key + Key32([u8; 32]), +} + +// Implementing From for easy conversion from array to Key enum. +impl From<[u8; 16]> for Key { + fn from(key: [u8; 16]) -> Self { + Key::Key16(key) + } +} + +#[cfg(any(feature = "esp32", feature = "esp32s2"))] +impl From<[u8; 24]> for Key { + fn from(key: [u8; 24]) -> Self { + Key::Key24(key) + } +} + +impl From<[u8; 32]> for Key { + fn from(key: [u8; 32]) -> Self { + Key::Key32(key) + } +} + +impl Key { + /// Returns a slice representation of the AES key. + fn as_slice(&self) -> &[u8] { + match self { + Key::Key16(ref key) => key, + #[cfg(any(feature = "esp32", feature = "esp32s2"))] + Key::Key24(ref key) => key, + Key::Key32(ref key) => key, + } + } +} + pub enum Mode { Encryption128 = 0, #[cfg(any(esp32, esp32s2))] @@ -154,14 +197,9 @@ impl<'d> Aes<'d> { } /// Encrypts/Decrypts the given buffer based on `mode` parameter - pub fn process(&mut self, block: &mut [u8; 16], mode: Mode, key: &[u8]) { - assert!( - key.len() == 16 - || (cfg!(any(feature = "esp32", feature = "esp32s2")) && key.len() == 24) - || key.len() == 32, - "Invalid key size" - ); - self.write_key(key); + pub fn process(&mut self, block: &mut [u8; 16], mode: Mode, key: Key){ + // Convert from Key enum to required byte slice + self.write_key(key.as_slice()); self.set_mode(mode as u8); self.set_block(block); self.start(); diff --git a/examples/src/bin/aes.rs b/examples/src/bin/aes.rs index 3f3f3824927..0c23c9d0f05 100644 --- a/examples/src/bin/aes.rs +++ b/examples/src/bin/aes.rs @@ -37,7 +37,7 @@ fn main() -> ! { let mut block = block_buf.clone(); let pre_hw_encrypt = cycles(); - aes.process(&mut block, Mode::Encryption128, &keybuf); + aes.process(&mut block, Mode::Encryption128, keybuf.into()); let post_hw_encrypt = cycles(); println!( "it took {} cycles for hw encrypt", @@ -45,7 +45,7 @@ fn main() -> ! { ); let hw_encrypted = block.clone(); let pre_hw_decrypt = cycles(); - aes.process(&mut block, Mode::Decryption128, &keybuf); + aes.process(&mut block, Mode::Decryption128, keybuf.into()); let post_hw_decrypt = cycles(); println!( "it took {} cycles for hw decrypt", diff --git a/hil-test/tests/aes.rs b/hil-test/tests/aes.rs index ee3c30e46a8..3c6051c86f9 100644 --- a/hil-test/tests/aes.rs +++ b/hil-test/tests/aes.rs @@ -53,7 +53,7 @@ mod tests { block_buf[..plaintext.len()].copy_from_slice(plaintext); let mut block = block_buf.clone(); - ctx.aes.process(&mut block, Mode::Encryption128, &keybuf); + ctx.aes.process(&mut block, Mode::Encryption128, keybuf.into()); assert_eq!(block, encrypted_message); } @@ -70,7 +70,7 @@ mod tests { keybuf[..keytext.len()].copy_from_slice(keytext); ctx.aes - .process(&mut encrypted_message, Mode::Decryption128, &keybuf); + .process(&mut encrypted_message, Mode::Decryption128, keybuf.into()); assert_eq!(&encrypted_message[..plaintext.len()], plaintext); } @@ -91,7 +91,7 @@ mod tests { block_buf[..plaintext.len()].copy_from_slice(plaintext); let mut block = block_buf.clone(); - ctx.aes.process(&mut block, Mode::Encryption192, &keybuf); + ctx.aes.process(&mut block, Mode::Encryption192, keybuf.into()); assert_eq!(block, encrypted_message); } @@ -109,7 +109,7 @@ mod tests { keybuf[..keytext.len()].copy_from_slice(keytext); ctx.aes - .process(&mut encrypted_message, Mode::Decryption192, &keybuf); + .process(&mut encrypted_message, Mode::Decryption192, keybuf.into()); assert_eq!(&encrypted_message[..plaintext.len()], plaintext); } @@ -129,7 +129,7 @@ mod tests { block_buf[..plaintext.len()].copy_from_slice(plaintext); let mut block = block_buf.clone(); - ctx.aes.process(&mut block, Mode::Encryption256, &keybuf); + ctx.aes.process(&mut block, Mode::Encryption256, keybuf.into()); assert_eq!(block, encrypted_message); } @@ -146,7 +146,7 @@ mod tests { keybuf[..keytext.len()].copy_from_slice(keytext); ctx.aes - .process(&mut encrypted_message, Mode::Decryption256, &keybuf); + .process(&mut encrypted_message, Mode::Decryption256, keybuf.into()); assert_eq!(&encrypted_message[..plaintext.len()], plaintext); } }