Skip to content

Commit

Permalink
Adding batch_size and local_epochs to server params
Browse files Browse the repository at this point in the history
  • Loading branch information
lotif committed Mar 27, 2024
1 parent a6aa46c commit 349bfe9
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
6 changes: 6 additions & 0 deletions florist/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def start_training(
model: Annotated[str, Form()],
server_address: Annotated[str, Form()],
n_server_rounds: Annotated[int, Form()],
batch_size: Annotated[int, Form()],
local_epochs: Annotated[int, Form()],
redis_host: Annotated[str, Form()],
redis_port: Annotated[str, Form()],
clients_info: Annotated[str, Form()],
Expand All @@ -37,6 +39,8 @@ def start_training(
:param server_address: (str) The address of the FL server to be started. It should be comprised of
the host name and port separated by colon (e.g. "localhost:8080")
:param n_server_rounds: (int) The number of rounds the FL server should run.
:param batch_size: (int) The size of the batch for training
:param local_epochs: (int) The number of epochs to run by the clients
:param redis_host: (str) The host name for the Redis instance for metrics reporting.
:param redis_port: (str) The port for the Redis instance for metrics reporting.
:param clients_info: (str) A JSON string containing the client information. It will be parsed by
Expand Down Expand Up @@ -74,6 +78,8 @@ def start_training(
n_clients=len(clients_info_list),
server_address=server_address,
n_server_rounds=n_server_rounds,
batch_size=batch_size,
local_epochs=local_epochs,
redis_host=redis_host,
redis_port=redis_port,
)
Expand Down
13 changes: 12 additions & 1 deletion florist/api/servers/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def launch_local_server(
n_clients: int,
server_address: str,
n_server_rounds: int,
batch_size: int,
local_epochs: int,
redis_host: str,
redis_port: str,
) -> Tuple[str, Process]:
Expand All @@ -27,6 +29,8 @@ def launch_local_server(
:param n_clients: (int) The number of clients that will report to this server.
:param server_address: (str) The address the server should start at.
:param n_server_rounds: (int) The number of rounds the training should run for.
:param batch_size: (int) The size of the batch for training
:param local_epochs: (int) The number of epochs to run by the clients
:param redis_host: (str) the host name for the Redis instance for metrics reporting.
:param redis_port: (str) the port for the Redis instance for metrics reporting.
:return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull
Expand All @@ -35,7 +39,14 @@ def launch_local_server(
server_uuid = str(uuid.uuid4())

metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid)
server_constructor = partial(get_server, model=model, n_clients=n_clients, metrics_reporter=metrics_reporter)
server_constructor = partial(
get_server,
model=model,
n_clients=n_clients,
batch_size=batch_size,
local_epochs=local_epochs,
metrics_reporter=metrics_reporter,
)

log_file_name = str(get_server_log_file_path(server_uuid))
server_process = launch_server(
Expand Down
2 changes: 2 additions & 0 deletions florist/tests/integration/api/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_train():
"model": (None, "MNIST"),
"server_address": (None, "localhost:8080"),
"n_server_rounds": (None, test_n_server_rounds),
"batch_size": (None, 8),
"local_epochs": (None, 1),
"redis_host": (None, test_redis_host),
"redis_port": (None, test_redis_port),
"clients_info": (None, json.dumps(
Expand Down
34 changes: 34 additions & 0 deletions florist/tests/unit/api/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun
test_model = "MNIST"
test_server_address = "test-server-address"
test_n_server_rounds = 2
test_batch_size = 8
test_local_epochs = 1
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_clients_info = [
Expand Down Expand Up @@ -49,6 +51,8 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun
test_model,
test_server_address,
test_n_server_rounds,
test_batch_size,
test_local_epochs,
test_redis_host,
test_redis_port,
json.dumps(test_clients_info),
Expand All @@ -65,6 +69,8 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun
n_clients=len(test_clients_info),
server_address=test_server_address,
n_server_rounds=test_n_server_rounds,
batch_size=test_batch_size,
local_epochs=test_local_epochs,
redis_host=test_redis_host,
redis_port=test_redis_port,
)
Expand Down Expand Up @@ -97,6 +103,8 @@ def test_start_fail_unsupported_server_model() -> None:
test_model = "WRONG MODEL"
test_server_address = "test-server-address"
test_n_server_rounds = 2
test_batch_size = 8
test_local_epochs = 1
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_clients_info = [
Expand All @@ -120,6 +128,8 @@ def test_start_fail_unsupported_server_model() -> None:
test_model,
test_server_address,
test_n_server_rounds,
test_batch_size,
test_local_epochs,
test_redis_host,
test_redis_port,
json.dumps(test_clients_info),
Expand All @@ -137,6 +147,8 @@ def test_start_fail_unsupported_client() -> None:
test_model = "MNIST"
test_server_address = "test-server-address"
test_n_server_rounds = 2
test_batch_size = 8
test_local_epochs = 1
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_clients_info = [
Expand All @@ -160,6 +172,8 @@ def test_start_fail_unsupported_client() -> None:
test_model,
test_server_address,
test_n_server_rounds,
test_batch_size,
test_local_epochs,
test_redis_host,
test_redis_port,
json.dumps(test_clients_info),
Expand All @@ -178,6 +192,8 @@ def test_start_training_launch_server_exception(mock_launch_local_server: Mock)
test_model = "MNIST"
test_server_address = "test-server-address"
test_n_server_rounds = 2
test_batch_size = 8
test_local_epochs = 1
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_clients_info = [
Expand All @@ -203,6 +219,8 @@ def test_start_training_launch_server_exception(mock_launch_local_server: Mock)
test_model,
test_server_address,
test_n_server_rounds,
test_batch_size,
test_local_epochs,
test_redis_host,
test_redis_port,
json.dumps(test_clients_info),
Expand All @@ -221,6 +239,8 @@ def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_ser
test_model = "MNIST"
test_server_address = "test-server-address"
test_n_server_rounds = 2
test_batch_size = 8
test_local_epochs = 1
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_clients_info = [
Expand Down Expand Up @@ -249,6 +269,8 @@ def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_ser
test_model,
test_server_address,
test_n_server_rounds,
test_batch_size,
test_local_epochs,
test_redis_host,
test_redis_port,
json.dumps(test_clients_info),
Expand All @@ -268,6 +290,8 @@ def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_lo
test_model = "MNIST"
test_server_address = "test-server-address"
test_n_server_rounds = 2
test_batch_size = 8
test_local_epochs = 1
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_clients_info = [
Expand Down Expand Up @@ -297,6 +321,8 @@ def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_lo
test_model,
test_server_address,
test_n_server_rounds,
test_batch_size,
test_local_epochs,
test_redis_host,
test_redis_port,
json.dumps(test_clients_info),
Expand All @@ -316,6 +342,8 @@ def test_start_training_fail_response(mock_requests: Mock, mock_redis: Mock, moc
test_model = "MNIST"
test_server_address = "test-server-address"
test_n_server_rounds = 2
test_batch_size = 8
test_local_epochs = 1
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_clients_info = [
Expand Down Expand Up @@ -350,6 +378,8 @@ def test_start_training_fail_response(mock_requests: Mock, mock_redis: Mock, moc
test_model,
test_server_address,
test_n_server_rounds,
test_batch_size,
test_local_epochs,
test_redis_host,
test_redis_port,
json.dumps(test_clients_info),
Expand All @@ -369,6 +399,8 @@ def test_start_training_no_uuid_in_response(mock_requests: Mock, mock_redis: Moc
test_model = "MNIST"
test_server_address = "test-server-address"
test_n_server_rounds = 2
test_batch_size = 8
test_local_epochs = 1
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_clients_info = [
Expand Down Expand Up @@ -403,6 +435,8 @@ def test_start_training_no_uuid_in_response(mock_requests: Mock, mock_redis: Moc
test_model,
test_server_address,
test_n_server_rounds,
test_batch_size,
test_local_epochs,
test_redis_host,
test_redis_port,
json.dumps(test_clients_info),
Expand Down

0 comments on commit 349bfe9

Please sign in to comment.