diff --git a/Cargo.toml b/Cargo.toml index c8766b9f..b6bed774 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,14 +47,11 @@ criterion-macro = { version = "0.4.0", default-features = false } proptest = { version = "1.5.0", default-features = false, features = ["std"] } test-strategy = { version = "0.4.0", default-features = false } -[profile.release.package.ruzstd] -opt-level = 's' - [profile.release] codegen-units = 1 -lto = true panic = "abort" -strip = "symbols" +lto = true +strip = true [profile.dev] opt-level = 3 diff --git a/bin/main.rs b/bin/main.rs index 0ed790d1..89f1c853 100644 --- a/bin/main.rs +++ b/bin/main.rs @@ -1,36 +1,21 @@ -use futures::executor::{block_on, block_on_stream}; -use futures::{channel::mpsc, prelude::*}; +use futures::executor::block_on; +use futures::{channel::mpsc::unbounded, sink::unfold}; use lib::uci::Uci; -use std::io::{prelude::*, stdin, stdout, LineWriter}; -use std::thread; +use std::io::{prelude::*, stdin, stdout}; +use std::{future::ready, thread}; fn main() { - let (mut tx, input) = mpsc::channel(32); - let (output, rx) = mpsc::channel(32); + let (tx, rx) = unbounded(); thread::spawn(move || { - for item in stdin().lock().lines() { - match item { - Err(error) => return eprint!("{error}"), - Ok(line) => { - if let Err(error) = block_on(tx.send(line)) { - if error.is_disconnected() { - break; - } - } - } + for line in stdin().lock().lines() { + if tx.unbounded_send(line.unwrap()).is_err() { + break; } } }); - thread::spawn(move || { - let mut stdout = LineWriter::new(stdout().lock()); - for line in block_on_stream(rx) { - if let Err(error) = writeln!(stdout, "{line}") { - return eprint!("{error}"); - } - } - }); - - block_on(Uci::new(input, output).run()).ok(); + let mut stdout = stdout().lock(); + let output = unfold((), |_, line: String| ready(writeln!(stdout, "{line}"))); + block_on(Uci::new(rx, output).run()).unwrap(); } diff --git a/lib/chess/bitboard.rs b/lib/chess/bitboard.rs index 2b38af61..7379ca50 100644 --- a/lib/chess/bitboard.rs +++ b/lib/chess/bitboard.rs @@ -1,7 +1,7 @@ use crate::chess::{File, Perspective, Rank, Square}; use crate::util::{Assume, Integer}; use derive_more::{Debug, *}; -use std::fmt::{self, Write}; +use std::fmt::{self, Formatter, Write}; use std::{cell::SyncUnsafeCell, mem::MaybeUninit}; /// A set of squares on a chess board. @@ -26,9 +26,9 @@ use std::{cell::SyncUnsafeCell, mem::MaybeUninit}; #[repr(transparent)] pub struct Bitboard(u64); -impl fmt::Debug for Bitboard { +impl Debug for Bitboard { #[coverage(off)] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_char('\n')?; for rank in Rank::iter().rev() { for file in File::iter() { diff --git a/lib/chess/board.rs b/lib/chess/board.rs index ab7d1324..a84a7f76 100644 --- a/lib/chess/board.rs +++ b/lib/chess/board.rs @@ -1,8 +1,9 @@ -use crate::{chess::*, util::Integer}; -use arrayvec::ArrayString; +use crate::chess::*; +use crate::util::{Assume, Integer}; use derive_more::{Debug, Display, Error}; -use std::fmt::{self, Write}; -use std::{ops::Index, str::FromStr}; +use std::fmt::{self, Formatter, Write}; +use std::str::{self, FromStr}; +use std::{io::Write as _, ops::Index}; /// The chess board. #[derive(Debug, Clone, Eq, PartialEq, Hash)] @@ -163,26 +164,31 @@ impl Index for Board { } impl Display for Board { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let mut skip = 0; for sq in Square::iter().map(|sq| sq.flip()) { - let mut buffer = ArrayString::<2>::new(); + let mut buffer = [b'\0'; 2]; - match self[sq] { - None => skip += 1, - Some(p) => write!(buffer, "{}", p)?, + if sq.file() == File::H { + buffer[0] = if sq.rank() == Rank::First { b' ' } else { b'/' }; } - if sq.file() == File::H { - buffer.push(if sq.rank() == Rank::First { ' ' } else { '/' }); + match self[sq] { + None => skip += 1, + Some(p) => { + buffer[1] = buffer[0]; + write!(&mut buffer[..1], "{p}").assume() + } } - if !buffer.is_empty() && skip > 0 { - write!(f, "{}", skip)?; + if skip > 0 && buffer != [b'\0'; 2] { + write!(f, "{skip}")?; skip = 0; } - f.write_str(&buffer)?; + for b in buffer.into_iter().take_while(|&b| b != b'\0') { + f.write_char(b.into())?; + } } match self.turn { @@ -197,7 +203,7 @@ impl Display for Board { } if let Some(ep) = self.en_passant { - write!(f, "{} ", ep)?; + write!(f, "{ep} ")?; } else { f.write_str("- ")?; } @@ -230,6 +236,7 @@ pub enum ParseFenError { impl FromStr for Board { type Err = ParseFenError; + #[inline(always)] fn from_str(s: &str) -> Result { let fields: Vec<_> = s.split(' ').collect(); let [board, turn, castles, en_passant, halfmoves, fullmoves] = &fields[..] else { diff --git a/lib/chess/castles.rs b/lib/chess/castles.rs index 59de5846..d2262d67 100644 --- a/lib/chess/castles.rs +++ b/lib/chess/castles.rs @@ -1,7 +1,8 @@ use crate::chess::{Color, Perspective, Piece, Role, Square}; use crate::util::{Bits, Integer}; use derive_more::{Debug, *}; -use std::{cell::SyncUnsafeCell, fmt, mem::MaybeUninit, str::FromStr}; +use std::fmt::{self, Formatter}; +use std::{cell::SyncUnsafeCell, mem::MaybeUninit, str::FromStr}; /// The castling rights in a chess [`Position`][`crate::chess::Position`]. #[derive( @@ -107,15 +108,15 @@ impl From for Castles { } } -impl fmt::Display for Castles { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl Display for Castles { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { for side in Color::iter() { if self.has_short(side) { - fmt::Display::fmt(&Piece::new(Role::King, side), f)?; + Display::fmt(&Piece::new(Role::King, side), f)?; } if self.has_long(side) { - fmt::Display::fmt(&Piece::new(Role::Queen, side), f)?; + Display::fmt(&Piece::new(Role::Queen, side), f)?; } } @@ -131,6 +132,7 @@ pub struct ParseCastlesError; impl FromStr for Castles { type Err = ParseCastlesError; + #[inline(always)] fn from_str(s: &str) -> Result { let mut castles = Castles::none(); diff --git a/lib/chess/file.rs b/lib/chess/file.rs index 34546ef5..3aa5bfb0 100644 --- a/lib/chess/file.rs +++ b/lib/chess/file.rs @@ -1,28 +1,21 @@ use crate::chess::{Bitboard, Mirror}; use crate::util::Integer; use derive_more::{Display, Error}; +use std::fmt::{self, Formatter, Write}; use std::{ops::Sub, str::FromStr}; /// A column on the chess board. -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] #[repr(i8)] pub enum File { - #[display("a")] A, - #[display("b")] B, - #[display("c")] C, - #[display("d")] D, - #[display("e")] E, - #[display("f")] F, - #[display("g")] G, - #[display("h")] H, } @@ -57,30 +50,29 @@ impl Sub for File { } } +impl Display for File { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_char((b'a' + self.cast::()).into()) + } +} + /// The reason why parsing [`File`] failed. #[derive(Debug, Display, Clone, Eq, PartialEq, Error)] -#[display( - "failed to parse file, expected letter in the range `({}..={})`", - File::A, - File::H -)] +#[display("failed to parse file")] pub struct ParseFileError; impl FromStr for File { type Err = ParseFileError; + #[inline(always)] fn from_str(s: &str) -> Result { - match s { - "a" => Ok(File::A), - "b" => Ok(File::B), - "c" => Ok(File::C), - "d" => Ok(File::D), - "e" => Ok(File::E), - "f" => Ok(File::F), - "g" => Ok(File::G), - "h" => Ok(File::H), - _ => Err(ParseFileError), - } + let [c] = s.as_bytes() else { + return Err(ParseFileError); + }; + + c.checked_sub(b'a') + .and_then(Integer::convert) + .ok_or(ParseFileError) } } diff --git a/lib/chess/magic.rs b/lib/chess/magic.rs index f7410a9b..5f2db34d 100644 --- a/lib/chess/magic.rs +++ b/lib/chess/magic.rs @@ -4,14 +4,17 @@ use crate::chess::{Bitboard, Square}; pub struct Magic(Bitboard, u64, usize); impl Magic { + #[inline(always)] pub fn mask(&self) -> Bitboard { self.0 } + #[inline(always)] pub fn factor(&self) -> u64 { self.1 } + #[inline(always)] pub fn offset(&self) -> usize { self.2 } diff --git a/lib/chess/move.rs b/lib/chess/move.rs index 115823e5..7871ef5e 100644 --- a/lib/chess/move.rs +++ b/lib/chess/move.rs @@ -1,6 +1,6 @@ use crate::chess::{Bitboard, Perspective, Piece, Rank, Role, Square, Squares}; use crate::util::{Assume, Binary, Bits, Integer}; -use std::fmt::{self, Write}; +use std::fmt::{self, Debug, Display, Formatter, Write}; use std::{num::NonZeroU16, ops::RangeBounds}; /// A chess move. @@ -135,10 +135,10 @@ impl Move { } } -impl fmt::Debug for Move { +impl Debug for Move { #[coverage(off)] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self, f)?; + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self, f)?; if self.is_en_passant() { f.write_char('^')?; @@ -152,13 +152,13 @@ impl fmt::Debug for Move { } } -impl fmt::Display for Move { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.whence(), f)?; - fmt::Display::fmt(&self.whither(), f)?; +impl Display for Move { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.whence(), f)?; + Display::fmt(&self.whither(), f)?; if let Some(r) = self.promotion() { - fmt::Display::fmt(&r, f)?; + Display::fmt(&r, f)?; } Ok(()) diff --git a/lib/chess/outcome.rs b/lib/chess/outcome.rs index e7a50f90..5c4b4d62 100644 --- a/lib/chess/outcome.rs +++ b/lib/chess/outcome.rs @@ -25,16 +25,19 @@ impl Outcome { /// Whether the outcome is a [draw] and neither side has won. /// /// [draw]: https://www.chessprogramming.org/Draw + #[inline(always)] pub fn is_draw(&self) -> bool { !self.is_decisive() } /// Whether the outcome is a decisive and one of the sides has won. + #[inline(always)] pub fn is_decisive(&self) -> bool { matches!(self, Outcome::Checkmate(_)) } /// The winning side, if the outcome is [decisive](`Self::is_decisive`). + #[inline(always)] pub fn winner(&self) -> Option { match *self { Outcome::Checkmate(c) => Some(c), diff --git a/lib/chess/piece.rs b/lib/chess/piece.rs index f6307de8..eac02bcf 100644 --- a/lib/chess/piece.rs +++ b/lib/chess/piece.rs @@ -1,36 +1,25 @@ use crate::chess::{Bitboard, Color, Magic, Perspective, Role, Square}; use crate::util::{Assume, Integer}; use derive_more::{Display, Error}; +use std::fmt::{self, Formatter, Write}; use std::{cell::SyncUnsafeCell, mem::MaybeUninit, str::FromStr}; /// A chess [piece][`Role`] of a certain [`Color`]. -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] #[repr(u8)] pub enum Piece { - #[display("P")] WhitePawn, - #[display("p")] BlackPawn, - #[display("N")] WhiteKnight, - #[display("n")] BlackKnight, - #[display("B")] WhiteBishop, - #[display("b")] BlackBishop, - #[display("R")] WhiteRook, - #[display("r")] BlackRook, - #[display("Q")] WhiteQueen, - #[display("q")] BlackQueen, - #[display("K")] WhiteKing, - #[display("k")] BlackKing, } @@ -201,28 +190,34 @@ impl Perspective for Piece { } } +impl Display for Piece { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Piece::WhitePawn => f.write_char('P'), + Piece::BlackPawn => f.write_char('p'), + Piece::WhiteKnight => f.write_char('N'), + Piece::BlackKnight => f.write_char('n'), + Piece::WhiteBishop => f.write_char('B'), + Piece::BlackBishop => f.write_char('b'), + Piece::WhiteRook => f.write_char('R'), + Piece::BlackRook => f.write_char('r'), + Piece::WhiteQueen => f.write_char('Q'), + Piece::BlackQueen => f.write_char('q'), + Piece::WhiteKing => f.write_char('K'), + Piece::BlackKing => f.write_char('k'), + } + } +} + /// The reason why parsing the piece. #[derive(Debug, Display, Clone, Eq, PartialEq, Error)] -#[display( - "failed to parse piece, expected one of `[{}{}{}{}{}{}{}{}{}{}{}{}]`", - Piece::WhitePawn, - Piece::BlackPawn, - Piece::WhiteKnight, - Piece::BlackKnight, - Piece::WhiteBishop, - Piece::BlackBishop, - Piece::WhiteRook, - Piece::BlackRook, - Piece::WhiteQueen, - Piece::BlackQueen, - Piece::WhiteKing, - Piece::BlackKing -)] +#[display("failed to parse piece")] pub struct ParsePieceError; impl FromStr for Piece { type Err = ParsePieceError; + #[inline(always)] fn from_str(s: &str) -> Result { match s { "P" => Ok(Piece::WhitePawn), diff --git a/lib/chess/position.rs b/lib/chess/position.rs index 4832ddb6..8d0d44dc 100644 --- a/lib/chess/position.rs +++ b/lib/chess/position.rs @@ -2,8 +2,9 @@ use crate::chess::*; use crate::util::{Assume, Integer}; use arrayvec::{ArrayVec, CapacityError}; use derive_more::{Debug, Display, Error, From}; +use std::fmt::{self, Formatter}; use std::hash::{Hash, Hasher}; -use std::{fmt, num::NonZeroU32, str::FromStr}; +use std::{num::NonZeroU32, str::FromStr}; #[cfg(test)] use proptest::{prelude::*, sample::*}; @@ -617,9 +618,8 @@ impl Position { } impl Display for Position { - #[inline(always)] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.board, f) + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.board, f) } } @@ -635,6 +635,7 @@ pub enum ParsePositionError { impl FromStr for Position { type Err = ParsePositionError; + #[inline(always)] fn from_str(s: &str) -> Result { use {ParsePositionError::*, Role::*}; diff --git a/lib/chess/rank.rs b/lib/chess/rank.rs index 3f101e43..986baf18 100644 --- a/lib/chess/rank.rs +++ b/lib/chess/rank.rs @@ -1,28 +1,21 @@ use crate::chess::{Bitboard, Perspective}; use crate::util::Integer; use derive_more::{Display, Error}; +use std::fmt::{self, Formatter, Write}; use std::{ops::Sub, str::FromStr}; /// A row on the chess board. -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] #[repr(i8)] pub enum Rank { - #[display("1")] First, - #[display("2")] Second, - #[display("3")] Third, - #[display("4")] Fourth, - #[display("5")] Fifth, - #[display("6")] Sixth, - #[display("7")] Seventh, - #[display("8")] Eighth, } @@ -57,30 +50,29 @@ impl Sub for Rank { } } +impl Display for Rank { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_char((b'1' + self.cast::()).into()) + } +} + /// The reason why parsing [`Rank`] failed. #[derive(Debug, Display, Clone, Eq, PartialEq, Error)] -#[display( - "failed to parse rank, expected digit in the range `({}..={})`", - Rank::First, - Rank::Eighth -)] +#[display("failed to parse rank")] pub struct ParseRankError; impl FromStr for Rank { type Err = ParseRankError; + #[inline(always)] fn from_str(s: &str) -> Result { - match s { - "1" => Ok(Rank::First), - "2" => Ok(Rank::Second), - "3" => Ok(Rank::Third), - "4" => Ok(Rank::Fourth), - "5" => Ok(Rank::Fifth), - "6" => Ok(Rank::Sixth), - "7" => Ok(Rank::Seventh), - "8" => Ok(Rank::Eighth), - _ => Err(ParseRankError), - } + let [c] = s.as_bytes() else { + return Err(ParseRankError); + }; + + c.checked_sub(b'1') + .and_then(Integer::convert) + .ok_or(ParseRankError) } } diff --git a/lib/chess/role.rs b/lib/chess/role.rs index 23cc429e..f1e79555 100644 --- a/lib/chess/role.rs +++ b/lib/chess/role.rs @@ -1,23 +1,18 @@ use crate::util::Integer; use derive_more::{Display, Error}; +use std::fmt::{self, Formatter, Write}; use std::str::FromStr; /// The type of a chess [`Piece`][`crate::Piece`]. -#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] #[repr(u8)] pub enum Role { - #[display("p")] Pawn, - #[display("n")] Knight, - #[display("b")] Bishop, - #[display("r")] Rook, - #[display("q")] Queen, - #[display("k")] King, } @@ -27,22 +22,28 @@ unsafe impl Integer for Role { const MAX: Self::Repr = Role::King as _; } +impl Display for Role { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Role::Pawn => f.write_char('p'), + Role::Knight => f.write_char('n'), + Role::Bishop => f.write_char('b'), + Role::Rook => f.write_char('r'), + Role::Queen => f.write_char('q'), + Role::King => f.write_char('k'), + } + } +} + /// The reason why parsing the piece. #[derive(Debug, Display, Clone, Eq, PartialEq, Error)] -#[display( - "failed to parse piece, expected one of `[{}{}{}{}{}{}]`", - Role::Pawn, - Role::Knight, - Role::Bishop, - Role::Rook, - Role::Queen, - Role::King -)] +#[display("failed to parse piece")] pub struct ParseRoleError; impl FromStr for Role { type Err = ParseRoleError; + #[inline(always)] fn from_str(s: &str) -> Result { match s { "p" => Ok(Role::Pawn), diff --git a/lib/chess/square.rs b/lib/chess/square.rs index 8e6e9284..ac599ec2 100644 --- a/lib/chess/square.rs +++ b/lib/chess/square.rs @@ -1,8 +1,9 @@ use crate::chess::{Bitboard, File, Mirror, ParseFileError, ParseRankError, Perspective, Rank}; use crate::util::{Assume, Binary, Bits, Integer}; use derive_more::{Display, Error, From}; +use std::fmt::{self, Formatter}; use std::ops::{Add, AddAssign, Sub, SubAssign}; -use std::{fmt, str::FromStr}; +use std::str::FromStr; /// A square on the chess board. #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] @@ -123,10 +124,10 @@ impl AddAssign for Square { } } -impl fmt::Display for Square { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.file(), f)?; - fmt::Display::fmt(&self.rank(), f)?; +impl Display for Square { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.file(), f)?; + Display::fmt(&self.rank(), f)?; Ok(()) } } @@ -143,8 +144,9 @@ pub enum ParseSquareError { impl FromStr for Square { type Err = ParseSquareError; + #[inline(always)] fn from_str(s: &str) -> Result { - let i = s.char_indices().nth(1).map_or_else(|| s.len(), |(i, _)| i); + let i = s.ceil_char_boundary(1); Ok(Square::new(s[..i].parse()?, s[i..].parse()?)) } } diff --git a/lib/lib.rs b/lib/lib.rs index 99cf12e2..d2ab2698 100644 --- a/lib/lib.rs +++ b/lib/lib.rs @@ -2,6 +2,7 @@ #![feature( array_chunks, coverage_attribute, + round_char_boundary, new_zeroed_alloc, optimize_attribute, ptr_as_ref_unchecked, diff --git a/lib/nnue/hidden.rs b/lib/nnue/hidden.rs index 925f8066..77ca2a63 100644 --- a/lib/nnue/hidden.rs +++ b/lib/nnue/hidden.rs @@ -1,8 +1,8 @@ use crate::util::AlignTo64; -use derive_more::{Constructor, Shl}; +use std::ops::Shl; /// The hidden layer. -#[derive(Debug, Clone, Eq, PartialEq, Hash, Constructor)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] pub struct Hidden { #[cfg_attr(test, map(|b: i8| i32::from(b)))] diff --git a/lib/nnue/transformer.rs b/lib/nnue/transformer.rs index 48ccef00..20805af9 100644 --- a/lib/nnue/transformer.rs +++ b/lib/nnue/transformer.rs @@ -1,6 +1,5 @@ use crate::nnue::Feature; use crate::util::{AlignTo64, Assume, Integer}; -use derive_more::Constructor; use std::ops::{Add, AddAssign, Sub, SubAssign}; #[cfg(test)] @@ -10,7 +9,7 @@ use proptest::{prelude::*, sample::Index}; use std::ops::Range; /// A feature transformer. -#[derive(Debug, Clone, Eq, PartialEq, Hash, Constructor)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct Transformer { pub(super) bias: AlignTo64<[T; N]>, pub(super) weight: AlignTo64<[[T; N]; Feature::LEN]>, diff --git a/lib/search/driver.rs b/lib/search/driver.rs index 6914e527..4754a616 100644 --- a/lib/search/driver.rs +++ b/lib/search/driver.rs @@ -20,6 +20,7 @@ pub enum Driver { impl Driver { /// Constructs a parallel search driver with the given [`ThreadCount`]. + #[inline(always)] pub fn new(threads: ThreadCount) -> Self { match threads.get() { 1 => Self::Sequential, diff --git a/lib/search/engine.rs b/lib/search/engine.rs index c51cb852..01c44713 100644 --- a/lib/search/engine.rs +++ b/lib/search/engine.rs @@ -122,7 +122,7 @@ impl Engine { ply: Ply, ctrl: &Control, ) -> Result { - self.pvs::(pos, Score::lower()..Score::upper(), depth, ply, ctrl) + self.pvs(pos, Score::lower()..Score::upper(), depth, ply, ctrl) } /// A [zero-window] alpha-beta search. @@ -136,14 +136,14 @@ impl Engine { ply: Ply, ctrl: &Control, ) -> Result { - self.pvs::(pos, beta - 1..beta, depth, ply, ctrl) + self.pvs(pos, beta - 1..beta, depth, ply, ctrl) } /// An implementation of the [PVS] variation of [alpha-beta pruning] algorithm. /// /// [PVS]: https://www.chessprogramming.org/Principal_Variation_Search /// [alpha-beta pruning]: https://www.chessprogramming.org/Alpha-Beta - fn pvs( + fn pvs( &self, pos: &Evaluator, bounds: Range, @@ -178,8 +178,9 @@ impl Engine { _ => depth, }; + let is_pv = alpha + 1 < beta; if let Some(t) = transposition { - if !PV && t.depth() >= depth - ply { + if !is_pv && t.depth() >= depth - ply { let (lower, upper) = t.bounds().into_inner(); if lower >= upper || upper <= alpha || lower >= beta { return Ok(Pv::new(t.score().normalize(ply), Some(t.best()))); @@ -202,7 +203,7 @@ impl Engine { if alpha >= beta || ply >= Ply::MAX { return Ok(Pv::new(score, None)); - } else if !PV && !pos.is_check() { + } else if !is_pv && !pos.is_check() { if let Some(d) = self.nmp(pos, score, beta, depth, ply) { let mut next = pos.clone(); next.pass(); @@ -241,7 +242,7 @@ impl Engine { Some((m, _)) => { let mut next = pos.clone(); next.play(m); - m >> -self.pvs::(&next, -beta..-alpha, depth, ply + 1, ctrl)? + m >> -self.pvs(&next, -beta..-alpha, depth, ply + 1, ctrl)? } }; @@ -271,7 +272,7 @@ impl Engine { let pv = match -self.nw(&next, -alpha, depth, ply + 1, ctrl)? { pv if pv <= alpha || pv >= beta => m >> pv, - _ => m >> -self.pvs::(&next, -beta..-alpha, depth, ply + 1, ctrl)?, + _ => m >> -self.pvs(&next, -beta..-alpha, depth, ply + 1, ctrl)?, }; Ok(pv) @@ -321,8 +322,7 @@ impl Engine { break 'id; } - let bounds = lower..upper; - let Ok(partial) = self.pvs::(pos, bounds, depth, Ply::new(0), &ctrl) else { + let Ok(partial) = self.pvs(pos, lower..upper, depth, Ply::new(0), &ctrl) else { break 'id; }; @@ -502,7 +502,7 @@ mod tests { d: Depth, p: Ply, ) { - e.pvs::(&pos, b.end..b.start, d, p, &Control::Unlimited)?; + e.pvs(&pos, b.end..b.start, d, p, &Control::Unlimited)?; } #[proptest] @@ -515,7 +515,7 @@ mod tests { ) { let interrupter = Trigger::armed(); let ctrl = Control::Limited(Counter::new(0), Timer::infinite(), &interrupter); - assert_eq!(e.pvs::(&pos, b, d, p, &ctrl), Err(Interrupted)); + assert_eq!(e.pvs(&pos, b, d, p, &ctrl), Err(Interrupted)); } #[proptest] @@ -533,7 +533,7 @@ mod tests { &interrupter, ); std::thread::sleep(Duration::from_millis(1)); - assert_eq!(e.pvs::(&pos, b, d, p, &ctrl), Err(Interrupted)); + assert_eq!(e.pvs(&pos, b, d, p, &ctrl), Err(Interrupted)); } #[proptest] @@ -546,7 +546,7 @@ mod tests { ) { let interrupter = Trigger::disarmed(); let ctrl = Control::Limited(Counter::new(u64::MAX), Timer::infinite(), &interrupter); - assert_eq!(e.pvs::(&pos, b, d, p, &ctrl), Err(Interrupted)); + assert_eq!(e.pvs(&pos, b, d, p, &ctrl), Err(Interrupted)); } #[proptest] @@ -557,7 +557,7 @@ mod tests { d: Depth, ) { assert_eq!( - e.pvs::(&pos, b, d, Ply::upper(), &Control::Unlimited), + e.pvs(&pos, b, d, Ply::upper(), &Control::Unlimited), Ok(Pv::new(pos.evaluate().saturate(), None)) ); } @@ -571,7 +571,7 @@ mod tests { p: Ply, ) { assert_eq!( - e.pvs::(&pos, b, d, p, &Control::Unlimited), + e.pvs(&pos, b, d, p, &Control::Unlimited), Ok(Pv::new(Score::new(0), None)) ); } @@ -585,7 +585,7 @@ mod tests { p: Ply, ) { assert_eq!( - e.pvs::(&pos, b, d, p, &Control::Unlimited), + e.pvs(&pos, b, d, p, &Control::Unlimited), Ok(Pv::new(Score::lower().normalize(p), None)) ); } diff --git a/lib/search/limits.rs b/lib/search/limits.rs index a20012a4..224cf8a2 100644 --- a/lib/search/limits.rs +++ b/lib/search/limits.rs @@ -25,6 +25,7 @@ pub enum Limits { impl Limits { /// Maximum depth or [`Depth::MAX`]. + #[inline(always)] pub fn depth(&self) -> Depth { match self { Limits::Depth(d) => *d, @@ -33,6 +34,7 @@ impl Limits { } /// Maximum number of nodes [`u64::MAX`]. + #[inline(always)] pub fn nodes(&self) -> u64 { match self { Limits::Nodes(n) => *n, @@ -41,6 +43,7 @@ impl Limits { } /// Maximum time or [`Duration::MAX`]. + #[inline(always)] pub fn time(&self) -> Duration { match self { Limits::Time(t) => *t, @@ -50,6 +53,7 @@ impl Limits { } /// Time left on the clock or [`Duration::MAX`]. + #[inline(always)] pub fn clock(&self) -> Duration { match self { Limits::Clock(t, _) => *t, @@ -58,6 +62,7 @@ impl Limits { } /// Time increment or [`Duration::ZERO`]. + #[inline(always)] pub fn increment(&self) -> Duration { match self { Limits::Clock(_, i) => *i, diff --git a/lib/search/score.rs b/lib/search/score.rs index acf77934..a06d09a1 100644 --- a/lib/search/score.rs +++ b/lib/search/score.rs @@ -1,6 +1,5 @@ use crate::util::{Binary, Bits, Integer, Saturating}; use crate::{chess::Perspective, search::Ply}; -use std::fmt; #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] @@ -20,6 +19,7 @@ impl Score { /// Returns number of plies to mate, if one is in the horizon. /// /// Negative number of plies means the opponent is mating. + #[inline(always)] pub fn mate(&self) -> Option { if *self <= Score::lower() - Ply::MIN { Some((Score::lower() - *self).saturate()) @@ -31,6 +31,7 @@ impl Score { } /// Normalizes mate scores relative to `ply`. + #[inline(always)] pub fn normalize(&self, ply: Ply) -> Self { if *self <= Score::lower() - Ply::MIN { (*self + ply).min(Score::lower() - Ply::MIN) @@ -63,16 +64,6 @@ impl Binary for Score { } } -impl fmt::Display for Score { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.mate() { - Some(p) if p > 0 => write!(f, "{:+}#{}", self.get(), (p.cast::() + 1) / 2), - Some(p) => write!(f, "{:+}#{}", self.get(), (1 - p.cast::()) / 2), - None => write!(f, "{:+}", self.get()), - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -106,24 +97,4 @@ mod tests { fn decoding_encoded_score_is_an_identity(s: Score) { assert_eq!(Score::decode(s.encode()), s); } - - #[proptest] - fn printing_score_displays_sign(s: Score) { - assert!(s.to_string().starts_with(if s < 0 { "-" } else { "+" })); - } - - #[proptest] - fn printing_mate_score_displays_moves_to_mate(p: Ply) { - if p > 0 { - assert!(Score::upper() - .normalize(p) - .to_string() - .ends_with(&format!("#{}", (p.cast::() + 1) / 2))); - } else { - assert!(Score::lower() - .normalize(-p) - .to_string() - .ends_with(&format!("#{}", (1 - p.cast::()) / 2))); - }; - } } diff --git a/lib/search/transposition.rs b/lib/search/transposition.rs index a8c7dba6..dfd7e26e 100644 --- a/lib/search/transposition.rs +++ b/lib/search/transposition.rs @@ -52,6 +52,7 @@ pub struct Transposition { } impl Transposition { + #[inline(always)] fn new(kind: TranspositionKind, depth: Depth, score: Score, best: Move) -> Self { Transposition { kind, @@ -62,21 +63,25 @@ impl Transposition { } /// Constructs a [`Transposition`] given a lower bound for the score, the depth searched, and best [`Move`]. + #[inline(always)] pub fn lower(depth: Depth, score: Score, best: Move) -> Self { Transposition::new(TranspositionKind::Lower, depth, score, best) } /// Constructs a [`Transposition`] given an upper bound for the score, the depth searched, and best [`Move`]. + #[inline(always)] pub fn upper(depth: Depth, score: Score, best: Move) -> Self { Transposition::new(TranspositionKind::Upper, depth, score, best) } /// Constructs a [`Transposition`] given the exact score, the depth searched, and best [`Move`]. + #[inline(always)] pub fn exact(depth: Depth, score: Score, best: Move) -> Self { Transposition::new(TranspositionKind::Exact, depth, score, best) } /// Bounds for the exact score. + #[inline(always)] pub fn bounds(&self) -> RangeInclusive { match self.kind { TranspositionKind::Lower => self.score..=Score::upper(), @@ -86,16 +91,19 @@ impl Transposition { } /// Depth searched. + #[inline(always)] pub fn depth(&self) -> Depth { self.depth } /// Partial score. + #[inline(always)] pub fn score(&self) -> Score { self.score } /// Best [`Move`] at this depth. + #[inline(always)] pub fn best(&self) -> Move { self.best } @@ -174,6 +182,7 @@ impl TranspositionTable { const WIDTH: usize = size_of::< as Binary>::Bits>(); /// Constructs a transposition table of at most `size` many bytes. + #[inline(always)] pub fn new(size: HashSize) -> Self { let capacity = (1 + size.get() / 2).next_power_of_two() / Self::WIDTH; @@ -183,11 +192,13 @@ impl TranspositionTable { } /// The actual size of this table in bytes. + #[inline(always)] pub fn size(&self) -> HashSize { HashSize::new(self.capacity() * Self::WIDTH) } /// The actual size of this table in number of entries. + #[inline(always)] pub fn capacity(&self) -> usize { self.cache.len() } diff --git a/lib/uci.rs b/lib/uci.rs index 07fba8e1..a7b2f361 100644 --- a/lib/uci.rs +++ b/lib/uci.rs @@ -1,12 +1,12 @@ use crate::chess::{Color, Move, Perspective}; use crate::nnue::Evaluator; -use crate::search::{Engine, HashSize, Limits, Options, Score, ThreadCount}; +use crate::search::{Engine, HashSize, Limits, Options, ThreadCount}; use crate::util::{Assume, Integer, Trigger}; -use arrayvec::ArrayString; -use derive_more::{Deref, Display}; -use futures::{channel::oneshot, future::FusedFuture, prelude::*, select_biased as select}; +use derive_more::Display; +use futures::channel::oneshot::channel as oneshot; +use futures::{future::FusedFuture, prelude::*, select_biased as select, stream::FusedStream}; use std::time::{Duration, Instant}; -use std::{fmt::Debug, mem::transmute, thread}; +use std::{fmt::Debug, io::Write, mem::transmute, ops::Deref, str, thread}; #[cfg(test)] use proptest::prelude::*; @@ -18,36 +18,44 @@ use proptest::prelude::*; /// Must be awaited on through completion strictly before any /// of the variables `f` may capture is dropped. #[must_use] -unsafe fn unblock<'a, F, R>(f: F) -> impl FusedFuture + 'a +unsafe fn unblock(f: F) -> impl FusedFuture where - F: FnOnce() -> R + Send + 'a, - R: Send + 'a, + F: FnOnce() -> R + Send, + R: Send, { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = oneshot(); thread::spawn(transmute::< - Box, + Box, Box, >(Box::new(move || tx.send(f()).assume()) as _)); rx.map(Assume::assume) } -#[derive(Debug, Display, Default, Clone, Eq, PartialEq, Hash, Deref)] +#[derive(Debug, Display, Default, Clone, Eq, PartialEq, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] -struct UciMove( - #[deref(forward)] - #[cfg_attr(test, map(|m: Move| ArrayString::from(&m.to_string()).unwrap()))] - ArrayString<5>, -); +#[display("{}", *self)] +struct UciMove([u8; 5]); -impl From for UciMove { - fn from(m: Move) -> Self { - Self(ArrayString::from(&m.to_string()).assume()) +impl Deref for UciMove { + type Target = str; + + fn deref(&self) -> &Self::Target { + let len = if self.0[4] == b'\0' { 4 } else { 5 }; + unsafe { str::from_utf8_unchecked(&self.0[..len]) } } } impl PartialEq<&str> for UciMove { fn eq(&self, other: &&str) -> bool { - self.0.eq(*other) + **self == **other + } +} + +impl From for UciMove { + fn from(m: Move) -> Self { + let mut buffer = [b'\0'; 5]; + write!(&mut buffer[..], "{m}").assume(); + Self(buffer) } } @@ -81,7 +89,7 @@ impl Uci { } } -impl + Unpin, O: Sink + Unpin> Uci { +impl + Unpin, O: Sink + Unpin> Uci { /// Runs the UCI server. pub async fn run(&mut self) -> Result<(), O::Error> { while let Some(line) = self.input.next().await { @@ -93,18 +101,23 @@ impl + Unpin, O: Sink + Unpin> Uci { Ok(()) } - fn play(&mut self, uci: &str) { - let mut moves = self.position.moves().flatten(); - let Some(m) = moves.find(|&m| UciMove::from(m) == uci) else { - return if !(0..=5).contains(&uci.len()) || !uci.is_ascii() { - eprintln!("invalid move `{uci}`") - } else { - eprintln!("illegal move `{uci}` in position `{}`", self.position) - }; - }; + fn play(&mut self, s: &str) { + if let Ok(whence) = s[..s.ceil_char_boundary(2)].parse() { + for ms in self.position.moves() { + if ms.whence() == whence { + for m in ms { + let uci = UciMove::from(m); + if uci == s { + self.position.play(m); + self.moves.push(uci); + return; + } + } + } + } + } - self.position.play(m); - self.moves.push(UciMove::from(m)); + eprintln!("illegal move `{s}` in position `{}`", self.position) } async fn go(&mut self, limits: &Limits) -> Result<(), O::Error> { @@ -113,30 +126,29 @@ impl + Unpin, O: Sink + Unpin> Uci { let mut search = unsafe { unblock(|| self.engine.search(&self.position, limits, &interrupter)) }; - let stop = async { - loop { - match self.input.next().await.as_deref().map(str::trim) { - None => break false, - Some("stop") => break interrupter.disarm(), - Some(cmd) => eprintln!("ignored unsupported command `{cmd}` during search"), - }; + let pv = loop { + select! { + pv = search => break pv, + line = self.input.next() => { + match line.as_deref().map(str::trim) { + None | Some("stop") => { interrupter.disarm(); }, + Some(cmd) => eprintln!("ignored unsupported command `{cmd}` during search"), + } + } } }; - let pv = select! { - pv = search => pv, - _ = stop.fuse() => search.await - }; - - let best = pv.best().expect("the engine failed to find a move"); let info = match pv.score().mate() { - Some(p) if p > 0 => format!("info score mate {} pv {best}", (p + 1).get() / 2), - Some(p) => format!("info score mate {} pv {best}", (p - 1).get() / 2), - None => format!("info score cp {} pv {best}", pv.score().get()), + Some(p) if p > 0 => format!("info score mate {}", (p + 1) / 2), + Some(p) => format!("info score mate {}", (p - 1) / 2), + None => format!("info score cp {:+}", pv.score()), }; self.output.send(info).await?; - self.output.send(format!("bestmove {best}")).await?; + + if let Some(m) = pv.best() { + self.output.send(format!("bestmove {m}")).await?; + } Ok(()) } @@ -148,7 +160,7 @@ impl + Unpin, O: Sink + Unpin> Uci { let millis = timer.elapsed().as_millis(); let info = match limits { - Limits::Depth(d) => format!("info time {millis} depth {}", d.get()), + Limits::Depth(d) => format!("info time {millis} depth {d}"), Limits::Nodes(nodes) => format!( "info time {millis} nodes {nodes} nps {}", *nodes as u128 * 1000 / millis @@ -168,13 +180,6 @@ impl + Unpin, O: Sink + Unpin> Uci { ["stop"] => Ok(true), ["quit"] => Ok(false), - ["position", "startpos", "moves", m, n] => { - self.position = Evaluator::default(); - self.play(m); - self.play(n); - Ok(true) - } - ["position", "startpos", "moves", moves @ .., m, n] if self.moves == moves => { self.play(m); self.play(n); @@ -226,8 +231,8 @@ impl + Unpin, O: Sink + Unpin> Uci { } ["go", "depth", depth] => { - match depth.parse::() { - Ok(d) => self.go(&Limits::Depth(d.saturate())).await?, + match depth.parse() { + Ok(d) => self.go(&Limits::Depth(d)).await?, Err(e) => eprintln!("{e}"), } @@ -235,8 +240,8 @@ impl + Unpin, O: Sink + Unpin> Uci { } ["go", "nodes", nodes] => { - match nodes.parse::() { - Ok(n) => self.go(&n.into()).await?, + match nodes.parse() { + Ok(n) => self.go(&Limits::Nodes(n)).await?, Err(e) => eprintln!("{e}"), } @@ -258,8 +263,8 @@ impl + Unpin, O: Sink + Unpin> Uci { } ["bench", "depth", depth] => { - match depth.parse::() { - Ok(d) => self.bench(&Limits::Depth(d.saturate())).await?, + match depth.parse() { + Ok(d) => self.bench(&Limits::Depth(d)).await?, Err(e) => eprintln!("{e}"), } @@ -267,8 +272,8 @@ impl + Unpin, O: Sink + Unpin> Uci { } ["bench", "nodes", nodes] => { - match nodes.parse::() { - Ok(n) => self.bench(&n.into()).await?, + match nodes.parse() { + Ok(n) => self.bench(&Limits::Nodes(n)).await?, Err(e) => eprintln!("{e}"), } @@ -278,10 +283,10 @@ impl + Unpin, O: Sink + Unpin> Uci { ["eval"] => { let pos = &self.position; let turn = self.position.turn(); - let mat: Score = pos.material().evaluate().perspective(turn).saturate(); - let positional: Score = pos.positional().evaluate().perspective(turn).saturate(); - let value: Score = pos.evaluate().perspective(turn).saturate(); - let info = format!("info material {mat} positional {positional} value {value}"); + let mat = pos.material().evaluate().perspective(turn); + let psn = pos.positional().evaluate().perspective(turn); + let val = pos.evaluate().perspective(turn); + let info = format!("info material {mat:+} positional {psn:+} value {val:+}"); self.output.send(info).await?; Ok(true) } @@ -327,7 +332,7 @@ impl + Unpin, O: Sink + Unpin> Uci { ["setoption", "name", "hash", "value", hash] | ["setoption", "name", "Hash", "value", hash] => { - match hash.parse::() { + match hash.parse() { Err(e) => eprintln!("{e}"), Ok(h) => { if h != self.options.hash { @@ -342,7 +347,7 @@ impl + Unpin, O: Sink + Unpin> Uci { ["setoption", "name", "threads", "value", threads] | ["setoption", "name", "Threads", "value", threads] => { - match threads.parse::() { + match threads.parse() { Err(e) => eprintln!("{e}"), Ok(t) => { if t != self.options.threads { @@ -390,6 +395,12 @@ mod tests { } } + impl FusedStream for StaticStream { + fn is_terminated(&self) -> bool { + self.0.is_empty() + } + } + type MockUci = Uci>; #[proptest] @@ -533,7 +544,7 @@ mod tests { #[proptest] fn handles_go_depth( #[filter(#uci.position.outcome().is_none())] - #[any(StaticStream::new([format!("go depth {}", #_d.get())]))] + #[any(StaticStream::new([format!("go depth {}", #_d)]))] mut uci: MockUci, _d: Depth, ) { @@ -605,7 +616,7 @@ mod tests { Color::Black => -pos.evaluate(), }; - let value = format!("value {:+}", value.get()); + let value = format!("value {:+}", value); assert_eq!(block_on(uci.run()), Ok(())); assert!(uci.output.concat().ends_with(&value)); } diff --git a/lib/util/integer.rs b/lib/util/integer.rs index 3abf7b2c..22935951 100644 --- a/lib/util/integer.rs +++ b/lib/util/integer.rs @@ -1,4 +1,4 @@ -use std::num::{NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize}; +use std::num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize}; use std::{mem::transmute_copy, ops::*}; /// Trait for types that can be represented by a contiguous range of primitive integers. @@ -133,6 +133,7 @@ impl_integer_for_non_zero!(NonZeroU8, u8); impl_integer_for_non_zero!(NonZeroU16, u16); impl_integer_for_non_zero!(NonZeroU32, u32); impl_integer_for_non_zero!(NonZeroU64, u64); +impl_integer_for_non_zero!(NonZeroU128, u128); impl_integer_for_non_zero!(NonZeroUsize, usize); macro_rules! impl_primitive_for { diff --git a/lib/util/saturating.rs b/lib/util/saturating.rs index 9102439b..b9a2ec4f 100644 --- a/lib/util/saturating.rs +++ b/lib/util/saturating.rs @@ -1,10 +1,15 @@ -use crate::util::{Integer, Signed}; +use crate::util::Integer; +use derive_more::{Debug, Display, Error}; +use std::fmt::{self, Formatter}; use std::ops::{Add, Div, Mul, Neg, Sub}; -use std::{cmp::Ordering, mem::size_of}; +use std::{cmp::Ordering, mem::size_of, num::Saturating as S, str::FromStr}; /// A saturating bounded integer. #[derive(Debug, Default, Copy, Clone, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] +#[cfg_attr(test, arbitrary(bound(T, Self: Debug)))] +#[debug("Saturating({self})")] +#[debug(bounds(T: Integer, T::Repr: Display))] #[repr(transparent)] pub struct Saturating(T); @@ -16,155 +21,142 @@ unsafe impl Integer for Saturating { impl Eq for Saturating where Self: PartialEq {} -impl PartialEq for Saturating -where - T: Integer, - U: Integer, - I: Signed, - J: Signed, -{ +impl PartialEq for Saturating { #[inline(always)] fn eq(&self, other: &U) -> bool { - if size_of::() <= size_of::() { - J::eq(&self.cast(), &other.cast()) + if size_of::() > size_of::() { + T::Repr::eq(&self.get(), &other.cast()) } else { - I::eq(&self.cast(), &other.cast()) + U::Repr::eq(&self.cast(), &other.get()) } } } -impl Ord for Saturating -where - T: Integer, - I: Signed + Ord, -{ +impl Ord for Saturating { #[inline(always)] fn cmp(&self, other: &Self) -> Ordering { self.get().cmp(&other.get()) } } -impl PartialOrd for Saturating -where - T: Integer, - U: Integer, - I: Signed, - J: Signed, -{ +impl PartialOrd for Saturating { #[inline(always)] fn partial_cmp(&self, other: &U) -> Option { - if size_of::() <= size_of::() { - J::partial_cmp(&self.cast(), &other.cast()) + if size_of::() > size_of::() { + T::Repr::partial_cmp(&self.get(), &other.cast()) } else { - I::partial_cmp(&self.cast(), &other.cast()) + U::Repr::partial_cmp(&self.cast(), &other.get()) } } } -impl Neg for Saturating +impl Neg for Saturating where - T: Integer, - I: Widen, - J: Signed + Neg, + S: Neg>, { type Output = Self; #[inline(always)] fn neg(self) -> Self::Output { - J::neg(self.cast()).saturate() + S(self.get()).neg().0.saturate() } } -impl Add for Saturating +impl Add for Saturating where - T: Integer, - U: Integer, - I: Widen, - J: Widen, + S: Add>, + S: Add>, { type Output = Self; #[inline(always)] fn add(self, rhs: U) -> Self::Output { - if size_of::() <= size_of::() { - J::Wider::add(self.cast(), rhs.cast()).saturate() + if size_of::() > size_of::() { + S::add(S(self.get()), S(rhs.cast())).0.saturate() } else { - I::Wider::add(self.cast(), rhs.cast()).saturate() + S::add(S(self.cast()), S(rhs.get())).0.saturate() } } } -impl Sub for Saturating +impl Sub for Saturating where - T: Integer, - U: Integer, - I: Widen, - J: Widen, + S: Sub>, + S: Sub>, { type Output = Self; #[inline(always)] fn sub(self, rhs: U) -> Self::Output { - if size_of::() <= size_of::() { - J::Wider::sub(self.cast(), rhs.cast()).saturate() + if size_of::() > size_of::() { + S::sub(S(self.get()), S(rhs.cast())).0.saturate() } else { - I::Wider::sub(self.cast(), rhs.cast()).saturate() + S::sub(S(self.cast()), S(rhs.get())).0.saturate() } } } -impl Mul for Saturating +impl Mul for Saturating where - T: Integer, - U: Integer, - I: Widen, - J: Widen, + S: Mul>, + S: Mul>, { type Output = Self; #[inline(always)] fn mul(self, rhs: U) -> Self::Output { - if size_of::() <= size_of::() { - J::Wider::mul(self.cast(), rhs.cast()).saturate() + if size_of::() > size_of::() { + S::mul(S(self.get()), S(rhs.cast())).0.saturate() } else { - I::Wider::mul(self.cast(), rhs.cast()).saturate() + S::mul(S(self.cast()), S(rhs.get())).0.saturate() } } } -impl Div for Saturating +impl Div for Saturating where - T: Integer, - U: Integer, - I: Widen, - J: Widen, + S: Div>, + S: Div>, { type Output = Self; #[inline(always)] fn div(self, rhs: U) -> Self::Output { - if size_of::() <= size_of::() { - J::Wider::div(self.cast(), rhs.cast()).saturate() + if size_of::() > size_of::() { + S::div(S(self.get()), S(rhs.cast())).0.saturate() } else { - I::Wider::div(self.cast(), rhs.cast()).saturate() + S::div(S(self.cast()), S(rhs.get())).0.saturate() } } } -trait Widen: Signed { - type Wider: Signed; +impl Display for Saturating +where + T::Repr: Display, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.get(), f) + } } -impl Widen for i8 { - type Wider = i16; -} +/// The reason why parsing [`Saturating`] failed. +#[derive(Debug, Display, Clone, Eq, PartialEq, Error)] +#[display("failed to parse saturating integer")] +pub struct ParseSaturatingIntegerError; -impl Widen for i16 { - type Wider = i32; -} +impl FromStr for Saturating +where + T::Repr: FromStr, +{ + type Err = ParseSaturatingIntegerError; -impl Widen for i32 { - type Wider = i64; + #[inline(always)] + fn from_str(s: &str) -> Result { + s.parse::() + .ok() + .and_then(Integer::convert) + .ok_or(ParseSaturatingIntegerError) + } } #[cfg(test)] @@ -233,4 +225,39 @@ mod tests { let r: i8 = i16::saturating_div(b.cast(), a.cast()).saturate(); assert_eq!(b / a, r); } + + #[proptest] + fn parsing_printed_saturating_integer_is_an_identity(a: Saturating) { + assert_eq!(a.to_string().parse(), Ok(a)); + } + + #[proptest] + fn parsing_saturating_integer_fails_for_numbers_too_small( + #[strategy(..Saturating::::MIN)] n: i16, + ) { + assert_eq!( + n.to_string().parse::>(), + Err(ParseSaturatingIntegerError) + ); + } + + #[proptest] + fn parsing_saturating_integer_fails_for_numbers_too_large( + #[strategy(Saturating::::MAX + 1..)] n: i16, + ) { + assert_eq!( + n.to_string().parse::>(), + Err(ParseSaturatingIntegerError) + ); + } + + #[proptest] + fn parsing_saturating_integer_fails_for_invalid_number( + #[filter(#s.parse::().is_err())] s: String, + ) { + assert_eq!( + s.to_string().parse::>(), + Err(ParseSaturatingIntegerError) + ); + } }