diff --git a/server/protocol/protocol.py b/server/protocol/protocol.py index e0f73d963..f8a78cb90 100644 --- a/server/protocol/protocol.py +++ b/server/protocol/protocol.py @@ -14,13 +14,21 @@ class CustomJSONEncoder(json.JSONEncoder): # taken from https://stackoverflow.com/a/60243503 def encode(self, o): if isinstance(o, Mapping): - dict_content = ", ".join( - f"{self.encode(key)}: {self.encode(value)}" - for (key, value) in o.items() - ) - return '{' + dict_content + '}' + # keys must be str, int, float, bool or None + dict_content = [] + for key, value in o.items(): + # do not use precision for float dict keys + if (isinstance(key, int) or isinstance(key, float)): + new_key = f'"{key}"'.lower() + elif key is None: + new_key = '"null"' + else: + new_key = self.encode(key) + new_value = self.encode(value) + dict_content.append(f"{new_key}:{new_value}") + return "{" + ",".join(dict_content) + "}" elif isinstance(o, Iterable) and not isinstance(o, str): - return "[" + ", ".join(map(self.encode, o)) + "]" + return "[" + ",".join(map(self.encode, o)) + "]" elif isinstance(o, float): return f"{o:.{config.JSON_FLOAT_DIGITS_PRECISION}f}" else: diff --git a/tests/unit_tests/test_protocol.py b/tests/unit_tests/test_protocol.py index aa4a72aa5..23c0ccd16 100644 --- a/tests/unit_tests/test_protocol.py +++ b/tests/unit_tests/test_protocol.py @@ -265,14 +265,15 @@ def test_json_encoder_float_serialization(): assert json_encoder.encode(123.0) == '123.00' assert json_encoder.encode(0.99) == '0.99' assert json_encoder.encode(0.999) == '1.00' - assert json_encoder.encode('simple words') == '"simple words"' - assert json_encoder.encode([1, 2, 3]) == '[1, 2, 3]' - assert json_encoder.encode({'a': 1}) == '{"a": 1}' - # Check that we did not brake the original behaviour of other things + +def test_json_encoder_encodes_server_messages(): + new_encode = json_encoder.encode + old_encode = json.JSONEncoder(separators=(",", ":")).encode + delta_ladder, delta_tmm = 200.3654233558, 19.879223123 - info_dict = { + matchmaker_info_dict = { "command": "matchmaker_info", "queues": [ { @@ -300,13 +301,68 @@ def test_json_encoder_float_serialization(): ] } - expected_info_dict = copy.deepcopy(info_dict) + expected_matchmaker_info_dict = copy.deepcopy(matchmaker_info_dict) - for queue in expected_info_dict["queues"]: + for queue in expected_matchmaker_info_dict["queues"]: queue["queue_pop_time_delta"] = round( queue["queue_pop_time_delta"], config.JSON_FLOAT_DIGITS_PRECISION, ) - assert json_encoder.encode(info_dict) == ( - json.JSONEncoder().encode(expected_info_dict) + assert new_encode(matchmaker_info_dict) == ( + old_encode(expected_matchmaker_info_dict) ) + + game_info_dict = { + "command": "game_info", + "visibility": "public", + "password_protected": True, + "uid": 13, + "title": "someone's game", + "state": "playing", + "game_type": "custom", + "featured_mod": "faf", + "sim_mods": {}, + "mapname": "scmp_009", + "map_file_path": "maps/scmp_009.zip", + "host": "Foo", + "num_players": 2, + "launched_at": 1111111111.11, + "rating_type": "faf", + "rating_min": None, + "rating_max": None, + "enforce_rating_range": False, + "team_ids": [ + { + "team_id": 1, + "player_ids": [1], + }, + ], + "teams": { + 1: ["Foo"], + 2: ["Bar"], + }, + } + assert new_encode(game_info_dict) == old_encode(game_info_dict) + + +@pytest.mark.parametrize( + "dict_key", [True, None, '"double_quoted"', "'single_quoted'", 1, 1.22], +) +@pytest.mark.parametrize( + "dict_value", [True, None, '"double_quoted"', "'single quoted'", 1, 1.22], +) +def test_json_encoder_encodes_dicts(dict_key, dict_value): + old_encode = json.JSONEncoder(separators=(",", ":")).encode + new_encode = json_encoder.encode + + assert new_encode({dict_key: dict_value}) == ( + old_encode({dict_key: dict_value}) + ) + + +def test_json_encoder_encodes_lists(): + old_encode = json.JSONEncoder(separators=(",", ":")).encode + new_encode = json_encoder.encode + + weird_list = [True, None, '"double_quoted"', "'single_quoted'", 1, 1.22] + assert new_encode(weird_list) == old_encode(weird_list)