Skip to content

Commit

Permalink
Round floats before encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Gatsik committed May 28, 2022
1 parent 16c70ce commit c61d227
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 3 deletions.
2 changes: 2 additions & 0 deletions server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def __init__(self):
self.QUEUE_POP_DESIRED_MATCHES = 2.5
# How many previous queue sizes to consider
self.QUEUE_POP_TIME_MOVING_AVG_SIZE = 5
# The maximum number of decimal places to use for float serialization
self.JSON_FLOAT_MAX_DIGITS = 2

self._defaults = {
key: value for key, value in vars(self).items() if key.isupper()
Expand Down
18 changes: 17 additions & 1 deletion server/protocol/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,26 @@
from asyncio import StreamReader, StreamWriter

import server.metrics as metrics
from server.config import config

from ..asyncio_extensions import synchronizedmethod

json_encoder = json.JSONEncoder(separators=(",", ":"))

class CustomJSONEncoder(json.JSONEncoder):
# taken from https://stackoverflow.com/a/53798633
def encode(self, o):
def round_floats(o):
if isinstance(o, float):
return round(o, config.JSON_FLOAT_MAX_DIGITS)
if isinstance(o, dict):
return {k: round_floats(v) for k, v in o.items()}
if isinstance(o, (list, tuple)):
return [round_floats(x) for x in o]
return o
return super().encode(round_floats(o))


json_encoder = CustomJSONEncoder(separators=(",", ":"))


class DisconnectedError(ConnectionError):
Expand Down
8 changes: 6 additions & 2 deletions tests/integration_tests/test_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,10 @@ async def test_game_ended_rates_game(lobby_server):

@pytest.mark.rabbitmq
@fast_forward(30)
async def test_game_ended_broadcasts_rating_update(lobby_server, channel):
async def test_game_ended_broadcasts_rating_update(
lobby_server, channel, mocker,
):
mocker.patch("server.config.JSON_FLOAT_MAX_DIGITS", 4)
mq_proto_all = await connect_mq_consumer(
lobby_server,
channel,
Expand Down Expand Up @@ -611,12 +614,13 @@ async def test_partial_game_ended_rates_game(lobby_server, tmp_user):


@fast_forward(100)
async def test_ladder_game_draw_bug(lobby_server, database):
async def test_ladder_game_draw_bug(lobby_server, database, mocker):
"""
This simulates the infamous "draw bug" where a player could self destruct
their own ACU in order to kill the enemy ACU and be awarded a victory
instead of a draw.
"""
mocker.patch("server.config.JSON_FLOAT_MAX_DIGITS", 13)
player1_id, proto1, player2_id, proto2 = await queue_players_for_matchmaking(lobby_server)

msg1, msg2 = await asyncio.gather(*[
Expand Down
45 changes: 45 additions & 0 deletions tests/unit_tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
QDataStreamProtocol,
SimpleJsonProtocol
)
from server.protocol.protocol import json_encoder


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -256,3 +257,47 @@ async def test_read_when_disconnected(protocol):

with pytest.raises(DisconnectedError):
await protocol.read_message()


def test_json_encoder_float_serialization():
assert json_encoder.encode(123.0) == "123.0"
assert json_encoder.encode(0.99) == "0.99"
assert json_encoder.encode(0.999) == "1.0"


@given(message=st_messages())
def test_json_encoder_encodes_server_messages(message):
new_encode = json_encoder.encode
old_encode = json.JSONEncoder(separators=(",", ":")).encode

assert new_encode(message) == old_encode(message)


def st_dictionaries():
value_types = (
st.booleans(),
st.text(),
st.integers(),
st.none(),
)
key_types = (*value_types, st.floats())
return st.dictionaries(
keys=st.one_of(*key_types),
values=st.one_of(
*value_types,
st.lists(st.one_of(*value_types)),
st.tuples(st.one_of(*value_types)),
)
)


@ given(dct=st_dictionaries())
def test_json_encoder_encodes_dicts(dct):
old_encode = json.JSONEncoder(separators=(",", ":")).encode
new_encode = json_encoder.encode

assert new_encode(dct) == old_encode(dct)

wrong_dict_key = (1, 2)
with pytest.raises(TypeError):
json_encoder.encode({wrong_dict_key: "a"})

0 comments on commit c61d227

Please sign in to comment.