-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
105 changed files
with
5,822 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,3 +147,4 @@ cython_debug/ | |
# Dev Environment Specific | ||
.vscode | ||
.venv | ||
server/keys |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# TODO: remove me from the repo | ||
|
||
FOR TUTORIAL | ||
|
||
- stream logs | ||
- check benchmark execution mlcube training exp ID | ||
- if association request failed for some reason, delete private key (or at least check if rerunning the request will simply overwrite the key) | ||
- define output folders in medperf storage (logs for both, weights for agg) | ||
- adding email to CN currently could be challenging. THINK | ||
- ASSUMPTION: emails are not changed after signup | ||
|
||
- We now have demo data url and hash in training exp (dummy) that we don't use. | ||
- what to say about this in miccai (I think no worries; it's hidden now) | ||
- rethink/review about the following serializers and if necessary use atomic transactions | ||
- association creation (dataset-training, agg-training) | ||
- association approval (dataset-training, agg-training) | ||
- training experiment creation (creating keypair); this could move to approval | ||
- public/private keys uniqueness constraint while blank; check django docs on how | ||
- fix bug about association list; /home/hasan/work/openfl_ws/medperf-private/server/utils/views.py | ||
- pull latest medperf main | ||
- test agg and training exp owner being same user | ||
- basically, test the tutorial steps EXACTLY | ||
|
||
AFTER TUTORIAL | ||
|
||
- FOLLOWUP: collaborators doesn't use tensorboard logs. | ||
- FOLLOWUP: show csr hash on approval is not necessary since now CSRs are transported securely | ||
- test remote aggregator | ||
- make network config better structured (URL to file? no, could be annoying.) | ||
- move key generation after admin approval of training experiments. | ||
- when the training experiment owner wants to "lock" the experiment | ||
- ask for confirmation? it's an easy command and after execution there is no going back; a mess if unintended. | ||
- secretstorage gcloud | ||
|
||
NOT SURE | ||
|
||
- consider if we want to enable restarts and epochs/"fresh restarts" for training exps (it's hard) | ||
- mlcube for agg alone | ||
|
||
LATER / FUTURE INVESTIGATIONS | ||
|
||
- root key thing. | ||
- limit network access (for now we can rely on the review of the experiment owner) | ||
- compatibility tests | ||
- rethink if keys are always needed (just for exps where they on't need a custom cert) | ||
- server side verification of CSRs (check common names) | ||
- later: the whole design might be changed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Optional | ||
from medperf.entities.aggregator import Aggregator | ||
import typer | ||
|
||
import medperf.config as config | ||
from medperf.decorators import clean_except | ||
from medperf.commands.aggregator.submit import SubmitAggregator | ||
from medperf.commands.aggregator.associate import AssociateAggregator | ||
from medperf.commands.aggregator.run import StartAggregator | ||
|
||
from medperf.commands.list import EntityList | ||
from medperf.commands.view import EntityView | ||
|
||
app = typer.Typer() | ||
|
||
|
||
@app.command("submit") | ||
@clean_except | ||
def submit( | ||
name: str = typer.Option(..., "--name", "-n", help="Name of the agg"), | ||
address: str = typer.Option( | ||
..., "--address", "-a", help="UID of benchmark to associate with" | ||
), | ||
port: int = typer.Option( | ||
..., "--port", "-p", help="UID of benchmark to associate with" | ||
), | ||
): | ||
"""Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" | ||
SubmitAggregator.run(name, address, port) | ||
config.ui.print("✅ Done!") | ||
|
||
|
||
@app.command("associate") | ||
@clean_except | ||
def associate( | ||
aggregator_id: int = typer.Option( | ||
..., "--aggregator_id", "-a", help="UID of benchmark to associate with" | ||
), | ||
training_exp_id: int = typer.Option( | ||
..., "--training_exp_id", "-t", help="UID of benchmark to associate with" | ||
), | ||
approval: bool = typer.Option(False, "-y", help="Skip approval step"), | ||
): | ||
"""Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" | ||
AssociateAggregator.run(aggregator_id, training_exp_id, approved=approval) | ||
config.ui.print("✅ Done!") | ||
|
||
|
||
@app.command("start") | ||
@clean_except | ||
def run( | ||
aggregator_id: int = typer.Option( | ||
..., "--aggregator_id", "-a", help="UID of benchmark to associate with" | ||
), | ||
training_exp_id: int = typer.Option( | ||
..., "--training_exp_id", "-t", help="UID of benchmark to associate with" | ||
), | ||
): | ||
"""Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" | ||
StartAggregator.run(training_exp_id, aggregator_id) | ||
config.ui.print("✅ Done!") | ||
|
||
|
||
@app.command("ls") | ||
@clean_except | ||
def list( | ||
local: bool = typer.Option(False, "--local", help="Get local aggregators"), | ||
mine: bool = typer.Option(False, "--mine", help="Get current-user aggregators"), | ||
): | ||
"""List aggregators stored locally and remotely from the user""" | ||
EntityList.run( | ||
Aggregator, | ||
fields=["UID", "Name", "Address", "Port"], | ||
local_only=local, | ||
mine_only=mine, | ||
) | ||
|
||
|
||
@app.command("view") | ||
@clean_except | ||
def view( | ||
entity_id: Optional[int] = typer.Argument(None, help="Benchmark ID"), | ||
format: str = typer.Option( | ||
"yaml", | ||
"-f", | ||
"--format", | ||
help="Format to display contents. Available formats: [yaml, json]", | ||
), | ||
local: bool = typer.Option( | ||
False, | ||
"--local", | ||
help="Display local benchmarks if benchmark ID is not provided", | ||
), | ||
mine: bool = typer.Option( | ||
False, | ||
"--mine", | ||
help="Display current-user benchmarks if benchmark ID is not provided", | ||
), | ||
output: str = typer.Option( | ||
None, | ||
"--output", | ||
"-o", | ||
help="Output file to store contents. If not provided, the output will be displayed", | ||
), | ||
): | ||
"""Displays the information of one or more aggregators""" | ||
EntityView.run(entity_id, Aggregator, format, local, mine, output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from medperf import config | ||
from medperf.entities.aggregator import Aggregator | ||
from medperf.entities.training_exp import TrainingExp | ||
from medperf.utils import approval_prompt, generate_agg_csr | ||
from medperf.exceptions import InvalidArgumentError | ||
|
||
|
||
class AssociateAggregator: | ||
@staticmethod | ||
def run(training_exp_id: int, agg_uid: int, approved=False): | ||
"""Associates a registered aggregator with a benchmark | ||
Args: | ||
agg_uid (int): UID of the registered aggregator to associate | ||
benchmark_uid (int): UID of the benchmark to associate with | ||
""" | ||
comms = config.comms | ||
ui = config.ui | ||
agg = Aggregator.get(agg_uid) | ||
if agg.id is None: | ||
msg = "The provided aggregator is not registered." | ||
raise InvalidArgumentError(msg) | ||
|
||
training_exp = TrainingExp.get(training_exp_id) | ||
csr, csr_hash = generate_agg_csr(training_exp_id, agg.address, agg.id) | ||
msg = "Please confirm that you would like to associate" | ||
msg += f" the aggregator {agg.name} with the training exp {training_exp.name}." | ||
msg += f" The certificate signing request hash is: {csr_hash}" | ||
msg += " [Y/n]" | ||
|
||
approved = approved or approval_prompt(msg) | ||
if approved: | ||
ui.print("Generating aggregator training association") | ||
# TODO: delete keys if upload fails | ||
# check if on failure, other (possible) request will overwrite key | ||
comms.associate_aggregator(agg.id, training_exp_id, csr) | ||
else: | ||
ui.print("Aggregator association operation cancelled.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import os | ||
from medperf import config | ||
from medperf.exceptions import InvalidArgumentError | ||
from medperf.entities.training_exp import TrainingExp | ||
from medperf.entities.aggregator import Aggregator | ||
from medperf.entities.cube import Cube | ||
from medperf.utils import storage_path | ||
|
||
|
||
class StartAggregator: | ||
@classmethod | ||
def run(cls, training_exp_id: int, agg_uid: int): | ||
"""Sets approval status for an association between a benchmark and a aggregator or mlcube | ||
Args: | ||
benchmark_uid (int): Benchmark UID. | ||
approval_status (str): Desired approval status to set for the association. | ||
comms (Comms): Instance of Comms interface. | ||
ui (UI): Instance of UI interface. | ||
aggregator_uid (int, optional): Aggregator UID. Defaults to None. | ||
mlcube_uid (int, optional): MLCube UID. Defaults to None. | ||
""" | ||
execution = cls(training_exp_id, agg_uid) | ||
execution.prepare() | ||
execution.validate() | ||
execution.prepare_agg_cert() | ||
execution.prepare_cube() | ||
with config.ui.interactive(): | ||
execution.run_experiment() | ||
|
||
def __init__(self, training_exp_id, agg_uid) -> None: | ||
self.training_exp_id = training_exp_id | ||
self.agg_uid = agg_uid | ||
self.ui = config.ui | ||
|
||
def prepare(self): | ||
self.training_exp = TrainingExp.get(self.training_exp_id) | ||
self.ui.print(f"Training Execution: {self.training_exp.name}") | ||
self.aggregator = Aggregator.get(self.agg_uid) | ||
|
||
def validate(self): | ||
if self.aggregator.id is None: | ||
msg = "The provided aggregator is not registered." | ||
raise InvalidArgumentError(msg) | ||
|
||
training_exp_aggregator = config.comms.get_experiment_aggregator( | ||
self.training_exp.id | ||
) | ||
|
||
if self.aggregator.id != training_exp_aggregator["id"]: | ||
msg = "The provided aggregator is not associated." | ||
raise InvalidArgumentError(msg) | ||
|
||
if self.training_exp.state != "OPERATION": | ||
msg = "The provided training exp is not operational." | ||
raise InvalidArgumentError(msg) | ||
|
||
def prepare_agg_cert(self): | ||
association = config.comms.get_aggregator_association( | ||
self.training_exp.id, self.aggregator.id | ||
) | ||
cert = association["certificate"] | ||
cert_folder = os.path.join( | ||
config.training_exps_storage, | ||
str(self.training_exp.id), | ||
config.agg_cert_folder, | ||
str(self.aggregator.id), | ||
) | ||
cert_folder = storage_path(cert_folder) | ||
os.makedirs(cert_folder, exist_ok=True) | ||
cert_file = os.path.join(cert_folder, "cert.crt") | ||
with open(cert_file, "w") as f: | ||
f.write(cert) | ||
Check failure Code scanning / CodeQL Clear-text storage of sensitive information High
This expression stores
sensitive data (certificate) Error loading related location Loading |
||
|
||
self.agg_cert_path = cert_folder | ||
|
||
def prepare_cube(self): | ||
self.cube = self.__get_cube(self.training_exp.fl_mlcube, "training") | ||
|
||
def __get_cube(self, uid: int, name: str) -> Cube: | ||
self.ui.text = f"Retrieving {name} cube" | ||
cube = Cube.get(uid) | ||
self.ui.print(f"> {name} cube download complete") | ||
return cube | ||
|
||
def run_experiment(self): | ||
task = "start_aggregator" | ||
port = self.aggregator.port | ||
# TODO: this overwrites existing cpu and gpu args | ||
string_params = { | ||
"-Pdocker.cpu_args": f"-p {port}:{port}", | ||
"-Pdocker.gpu_args": f"-p {port}:{port}", | ||
} | ||
|
||
# just for now create some output folders (TODO) | ||
out_logs = os.path.join(self.training_exp.path, "logs") | ||
out_weights = os.path.join(self.training_exp.path, "weights") | ||
os.makedirs(out_logs, exist_ok=True) | ||
os.makedirs(out_weights, exist_ok=True) | ||
|
||
params = { | ||
"node_cert_folder": self.agg_cert_path, | ||
"ca_cert_folder": self.training_exp.cert_path, | ||
"network_config": self.aggregator.network_config_path, | ||
"collaborators": self.training_exp.cols_path, | ||
"output_logs": out_logs, | ||
"output_weights": out_weights, | ||
} | ||
|
||
self.ui.text = "Running Aggregator" | ||
self.cube.run(task=task, string_params=string_params, **params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import medperf.config as config | ||
from medperf.entities.aggregator import Aggregator | ||
from medperf.utils import remove_path | ||
|
||
|
||
class SubmitAggregator: | ||
@classmethod | ||
def run(cls, name, address, port): | ||
"""Submits a new cube to the medperf platform | ||
Args: | ||
benchmark_info (dict): benchmark information | ||
expected keys: | ||
name (str): benchmark name | ||
description (str): benchmark description | ||
docs_url (str): benchmark documentation url | ||
demo_url (str): benchmark demo dataset url | ||
demo_hash (str): benchmark demo dataset hash | ||
data_preparation_mlcube (int): benchmark data preparation mlcube uid | ||
reference_model_mlcube (int): benchmark reference model mlcube uid | ||
evaluator_mlcube (int): benchmark data evaluator mlcube uid | ||
""" | ||
ui = config.ui | ||
submission = cls(name, address, port) | ||
|
||
with ui.interactive(): | ||
ui.text = "Submitting Aggregator to MedPerf" | ||
updated_benchmark_body = submission.submit() | ||
ui.print("Uploaded") | ||
submission.write(updated_benchmark_body) | ||
|
||
def __init__(self, name, address, port): | ||
self.ui = config.ui | ||
# TODO: server config should be a URL... | ||
server_config = { | ||
"address": address, | ||
"agg_addr": address, | ||
"port": port, | ||
"agg_port": port, | ||
} | ||
self.aggregator = Aggregator(name=name, server_config=server_config) | ||
config.tmp_paths.append(self.aggregator.path) | ||
|
||
def submit(self): | ||
updated_body = self.aggregator.upload() | ||
return updated_body | ||
|
||
def write(self, updated_body): | ||
remove_path(self.aggregator.path) | ||
aggregator = Aggregator(**updated_body) | ||
aggregator.write() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.