From c5bc78a0900ca74397af88f0c015f0271cc8d934 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Wed, 8 May 2024 00:07:15 +0200 Subject: [PATCH] feat: improve formatting implementations --- src/base_convert.rs | 4 +- src/cmp.rs | 14 +-- src/fmt.rs | 216 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 6 +- src/string.rs | 166 ++++-------------------------- src/support/ark_ff_04.rs | 2 +- src/utils.rs | 1 + 7 files changed, 251 insertions(+), 158 deletions(-) create mode 100644 src/fmt.rs diff --git a/src/base_convert.rs b/src/base_convert.rs index daa4f900..d3c386f2 100644 --- a/src/base_convert.rs +++ b/src/base_convert.rs @@ -26,9 +26,9 @@ impl fmt::Display for BaseConvertError { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Overflow => f.write_str("The value is too large to fit the target type"), + Self::Overflow => f.write_str("the value is too large to fit the target type"), Self::InvalidBase(base) => { - write!(f, "The requested number base {base} is less than two") + write!(f, "the requested number base {base} is less than two") } Self::InvalidDigit(digit, base) => { write!(f, "digit {digit} is out of range for base {base}") diff --git a/src/cmp.rs b/src/cmp.rs index 80832f6a..130ab43f 100644 --- a/src/cmp.rs +++ b/src/cmp.rs @@ -1,17 +1,17 @@ use crate::Uint; use core::cmp::Ordering; -impl Ord for Uint { +impl PartialOrd for Uint { #[inline] - fn cmp(&self, rhs: &Self) -> Ordering { - crate::algorithms::cmp(self.as_limbs(), rhs.as_limbs()) + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } -impl PartialOrd for Uint { +impl Ord for Uint { #[inline] - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + fn cmp(&self, rhs: &Self) -> Ordering { + crate::algorithms::cmp(self.as_limbs(), rhs.as_limbs()) } } @@ -20,7 +20,7 @@ impl Uint { #[inline] #[must_use] pub fn is_zero(&self) -> bool { - self == &Self::ZERO + *self == Self::ZERO } } diff --git a/src/fmt.rs b/src/fmt.rs new file mode 100644 index 00000000..61a4dd2a --- /dev/null +++ b/src/fmt.rs @@ -0,0 +1,216 @@ +#![allow(clippy::missing_inline_in_public_items)] // allow format functions +#![cfg(feature = "alloc")] + +use crate::Uint; +use core::{ + fmt::{self, Write}, + mem::MaybeUninit, +}; + +mod base { + pub(super) trait Base { + /// Highest power of the base that fits in a `u64`. + const MAX: u64; + /// Number of characters written using `MAX` as the base in + /// `to_base_be`. + /// + /// This is `MAX.log(base)`. + const WIDTH: usize; + /// The prefix for the base. + const PREFIX: &'static str; + } + + pub(super) struct Binary; + impl Base for Binary { + const MAX: u64 = 1 << 63; + const WIDTH: usize = 63; + const PREFIX: &'static str = "0b"; + } + + pub(super) struct Octal; + impl Base for Octal { + const MAX: u64 = 1 << 63; + const WIDTH: usize = 21; + const PREFIX: &'static str = "0o"; + } + + pub(super) struct Decimal; + impl Base for Decimal { + const MAX: u64 = 10_000_000_000_000_000_000; + const WIDTH: usize = 19; + const PREFIX: &'static str = ""; + } + + pub(super) struct Hexadecimal; + impl Base for Hexadecimal { + const MAX: u64 = 1 << 60; + const WIDTH: usize = 15; + const PREFIX: &'static str = "0x"; + } +} +use base::Base; + +macro_rules! write_digits { + ($self:expr, $f:expr; $base:ty, $base_char:literal) => { + if LIMBS == 0 || $self.is_zero() { + return $f.pad_integral(true, <$base>::PREFIX, "0"); + } + // Use `BITS` for all bases since `generic_const_exprs` is not yet stable. + let mut buffer = DisplayBuffer::::new(); + for (i, spigot) in $self.to_base_be(<$base>::MAX).enumerate() { + write!( + buffer, + concat!("{:0width$", $base_char, "}"), + spigot, + width = if i == 0 { 0 } else { <$base>::WIDTH }, + ) + .unwrap(); + } + return $f.pad_integral(true, <$base>::PREFIX, buffer.as_str()); + }; +} + +impl fmt::Display for Uint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write_digits!(self, f; base::Decimal, ""); + } +} + +impl fmt::Debug for Uint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} + +impl fmt::Binary for Uint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write_digits!(self, f; base::Binary, "b"); + } +} + +impl fmt::Octal for Uint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write_digits!(self, f; base::Octal, "o"); + } +} + +impl fmt::LowerHex for Uint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write_digits!(self, f; base::Hexadecimal, "x"); + } +} + +impl fmt::UpperHex for Uint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write_digits!(self, f; base::Hexadecimal, "X"); + } +} + +struct DisplayBuffer { + buf: [MaybeUninit; SIZE], + len: usize, +} + +impl DisplayBuffer { + #[inline] + const fn new() -> Self { + Self { + buf: unsafe { MaybeUninit::uninit().assume_init() }, + len: 0, + } + } + + #[inline] + fn as_str(&self) -> &str { + // SAFETY: `buf` is only written to by the `fmt::Write::write_str` + // implementation which writes a valid UTF-8 string to `buf` and + // correctly sets `len`. + unsafe { core::str::from_utf8_unchecked(&self.as_bytes_full()[..self.len]) } + } + + #[inline] + const fn as_bytes_full(&self) -> &[u8] { + unsafe { &*(self.buf.as_slice() as *const [_] as *const [u8]) } + } +} + +impl fmt::Write for DisplayBuffer { + fn write_str(&mut self, s: &str) -> fmt::Result { + if self.len + s.len() > SIZE { + return Err(fmt::Error); + } + unsafe { + let dst = self.buf.as_mut_ptr().add(self.len).cast(); + core::ptr::copy_nonoverlapping(s.as_ptr(), dst, s.len()); + } + self.len += s.len(); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::{prop_assert_eq, proptest}; + + #[allow(unused_imports)] + use alloc::string::ToString; + + #[allow(clippy::unreadable_literal)] + const N: Uint<256, 4> = Uint::from_limbs([ + 0xa8ec92344438aaf4_u64, + 0x9819ebdbd1faaab1_u64, + 0x573b1a7064c19c1a_u64, + 0xc85ef7d79691fe79_u64, + ]); + + #[test] + fn test_num() { + assert_eq!( + N.to_string(), + "90630363884335538722706632492458228784305343302099024356772372330524102404852" + ); + assert_eq!( + format!("{N:x}"), + "c85ef7d79691fe79573b1a7064c19c1a9819ebdbd1faaab1a8ec92344438aaf4" + ); + assert_eq!( + format!("{N:b}"), + "1100100001011110111101111101011110010110100100011111111001111001010101110011101100011010011100000110010011000001100111000001101010011000000110011110101111011011110100011111101010101010101100011010100011101100100100100011010001000100001110001010101011110100" + ); + assert_eq!( + format!("{N:o}"), + "14413675753626443771712563543234062301470152300636573364375252543243544443210416125364" + ); + } + + #[test] + fn test_fmt() { + proptest!(|(value: u128)| { + let n: Uint<128, 2> = Uint::from(value); + + prop_assert_eq!(format!("{n:b}"), format!("{value:b}")); + prop_assert_eq!(format!("{n:064b}"), format!("{value:064b}")); + prop_assert_eq!(format!("{n:#b}"), format!("{value:#b}")); + + prop_assert_eq!(format!("{n:o}"), format!("{value:o}")); + prop_assert_eq!(format!("{n:064o}"), format!("{value:064o}")); + prop_assert_eq!(format!("{n:#o}"), format!("{value:#o}")); + + prop_assert_eq!(format!("{n:}"), format!("{value:}")); + prop_assert_eq!(format!("{n:064}"), format!("{value:064}")); + prop_assert_eq!(format!("{n:#}"), format!("{value:#}")); + prop_assert_eq!(format!("{n:?}"), format!("{value:?}")); + prop_assert_eq!(format!("{n:064}"), format!("{value:064?}")); + prop_assert_eq!(format!("{n:#?}"), format!("{value:#?}")); + + prop_assert_eq!(format!("{n:x}"), format!("{value:x}")); + prop_assert_eq!(format!("{n:064x}"), format!("{value:064x}")); + prop_assert_eq!(format!("{n:#x}"), format!("{value:#x}")); + + prop_assert_eq!(format!("{n:X}"), format!("{value:X}")); + prop_assert_eq!(format!("{n:064X}"), format!("{value:064X}")); + prop_assert_eq!(format!("{n:#X}"), format!("{value:#X}")); + }); + } +} diff --git a/src/lib.rs b/src/lib.rs index f6ddf7a8..55e4c02f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,10 +19,7 @@ clippy::cast_sign_loss, clippy::cast_lossless, )] -#![cfg_attr( - any(test, feature = "bench"), - allow(clippy::wildcard_imports, clippy::cognitive_complexity) -)] +#![cfg_attr(test, allow(clippy::wildcard_imports, clippy::cognitive_complexity))] #![cfg_attr(not(feature = "std"), no_std)] // Unstable features #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] @@ -51,6 +48,7 @@ mod bytes; mod cmp; mod const_for; mod div; +mod fmt; mod from; mod gcd; mod log; diff --git a/src/string.rs b/src/string.rs index 61392517..4bc4bcf6 100644 --- a/src/string.rs +++ b/src/string.rs @@ -1,100 +1,8 @@ #![allow(clippy::missing_inline_in_public_items)] // allow format functions -use crate::{base_convert::BaseConvertError, utils::rem_up, Uint}; +use crate::{base_convert::BaseConvertError, Uint}; use core::{fmt, str::FromStr}; -// FEATURE: Respect width parameter in formatters. - -// TODO: Do we want to write `0` for `BITS == 0`. - -#[cfg(feature = "alloc")] -impl fmt::Display for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // Base convert 19 digits at a time - const BASE: u64 = 10_000_000_000_000_000_000_u64; - let mut spigot = self.to_base_be(BASE); - write!(f, "{}", spigot.next().unwrap_or(0))?; - for digits in spigot { - write!(f, "{digits:019}")?; - } - Ok(()) - } -} - -impl fmt::Debug for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:#x}_U{BITS}") - } -} - -impl fmt::LowerHex for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.fmt_hex::(f) - } -} - -impl fmt::UpperHex for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.fmt_hex::(f) - } -} - -impl fmt::Binary for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if f.alternate() { - write!(f, "0b")?; - } - if LIMBS == 0 || *self == Self::ZERO { - return f.write_str("0"); - } - - for (i, &limb) in self.limbs.iter().rev().enumerate() { - let width = if i == 0 { rem_up(Self::BITS, 64) } else { 64 }; - write!(f, "{limb:0width$b}")?; - } - Ok(()) - } -} - -#[cfg(feature = "alloc")] -impl fmt::Octal for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // Base convert 21 digits at a time - const BASE: u64 = 0x8000_0000_0000_0000_u64; - let mut spigot = self.to_base_be(BASE); - write!(f, "{:o}", spigot.next().unwrap_or(0))?; - for digits in spigot { - write!(f, "{digits:021o}")?; - } - Ok(()) - } -} - -impl Uint { - fn fmt_hex(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if f.alternate() { - write!(f, "0x")?; - } - if LIMBS == 0 || *self == Self::ZERO { - return f.write_str("0"); - } - - for (i, &limb) in self.limbs.iter().rev().enumerate() { - let width = if i == 0 { - 2 * rem_up(Self::BITS, 8) - } else { - 16 - }; - if UPPER { - write!(f, "{limb:0width$X}")?; - } else { - write!(f, "{limb:0width$x}")?; - } - } - Ok(()) - } -} - /// Error for [`from_str_radix`](Uint::from_str_radix). #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ParseError { @@ -129,9 +37,9 @@ impl From for ParseError { impl fmt::Display for ParseError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::BaseConvertError(e) => fmt::Display::fmt(e, f), - Self::InvalidDigit(c) => write!(f, "Invalid digit: {c}"), - Self::InvalidRadix(r) => write!(f, "Invalid radix {r}, up to 64 is supported"), + Self::BaseConvertError(e) => e.fmt(f), + Self::InvalidDigit(c) => write!(f, "invalid digit: {c}"), + Self::InvalidRadix(r) => write!(f, "invalid radix {r}, up to 64 is supported"), } } } @@ -198,65 +106,35 @@ impl FromStr for Uint { type Err = ParseError; fn from_str(src: &str) -> Result { - if src.is_char_boundary(2) { + let (src, radix) = if src.is_char_boundary(2) { let (prefix, rest) = src.split_at(2); match prefix { - "0x" | "0X" => return Self::from_str_radix(rest, 16), - "0o" | "0O" => return Self::from_str_radix(rest, 8), - "0b" | "0B" => return Self::from_str_radix(rest, 2), - _ => {} + "0x" | "0X" => (rest, 16), + "0o" | "0O" => (rest, 8), + "0b" | "0B" => (rest, 2), + _ => (src, 10), } - } - Self::from_str_radix(src, 10) + } else { + (src, 10) + }; + Self::from_str_radix(src, radix) } } #[cfg(test)] mod tests { use super::*; - use proptest::proptest; - - #[allow(unused_imports)] - use alloc::string::ToString; - - #[allow(clippy::unreadable_literal)] - const N: Uint<256, 4> = Uint::from_limbs([ - 0xa8ec92344438aaf4_u64, - 0x9819ebdbd1faaab1_u64, - 0x573b1a7064c19c1a_u64, - 0xc85ef7d79691fe79_u64, - ]); - - #[test] - fn test_num() { - assert_eq!( - N.to_string(), - "90630363884335538722706632492458228784305343302099024356772372330524102404852" - ); - assert_eq!( - format!("{N:x}"), - "c85ef7d79691fe79573b1a7064c19c1a9819ebdbd1faaab1a8ec92344438aaf4" - ); - assert_eq!( - format!("{N:b}"), - "1100100001011110111101111101011110010110100100011111111001111001010101110011101100011010011100000110010011000001100111000001101010011000000110011110101111011011110100011111101010101010101100011010100011101100100100100011010001000100001110001010101011110100" - ); - assert_eq!( - format!("{N:o}"), - "14413675753626443771712563543234062301470152300636573364375252543243544443210416125364" - ); - } + use proptest::{prop_assert_eq, proptest}; #[test] - fn test_hex() { - proptest!(|(value: u64)| { - let n: Uint<64, 1> = Uint::from(value); - assert_eq!(format!("{n:x}"), format!("{value:016x}")); - assert_eq!(format!("{n:#x}"), format!("{value:#018x}")); - assert_eq!(format!("{n:X}"), format!("{value:016X}")); - assert_eq!(format!("{n:#X}"), format!("{value:#018X}")); - assert_eq!(format!("{n:b}"), format!("{value:064b}")); - assert_eq!(format!("{n:#b}"), format!("{value:#066b}")); + fn test_parse() { + proptest!(|(value: u128)| { + type U = Uint<128, 2>; + prop_assert_eq!(U::from_str(&format!("{value:#b}")), Ok(U::from(value))); + prop_assert_eq!(U::from_str(&format!("{value:#o}")), Ok(U::from(value))); + prop_assert_eq!(U::from_str(&format!("{value:}")), Ok(U::from(value))); + prop_assert_eq!(U::from_str(&format!("{value:#x}")), Ok(U::from(value))); + prop_assert_eq!(U::from_str(&format!("{value:#X}")), Ok(U::from(value))); }); } } diff --git a/src/support/ark_ff_04.rs b/src/support/ark_ff_04.rs index 694ce65c..95ab4133 100644 --- a/src/support/ark_ff_04.rs +++ b/src/support/ark_ff_04.rs @@ -1,6 +1,6 @@ //! Support for the [`ark-ff`](https://crates.io/crates/ark-ff) crate. #![cfg(feature = "ark-ff-04")] -#![cfg_attr(has_doc_cfg, doc(cfg(feature = "ark-ff-04")))] +#![cfg_attr(docsrs, doc(cfg(feature = "ark-ff-04")))] use crate::{ToFieldError, Uint}; use ark_ff_04::{ diff --git a/src/utils.rs b/src/utils.rs index 75b9e773..b06d1d0a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,6 +3,7 @@ use alloc::vec::Vec; /// Like `a % b` but returns `b` instead of `0`. +#[allow(dead_code)] // This is used by some support features. #[must_use] pub(crate) const fn rem_up(a: usize, b: usize) -> usize { let rem = a % b;