diff --git a/server/config.py b/server/config.py index f416b9f3c..910b0524b 100644 --- a/server/config.py +++ b/server/config.py @@ -124,6 +124,8 @@ def __init__(self): self.QUEUE_POP_DESIRED_MATCHES = 2.5 # How many previous queue sizes to consider self.QUEUE_POP_TIME_MOVING_AVG_SIZE = 5 + # The number of decimal places to use for float serialization + self.JSON_FLOAT_DIGITS_PRECISION = 2 self._defaults = { key: value for key, value in vars(self).items() if key.isupper() diff --git a/server/protocol/protocol.py b/server/protocol/protocol.py index a525f5a52..d996b1455 100644 --- a/server/protocol/protocol.py +++ b/server/protocol/protocol.py @@ -2,12 +2,32 @@ import json from abc import ABCMeta, abstractmethod from asyncio import StreamReader, StreamWriter +from typing import Iterable, Mapping import server.metrics as metrics +from server.config import config from ..asyncio_extensions import synchronizedmethod -json_encoder = json.JSONEncoder(separators=(",", ":")) + +class CustomJSONEncoder(json.JSONEncoder): + # taken from https://stackoverflow.com/a/60243503 + def encode(self, obj): + if isinstance(obj, Mapping): + dict_content = ", ".join( + f"{self.encode(key)}: {self.encode(value)}" + for (key, value) in obj.items() + ) + return '{' + dict_content + '}' + elif isinstance(obj, Iterable) and not isinstance(obj, str): + return "[" + ", ".join(map(self.encode, obj)) + "]" + elif isinstance(obj, float): + return f"{obj:.{config.JSON_FLOAT_DIGITS_PRECISION}f}" + else: + return super().encode(obj) + + +json_encoder = CustomJSONEncoder(separators=(",", ":")) class DisconnectedError(ConnectionError): diff --git a/tests/unit_tests/test_protocol.py b/tests/unit_tests/test_protocol.py index 69b37902a..aa4a72aa5 100644 --- a/tests/unit_tests/test_protocol.py +++ b/tests/unit_tests/test_protocol.py @@ -1,4 +1,5 @@ import asyncio +import copy import json import struct from contextlib import asynccontextmanager, closing @@ -8,11 +9,13 @@ from hypothesis import example, given, settings from hypothesis import strategies as st +from server.config import config from server.protocol import ( DisconnectedError, QDataStreamProtocol, SimpleJsonProtocol ) +from server.protocol.protocol import json_encoder @pytest.fixture(scope="session") @@ -256,3 +259,54 @@ async def test_read_when_disconnected(protocol): with pytest.raises(DisconnectedError): await protocol.read_message() + + +def test_json_encoder_float_serialization(): + assert json_encoder.encode(123.0) == '123.00' + assert json_encoder.encode(0.99) == '0.99' + assert json_encoder.encode(0.999) == '1.00' + assert json_encoder.encode('simple words') == '"simple words"' + assert json_encoder.encode([1, 2, 3]) == '[1, 2, 3]' + assert json_encoder.encode({'a': 1}) == '{"a": 1}' + + # Check that we did not brake the original behaviour of other things + delta_ladder, delta_tmm = 200.3654233558, 19.879223123 + + info_dict = { + "command": "matchmaker_info", + "queues": [ + { + "queue_name": "ladder1v1", + "queue_pop_time": "2002-01-01T23:59:59.595959", + "queue_pop_time_delta": delta_ladder, + "num_players": 3, + "boundary_80s": [(1200, 1600), (1600, 2000), (300, 700)], + "boundary_75s": [(1300, 1500), (1700, 1900), (400, 600)], + "team_size": 1, + }, + { + "queue_name": "tmm2v2", + "queue_pop_time": "2002-01-01T23:56:23.234255", + "queue_pop_time_delta": delta_tmm, + "num_players": 4, + "boundary_80s": [ + (1200, 1600), (1600, 2000), (300, 700), (300, 700), + ], + "boundary_75s": [ + (1300, 1500), (1700, 1900), (400, 600), (400, 600), + ], + "team_size": 2, + }, + ] + } + + expected_info_dict = copy.deepcopy(info_dict) + + for queue in expected_info_dict["queues"]: + queue["queue_pop_time_delta"] = round( + queue["queue_pop_time_delta"], config.JSON_FLOAT_DIGITS_PRECISION, + ) + + assert json_encoder.encode(info_dict) == ( + json.JSONEncoder().encode(expected_info_dict) + )