diff --git a/examples/core_examples/pipe_example.yml b/examples/core_examples/pipe_example.yml index 9ea1cdb..7bc5ce8 100644 --- a/examples/core_examples/pipe_example.yml +++ b/examples/core_examples/pipe_example.yml @@ -12,15 +12,26 @@ stages: module_name: rail.creation.engines.flowEngine name: flow_engine_test nprocess: 1 + aliases: + output: output_flow_engine_test - classname: LSSTErrorModel module_name: rail.creation.degraders.lsst_error_model name: lsst_error_model_test nprocess: 1 + aliases: + input: output_flow_engine_test + output: output_lsst_error_model_test - classname: ColumnMapper module_name: rail.tools.table_tools name: col_remapper_test nprocess: 1 + aliases: + input: output_lsst_error_model_test + output: output_col_remapper_test - classname: TableConverter module_name: rail.tools.table_tools name: table_conv_test nprocess: 1 + aliases: + input: output_col_remapper_test + output: output_table_conv_test diff --git a/examples/core_examples/pipe_example_config.yml b/examples/core_examples/pipe_example_config.yml index c871d35..55cbc37 100644 --- a/examples/core_examples/pipe_example_config.yml +++ b/examples/core_examples/pipe_example_config.yml @@ -1,7 +1,4 @@ col_remapper_test: - aliases: - input: output_lsst_error_model_test - output: output_col_remapper_test chunk_size: 100000 columns: mag_g_lsst_err: mag_err_g_lsst @@ -17,8 +14,6 @@ col_remapper_test: name: col_remapper_test output_mode: default flow_engine_test: - aliases: - output: output_flow_engine_test config: null model: ${FLOWDIR}/pretrained_flow.pkl n_samples: 50 @@ -36,9 +31,6 @@ lsst_error_model_test: y: 23.73 z: 24.16 airmass: 1.2 - aliases: - input: output_flow_engine_test - output: output_lsst_error_model_test bandNames: g: mag_g_lsst i: mag_i_lsst @@ -96,9 +88,6 @@ lsst_error_model_test: z: 0.69 tvis: 30.0 table_conv_test: - aliases: - input: output_col_remapper_test - output: output_table_conv_test config: null input: None name: table_conv_test diff --git a/examples/estimation_examples/pzflow_demo.ipynb b/examples/estimation_examples/pzflow_demo.ipynb new file mode 100644 index 0000000..12aec82 --- /dev/null +++ b/examples/estimation_examples/pzflow_demo.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "327d391f-58bc-4b6a-9bbe-3987b969c8f4", + "metadata": {}, + "source": [ + "PZFlow Informer and Estimator Demo\n", + "\n", + "Author: Tianqing Zhang\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "916a05ad", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import rail\n", + "from rail.core.data import TableHandle\n", + "from rail.core.stage import RailStage\n", + "import qp\n", + "import tables_io\n", + "\n", + "from rail.estimation.algos.pzflow_nf import PZFlowInformer, PZFlowEstimator\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8ef87d3", + "metadata": {}, + "outputs": [], + "source": [ + "DS = RailStage.data_store\n", + "DS.__class__.allow_overwrite = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f79c3a7b", + "metadata": {}, + "outputs": [], + "source": [ + "from rail.utils.path_utils import find_rail_file\n", + "trainFile = find_rail_file('examples_data/testdata/test_dc2_training_9816.hdf5')\n", + "testFile = find_rail_file('examples_data/testdata/test_dc2_validation_9816.hdf5')\n", + "training_data = DS.read_file(\"training_data\", TableHandle, trainFile)\n", + "test_data = DS.read_file(\"test_data\", TableHandle, testFile)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "756d78a3", + "metadata": {}, + "outputs": [], + "source": [ + "pzflow_dict = dict(hdf5_groupname='photometry',output_mode = 'not_fiducial' )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0857e6bb-18eb-4f89-bc4b-29bed1ffa122", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1042a9f3", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# epoch = 200 gives a reasonable converged loss\n", + "pzflow_train = PZFlowInformer.make_stage(name='inform_pzflow',model='demo_pzflow.pkl',num_training_epochs = 30, **pzflow_dict)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c407f45b", + "metadata": {}, + "outputs": [], + "source": [ + "# training of the pzflow\n", + "pzflow_train.inform(training_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "156b6e3d", + "metadata": {}, + "outputs": [], + "source": [ + "pzflow_dict = dict(hdf5_groupname='photometry')\n", + "\n", + "pzflow_estimator = PZFlowEstimator.make_stage(name='estimate_pzflow',model='demo_pzflow.pkl',**pzflow_dict, chunk_size = 20000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00911d60", + "metadata": {}, + "outputs": [], + "source": [ + "# estimate using the test data\n", + "estimate_results = pzflow_estimator.estimate(test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cbdece3", + "metadata": {}, + "outputs": [], + "source": [ + "mode = estimate_results.read(force=True).ancil['zmode']\n", + "truth = np.array(test_data.data['photometry']['redshift'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba076bab-c5ab-4292-8de9-415e7b30af5c", + "metadata": {}, + "outputs": [], + "source": [ + "# visualize the prediction. \n", + "plt.figure(figsize = (8,8))\n", + "plt.scatter(truth, mode, s = 0.5)\n", + "plt.xlabel('True Redshift')\n", + "plt.ylabel('Mode of Estimated Redshift')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed5bf266-3b5c-4d9b-8428-77a2833cafef", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}