Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Jun 10, 2024
1 parent f267971 commit b3310ac
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 28 deletions.
8 changes: 3 additions & 5 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)
from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream
from fedn.network.combiner.connect import ConnectorCombiner, Status
from fedn.network.combiner.modelservice import ModelService
from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler
Expand Down Expand Up @@ -66,7 +65,6 @@ def __init__(self, config):
# Client queues
self.clients = {}


# Validate combiner name
match = re.search(VALID_NAME_REGEX, config["name"])
if not match:
Expand Down Expand Up @@ -196,7 +194,7 @@ def request_model_validation(self, session_id, model_id, clients=[]):
else:
logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients)))

def request_model_inference(self, session_id: str, model_id: str, clients: list=[]) -> None:
def request_model_inference(self, session_id: str, model_id: str, clients: list = []) -> None:
"""Ask clients to perform inference on the model.
:param model_id: the model id to perform inference on
Expand Down Expand Up @@ -250,7 +248,7 @@ def _send_request_type(self, request_type, session_id, model_id, config=None, cl
if len(clients) == 0:
# TODO: add inference clients type
clients = self.get_active_validators()

# TODO: if inference, request.data should be user-defined data/parameters

for client in clients:
Expand Down
14 changes: 8 additions & 6 deletions fedn/network/combiner/roundhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@

from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator
from fedn.network.combiner.modelservice import (load_model_from_BytesIO,
serialize_model_to_BytesIO)
from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO
from fedn.utils.helpers.helpers import get_helper
from fedn.utils.parameters import Parameters


class RoundConfig(TypedDict):
"""Round configuration.
:param _job_id: A universally unique identifier for the round. Set by Combiner.
:type _job_id: str
:param committed_at: The time the round was committed. Set by Controller.
Expand Down Expand Up @@ -47,6 +46,7 @@ class RoundConfig(TypedDict):
:param aggregator: The aggregator type.
:type aggregator: str
"""

_job_id: str
committed_at: str
task: str
Expand All @@ -62,6 +62,8 @@ class RoundConfig(TypedDict):
session_id: str
helper_type: str
aggregator: str


class ModelUpdateError(Exception):
pass

Expand Down Expand Up @@ -246,7 +248,7 @@ def _validation_round(self, session_id, model_id, clients):
:type model_id: str
"""
self.server.request_model_validation(session_id, model_id, clients=clients)

def _inference_round(self, session_id: str, model_id: str, clients: list):
"""Send model inference requests to clients.
Expand Down Expand Up @@ -346,7 +348,7 @@ def execute_validation_round(self, session_id, model_id):
self.stage_model(model_id)
validators = self._assign_round_clients(self.server.max_clients, type="validators")
self._validation_round(session_id, model_id, validators)

def execute_inference_round(self, session_id: str, model_id: str) -> None:
"""Coordinate inference rounds as specified in config.
Expand Down Expand Up @@ -423,7 +425,7 @@ def run(self, polling_interval=1.0):
self.server.statestore.set_round_combiner_data(round_meta)
elif round_config["task"] == "validation":
self.execute_validation_round(session_id, model_id)
elif round_config["task"] == "inference":
elif round_config["task"] == "inference":
self.execute_inference_round(session_id, model_id)
else:
logger.warning("config contains unkown task type.")
Expand Down
17 changes: 7 additions & 10 deletions fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import datetime
import time
import uuid
from typing import TypedDict

from tenacity import (retry, retry_if_exception_type, stop_after_delay,
wait_random)
from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_random

from fedn.common.log_config import logger
from fedn.network.combiner.interfaces import CombinerUnavailableError
Expand Down Expand Up @@ -185,26 +183,25 @@ def session(self, config: RoundConfig) -> None:
# TODO: Report completion of session
self.set_session_status(config["session_id"], "Finished")
self._state = ReducerState.idle

def inference_session(self, config: RoundConfig) -> None:
"""Execute a new inference session.
:param config: The round config.
:type config: InferenceConfig
:return: None
"""

if self._state == ReducerState.instructing:
logger.info("Controller already in INSTRUCTING state. A session is in progress.")
return

if len(self.network.get_combiners()) < 1:
logger.warning("Inference round cannot start, no combiners connected!")
return
if not "model_id" in config.keys():
config["model_id"]= self.statestore.get_latest_model()

if "model_id" not in config.keys():
config["model_id"] = self.statestore.get_latest_model()

config["committed_at"] = datetime.datetime.now()
config["task"] = "inference"
config["rounds"] = str(1)
Expand All @@ -216,7 +213,7 @@ def inference_session(self, config: RoundConfig) -> None:
round_start = self.evaluate_round_start_policy(participating_combiners)

if round_start:
logger.info("Inference round start policy met, {} participating combiners.".format(len(participating_combiners)))
logger.info("Inference round start policy met, {} participating combiners.".format(len(participating_combiners)))
for combiner, _ in participating_combiners:
combiner.submit(config)
logger.info("Inference round submitted to combiner {}".format(combiner))
Expand Down
11 changes: 4 additions & 7 deletions fedn/network/controller/controlbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,12 @@ def idle(self):
return False

def get_model_info(self):
""":return:
"""
""":return:"""
return self.statestore.get_model_trail()

# TODO: remove use statestore.get_events() instead
def get_events(self):
""":return:
"""
""":return:"""
return self.statestore.get_events()

def get_latest_round_id(self):
Expand All @@ -136,8 +134,7 @@ def get_latest_round(self):
return round

def get_compute_package_name(self):
""":return:
"""
""":return:"""
definition = self.statestore.get_compute_package()
if definition:
try:
Expand All @@ -164,7 +161,7 @@ def get_compute_package(self, compute_package=""):
else:
return None

def create_session(self, config: RoundConfig, status: str="Initialized") -> None:
def create_session(self, config: RoundConfig, status: str = "Initialized") -> None:
"""Initialize a new session in backend db."""
if "session_id" not in config.keys():
session_id = uuid.uuid4()
Expand Down

0 comments on commit b3310ac

Please sign in to comment.