-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
examples: remove client.py, refactor setup.py to use new client
- Loading branch information
1 parent
9706738
commit fb10fe4
Showing
3 changed files
with
256 additions
and
1,140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,10 +94,9 @@ | |
"register_python_source_file(\"scripts\", Path(\"..\", \"scripts\", \"__init__.py\"))\n", | ||
"\n", | ||
"# Register the examples/scripts directory as a Python module\n", | ||
"from scripts.client import DioptraClient\n", | ||
"from scripts.utils import make_tar\n", | ||
"from scripts.setup import upload_experiment, run_experiment, delete_all\n", | ||
"\n", | ||
"from dioptra.client import connect_json_dioptra_client\n", | ||
"# Set DIOPTRA_RESTAPI_URI variable if not defined, used to connect to RESTful API service\n", | ||
"if os.getenv(\"DIOPTRA_RESTAPI_URI\") is None:\n", | ||
" os.environ[\"DIOPTRA_RESTAPI_URI\"] = RESTAPI_ADDRESS\n", | ||
|
@@ -207,7 +206,7 @@ | |
}, | ||
"outputs": [], | ||
"source": [ | ||
"client = DioptraClient()" | ||
"client = connect_json_dioptra_client()" | ||
] | ||
}, | ||
{ | ||
|
@@ -226,10 +225,18 @@ | |
"outputs": [], | ||
"source": [ | ||
"try:\n", | ||
" client.users.create('pluginuser','[email protected]','pleasemakesuretoPLUGINthecomputer','pleasemakesuretoPLUGINthecomputer')\n", | ||
" client.users.create(\n", | ||
" username='pluginuser',\n", | ||
" email='[email protected]',\n", | ||
" password='pleasemakesuretoPLUGINthecomputer'\n", | ||
" )\n", | ||
"except:\n", | ||
" pass # ignore if user exists already\n", | ||
"client.auth.login('pluginuser','pleasemakesuretoPLUGINthecomputer')" | ||
"\n", | ||
"client.auth.login(\n", | ||
" username='pluginuser',\n", | ||
" password='pleasemakesuretoPLUGINthecomputer'\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
|
@@ -271,13 +278,13 @@ | |
"source": [ | ||
"job_time_limit = '1h'\n", | ||
"\n", | ||
"training_job = client.experiments.create_jobs_by_experiment_id(\n", | ||
" experiment_id, \n", | ||
" f\"training job for {experiment_id}\", \n", | ||
" queue_id,\n", | ||
" train_ep, \n", | ||
" {\"epochs\":\"30\"}, \n", | ||
" job_time_limit\n", | ||
"training_job = client.experiments.jobs.create(\n", | ||
" experiment_id=experiment_id, \n", | ||
" description=f\"training job for {experiment_id}\", \n", | ||
" queue_id=queue_id,\n", | ||
" entrypoint_id=train_ep, \n", | ||
" values={\"epochs\":\"30\"}, \n", | ||
" timeout=job_time_limit\n", | ||
")" | ||
] | ||
}, | ||
|
@@ -297,13 +304,13 @@ | |
"job_time_limit = '1h'\n", | ||
"\n", | ||
"wait_for_job(training_job, 'fgm')\n", | ||
"fgm_job = client.experiments.create_jobs_by_experiment_id(\n", | ||
" experiment_id,\n", | ||
" f\"fgm job for {experiment_id}\",\n", | ||
" queue_id,\n", | ||
" fgm_ep,\n", | ||
" {\"model_name\": MODEL_NAME, \"model_version\": str(-1)}, # -1 means get the latest model\n", | ||
" job_time_limit\n", | ||
"fgm_job = client.experiments.jobs.create(\n", | ||
" experiment_id=experiment_id,\n", | ||
" description=f\"fgm job for {experiment_id}\",\n", | ||
" queue_id=queue_id,\n", | ||
" entrypoint_id=fgm_ep,\n", | ||
" values={\"model_name\": MODEL_NAME, \"model_version\": str(-1)}, # -1 means get the latest model\n", | ||
" timeout=job_time_limit\n", | ||
")" | ||
] | ||
}, | ||
|
@@ -325,18 +332,18 @@ | |
"job_time_limit = '1h'\n", | ||
"\n", | ||
"wait_for_job(training_job, 'patch_gen')\n", | ||
"patch_gen_job = client.experiments.create_jobs_by_experiment_id(\n", | ||
" experiment_id,\n", | ||
" f\"patch generation job for {experiment_id}\",\n", | ||
" queue_id,\n", | ||
" patch_gen_ep,\n", | ||
" {\"model_name\": MODEL_NAME,\n", | ||
"patch_gen_job = client.experiments.jobs.create(\n", | ||
" experiment_id=experiment_id,\n", | ||
" description=f\"patch generation job for {experiment_id}\",\n", | ||
" queue_id=queue_id,\n", | ||
" entrypoint_id=patch_gen_ep,\n", | ||
" values={\"model_name\": MODEL_NAME,\n", | ||
" \"model_version\": str(-1), # -1 means get the latest\n", | ||
" \"rotation_max\": str(180),\n", | ||
" \"max_iter\": str(5000),\n", | ||
" \"learning_rate\": str(5.0),\n", | ||
" },\n", | ||
" job_time_limit\n", | ||
" timeout=job_time_limit\n", | ||
")" | ||
] | ||
}, | ||
|
@@ -356,18 +363,18 @@ | |
"job_time_limit = '1h'\n", | ||
"\n", | ||
"wait_for_job(training_job, 'patch_apply')\n", | ||
"patch_apply_job = client.experiments.create_jobs_by_experiment_id(\n", | ||
" experiment_id,\n", | ||
" f\"patch generation job for {experiment_id}\",\n", | ||
" queue_id,\n", | ||
" patch_apply_ep,\n", | ||
" {\"model_name\": MODEL_NAME, \n", | ||
"patch_apply_job = client.experiments.jobs.create(\n", | ||
" experiment_id=experiment_id,\n", | ||
" description=f\"patch generation job for {experiment_id}\",\n", | ||
" queue_id=queue_id,\n", | ||
" entrypoint_id=patch_apply_ep,\n", | ||
" values={\"model_name\": MODEL_NAME, \n", | ||
" \"model_version\": str(-1), # -1 means get the latest model\n", | ||
" \"job_id\": str(patch_gen_job['id']),# we need the patches we just generated too\n", | ||
" \"patch_scale\": str(0.5),\n", | ||
" \"rotation_max\": str(180),\n", | ||
" }, \n", | ||
" job_time_limit\n", | ||
" timeout=job_time_limit\n", | ||
")" | ||
] | ||
}, | ||
|
@@ -385,20 +392,22 @@ | |
"outputs": [], | ||
"source": [ | ||
"def run_job(experiment_id, queue_id, ep, title, prev_job_id=False, latest_model=False, args=None, prev_job=None, job_time_limit='1h'):\n", | ||
" args = {} if args is None else args\n", | ||
" prev_job = {} if prev_job is None else prev_job\n", | ||
" if prev_job is not None:\n", | ||
" wait_for_job(prev_job, title, quiet=False)\n", | ||
" if prev_job_id:\n", | ||
" if prev_job_id and 'id' in prev_job.keys():\n", | ||
" args['job_id'] = str(prev_job['id'])\n", | ||
" if latest_model:\n", | ||
" args['model_name'] = MODEL_NAME \n", | ||
" args['model_version'] = str(-1)\n", | ||
" job = client.experiments.create_jobs_by_experiment_id(\n", | ||
" experiment_id,\n", | ||
" f\"{title} job for {experiment_id}\",\n", | ||
" queue_id,\n", | ||
" ep,\n", | ||
" args,\n", | ||
" job_time_limit\n", | ||
" job = client.experiments.jobs.create(\n", | ||
" experiment_id=experiment_id,\n", | ||
" description=f\"{title} job for {experiment_id}\",\n", | ||
" queue_id=queue_id,\n", | ||
" entrypoint_id=ep,\n", | ||
" values=args,\n", | ||
" timeout=job_time_limit\n", | ||
" )\n", | ||
" return job\n", | ||
"def get_prev_tar_file(adv=\"def\"):\n", | ||
|
Oops, something went wrong.