diff --git a/bliss/__init__.py b/bliss/__init__.py index d6b876350..2ccff7dff 100644 --- a/bliss/__init__.py +++ b/bliss/__init__.py @@ -2,8 +2,14 @@ import omegaconf -def make_range(start, stop, step): - return omegaconf.ListConfig(list(range(start, stop, step))) +def make_range(start, stop, step, *args): + orig_range = list(range(start, stop, step)) + for arg in args: + try: + orig_range.remove(arg) + except ValueError: + continue + return omegaconf.ListConfig(orig_range) # resolve ranges in config files (we want this to execute at any entry point) diff --git a/bliss/cached_dataset.py b/bliss/cached_dataset.py index 2294cb819..2547cae1d 100644 --- a/bliss/cached_dataset.py +++ b/bliss/cached_dataset.py @@ -90,7 +90,7 @@ def __call__(self, datum_in): class ChunkingSampler(Sampler): def __init__(self, dataset: Dataset) -> None: - super().__init__() + super().__init__(dataset) assert isinstance(dataset, ChunkingDataset), "dataset should be ChunkingDataset" self.dataset = dataset diff --git a/bliss/conf/base_config.yaml b/bliss/conf/base_config.yaml index 76eddc0c8..4bb136c55 100644 --- a/bliss/conf/base_config.yaml +++ b/bliss/conf/base_config.yaml @@ -152,6 +152,8 @@ metrics: flux_error: _target_: bliss.encoder.metrics.FluxError survey_bands: ${encoder.survey_bands} + bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6] + bin_type: "njymag" image_normalizers: psf: @@ -207,7 +209,10 @@ surveys: fields: # TODO: better arbitary name for fields/bricks? - run: 94 camcol: 1 - fields: [12] + fields: [12] # can also use ${range:start,stop,step,*exclude} + - run: 3900 + camcol: 6 + fields: [269] psf_config: pixel_scale: 0.396 psf_slen: 25 diff --git a/bliss/encoder/metrics.py b/bliss/encoder/metrics.py index 045b33328..5e99f0e27 100644 --- a/bliss/encoder/metrics.py +++ b/bliss/encoder/metrics.py @@ -164,7 +164,7 @@ def __init__( self.bin_type = bin_type self.exclude_last_bin = exclude_last_bin - assert self.bin_type in {"nmgy", "njymag"}, "invalid bin type" + assert self.bin_type in {"mag", "nmgy", "njymag"}, "invalid bin type" detection_metrics = [ "n_true_sources", @@ -343,14 +343,14 @@ def plot(self): axes[0].step( range(len(xlabels)), n_true_sources.tolist(), - label=f"# true sources {fig_tag}", + label=f"Number of true sources {fig_tag}", where="mid", color=c4, ) axes[0].step( range(len(xlabels)), n_true_matches.tolist(), - label=f"# BLISS matches {fig_tag}", + label=f"Number of BLISS matches {fig_tag}", ls="--", where="mid", color=c4, @@ -381,7 +381,7 @@ def __init__( self.bin_type = bin_type assert self.bin_cutoffs, "flux_bin_cutoffs can't be None or empty" - assert self.bin_type in {"nmgy", "njymag"}, "invalid bin type" + assert self.bin_type in {"mag", "nmgy", "njymag"}, "invalid bin type" n_bins = len(self.bin_cutoffs) + 1 @@ -501,59 +501,218 @@ def get_internal_states(self): class FluxError(Metric): - def __init__(self, survey_bands): + def __init__( + self, + survey_bands, + bin_cutoffs: list, + ref_band: int = 2, + bin_type: str = "mag", + exclude_last_bin: bool = False, + ): super().__init__() - self.survey_bands = survey_bands + self.survey_bands = survey_bands # list of band names (e.g. "r") + self.ref_band = ref_band + self.bin_cutoffs = bin_cutoffs + self.bin_type = bin_type + self.exclude_last_bin = exclude_last_bin + self.n_bins = len(self.bin_cutoffs) + 1 - fe_init = torch.zeros(len(self.survey_bands)) - self.add_state("flux_err", default=fe_init, dist_reduce_fx="sum") - self.add_state("n_matches", default=torch.zeros(1), dist_reduce_fx="sum") + self.add_state( + "flux_abs_err", + default=torch.zeros((len(self.survey_bands), self.n_bins)), # n_bins per band + dist_reduce_fx="sum", + ) + self.add_state( + "flux_pct_err", + default=torch.zeros((len(self.survey_bands), self.n_bins)), # n_bins per band + dist_reduce_fx="sum", + ) + self.add_state( + "flux_abs_pct_err", + default=torch.zeros((len(self.survey_bands), self.n_bins)), # n_bins per band + dist_reduce_fx="sum", + ) + self.add_state("n_matches", default=torch.zeros(self.n_bins), dist_reduce_fx="sum") def update(self, true_cat, est_cat, matching): + cutoffs = torch.tensor(self.bin_cutoffs, device=self.device) + true_bin_measures = true_cat.on_fluxes(self.bin_type)[:, :, self.ref_band].contiguous() + for i in range(true_cat.batch_size): tcat_matches, ecat_matches = matching[i] - self.n_matches += tcat_matches.size(0) - - true_flux = true_cat.on_nmgy[i, tcat_matches, :] - est_flux = est_cat.on_nmgy[i, ecat_matches, :] - self.flux_err += (true_flux - est_flux).abs().sum(dim=0) + n_true = true_cat["n_sources"][i].int().sum().item() + bin_measure = true_bin_measures[i, 0:n_true][tcat_matches].contiguous() + bins = torch.bucketize(bin_measure, cutoffs) + + true_flux = true_cat.on_fluxes("nmgy")[i, tcat_matches] + est_flux = est_cat.on_fluxes("nmgy")[i, ecat_matches] + + # Compute and update percent error per band + abs_err = (true_flux - est_flux).abs() + pct_err = (true_flux - est_flux) / true_flux + abs_pct_err = pct_err.abs() + for band in range(len(self.survey_bands)): # noqa: WPS518 + tmp = torch.zeros((self.n_bins,), dtype=torch.float, device=self.device) + tmp = tmp.scatter_add(0, bins.reshape(-1), abs_err[..., band].reshape(-1)) + self.flux_abs_err[band] += tmp + + tmp = torch.zeros((self.n_bins,), dtype=torch.float, device=self.device) + tmp = tmp.scatter_add(0, bins.reshape(-1), pct_err[..., band].reshape(-1)) + self.flux_pct_err[band] += tmp + + tmp = torch.zeros((self.n_bins,), dtype=torch.float, device=self.device) + tmp = tmp.scatter_add(0, bins.reshape(-1), abs_pct_err[..., band].reshape(-1)) + self.flux_abs_pct_err[band] += tmp.abs() + self.n_matches += bins.bincount(minlength=self.n_bins) def compute(self): - avg_flux_err = self.flux_err / self.n_matches + final_idx = -1 if self.exclude_last_bin else None + flux_abs_err = self.flux_abs_err[:, :final_idx] + flux_pct_err = self.flux_pct_err[:, :final_idx] + flux_abs_pct_err = self.flux_abs_pct_err[:, :final_idx] + n_matches = self.n_matches[:final_idx] + + # Compute final metrics + mae = flux_abs_err.sum(dim=1) / n_matches.sum() + binned_mae = flux_abs_err / n_matches + mpe = flux_pct_err.sum(dim=1) / n_matches.sum() + binned_mpe = flux_pct_err / n_matches + mape = flux_abs_pct_err.sum(dim=1) / n_matches.sum() + binned_mape = flux_abs_pct_err / n_matches + results = {} for i, band in enumerate(self.survey_bands): - results[f"flux_err_{band}_mae"] = avg_flux_err[i].item() + results[f"flux_err_{band}_mae"] = mae[i] + results[f"flux_err_{band}_mpe"] = mpe[i] + results[f"flux_err_{band}_mape"] = mape[i] + for j in range(binned_mpe.shape[1]): + results[f"flux_err_{band}_mae_bin_{j}"] = binned_mae[i, j] + results[f"flux_err_{band}_mpe_bin_{j}"] = binned_mpe[i, j] + results[f"flux_err_{band}_mape_bin_{j}"] = binned_mape[i, j] + return results class GalaxyShapeError(Metric): - GALSIM_NAMES = ["disk_frac", "beta_radians", "disk_q", "a_d", "bulge_q", "a_b"] + galaxy_params = [ + "galaxy_disk_frac", + "galaxy_beta_radians", + "galaxy_disk_q", + "galaxy_a_d", + "galaxy_bulge_q", + "galaxy_a_b", + ] + galaxy_param_to_idx = {param: i for i, param in enumerate(galaxy_params)} - def __init__(self): + def __init__( + self, + bin_cutoffs, + ref_band=2, + bin_type="mag", + exclude_last_bin=False, + ): super().__init__() - gpe_init = torch.zeros(len(self.GALSIM_NAMES)) - self.add_state("galsim_param_err", default=gpe_init, dist_reduce_fx="sum") - self.add_state("n_true_galaxies", default=torch.zeros(1), dist_reduce_fx="sum") + self.ref_band = ref_band + self.bin_cutoffs = bin_cutoffs + self.n_bins = len(self.bin_cutoffs) + 1 + self.bin_type = bin_type + self.exclude_last_bin = exclude_last_bin # used to ignore dim objects + + gpe_init = torch.zeros((len(self.galaxy_params), self.n_bins)) + self.add_state("galaxy_param_err", default=gpe_init, dist_reduce_fx="sum") + self.add_state("disk_hlr_err", torch.zeros(self.n_bins), dist_reduce_fx="sum") + self.add_state("bulge_hlr_err", torch.zeros(self.n_bins), dist_reduce_fx="sum") + self.add_state("n_true_galaxies", default=torch.zeros(self.n_bins), dist_reduce_fx="sum") def update(self, true_cat, est_cat, matching): + true_bin_meas = true_cat.on_fluxes(self.bin_type)[:, :, self.ref_band].contiguous() + cutoffs = torch.tensor(self.bin_cutoffs, device=self.device) + for i in range(true_cat.batch_size): tcat_matches, ecat_matches = matching[i] + n_true = true_cat["n_sources"][i].sum().item() + + is_gal = true_cat.galaxy_bools[i][tcat_matches][:, 0] + # Skip if no galaxies in this image + if (~is_gal).all(): + continue + true_matched_mags = true_bin_meas[i, 0:n_true][tcat_matches] + true_gal_mags = true_matched_mags[is_gal] + + # get magnitude bin for each matched galaxy + mag_bins = torch.bucketize(true_gal_mags, cutoffs) + self.n_true_galaxies += mag_bins.bincount(minlength=self.n_bins) + + true_gal_params = true_cat["galaxy_params"][i, tcat_matches][is_gal] - true_gal = true_cat.galaxy_bools[i][tcat_matches] - self.n_true_galaxies += true_gal.sum() + for j, name in enumerate(self.galaxy_params): + true_param = true_gal_params[:, j] + est_param = est_cat[name][i, ecat_matches][is_gal, 0] + abs_res = (true_param - est_param).abs() - # TODO: only compute error for *true* galaxies - true_gp = true_cat["galaxy_params"][i, tcat_matches, :] - est_gp = est_cat["galaxy_params"][i, ecat_matches, :] - gp_err = (true_gp - est_gp).abs().sum(dim=0) - # TODO: angle is a special case, need to wrap around pi (not 2pi due to symmetry) - self.galsim_param_err += gp_err + # Wrap angle around pi + if name == "galaxy_beta_radians": + abs_res = abs_res % torch.pi + + # Update bins + tmp = torch.zeros(self.n_bins, dtype=torch.float, device=self.device) + self.galaxy_param_err[j] += tmp.scatter_add(0, mag_bins, abs_res) + + # Compute HLRs for disk and bulge + true_a_d = true_gal_params[:, self.galaxy_param_to_idx["galaxy_a_d"]] + true_disk_q = true_gal_params[:, self.galaxy_param_to_idx["galaxy_disk_q"]] + true_b_d = true_a_d * true_disk_q + true_disk_hlr = torch.sqrt(true_a_d * true_b_d) + + est_a_d = est_cat["galaxy_a_d"][i, ecat_matches][is_gal, 0] + est_disk_q = est_cat["galaxy_disk_q"][i, ecat_matches][is_gal, 0] + est_b_d = est_a_d * est_disk_q + est_disk_hlr = torch.sqrt(est_a_d * est_b_d) + + true_a_b = true_gal_params[:, self.galaxy_param_to_idx["galaxy_a_b"]] + true_bulge_q = true_gal_params[:, self.galaxy_param_to_idx["galaxy_bulge_q"]] + true_b_b = true_a_b * true_bulge_q + true_bulge_hlr = torch.sqrt(true_a_b * true_b_b) + + est_a_b = est_cat["galaxy_a_b"][i, ecat_matches][is_gal, 0] + est_bulge_q = est_cat["galaxy_bulge_q"][i, ecat_matches][is_gal, 0] + est_b_b = est_a_b * est_bulge_q + est_bulge_hlr = torch.sqrt(est_a_b * est_b_b) + + abs_disk_hlr_res = (true_disk_hlr - est_disk_hlr).abs() + tmp = torch.zeros(self.n_bins, dtype=torch.float, device=self.device) + self.disk_hlr_err += tmp.scatter_add(0, mag_bins, abs_disk_hlr_res) + + abs_bulge_hlr_res = (true_bulge_hlr - est_bulge_hlr).abs() + tmp = torch.zeros(self.n_bins, dtype=torch.float, device=self.device) + self.bulge_hlr_err += tmp.scatter_add(0, mag_bins, abs_bulge_hlr_res) def compute(self): - avg_galsim_param_err = self.galsim_param_err / self.n_true_galaxies + final_idx = -1 if self.exclude_last_bin else None + galaxy_param_err = self.galaxy_param_err[:, :final_idx] + disk_hlr_err = self.disk_hlr_err[:final_idx] + bulge_hlr_err = self.bulge_hlr_err[:final_idx] + n_galaxies = self.n_true_galaxies[:final_idx] + + gal_param_mae = galaxy_param_err.sum(dim=1) / n_galaxies.sum() + binned_gal_param_mae = galaxy_param_err / n_galaxies + + disk_hlr_mae = disk_hlr_err.sum() / n_galaxies.sum() + binned_disk_hlr_mae = disk_hlr_err / n_galaxies + bulge_hlr_mae = bulge_hlr_err.sum() / n_galaxies.sum() + binned_bulge_hlr_mae = bulge_hlr_err / n_galaxies + results = {} - for i, gs_name in enumerate(self.GALSIM_NAMES): - results[f"{gs_name}_mae"] = avg_galsim_param_err[i] + for i, name in enumerate(self.galaxy_params): + results[f"{name}_mae"] = gal_param_mae[i] + for j in range(self.n_bins): + results[f"{name}_mae_bin_{j}"] = binned_gal_param_mae[i, j] + + results["galaxy_disk_hlr_mae"] = disk_hlr_mae + results["galaxy_bulge_hlr_mae"] = bulge_hlr_mae + for j in range(self.n_bins): + results[f"galaxy_disk_hlr_mae_bin_{j}"] = binned_disk_hlr_mae[j] + results[f"galaxy_bulge_hlr_mae_bin_{j}"] = binned_bulge_hlr_mae[j] return results diff --git a/case_studies/psf_variation/README.md b/case_studies/psf_variation/README.md new file mode 100644 index 000000000..58441c76e --- /dev/null +++ b/case_studies/psf_variation/README.md @@ -0,0 +1,18 @@ +Spatially Varying Backgrounds and PSFs Case Study +============================================ + +This case study contains code to reproduce the results and figures for the paper "Neural Posterior Estimation for Cataloging Astronomical Images with Spatially Varying Backgrounds and Point Spread Functions". + +To run the experiments, run +``` +sh run_experiments.sh +``` + +This will generate two synthetic datasets, train three models on the appropriate dataset, and run evaluation and generate figures similar to those seen in the paper. Note that the results may not be exact due to differences in the generated data. + +There are three main config files, one for each model: +- `conf/single_field.yaml` +- `conf/psf_unaware.yaml` +- `conf/psf_aware.yaml` + +You should modify the `paths/cached_data` in each file to point to where you would like to save the generated datasets. diff --git a/case_studies/psf_variation/conf/config.yaml b/case_studies/psf_variation/conf/config.yaml new file mode 100644 index 000000000..10be10845 --- /dev/null +++ b/case_studies/psf_variation/conf/config.yaml @@ -0,0 +1,304 @@ +--- +defaults: + - _self_ + - override /hydra/job_logging: stdout + +# completely disable hydra logging +# https://github.com/facebookresearch/hydra/issues/910 +hydra: + output_subdir: null + run: + dir: . + +paths: + sdss: /data/scratch/sdss + decals: /data/scratch/decals + des: /data/scratch/des + dc2: /data/scratch/dc2local + cached_data: null # TODO: override in other configs + output: ${oc.env:HOME}/bliss_output + +# this prior is sdss-like; the flux parameters were fit using SDSS catalogs +prior: + _target_: bliss.simulator.prior.CatalogPrior + survey_bands: ["u", "g", "r", "i", "z"] # SDSS available band filters + reference_band: 2 # SDSS r-band + star_color_model_path: ${simulator.survey.dir_path}/color_models/star_gmm_nmgy.pkl + gal_color_model_path: ${simulator.survey.dir_path}/color_models/gal_gmm_nmgy.pkl + n_tiles_h: 20 + n_tiles_w: 20 + tile_slen: 4 + batch_size: 64 + max_sources: 1 + mean_sources: 0.01 # 0.0025 is more realistic for SDSS but training takes more iterations + min_sources: 0 + prob_galaxy: 0.5144 + star_flux_exponent: 0.4689157382430609 + # star_flux_truncation: 613313.768995269 + star_flux_truncation: 1000 + star_flux_loc: -0.5534648001193676 + star_flux_scale: 1.1846035501201129 + galaxy_flux_exponent: 1.5609458661807678 + # galaxy_flux_truncation: 28790.449063519092 + galaxy_flux_truncation: 1000 + galaxy_flux_loc: -3.29383532288203 + galaxy_flux_scale: 3.924799999613338 + galaxy_a_concentration: 0.39330758068481686 + galaxy_a_loc: 0.8371888967872619 + galaxy_a_scale: 4.432725319432478 + galaxy_a_bd_ratio: 2.0 + +surveys: + sdss: + _target_: bliss.surveys.sdss.SloanDigitalSkySurvey + dir_path: ${paths.sdss} + fields: + - run: 94 + camcol: 1 + fields: ${range:12,482,20} + - run: 125 + camcol: 1 + fields: ${range:15,565,20,435} + - run: 752 + camcol: 1 + fields: ${range:15,685,20} + - run: 3900 + camcol: 6 + fields: ${range:16,596,20,76} + psf_config: + pixel_scale: 0.396 + psf_slen: 25 + align_to_band: null # we should set this to 2 (r-band) + load_image_data: false + +simulator: + _target_: bliss.simulator.simulated_dataset.SimulatedDataset + survey: ${surveys.sdss} + prior: ${prior} + n_batches: 128 + coadd_depth: 1 + num_workers: 32 + valid_n_batches: 10 # 256 + fix_validation_set: true + +cached_simulator: + _target_: bliss.cached_dataset.CachedSimulatedDataModule + batch_size: 64 + splits: 0:80/80:90/90:100 # train/val/test splits as percent ranges + num_workers: 8 + cached_data_path: ${paths.cached_data} + train_transforms: + - _target_: bliss.data_augmentation.RotateFlipTransform + nontrain_transforms: [] + +variational_factors: + - _target_: bliss.encoder.variational_dist.BernoulliFactor + name: n_sources + sample_rearrange: null + nll_rearrange: null + nll_gating: null + - _target_: bliss.encoder.variational_dist.TDBNFactor + name: locs + sample_rearrange: "b ht wt d -> b ht wt 1 d" + nll_rearrange: "b ht wt 1 d -> b ht wt d" + nll_gating: + _target_: bliss.encoder.variational_dist.SourcesGating + - _target_: bliss.encoder.variational_dist.BernoulliFactor + name: source_type + sample_rearrange: "b ht wt -> b ht wt 1 1" + nll_rearrange: "b ht wt 1 1 -> b ht wt" + nll_gating: + _target_: bliss.encoder.variational_dist.SourcesGating + - _target_: bliss.encoder.variational_dist.LogNormalFactor + name: star_fluxes + dim: 5 + sample_rearrange: "b ht wt d -> b ht wt 1 d" + nll_rearrange: "b ht wt 1 d -> b ht wt d" + nll_gating: + _target_: bliss.encoder.variational_dist.StarGating + - _target_: bliss.encoder.variational_dist.LogNormalFactor + name: galaxy_fluxes + dim: 5 + sample_rearrange: "b ht wt d -> b ht wt 1 d" + nll_rearrange: "b ht wt 1 d -> b ht wt d" + nll_gating: + _target_: bliss.encoder.variational_dist.GalaxyGating + - _target_: bliss.encoder.variational_dist.LogitNormalFactor + name: galaxy_disk_frac + sample_rearrange: "b ht wt d -> b ht wt 1 d" + nll_rearrange: "b ht wt 1 1 -> b ht wt 1" + nll_gating: + _target_: bliss.encoder.variational_dist.GalaxyGating + - _target_: bliss.encoder.variational_dist.LogitNormalFactor + name: galaxy_beta_radians + high: 3.1415926 + sample_rearrange: "b ht wt d -> b ht wt 1 d" + nll_rearrange: "b ht wt 1 1 -> b ht wt 1" + nll_gating: + _target_: bliss.encoder.variational_dist.GalaxyGating + - _target_: bliss.encoder.variational_dist.LogitNormalFactor + name: galaxy_disk_q + sample_rearrange: "b ht wt d -> b ht wt 1 d" + nll_rearrange: "b ht wt 1 1 -> b ht wt 1" + nll_gating: + _target_: bliss.encoder.variational_dist.GalaxyGating + - _target_: bliss.encoder.variational_dist.LogNormalFactor + name: galaxy_a_d + sample_rearrange: "b ht wt d -> b ht wt 1 d" + nll_rearrange: "b ht wt 1 1 -> b ht wt 1" + nll_gating: + _target_: bliss.encoder.variational_dist.GalaxyGating + - _target_: bliss.encoder.variational_dist.LogitNormalFactor + name: galaxy_bulge_q + sample_rearrange: "b ht wt d -> b ht wt 1 d" + nll_rearrange: "b ht wt 1 1 -> b ht wt 1" + nll_gating: + _target_: bliss.encoder.variational_dist.GalaxyGating + - _target_: bliss.encoder.variational_dist.LogNormalFactor + name: galaxy_a_b + sample_rearrange: "b ht wt d -> b ht wt 1 d" + nll_rearrange: "b ht wt 1 1 -> b ht wt 1" + nll_gating: + _target_: bliss.encoder.variational_dist.GalaxyGating + +psf_asinh_normalizers: + psf: + _target_: bliss.encoder.image_normalizer.PsfAsImage + num_psf_params: 6 # 6 for SDSS, 4 for DC2, 10 for DES + asinh: + _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer + q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999] + +asinh_only_normalizer: + asinh: + _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer + q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999] + +all_normalizers: + psf: + _target_: bliss.encoder.image_normalizer.PsfAsImage + num_psf_params: 6 # 6 for SDSS, 4 for DC2, 10 for DES + clahe: + _target_: bliss.encoder.image_normalizer.ClaheNormalizer + min_stdev: 200 + asinh: + _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer + q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999] + +metrics: + detection_performance: + _target_: bliss.encoder.metrics.DetectionPerformance + bin_cutoffs: [17.777, 19.101, 19.781, 20.258, 20.625, 20.940, 21.227, 21.495, 21.746, 22.000] + ref_band: 2 + bin_type: mag + source_type_accuracy: + _target_: bliss.encoder.metrics.SourceTypeAccuracy + bin_cutoffs: [17.777, 19.101, 19.781, 20.258, 20.625, 20.940, 21.227, 21.495, 21.746, 22.000] + ref_band: 2 + bin_type: mag + flux_error: + _target_: bliss.encoder.metrics.FluxError + bin_cutoffs: [17.777, 19.101, 19.781, 20.258, 20.625, 20.940, 21.227, 21.495, 21.746, 22.000] + ref_band: 2 + survey_bands: ${encoder.survey_bands} + bin_type: mag + gal_shape_error: + _target_: bliss.encoder.metrics.GalaxyShapeError + bin_cutoffs: [17.777, 19.101, 19.781, 20.258, 20.625, 20.940, 21.227, 21.495, 21.746, 22.000] + ref_band: 2 + bin_type: mag + +encoder: + _target_: bliss.encoder.encoder.Encoder + survey_bands: ["u", "g", "r", "i", "z"] + reference_band: 2 # SDSS r-band + tile_slen: ${simulator.prior.tile_slen} + optimizer_params: + lr: 1e-3 + scheduler_params: + milestones: [32] + gamma: 0.1 + image_normalizers: null # TODO: override in other configs + var_dist: null # TODO: override in other configs + matcher: + _target_: bliss.encoder.metrics.CatalogMatcher + dist_slack: 1.0 + mag_slack: null + mag_band: 2 # SDSS r-band + mode_metrics: + _target_: torchmetrics.MetricCollection + _convert_: "partial" + metrics: ${metrics} + sample_metrics: null + sample_image_renders: + _target_: torchmetrics.MetricCollection + metrics: + - _target_: bliss.encoder.sample_image_renders.PlotSampleImages + frequency: 1 + restrict_batch: 0 + tiles_to_crop: 0 + tile_slen: ${simulator.prior.tile_slen} + use_double_detect: false + use_checkerboard: false + +####################################################################### +# things above matter only if they are referenced below + +mode: train + +generate: + n_image_files: 8 + n_batches_per_file: 16 + simulator: ${simulator} + cached_data_path: ${paths.cached_data} + file_prefix: dataset + store_full_catalog: false + +train: + trainer: + _target_: pytorch_lightning.Trainer + logger: + _target_: pytorch_lightning.loggers.TensorBoardLogger + save_dir: ${paths.output} + name: null # TODO: override in other configs + version: null # TODO: override in other configs + default_hp_metric: false + reload_dataloaders_every_n_epochs: 0 + check_val_every_n_epoch: 1 + log_every_n_steps: 10 # corresponds to n_batches + min_epochs: 1 + max_epochs: 50 + accelerator: "gpu" + devices: 1 + precision: 32-true + callbacks: + checkpointing: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + filename: best_encoder + save_top_k: 1 + verbose: True + monitor: val/_loss + mode: min + save_on_train_epoch_end: False + auto_insert_metric_name: False + early_stopping: + _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping + monitor: val/_loss + mode: min + patience: 10 + data_source: ${cached_simulator} + encoder: ${encoder} + seed: 12345 + pretrained_weights: null + ckpt_path: null + matmul_precision: high + +predict: + dataset: ${surveys.sdss} + trainer: + _target_: pytorch_lightning.Trainer + accelerator: "gpu" + precision: ${train.trainer.precision} + encoder: ${encoder} + weight_save_path: null + device: "cuda:0" diff --git a/case_studies/psf_variation/conf/psf_aware.yaml b/case_studies/psf_variation/conf/psf_aware.yaml new file mode 100644 index 000000000..1aca9592c --- /dev/null +++ b/case_studies/psf_variation/conf/psf_aware.yaml @@ -0,0 +1,20 @@ +defaults: + - ./@_here_: config + - _self_ + - override hydra/job_logging: stdout + +paths: + cached_data: /data/scratch/aakash/multi_field + +encoder: + image_normalizers: ${psf_asinh_normalizers} + var_dist: + _target_: bliss.encoder.variational_dist.VariationalDist + tile_slen: ${encoder.tile_slen} + factors: ${variational_factors} + +train: + trainer: + logger: + name: PSF_MODELS + version: psf_aware diff --git a/case_studies/psf_variation/conf/psf_unaware.yaml b/case_studies/psf_variation/conf/psf_unaware.yaml new file mode 100644 index 000000000..2b17a8ff2 --- /dev/null +++ b/case_studies/psf_variation/conf/psf_unaware.yaml @@ -0,0 +1,20 @@ +defaults: + - ./@_here_: config + - _self_ + - override hydra/job_logging: stdout + +paths: + cached_data: /data/scratch/aakash/multi_field + +encoder: + image_normalizers: ${asinh_only_normalizer} + var_dist: + _target_: bliss.encoder.variational_dist.VariationalDist + tile_slen: ${encoder.tile_slen} + factors: ${variational_factors} + +train: + trainer: + logger: + name: PSF_MODELS + version: psf_unaware diff --git a/case_studies/psf_variation/conf/single_field.yaml b/case_studies/psf_variation/conf/single_field.yaml new file mode 100644 index 000000000..c38bbea08 --- /dev/null +++ b/case_studies/psf_variation/conf/single_field.yaml @@ -0,0 +1,27 @@ +defaults: + - ./@_here_: config + - _self_ + - override hydra/job_logging: stdout + +paths: + cached_data: /data/scratch/aakash/single_field + +surveys: + sdss: + fields: + - run: 94 + camcol: 1 + fields: [12] + +encoder: + image_normalizers: ${asinh_only_normalizer} + var_dist: + _target_: bliss.encoder.variational_dist.VariationalDist + tile_slen: ${encoder.tile_slen} + factors: ${variational_factors} + +train: + trainer: + logger: + name: PSF_MODELS + version: single_field \ No newline at end of file diff --git a/case_studies/psf_variation/config.yaml b/case_studies/psf_variation/config.yaml deleted file mode 100644 index 73032b9b1..000000000 --- a/case_studies/psf_variation/config.yaml +++ /dev/null @@ -1,29 +0,0 @@ ---- -defaults: - - ../../bliss/conf@_here_: base_config - - _self_ - - override hydra/job_logging: stdout - -encoder: - survey_bands: ["r"] - tile_slen: 4 - min_flux_for_loss: 1.59 # 22 magnitude (0.63 is 23 mag; 0.25 is 24 mag) - use_double_detect: false - image_normalizer: - include_original: true - concat_psf_params: true - num_psf_params: 6 # for SDSS, 4 for DC2 - log_transform_stdevs: [0] - use_clahe: false - -cached_simulator: - cached_data_path: /data/scratch/aakash/train_multi_field - batch_size: 64 - splits: 0:80/80:90/90:100 # train/val/test splits as percent ranges - -train: - trainer: - logger: - name: PSF_MODELS - version: multi_field_psf_param - precision: 32 diff --git a/case_studies/psf_variation/data/clear_psf.pt b/case_studies/psf_variation/data/clear_psf.pt deleted file mode 100644 index 022d8e8b6..000000000 --- a/case_studies/psf_variation/data/clear_psf.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ad6fe30747ab7b134abbeb848c9682fca646c65e418c6a7f66f81daa31d3952d -size 3151876 diff --git a/case_studies/psf_variation/data/cloudy_psf.pt b/case_studies/psf_variation/data/cloudy_psf.pt deleted file mode 100644 index ff810c82e..000000000 --- a/case_studies/psf_variation/data/cloudy_psf.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:02ee3625bf879dd7c89e319d48f54354c6605f329af5c0580b703a63501997f7 -size 3151876 diff --git a/case_studies/psf_variation/evaluate_models.ipynb b/case_studies/psf_variation/evaluate_models.ipynb deleted file mode 100644 index a93da9898..000000000 --- a/case_studies/psf_variation/evaluate_models.ipynb +++ /dev/null @@ -1,258 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "from os import environ\n", - "import torch\n", - "\n", - "from hydra import initialize, compose\n", - "from hydra.utils import instantiate\n", - "\n", - "import pandas as pd\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with initialize(config_path=\".\", version_base=None):\n", - " cfg = compose(\"config\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# evaluate on fields not trained on\n", - "test_path = \"/data/scratch/aakash/test_small\"\n", - "test_dataset = instantiate(cfg.cached_simulator, cached_data_path=test_path, splits=\"0:100/0:100/0:100\")\n", - "trainer = instantiate(cfg.train.trainer, logger=None)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BASE_PATH = \"../../output/PSF_MODELS/single_field_base/checkpoints/epoch15.ckpt\"\n", - "UNAWARE_PATH = \"../../output/PSF_MODELS/multi_field_psf_unaware/checkpoints/epoch19.ckpt\"\n", - "PARAMS_ONLY_PATH = \"../../output/PSF_MODELS/multi_field_psf_params_only/checkpoints/epoch15.ckpt\"\n", - "\n", - "base_model = instantiate(cfg.encoder, image_normalizer={\"concat_psf_params\": False})\n", - "base_model.load_state_dict(torch.load(BASE_PATH)[\"state_dict\"])\n", - "base_model.eval();\n", - "\n", - "unaware_model = instantiate(cfg.encoder, image_normalizer={\"concat_psf_params\": False})\n", - "unaware_model.load_state_dict(torch.load(UNAWARE_PATH)[\"state_dict\"])\n", - "unaware_model.eval();\n", - "\n", - "params_only_model = instantiate(cfg.encoder, image_normalizer={\"concat_psf_params\": True})\n", - "params_only_model.load_state_dict(torch.load(PARAMS_ONLY_PATH)[\"state_dict\"])\n", - "params_only_model.eval();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Base model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "base_results = trainer.test(base_model, datamodule=test_dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### PSF-unaware model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "unaware_results = trainer.test(unaware_model, datamodule=test_dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Concat params only model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "params_results = trainer.test(params_only_model, datamodule=test_dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Concatenate results into dataframe" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "models = {\n", - " \"base\": (base_model, base_results),\n", - " \"unaware\": (unaware_model, unaware_results),\n", - " \"params\": (params_only_model, params_results),\n", - "}\n", - "\n", - "# Results\n", - "keys = list(base_results[0].keys())\n", - "data = { model_name: [results[0][key] for key in keys] for model_name, (_, results) in models.items() }\n", - "data_flat = pd.DataFrame.from_dict(data, orient=\"index\", columns=[key.split(\"/\")[1] for key in keys]).reset_index()\n", - "data_flat= data_flat.rename(columns={\"index\": \"model\"})\n", - "data_melt = pd.melt(data_flat, id_vars=\"model\", value_vars=[key.split(\"/\")[1] for key in keys], var_name=\"metric\", value_name=\"value\")\n", - "data_melt.to_csv(\"psf_model_results.csv\")\n", - "\n", - "from IPython.display import HTML\n", - "HTML(data_flat.to_html())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Plot Results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_results(data, ncols=3, title=None):\n", - " sns.set_style('ticks')\n", - " sns.set(font_scale=0.8)\n", - "\n", - " hue = \"bin\" if \"bin\" in data.columns else None\n", - "\n", - " g = sns.catplot(\n", - " data,\n", - " kind=\"bar\",\n", - " x=\"model\", y=\"value\", col=\"metric\", hue=hue,\n", - " sharex=False, sharey=False, col_wrap=ncols,\n", - " height=3, aspect=1.5,\n", - " palette=\"dark\", alpha=0.6,\n", - " legend=True\n", - " )\n", - " g.set_titles(template=\"{col_name}\")\n", - "\n", - " for ax in g.axes:\n", - " remove_ticks = False\n", - " heights = []\n", - " for container in ax.containers:\n", - " heights.extend([rect.get_height() for rect in container.patches])\n", - " median = np.median(heights)\n", - "\n", - " for container in ax.containers:\n", - " orig_heights = [rect.get_height() for rect in container]\n", - " # clip outlier heights\n", - " for rect in container.patches:\n", - " if rect.get_height() > np.abs(5 * median):\n", - " rect.set_height(np.abs(5 * median))\n", - " remove_ticks = True\n", - "\n", - " # add labels\n", - " labels = ax.bar_label(container, labels=[f\"{height:.3f}\" for height in orig_heights], fontsize=6)\n", - "\n", - " new_heights = []\n", - " for container in ax.containers:\n", - " new_heights.extend([rect.get_height() for rect in container.patches])\n", - " ax.set_ylim(min(0, min(new_heights) * 1.1 + 0.1), max(new_heights) * 1.1 + 0.1)\n", - "\n", - " ax.tick_params(axis=\"x\", labelsize=6)\n", - "\n", - " # remove y ticks and labels\n", - " if remove_ticks:\n", - " ax.set(yticklabels=[])\n", - " \n", - " if title:\n", - " fig = g.axes[-1].get_figure()\n", - " plt.suptitle(title)\n", - " fig.set_tight_layout(True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "keep_keys = [\"f1\", \"star_fluxes_r_mae\", \"gal_fluxes_r_mae\", \"disk_hlr_mae\", \"bulge_hlr_mae\"]\n", - "data_to_plot = data_melt[np.isin(data_melt[\"metric\"], keep_keys)]\n", - "\n", - "data_to_plot = data_to_plot.replace({\"base\": \"Single Field\", \"unaware\": \"PSF-unaware\", \"params\": \"PSF Encoding\"})\n", - "data_to_plot = data_to_plot.replace({\"f1\": \"F1-score\", \"galaxy_fluxes_r_mae\": \"Galaxy flux, median estimation error\", \"star_fluxes_r_mae\": \"Star flux, median estimation error\", \"disk_hlr_mae\": \"Disk half-light radius, median estimation error\", \"bulge_hlr_mae\": \"Bulge half-light radius, median estimation error\"})\n", - "plot_results(data_to_plot, ncols=2, title=\"Results on simulated data with variable PSFs\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "bliss-deblender-av05Bskt-py3.10", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/case_studies/psf_variation/evaluate_models.py b/case_studies/psf_variation/evaluate_models.py new file mode 100644 index 000000000..336e51a54 --- /dev/null +++ b/case_studies/psf_variation/evaluate_models.py @@ -0,0 +1,1309 @@ +# flake8: noqa +# pylint: skip-file +# Ignoring flake8/pylint for this file since this is just a plotting script + +import os +import argparse + +import torch +from torch.utils.data import DataLoader +import pandas as pd +import numpy as np +from sklearn.metrics import roc_curve, roc_auc_score + +import matplotlib +import matplotlib.pyplot as plt +import seaborn as sns + +from tqdm import tqdm + +from hydra import initialize, compose +from hydra.utils import instantiate + +from bliss.catalog import TileCatalog, convert_mag_to_nmgy + +# Set up plots, colors, global constants +sns.set_theme("paper") +matplotlib.rc("text", usetex=True) +plt.rc("font", family="serif") + +COLORS = [ + "#0072BD", # blue + "#D95319", # orange + "#EDB120", # yellow + "#7E2F8E", # purple + "#77AC30", # green + "#4DBEEE", # light blue + "#A2142F", # dark red +] + +BRIGHT_THRESHOLD = 17.777 +FAINT_THRESHOLD = 21.746 + +# Parse arguments +parser = argparse.ArgumentParser() +parser.add_argument("--run_eval", action="store_true") +parser.add_argument("--plot_eval", action="store_true") +parser.add_argument("--run_calibration", action="store_true") +parser.add_argument("--plot_calibration", action="store_true") +parser.add_argument("--data_path", type=str, required=True, help="Path to test data directory") + +args = parser.parse_args() + +# Load config, data, models +with initialize(config_path="./conf", version_base=None): + base_cfg = compose("config") + +data_path = args.data_path # "/data/scratch/aakash/multi_field" +dataset_name = data_path.split("/")[-1] # used for save directory names + +cached_dataset = instantiate( + base_cfg.cached_simulator, cached_data_path=data_path, splits="0:0/0:0/90:100" +) +cached_dataset.setup(stage="test") +calib_dataloader = cached_dataset.test_dataloader() +print(f"Test dataset size: {len(cached_dataset.test_dataset)}") + +trainer = instantiate(base_cfg.train.trainer, logger=None) + +models = { + "single_field": { + "ckpt_path": "/home/aakashdp/bliss_output/PSF_MODELS/single_field_with_gal_params/checkpoints/best_encoder.ckpt", + "config_path": "single_field.yaml", + "plot_config": {"name": "Single-field", "marker": "o", "color": COLORS[0]}, + }, + "psf_unaware": { + "ckpt_path": "/home/aakashdp/bliss_output/PSF_MODELS/psf_unaware_with_gal_params/checkpoints/best_encoder.ckpt", + "config_path": "psf_unaware.yaml", + "plot_config": {"name": "PSF-unaware", "marker": "s", "color": COLORS[1]}, + }, + "psf_aware": { + "ckpt_path": "/home/aakashdp/bliss_output/PSF_MODELS/psf_aware_with_gal_params/checkpoints/best_encoder.ckpt", + "config_path": "psf_aware.yaml", + "plot_config": {"name": "PSF-aware", "marker": "^", "color": COLORS[2]}, + }, +} + +rep_key = list(models.keys())[0] + +for model_name, model_info in models.items(): + with initialize(config_path="./conf", version_base=None): + cfg = compose(model_info["config_path"]) + + encoder = instantiate(cfg.encoder) + encoder.load_state_dict(torch.load(model_info["ckpt_path"], map_location="cpu")["state_dict"]) + encoder.eval() + model_info["encoder"] = encoder + model_info["config"] = cfg + + +def run_eval(): + """Compute metrics and standard deviations for each model.""" + # Compute metrics for each model + for model_name in models: + print(f"Evaluating {model_name} model...") + results = trainer.test( + models[model_name]["encoder"], datamodule=cached_dataset, verbose=False + ) + models[model_name]["results"] = results + + # Compute bootstrap variance + N_samples = 5 + orig_test_slice = cached_dataset.slices[2] + orig_start = orig_test_slice.start + orig_stop = orig_test_slice.stop + + data_for_var = { + model: {key: [] for key in models[rep_key]["results"][0].keys()} for model in models + } + + for i in tqdm(range(N_samples), desc=f"Bootstrapping {N_samples} samples"): + random_batch = np.random.randint(orig_start, orig_stop - 1) + cached_dataset.slices[2] = slice(random_batch, random_batch + 1) + cached_dataset.setup(stage="test") + + for model_name in models: + results = trainer.test( + models[model_name]["encoder"], dataloaders=cached_dataset, verbose=False + ) + + for key, val in results[0].items(): + data_for_var[model_name][key].append(val) + + cached_dataset.slices[2] = orig_test_slice + + stds = { + model: {f"{key}_std": np.nanstd(val) for key, val in data_for_var[model].items()} + for model in data_for_var + } + + # Concatenate results into dataframe + keys = list(models[rep_key]["results"][0].keys()) + keys.extend(stds[rep_key].keys()) + + data = {} + for model_name in models: + results = models[model_name]["results"][0] | stds[model_name] + model_vals = [results[key] for key in keys] + data[model_name] = model_vals + + data_flat = pd.DataFrame.from_dict( + data, orient="index", columns=[key.split("/")[-1] for key in keys] + ).reset_index() + data_flat = data_flat.rename(columns={"index": "model"}) + data_flat = data_flat.set_index("model") + + os.makedirs(f"data/{dataset_name}", exist_ok=True) + with open(f"data/{dataset_name}/metrics.pt", "wb") as f: + torch.save(data_flat.to_dict(), f) + + +def run_calibration(): + """Get posterior distributions for each model and save to disk.""" + # Precompute predicted distributions + pred_dists = {} + assert torch.cuda.is_available(), "ERROR: GPU not found." + device = "cuda" + + with torch.no_grad(): + for model_name, model in models.items(): + encoder = model["encoder"].to(device) + pred_dists[model_name] = [] + + for batch in tqdm( + calib_dataloader, desc=f"Getting calibration metrics for {model_name}..." + ): + batch_size, _n_bands, h, w = batch["images"].shape[0:4] + ht, wt = h // encoder.tile_slen, w // encoder.tile_slen + + target_cat = TileCatalog(batch["tile_catalog"]) + target_cat = target_cat.filter_by_flux(min_flux=1.59, band=2) + target_cat = target_cat.to(device) + + batch["images"] = batch["images"].to(device) + batch["psf_params"] = batch["psf_params"].to(device) + + # Get predicted params + x_features = encoder.get_features(batch) + patterns_to_use = (0,) # no checkerboard + mask_pattern = encoder.mask_patterns[patterns_to_use, ...][0] + mask = mask_pattern.repeat([batch_size, ht // 2, wt // 2]) + context1 = encoder.make_context(target_cat, mask) + x_cat1 = encoder.catalog_net(x_features, context1) + factor_param_zip = encoder.var_dist._factor_param_pairs(x_cat1) + + batch_dists = {} + for factor, params in factor_param_zip: + params = params.to("cpu") + factor_name = factor.name + dist = factor._get_dist(params) + batch_dists[factor_name] = dist + + pred_dists[model_name].append(batch_dists) + + os.makedirs("data", exist_ok=True) + torch.save(pred_dists, f"data/{dataset_name}/posterior_dists.pt") + + +def plot_metric(data, metric, metric_cfg): + """Plot results for a specific metric. + + Args: + data (dict): All metrics results computed by run_eval. + metric (str): Name of metric to plot. Should be a key in data. + metric_cfg (dict): Config info like the type of metric and ylabel to use for the plot. + """ + fig, ax = plt.subplots(figsize=(7.25, 5)) + + bins = [17.777, 19.101, 19.781, 20.258, 20.625, 20.940, 21.227, 21.495, 21.746, 22.000] + n_bins = len(bins) + xlabel = "r-band magnitude" + + # Plot each model + for i, name in enumerate(data[metric].keys()): + plot_config = models[name]["plot_config"] + + # Get metric values in each bin + binned_vals = np.array([data[f"{metric}_bin_{j}"][name] for j in range(n_bins)]) + + # Construct label for legend: "model_name [average value for metric]"" + model_name = plot_config["name"] + average_val = data[metric][name] + label = f"{model_name} [{average_val:.3f}]" + + # Plot line + ax.plot( + binned_vals, + c=plot_config["color"], + markeredgecolor="k", + markersize=6, + linewidth=1, + marker=plot_config["marker"], + label=label, + ) + + # Fill in +/- one standard deviation + if f"{metric}_bin_0_std" in data and name in data[f"{metric}_bin_0_std"]: + binned_stds = np.array([data[f"{metric}_bin_{j}_std"][name] for j in range(n_bins)]) + lower = binned_vals - binned_stds + upper = binned_vals + binned_stds + + ax.fill_between( + np.arange(len(lower)), lower, upper, color=plot_config["color"], alpha=0.2 + ) + + # Place bin values on xticks + xticklabels = [f"{bins[i+1]:.2f}" for i in range(n_bins - 1)] + xticklabels.insert(0, f"$<${bins[0]:.2f}") + ax.set_xticks(range(len(xticklabels)), xticklabels) + ax.tick_params(axis="both", which="major", labelsize="x-large") + + if metric_cfg["yaxis_in_percent"]: + ax.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(xmax=1.0)) + + # Set axis labels and legend + ax.set_xlabel(xlabel, fontsize="xx-large") + ax.set_ylabel(metric_cfg["ylabel"], fontsize="xx-large") + ax.legend(fontsize="x-large") + + # Save figure + fig.tight_layout() + plt.savefig(f"plots/{dataset_name}/metrics/{metric}.pdf") + + +def compute_expected_sources(pred_dists, bins, cached_path): + """Compute expected value of number of sources and save to disk. + + Args: + pred_dists (dict): Dictionary of calibration results computed by run_calibration. + bins (list): List of magnitude bins. + cached_path (str): Where to save the resulting data. + + Returns: + dict: Dictionary of results. + """ + n_bins = len(bins) + sum_all = {name: torch.zeros(n_bins) for name in models} + all_count = {name: torch.zeros(n_bins) for name in models} + + sum_bright = {name: torch.zeros(n_bins) for name in models} + bright_count = {name: torch.zeros(n_bins) for name in models} + + sum_dim = {name: torch.zeros(n_bins) for name in models} + dim_count = {name: torch.zeros(n_bins) for name in models} + + for i, batch in enumerate(tqdm(calib_dataloader, desc="Computing expected number of sources")): + target_cat = TileCatalog(batch["tile_catalog"]) + target_cat = target_cat.filter_by_flux(min_flux=1.59, band=2) + + normal_mask = (target_cat.on_fluxes("mag")[..., 2] < 22).squeeze() + bright_mask = (target_cat.on_fluxes("mag")[..., 2] < BRIGHT_THRESHOLD).squeeze() + dim_mask = (target_cat.on_fluxes("mag")[..., 2] > FAINT_THRESHOLD).squeeze() * normal_mask + + true_sources = (target_cat["n_sources"].bool() * normal_mask).sum(dim=(1, 2)) + true_bright = (target_cat["n_sources"].bool() * bright_mask).sum(dim=(1, 2)) + true_dim = (target_cat["n_sources"].bool() * dim_mask).sum(dim=(1, 2)) + + binned_true = torch.bucketize(true_sources, bins) + binned_bright = torch.bucketize(true_bright, bins) + binned_dim = torch.bucketize(true_dim, bins) + + for name in models: + on_prob = pred_dists[name][i]["n_sources"].probs[..., 1] + + all_on = on_prob.sum(dim=(1, 2)) + bright_on = (on_prob * bright_mask).sum(dim=(1, 2)) + dim_on = (on_prob * dim_mask).sum(dim=(1, 2)) + + tmp = torch.zeros(n_bins, dtype=on_prob.dtype) + sum_all[name] += tmp.scatter_add(0, binned_true, all_on) + all_count[name] += binned_true.bincount(minlength=n_bins) + + tmp = torch.zeros(n_bins, dtype=on_prob.dtype) + sum_bright[name] += tmp.scatter_add(0, binned_bright, bright_on) + bright_count[name] += binned_bright.bincount(minlength=n_bins) + + tmp = torch.zeros(n_bins, dtype=on_prob.dtype) + sum_dim[name] += tmp.scatter_add(0, binned_dim, dim_on) + dim_count[name] += binned_dim.bincount(minlength=n_bins) + + source_data = {"mean_sources_per_bin": {}, "mean_bright_per_bin": {}, "mean_dim_per_bin": {}} + for name in models: + source_data["mean_sources_per_bin"][name] = sum_all[name] / all_count[name] + source_data["mean_bright_per_bin"][name] = sum_bright[name] / bright_count[name] + source_data["mean_dim_per_bin"][name] = sum_dim[name] / dim_count[name] + + torch.save(source_data, cached_path) + return source_data + + +def plot_expected_vs_predicted_sources(mean_sources_per_bin, mean_bright_per_bin, mean_dim_per_bin): + """Plot the expected value of number of sources vs true number of sources. + + Args: + mean_sources_per_bin (dict): Average number of sources in each bin. + mean_bright_per_bin (dict): Average number of bright sources in each bin. + mean_dim_per_bin (dict): Average number of dim sources in each bin. + """ + fig, ax = plt.subplots(1, 3, figsize=(5 * 3, 5)) + + bins = torch.arange(20) + n_bins = len(bins) + shared_params = {"markeredgecolor": "k", "markersize": 6, "linewidth": 1} + + # All + max_n_all = max( + [ + torch.argmax(torch.arange(n_bins) * ~mean_sources_per_bin[name].isnan()) + for name in models + ] + ) + ax[0].plot(torch.arange(max_n_all + 1), c="darkgray", linewidth=1, linestyle="dashed") + for name in mean_sources_per_bin: + plot_config = models[name]["plot_config"] + ax[0].plot( + mean_sources_per_bin[name], + c=plot_config["color"], + marker=plot_config["marker"], + label=plot_config["name"], + **shared_params, + ) + ax[0].legend(fontsize="x-large") + ax[0].set_title("All sources", fontsize="xx-large") + + # Bright + max_n_bright = max( + [torch.argmax(torch.arange(n_bins) * ~mean_bright_per_bin[name].isnan()) for name in models] + ) + ax[1].plot(torch.arange(max_n_bright + 1), c="darkgray", linewidth=1, linestyle="dashed") + for name in mean_bright_per_bin: + plot_config = models[name]["plot_config"] + ax[1].plot( + mean_bright_per_bin[name], + c=plot_config["color"], + marker=plot_config["marker"], + label=plot_config["name"], + **shared_params, + ) + ax[1].set_title(f"Bright (magnitude $<$ {BRIGHT_THRESHOLD:.2f})", fontsize="xx-large") + + # Faint + max_n_dim = max( + [torch.argmax(torch.arange(n_bins) * ~mean_dim_per_bin[name].isnan()) for name in models] + ) + ax[2].plot(torch.arange(max_n_dim + 1), c="darkgray", linewidth=1, linestyle="dashed") + for name in mean_dim_per_bin: + plot_config = models[name]["plot_config"] + ax[2].plot( + mean_dim_per_bin[name], + c=plot_config["color"], + marker=plot_config["marker"], + label=plot_config["name"], + **shared_params, + ) + ax[2].set_title(f"Faint (magnitude {FAINT_THRESHOLD:.2f}-22)", fontsize="xx-large") + + for a in ax: + a.tick_params(axis="both", which="major", labelsize="x-large") + a.set_xlabel("True sources", fontsize="xx-large") + ax[0].set_ylabel("Detected sources", fontsize="xx-large") + + fig.tight_layout() + plt.savefig(f"plots/{dataset_name}/calibration/true_vs_pred_sources_by_mag.pdf") + + +def compute_prob_flux_within_one_mag(pred_dists, bins, cached_path): + """Compute probability of estimated flux being within 1 magnitude of true flux. + + Args: + pred_dists (dict): Calibration results computed by run_calibration + bins (list): Magnitude bins + cached_path (str): Where to save the results + + Returns: + Dict: dictionary of results + """ + n_bins = len(bins) + sum_probs = {name: torch.zeros(n_bins) for name in models} + bin_count = {name: torch.zeros(n_bins) for name in models} + + for i, batch in enumerate(tqdm(calib_dataloader, desc="Prob flux within 1 of true mag")): + # Get target catalog and magnitudes, construct upper and lower bounds + target_cat = TileCatalog(batch["tile_catalog"]) + target_cat = target_cat.filter_by_flux(min_flux=1.59, band=2) + + target_mags = target_cat.on_fluxes("mag") + lb = convert_mag_to_nmgy(target_mags + 1).squeeze() + ub = convert_mag_to_nmgy(target_mags - 1).squeeze() + + target_on_mags = target_mags[target_cat.is_on_mask][:, 2].contiguous() + binned_target_on_mags = torch.bucketize(target_on_mags, bins) + + for name, model in models.items(): + # Get probabilities of flux within 1 magnitude of true + q_star_flux = pred_dists[name][i]["star_fluxes"].base_dist + q_gal_flux = pred_dists[name][i]["galaxy_fluxes"].base_dist + + star_flux_probs = q_star_flux.cdf(ub) - q_star_flux.cdf(lb) + gal_flux_probs = q_gal_flux.cdf(ub) - q_gal_flux.cdf(lb) + + pred_probs = torch.where( + target_cat.star_bools, star_flux_probs.unsqueeze(-2), gal_flux_probs.unsqueeze(-2) + ) + pred_probs = pred_probs[target_cat.is_on_mask][:, 2] + + probs_per_bin = torch.zeros(n_bins, dtype=pred_probs.dtype) + sum_probs[name] += probs_per_bin.scatter_add(0, binned_target_on_mags, pred_probs) + bin_count[name] += binned_target_on_mags.bincount(minlength=n_bins) + + binned_avg_flux_probs = {} + for name in models: + binned_avg_flux_probs[name] = sum_probs[name] / bin_count[name] + torch.save(binned_avg_flux_probs, cached_path) + return binned_avg_flux_probs + + +def plot_flux_within_one_mag(binned_avg_flux_probs, bins): + """Plot the probability of estimated flux being within 1 magnitude of true flux. + + Args: + binned_avg_flux_probs (dict): results from compute_avg_flux_probs + bins (list): Magnitude bins + """ + fig, ax = plt.subplots(figsize=(7.25, 5)) + for i, name in enumerate(models): + if name not in binned_avg_flux_probs: + continue + + plot_config = models[name]["plot_config"] + + binned_vals = binned_avg_flux_probs[name].detach() + label = plot_config["name"] + ax.plot( + binned_vals, + c=plot_config["color"], + markeredgecolor="k", + markersize=6, + linewidth=1, + marker=plot_config["marker"], + label=label, + ) + + xticklabels = [f"{bins[i+1]:.2f}" for i in range(len(bins) - 1)] + xticklabels.insert(0, f"$<${bins[0]:.2f}") + ax.set_xticks(range(len(xticklabels)), xticklabels) + + ax.tick_params(axis="both", which="major", labelsize="xx-large") + ax.set_xlabel("r-band Magnitude", fontsize="xx-large") + ax.set_ylabel("Pr($f_{\mathrm{pred}}$ within 1 magnitude)", fontsize="xx-large") + ax.legend(fontsize="x-large") + + fig.tight_layout() + plt.savefig(f"plots/{dataset_name}/calibration/prob_flux_within_1_mag.pdf") + + +def compute_prop_flux_in_interval(pred_dists, intervals, cached_path): + """Compute proportion of sources that fall in equal-tailed credible intervals. + + Args: + pred_dists (dict): Calibration results computed by run_calibration + intervals (list): List of credible interval sizes + cached_path (str): Where to save the results + + Returns: + Dict: dictionary of results + """ + sum_all_in_eti = {name: torch.zeros(len(intervals)) for name in models} + sum_bright_in_eti = {name: torch.zeros(len(intervals)) for name in models} + sum_dim_in_eti = {name: torch.zeros(len(intervals)) for name in models} + all_count = 0 + bright_count = 0 + dim_count = 0 + + for i, batch in enumerate(tqdm(calib_dataloader, desc="Computing prob in credible interval")): + target_cat = TileCatalog(batch["tile_catalog"]) + target_cat = target_cat.filter_by_flux(min_flux=1.59, band=2) + true_fluxes = target_cat.on_fluxes("nmgy")[..., 0, 2] + + normal_mask = (target_cat.on_fluxes("mag")[..., 2] < 22).squeeze() + bright_mask = (target_cat.on_fluxes("mag")[..., 2] < BRIGHT_THRESHOLD).squeeze() + dim_mask = (target_cat.on_fluxes("mag")[..., 2] > FAINT_THRESHOLD).squeeze() * normal_mask + + all_count += target_cat["n_sources"].sum() + bright_count += bright_mask.sum() + dim_count += dim_mask.sum() + + for name, model in models.items(): + q_star_flux = pred_dists[name][i]["star_fluxes"].base_dist + q_gal_flux = pred_dists[name][i]["galaxy_fluxes"].base_dist + + for j, interval in enumerate(intervals): + # construct equal tail intervals and determine if true flux is within ETI + tail_prob = (1 - interval) / 2 + star_lb = q_star_flux.icdf(tail_prob)[..., 2] + star_ub = q_star_flux.icdf(1 - tail_prob)[..., 2] + gal_lb = q_gal_flux.icdf(tail_prob)[..., 2] + gal_ub = q_gal_flux.icdf(1 - tail_prob)[..., 2] + + star_flux_in_eti = (true_fluxes >= star_lb) & (true_fluxes <= star_ub) + gal_flux_in_eti = (true_fluxes >= gal_lb) & (true_fluxes <= gal_ub) + + source_in_eti = torch.where( + target_cat.star_bools.squeeze(), star_flux_in_eti, gal_flux_in_eti + ) + + sum_all_in_eti[name][j] += (source_in_eti * target_cat.is_on_mask.squeeze()).sum() + sum_bright_in_eti[name][j] += (source_in_eti * bright_mask).sum() + sum_dim_in_eti[name][j] += (source_in_eti * dim_mask).sum() + + # Compute proportions and save data + prop_all_in_eti = {} + prop_bright_in_eti = {} + prop_dim_in_eti = {} + for name in models: + prop_all_in_eti[name] = sum_all_in_eti[name] / all_count + prop_bright_in_eti[name] = sum_bright_in_eti[name] / bright_count + prop_dim_in_eti[name] = sum_dim_in_eti[name] / dim_count + + data = { + "prop_all_in_eti": prop_all_in_eti, + "prop_bright_in_eti": prop_bright_in_eti, + "prop_dim_in_eti": prop_dim_in_eti, + } + torch.save(data, cached_path) + return data + + +def plot_prop_flux_in_interval(prop_all_in_eti, prop_bright_in_eti, prop_dim_in_eti, intervals): + """Plot proportion of sources that fall in credible interval. + + Args: + prop_all_in_eti (dict): From compute_prop_in_interval + prop_bright_in_eti (dict): From compute_prop_in_interval + prop_dim_in_eti (dict): From compute_prop_in_interval + intervals (List): List of credible intervals + """ + fig, ax = plt.subplots(1, 3, figsize=(5 * 3, 5)) + + ax[0].plot(intervals, intervals, color="darkgray", linewidth=1, linestyle="dashed") + ax[1].plot(intervals, intervals, color="darkgray", linewidth=1, linestyle="dashed") + ax[2].plot(intervals, intervals, color="darkgray", linewidth=1, linestyle="dashed") + + for name in models: + plot_config = models[name]["plot_config"] + kwargs = { + "color": plot_config["color"], + "marker": plot_config["marker"], + "markeredgecolor": "k", + "markersize": 6, + "linewidth": 1, + "label": plot_config["name"], + } + ax[0].plot(intervals, prop_all_in_eti[name], **kwargs) + ax[1].plot(intervals, prop_bright_in_eti[name], **kwargs) + ax[2].plot(intervals, prop_dim_in_eti[name], **kwargs) + + ax[0].legend(fontsize="x-large") + ax[0].set_title("All sources", fontsize="xx-large") + ax[1].set_title(f"Bright (magnitude $<$ {BRIGHT_THRESHOLD:.2f})", fontsize="xx-large") + ax[2].set_title(f"Faint (magnitude {FAINT_THRESHOLD:.2f}-22)", fontsize="xx-large") + + xticks = (intervals * 100).int().tolist() + for a in ax: + a.set_xticks(intervals, xticks) + a.tick_params(axis="both", which="major", labelsize="x-large") + a.set_xlabel("\% Credible Interval", fontsize="xx-large") + ax[0].set_ylabel("Proportion of true fluxes in interval", fontsize="xx-large") + + fig.tight_layout() + plt.savefig(f"plots/{dataset_name}/calibration/prop_flux_in_interval_by_mag.pdf") + + +def compute_avg_prob_true_source_type(pred_dists, bins, cached_path): + """Compute average probability of correct class by magnitude bin. + + Args: + pred_dists (dict): Calibration results computed by run_calibration + bins (list): Magnitude bins + cached_path (str): Where to save the results + + Returns: + Dict: dictionary of results + """ + n_bins = len(bins) + sum_probs = {name: torch.zeros(n_bins) for name in models} + bin_count = {name: torch.zeros(n_bins) for name in models} + + for i, batch in enumerate(tqdm(calib_dataloader, desc="Computing prob of true class")): + target_cat = TileCatalog(batch["tile_catalog"]) + target_cat = target_cat.filter_by_flux(min_flux=1.59, band=2) + + target_mags = target_cat.on_fluxes("mag") + target_on_mags = target_mags[target_cat.is_on_mask][:, 2].contiguous() + binned_target_on_mags = torch.bucketize(target_on_mags, bins) + + target_types = target_cat["source_type"].squeeze().bool() + + for name, model in models.items(): + gal_probs = pred_dists[name][i]["source_type"].probs[..., 1] + + true_type_prob = torch.where(target_types, gal_probs, 1 - gal_probs) + on_gal_probs = true_type_prob[target_cat.is_on_mask.squeeze()] + + probs_per_bin = torch.zeros(n_bins, dtype=true_type_prob.dtype) + sum_probs[name] += probs_per_bin.scatter_add(0, binned_target_on_mags, on_gal_probs) + bin_count[name] += binned_target_on_mags.bincount(minlength=n_bins) + + binned_source_type_probs = {} + for name in models: + binned_source_type_probs[name] = sum_probs[name] / bin_count[name] + torch.save(binned_source_type_probs, cached_path) + return binned_source_type_probs + + +def plot_prob_true_source_type(binned_source_type_probs, bins): + """Plot average probability of true source type. + + Args: + binned_source_type_probs (dict): From compute_source_type_probs + bins (list): Magnitude bins + """ + fig, ax = plt.subplots(figsize=(7.25, 5)) + + for name in binned_source_type_probs: + plot_config = models[name]["plot_config"] + + binned_vals = binned_source_type_probs[name].detach() + ax.plot( + binned_vals, + c=plot_config["color"], + markeredgecolor="k", + markersize=6, + linewidth=1, + marker=plot_config["marker"], + label=plot_config["name"], + ) + + xticklabels = [f"{bins[i+1]:.2f}" for i in range(len(bins) - 1)] + xticklabels.insert(0, f"$<${bins[0]:.2f}") + + ax.set_xticks(range(len(xticklabels)), xticklabels) + ax.tick_params(axis="both", which="major", labelsize="xx-large") + ax.set_xlabel("r-band Magnitude", fontsize="xx-large") + ax.set_ylabel("Probability of correct classification", fontsize="xx-large") + ax.legend(fontsize="x-large") + + fig.tight_layout() + plt.savefig(f"plots/{dataset_name}/calibration/prob_true_source_type.pdf") + + +def compute_classification_probs_by_threshold(pred_dists, thresholds, cached_path): + """Compute probability of correct classification by decision threshold. + + Args: + pred_dists (dict): Calibration results computed by run_calibration + thresholds (list): Classification thresholds + cached_path (str): Where to save the results + + Returns: + Dict: dictionary of results + """ + pred_all_gal = {name: torch.zeros(len(thresholds)) for name in models} + pred_bright_gal = {name: torch.zeros(len(thresholds)) for name in models} + pred_dim_gal = {name: torch.zeros(len(thresholds)) for name in models} + + pred_all_star = {name: torch.zeros(len(thresholds)) for name in models} + pred_bright_star = {name: torch.zeros(len(thresholds)) for name in models} + pred_dim_star = {name: torch.zeros(len(thresholds)) for name in models} + + true_all_gal = 0 + true_bright_gal = 0 + true_dim_gal = 0 + + true_all_star = 0 + true_bright_star = 0 + true_dim_star = 0 + + for i, batch in enumerate(tqdm(calib_dataloader, desc="Prob correct star/gal by threshold")): + target_cat = TileCatalog(batch["tile_catalog"]) + target_cat = target_cat.filter_by_flux(min_flux=1.59, band=2) + + normal_mask = (target_cat.on_fluxes("mag")[..., 2] < 22).squeeze() + bright_mask = (target_cat.on_fluxes("mag")[..., 2] < BRIGHT_THRESHOLD).squeeze() + dim_mask = (target_cat.on_fluxes("mag")[..., 2] > FAINT_THRESHOLD).squeeze() * normal_mask + + true_all_gal += target_cat.galaxy_bools.sum() + true_bright_gal += (target_cat.galaxy_bools.squeeze() * bright_mask).sum() + true_dim_gal += (target_cat.galaxy_bools.squeeze() * dim_mask).sum() + + true_all_star += target_cat.star_bools.sum() + true_bright_star += (target_cat.star_bools.squeeze() * bright_mask).sum() + true_dim_star += (target_cat.star_bools.squeeze() * dim_mask).sum() + + for name, model in models.items(): + gal_probs = pred_dists[name][i]["source_type"].probs[..., 1] + star_probs = 1 - gal_probs + + all_gal_probs = gal_probs * target_cat.galaxy_bools.squeeze() + bright_gal_probs = gal_probs * bright_mask * target_cat.galaxy_bools.squeeze() + dim_gal_probs = gal_probs * dim_mask * target_cat.galaxy_bools.squeeze() + + all_star_probs = star_probs * target_cat.star_bools.squeeze() + bright_star_probs = star_probs * bright_mask * target_cat.star_bools.squeeze() + dim_star_probs = star_probs * dim_mask * target_cat.star_bools.squeeze() + + for j, threshold in enumerate(thresholds): + pred_all_gal[name][j] += (all_gal_probs > threshold).sum() + pred_bright_gal[name][j] += (bright_gal_probs > threshold).sum() + pred_dim_gal[name][j] += (dim_gal_probs > threshold).sum() + + pred_all_star[name][j] += (all_star_probs > threshold).sum() + pred_bright_star[name][j] += (bright_star_probs > threshold).sum() + pred_dim_star[name][j] += (dim_star_probs > threshold).sum() + + prop_all_gal = {} + prop_bright_gal = {} + prop_dim_gal = {} + prop_all_star = {} + prop_bright_star = {} + prop_dim_star = {} + for name in models: + prop_all_gal[name] = pred_all_gal[name] / true_all_gal + prop_bright_gal[name] = pred_bright_gal[name] / true_bright_gal + prop_dim_gal[name] = pred_dim_gal[name] / true_dim_gal + prop_all_star[name] = pred_all_star[name] / true_all_star + prop_bright_star[name] = pred_bright_star[name] / true_bright_star + prop_dim_star[name] = pred_dim_star[name] / true_dim_star + + data = { + "prop_all_gal": prop_all_gal, + "prop_bright_gal": prop_bright_gal, + "prop_dim_gal": prop_dim_gal, + "prop_all_star": prop_all_star, + "prop_bright_star": prop_bright_star, + "prop_dim_star": prop_dim_star, + } + torch.save(data, cached_path) + return data + + +def plot_classification_by_threshold(prop_all, prop_bright, prop_dim, source_type, thresholds): + """Plot classification accuracy by threshold + + Args: + prop_all (dict): from compute_classification_probs_by_threshold + prop_bright (dict): from compute_classification_probs_by_threshold + prop_dim (dict): from compute_classification_probs_by_threshold + source_type (dict): from compute_classification_probs_by_threshold + thresholds (list): List of classification thresholds + """ + fig, ax = plt.subplots(1, 3, figsize=(5 * 3, 5)) + + ax[0].hlines([1], 0, 1, color="darkgray", linewidth=1, linestyle="dashed") + ax[1].hlines([1], 0, 1, color="darkgray", linewidth=1, linestyle="dashed") + ax[2].hlines([1], 0, 1, color="darkgray", linewidth=1, linestyle="dashed") + + for name in models: + plot_config = models[name]["plot_config"] + kwargs = { + "color": plot_config["color"], + "marker": plot_config["marker"], + "markeredgecolor": "k", + "markersize": 6, + "linewidth": 1, + "label": plot_config["name"], + } + ax[0].plot(thresholds, prop_all[name], **kwargs) + ax[1].plot(thresholds, prop_bright[name], **kwargs) + ax[2].plot(thresholds, prop_dim[name], **kwargs) + + ax[0].legend(fontsize="x-large") + ax[0].set_title("All sources", fontsize="xx-large") + ax[1].set_title(f"Bright (magnitude $<$ {BRIGHT_THRESHOLD:.2f})", fontsize="xx-large") + ax[2].set_title(f"Faint (magnitude {FAINT_THRESHOLD:.2f}-22)", fontsize="xx-large") + + for a in ax: + a.tick_params(axis="both", which="major", labelsize="x-large") + a.set_xlabel("Threshold", fontsize="xx-large") + ax[0].set_ylabel(f"Proportion of true {source_type} predicted", fontsize="xx-large") + + fig.tight_layout() + plt.savefig(f"plots/{dataset_name}/calibration/prop_{source_type}_threshold_by_mag.pdf") + + +def compute_source_type_roc_curve(pred_dists, cached_path): + """Compute the ROC curve for source type classification at different thresholds. + + Args: + pred_dists (dict): Calibration results computed by run_calibration + cached_path (str): Where to save the results + + Returns: + Dict: dictionary of results + """ + all_true, bright_true, dim_true = [], [], [] + all_pred = {name: [] for name in models} + bright_pred = {name: [] for name in models} + dim_pred = {name: [] for name in models} + + for i, batch in enumerate(tqdm(calib_dataloader, desc="Prob correct star/gal by threshold")): + target_cat = TileCatalog(batch["tile_catalog"]) + target_cat = target_cat.filter_by_flux(min_flux=1.59, band=2) + + normal_mask = (target_cat.on_fluxes("mag")[..., 2] < 22).squeeze() + bright_mask = (target_cat.on_fluxes("mag")[..., 2] < BRIGHT_THRESHOLD).squeeze() + dim_mask = (target_cat.on_fluxes("mag")[..., 2] > FAINT_THRESHOLD).squeeze() * normal_mask + on_mask = target_cat.is_on_mask.squeeze() + + true_source_type = target_cat["source_type"].squeeze() + + all_true.extend(true_source_type[on_mask * normal_mask].tolist()) + bright_true.extend(true_source_type[on_mask * bright_mask].tolist()) + dim_true.extend(true_source_type[on_mask * dim_mask].tolist()) + + for name, model in models.items(): + gal_probs = pred_dists[name][i]["source_type"].probs[..., 1] + + all_pred[name].extend(gal_probs[on_mask * normal_mask]) + bright_pred[name].extend(gal_probs[on_mask * bright_mask]) + dim_pred[name].extend(gal_probs[on_mask * dim_mask]) + + all_roc = {} + bright_roc = {} + dim_roc = {} + for name in models: + all_fpr, all_tpr, _ = roc_curve(all_true, all_pred[name]) + all_auc = roc_auc_score(all_true, all_pred[name]) + bright_fpr, bright_tpr, _ = roc_curve(bright_true, bright_pred[name]) + bright_auc = roc_auc_score(bright_true, bright_pred[name]) + dim_fpr, dim_tpr, _ = roc_curve(dim_true, dim_pred[name]) + dim_auc = roc_auc_score(dim_true, dim_pred[name]) + + all_roc[name] = {"fpr": all_fpr, "tpr": all_tpr, "auc": all_auc} + bright_roc[name] = {"fpr": bright_fpr, "tpr": bright_tpr, "auc": bright_auc} + dim_roc[name] = {"fpr": dim_fpr, "tpr": dim_tpr, "auc": dim_auc} + + data = { + "all_roc": all_roc, + "bright_roc": bright_roc, + "dim_roc": dim_roc, + } + torch.save(data, cached_path) + return data + + +def plot_source_type_roc_curve(all_roc, bright_roc, dim_roc): + """Plot ROC curve for source type classification. + + Args: + all_roc (dict): from compute_source_type_roc_curve + bright_roc (dict): from compute_source_type_roc_curve + dim_roc (dict): from compute_source_type_roc_curve + """ + fig, ax = plt.subplots(1, 3, figsize=(5 * 3, 5)) + + ax[0].plot(np.linspace(0, 1, 2), color="darkgray", linewidth=1, linestyle="dashed") + ax[1].plot(np.linspace(0, 1, 2), color="darkgray", linewidth=1, linestyle="dashed") + ax[2].plot(np.linspace(0, 1, 2), color="darkgray", linewidth=1, linestyle="dashed") + + for name in models: + plot_config = models[name]["plot_config"] + kwargs = { + "color": plot_config["color"], + "linewidth": 1.5, + } + ax[0].plot( + all_roc[name]["fpr"], + all_roc[name]["tpr"], + label=f"{plot_config['name']} [{all_roc[name]['auc']:.3f}]", + **kwargs, + ) + ax[1].plot( + bright_roc[name]["fpr"], + bright_roc[name]["tpr"], + label=f"{plot_config['name']} [{bright_roc[name]['auc']:.3f}]", + **kwargs, + ) + ax[2].plot( + dim_roc[name]["fpr"], + dim_roc[name]["tpr"], + label=f"{plot_config['name']} [{dim_roc[name]['auc']:.3f}]", + **kwargs, + ) + + ax[0].legend(fontsize="x-large") + ax[1].legend(fontsize="x-large") + ax[2].legend(fontsize="x-large") + + ax[0].set_title("All sources", fontsize="xx-large") + ax[1].set_title(f"Bright (magnitude $<$ {BRIGHT_THRESHOLD:.2f})", fontsize="xx-large") + ax[2].set_title(f"Faint (magnitude {FAINT_THRESHOLD:.2f}-22)", fontsize="xx-large") + + for a in ax: + a.tick_params(axis="both", which="major", labelsize="x-large") + a.set_xlabel("Galaxy False Positive Rate", fontsize="xx-large") + ax[0].set_ylabel(f"Galaxy True Positive Rate", fontsize="xx-large") + + fig.tight_layout() + plt.savefig(f"plots/{dataset_name}/calibration/source_type_roc.pdf") + + +def compute_ci_width(pred_dists, bins, cached_path): + """Compute average flux credible interval width and average flux standard deviation. + + Args: + pred_dists (dict): Calibration results computed by run_calibration + bins (list): Magnitude bins + cached_path (str): Where to save the results + + Returns: + Dict: dictionary of results + """ + interval = 0.95 + tail_prob = torch.tensor((1 - interval) / 2) + n_bins = len(bins) + + ci_width = {name: torch.zeros(n_bins) for name in models} + ci_width_prop = {name: torch.zeros(n_bins) for name in models} + flux_scale = {name: torch.zeros(n_bins) for name in models} + bin_count = torch.zeros(n_bins) + + for i, batch in enumerate(tqdm(calib_dataloader, desc="CI width")): + target_cat = TileCatalog(batch["tile_catalog"]) + target_cat = target_cat.filter_by_flux(min_flux=1.59, band=2) + + target_on_mags = target_cat.on_fluxes("mag")[target_cat.is_on_mask][:, 2].contiguous() + binned_target_on_mags = torch.bucketize(target_on_mags, bins) + + bin_count += binned_target_on_mags.bincount(minlength=n_bins) + + for name, model in models.items(): + q_star_flux = pred_dists[name][i]["star_fluxes"].base_dist + q_gal_flux = pred_dists[name][i]["galaxy_fluxes"].base_dist + + # construct equal tail intervals + star_intervals = q_star_flux.icdf(1 - tail_prob) - q_star_flux.icdf(tail_prob) + gal_intervals = q_gal_flux.icdf(1 - tail_prob) - q_gal_flux.icdf(tail_prob) + + # Compute CI width for true sources based on source type + width = torch.where( + target_cat.star_bools, star_intervals.unsqueeze(-2), gal_intervals.unsqueeze(-2) + ) + width = width[target_cat.is_on_mask][:, 2] + width[width == torch.inf] = 0 # temp hack to not get inf + + tmp = torch.zeros(n_bins, dtype=width.dtype) + ci_width[name] += tmp.scatter_add(0, binned_target_on_mags, width) + + tmp = torch.zeros(n_bins, dtype=width.dtype) + ci_width_prop[name] += tmp.scatter_add( + 0, + binned_target_on_mags, + width / target_cat.on_fluxes("nmgy")[target_cat.is_on_mask][:, 2], + ) + + # Get flux scale for true sources based on source type + scale = torch.where( + target_cat.star_bools, + q_star_flux.scale.unsqueeze(-2), + q_gal_flux.scale.unsqueeze(-2), + ) + scale = scale[target_cat.is_on_mask][:, 2] + + scales_per_bin = torch.zeros(n_bins, dtype=scale.dtype) + flux_scale[name] += scales_per_bin.scatter_add(0, binned_target_on_mags, scale) + + for name in models: + ci_width[name] = ci_width[name] / bin_count + ci_width_prop[name] = ci_width_prop[name] / bin_count + flux_scale[name] = flux_scale[name] / bin_count + data = {"ci_width": ci_width, "ci_width_prop": ci_width_prop, "flux_scale": flux_scale} + torch.save(data, cached_path) + return data + + +def plot_ci_width_data(data, plot_name, cfg_dict, bins): + """Plot credible interval width and average standard deviation. + + Args: + data (dict): from compute_ci_width + plot_name (str): name of metric being plotted (for saving) + cfg_dict (dict): dictionary of plot config + bins (list): Magnitude bins + """ + fig, ax = plt.subplots(figsize=(7.25, 5)) + + for name in data: + plot_config = models[name]["plot_config"] + + binned_vals = data[name].detach() + kwargs = { + "color": plot_config["color"], + "marker": plot_config["marker"], + "markeredgecolor": "k", + "markersize": 6, + "linewidth": 1, + "label": plot_config["name"], + } + ax.plot(binned_vals, **kwargs) + + xticklabels = [f"{bins[i+1]:.2f}" for i in range(len(bins) - 1)] + xticklabels.insert(0, f"$<${bins[0]:.2f}") + ax.set_xticks(range(len(xticklabels)), xticklabels) + + ax.tick_params(axis="both", which="major", labelsize="xx-large") + ax.set_xlabel("r-band Magnitude", fontsize="xx-large") + ax.set_ylabel(cfg_dict["ylabel"], fontsize=cfg_dict["ylabel_size"]) + ax.legend(fontsize="x-large") + + fig.tight_layout() + plt.savefig(f"plots/{dataset_name}/calibration/{plot_name}_by_mag.pdf") + + +################################################# +# Evaluate models +################################################# +if args.run_eval: + print("Computing metrics for eval") + run_eval() + + +################################################# +# Plot eval results +################################################# +if args.plot_eval: + print("Plotting eval results") + # Load saved results + cached_path = f"data/{dataset_name}/metrics.pt" + assert os.path.exists( + cached_path + ), f"ERROR: could not find cached metrics at {cached_path}. Try running with the --run_eval flag." + data = torch.load(f"data/{dataset_name}/metrics.pt") + + # Choose metrics to plot and specify labels, marker, and color for each model + metrics_to_plot = { + "detection_precision": { + "ylabel": "Precision", + "metric_class": "detection_performance", + "yaxis_in_percent": False, + }, + "detection_recall": { + "ylabel": "Recall", + "metric_class": "detection_performance", + "yaxis_in_percent": False, + }, + "detection_f1": { + "ylabel": "F1-Score", + "metric_class": "detection_performance", + "yaxis_in_percent": False, + }, + "classification_acc": { + "ylabel": "Classification Accuracy", + "metric_class": "source_type_accuracy", + "yaxis_in_percent": False, + }, + "flux_err_r_mpe": { + "ylabel": "r-band Flux Mean \% Error", + "metric_class": "flux_error", + "yaxis_in_percent": True, + }, + "flux_err_r_mape": { + "ylabel": "r-band Flux Mean Abosolute \% Error", + "metric_class": "flux_error", + "yaxis_in_percent": True, + }, + "galaxy_disk_frac_mae": { + "ylabel": "Disk fraction of flux MAE", + "metric_class": "gal_shape_error", + "yaxis_in_percent": False, + }, + "galaxy_beta_radians_mae": { + "ylabel": "Angle MAE", + "metric_class": "gal_shape_error", + "yaxis_in_percent": False, + }, + "galaxy_disk_q_mae": { + "ylabel": "Disk minor:major ratio MAE", + "metric_class": "gal_shape_error", + "yaxis_in_percent": False, + }, + "galaxy_a_d_mae": { + "ylabel": "Disk major axis MAE", + "metric_class": "gal_shape_error", + "yaxis_in_percent": False, + }, + "galaxy_bulge_q_mae": { + "ylabel": "Bulge minor:major ratio MAE", + "metric_class": "gal_shape_error", + "yaxis_in_percent": False, + }, + "galaxy_a_b_mae": { + "ylabel": "Bulge major axis MAE", + "metric_class": "gal_shape_error", + "yaxis_in_percent": False, + }, + "galaxy_disk_hlr_mae": { + "ylabel": "Disk HLR MAE", + "metric_class": "gal_shape_error", + "yaxis_in_percent": False, + }, + "galaxy_bulge_hlr_mae": { + "ylabel": "Bulge HLR MAE", + "metric_class": "gal_shape_error", + "yaxis_in_percent": False, + }, + } + + # Plot! + os.makedirs(f"plots/{dataset_name}/metrics", exist_ok=True) + for metric, metric_cfg in metrics_to_plot.items(): + plot_metric(data, metric, metric_cfg) + + +################################################# +# Run calibration +################################################# +if args.run_calibration: + print("Computing calibration results") + run_calibration() + + +################################################# +# Plot calibration results +################################################# +if args.plot_calibration: + print("Plotting calibration results") + cached_path = ( + f"/home/aakashdp/bliss/case_studies/psf_variation/data/{dataset_name}/posterior_dists.pt" + ) + assert os.path.exists( + cached_path + ), f"ERROR: could not find cached calibration data at {cached_path}. Try running with the --run_calibration flag." + pred_dists = torch.load(cached_path) + os.makedirs(f"plots/{dataset_name}/calibration", exist_ok=True) + + ### Expected number of sources + bins = torch.arange(20) + cached_path = f"data/{dataset_name}/true_vs_pred_sources.pt" + if os.path.exists(cached_path): + print(f"Loading cached data from {cached_path}") + source_data = torch.load(cached_path) + else: + source_data = compute_expected_sources(pred_dists, bins, cached_path) + + mean_sources_per_bin = source_data["mean_sources_per_bin"] + mean_bright_per_bin = source_data["mean_bright_per_bin"] + mean_dim_per_bin = source_data["mean_dim_per_bin"] + + plot_expected_vs_predicted_sources(mean_sources_per_bin, mean_bright_per_bin, mean_dim_per_bin) + + ### Probability predicted magnitude is within x of true magnitude + bins = torch.tensor( + [17.777, 19.101, 19.781, 20.258, 20.625, 20.940, 21.227, 21.495, 21.746, 22.000] + ) + cached_path = f"data/{dataset_name}/prob_flux_within_one_mag.pt" + if os.path.exists(cached_path): # Load cached data if exists + print(f"Loading cached data from {cached_path}") + binned_avg_flux_probs = torch.load(cached_path) + else: + binned_avg_flux_probs = compute_prob_flux_within_one_mag(pred_dists, bins, cached_path) + + plot_flux_within_one_mag(binned_avg_flux_probs, bins) + + ### Proportion of true fluxes in credible interval + intervals = torch.linspace(0.5, 1, 11) + cached_path = f"data/{dataset_name}/prop_flux_in_interval.pt" + if os.path.exists(cached_path): + print(f"Loading cached data from {cached_path}") + data = torch.load(cached_path) + else: + data = compute_prop_flux_in_interval(pred_dists, intervals, cached_path) + + prop_all_in_eti = data["prop_all_in_eti"] + prop_bright_in_eti = data["prop_bright_in_eti"] + prop_dim_in_eti = data["prop_dim_in_eti"] + + plot_prop_flux_in_interval(prop_all_in_eti, prop_bright_in_eti, prop_dim_in_eti, intervals) + + ### Prob source type by magnitude + bins = torch.tensor( + [17.777, 19.101, 19.781, 20.258, 20.625, 20.940, 21.227, 21.495, 21.746, 22.000] + ) + cached_path = f"data/{dataset_name}/true_source_type_probs.pt" + if os.path.exists(cached_path): + print(f"Loading cached data from {cached_path}") + binned_source_type_probs = torch.load(cached_path) + else: + binned_source_type_probs = compute_avg_prob_true_source_type(pred_dists, bins, cached_path) + + plot_prob_true_source_type(binned_source_type_probs, bins) + + ### Prob correct galaxy / star by threshold + thresholds = torch.linspace(0, 1, 11) + thresholds[-1] = 0.99 # 1 is trivial - 0.99 gives more information + cached_path = f"data/{dataset_name}/source_type_classification_by_threshold.pt" + if os.path.exists(cached_path): + print(f"Loading cached data from {cached_path}") + data = torch.load(cached_path) + else: + data = compute_classification_probs_by_threshold(pred_dists, thresholds, cached_path) + + prop_all_gal = data["prop_all_gal"] + prop_bright_gal = data["prop_bright_gal"] + prop_dim_gal = data["prop_dim_gal"] + prop_all_star = data["prop_all_star"] + prop_bright_star = data["prop_bright_star"] + prop_dim_star = data["prop_dim_star"] + + # Plot gal + plot_classification_by_threshold( + prop_all_gal, prop_bright_gal, prop_dim_gal, "galaxies", thresholds + ) + + # Plot star + plot_classification_by_threshold( + prop_all_star, prop_bright_star, prop_dim_star, "stars", thresholds + ) + + ## Source type classification ROC curve + cached_path = ( + f"/home/aakashdp/bliss/case_studies/psf_variation/data/{dataset_name}/source_type_roc.pt" + ) + if os.path.exists(cached_path): + print(f"Loading cached data from {cached_path}") + data = torch.load(cached_path) + else: + data = compute_source_type_roc_curve(pred_dists, cached_path) + + all_roc = data["all_roc"] + bright_roc = data["bright_roc"] + dim_roc = data["dim_roc"] + + plot_source_type_roc_curve(all_roc, bright_roc, dim_roc) + + ### CI width / standard deviation vs magnitude + bins = torch.tensor( + [17.777, 19.101, 19.781, 20.258, 20.625, 20.940, 21.227, 21.495, 21.746, 22.000] + ) + cached_path = f"data/{dataset_name}/ci_width_and_flux_scale.pt" + if os.path.exists(cached_path): + print(f"Loading cached data from {cached_path}") + data = torch.load(cached_path) + else: + data = compute_ci_width(pred_dists, bins, cached_path) + + ci_width = data["ci_width"] + ci_width_prop = data["ci_width_prop"] + flux_scale = data["flux_scale"] + + metrics_to_plot = { + "ci_width": {"ylabel": "Average width of 95\% CI (nmgy)", "ylabel_size": "xx-large"}, + "ci_width_prop": { + "ylabel": "Average width of 95\% CI / true flux", + "ylabel_size": "xx-large", + }, + "flux_scale": {"ylabel": "Average $\\sigma$ for predicted flux", "ylabel_size": "xx-large"}, + } + + for metric, cfg_dict in metrics_to_plot.items(): + plot_ci_width_data(eval(metric), metric, cfg_dict, bins) diff --git a/case_studies/psf_variation/find_param_bins.ipynb b/case_studies/psf_variation/find_param_bins.ipynb new file mode 100644 index 000000000..b413257db --- /dev/null +++ b/case_studies/psf_variation/find_param_bins.ipynb @@ -0,0 +1,158 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from os import environ\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "from tqdm.notebook import tqdm\n", + "\n", + "from hydra import initialize, compose\n", + "from hydra.utils import instantiate\n", + "\n", + "from bliss.catalog import TileCatalog, convert_mag_to_nmgy, convert_nmgy_to_mag" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "with initialize(config_path=\"./conf\", version_base=None):\n", + " cfg = compose(\"config\")\n", + "\n", + "data_path = \"/data/scratch/aakash/multi_field\"\n", + "cached_dataset = instantiate(cfg.cached_simulator, cached_data_path=data_path, splits=\"0:0/0:0/90:100\")\n", + "cached_dataset.setup(stage=\"test\")\n", + "test_size = len(cached_dataset.test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b570547e31ef478ea6ac2e65cd37b6ce", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Getting fluxes: 0%| | 0/416 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(all_mags, bins=50);" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bliss-toolkit-av05Bskt-py3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/case_studies/psf_variation/psf_variation_demo.ipynb b/case_studies/psf_variation/psf_variation_demo.ipynb deleted file mode 100644 index 38226ab12..000000000 --- a/case_studies/psf_variation/psf_variation_demo.ipynb +++ /dev/null @@ -1,325 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Demo: Encoding PSF variation with BLISS" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This notebook demonstrates the capability of BLISS to encode and use PSF information to make more accurate predictions compared to a PSF-unaware model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from os import environ\n", - "from pathlib import Path\n", - "import re\n", - "\n", - "import torch\n", - "from torch.utils.data import DataLoader\n", - "import matplotlib.pyplot as plt\n", - "import pandas as pd\n", - "\n", - "import hydra\n", - "\n", - "from bliss.catalog import TileCatalog\n", - "from bliss.encoder.encoder import Encoder\n", - "from bliss.simulator.decoder import Decoder\n", - "from bliss.encoder.metrics import CatalogMetrics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# load config\n", - "with hydra.initialize(config_path=\".\", version_base=None):\n", - " cfg = hydra.compose(\"config\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load models and the data" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will compare two models:\n", - "1. A \"PSF-unaware\" model that has been trained on a single PSF (and has no information about the PSF during inference), and\n", - "2. A \"PSF-aware\" model that has been trained on images from different PSFs, and uses the PSF parameters during inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# psf-unaware model (trained on single PSF)\n", - "base_model: Encoder = hydra.utils.instantiate(\n", - " cfg.encoder,\n", - " image_normalizer={\"concat_psf_params\": False}\n", - ")\n", - "base_model.load_state_dict(torch.load(Path(cfg.paths.pretrained_models) / \"single_band_base.pt\"))\n", - "base_model.eval();\n", - "\n", - "# psf-aware model (trained with varying PSFs)\n", - "psf_model: Encoder = hydra.utils.instantiate(cfg.encoder)\n", - "psf_model.load_state_dict(torch.load(Path(cfg.paths.pretrained_models) / \"psf_aware.pt\"));\n", - "psf_model.eval();" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will evaluate these models on two images. Both are generated from the same catalog; they have sources at the same locations with the same parameters, but the only difference is one has a more localized PSF than the other, which is more spread out. We can think of these as being the PSFs on a clear night vs. a cloudy night, but looking at the same region of the sky." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# load data\n", - "with open(\"data/clear_psf.pt\", \"rb\") as f:\n", - " clear_data = torch.load(f)\n", - "\n", - "with open(\"data/cloudy_psf.pt\", \"rb\") as f:\n", - " cloudy_data = torch.load(f)\n", - "\n", - "dataloader = DataLoader(clear_data + cloudy_data)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's plot the images and PSFs to see the differences:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# construct PSF images from saved params\n", - "decoder: Decoder = hydra.utils.instantiate(cfg.simulator.decoder)\n", - "clear_psf = decoder._get_psf(clear_data[0][\"psf_params\"])\n", - "cloudy_psf = decoder._get_psf(cloudy_data[0][\"psf_params\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# plot images and PSFs\n", - "fig, ax = plt.subplots(2, 2, figsize=(6, 6))\n", - "ax[0, 0].set_title(\"'Clear' PSF\", size=\"large\")\n", - "ax[0, 1].set_title(\"'Cloudy' PSF\", size=\"large\")\n", - "ax[0, 0].set_ylabel(\"Image\", rotation=90, size=\"large\")\n", - "ax[1, 0].set_ylabel(\"PSF\", rotation=90, size=\"large\")\n", - "\n", - "ax[0, 0].imshow(clear_data[0][\"images\"][2]) # plot r band\n", - "ax[0, 1].imshow(cloudy_data[0][\"images\"][2])\n", - "\n", - "ax[1, 0].imshow(clear_psf[0].original.image.array)\n", - "ax[1, 1].imshow(cloudy_psf[0].original.image.array)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Make predictions on both models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# instantiate the Trainer\n", - "trainer = hydra.utils.instantiate(cfg.train.trainer, accelerator=\"cpu\", logger=None)\n", - "\n", - "# make predictions\n", - "base_results = trainer.predict(base_model, dataloader)\n", - "psf_results = trainer.predict(psf_model, dataloader)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Plot predictions and true locations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# get true catalog and crop to account for tiles cropped by model\n", - "true_catalog = TileCatalog(cfg.encoder.tile_slen, clear_data[0][\"tile_catalog\"])\n", - "true_catalog = true_catalog.symmetric_crop(base_model.tiles_to_crop).to_full_catalog(cfg.encoder.tile_slen)\n", - "true_locs = true_catalog[\"plocs\"]\n", - "\n", - "# get predicted locations\n", - "px_to_crop = 0 #base_model.tiles_to_crop * base_model.tile_slen\n", - "est_locs = [\n", - " [\n", - " base_results[0][\"est_cat\"].to_full_catalog(cfg.encoder.tile_slen)[\"plocs\"] + px_to_crop,\n", - " base_results[1][\"est_cat\"].to_full_catalog(cfg.encoder.tile_slen)[\"plocs\"] + px_to_crop\n", - " ],\n", - " [\n", - " psf_results[0][\"est_cat\"].to_full_catalog(cfg.encoder.tile_slen)[\"plocs\"] + px_to_crop,\n", - " psf_results[1][\"est_cat\"].to_full_catalog(cfg.encoder.tile_slen)[\"plocs\"] + px_to_crop\n", - " ]\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Plot true and estimated predictions for both models and both images\n", - "fig, ax = plt.subplots(2, 2, figsize=(8, 8))\n", - "\n", - "true_plot_args = dict(color=\"r\", s=25, marker=\"X\", edgecolors=\"k\", linewidth=0.5, label=\"True\")\n", - "est_plot_args = dict(color=\"y\", s=25, marker=\"P\", edgecolors=\"k\", linewidth=0.5, label=\"Predicted\")\n", - "\n", - "for row in range(2):\n", - " for col in range(2):\n", - " image = clear_data[0][\"images\"] if col == 0 else cloudy_data[0][\"images\"]\n", - " image = image[2, 4:76, 4:76]\n", - " \n", - " ax[row, col].imshow(image, origin=\"lower\", extent=(0, 72, 0, 72))\n", - " ax[row, col].scatter(true_locs[0, :, 1], true_locs[0, :, 0], **true_plot_args)\n", - " ax[row, col].scatter(est_locs[row][col][0, :, 1], est_locs[row][col][0, :, 0], **est_plot_args)\n", - "\n", - "# add row and column labels\n", - "ax[0, 0].set_ylabel(\"Base model\", rotation=90, size=12);\n", - "ax[1, 0].set_ylabel(\"PSF-aware model\", rotation=90, size=12);\n", - "ax[0, 0].set_title(\"'Clear' PSF\", size=12);\n", - "ax[0, 1].set_title(\"'Cloudy' PSF\", size=12);\n", - "\n", - "# add legend\n", - "handles, labels = ax[0, 0].get_legend_handles_labels()\n", - "fig.legend(handles, labels, loc='lower center', ncol=2, bbox_to_anchor=(0, -0.03, 1, 1), fontsize=10)\n", - "fig.tight_layout()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The PSF-unaware model performs well on the \"clear\" image, but much more poorly on the \"cloudy\" image; it misses the brightest source altogether, and some of the other predicted sources are offset from the true location. The PSF-aware model, on the other hand, is able to deal with the difference in PSF and performs equally well on the \"cloudy\" image as the \"clear\" image." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, let's take a look at some metrics for the PSF-unaware model compared to the PSF-aware model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metrics = CatalogMetrics(mode=\"matching\")\n", - "keys = [\"f1\", \"avg_distance_keep\", \"(disk|bulge)_.*_mae\", \".*_r_mae\"]\n", - "\n", - "results = {\n", - " \"psf-unaware\": { key: val for key, val in metrics(true_catalog, base_results[1][\"est_cat\"].to_full_catalog(4)).items()},\n", - " \"psf-aware\": { key: val for key, val in metrics(true_catalog, psf_results[1][\"est_cat\"].to_full_catalog(4)).items()}\n", - "}\n", - "\n", - "df = pd.DataFrame.from_dict(results, orient=\"index\")\n", - "df[[col for col in df.columns if any([re.match(pattern, col) for pattern in keys])]]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We see that across a majority of the relevant metrics, the psf-aware model outperforms the psf-unaware model, demonstrating the effectiveness of using PSF information in the encoder." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "bliss-deblender-av05Bskt-py3.10", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/case_studies/psf_variation/run_experiments.sh b/case_studies/psf_variation/run_experiments.sh new file mode 100755 index 000000000..df71e27e2 --- /dev/null +++ b/case_studies/psf_variation/run_experiments.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# This script runs the experiments from the BLISS spatially-variant PSF paper. + +# Generate new data +bash ~/bliss/scripts/generate_data_in_parallel.sh -n 32 -cp ~/bliss/case_studies/psf_variation/conf -cn psf_aware +bash ~/bliss/scripts/generate_data_in_parallel.sh -n 32 -cp ~/bliss/case_studies/psf_variation/conf -cn single_field + +# Train single-field model +bliss -cp ~/bliss/case_studies/psf_variation/conf -cn single_field mode=train + +# Train psf-unaware model +bliss -cp ~/bliss/case_studies/psf_variation/conf -cn psf_unaware mode=train + +# Train psf-aware model +bliss -cp ~/bliss/case_studies/psf_variation/conf -cn psf_aware mode=train + +# Run evaluation and generate figures +python evaluate_models.py --run_eval --plot_eval --run_calibration --plot_calibration --data_path=/data/scratch/aakash/multi_field +python evaluate_models.py --run_eval --plot_eval --run_calibration --plot_calibration --data_path=/data/scratch/aakash/single_field \ No newline at end of file diff --git a/case_studies/psf_variation/train_models.sh b/case_studies/psf_variation/train_models.sh deleted file mode 100755 index 519c219c6..000000000 --- a/case_studies/psf_variation/train_models.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# Runs the experiments for the BLISS PSF variation paper. - -# train base model -bliss -cp ~/bliss/case_studies/psf_variation -cn config mode=train \ - train.trainer.logger.name=PSF_MODELS \ - train.trainer.logger.version=base_model \ - encoder.image_normalizer.concat_psf_params=False \ - cached_simulator.cached_data_path=/data/scratch/aakash/train_single_94-1-12 - -# train psf-unaware model -bliss -cp ~/bliss/case_studies/psf_variation -cn config mode=train \ - train.trainer.logger.name=PSF_MODELS \ - train.trainer.logger.version=multi_field_psf_unaware \ - encoder.image_normalizer.concat_psf_params=False \ - cached_simulator.cached_data_path=/data/scratch/aakash/train_multi_field - -# train concat params only model -bliss -cp ~/bliss/case_studies/psf_variation -cn config mode=train \ - train.trainer.logger.name=PSF_MODELS \ - train.trainer.logger.version=multi_field_psf_params_only \ - encoder.image_normalizer.concat_psf_params=True \ - cached_simulator.cached_data_path=/data/scratch/aakash/train_multi_field - -# run the evaluation notebook -jupyter nbconvert --execute evaluate_models.ipynb --to html diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 13c0a9b7d..769a76a68 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -8,6 +8,7 @@ CatalogMatcher, DetectionPerformance, FluxError, + GalaxyShapeError, SourceTypeAccuracy, ) from bliss.surveys.des import TractorFullCatalog @@ -46,7 +47,12 @@ def test_metrics(self): "source_type": est_source_type, "star_fluxes": torch.ones(2, 2, 5), "galaxy_fluxes": torch.ones(2, 2, 5), - "galaxy_params": torch.ones(2, 2, 6), + "galaxy_disk_frac": torch.ones(2, 2, 1), + "galaxy_beta_radians": torch.ones(2, 2, 1), + "galaxy_disk_q": torch.ones(2, 2, 1), + "galaxy_a_d": torch.ones(2, 2, 1), + "galaxy_bulge_q": torch.ones(2, 2, 1), + "galaxy_a_b": torch.ones(2, 2, 1), } est_params = FullCatalog(slen, slen, d_est) @@ -62,6 +68,10 @@ def test_metrics(self): acc_results = acc_metrics(true_params, est_params, matching) assert np.isclose(acc_results["classification_acc"], 1 / 2) + gal_shape_metrics = GalaxyShapeError(bin_cutoffs=[200, 400, 600, 800, 1000]) + gal_shape_results = gal_shape_metrics(true_params, est_params, matching) + assert gal_shape_results["galaxy_disk_hlr_mae"] == 0 + def test_no_sources(self): """Tests that metrics work when there are no true or estimated sources.""" true_locs = torch.tensor( @@ -117,7 +127,7 @@ def test_self_agreement(self, tile_catalog): acc_results = acc_metrics(full_catalog, full_catalog, matching) assert acc_results["classification_acc"] == 1 - flux_metrics = FluxError("ugriz") + flux_metrics = FluxError("ugriz", bin_cutoffs=[200, 400, 600, 800, 1000]) flux_results = flux_metrics(full_catalog, full_catalog, matching) assert flux_results["flux_err_r_mae"] == 0 diff --git a/tests/testing_config.yaml b/tests/testing_config.yaml index 98a0ef51a..1b2587c75 100644 --- a/tests/testing_config.yaml +++ b/tests/testing_config.yaml @@ -78,6 +78,8 @@ encoder: flux_error: _target_: bliss.encoder.metrics.FluxError survey_bands: ${encoder.survey_bands} + bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6] + bin_type: "njymag" #########################################