Skip to content

Commit

Permalink
Add integration tests with Redis (#11)
Browse files Browse the repository at this point in the history
* Setting up a Redis instance in the integration tests github workflow
* Making an integration test for end-to-end training
  • Loading branch information
lotif authored Mar 13, 2024
1 parent 253e3c0 commit d8f7055
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
source $(poetry env info --path)/bin/activate
poetry install --with docs,test
cd docs && rm -rf source/reference/api/_autosummary && make html
cd .. && coverage run -m pytest -m "not integration_test" && coverage xml && coverage report -m
cd .. && coverage run -m pytest florist/tests/unit && coverage xml && coverage report -m
# - name: Upload coverage to Codecov
# uses: Wandalen/[email protected]
# with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
source $(poetry env info --path)/bin/activate
poetry install --with docs,test
cd docs && rm -rf source/reference/api/_autosummary && make html
cd .. && coverage run -m pytest -m "not integration_test" && coverage xml && coverage report -m
cd .. && coverage run -m pytest florist/tests/unit && coverage xml && coverage report -m
# - name: Upload coverage to Codecov
# uses: Wandalen/[email protected]
# with:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/integration_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ jobs:
- uses: actions/[email protected]
with:
python-version: '3.9'
- name: Setup redis
uses: supercharge/[email protected]
with:
redis-version: 7.2.4
- name: Install dependencies and check code
run: |
poetry env use '3.9'
Expand Down
13 changes: 13 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ Then, run the server and client's Redis instance by following
[these instructions](README.md#start-servers-redis-instance) and
[these instructions](README.md#start-clients-redis-instance) respectively.

## Running the tests

To run the unit tests, simply execute:
```shell
pytest florist/tests/unit
```

To run the integration tests, first make sure you have a Redis server running on your
local machine on port 6379, then execute:
```shell
pytest florist/tests/integration
```

## Coding guidelines

For code style, we recommend the [PEP 8 style guide](https://peps.python.org/pep-0008/).
Expand Down
28 changes: 3 additions & 25 deletions florist/tests/integration/api/launchers/test_launch.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,20 @@
import os
import re
import tempfile
from functools import partial
from pathlib import Path
from typing import Callable, Dict

import torch
from fl4health.server.base_server import FlServer

from florist.api.launchers.local import launch
from florist.api.clients.mnist import MnistClient, MnistNet
from florist.tests.utils.api.fl4health_utils import get_server_fedavg


def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]:
return {
"batch_size": batch_size,
"current_server_round": current_server_round,
"local_epochs": local_epochs,
}


def get_server(
fit_config: Callable[..., Dict[str, int]] = fit_config,
n_clients: int = 2,
batch_size: int = 8,
local_epochs: int = 1,
) -> FlServer:
fit_config_fn = partial(fit_config, batch_size, local_epochs)
server = get_server_fedavg(model=MnistNet(), n_clients=n_clients, fit_config_fn=fit_config_fn)
return server
from florist.api.clients.mnist import MnistClient
from florist.tests.utils.api.launch_utils import get_server


def assert_string_in_file(file_path: str, search_string: str) -> bool:
with open(file_path, "r") as f:
file_contents = f.read()
match = re.search(search_string, file_contents)
return match is not None
assert match is not None


def test_launch() -> None:
Expand Down
32 changes: 32 additions & 0 deletions florist/tests/integration/api/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
import tempfile
from functools import partial
from unittest.mock import ANY

from florist.api import client
from florist.api.launchers.local import launch_server
from florist.tests.utils.api.launch_utils import get_server


def test_train():
test_server_address = "0.0.0.0:8080"

with tempfile.TemporaryDirectory() as temp_dir:
server_constructor = partial(get_server, n_clients=1)
server_log_file = f"{temp_dir}/server.out"
server_process = launch_server(server_constructor, test_server_address, 2, server_log_file)

test_client = "MNIST"
test_data_path = f"{temp_dir}/data"
test_redis_host = "localhost"
test_redis_port = "6379"

response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port)

assert json.loads(response.body.decode()) == {"uuid": ANY}

server_process.join()

with open(server_log_file, "r") as f:
file_contents = f.read()
assert "FL finished in" in file_contents
26 changes: 26 additions & 0 deletions florist/tests/utils/api/launch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from functools import partial
from typing import Callable, Dict

from fl4health.server.base_server import FlServer

from florist.tests.utils.api.fl4health_utils import get_server_fedavg
from florist.api.clients.mnist import MnistNet


def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]:
return {
"batch_size": batch_size,
"current_server_round": current_server_round,
"local_epochs": local_epochs,
}


def get_server(
fit_config: Callable[..., Dict[str, int]] = fit_config,
n_clients: int = 2,
batch_size: int = 8,
local_epochs: int = 1,
) -> FlServer:
fit_config_fn = partial(fit_config, batch_size, local_epochs)
server = get_server_fedavg(model=MnistNet(), n_clients=n_clients, fit_config_fn=fit_config_fn)
return server

0 comments on commit d8f7055

Please sign in to comment.