diff --git a/florist/tests/unit/api/servers/test_launch.py b/florist/tests/unit/api/servers/test_launch.py index 73fed65..17b5a3f 100644 --- a/florist/tests/unit/api/servers/test_launch.py +++ b/florist/tests/unit/api/servers/test_launch.py @@ -13,6 +13,8 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: test_n_clients = 2 test_server_address = "test-server-address" test_n_server_rounds = 5 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_server_process = "test-server-process" @@ -23,6 +25,8 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: test_n_clients, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, ) @@ -41,7 +45,13 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: ) assert call_kwargs == {"seconds_to_sleep": 0} assert call_args[0].func == get_server - assert call_args[0].keywords == {"model": test_model, "n_clients": test_n_clients, "metrics_reporter": ANY} + assert call_args[0].keywords == { + "model": test_model, + "n_clients": test_n_clients, + "batch_size": test_batch_size, + "local_epochs": test_local_epochs, + "metrics_reporter": ANY, + } metrics_reporter = call_args[0].keywords["metrics_reporter"] assert isinstance(metrics_reporter, RedisMetricsReporter)