Skip to content

Commit

Permalink
Refactor/SK-1225 | Use api/v1 in APIClient start_session (#758)
Browse files Browse the repository at this point in the history
* Use api/v1 in APIClient start_session

* Set helper based on file extension of seed model

* Use active model if model_id not in session_config

* Use autogenerated session id when starting session
  • Loading branch information
benjaminastrand authored Nov 26, 2024
1 parent aa0ce51 commit 9d3431d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 17 deletions.
7 changes: 4 additions & 3 deletions examples/monai-2D-mednist/client/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def validate(in_model_path, out_json_path, data_path=None, client_settings_path=

image_list = clients["client " + str(split_index)]["validation"]

val_ds = MedNISTDataset(data_path=data_path+"/MedNIST/", transforms=val_transforms, image_files=image_list)
val_ds = MedNISTDataset(data_path=data_path + "/MedNIST/", transforms=val_transforms, image_files=image_list)

val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)

Expand Down Expand Up @@ -86,8 +86,9 @@ def validate(in_model_path, out_json_path, data_path=None, client_settings_path=

# JSON schema
report.update({"test_accuracy": accuracy_score(y_true, y_pred), "test_f1_score": f1_score(y_true, y_pred, average="macro")})
for r in report:
print(r, ": ", report[r])

for key, value in report.items():
print(f"{key}: {value}")

# Save JSON
save_metrics(report, out_json_path)
Expand Down
61 changes: 47 additions & 14 deletions fedn/network/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,18 @@ def set_active_model(self, path):
:return: A dict with success or failure message.
:rtype: dict
"""
if path.endswith(".npz"):
helper = "numpyhelper"
elif path.endswith(".bin"):
helper = "binaryhelper"

if helper:
response = requests.put(self._get_url_api_v1("helpers/active"), json={"helper": helper}, verify=self.verify, headers=self.headers)

with open(path, "rb") as file:
response = requests.post(self._get_url("set_initial_model"), files={"file": file}, verify=self.verify, headers=self.headers)
response = requests.post(
self._get_url("set_initial_model"), files={"file": file}, data={"helper": helper}, verify=self.verify, headers=self.headers
)
return response.json()

# --- Packages --- #
Expand Down Expand Up @@ -606,27 +616,50 @@ def start_session(
:return: A dict with success or failure message and session config.
:rtype: dict
"""
if model_id is None:
response = requests.get(self._get_url_api_v1("models/active"), verify=self.verify, headers=self.headers)
if response.status_code == 200:
model_id = response.json()
else:
return response.json()

response = requests.post(
self._get_url("start_session"),
self._get_url_api_v1("sessions"),
json={
"session_id": id,
"aggregator": aggregator,
"aggregator_kwargs": aggregator_kwargs,
"model_id": model_id,
"round_timeout": round_timeout,
"rounds": rounds,
"round_buffer_size": round_buffer_size,
"delete_models": delete_models,
"validate": validate,
"helper": helper,
"min_clients": min_clients,
"requested_clients": requested_clients,
"server_functions": None if server_functions is None else inspect.getsource(server_functions),
"session_config": {
"aggregator": aggregator,
"aggregator_kwargs": aggregator_kwargs,
"round_timeout": round_timeout,
"buffer_size": round_buffer_size,
"model_id": model_id,
"delete_models_storage": delete_models,
"clients_required": min_clients,
"requested_clients": requested_clients,
"validate": validate,
"helper_type": helper,
"server_functions": None if server_functions is None else inspect.getsource(server_functions),
},
},
verify=self.verify,
headers=self.headers,
)

if id is None:
id = response.json()["session_id"]

if response.status_code == 201:
response = requests.post(
self._get_url_api_v1("sessions/start"),
json={
"session_id": id,
"rounds": rounds,
"round_timeout": round_timeout,
},
verify=self.verify,
headers=self.headers,
)

_json = response.json()

return _json
Expand Down

0 comments on commit 9d3431d

Please sign in to comment.