Skip to content

Commit

Permalink
Add tuple return type (#69)
Browse files Browse the repository at this point in the history
* Add tuple as an optional return type to get_wdl_stats function

* Incorporate changes from black and mypy. Add test case for when get_as_tuple=True.

* Change tuple type hint to Tuple[int, int, int] to satisfy mypy. Return tuple naively instead of type casting from list to tuple. Add 2 regression tests for robustness.
  • Loading branch information
evanjhoward11 authored Jun 2, 2024
1 parent f2cb530 commit 496f5b3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
32 changes: 23 additions & 9 deletions stockfish/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,11 +706,18 @@ def is_move_correct(self, move_value: str) -> bool:
self.info = old_self_info
return is_move_correct

def get_wdl_stats(self) -> Optional[List[int]]:
def get_wdl_stats(
self, get_as_tuple: bool = False
) -> Union[list[int] | tuple[int, int, int] | None]:
"""Returns Stockfish's win/draw/loss stats for the side to move.
Args:
get_as_tuple:
Option to return the wdl stats as a tuple instead of a list
`Boolean`. Default is `False`.
Returns:
A list of three integers, unless the game is over (in which case
A list or tuple of three integers, unless the game is over (in which case
`None` is returned).
"""

Expand All @@ -730,7 +737,12 @@ def get_wdl_stats(self) -> Optional[List[int]]:
return None
split_line = [line.split(" ") for line in lines if " multipv 1 " in line][-1]
wdl_index = split_line.index("wdl")
return [int(split_line[i]) for i in range(wdl_index + 1, wdl_index + 4)]

wdl_stats = [int(split_line[i]) for i in range(wdl_index + 1, wdl_index + 4)]

if get_as_tuple:
return (wdl_stats[0], wdl_stats[1], wdl_stats[2])
return wdl_stats

def does_current_engine_version_have_wdl_option(self) -> bool:
"""Returns whether the user's version of Stockfish has the option
Expand Down Expand Up @@ -915,13 +927,15 @@ def get_top_moves(
# get move
"Move": self._pick(line, "pv"),
# get cp if available
"Centipawn": int(self._pick(line, "cp")) * perspective
if "cp" in line
else None,
"Centipawn": (
int(self._pick(line, "cp")) * perspective if "cp" in line else None
),
# get mate if available
"Mate": int(self._pick(line, "mate")) * perspective
if "mate" in line
else None,
"Mate": (
int(self._pick(line, "mate")) * perspective
if "mate" in line
else None
),
}

# add more info if verbose
Expand Down
5 changes: 5 additions & 0 deletions tests/stockfish/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,11 @@ def test_get_wdl_stats(self, stockfish: Stockfish):
wdl_stats_3 = stockfish.get_wdl_stats()
assert isinstance(wdl_stats_3, list) and len(wdl_stats_3) == 3

wdl_stats_4 = stockfish.get_wdl_stats(get_as_tuple=True)
assert isinstance(wdl_stats_4, tuple) and len(wdl_stats_4) == 3
assert wdl_stats_3 == list(wdl_stats_4)
assert tuple(wdl_stats_3) == wdl_stats_4

stockfish.set_fen_position("8/8/8/8/8/3k4/3p4/3K4 w - - 0 1")
assert stockfish.get_wdl_stats() is None

Expand Down

0 comments on commit 496f5b3

Please sign in to comment.