diff --git a/src-tauri/src/db/search.rs b/src-tauri/src/db/search.rs index 47c9ca57..d5a0b029 100644 --- a/src-tauri/src/db/search.rs +++ b/src-tauri/src/db/search.rs @@ -5,6 +5,7 @@ use rayon::prelude::*; use serde::{Deserialize, Serialize}; use shakmaty::{fen::Fen, san::SanPlus, Bitboard, ByColor, Chess, Position, Setup}; use std::{path::PathBuf, sync::Mutex, time::Instant}; +use tokio::sync::Semaphore; use crate::{ db::{ @@ -12,7 +13,7 @@ use crate::{ normalize_games, schema::*, ConnectionOptions, MaterialCount, }, error::Error, - AppState, + AppState, GameData, }; #[derive(Debug, Hash, PartialEq, Eq, Clone)] @@ -110,24 +111,23 @@ impl PositionQuery { } } - fn is_reachable(&self, material: &MaterialCount, pawn_home: u16, reverse: bool) -> bool { + fn is_reachable_by(&self, material: &MaterialCount, pawn_home: u16) -> bool { match self { PositionQuery::Exact(ref data) => { - if reverse { - is_end_reachable(pawn_home, data.pawn_home) - && is_material_reachable(material, &data.material) - } else { - is_end_reachable(data.pawn_home, pawn_home) - && is_material_reachable(&data.material, material) - } + is_end_reachable(data.pawn_home, pawn_home) + && is_material_reachable(&data.material, material) } - PositionQuery::Partial(ref data) => { - if reverse { - is_material_reachable(material, &data.material) - } else { - is_material_reachable(&data.material, material) - } + PositionQuery::Partial(ref data) => is_material_reachable(&data.material, material), + } + } + + fn can_reach(&self, material: &MaterialCount, pawn_home: u16) -> bool { + match self { + PositionQuery::Exact(ref data) => { + is_end_reachable(pawn_home, data.pawn_home) + && is_material_reachable(material, &data.material) } + PositionQuery::Partial(_) => true, } } } @@ -137,10 +137,12 @@ fn is_end_reachable(end: u16, pos: u16) -> bool { end & !pos == 0 } +/// Returns true if the end material is reachable fn is_material_reachable(end: &MaterialCount, pos: &MaterialCount) -> bool { end.white <= pos.white && end.black <= pos.black } +/// Returns true if the subset is contained in the container fn is_contained(container: Bitboard, subset: Bitboard) -> bool { container & subset == subset } @@ -173,7 +175,7 @@ fn get_move_after_match( let m = decode_move(*byte, &chess).unwrap(); chess.play_unchecked(&m); let board = chess.board(); - if !query.is_reachable(&get_material_count(board), get_pawn_home(board), false) { + if !query.is_reachable_by(&get_material_count(board), get_pawn_home(board)) { return Ok(None); } if query.matches(&chess) { @@ -222,20 +224,50 @@ pub async fn search_position( info!("got {} games: {:?}", games.len(), start.elapsed()); } + let (openings, ids) = execute_query(&query, &games, &state.new_request); + info!("finished search in {:?}", start.elapsed()); + + if state.new_request.available_permits() == 0 { + drop(permit); + return Err(Error::SearchStopped); + } + + let (white_players, black_players) = diesel::alias!(players as white, players as black); + let games: Vec<(Game, Player, Player, Event, Site)> = games::table + .inner_join(white_players.on(games::white_id.eq(white_players.field(players::id)))) + .inner_join(black_players.on(games::black_id.eq(black_players.field(players::id)))) + .inner_join(events::table.on(games::event_id.eq(events::id))) + .inner_join(sites::table.on(games::site_id.eq(sites::id))) + .filter(games::id.eq_any(ids)) + .load(db)?; + let normalized_games = normalize_games(games); + + state + .line_cache + .insert((query, file), (openings.clone(), normalized_games.clone())); + + Ok((openings, normalized_games)) +} + +fn execute_query( + query: &PositionQuery, + games: &[GameData], + new_request: &Semaphore, +) -> (Vec, Vec) { let openings: DashMap = DashMap::new(); let sample_games: Mutex> = Mutex::new(Vec::new()); games.par_iter().for_each( |(id, result, game, end_pawn_home, white_material, black_material)| { - if state.new_request.available_permits() == 0 { + if new_request.available_permits() == 0 { return; } let end_material: MaterialCount = ByColor { white: *white_material as u8, black: *black_material as u8, }; - if query.is_reachable(&end_material, *end_pawn_home as u16, true) { - if let Ok(Some(m)) = get_move_after_match(game, &query) { + if query.can_reach(&end_material, *end_pawn_home as u16) { + if let Ok(Some(m)) = get_move_after_match(game, query) { if sample_games.lock().unwrap().len() < 10 { sample_games.lock().unwrap().push(*id); } @@ -270,31 +302,11 @@ pub async fn search_position( } }, ); - info!("finished search in {:?}", start.elapsed()); - if state.new_request.available_permits() == 0 { - drop(permit); - return Err(Error::SearchStopped); - } - - let ids: Vec = sample_games.lock().unwrap().clone(); - - let (white_players, black_players) = diesel::alias!(players as white, players as black); - let games: Vec<(Game, Player, Player, Event, Site)> = games::table - .inner_join(white_players.on(games::white_id.eq(white_players.field(players::id)))) - .inner_join(black_players.on(games::black_id.eq(black_players.field(players::id)))) - .inner_join(events::table.on(games::event_id.eq(events::id))) - .inner_join(sites::table.on(games::site_id.eq(sites::id))) - .filter(games::id.eq_any(ids)) - .load(db)?; - let normalized_games = normalize_games(games); let openings: Vec = openings.into_iter().map(|(_, v)| v).collect(); + let sample_games_ids: Vec = sample_games.lock().unwrap().clone(); - state - .line_cache - .insert((query, file), (openings.clone(), normalized_games.clone())); - - Ok((openings, normalized_games)) + (openings, sample_games_ids) } pub async fn is_position_in_db( @@ -339,7 +351,7 @@ pub async fn is_position_in_db( white: *white_material as u8, black: *black_material as u8, }; - query.is_reachable(&end_material, *end_pawn_home as u16, true) + query.can_reach(&end_material, *end_pawn_home as u16) && get_move_after_match(game, &query).unwrap_or(None).is_some() }, ); @@ -398,16 +410,11 @@ mod tests { #[test] #[should_panic] - fn fail_partial_match_1() { + fn fail_partial_match() { assert_partial_match( "8/8/8/8/8/8/8/6N1 w - - 0 1", "3k4/8/8/8/8/4P3/3PKP2/7N w - - 0 1", ); - } - - #[test] - #[should_panic] - fn fail_partial_match_2() { assert_partial_match( "8/8/8/8/8/8/8/6N1 w - - 0 1", "3k4/8/8/8/8/4P3/3PKP2/6n1 w - - 0 1", @@ -415,7 +422,39 @@ mod tests { } #[test] - fn get_move_after_match_test() { + fn correct_exact_is_reachable() { + let query = + PositionQuery::exact_from_fen("rnbqkb1r/pppp1ppp/5n2/4p3/4P3/2N5/PPPP1PPP/R1BQKBNR") + .unwrap(); + let chess = Chess::default(); + assert!(query.is_reachable_by( + &get_material_count(chess.board()), + get_pawn_home(chess.board()) + )); + } + + #[test] + fn correct_partial_is_reachable() { + let query = PositionQuery::partial_from_fen("8/8/8/8/8/8/8/8").unwrap(); + let chess = Chess::default(); + assert!(query.is_reachable_by( + &get_material_count(chess.board()), + get_pawn_home(chess.board()) + )); + } + + #[test] + fn correct_partial_can_reach() { + let query = PositionQuery::partial_from_fen("8/8/8/8/8/8/8/8").unwrap(); + let chess = Chess::default(); + assert!(query.can_reach( + &get_material_count(chess.board()), + get_pawn_home(chess.board()) + )); + } + + #[test] + fn get_move_after_exact_match_test() { let game = vec![12, 12]; // 1. e4 e5 let query = @@ -429,8 +468,18 @@ mod tests { assert_eq!(result, Some("e5".to_string())); let query = - PositionQuery::exact_from_fen("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR").unwrap(); + PositionQuery::exact_from_fen("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR") + .unwrap(); let result = get_move_after_match(&game, &query).unwrap(); assert_eq!(result, Some("*".to_string())); } + + #[test] + fn get_move_after_partial_match_test() { + let game = vec![12, 12]; // 1. e4 e5 + + let query = PositionQuery::partial_from_fen("8/pppppppp/8/8/8/8/PPPPPPPP/8").unwrap(); + let result = get_move_after_match(&game, &query).unwrap(); + assert_eq!(result, Some("e4".to_string())); + } } diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 2abf3e59..281f7d82 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -107,6 +107,8 @@ async fn start_server(username: String, verifier: String, window: Window) -> Res })?) } +pub type GameData = (i32, Option, Vec, i32, i32, i32); + #[derive(Derivative)] #[derivative(Default)] pub struct AppState { @@ -115,7 +117,7 @@ pub struct AppState { diesel::r2d2::Pool>, >, line_cache: DashMap<(PositionQuery, PathBuf), (Vec, Vec)>, - db_cache: Mutex, Vec, i32, i32, i32)>>, + db_cache: Mutex>, analysis_cache: DashMap>, #[derivative(Default(value = "Arc::new(Semaphore::new(2))"))] new_request: Arc,