Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/SK-1365 | Fix so server functions code resets by sessions #806

Merged
merged 8 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/server-functions/server_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from fedn.common.log_config import logger
from fedn.network.combiner.hooks.allowed_import import Dict, List, ServerFunctionsBase, Tuple, np, random

# See allowed_imports for what packages you can use in this class.
Expand Down Expand Up @@ -39,6 +38,6 @@ def aggregate(self, previous_global: List[np.ndarray], client_updates: Dict[str,
for i in range(len(weighted_sum)):
weighted_sum[i] += client_parameters[i] * num_examples

logger.info("Models aggregated")
print("Models aggregated")
averaged_updates = [weighted / total_weight for weighted in weighted_sum]
return averaged_updates
2 changes: 2 additions & 0 deletions fedn/network/combiner/hooks/allowed_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@

from fedn.common.log_config import logger # noqa: F401
from fedn.network.combiner.hooks.serverfunctionsbase import ServerFunctionsBase # noqa: F401

print = logger.info
18 changes: 13 additions & 5 deletions fedn/network/combiner/hooks/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import logger
from fedn.network.combiner.hooks.allowed_import import * # noqa: F403
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

går det inte att bara köra "import fedn.....allowed_imports" alltså skippa from så du inte behöver *?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Borde funka, testar


# imports for user code
from fedn.network.combiner.hooks.allowed_import import Dict, List, ServerFunctionsBase, Tuple, np, random # noqa: F401
from fedn.network.combiner.hooks.allowed_import import ServerFunctionsBase
from fedn.network.combiner.modelservice import bytesIO_request_generator, model_as_bytesIO, unpack_model
from fedn.utils.helpers.plugins.numpyhelper import Helper

Expand All @@ -30,7 +31,7 @@ def __init__(self) -> None:
self.server_functions: ServerFunctionsBase = None
self.server_functions_code: str = None
self.client_updates = {}
self.implemented_functions = None
self.implemented_functions = {}

def HandleClientConfig(self, request_iterator: fedn.ClientConfigRequest, context):
"""Distribute client configs to clients from user defined code.
Expand Down Expand Up @@ -122,13 +123,20 @@ def HandleProvidedFunctions(self, request: fedn.ProvidedFunctionsResponse, conte
:rtype: :class:`fedn.network.grpc.fedn_pb2.ProvidedFunctionsResponse`
"""
logger.info("Receieved provided functions request.")
if self.implemented_functions is not None:
return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions)
server_functions_code = request.function_code
# if no new code return previous
if server_functions_code == self.server_functions_code:
logger.info("No new server function code provided.")
logger.info(f"Provided function: {self.implemented_functions}")
return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions)

self.server_functions_code = server_functions_code
self.implemented_functions = {}
self._instansiate_server_functions_code()
# if crashed or not returning None we assume function is implemented
# We are not sending dummy values here since the implementation might depend on model shape / implementations

# Implemented=False if return is None (indicating base implementation)

# check if aggregation is available
try:
ret = self.server_functions.aggregate(0, 0)
Expand Down
1 change: 1 addition & 0 deletions fedn/network/combiner/roundhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def execute_training_round(self, config):
# Download model to update and set in temp storage.
self.stage_model(config["model_id"])

# dictionary to which functions are provided
provided_functions = self.hook_interface.provided_functions(self.server_functions)

if provided_functions.get("client_selection", False):
Expand Down
Loading