From c61d227803ce2942f260e6f6d0b4f1b44b8d3900 Mon Sep 17 00:00:00 2001 From: gatsik <74517072+Gatsik@users.noreply.github.com> Date: Sat, 28 May 2022 22:32:12 +0300 Subject: [PATCH] Round floats before encoding --- server/config.py | 2 ++ server/protocol/protocol.py | 18 ++++++++++- tests/integration_tests/test_game.py | 8 +++-- tests/unit_tests/test_protocol.py | 45 ++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 3 deletions(-) diff --git a/server/config.py b/server/config.py index f416b9f3c..af2a76df6 100644 --- a/server/config.py +++ b/server/config.py @@ -124,6 +124,8 @@ def __init__(self): self.QUEUE_POP_DESIRED_MATCHES = 2.5 # How many previous queue sizes to consider self.QUEUE_POP_TIME_MOVING_AVG_SIZE = 5 + # The maximum number of decimal places to use for float serialization + self.JSON_FLOAT_MAX_DIGITS = 2 self._defaults = { key: value for key, value in vars(self).items() if key.isupper() diff --git a/server/protocol/protocol.py b/server/protocol/protocol.py index a525f5a52..4b51b90c9 100644 --- a/server/protocol/protocol.py +++ b/server/protocol/protocol.py @@ -4,10 +4,26 @@ from asyncio import StreamReader, StreamWriter import server.metrics as metrics +from server.config import config from ..asyncio_extensions import synchronizedmethod -json_encoder = json.JSONEncoder(separators=(",", ":")) + +class CustomJSONEncoder(json.JSONEncoder): + # taken from https://stackoverflow.com/a/53798633 + def encode(self, o): + def round_floats(o): + if isinstance(o, float): + return round(o, config.JSON_FLOAT_MAX_DIGITS) + if isinstance(o, dict): + return {k: round_floats(v) for k, v in o.items()} + if isinstance(o, (list, tuple)): + return [round_floats(x) for x in o] + return o + return super().encode(round_floats(o)) + + +json_encoder = CustomJSONEncoder(separators=(",", ":")) class DisconnectedError(ConnectionError): diff --git a/tests/integration_tests/test_game.py b/tests/integration_tests/test_game.py index 7fa648aaf..a7366bd60 100644 --- a/tests/integration_tests/test_game.py +++ b/tests/integration_tests/test_game.py @@ -307,7 +307,10 @@ async def test_game_ended_rates_game(lobby_server): @pytest.mark.rabbitmq @fast_forward(30) -async def test_game_ended_broadcasts_rating_update(lobby_server, channel): +async def test_game_ended_broadcasts_rating_update( + lobby_server, channel, mocker, +): + mocker.patch("server.config.JSON_FLOAT_MAX_DIGITS", 4) mq_proto_all = await connect_mq_consumer( lobby_server, channel, @@ -611,12 +614,13 @@ async def test_partial_game_ended_rates_game(lobby_server, tmp_user): @fast_forward(100) -async def test_ladder_game_draw_bug(lobby_server, database): +async def test_ladder_game_draw_bug(lobby_server, database, mocker): """ This simulates the infamous "draw bug" where a player could self destruct their own ACU in order to kill the enemy ACU and be awarded a victory instead of a draw. """ + mocker.patch("server.config.JSON_FLOAT_MAX_DIGITS", 13) player1_id, proto1, player2_id, proto2 = await queue_players_for_matchmaking(lobby_server) msg1, msg2 = await asyncio.gather(*[ diff --git a/tests/unit_tests/test_protocol.py b/tests/unit_tests/test_protocol.py index 69b37902a..be5f398cd 100644 --- a/tests/unit_tests/test_protocol.py +++ b/tests/unit_tests/test_protocol.py @@ -13,6 +13,7 @@ QDataStreamProtocol, SimpleJsonProtocol ) +from server.protocol.protocol import json_encoder @pytest.fixture(scope="session") @@ -256,3 +257,47 @@ async def test_read_when_disconnected(protocol): with pytest.raises(DisconnectedError): await protocol.read_message() + + +def test_json_encoder_float_serialization(): + assert json_encoder.encode(123.0) == "123.0" + assert json_encoder.encode(0.99) == "0.99" + assert json_encoder.encode(0.999) == "1.0" + + +@given(message=st_messages()) +def test_json_encoder_encodes_server_messages(message): + new_encode = json_encoder.encode + old_encode = json.JSONEncoder(separators=(",", ":")).encode + + assert new_encode(message) == old_encode(message) + + +def st_dictionaries(): + value_types = ( + st.booleans(), + st.text(), + st.integers(), + st.none(), + ) + key_types = (*value_types, st.floats()) + return st.dictionaries( + keys=st.one_of(*key_types), + values=st.one_of( + *value_types, + st.lists(st.one_of(*value_types)), + st.tuples(st.one_of(*value_types)), + ) + ) + + +@ given(dct=st_dictionaries()) +def test_json_encoder_encodes_dicts(dct): + old_encode = json.JSONEncoder(separators=(",", ":")).encode + new_encode = json_encoder.encode + + assert new_encode(dct) == old_encode(dct) + + wrong_dict_key = (1, 2) + with pytest.raises(TypeError): + json_encoder.encode({wrong_dict_key: "a"})