Skip to content

Commit

Permalink
refactor get evals function
Browse files Browse the repository at this point in the history
  • Loading branch information
guidopetri committed Apr 21, 2024
1 parent cb414aa commit 38d5219
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 82 deletions.
171 changes: 90 additions & 81 deletions src/chess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from calendar import timegm
from datetime import datetime, timedelta
from typing import Type
from typing import Any, Type

import lichess.api
import pandas as pd
Expand Down Expand Up @@ -42,7 +42,7 @@
from utils.types import Json, Visitor


def run_remote_sql_query(sql, **params):
def run_remote_sql_query(sql, **params) -> pd.DataFrame:
pg_cfg = postgres_cfg()
user = pg_cfg.user
password = pg_cfg.password
Expand All @@ -57,7 +57,7 @@ def run_remote_sql_query(sql, **params):
port=port,
)

df = pd.read_sql_query(sql, db, params=params)
df: pd.DataFrame = pd.read_sql_query(sql, db, params=params)

return df

Expand Down Expand Up @@ -196,6 +196,92 @@ def clean_chess_df(pgn: pd.DataFrame, json: pd.DataFrame) -> pd.DataFrame:
return df


def get_evals(df: pd.DataFrame,
local_stockfish: bool,
task: Task,
) -> pd.DataFrame:
sf_params: Any = stockfish_cfg()

df = df[['evaluations', 'eval_depths', 'positions']]

# explode the two different list-likes separately, then concat
no_evals: pd.DataFrame = df[~df['evaluations'].astype(bool)]
df = df[df['evaluations'].astype(bool)]

no_evals = pd.DataFrame(no_evals['positions'].explode())
no_evals['positions'] = get_clean_fens(no_evals['positions'])

evals: pd.Series = df['evaluations'].explode().reset_index(drop=True)
depths: pd.Series = df['eval_depths'].explode().reset_index(drop=True)
positions: pd.Series = df['positions'].explode().reset_index(drop=True)
positions = get_clean_fens(positions)

sql: str = """SELECT fen, evaluation, eval_depth
FROM position_evals
WHERE fen IN %(positions)s;
"""
db_evaluations = run_remote_sql_query(sql,
positions=tuple(positions.tolist() + no_evals['positions'].tolist()), # noqa
)
positions_evaluated = db_evaluations['fen'].drop_duplicates()

df = pd.concat([positions, evals, depths], axis=1)

if local_stockfish:

local_evals: list[float | None] = []

counter: int = 0
position_count: int = len(no_evals['positions'])
evaluation: float | None = None

for position in no_evals['positions'].tolist():
if position in positions_evaluated.values:
# position will be dropped later if evaluation is None
evaluation = None
else:
sf_eval: float | None = get_sf_evaluation(position + ' 0',
sf_params.location,
sf_params.depth)
if sf_eval is not None:
# TODO: this is implicitly setting evaluation = last
# eval if in a checkmate position. handle this better
evaluation = sf_eval

local_evals.append(evaluation)

# progress bar stuff
counter += 1

current_progress = counter / position_count
task.set_status_message(f'Analyzed :: '
f'{counter} / {position_count}')
task.set_progress_percentage(round(current_progress * 100, 2))

task.set_status_message(f'Analyzed all {position_count} positions')
task.set_progress_percentage(100)

no_evals['evaluations'] = local_evals
no_evals['eval_depths'] = sf_params.depth
no_evals.dropna(inplace=True)

df = pd.concat([df, no_evals], axis=0, ignore_index=True)

df = df[~df['positions'].isin(positions_evaluated)]

df.rename(columns={'evaluations': 'evaluation',
'eval_depths': 'eval_depth',
'positions': 'fen'},
inplace=True)
df['evaluation'] = pd.to_numeric(df['evaluation'],
errors='coerce')

df.dropna(inplace=True)
df = pd.concat([df, db_evaluations], axis=0, ignore_index=True)

return df


class FetchLichessApiJSON(Task):

player = Parameter(default='thibault')
Expand Down Expand Up @@ -314,84 +400,7 @@ def complete(self):

return

stockfish_params = stockfish_cfg()

df = df[['evaluations', 'eval_depths', 'positions']]

# explode the two different list-likes separately, then concat
no_evals = df[~df['evaluations'].astype(bool)]
df = df[df['evaluations'].astype(bool)]

no_evals = pd.DataFrame(no_evals['positions'].explode())
no_evals['positions'] = get_clean_fens(no_evals['positions'])

evals = df['evaluations'].explode().reset_index(drop=True)
depths = df['eval_depths'].explode().reset_index(drop=True)
positions = df['positions'].explode().reset_index(drop=True)
positions = get_clean_fens(positions)

sql = """SELECT fen, evaluation, eval_depth
FROM position_evals
WHERE fen IN %(positions)s;
"""
db_evaluations = run_remote_sql_query(sql,
positions=tuple(positions.tolist() + no_evals['positions'].tolist()), # noqa
)
positions_evaluated = db_evaluations['fen'].drop_duplicates()

df = pd.concat([positions, evals, depths], axis=1)

if self.local_stockfish:

local_evals = []

counter = 0
position_count = len(no_evals['positions'])
evaluation = None

for position in no_evals['positions'].tolist():
if position in positions_evaluated.values:
# position will be dropped later if evaluation is None
evaluation = None
else:
sf_eval = get_sf_evaluation(position + ' 0',
stockfish_params.location,
stockfish_params.depth)
if sf_eval is not None:
# TODO: this is implicitly setting evaluation = last
# eval if in a checkmate position. handle this better
evaluation = sf_eval

local_evals.append(evaluation)

# progress bar stuff
counter += 1

current_progress = counter / position_count
self.set_status_message(f'Analyzed :: '
f'{counter} / {position_count}')
self.set_progress_percentage(round(current_progress * 100, 2))

self.set_status_message(f'Analyzed all {position_count} positions')
self.set_progress_percentage(100)

no_evals['evaluations'] = local_evals
no_evals['eval_depths'] = stockfish_params.depth
no_evals.dropna(inplace=True)

df = pd.concat([df, no_evals], axis=0, ignore_index=True)

df = df[~df['positions'].isin(positions_evaluated)]

df.rename(columns={'evaluations': 'evaluation',
'eval_depths': 'eval_depth',
'positions': 'fen'},
inplace=True)
df['evaluation'] = pd.to_numeric(df['evaluation'],
errors='coerce')

df.dropna(inplace=True)
df = pd.concat([df, db_evaluations], axis=0, ignore_index=True)
df: pd.DataFrame = get_evals(df, self.local_stockfish, self)

with self.output().temporary_path() as temp_output_path:
df.to_pickle(temp_output_path, compression=None)
Expand Down
2 changes: 1 addition & 1 deletion src/pipeline_import/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def convert_clock_to_seconds(clocks):
return clocks


def get_clean_fens(positions):
def get_clean_fens(positions: pd.Series) -> pd.Series:
# split, get all but last element of resulting list, then re-join
return positions.str.split().str[:-1].str.join(' ')

Expand Down

0 comments on commit 38d5219

Please sign in to comment.