diff --git a/hacspec/src/lib.rs b/hacspec/src/lib.rs index a543d314..7e8f49e3 100644 --- a/hacspec/src/lib.rs +++ b/hacspec/src/lib.rs @@ -733,40 +733,40 @@ pub fn construct_state( ) } -/// the unsigned integer is encoded as a single byte +/// Check for: an unsigned integer encoded as a single byte #[inline(always)] fn is_cbor_uint_1byte(byte: U8) -> bool { let byte = byte.declassify(); return byte >= CBOR_UINT_1BYTE_START && byte <= CBOR_UINT_1BYTE_END; } -/// the unsigned integer is encoded as two bytes +/// Check for: an unsigned integer encoded as two bytes #[inline(always)] fn is_cbor_uint_2bytes(byte: U8) -> bool { return byte.declassify() == CBOR_UINT_1BYTE; } -/// the negative integer is encoded as a single byte +/// Check for: a negative integer encoded as a single byte #[inline(always)] fn is_cbor_neg_int_1byte(byte: U8) -> bool { let byte = byte.declassify(); return byte >= CBOR_NEG_INT_1BYTE_START && byte <= CBOR_NEG_INT_1BYTE_END; } -/// a single byte signals both bstr type and content length +/// Check for: a bstr denoted by a single byte which encodes both type and content length #[inline(always)] fn is_cbor_bstr_1byte_prefix(byte: U8) -> bool { let byte = byte.declassify(); return byte >= CBOR_MAJOR_BYTE_STRING && byte <= CBOR_MAJOR_BYTE_STRING_MAX; } -/// two bytes are used to signal bstr type and content length +/// Check for: a bstr denoted by two bytes, onr for type the other for content length #[inline(always)] fn is_cbor_bstr_2bytes_prefix(byte: U8) -> bool { return byte.declassify() == CBOR_BYTE_STRING; } -/// a single byte signals both array type and content length +/// Check for: an array denoted by a single byte which encodes both type and content length #[inline(always)] fn is_cbor_array_1byte_prefix(byte: U8) -> bool { let byte = byte.declassify(); @@ -1182,7 +1182,6 @@ fn decode_plaintext_3( let mut ead_3 = None::; let mut error = EDHOCError::UnknownError; let mut kid = U8(0xff); - // skip the CBOR magic byte as we know how long the MAC is let mut mac_3 = BytesMac3::new(); // check ID_CRED_I and MAC_3 @@ -1432,6 +1431,7 @@ fn decode_plaintext_2( let mut id_cred_r: U8 = U8(0xff); let mut mac_2 = BytesMac2::new(); + // check CBOR sequence types for c_r, id_cred_r, and mac_2 if (is_cbor_neg_int_1byte(plaintext_2[0]) || is_cbor_uint_1byte(plaintext_2[0])) && (is_cbor_neg_int_1byte(plaintext_2[1]) || is_cbor_uint_1byte(plaintext_2[1])) && (is_cbor_bstr_1byte_prefix(plaintext_2[2])) diff --git a/lib/src/edhoc.rs b/lib/src/edhoc.rs index e0e2632f..e4dfb2c9 100644 --- a/lib/src/edhoc.rs +++ b/lib/src/edhoc.rs @@ -672,6 +672,42 @@ pub fn construct_state( ) } +/// Check for: an unsigned integer encoded as a single byte +#[inline(always)] +fn is_cbor_uint_1byte(byte: U8) -> bool { + return byte >= CBOR_UINT_1BYTE_START && byte <= CBOR_UINT_1BYTE_END; +} + +/// Check for: an unsigned integer encoded as two bytes +#[inline(always)] +fn is_cbor_uint_2bytes(byte: U8) -> bool { + return byte == CBOR_UINT_1BYTE; +} + +/// Check for: a negative integer encoded as a single byte +#[inline(always)] +fn is_cbor_neg_int_1byte(byte: U8) -> bool { + return byte >= CBOR_NEG_INT_1BYTE_START && byte <= CBOR_NEG_INT_1BYTE_END; +} + +/// Check for: a bstr denoted by a single byte which encodes both type and content length +#[inline(always)] +fn is_cbor_bstr_1byte_prefix(byte: U8) -> bool { + return byte >= CBOR_MAJOR_BYTE_STRING && byte <= CBOR_MAJOR_BYTE_STRING_MAX; +} + +/// Check for: a bstr denoted by two bytes, onr for type the other for content length +#[inline(always)] +fn is_cbor_bstr_2bytes_prefix(byte: U8) -> bool { + return byte == CBOR_BYTE_STRING; +} + +/// Check for: an array denoted by a single byte which encodes both type and content length +#[inline(always)] +fn is_cbor_array_1byte_prefix(byte: U8) -> bool { + return byte >= CBOR_MAJOR_ARRAY && byte <= CBOR_MAJOR_ARRAY_MAX; +} + fn parse_suites_i( rcvd_message_1: &BufferMessage1, ) -> Result<(BytesSuites, usize, usize), EDHOCError> { @@ -681,62 +717,55 @@ fn parse_suites_i( let mut suites_i_len: usize = 0; // match based on first byte of SUITES_I, which can be either an int or an array - match rcvd_message_1.content[1] { + if is_cbor_uint_1byte(rcvd_message_1.content[1]) { // CBOR unsigned integer (0..=23) - 0x00..=0x17 => { - suites_i[0] = rcvd_message_1.content[1]; - suites_i_len = 1; - raw_suites_len = 1; - error = EDHOCError::Success; - } + suites_i[0] = rcvd_message_1.content[1]; + suites_i_len = 1; + raw_suites_len = 1; + error = EDHOCError::Success; + } else if is_cbor_uint_2bytes(rcvd_message_1.content[1]) { // CBOR unsigned integer (one-byte uint8_t follows) - 0x18 => { - suites_i[0] = rcvd_message_1.content[2]; - suites_i_len = 1; - raw_suites_len = 2; - error = EDHOCError::Success; - } + suites_i[0] = rcvd_message_1.content[2]; + suites_i_len = 1; + raw_suites_len = 2; + error = EDHOCError::Success; + } else if is_cbor_array_1byte_prefix(rcvd_message_1.content[1]) { // CBOR array (0..=23 data items follow) - 0x80..=0x97 => { - // the CBOR array length is encoded in the first byte, so we extract it - let suites_len: usize = (rcvd_message_1.content[1] - CBOR_MAJOR_ARRAY).into(); - // check surplus array encoding of ciphersuite - if suites_len > 1 { - raw_suites_len = 1; // account for the CBOR_MAJOR_ARRAY byte - if suites_len <= EDHOC_SUITES.len() { - let mut j: usize = 0; // index for addressing cipher suites - while j < suites_len { - raw_suites_len += 1; - // match based on cipher suite identifier - match rcvd_message_1.content[raw_suites_len] { - // CBOR unsigned integer (0..23) - 0x00..=0x17 => { - suites_i[j] = rcvd_message_1.content[raw_suites_len]; - suites_i_len += 1; - } - // CBOR unsigned integer (one-byte uint8_t follows) - 0x18 => { - raw_suites_len += 1; // account for the 0x18 tag byte - suites_i[j] = rcvd_message_1.content[raw_suites_len]; - suites_i_len += 1; - } - _ => { - error = EDHOCError::ParsingError; - break; - } - } - j += 1; + // the CBOR array length is encoded in the first byte, so we extract it + let suites_len: usize = (rcvd_message_1.content[1] - CBOR_MAJOR_ARRAY).into(); + // check surplus array encoding of ciphersuite + raw_suites_len = 1; // account for the CBOR_MAJOR_ARRAY byte + if suites_len > 1 && suites_len <= EDHOC_SUITES.len() { + let mut j: usize = 0; // index for addressing cipher suites + while j < suites_len { + raw_suites_len += 1; + // match based on cipher suite identifier + match rcvd_message_1.content[raw_suites_len] { + // CBOR unsigned integer (0..23) + 0x00..=0x17 => { + suites_i[j] = rcvd_message_1.content[raw_suites_len]; + suites_i_len += 1; + } + // CBOR unsigned integer (one-byte uint8_t follows) + 0x18 => { + raw_suites_len += 1; // account for the 0x18 tag byte + suites_i[j] = rcvd_message_1.content[raw_suites_len]; + suites_i_len += 1; + } + _ => { + error = EDHOCError::ParsingError; + break; } - error = EDHOCError::Success; - } else { - error = EDHOCError::ParsingError; } - } else { - error = EDHOCError::ParsingError; + j += 1; } + error = EDHOCError::Success; + } else { + error = EDHOCError::ParsingError; } - _ => error = EDHOCError::ParsingError, - }; + } else { + error = EDHOCError::ParsingError; + } match error { EDHOCError::Success => Ok((suites_i, suites_i_len, raw_suites_len)), @@ -751,12 +780,14 @@ fn parse_ead(message: &EdhocMessageBuffer, offset: usize) -> Result Ok((label as u8, false)), + Ok((label as u8, false)) + } else if is_cbor_neg_int_1byte(label) { // CBOR negative integer (-1..=-24) - label @ 0x20..=0x37 => Ok((label - (CBOR_NEG_INT_1BYTE_START - 1), true)), - _ => Err(EDHOCError::ParsingError), + Ok((label - (CBOR_NEG_INT_1BYTE_START - 1), true)) + } else { + Err(EDHOCError::ParsingError) }; if res_label.is_ok() { @@ -785,24 +816,6 @@ fn parse_ead(message: &EdhocMessageBuffer, offset: usize) -> Result bool { - return (first_byte & CBOR_MAJOR_ARRAY) == CBOR_MAJOR_ARRAY; -} - -fn is_encoded_conn_id_minimal(enc_conn_id: U8) -> bool { - return (enc_conn_id >= 20 && enc_conn_id <= 37) || (enc_conn_id >= 0 && enc_conn_id <= 17); -} - -fn should_encoded_conn_id_be_minimal(conn_id_byte1: U8, conn_id_byte2: U8) -> bool { - return !is_encoded_conn_id_minimal(conn_id_byte1) - && (conn_id_byte1 == (CBOR_MAJOR_BYTE_STRING | 0x1)) // bstr with length of 1 - && is_encoded_conn_id_minimal(conn_id_byte2); -} - -fn should_encoded_ciphersuite_be_minimal(byte1: U8, byte2: U8) -> bool { - return is_cbor_array(byte1) && (byte1 - CBOR_MAJOR_ARRAY) == 1; -} - fn parse_message_1( rcvd_message_1: &BufferMessage1, ) -> Result< @@ -825,24 +838,22 @@ fn parse_message_1( let mut c_i = 0; let mut ead_1 = None::; - // check surplus array encoding - if !is_cbor_array(rcvd_message_1.content[0]) { + // first element of CBOR sequence must be an integer + if is_cbor_uint_1byte(rcvd_message_1.content[0]) { method = rcvd_message_1.content[0]; let res_suites = parse_suites_i(rcvd_message_1); if res_suites.is_ok() { (suites_i, suites_i_len, raw_suites_len) = res_suites.unwrap(); - let g_x_type = rcvd_message_1.content[1 + raw_suites_len]; - if g_x_type == CBOR_BYTE_STRING { + if is_cbor_bstr_2bytes_prefix(rcvd_message_1.content[1 + raw_suites_len]) { g_x.copy_from_slice( &rcvd_message_1.content[3 + raw_suites_len..3 + raw_suites_len + P256_ELEM_LEN], ); - // check surplus bstr encoding c_i = rcvd_message_1.content[3 + raw_suites_len + P256_ELEM_LEN]; - let c_i_lookahead = rcvd_message_1.content[4 + raw_suites_len + P256_ELEM_LEN]; - if !should_encoded_conn_id_be_minimal(c_i, c_i_lookahead) { + // check that c_i is encoded as single-byte int (we still do not support bstr encoding) + if is_cbor_neg_int_1byte(c_i) || is_cbor_uint_1byte(c_i) { // if there is still more to parse, the rest will be the EAD_1 if rcvd_message_1.len > (4 + raw_suites_len + P256_ELEM_LEN) { // NOTE: since the current implementation only supports one EAD handler, @@ -966,7 +977,7 @@ fn parse_message_2( let mut ciphertext_2: BufferCiphertext2 = BufferCiphertext2::new(); // ensure the whole message is a single CBOR sequence - if rcvd_message_2.content[0] == CBOR_BYTE_STRING + if is_cbor_bstr_2bytes_prefix(rcvd_message_2.content[0]) && rcvd_message_2.content[1] == (rcvd_message_2.len as u8 - 2) { g_y[..].copy_from_slice(&rcvd_message_2.content[2..2 + P256_ELEM_LEN]); @@ -1099,25 +1110,33 @@ fn decode_plaintext_3( ) -> Result<(U8, BytesMac3, Option), EDHOCError> { let mut ead_3 = None::; let mut error = EDHOCError::UnknownError; - - let kid = plaintext_3.content[0usize]; - // skip the CBOR magic byte as we know how long the MAC is + let mut kid: u8 = 0xff; let mut mac_3: BytesMac3 = [0x00; MAC_LENGTH_3]; - mac_3[..].copy_from_slice(&plaintext_3.content[2..2 + MAC_LENGTH_3]); - - // if there is still more to parse, the rest will be the EAD_3 - if plaintext_3.len > (2 + MAC_LENGTH_3) { - // NOTE: since the current implementation only supports one EAD handler, - // we assume only one EAD item - let ead_res = parse_ead(plaintext_3, 2 + MAC_LENGTH_3); - if ead_res.is_ok() { - ead_3 = ead_res.unwrap(); + + // check ID_CRED_I and MAC_3 + if (is_cbor_neg_int_1byte(plaintext_3.content[0]) || is_cbor_uint_1byte(plaintext_3.content[0])) + && (is_cbor_bstr_1byte_prefix(plaintext_3.content[1])) + { + kid = plaintext_3.content[0usize]; + // skip the CBOR magic byte as we know how long the MAC is + mac_3[..].copy_from_slice(&plaintext_3.content[2..2 + MAC_LENGTH_3]); + + // if there is still more to parse, the rest will be the EAD_3 + if plaintext_3.len > (2 + MAC_LENGTH_3) { + // NOTE: since the current implementation only supports one EAD handler, + // we assume only one EAD item + let ead_res = parse_ead(plaintext_3, 2 + MAC_LENGTH_3); + if ead_res.is_ok() { + ead_3 = ead_res.unwrap(); + error = EDHOCError::Success; + } else { + error = ead_res.unwrap_err(); + } + } else if plaintext_3.len == (2 + MAC_LENGTH_3) { error = EDHOCError::Success; } else { - error = ead_res.unwrap_err(); + error = EDHOCError::ParsingError; } - } else if plaintext_3.len == (2 + MAC_LENGTH_3) { - error = EDHOCError::Success; } else { error = EDHOCError::ParsingError; } @@ -1323,29 +1342,40 @@ fn decode_plaintext_2( ) -> Result<(U8, U8, BytesMac2, Option), EDHOCError> { let mut error = EDHOCError::UnknownError; let mut ead_2 = None::; - - let c_r = plaintext_2[0]; - let id_cred_r = plaintext_2[1]; - // skip cbor byte string byte as we know how long the string is + let mut c_r: U8 = 0xff; + let mut id_cred_r: U8 = 0xff; let mut mac_2: BytesMac2 = [0x00; MAC_LENGTH_2]; - mac_2[..].copy_from_slice(&plaintext_2[3..3 + MAC_LENGTH_2]); - - // if there is still more to parse, the rest will be the EAD_2 - if plaintext_2_len > (3 + MAC_LENGTH_2) { - // NOTE: since the current implementation only supports one EAD handler, - // we assume only one EAD item - let ead_res = parse_ead( - &plaintext_2[..plaintext_2_len].try_into().expect("too long"), - 3 + MAC_LENGTH_2, - ); - if ead_res.is_ok() { - ead_2 = ead_res.unwrap(); + + // check CBOR sequence types for c_r, id_cred_r, and mac_2 + if (is_cbor_neg_int_1byte(plaintext_2[0]) || is_cbor_uint_1byte(plaintext_2[0])) + && (is_cbor_neg_int_1byte(plaintext_2[1]) || is_cbor_uint_1byte(plaintext_2[1])) + && (is_cbor_bstr_1byte_prefix(plaintext_2[2])) + // TODO: check mac length as well + { + c_r = plaintext_2[0]; + id_cred_r = plaintext_2[1]; + // skip cbor byte string byte as we know how long the string is + mac_2[..].copy_from_slice(&plaintext_2[3..3 + MAC_LENGTH_2]); + + // if there is still more to parse, the rest will be the EAD_2 + if plaintext_2_len > (3 + MAC_LENGTH_2) { + // NOTE: since the current implementation only supports one EAD handler, + // we assume only one EAD item + let ead_res = parse_ead( + &plaintext_2[..plaintext_2_len].try_into().expect("too long"), + 3 + MAC_LENGTH_2, + ); + if ead_res.is_ok() { + ead_2 = ead_res.unwrap(); + error = EDHOCError::Success; + } else { + error = ead_res.unwrap_err(); + } + } else if plaintext_2_len == (3 + MAC_LENGTH_2) { error = EDHOCError::Success; } else { - error = ead_res.unwrap_err(); + error = EDHOCError::ParsingError; } - } else if plaintext_2_len == (3 + MAC_LENGTH_2) { - error = EDHOCError::Success; } else { error = EDHOCError::ParsingError; } @@ -1667,22 +1697,37 @@ mod tests { #[test] fn test_parse_message_1_invalid_traces() { let message_1_tv = BufferMessage1::from_hex(MESSAGE_1_INVALID_ARRAY_TV); - assert!(parse_message_1(&message_1_tv).is_err()); + assert_eq!( + parse_message_1(&message_1_tv).unwrap_err(), + EDHOCError::ParsingError + ); let message_1_tv = BufferMessage1::from_hex(MESSAGE_1_INVALID_C_I_TV); - assert!(parse_message_1(&message_1_tv).is_err()); + assert_eq!( + parse_message_1(&message_1_tv).unwrap_err(), + EDHOCError::ParsingError + ); let message_1_tv = BufferMessage1::from_hex(MESSAGE_1_INVALID_CIPHERSUITE_TV); - assert!(parse_message_1(&message_1_tv).is_err()); + assert_eq!( + parse_message_1(&message_1_tv).unwrap_err(), + EDHOCError::ParsingError + ); let message_1_tv = BufferMessage1::from_hex(MESSAGE_1_INVALID_TEXT_EPHEMERAL_KEY_TV); - assert!(parse_message_1(&message_1_tv).is_err()); + assert_eq!( + parse_message_1(&message_1_tv).unwrap_err(), + EDHOCError::ParsingError + ); } #[test] fn test_parse_message_2_invalid_traces() { let message_2_tv = BufferMessage1::from_hex(MESSAGE_2_INVALID_NUMBER_OF_CBOR_SEQUENCE_TV); - assert!(parse_message_2(&message_2_tv).is_err()); + assert_eq!( + parse_message_1(&message_2_tv).unwrap_err(), + EDHOCError::ParsingError + ); } #[test] @@ -1822,15 +1867,15 @@ mod tests { let mut plaintext_2_tv_buffer: BytesMaxBuffer = [0x00u8; MAX_BUFFER_LEN]; plaintext_2_tv_buffer[..plaintext_2_tv.len] .copy_from_slice(&plaintext_2_tv.content[..plaintext_2_tv.len]); - let plaintext_2 = decode_plaintext_2(&plaintext_2_tv_buffer, plaintext_2_tv.len); - assert!(plaintext_2.is_err()); + let ret = decode_plaintext_2(&plaintext_2_tv_buffer, plaintext_2_tv.len); + assert_eq!(ret.unwrap_err(), EDHOCError::ParsingError); let plaintext_2_tv = BufferPlaintext2::from_hex(PLAINTEXT_2_SURPLUS_BSTR_ID_CRED_TV); let mut plaintext_2_tv_buffer: BytesMaxBuffer = [0x00u8; MAX_BUFFER_LEN]; plaintext_2_tv_buffer[..plaintext_2_tv.len] .copy_from_slice(&plaintext_2_tv.content[..plaintext_2_tv.len]); - let plaintext_2 = decode_plaintext_2(&plaintext_2_tv_buffer, plaintext_2_tv.len); - assert!(plaintext_2.is_err()); + let ret = decode_plaintext_2(&plaintext_2_tv_buffer, plaintext_2_tv.len); + assert_eq!(ret.unwrap_err(), EDHOCError::ParsingError); } #[test]