Skip to content

Commit

Permalink
Update AutoML notebook with more demos.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 564459044
  • Loading branch information
qiuyiz authored and copybara-github committed Sep 11, 2023
1 parent 38e1c0a commit 7f5ee12
Showing 1 changed file with 88 additions and 7 deletions.
95 changes: 88 additions & 7 deletions docs/tutorials/automl_conf_2023.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@
},
"outputs": [],
"source": [
"# @title Test\n",
"# @title Test for Exercise\n",
"experimenter = CookieExperimenter()\n",
"trial = vz.Trial(parameters={'chocolate': 0.5, 'salt': 0.5, 'sugar': 0.5})\n",
"experimenter.evaluate([trial])\n",
Expand Down Expand Up @@ -632,16 +632,16 @@
"benchmark_states = []\n",
"\n",
"grid_designer_factory = grid_lib.GridSearchDesigner.from_problem\n",
"benchmark_state_factory = vzb.DesignerBenchmarkStateFactory(\n",
"grid_state_factory = vzb.DesignerBenchmarkStateFactory(\n",
" experimenter=experimenter, designer_factory=grid_designer_factory\n",
")\n",
"benchmark_states.append(benchmark_state_factory())\n",
"benchmark_states.append(grid_state_factory())\n",
"\n",
"random_designer_factory = random_lib.RandomDesigner.from_problem\n",
"benchmark_state_factory = vzb.DesignerBenchmarkStateFactory(\n",
"random_state_factory = vzb.DesignerBenchmarkStateFactory(\n",
" experimenter=experimenter, designer_factory=random_designer_factory\n",
")\n",
"benchmark_states.append(benchmark_state_factory())"
"benchmark_states.append(random_state_factory())"
]
},
{
Expand All @@ -652,7 +652,7 @@
},
"outputs": [],
"source": [
"# @title Test\n",
"# @title Test for Exercise\n",
"runner = vzb.BenchmarkRunner(\n",
" benchmark_subroutines=[\n",
" vzb.GenerateSuggestions(),\n",
Expand Down Expand Up @@ -682,7 +682,7 @@
},
"outputs": [],
"source": [
"# @title Demonstration\n",
"# @title Analysis Demonstration\n",
"from vizier.benchmarks import analyzers\n",
"import matplotlib.pyplot as plt\n",
"\n",
Expand All @@ -694,6 +694,87 @@
"plt.ylabel('Objective value')\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8sMKUkik7hi4"
},
"source": [
"Finally, we show the flexibility of our basic setup. In a few lines of code, we can accomplish the following with relative ease:\n",
"\n",
"* Add a noisy Cookie benchmark\n",
"* Add discretization to the chocolate Cookie parameter\n",
"* Add normalized metrics for analysis\n",
"* Add another algorithm for comparison\n",
"* Add repeats and error bars"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "2tJ48uvd7bf6"
},
"outputs": [],
"source": [
"from vizier._src.algorithms.designers.eagle_strategy import eagle_strategy\n",
"\n",
"NUM_REPEATS = 5 # @param\n",
"NUM_ITERATIONS = 100 # @param\n",
"\n",
"algorithms = {\n",
" 'grid': grid_lib.GridSearchDesigner.from_problem,\n",
" 'random': random_lib.RandomDesigner.from_problem,\n",
" 'eagle': eagle_strategy.EagleStrategyDesigner,\n",
"}\n",
"\n",
"\n",
"class CookieExperimenterFactory(experimenters.SerializableExperimenterFactory):\n",
"\n",
" def __call__(self, *, seed=None) -\u003e experimenters.Experimenter:\n",
" return CookieExperimenter()\n",
"\n",
" def dump(self):\n",
" return vz.Metadata({'name': 'CookieExperimenter'})\n",
"\n",
"\n",
"experimenter_factories = [\n",
" CookieExperimenterFactory(),\n",
" experimenters.SingleObjectiveExperimenterFactory(\n",
" base_factory=CookieExperimenterFactory(),\n",
" noise_type='SEVERE_ADDITIVE_GAUSSIAN',\n",
" ),\n",
" experimenters.SingleObjectiveExperimenterFactory(\n",
" base_factory=CookieExperimenterFactory(),\n",
" discrete_dict = {0: 4}\n",
" )\n",
"]\n",
"\n",
"records = []\n",
"for experimenter_factory in experimenter_factories:\n",
" for algo_name, algo_factory in algorithms.items():\n",
" benchmark_state_factory = vzb.ExperimenterDesignerBenchmarkStateFactory(\n",
" experimenter_factory=experimenter_factory, designer_factory=algo_factory\n",
" )\n",
" states = []\n",
" for _ in range(NUM_REPEATS):\n",
" benchmark_state = benchmark_state_factory()\n",
" runner.run(benchmark_state)\n",
" states.append(benchmark_state)\n",
" record = analyzers.BenchmarkStateAnalyzer.to_record(\n",
" algorithm=algo_name,\n",
" experimenter_factory=experimenter_factory,\n",
" states=states,\n",
" )\n",
" records.append(record)\n",
"\n",
"analyzed_records = analyzers.BenchmarkRecordAnalyzer.add_comparison_metrics(\n",
" records=records, baseline_algo='random'\n",
")\n",
"analyzers.plot_from_records(analyzed_records, title_maxlen=100, col_figsize=12)"
]
}
],
"metadata": {
Expand Down

0 comments on commit 7f5ee12

Please sign in to comment.