Skip to content

Commit

Permalink
Merge pull request #155 from LSSTDESC/tqz/pzflow_inform_estimate_note…
Browse files Browse the repository at this point in the history
…book

Add a pzflow inform estimate demo to the examples
  • Loading branch information
ztq1996 committed Aug 2, 2024
2 parents 7e66a94 + 1fca2e6 commit c22b723
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 11 deletions.
11 changes: 11 additions & 0 deletions examples/core_examples/pipe_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,26 @@ stages:
module_name: rail.creation.engines.flowEngine
name: flow_engine_test
nprocess: 1
aliases:
output: output_flow_engine_test
- classname: LSSTErrorModel
module_name: rail.creation.degraders.lsst_error_model
name: lsst_error_model_test
nprocess: 1
aliases:
input: output_flow_engine_test
output: output_lsst_error_model_test
- classname: ColumnMapper
module_name: rail.tools.table_tools
name: col_remapper_test
nprocess: 1
aliases:
input: output_lsst_error_model_test
output: output_col_remapper_test
- classname: TableConverter
module_name: rail.tools.table_tools
name: table_conv_test
nprocess: 1
aliases:
input: output_col_remapper_test
output: output_table_conv_test
11 changes: 0 additions & 11 deletions examples/core_examples/pipe_example_config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
col_remapper_test:
aliases:
input: output_lsst_error_model_test
output: output_col_remapper_test
chunk_size: 100000
columns:
mag_g_lsst_err: mag_err_g_lsst
Expand All @@ -17,8 +14,6 @@ col_remapper_test:
name: col_remapper_test
output_mode: default
flow_engine_test:
aliases:
output: output_flow_engine_test
config: null
model: ${FLOWDIR}/pretrained_flow.pkl
n_samples: 50
Expand All @@ -36,9 +31,6 @@ lsst_error_model_test:
y: 23.73
z: 24.16
airmass: 1.2
aliases:
input: output_flow_engine_test
output: output_lsst_error_model_test
bandNames:
g: mag_g_lsst
i: mag_i_lsst
Expand Down Expand Up @@ -96,9 +88,6 @@ lsst_error_model_test:
z: 0.69
tvis: 30.0
table_conv_test:
aliases:
input: output_col_remapper_test
output: output_table_conv_test
config: null
input: None
name: table_conv_test
Expand Down
179 changes: 179 additions & 0 deletions examples/estimation_examples/pzflow_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "327d391f-58bc-4b6a-9bbe-3987b969c8f4",
"metadata": {},
"source": [
"PZFlow Informer and Estimator Demo\n",
"\n",
"Author: Tianqing Zhang\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "916a05ad",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import rail\n",
"from rail.core.data import TableHandle\n",
"from rail.core.stage import RailStage\n",
"import qp\n",
"import tables_io\n",
"\n",
"from rail.estimation.algos.pzflow_nf import PZFlowInformer, PZFlowEstimator\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8ef87d3",
"metadata": {},
"outputs": [],
"source": [
"DS = RailStage.data_store\n",
"DS.__class__.allow_overwrite = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f79c3a7b",
"metadata": {},
"outputs": [],
"source": [
"from rail.utils.path_utils import find_rail_file\n",
"trainFile = find_rail_file('examples_data/testdata/test_dc2_training_9816.hdf5')\n",
"testFile = find_rail_file('examples_data/testdata/test_dc2_validation_9816.hdf5')\n",
"training_data = DS.read_file(\"training_data\", TableHandle, trainFile)\n",
"test_data = DS.read_file(\"test_data\", TableHandle, testFile)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "756d78a3",
"metadata": {},
"outputs": [],
"source": [
"pzflow_dict = dict(hdf5_groupname='photometry',output_mode = 'not_fiducial' )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0857e6bb-18eb-4f89-bc4b-29bed1ffa122",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1042a9f3",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# epoch = 200 gives a reasonable converged loss\n",
"pzflow_train = PZFlowInformer.make_stage(name='inform_pzflow',model='demo_pzflow.pkl',num_training_epochs = 30, **pzflow_dict)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c407f45b",
"metadata": {},
"outputs": [],
"source": [
"# training of the pzflow\n",
"pzflow_train.inform(training_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "156b6e3d",
"metadata": {},
"outputs": [],
"source": [
"pzflow_dict = dict(hdf5_groupname='photometry')\n",
"\n",
"pzflow_estimator = PZFlowEstimator.make_stage(name='estimate_pzflow',model='demo_pzflow.pkl',**pzflow_dict, chunk_size = 20000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00911d60",
"metadata": {},
"outputs": [],
"source": [
"# estimate using the test data\n",
"estimate_results = pzflow_estimator.estimate(test_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cbdece3",
"metadata": {},
"outputs": [],
"source": [
"mode = estimate_results.read(force=True).ancil['zmode']\n",
"truth = np.array(test_data.data['photometry']['redshift'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba076bab-c5ab-4292-8de9-415e7b30af5c",
"metadata": {},
"outputs": [],
"source": [
"# visualize the prediction. \n",
"plt.figure(figsize = (8,8))\n",
"plt.scatter(truth, mode, s = 0.5)\n",
"plt.xlabel('True Redshift')\n",
"plt.ylabel('Mode of Estimated Redshift')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed5bf266-3b5c-4d9b-8428-77a2833cafef",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit c22b723

Please sign in to comment.