Skip to content

Commit

Permalink
Added back time_model_load/aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Dec 16, 2024
1 parent e772691 commit 6087890
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 28 deletions.
30 changes: 22 additions & 8 deletions fedn/network/combiner/aggregators/fedavg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import traceback

from fedn.common.log_config import logger
Expand Down Expand Up @@ -41,37 +42,50 @@ def combine_models(self, helper=None, delete_models=True, parameters=None):
nr_aggregated_models = 0
total_examples = 0

logger.info("AGGREGATOR({}): Aggregating model updates... ".format(self.name))
logger.info(
"AGGREGATOR({}): Aggregating model updates... ".format(self.name))

while not self.update_handler.model_updates.empty():
try:
logger.info("AGGREGATOR({}): Getting next model update from queue.".format(self.name))
logger.info(
"AGGREGATOR({}): Getting next model update from queue.".format(self.name))
model_update = self.update_handler.next_model_update()

# Load model parameters and metadata
logger.info("AGGREGATOR({}): Loading model metadata {}.".format(self.name, model_update.model_update_id))
model_next, metadata = self.update_handler.load_model_update(model_update, helper)
logger.info("AGGREGATOR({}): Loading model metadata {}.".format(
self.name, model_update.model_update_id))

logger.info("AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata))
tic = time.time()
model_next, metadata = self.update_handler.load_model_update(
model_update, helper)
data['time_model_load'] += time.time()-tic

logger.info("AGGREGATOR({}): Processing model update {}, metadata: {} ".format(
self.name, model_update.model_update_id, metadata))

# Increment total number of examples
total_examples += metadata["num_examples"]

tic = time.time()
if nr_aggregated_models == 0:
model = model_next
else:
model = helper.increment_average(model, model_next, metadata["num_examples"], total_examples)
model = helper.increment_average(
model, model_next, metadata["num_examples"], total_examples)
data['time_model_aggregration'] += time.time()-tic

nr_aggregated_models += 1
# Delete model from storage
if delete_models:
self.update_handler.delete_model(model_update)
except Exception as e:
tb = traceback.format_exc()
logger.error(f"AGGREGATOR({self.name}): Error encoutered while processing model update: {e}")
logger.error(
f"AGGREGATOR({self.name}): Error encoutered while processing model update: {e}")
logger.error(tb)

data["nr_aggregated_models"] = nr_aggregated_models

logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models))
logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(
self.name, nr_aggregated_models))
return model, data
66 changes: 46 additions & 20 deletions fedn/network/combiner/aggregators/fedopt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import time

from fedn.common.exceptions import InvalidParameterError
from fedn.common.log_config import logger
Expand Down Expand Up @@ -61,7 +62,8 @@ def combine_models(self, helper=None, delete_models=True, parameters=None):
try:
parameters.validate(parameter_schema)
except InvalidParameterError as e:
logger.error("Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e))
logger.error(
"Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e))
return None, data

# Default hyperparameters. Note that these may need fine tuning.
Expand All @@ -78,10 +80,12 @@ def combine_models(self, helper=None, delete_models=True, parameters=None):
try:
parameters.validate(parameter_schema)
except InvalidParameterError as e:
logger.error("Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e))
logger.error(
"Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e))
return None, data
else:
logger.info("Aggregator {} using default parameteres.", format(self.name))
logger.info("Aggregator {} using default parameteres.",
format(self.name))
parameters = self.default_parameters

# Override missing paramters with defaults
Expand All @@ -93,48 +97,67 @@ def combine_models(self, helper=None, delete_models=True, parameters=None):
nr_aggregated_models = 0
total_examples = 0

logger.info("AGGREGATOR({}): Aggregating model updates... ".format(self.name))
logger.info(
"AGGREGATOR({}): Aggregating model updates... ".format(self.name))

while not self.update_handler.model_updates.empty():
try:
logger.info("AGGREGATOR({}): Getting next model update from queue.".format(self.name))
logger.info(
"AGGREGATOR({}): Getting next model update from queue.".format(self.name))
model_update = self.update_handler.next_model_update()
# Load model paratmeters and metadata
model_next, metadata = self.update_handler.load_model_update(model_update, helper)

logger.info("AGGREGATOR({}): Processing model update {}".format(self.name, model_update.model_update_id))
tic = time.time()
model_next, metadata = self.update_handler.load_model_update(
model_update, helper)
data['time_model_load'] += time.time()-tic

logger.info("AGGREGATOR({}): Processing model update {}".format(
self.name, model_update.model_update_id))

# Increment total number of examples
total_examples += metadata["num_examples"]

tic = time.time()
if nr_aggregated_models == 0:
model_old = self.update_handler.load_model(helper, model_update.model_id)
model_old = self.update_handler.load_model(
helper, model_update.model_id)
pseudo_gradient = helper.subtract(model_next, model_old)
else:
pseudo_gradient_next = helper.subtract(model_next, model_old)
pseudo_gradient = helper.increment_average(pseudo_gradient, pseudo_gradient_next, metadata["num_examples"], total_examples)
pseudo_gradient_next = helper.subtract(
model_next, model_old)
pseudo_gradient = helper.increment_average(
pseudo_gradient, pseudo_gradient_next, metadata["num_examples"], total_examples)
data['time_model_aggregration'] += time.time()-tic

nr_aggregated_models += 1
# Delete model from storage
if delete_models:
self.update_handler.delete_model(model_update.model_update_id)
logger.info("AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id))
self.update_handler.delete_model(
model_update.model_update_id)
logger.info("AGGREGATOR({}): Deleted model update {} from storage.".format(
self.name, model_update.model_update_id))
except Exception as e:
logger.error("AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e))
logger.error(
"AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e))

if parameters["serveropt"] == "adam":
model = self.serveropt_adam(helper, pseudo_gradient, model_old, parameters)
model = self.serveropt_adam(
helper, pseudo_gradient, model_old, parameters)
elif parameters["serveropt"] == "yogi":
model = self.serveropt_yogi(helper, pseudo_gradient, model_old, parameters)
model = self.serveropt_yogi(
helper, pseudo_gradient, model_old, parameters)
elif parameters["serveropt"] == "adagrad":
model = self.serveropt_adagrad(helper, pseudo_gradient, model_old, parameters)
model = self.serveropt_adagrad(
helper, pseudo_gradient, model_old, parameters)
else:
logger.error("Unsupported server optimizer passed to FedOpt.")
return None, data

data["nr_aggregated_models"] = nr_aggregated_models

logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models))
logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(
self.name, nr_aggregated_models))
return model, data

def serveropt_adam(self, helper, pseudo_gradient, model_old, parameters):
Expand All @@ -160,7 +183,8 @@ def serveropt_adam(self, helper, pseudo_gradient, model_old, parameters):
self.v = helper.ones(pseudo_gradient, math.pow(tau, 2))

if not self.m:
self.m = helper.multiply(pseudo_gradient, [(1.0 - beta1)] * len(pseudo_gradient))
self.m = helper.multiply(
pseudo_gradient, [(1.0 - beta1)] * len(pseudo_gradient))
else:
self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0 - beta1))

Expand Down Expand Up @@ -196,7 +220,8 @@ def serveropt_yogi(self, helper, pseudo_gradient, model_old, parameters):
self.v = helper.ones(pseudo_gradient, math.pow(tau, 2))

if not self.m:
self.m = helper.multiply(pseudo_gradient, [(1.0 - beta1)] * len(pseudo_gradient))
self.m = helper.multiply(
pseudo_gradient, [(1.0 - beta1)] * len(pseudo_gradient))
else:
self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0 - beta1))

Expand Down Expand Up @@ -233,7 +258,8 @@ def serveropt_adagrad(self, helper, pseudo_gradient, model_old, parameters):
self.v = helper.ones(pseudo_gradient, math.pow(tau, 2))

if not self.m:
self.m = helper.multiply(pseudo_gradient, [(1.0 - beta1)] * len(pseudo_gradient))
self.m = helper.multiply(
pseudo_gradient, [(1.0 - beta1)] * len(pseudo_gradient))
else:
self.m = helper.add(self.m, pseudo_gradient, beta1, (1.0 - beta1))

Expand Down

0 comments on commit 6087890

Please sign in to comment.