diff --git a/src/chess_pipeline.py b/src/chess_pipeline.py index 5c3e7ca..d27b880 100644 --- a/src/chess_pipeline.py +++ b/src/chess_pipeline.py @@ -4,7 +4,7 @@ import os from calendar import timegm from datetime import datetime, timedelta -from typing import Type +from typing import Any, Type import lichess.api import pandas as pd @@ -42,7 +42,7 @@ from utils.types import Json, Visitor -def run_remote_sql_query(sql, **params): +def run_remote_sql_query(sql, **params) -> pd.DataFrame: pg_cfg = postgres_cfg() user = pg_cfg.user password = pg_cfg.password @@ -57,7 +57,7 @@ def run_remote_sql_query(sql, **params): port=port, ) - df = pd.read_sql_query(sql, db, params=params) + df: pd.DataFrame = pd.read_sql_query(sql, db, params=params) return df @@ -196,6 +196,92 @@ def clean_chess_df(pgn: pd.DataFrame, json: pd.DataFrame) -> pd.DataFrame: return df +def get_evals(df: pd.DataFrame, + local_stockfish: bool, + task: Task, + ) -> pd.DataFrame: + sf_params: Any = stockfish_cfg() + + df = df[['evaluations', 'eval_depths', 'positions']] + + # explode the two different list-likes separately, then concat + no_evals: pd.DataFrame = df[~df['evaluations'].astype(bool)] + df = df[df['evaluations'].astype(bool)] + + no_evals = pd.DataFrame(no_evals['positions'].explode()) + no_evals['positions'] = get_clean_fens(no_evals['positions']) + + evals: pd.Series = df['evaluations'].explode().reset_index(drop=True) + depths: pd.Series = df['eval_depths'].explode().reset_index(drop=True) + positions: pd.Series = df['positions'].explode().reset_index(drop=True) + positions = get_clean_fens(positions) + + sql: str = """SELECT fen, evaluation, eval_depth + FROM position_evals + WHERE fen IN %(positions)s; + """ + db_evaluations = run_remote_sql_query(sql, + positions=tuple(positions.tolist() + no_evals['positions'].tolist()), # noqa + ) + positions_evaluated = db_evaluations['fen'].drop_duplicates() + + df = pd.concat([positions, evals, depths], axis=1) + + if local_stockfish: + + local_evals: list[float | None] = [] + + counter: int = 0 + position_count: int = len(no_evals['positions']) + evaluation: float | None = None + + for position in no_evals['positions'].tolist(): + if position in positions_evaluated.values: + # position will be dropped later if evaluation is None + evaluation = None + else: + sf_eval: float | None = get_sf_evaluation(position + ' 0', + sf_params.location, + sf_params.depth) + if sf_eval is not None: + # TODO: this is implicitly setting evaluation = last + # eval if in a checkmate position. handle this better + evaluation = sf_eval + + local_evals.append(evaluation) + + # progress bar stuff + counter += 1 + + current_progress = counter / position_count + task.set_status_message(f'Analyzed :: ' + f'{counter} / {position_count}') + task.set_progress_percentage(round(current_progress * 100, 2)) + + task.set_status_message(f'Analyzed all {position_count} positions') + task.set_progress_percentage(100) + + no_evals['evaluations'] = local_evals + no_evals['eval_depths'] = sf_params.depth + no_evals.dropna(inplace=True) + + df = pd.concat([df, no_evals], axis=0, ignore_index=True) + + df = df[~df['positions'].isin(positions_evaluated)] + + df.rename(columns={'evaluations': 'evaluation', + 'eval_depths': 'eval_depth', + 'positions': 'fen'}, + inplace=True) + df['evaluation'] = pd.to_numeric(df['evaluation'], + errors='coerce') + + df.dropna(inplace=True) + df = pd.concat([df, db_evaluations], axis=0, ignore_index=True) + + return df + + class FetchLichessApiJSON(Task): player = Parameter(default='thibault') @@ -314,84 +400,7 @@ def complete(self): return - stockfish_params = stockfish_cfg() - - df = df[['evaluations', 'eval_depths', 'positions']] - - # explode the two different list-likes separately, then concat - no_evals = df[~df['evaluations'].astype(bool)] - df = df[df['evaluations'].astype(bool)] - - no_evals = pd.DataFrame(no_evals['positions'].explode()) - no_evals['positions'] = get_clean_fens(no_evals['positions']) - - evals = df['evaluations'].explode().reset_index(drop=True) - depths = df['eval_depths'].explode().reset_index(drop=True) - positions = df['positions'].explode().reset_index(drop=True) - positions = get_clean_fens(positions) - - sql = """SELECT fen, evaluation, eval_depth - FROM position_evals - WHERE fen IN %(positions)s; - """ - db_evaluations = run_remote_sql_query(sql, - positions=tuple(positions.tolist() + no_evals['positions'].tolist()), # noqa - ) - positions_evaluated = db_evaluations['fen'].drop_duplicates() - - df = pd.concat([positions, evals, depths], axis=1) - - if self.local_stockfish: - - local_evals = [] - - counter = 0 - position_count = len(no_evals['positions']) - evaluation = None - - for position in no_evals['positions'].tolist(): - if position in positions_evaluated.values: - # position will be dropped later if evaluation is None - evaluation = None - else: - sf_eval = get_sf_evaluation(position + ' 0', - stockfish_params.location, - stockfish_params.depth) - if sf_eval is not None: - # TODO: this is implicitly setting evaluation = last - # eval if in a checkmate position. handle this better - evaluation = sf_eval - - local_evals.append(evaluation) - - # progress bar stuff - counter += 1 - - current_progress = counter / position_count - self.set_status_message(f'Analyzed :: ' - f'{counter} / {position_count}') - self.set_progress_percentage(round(current_progress * 100, 2)) - - self.set_status_message(f'Analyzed all {position_count} positions') - self.set_progress_percentage(100) - - no_evals['evaluations'] = local_evals - no_evals['eval_depths'] = stockfish_params.depth - no_evals.dropna(inplace=True) - - df = pd.concat([df, no_evals], axis=0, ignore_index=True) - - df = df[~df['positions'].isin(positions_evaluated)] - - df.rename(columns={'evaluations': 'evaluation', - 'eval_depths': 'eval_depth', - 'positions': 'fen'}, - inplace=True) - df['evaluation'] = pd.to_numeric(df['evaluation'], - errors='coerce') - - df.dropna(inplace=True) - df = pd.concat([df, db_evaluations], axis=0, ignore_index=True) + df: pd.DataFrame = get_evals(df, self.local_stockfish, self) with self.output().temporary_path() as temp_output_path: df.to_pickle(temp_output_path, compression=None) diff --git a/src/pipeline_import/transforms.py b/src/pipeline_import/transforms.py index c77834d..dcbc702 100644 --- a/src/pipeline_import/transforms.py +++ b/src/pipeline_import/transforms.py @@ -105,7 +105,7 @@ def convert_clock_to_seconds(clocks): return clocks -def get_clean_fens(positions): +def get_clean_fens(positions: pd.Series) -> pd.Series: # split, get all but last element of resulting list, then re-join return positions.str.split().str[:-1].str.join(' ')