From b5199b1f1c16718afbc30ced7b6343fb560c5118 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 28 Feb 2024 09:24:45 -0500 Subject: [PATCH] Fix formatting --- florist/tests/integration/api/test_launch.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/florist/tests/integration/api/test_launch.py b/florist/tests/integration/api/test_launch.py index f29543b1..f5965e7c 100644 --- a/florist/tests/integration/api/test_launch.py +++ b/florist/tests/integration/api/test_launch.py @@ -13,9 +13,7 @@ from florist.tests.utils.api.models import MnistNet -def fit_config( - batch_size: int, local_epochs: int, current_server_round: int -) -> Dict[str, int]: +def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]: return { "batch_size": batch_size, "current_server_round": current_server_round, @@ -30,9 +28,7 @@ def get_server( local_epochs: int = 1, ) -> FlServer: fit_config_fn = partial(fit_config, batch_size, local_epochs) - server = get_server_fedavg( - model=MnistNet(), n_clients=n_clients, fit_config_fn=fit_config_fn - ) + server = get_server_fedavg(model=MnistNet(), n_clients=n_clients, fit_config_fn=fit_config_fn) return server @@ -52,10 +48,7 @@ def test_launch() -> None: client_data_paths = [Path(f"{temp_dir}/{i}") for i in range(n_clients)] for client_data_path in client_data_paths: os.mkdir(client_data_path) - clients = [ - MnistClient(client_data_path, [], torch.device("cpu")) - for client_data_path in client_data_paths - ] + clients = [MnistClient(client_data_path, [], torch.device("cpu")) for client_data_path in client_data_paths] server_path = os.path.join(temp_dir, "server.out") client_base_path = f"{temp_dir}/client"