Skip to content

Commit

Permalink
client_name arg renamed client_id
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Nov 6, 2024
1 parent 86667f0 commit 9e1cc5c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 35 deletions.
22 changes: 8 additions & 14 deletions fedn/network/clients/client_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def init_grpchandler(self, config: GrpcConnectionOptions, client_name: str, toke
logger.error("Error: Could not initialize GRPC connection")
return False


def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0):
self.grpc_handler.send_heartbeats(client_name=client_name, client_id=client_id, update_frequency=update_frequency)

Expand All @@ -220,24 +219,25 @@ def _task_stream_callback(self, request):
elif request.type == fedn.StatusType.MODEL_VALIDATION:
self.validate(request)

def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO:
return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_name, timeout=timeout)
def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO:
return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_id, timeout=timeout)

def send_model_to_combiner(self, model: BytesIO, id: str):
return self.grpc_handler.send_model_to_combiner(model, id)

def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None):
return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name)

def send_model_update(self,
def send_model_update(
self,
sender_name: str,
sender_role: fedn.Role,
client_id: str,
model_id: str,
model_update_id: str,
receiver_name: str,
receiver_role: fedn.Role,
meta: dict
meta: dict,
) -> bool:
return self.grpc_handler.send_model_update(
sender_name=sender_name,
Expand All @@ -247,17 +247,11 @@ def send_model_update(self,
model_update_id=model_update_id,
receiver_name=receiver_name,
receiver_role=receiver_role,
meta=meta
meta=meta,
)

def send_model_validation(self,
sender_name: str,
receiver_name: str,
receiver_role: fedn.Role,
model_id: str,
metrics: dict,
correlation_id: str,
session_id: str
def send_model_validation(
self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, metrics: dict, correlation_id: str, session_id: str
) -> bool:
return self.grpc_handler.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id)

Expand Down
36 changes: 17 additions & 19 deletions fedn/network/clients/client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ def to_json(self):


class Client:
def __init__(self,
api_url: str,
api_port: int,
client_obj: ClientOptions,
combiner_host: str = None,
combiner_port: int = None,
token: str = None,
package_checksum: str = None,
helper_type: str = None
):
def __init__(
self,
api_url: str,
api_port: int,
client_obj: ClientOptions,
combiner_host: str = None,
combiner_port: int = None,
token: str = None,
package_checksum: str = None,
helper_type: str = None,
):
self.api_url = api_url
self.api_port = api_port
self.combiner_host = combiner_host
Expand Down Expand Up @@ -149,7 +150,6 @@ def on_validation(self, request):
logger.info("Received validation request")
self._process_validation_request(request)


def _process_training_request(self, request) -> Tuple[str, dict]:
"""Process a training (model update) request.
Expand All @@ -164,16 +164,14 @@ def _process_training_request(self, request) -> Tuple[str, dict]:
session_id: str = request.session_id

self.client_api.send_status(
f"\t Starting processing of training request for model_id {model_id}",
sesssion_id=session_id,
sender_name=self.client_obj.name
f"\t Starting processing of training request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name
)

try:
meta = {}
tic = time.time()

model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id)
model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id)

if model is None:
logger.error("Could not retrieve model from combiner. Aborting training request.")
Expand Down Expand Up @@ -246,7 +244,7 @@ def _process_training_request(self, request) -> Tuple[str, dict]:
type=fedn.StatusType.MODEL_UPDATE,
request=request,
sesssion_id=session_id,
sender_name=self.client_obj.name
sender_name=self.client_obj.name,
)

def _process_validation_request(self, request):
Expand All @@ -266,7 +264,7 @@ def _process_validation_request(self, request):
self.client_api.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name)

try:
model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id)
model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id)
if model is None:
logger.error("Could not retrieve model from combiner. Aborting validation request.")
return
Expand Down Expand Up @@ -318,13 +316,13 @@ def _process_validation_request(self, request):
type=fedn.StatusType.MODEL_VALIDATION,
request=validation,
sesssion_id=request.session_id,
sender_name=self.client_obj.name
sender_name=self.client_obj.name,
)
else:
self.client_api.send_status(
"Client {} failed to complete model validation.".format(self.name),
log_level=fedn.Status.WARNING,
request=request,
sesssion_id=request.session_id,
sender_name=self.client_obj.name
sender_name=self.client_obj.name,
)
4 changes: 2 additions & 2 deletions fedn/network/clients/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=N
logger.error(f"GRPC (SendStatus): An error occurred: {e}")
self._disconnect()

def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO:
def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO:
"""Fetch a model from the assigned combiner.
Downloads the model update object via a gRPC streaming channel.
Expand All @@ -218,7 +218,7 @@ def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20)
data = BytesIO()
time_start = time.time()
request = fedn.ModelRequest(id=id)
request.sender.client_id = client_name
request.sender.client_id = client_id
request.sender.role = fedn.WORKER

try:
Expand Down

0 comments on commit 9e1cc5c

Please sign in to comment.