Skip to content

Commit

Permalink
examples: remove client.py, refactor setup.py to use new client
Browse files Browse the repository at this point in the history
  • Loading branch information
jtsextonMITRE committed Dec 19, 2024
1 parent 9706738 commit fb10fe4
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 1,140 deletions.
91 changes: 50 additions & 41 deletions examples/mnist-classifier-demo/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,9 @@
"register_python_source_file(\"scripts\", Path(\"..\", \"scripts\", \"__init__.py\"))\n",
"\n",
"# Register the examples/scripts directory as a Python module\n",
"from scripts.client import DioptraClient\n",
"from scripts.utils import make_tar\n",
"from scripts.setup import upload_experiment, run_experiment, delete_all\n",
"\n",
"from dioptra.client import connect_json_dioptra_client\n",
"# Set DIOPTRA_RESTAPI_URI variable if not defined, used to connect to RESTful API service\n",
"if os.getenv(\"DIOPTRA_RESTAPI_URI\") is None:\n",
" os.environ[\"DIOPTRA_RESTAPI_URI\"] = RESTAPI_ADDRESS\n",
Expand Down Expand Up @@ -207,7 +206,7 @@
},
"outputs": [],
"source": [
"client = DioptraClient()"
"client = connect_json_dioptra_client()"
]
},
{
Expand All @@ -226,10 +225,18 @@
"outputs": [],
"source": [
"try:\n",
" client.users.create('pluginuser','[email protected]','pleasemakesuretoPLUGINthecomputer','pleasemakesuretoPLUGINthecomputer')\n",
" client.users.create(\n",
" username='pluginuser',\n",
" email='[email protected]',\n",
" password='pleasemakesuretoPLUGINthecomputer'\n",
" )\n",
"except:\n",
" pass # ignore if user exists already\n",
"client.auth.login('pluginuser','pleasemakesuretoPLUGINthecomputer')"
"\n",
"client.auth.login(\n",
" username='pluginuser',\n",
" password='pleasemakesuretoPLUGINthecomputer'\n",
")"
]
},
{
Expand Down Expand Up @@ -271,13 +278,13 @@
"source": [
"job_time_limit = '1h'\n",
"\n",
"training_job = client.experiments.create_jobs_by_experiment_id(\n",
" experiment_id, \n",
" f\"training job for {experiment_id}\", \n",
" queue_id,\n",
" train_ep, \n",
" {\"epochs\":\"30\"}, \n",
" job_time_limit\n",
"training_job = client.experiments.jobs.create(\n",
" experiment_id=experiment_id, \n",
" description=f\"training job for {experiment_id}\", \n",
" queue_id=queue_id,\n",
" entrypoint_id=train_ep, \n",
" values={\"epochs\":\"30\"}, \n",
" timeout=job_time_limit\n",
")"
]
},
Expand All @@ -297,13 +304,13 @@
"job_time_limit = '1h'\n",
"\n",
"wait_for_job(training_job, 'fgm')\n",
"fgm_job = client.experiments.create_jobs_by_experiment_id(\n",
" experiment_id,\n",
" f\"fgm job for {experiment_id}\",\n",
" queue_id,\n",
" fgm_ep,\n",
" {\"model_name\": MODEL_NAME, \"model_version\": str(-1)}, # -1 means get the latest model\n",
" job_time_limit\n",
"fgm_job = client.experiments.jobs.create(\n",
" experiment_id=experiment_id,\n",
" description=f\"fgm job for {experiment_id}\",\n",
" queue_id=queue_id,\n",
" entrypoint_id=fgm_ep,\n",
" values={\"model_name\": MODEL_NAME, \"model_version\": str(-1)}, # -1 means get the latest model\n",
" timeout=job_time_limit\n",
")"
]
},
Expand All @@ -325,18 +332,18 @@
"job_time_limit = '1h'\n",
"\n",
"wait_for_job(training_job, 'patch_gen')\n",
"patch_gen_job = client.experiments.create_jobs_by_experiment_id(\n",
" experiment_id,\n",
" f\"patch generation job for {experiment_id}\",\n",
" queue_id,\n",
" patch_gen_ep,\n",
" {\"model_name\": MODEL_NAME,\n",
"patch_gen_job = client.experiments.jobs.create(\n",
" experiment_id=experiment_id,\n",
" description=f\"patch generation job for {experiment_id}\",\n",
" queue_id=queue_id,\n",
" entrypoint_id=patch_gen_ep,\n",
" values={\"model_name\": MODEL_NAME,\n",
" \"model_version\": str(-1), # -1 means get the latest\n",
" \"rotation_max\": str(180),\n",
" \"max_iter\": str(5000),\n",
" \"learning_rate\": str(5.0),\n",
" },\n",
" job_time_limit\n",
" timeout=job_time_limit\n",
")"
]
},
Expand All @@ -356,18 +363,18 @@
"job_time_limit = '1h'\n",
"\n",
"wait_for_job(training_job, 'patch_apply')\n",
"patch_apply_job = client.experiments.create_jobs_by_experiment_id(\n",
" experiment_id,\n",
" f\"patch generation job for {experiment_id}\",\n",
" queue_id,\n",
" patch_apply_ep,\n",
" {\"model_name\": MODEL_NAME, \n",
"patch_apply_job = client.experiments.jobs.create(\n",
" experiment_id=experiment_id,\n",
" description=f\"patch generation job for {experiment_id}\",\n",
" queue_id=queue_id,\n",
" entrypoint_id=patch_apply_ep,\n",
" values={\"model_name\": MODEL_NAME, \n",
" \"model_version\": str(-1), # -1 means get the latest model\n",
" \"job_id\": str(patch_gen_job['id']),# we need the patches we just generated too\n",
" \"patch_scale\": str(0.5),\n",
" \"rotation_max\": str(180),\n",
" }, \n",
" job_time_limit\n",
" timeout=job_time_limit\n",
")"
]
},
Expand All @@ -385,20 +392,22 @@
"outputs": [],
"source": [
"def run_job(experiment_id, queue_id, ep, title, prev_job_id=False, latest_model=False, args=None, prev_job=None, job_time_limit='1h'):\n",
" args = {} if args is None else args\n",
" prev_job = {} if prev_job is None else prev_job\n",
" if prev_job is not None:\n",
" wait_for_job(prev_job, title, quiet=False)\n",
" if prev_job_id:\n",
" if prev_job_id and 'id' in prev_job.keys():\n",
" args['job_id'] = str(prev_job['id'])\n",
" if latest_model:\n",
" args['model_name'] = MODEL_NAME \n",
" args['model_version'] = str(-1)\n",
" job = client.experiments.create_jobs_by_experiment_id(\n",
" experiment_id,\n",
" f\"{title} job for {experiment_id}\",\n",
" queue_id,\n",
" ep,\n",
" args,\n",
" job_time_limit\n",
" job = client.experiments.jobs.create(\n",
" experiment_id=experiment_id,\n",
" description=f\"{title} job for {experiment_id}\",\n",
" queue_id=queue_id,\n",
" entrypoint_id=ep,\n",
" values=args,\n",
" timeout=job_time_limit\n",
" )\n",
" return job\n",
"def get_prev_tar_file(adv=\"def\"):\n",
Expand Down
Loading

0 comments on commit fb10fe4

Please sign in to comment.