Skip to content

Commit

Permalink
Use FednClient API instead of compute package
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminastrand committed Nov 29, 2024
1 parent 4a7466b commit b50c394
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 45 deletions.
15 changes: 10 additions & 5 deletions examples/async-clients/init_fedn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from config import settings
from fedn import APIClient

DISCOVER_HOST = "127.0.0.1"
DISCOVER_PORT = 8092
client = APIClient(
host=settings["DISCOVER_HOST"],
port=settings["DISCOVER_PORT"],
secure=settings["SECURE"],
verify=settings["VERIFY"],
token=settings["ADMIN_TOKEN"],
)

client = APIClient(DISCOVER_HOST, DISCOVER_PORT)
client.set_active_package("package.tgz", "numpyhelper")
client.set_active_model("seed.npz")
result = client.set_active_model("seed.npz")
print(result["message"])
141 changes: 101 additions & 40 deletions examples/async-clients/run_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,67 +16,128 @@
(this requires root previliges)
"""

import copy
import threading
import time
import uuid
from io import BytesIO
from multiprocessing import Process

import numpy as np
from init_seed import compile_model, make_data
from sklearn.metrics import accuracy_score

from fedn.network.clients.client import Client

# Use with a local deployment
settings = {
"DISCOVER_HOST": "127.0.0.1",
"DISCOVER_PORT": 8092,
"TOKEN": None,
"N_CLIENTS": 10,
"N_CYCLES": 100,
"CLIENTS_MAX_DELAY": 10,
"CLIENTS_ONLINE_FOR_SECONDS": 120,
}

client_config = {
"discover_host": settings["DISCOVER_HOST"],
"discover_port": settings["DISCOVER_PORT"],
"token": settings["TOKEN"],
"name": "testclient",
"client_id": 1,
"remote_compute_context": True,
"force_ssl": False,
"dry_run": False,
"secure": False,
"preshared_cert": False,
"verify": False,
"validator": True,
"trainer": True,
"init": None,
"logfile": "test.log",
"heartbeat_interval": 2,
"reconnect_after_missed_heartbeat": 30,
}
from config import settings
from fedn import FednClient

HELPER_MODULE = "numpyhelper"


def get_api_url(host: str, port: int = None, secure: bool = False):
if secure:
url = f"https://{host}:{port}" if port else f"https://{host}"
else:
url = f"http://{host}:{port}" if port else f"http://{host}"
if not url.endswith("/"):
url += "/"
return url


def load_parameters(model_bytes_io: BytesIO):
"""Load model parameters from a BytesIO object."""
model_bytes_io.seek(0) # Ensure we're at the start of the BytesIO object
a = np.load(model_bytes_io)
weights = [a[str(i)] for i in range(len(a.files))]
return weights


def load_model(model_bytes_io: BytesIO):
parameters = load_parameters(model_bytes_io)

model = compile_model()
n = len(parameters) // 2
model.coefs_ = parameters[:n]
model.intercepts_ = parameters[n:]

return model


def on_train(in_model, client_settings):
print("Running training callback...")
model = load_model(in_model)

X_train, y_train, _, _ = make_data()
epochs = settings["N_EPOCHS"]
for i in range(epochs):
model.partial_fit(X_train, y_train)

# Prepare updated model parameters
updated_parameters = model.coefs_ + model.intercepts_
out_model = BytesIO()
np.savez_compressed(out_model, **{str(i): w for i, w in enumerate(updated_parameters)})
out_model.seek(0)

# Metadata needed for aggregation server side
training_metadata = {
"num_examples": len(X_train),
"training_metadata": {
"epochs": epochs,
"batch_size": len(X_train),
"learning_rate": model.learning_rate_init,
},
}

metadata = {"training_metadata": training_metadata}

return out_model, metadata


def on_validate(in_model):
model = load_model(in_model)

X_train, y_train, X_test, y_test = make_data()

# JSON schema
metrics = {"validation_accuracy": accuracy_score(y_test, model.predict(X_test)), "training_accuracy": accuracy_score(y_train, model.predict(X_train))}

return metrics


def run_client(online_for=120, name="client"):
"""Simulates a client that starts and stops
at random intervals.
The client will start after a radom time 'mean_delay',
The client will start after a random time 'mean_delay',
stay online for 'online_for' seconds (deterministic),
then disconnect.
This is repeated for N_CYCLES.
"""
conf = copy.deepcopy(client_config)
conf["name"] = name

for i in range(settings["N_CYCLES"]):
# Sample a delay until the client starts
t_start = np.random.randint(0, settings["CLIENTS_MAX_DELAY"])
time.sleep(t_start)
fl_client = Client(conf)

fl_client = FednClient(train_callback=on_train, validate_callback=on_validate)
fl_client.set_name(name)
fl_client.set_client_id(str(uuid.uuid4()))

controller_config = {
"name": fl_client.name,
"client_id": fl_client.client_id,
"package": "local",
"preferred_combiner": "",
}

url = get_api_url(host=settings["DISCOVER_HOST"], port=settings["DISCOVER_PORT"], secure=settings["SECURE"])

result, combiner_config = fl_client.connect_to_api(url, settings["CLIENT_TOKEN"], controller_config)

fl_client.init_grpchandler(config=combiner_config, client_name=fl_client.client_id, token=settings["CLIENT_TOKEN"])

threading.Thread(target=fl_client.run, daemon=True).start()
time.sleep(online_for)
fl_client.disconnect()
fl_client.grpc_handler._disconnect()


if __name__ == "__main__":
Expand All @@ -87,7 +148,7 @@ def run_client(online_for=120, name="client"):
target=run_client,
args=(
settings["CLIENTS_ONLINE_FOR_SECONDS"],
"client{}".format(i),
"client{}".format(i + 1),
),
)
processes.append(p)
Expand Down

0 comments on commit b50c394

Please sign in to comment.