Skip to content

Commit

Permalink
handle cloud eval api rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
guidopetri committed Nov 14, 2024
1 parent 7d659e5 commit aaaaa89
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 20 deletions.
81 changes: 65 additions & 16 deletions src/pipeline_import/transforms.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#! /usr/bin/env python3

import logging
import re
from datetime import date, timedelta
from pathlib import Path
from subprocess import SubprocessError
from typing import Type
from typing import Type, cast

import chess
import lichess.api
import pandas as pd
import stockfish
import valkey
from chess.pgn import Game
from pandas import (
Series,
Expand All @@ -21,30 +24,76 @@
)
from utils.types import Json, Visitor

MAX_CLOUD_API_CALLS_PER_DAY = 3


class LichessApiClient(lichess.api.DefaultApiClient):
max_retries = 3


def increment_successful_api_call(valkey_client: valkey.Valkey,
key: str,
expire_at_unix: int,
) -> None:
valkey_client.incr(key, 1)
valkey_client.expireat(key, expire_at_unix, nx=True)


def get_sf_evaluation(fen: str,
sf_location: Path,
sf_depth: int,
valkey_client: valkey.Valkey | None = None,
) -> float:
# get cloud eval if available
try:
client = LichessApiClient()
cloud_eval = lichess.api.cloud_eval(fen=fen, multiPv=1, client=client)
rating = cloud_eval['pvs'][0]
if 'cp' in rating:
rating = rating['cp'] / 100
elif 'mate' in rating:
rating = -9999 if rating['mate'] < 0 else 9999
else:
raise KeyError(f'{fen}, {rating}')
return rating
except lichess.api.ApiHttpError:
# continue execution
pass
today: date = date.today()
valkey_key: str = today.strftime('lichess-cloud-evals-api-%F')

# by default, don't use cloud API if we are not tracking api hits
api_calls_done: int = MAX_CLOUD_API_CALLS_PER_DAY + 1

if valkey_client is not None:
valkey_response: str | None = cast(str | None,
valkey_client.get(valkey_key),
)
api_calls_done: int = int(valkey_response
if valkey_response is not None
else 0
)

if api_calls_done < MAX_CLOUD_API_CALLS_PER_DAY:
try:
# get cloud eval if available
client = LichessApiClient()
cloud_eval = lichess.api.cloud_eval(fen=fen,
multiPv=1,
client=client,
)
rating = cloud_eval['pvs'][0]
if 'cp' in rating:
rating = rating['cp'] / 100
elif 'mate' in rating:
rating = -9999 if rating['mate'] < 0 else 9999
else:
raise KeyError(f'{fen}, {rating}')

if valkey_client is not None:
tomorrow: date = today + timedelta(days=1)
expire_at_unix: int = int(tomorrow.strftime('%s'))

increment_successful_api_call(valkey_client,
valkey_key,
expire_at_unix,
)
return rating
except lichess.api.ApiHttpError as e:
logging.warning(f'Got an API HTTP error: {e}')
# continue execution
pass
except lichess.api.ApiError as e:
logging.warning('Hit an API error (potentially a rate limit) '
f'with only {api_calls_done} calls')
logging.warning(e)
# continue execution
pass

# implicit else
sf = stockfish.Stockfish(sf_location,
Expand Down
11 changes: 10 additions & 1 deletion src/vendors/stockfish.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Any

import pandas as pd
import valkey
from luigi import Task
from pipeline_import.configs import stockfish_cfg
from pipeline_import.transforms import get_clean_fens, get_sf_evaluation
Expand Down Expand Up @@ -46,14 +48,21 @@ def get_evals(df: pd.DataFrame,
position_count: int = len(no_evals['positions'])
evaluation: float | None = None

valkey_url: str = os.environ['VALKEY_CONNECTION_URL']
valkey_client: valkey.Valkey = valkey.from_url(valkey_url,
decode_responses=True,
)

for position in no_evals['positions'].tolist():
if position in positions_evaluated.values:
# position will be dropped later if evaluation is None
evaluation = None
else:
evaluation = get_sf_evaluation(position + ' 0',
sf_params.location,
sf_params.depth)
sf_params.depth,
valkey_client,
)

local_evals.append(evaluation)

Expand Down
67 changes: 64 additions & 3 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import pytest
from pipeline_import import transforms, visitors
from pipeline_import.transforms import MAX_CLOUD_API_CALLS_PER_DAY


def test_parse_headers():
Expand Down Expand Up @@ -228,11 +229,63 @@ def test_get_sf_evaluation_cloud(mocker):
fen = 'r1bqkb1r/pp1ppppp/2n2n2/2p5/8/1P3NP1/PBPPPP1P/RN1QKB1R b KQkq - 0 1'

# loc/depth don't matter
rating = transforms.get_sf_evaluation(fen, '', 1)
rating = transforms.get_sf_evaluation(fen,
'',
1,
valkey_client=mocker.MagicMock(),
)

assert rating == -0.3


@pytest.fixture
def mock_valkey_client():
class MockValkey:
def __init__(self):
self.counter = 0

def get(self, *args, **kwargs):
return '0'

def incr(self, *args, **kwargs):
self.counter += 1

def expireat(self, *args, **kwargs):
pass
return MockValkey()


def test_get_sf_evaluation_tracks_api_calls(mocker, mock_valkey_client):
mock_parsed_resp = {'pvs': [{'cp': -30}]}

mocker.patch('lichess.api.cloud_eval', return_value=mock_parsed_resp)

# loc/depth don't matter
transforms.get_sf_evaluation('',
'',
1,
valkey_client=mock_valkey_client,
)

assert mock_valkey_client.counter == 1


def test_get_sf_evaluation_doesnt_exceed_api_calls(mocker, mock_valkey_client):
mock_sf = mocker.patch('stockfish.Stockfish')
mocker.patch('re.search', return_value=None)

mock_valkey_client.counter = MAX_CLOUD_API_CALLS_PER_DAY + 1

with pytest.raises(SubprocessError):
transforms.get_sf_evaluation('',
'',
1,
valkey_client=mock_valkey_client,
)

mock_sf.assert_called_once()


def test_get_sf_evaluation_cloud_mate_in_x(mocker):
mock_parsed_resp = {'pvs': [{'mate': 1}]}

Expand All @@ -242,15 +295,23 @@ def test_get_sf_evaluation_cloud_mate_in_x(mocker):
fen = 'r1bqkbnr/ppp2ppp/2np4/4p3/2B1P3/5Q2/PPPP1PPP/RNB1K1NR w KQkq - 2 4'

# loc/depth don't matter
rating = transforms.get_sf_evaluation(fen, '', 1)
rating = transforms.get_sf_evaluation(fen,
'',
1,
valkey_client=mocker.MagicMock(),
)

assert rating == 9999


def test_get_sf_evaluation_cloud_error(mocker):
mocker.patch('lichess.api.cloud_eval', return_value={'pvs': ['foobar']})
with pytest.raises(KeyError):
transforms.get_sf_evaluation('fake fen', '', 1)
transforms.get_sf_evaluation('fake fen',
'',
1,
valkey_client=mocker.MagicMock(),
)


def test_get_sf_evaluation_local_returns_error(mocker, mocked_cloud_eval):
Expand Down

0 comments on commit aaaaa89

Please sign in to comment.