From 6f53ba9092e7f008b447bdebea466124239284f7 Mon Sep 17 00:00:00 2001 From: Remco Bloemen Date: Mon, 14 Oct 2024 15:34:19 +0200 Subject: [PATCH] Conversions from der Any, Int, Uint and referenced --- src/support/der.rs | 92 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 13 deletions(-) diff --git a/src/support/der.rs b/src/support/der.rs index 98c31de..224b86f 100644 --- a/src/support/der.rs +++ b/src/support/der.rs @@ -4,7 +4,7 @@ use crate::Uint; use der::{ - asn1::{AnyRef, IntRef, UintRef}, + asn1::{Any, AnyRef, Int, IntRef, Uint as DerUint, UintRef}, DecodeValue, EncodeValue, Error, FixedTag, Header, Length, Reader, Result, Tag, ValueOrd, Writer, }; @@ -43,15 +43,7 @@ impl<'a, const BITS: usize, const LIMBS: usize> DecodeValue<'a> for Uint Length::try_from(Self::BYTES + 1)? { return Err(Self::TAG.non_canonical_error()); } - let bytes = reader.read_vec(header.length)?; - let bytes = match bytes.as_slice() { - [] => Err(Tag::Integer.length_error()), - [0, byte, ..] if *byte < 0x80 => Err(Tag::Integer.non_canonical_error()), - [0, rest @ ..] => Ok(rest), - [byte, ..] if *byte >= 0x80 => Err(Tag::Integer.value_error()), - bytes => Ok(bytes), - }?; - Self::try_from_be_slice(bytes).ok_or_else(|| Tag::Integer.non_canonical_error()) + from_der_slice(reader.read_vec(header.length)?.as_slice()) } } @@ -66,19 +58,69 @@ impl TryFrom> for Uint TryFrom> for Uint { type Error = Error; - fn try_from(any: IntRef<'_>) -> Result { - any.decode_as() + fn try_from(int: IntRef<'_>) -> Result { + from_der_slice(int.as_bytes()) } } impl TryFrom> for Uint { type Error = Error; - fn try_from(any: UintRef<'_>) -> Result { + fn try_from(uint: UintRef<'_>) -> Result { + from_der_uint_slice(uint.as_bytes()) + } +} + +impl TryFrom for Uint { + type Error = Error; + + fn try_from(any: Any) -> Result { any.decode_as() } } +impl TryFrom for Uint { + type Error = Error; + + fn try_from(int: Int) -> Result { + from_der_slice(int.as_bytes()) + } +} + +impl TryFrom for Uint { + type Error = Error; + + fn try_from(uint: DerUint) -> Result { + from_der_uint_slice(uint.as_bytes()) + } +} + +fn from_der_slice( + bytes: &[u8], +) -> Result> { + // Handle sign bits and zero-prefix. + let bytes = match bytes { + [] => Err(Tag::Integer.length_error()), + [0, byte, ..] if *byte < 0x80 => Err(Tag::Integer.non_canonical_error()), + [0, rest @ ..] => Ok(rest), + [byte, ..] if *byte >= 0x80 => Err(Tag::Integer.value_error()), + bytes => Ok(bytes), + }?; + Uint::try_from_be_slice(bytes).ok_or_else(|| Tag::Integer.non_canonical_error()) +} + +fn from_der_uint_slice( + bytes: &[u8], +) -> Result> { + // UintRef and Uint have the leading 0x00 removed. + match bytes { + [] => Err(Tag::Integer.length_error()), + [0] => Ok(Uint::ZERO), + [0, ..] => Err(Tag::Integer.non_canonical_error()), + bytes => Uint::try_from_be_slice(bytes).ok_or_else(|| Tag::Integer.non_canonical_error()), + } +} + #[cfg(test)] mod tests { use super::*; @@ -107,4 +149,28 @@ mod tests { assert_eq!(serialized1, serialized2); }); } + + macro_rules! test_roundtrip { + ($name:ident, $ty:ty) => { + #[test] + fn $name() { + const_for!(BITS in SIZES { + const LIMBS: usize = nlimbs(BITS); + proptest!(|(value: Uint)| { + let serialized = value.to_der().unwrap(); + let der = <$ty>::from_der(&serialized).unwrap(); + let deserialized = der.try_into().unwrap(); + assert_eq!(value, deserialized); + }); + }); + } + }; + } + + test_roundtrip!(test_der_anyref_roundtrip, AnyRef); + test_roundtrip!(test_der_intref_roundtrip, IntRef); + test_roundtrip!(test_der_uintref_roundtrip, UintRef); + test_roundtrip!(test_der_any_roundtrip, Any); + test_roundtrip!(test_der_int_roundtrip, Int); + test_roundtrip!(test_der_uint_roundtrip, DerUint); }