Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
madcpf committed Oct 6, 2023
1 parent bdf8746 commit 68be945
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 2 deletions.
5 changes: 4 additions & 1 deletion unitary/examples/quantum_chinese_chess/chess_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@


def test_game_init():
output = io.StringIO()
sys.stdout = output
with mock.patch("builtins.input", side_effect=["y", "Bob", "Ben"]):
game = QuantumChineseChess()
assert game.lang == Language.ZH
assert game.players_name == ["Bob", "Ben"]
assert game.current_player == 0
assert "Welcome" in output.getvalue()
sys.stdout = sys.__stdout__


def test_game_invalid_move():
output = io.StringIO()
sys.stdout = output
with mock.patch("builtins.input", side_effect=["y", "Bob", "Ben", "a1n1", "exit"]):
# with pytest.raises(ValueError, match = "Invalid location string. Make sure they are from a0 to i9."):
game = QuantumChineseChess()
game.play()
assert (
Expand Down
6 changes: 5 additions & 1 deletion unitary/examples/quantum_chinese_chess/move.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def parse_input_string(str_to_parse: str) -> Tuple[List[str], List[str]]:
sources = [str_to_parse[0:2]]
targets = [str_to_parse[2:4]]
if sources[0] == targets[0]:
raise ValueError("source and target should not be the same.")
raise ValueError("Source and target should not be the same.")

# Make sure all the locations are valid.
for location in sources + targets:
Expand Down Expand Up @@ -115,6 +115,10 @@ def __eq__(self, other):
)
return False

def _verify_objects(self, *objects):
# TODO(): add checks that apply to all move types
return

def effect(self, *objects):
# TODO(): add effects according to move_type and move_variant
return
Expand Down
149 changes: 149 additions & 0 deletions unitary/examples/quantum_chinese_chess/move_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2023 The Unitary Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unitary.examples.quantum_chinese_chess.move import (
Move,
get_move_from_string,
parse_input_string,
)
from unitary.examples.quantum_chinese_chess.board import Board
from unitary.examples.quantum_chinese_chess.enums import MoveType, MoveVariant
import pytest


def test_parse_success():
assert parse_input_string("a1b1") == (["a1"], ["b1"])
assert parse_input_string("a1b1^c2") == (["a1", "b1"], ["c2"])
assert parse_input_string("a1^b1c2") == (["a1"], ["b1", "c2"])


def test_parse_fail():
with pytest.raises(ValueError, match="Invalid sources/targets string "):
parse_input_string("a1^b1")
with pytest.raises(ValueError, match="Invalid sources/targets string "):
parse_input_string("a^1b1c2")
with pytest.raises(ValueError, match="Two sources should not be the same."):
parse_input_string("a1a1^c2")
with pytest.raises(ValueError, match="Two targets should not be the same."):
parse_input_string("a1^c2c2")
with pytest.raises(ValueError, match="Invalid sources/targets string "):
parse_input_string("a1b")
with pytest.raises(ValueError, match="Source and target should not be the same."):
parse_input_string("a1a1")
with pytest.raises(ValueError, match="Invalid location string."):
parse_input_string("a1n1")


def test_move_eq():
board = Board.from_fen()
move1 = Move(
"a1",
"b2",
board,
"c1",
move_type=MoveType.MERGE_JUMP,
move_variant=MoveVariant.CAPTURE,
)
move2 = Move(
"a1",
"b2",
board,
"c1",
move_type=MoveType.MERGE_JUMP,
move_variant=MoveVariant.CAPTURE,
)
move3 = Move(
"a1", "b2", board, move_type=MoveType.JUMP, move_variant=MoveVariant.CAPTURE
)
move4 = Move(
"a1",
"b2",
board,
"c1",
move_type=MoveType.MERGE_SLIDE,
move_variant=MoveVariant.CAPTURE,
)

assert move1 == move2
assert move1 != move3
assert move1 != move4


def test_move_type():
# TODO(): change to real senarios
board = Board.from_fen()
move1 = Move(
"a1",
"b2",
board,
"c1",
move_type=MoveType.MERGE_JUMP,
move_variant=MoveVariant.CAPTURE,
)
assert move1.is_split_move() == False
assert move1.is_merge_move()

move2 = Move(
"a1",
"b2",
board,
target2="c1",
move_type=MoveType.SPLIT_JUMP,
move_variant=MoveVariant.BASIC,
)
assert move2.is_split_move()
assert move2.is_merge_move() == False

move3 = Move(
"a1", "b2", board, move_type=MoveType.SLIDE, move_variant=MoveVariant.CAPTURE
)
assert move3.is_split_move() == False
assert move3.is_merge_move() == False


def test_to_str():
# TODO(): change to real senarios
board = Board.from_fen()
move1 = Move(
"a0",
"a6",
board,
"c1",
move_type=MoveType.MERGE_JUMP,
move_variant=MoveVariant.CAPTURE,
)
assert move1.to_str(0) == ""
assert move1.to_str(1) == "a0c1^a6"
assert move1.to_str(2) == "a0c1^a6:MERGE_JUMP:CAPTURE"
assert move1.to_str(3) == "a0c1^a6:MERGE_JUMP:CAPTURE:BLACK_ROOK->RED_PAWN"

move2 = Move(
"a0",
"b3",
board,
target2="c1",
move_type=MoveType.SPLIT_JUMP,
move_variant=MoveVariant.BASIC,
)
assert move2.to_str(0) == ""
assert move2.to_str(1) == "a0^b3c1"
assert move2.to_str(2) == "a0^b3c1:SPLIT_JUMP:BASIC"
assert move2.to_str(3) == "a0^b3c1:SPLIT_JUMP:BASIC:BLACK_ROOK->NA_EMPTY"

move3 = Move(
"a0", "a6", board, move_type=MoveType.SLIDE, move_variant=MoveVariant.CAPTURE
)
assert move3.to_str(0) == ""
assert move3.to_str(1) == "a0a6"
assert move3.to_str(2) == "a0a6:SLIDE:CAPTURE"
assert move3.to_str(3) == "a0a6:SLIDE:CAPTURE:BLACK_ROOK->RED_PAWN"

0 comments on commit 68be945

Please sign in to comment.