Skip to content

Commit

Permalink
Add a bit more documentation, fix typing and pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jewelltaylor committed Feb 22, 2024
1 parent 5400716 commit 771af14
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
16 changes: 14 additions & 2 deletions florist/api/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,20 @@
from flwr.server import ServerConfig


def redirect_logging_from_console_to_file(log_file_name: str) -> None:
fh = logging.FileHandler(log_file_name)
def redirect_logging_from_console_to_file(log_file_path: str) -> None:
"""
Function that redirects loggers outputing to console to specified file.
Args:
log_file_name (str): The path to the file to log to.
"""

# Define file handler to log to and set format
fh = logging.FileHandler(log_file_path)
fh.setFormatter(DEFAULT_FORMATTER)

# Loop through existing loggers to check if they have one or more streamhandlers
# If they do, remove them (to prevent logging to the console) and add filehandler
for name in logging.root.manager.loggerDict:
logger = logging.getLogger(name)
if not all([isinstance(h, logging.StreamHandler) is False for h in logger.handlers]):
Expand All @@ -40,6 +50,7 @@ def start_server(
log_file_name = "server.out"
redirect_logging_from_console_to_file(log_file_name)
log_file = open(log_file_name, "a")
# Send remaining ouput (ie print) from stdout and stderr to file
sys.stdout = sys.stderr = log_file
server = server_constructor()
fl.server.start_server(
Expand All @@ -63,6 +74,7 @@ def start_client(client: BasicClient, server_address: str) -> None:
log_file_name = f"client_{str(os.getpid())}.out"
redirect_logging_from_console_to_file(log_file_name)
log_file = open(log_file_name, "a")
# Send remaining ouput (ie print) from stdout and stderr to file
sys.stdout = sys.stderr = log_file
fl.client.start_numpy_client(server_address=server_address, client=client)
client.shutdown()
Expand Down
19 changes: 8 additions & 11 deletions florist/tests/api/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from florist.tests.api.utils.models import MnistNet


def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict:
def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]:
return {
"batch_size": batch_size,
"current_server_round": current_server_round,
Expand All @@ -21,7 +21,7 @@ def fit_config(batch_size: int, local_epochs: int, current_server_round: int) ->


def get_server(
fit_config: Callable = fit_config,
fit_config: Callable[..., Dict[str, int]] = fit_config,
n_clients: int = 2,
batch_size: int = 8,
local_epochs: int = 1,
Expand All @@ -42,12 +42,9 @@ def test_launch() -> None:
os.mkdir(client_data_path)
clients = [MnistClient(client_data_path, [], torch.device("cpu")) for client_data_path in client_data_paths]

try:
launch(
get_server,
server_address,
n_server_rounds,
clients,
)
finally:
pass
launch(
get_server,
server_address,
n_server_rounds,
clients,
)

0 comments on commit 771af14

Please sign in to comment.