Skip to content

Commit

Permalink
fix request instancd same between clients
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Jun 13, 2024
1 parent 5e58d4d commit 5d4212f
Showing 1 changed file with 28 additions and 31 deletions.
59 changes: 28 additions & 31 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ def request_model_update(self, session_id, model_id, config, clients=[]):
:type clients: list
"""
request, clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients)
clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients)

if len(clients) < 20:
logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients))
logger.info("Sent model update request for model {} to clients {}".format(model_id, clients))
else:
logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients)))
logger.info("Sent model update request for model {} to {} clients".format(model_id, len(clients)))

def request_model_validation(self, session_id, model_id, clients=[]):
"""Ask clients to validate the current global model.
Expand All @@ -187,12 +187,12 @@ def request_model_validation(self, session_id, model_id, clients=[]):
:type clients: list
"""
request, clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients)
clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients))
logger.info("Sent model validation request for model {} to clients {}".format(model_id, clients))
else:
logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients)))
logger.info("Sent model validation request for model {} to {} clients".format(model_id, len(clients)))

def request_model_inference(self, session_id: str, model_id: str, clients: list = []) -> None:
"""Ask clients to perform inference on the model.
Expand All @@ -205,12 +205,12 @@ def request_model_inference(self, session_id: str, model_id: str, clients: list
:type clients: list
"""
request, clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients)
clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model inference request for model {} to clients {}".format(request.model_id, clients))
logger.info("Sent model inference request for model {} to clients {}".format(model_id, clients))
else:
logger.info("Sent model inference request for model {} to {} clients".format(request.model_id, len(clients)))
logger.info("Sent model inference request for model {} to {} clients".format(model_id, len(clients)))

def _send_request_type(self, request_type, session_id, model_id, config=None, clients=[]):
"""Send a request of a specific type to clients.
Expand All @@ -223,41 +223,38 @@ def _send_request_type(self, request_type, session_id, model_id, config=None, cl
:type config: dict
:param clients: the clients to send the request to
:type clients: list
:return: the request and the clients
:rtype: tuple
:return: the clients
:rtype: list
"""
request = fedn.TaskRequest()
request.model_id = model_id
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
request.type = request_type
request.session_id = session_id

request.sender.name = self.id
request.sender.role = fedn.COMBINER

if request_type == fedn.StatusType.MODEL_UPDATE:
request.data = json.dumps(config)
if len(clients) == 0:
if len(clients) == 0:
if request_type == fedn.StatusType.MODEL_UPDATE:
clients = self.get_active_trainers()
elif request_type == fedn.StatusType.MODEL_VALIDATION:
if len(clients) == 0:
elif request_type == fedn.StatusType.MODEL_VALIDATION:
clients = self.get_active_validators()
elif request_type == fedn.StatusType.INFERENCE:
if len(clients) == 0:
elif request_type == fedn.StatusType.INFERENCE:
# TODO: add inference clients type
clients = self.get_active_validators()

for client in clients:
request = fedn.TaskRequest()
request.model_id = model_id
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
request.type = request_type
request.session_id = session_id

request.sender.name = self.id
request.sender.role = fedn.COMBINER
request.receiver.name = client
request.receiver.role = fedn.WORKER
# Set the request data, not used in validation
if request_type == fedn.StatusType.INFERENCE:
presigned_url = self.repository.presigned_put_url(self.repository.inference_bucket, f"{client}/{session_id}")
# TODO: in inference, request.data should also contain user-defined data/parameters
request.data = json.dumps({"presigned_url": presigned_url})
elif request_type == fedn.StatusType.MODEL_UPDATE:
request.data = json.dumps(config)
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)

return request, clients
return clients

def get_active_trainers(self):
"""Get a list of active trainers.
Expand Down

0 comments on commit 5d4212f

Please sign in to comment.