Skip to content

Commit

Permalink
multi-task tutorial change from legacy to mbm models (#1778)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1778

Changed legacy GPEI and MT_MTGP from legacy to MBM in the multi-task tutorial.

Reviewed By: saitcakmak

Differential Revision: D48245811

fbshipit-source-id: 9802b781bb0b2ab791a2aa232b92062b7b90a287
  • Loading branch information
Jelena Markovic-Voronov authored and facebook-github-bot committed Aug 14, 2023
1 parent 7fcdb3d commit 5646dcd
Showing 1 changed file with 129 additions and 41 deletions.
170 changes: 129 additions & 41 deletions tutorials/multi_task.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,43 +28,53 @@
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "3ce827be-d20b-48d3-a6ff-291bd442c748"
"originalKey": "3ce827be-d20b-48d3-a6ff-291bd442c748",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"\n",
"from copy import deepcopy\n",
"from typing import Optional\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from scipy.stats import norm\n",
"\n",
"import torch\n",
"\n",
"from ax.core.data import Data\n",
"from ax.core.experiment import Experiment\n",
"from ax.core.generator_run import GeneratorRun\n",
"from ax.core.multi_type_experiment import MultiTypeExperiment\n",
"from ax.core.objective import Objective\n",
"from ax.core.observation import ObservationFeatures, observations_from_data\n",
"from ax.core.optimization_config import OptimizationConfig\n",
"from ax.core.parameter import ParameterType, RangeParameter\n",
"from ax.core.search_space import SearchSpace\n",
"from ax.core.objective import Objective\n",
"from ax.runners.synthetic import SyntheticRunner\n",
"from ax.modelbridge.random import RandomModelBridge\n",
"from ax.core.types import ComparisonOp\n",
"from ax.core.parameter import RangeParameter, ParameterType\n",
"from ax.core.multi_type_experiment import MultiTypeExperiment\n",
"from ax.metrics.hartmann6 import Hartmann6Metric\n",
"from ax.metrics.l2norm import L2NormMetric\n",
"from ax.modelbridge.factory import get_sobol, get_GPEI, get_MTGP_LEGACY\n",
"from ax.core.generator_run import GeneratorRun\n",
"from ax.modelbridge.factory import get_sobol\n",
"from ax.modelbridge.registry import Models, MT_MTGP_trans, ST_MTGP_trans\n",
"from ax.modelbridge.torch import TorchModelBridge\n",
"from ax.modelbridge.transforms.convert_metric_names import tconfig_from_mt_experiment\n",
"from ax.plot.diagnostic import interact_batch_comparison\n",
"from ax.plot.trace import optimization_trace_all_methods\n",
"from ax.runners.synthetic import SyntheticRunner\n",
"from ax.utils.notebook.plotting import init_notebook_plotting, render\n",
"from ax.utils.common.typeutils import checked_cast\n",
"\n",
"init_notebook_plotting()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")"
Expand All @@ -89,7 +99,10 @@
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "2315ca64-74e5-4084-829e-e8a482c653e5"
"originalKey": "2315ca64-74e5-4084-829e-e8a482c653e5",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -133,7 +146,10 @@
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "39504f84-793e-4dae-ae55-068f1b762706"
"originalKey": "39504f84-793e-4dae-ae55-068f1b762706",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -203,7 +219,10 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"originalKey": "8260b668-91ef-404e-aa8c-4bf43f6a5660"
"originalKey": "8260b668-91ef-404e-aa8c-4bf43f6a5660",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -238,7 +257,10 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"originalKey": "3d124563-8a1f-411e-9822-972568ce1970"
"originalKey": "3d124563-8a1f-411e-9822-972568ce1970",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -274,7 +296,10 @@
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "040354c2-4313-46db-b40d-8adc8da6fafb"
"originalKey": "040354c2-4313-46db-b40d-8adc8da6fafb",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
Expand All @@ -291,7 +316,7 @@
" for b in range(n_batches):\n",
" print(\"Online-only batch\", b, time.time() - t1)\n",
" # Fit the GP\n",
" m = get_GPEI(\n",
" m = Models.BOTORCH_MODULAR(\n",
" experiment=exp_online,\n",
" data=exp_online.fetch_data(),\n",
" search_space=exp_online.search_space,\n",
Expand Down Expand Up @@ -323,13 +348,89 @@
"7. <b> Update model and repeat </b> - Update the model with the online observations, and repeat from step 3 for the next batch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def get_MTGP(\n",
" experiment: Experiment,\n",
" data: Data,\n",
" search_space: Optional[SearchSpace] = None,\n",
" trial_index: Optional[int] = None,\n",
" device: torch.device = torch.device(\"cpu\"),\n",
" dtype: torch.dtype = torch.double,\n",
") -> TorchModelBridge:\n",
" \"\"\"Instantiates a Multi-task Gaussian Process (MTGP) model that generates\n",
" points with EI.\n",
"\n",
" If the input experiment is a MultiTypeExperiment then a\n",
" Multi-type Multi-task GP model will be instantiated.\n",
" Otherwise, the model will be a Single-type Multi-task GP.\n",
" \"\"\"\n",
"\n",
" if isinstance(experiment, MultiTypeExperiment):\n",
" trial_index_to_type = {\n",
" t.index: t.trial_type for t in experiment.trials.values()\n",
" }\n",
" transforms = MT_MTGP_trans\n",
" transform_configs = {\n",
" \"TrialAsTask\": {\"trial_level_map\": {\"trial_type\": trial_index_to_type}},\n",
" \"ConvertMetricNames\": tconfig_from_mt_experiment(experiment),\n",
" }\n",
" else:\n",
" # Set transforms for a Single-type MTGP model.\n",
" transforms = ST_MTGP_trans\n",
" transform_configs = None\n",
"\n",
" # Choose the status quo features for the experiment from the selected trial.\n",
" # If trial_index is None, we will look for a status quo from the last\n",
" # experiment trial to use as a status quo for the experiment.\n",
" if trial_index is None:\n",
" trial_index = len(experiment.trials) - 1\n",
" elif trial_index >= len(experiment.trials):\n",
" raise ValueError(\"trial_index is bigger than the number of experiment trials\")\n",
"\n",
" status_quo = experiment.trials[trial_index].status_quo\n",
" if status_quo is None:\n",
" status_quo_features = None\n",
" else:\n",
" status_quo_features = ObservationFeatures(\n",
" parameters=status_quo.parameters,\n",
" trial_index=trial_index, # pyre-ignore[6]\n",
" )\n",
"\n",
" \n",
" return checked_cast(\n",
" TorchModelBridge,\n",
" Models.ST_MTGP(\n",
" experiment=experiment,\n",
" search_space=search_space or experiment.search_space,\n",
" data=data,\n",
" transforms=transforms,\n",
" transform_configs=transform_configs,\n",
" torch_dtype=dtype,\n",
" torch_device=device,\n",
" status_quo_features=status_quo_features,\n",
" ),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "37735b0e-e488-4927-a3da-a7d32d9f1ae0"
"originalKey": "37735b0e-e488-4927-a3da-a7d32d9f1ae0",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -378,7 +479,7 @@
" for b in range(n_batches):\n",
" print(\"Multi-task batch\", b, time.time() - t1)\n",
" # (2 / 7). Fit the MTGP\n",
" m = get_MTGP_LEGACY(\n",
" m = get_MTGP(\n",
" experiment=exp_multitask,\n",
" data=exp_multitask.fetch_data(),\n",
" search_space=exp_multitask.search_space,\n",
Expand All @@ -397,7 +498,7 @@
" exp_multitask.new_batch_trial(trial_type=\"offline\", generator_run=gr).run()\n",
"\n",
" # 5. Update the model\n",
" m = get_MTGP_LEGACY(\n",
" m = get_MTGP(\n",
" experiment=exp_multitask,\n",
" data=exp_multitask.fetch_data(),\n",
" search_space=exp_multitask.search_space,\n",
Expand Down Expand Up @@ -432,7 +533,10 @@
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "f94a7537-61a6-4200-8e56-01de41aff6c9"
"originalKey": "f94a7537-61a6-4200-8e56-01de41aff6c9",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -460,23 +564,7 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down

0 comments on commit 5646dcd

Please sign in to comment.