diff --git a/bin/uci.rs b/bin/uci.rs index 5b36e258..fad403c6 100644 --- a/bin/uci.rs +++ b/bin/uci.rs @@ -68,7 +68,7 @@ impl Uci { fn go(&mut self, limits: Limits) -> Result<(), Anyhow> { let pv = self.engine.search(&self.position, limits); - let best = *pv.first().expect("expected some legal move"); + let best = *pv.first().context("the engine failed to find a move")?; let score = match pv.score().mate() { Some(p) if p > 0 => UciInfoAttribute::from_mate((p + 1).get() / 2), diff --git a/lib/chess/move.rs b/lib/chess/move.rs index 16286b7c..35b529de 100644 --- a/lib/chess/move.rs +++ b/lib/chess/move.rs @@ -1,5 +1,5 @@ use crate::chess::{Role, Square}; -use crate::util::{Binary, Bits}; +use crate::util::{Assume, Binary, Bits}; use derive_more::{DebugCustom, Display, Error}; use shakmaty as sm; use vampirc_uci::UciMove; @@ -80,7 +80,7 @@ impl Move { pub fn role(&self) -> Role { let mut bits = self.role; match bits.pop::().get() { - 0 => Role::decode(bits.pop()).expect("expected valid encoding"), + 0 => Role::decode(bits.pop()).assume(), _ => Role::Pawn, } } @@ -90,7 +90,7 @@ impl Move { let mut bits = self.role; match bits.pop::().get() { 0 => None, - _ => Some(Role::decode(bits.pop()).expect("expected valid encoding")), + _ => Some(Role::decode(bits.pop()).assume()), } } @@ -104,10 +104,7 @@ impl Move { Square::new(self.whither.file(), self.whence.rank()), )) } else { - Some(( - Role::decode(self.capture).expect("expected valid encoding"), - self.whither, - )) + Some((Role::decode(self.capture).assume(), self.whither)) } } } diff --git a/lib/chess/position.rs b/lib/chess/position.rs index 4c318e9a..a89e48aa 100644 --- a/lib/chess/position.rs +++ b/lib/chess/position.rs @@ -1,5 +1,5 @@ use crate::chess::{Bitboard, Color, Move, Outcome, Piece, Role, Square}; -use crate::util::{Bits, Buffer}; +use crate::util::{Assume, Bits, Buffer}; use derive_more::{DebugCustom, Display, Error, From}; use shakmaty as sm; use std::hash::{Hash, Hasher}; @@ -163,7 +163,7 @@ impl Position { self.by_piece(Piece(side, Role::King)) .into_iter() .next() - .expect("expected king on the board") + .assume() } /// The [`Role`] of the piece on the given [`Square`], if any. diff --git a/lib/chess/promotion.rs b/lib/chess/promotion.rs new file mode 100644 index 00000000..a17154c7 --- /dev/null +++ b/lib/chess/promotion.rs @@ -0,0 +1,138 @@ +use crate::chess::Role; +use crate::util::{Binary, Bits}; +use derive_more::{Display, Error}; +use shakmaty as sm; +use vampirc_uci::UciPiece; + +/// A promotion specifier. +#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +#[cfg_attr(test, derive(test_strategy::Arbitrary))] +#[repr(u8)] +pub enum Promotion { + #[display(fmt = "")] + None, + #[display(fmt = "n")] + Knight, + #[display(fmt = "b")] + Bishop, + #[display(fmt = "r")] + Rook, + #[display(fmt = "q")] + Queen, +} + +/// The reason why decoding [`Promotion`] from binary failed. +#[derive(Debug, Display, Clone, Eq, PartialEq, Error)] +#[cfg_attr(test, derive(test_strategy::Arbitrary))] +#[display(fmt = "not a valid promotion")] +pub struct DecodePromotionError; + +impl Binary for Promotion { + type Bits = Bits; + type Error = DecodePromotionError; + + fn encode(&self) -> Self::Bits { + Bits::new(*self as _) + } + + fn decode(bits: Self::Bits) -> Result { + use Promotion::*; + [None, Knight, Bishop, Rook, Queen] + .into_iter() + .nth(bits.get() as _) + .ok_or(DecodePromotionError) + } +} + +impl From for Option { + fn from(p: Promotion) -> Self { + match p { + Promotion::None => None, + Promotion::Knight => Some(Role::Knight), + Promotion::Bishop => Some(Role::Bishop), + Promotion::Rook => Some(Role::Rook), + Promotion::Queen => Some(Role::Queen), + } + } +} + +#[doc(hidden)] +impl From for Option { + fn from(p: Promotion) -> Self { + match p { + Promotion::None => None, + Promotion::Knight => Some(UciPiece::Knight), + Promotion::Bishop => Some(UciPiece::Bishop), + Promotion::Rook => Some(UciPiece::Rook), + Promotion::Queen => Some(UciPiece::Queen), + } + } +} + +#[doc(hidden)] +impl From> for Promotion { + fn from(p: Option) -> Self { + match p { + None => Promotion::None, + Some(UciPiece::Knight) => Promotion::Knight, + Some(UciPiece::Bishop) => Promotion::Bishop, + Some(UciPiece::Rook) => Promotion::Rook, + Some(UciPiece::Queen) => Promotion::Queen, + Some(v) => panic!("unexpected {v:?}"), + } + } +} + +#[doc(hidden)] +impl From> for Promotion { + fn from(p: Option) -> Self { + match p { + None => Promotion::None, + Some(sm::Role::Knight) => Promotion::Knight, + Some(sm::Role::Bishop) => Promotion::Bishop, + Some(sm::Role::Rook) => Promotion::Rook, + Some(sm::Role::Queen) => Promotion::Queen, + Some(v) => panic!("unexpected {v:?}"), + } + } +} + +#[doc(hidden)] +impl From for Option { + fn from(p: Promotion) -> Self { + match p { + Promotion::None => None, + Promotion::Knight => Some(sm::Role::Knight), + Promotion::Bishop => Some(sm::Role::Bishop), + Promotion::Rook => Some(sm::Role::Rook), + Promotion::Queen => Some(sm::Role::Queen), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_strategy::proptest; + + #[proptest] + fn decoding_encoded_promotion_is_an_identity(p: Promotion) { + assert_eq!(Promotion::decode(p.encode()), Ok(p)); + } + + #[proptest] + fn decoding_promotion_fails_for_invalid_bits(#[strategy(5u8..8)] n: u8) { + let b = ::Bits::new(n as _); + assert_eq!(Promotion::decode(b), Err(DecodePromotionError)); + } + + #[proptest] + fn promotion_has_an_equivalent_vampirc_uci_representation(p: Promotion) { + assert_eq!(Promotion::from(Option::::from(p)), p); + } + + #[proptest] + fn promotion_has_an_equivalent_shakmaty_representation(p: Promotion) { + assert_eq!(Promotion::from(Option::::from(p)), p); + } +} diff --git a/lib/search/engine.rs b/lib/search/engine.rs index 7261670b..1faf9312 100644 --- a/lib/search/engine.rs +++ b/lib/search/engine.rs @@ -2,7 +2,7 @@ use crate::chess::{Move, Piece, Position, Role, Zobrist}; use crate::nnue::Evaluator; use crate::search::{Depth, Limits, Options, Ply, Pv, Score, Value}; use crate::search::{Transposition, TranspositionTable}; -use crate::util::{Buffer, Timeout, Timer}; +use crate::util::{Assume, Buffer, Timeout, Timer}; use rayon::{prelude::*, ThreadPool, ThreadPoolBuilder}; use std::sync::atomic::{AtomicI16, Ordering}; use std::{cmp::max, ops::Range, time::Duration}; @@ -186,7 +186,7 @@ impl Engine { } else if !in_check { if let Some(d) = self.nmp(pos, score, beta, depth) { let mut next = pos.clone(); - next.pass().expect("expected possible pass"); + next.pass().assume(); if d <= ply || -self.nw(&next, -beta + 1, d, ply + 1, timer)? >= beta { #[cfg(not(test))] // The null move pruning heuristic is not exact. @@ -201,7 +201,7 @@ impl Engine { } let mut next = pos.clone(); - next.play(m).expect("expected legal move"); + next.play(m).assume(); let guess = -next.see(m.whither(), Value::LOWER..Value::UPPER).cast(); let rank = if Some(m) == tpos.map(|t| t.best()) { @@ -218,7 +218,7 @@ impl Engine { None => return Ok(Pv::new(score, [])), Some((m, _, _)) => { let mut next = pos.clone(); - next.play(m).expect("expected legal move"); + next.play(m).assume(); let mut pv = -self.ns(&next, -beta..-alpha, depth, ply + 1, timer)?; pv.shift(m); pv @@ -244,7 +244,7 @@ impl Engine { } let mut next = pos.clone(); - next.play(m).expect("expected legal move"); + next.play(m).assume(); if !in_check { if let Some(d) = self.lmp(&next, guess, alpha, depth) { @@ -276,7 +276,7 @@ impl Engine { }) .chain([Ok(Some((best, i16::MAX)))]) .try_reduce(|| None, |a, b| Ok(max(a, b)))? - .expect("expected at least one principal variation"); + .assume(); self.record(zobrist, bounds, depth, ply, best.score(), best[0]); diff --git a/lib/search/transposition.rs b/lib/search/transposition.rs index 912dcb8c..7432acf0 100644 --- a/lib/search/transposition.rs +++ b/lib/search/transposition.rs @@ -1,6 +1,6 @@ use crate::chess::{Move, Zobrist}; use crate::search::{Depth, Score}; -use crate::util::{Binary, Bits, Cache}; +use crate::util::{Assume, Binary, Bits, Cache}; use derive_more::{Display, Error}; use std::{cmp::Ordering, mem::size_of, ops::RangeInclusive}; @@ -196,7 +196,7 @@ impl TranspositionTable { let sig = self.signature_of(key); let bits = self.cache.load(self.index_of(key)); - match Binary::decode(bits).expect("expected valid encoding") { + match Binary::decode(bits).assume() { Some(SignedTransposition(t, s)) if s == sig => Some(t), _ => None, } @@ -209,12 +209,11 @@ impl TranspositionTable { if self.capacity() > 0 { let sig = self.signature_of(key); let bits = Some(SignedTransposition(transposition, sig)).encode(); - self.cache.update(self.index_of(key), |r| { - match Binary::decode(r).expect("expected valid encoding") { + self.cache + .update(self.index_of(key), |r| match Binary::decode(r).assume() { Some(SignedTransposition(t, _)) if t > transposition => None, _ => Some(bits), - } - }) + }) } } } diff --git a/lib/util.rs b/lib/util.rs index fe56430a..6a91e454 100644 --- a/lib/util.rs +++ b/lib/util.rs @@ -1,3 +1,4 @@ +mod assume; mod binary; mod bits; mod bounds; @@ -6,6 +7,7 @@ mod cache; mod saturating; mod timer; +pub use assume::*; pub use binary::*; pub use bits::*; pub use bounds::*; diff --git a/lib/util/assume.rs b/lib/util/assume.rs new file mode 100644 index 00000000..ff226366 --- /dev/null +++ b/lib/util/assume.rs @@ -0,0 +1,26 @@ +/// A trait for types that can be assumed to be another type. +pub trait Assume { + /// The type of the assumed value. + type Assumed; + + /// Assume `Self` represents a value of `Self::Assumed`. + fn assume(self) -> Self::Assumed; +} + +impl Assume for Option { + type Assumed = T; + + fn assume(self) -> Self::Assumed { + // Definitely not safe, but we'll do it anyway. + unsafe { self.unwrap_unchecked() } + } +} + +impl Assume for Result { + type Assumed = T; + + fn assume(self) -> Self::Assumed { + // Definitely not safe, but we'll do it anyway. + unsafe { self.unwrap_unchecked() } + } +}