From c9222584f5a798dad3b48e0518183e261fae38e1 Mon Sep 17 00:00:00 2001 From: Jeffrey Regier Date: Wed, 11 Sep 2024 14:24:03 -0400 Subject: [PATCH] shared fluxes --- bliss/catalog.py | 18 +-- bliss/conf/base_config.yaml | 11 +- bliss/encoder/metrics.py | 4 +- bliss/simulator/decoder.py | 10 +- bliss/simulator/prior.py | 6 +- bliss/surveys/dc2.py | 10 +- bliss/surveys/sdss.py | 9 +- .../dc2_cataloging/cataloging_exp.ipynb | 91 +++--------- case_studies/dc2_cataloging/train_config.yaml | 11 +- .../dc2_cataloging/utils/load_lsst.py | 3 +- .../dc2_cataloging/utils/lsst_predictor.py | 3 +- .../dc2_cataloging/utils/variational_dist.py | 6 +- .../dc2_multidetection/train_config.yaml | 11 +- .../data_generation/data_gen.py | 3 +- case_studies/psf_variation/conf/config.yaml | 11 +- case_studies/psf_variation/evaluate_models.py | 57 +++----- .../redshift/evaluation/dc2_plot.ipynb | 3 +- .../redshift_from_img/full_train_config.yaml | 10 +- case_studies/spatial_tiling/m2/config.yaml | 10 +- case_studies/spatial_tiling/m2/m2.ipynb | 6 +- .../spatial_tiling/sdss_demo/config.yaml | 24 ++-- .../spatial_tiling/sdss_demo/sdss_field.py | 59 ++++++-- .../toy_example.ipynb | 29 ++-- .../spatial_tiling/toy_example/config.yaml | 132 ------------------ tests/data/base_config_trained_encoder.pt | 4 +- tests/data/multiband_data/dataset_0.pt | 4 +- tests/data/multiband_data/dataset_1.pt | 3 - tests/data/sdss_preds.pt | 4 +- tests/data/test_image/dataset_0.pt | 4 +- tests/test_catalogs.py | 37 +++-- tests/test_dc2.py | 17 +-- tests/test_main.py | 2 +- tests/test_metrics.py | 12 +- tests/test_simulator.py | 15 +- 34 files changed, 184 insertions(+), 455 deletions(-) rename case_studies/spatial_tiling/{toy_example => sdss_demo}/toy_example.ipynb (96%) delete mode 100644 case_studies/spatial_tiling/toy_example/config.yaml delete mode 100644 tests/data/multiband_data/dataset_1.pt diff --git a/bliss/catalog.py b/bliss/catalog.py index 2ffb8207c..084f7bf5a 100644 --- a/bliss/catalog.py +++ b/bliss/catalog.py @@ -171,13 +171,7 @@ def galaxy_bools(self) -> Tensor: @property def on_fluxes(self): - # TODO: a tile catalog should store fluxes rather than star_fluxes and galaxy_fluxes - # because that's all that's needed to render the source - if "galaxy_fluxes" not in self: - fluxes = self["star_fluxes"] - else: - fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"]) - return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes)) + return self.is_on_mask.unsqueeze(-1) * self["fluxes"] def on_magnitudes(self, zero_point) -> Tensor: return convert_flux_to_magnitude(self.on_fluxes, zero_point) @@ -470,12 +464,7 @@ def galaxy_bools(self) -> Tensor: @property def on_fluxes(self) -> Tensor: - # ideally we'd always store fluxes rather than star_fluxes and galaxy_fluxes - if "galaxy_fluxes" not in self: - fluxes = self["star_fluxes"] - else: - fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"]) - return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes)) + return self.is_on_mask.unsqueeze(-1) * self["fluxes"] def on_magnitudes(self, zero_point) -> Tensor: return convert_flux_to_magnitude(self.on_fluxes, zero_point) @@ -571,9 +560,6 @@ def to_tile_catalog( """ assert max_sources_per_tile <= torch.iinfo(inter_int_type).max - # TODO: a FullCatalog only needs to "know" its height and width to convert itself to a - # TileCatalog. So those parameters should be passed on conversion, not initialization. - # initialization # n_tiles_h = math.ceil(self.height / tile_slen) n_tiles_w = math.ceil(self.width / tile_slen) diff --git a/bliss/conf/base_config.yaml b/bliss/conf/base_config.yaml index 2b523e1bd..f13dde34d 100644 --- a/bliss/conf/base_config.yaml +++ b/bliss/conf/base_config.yaml @@ -82,19 +82,12 @@ variational_factors: nll_gating: _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: star_fluxes + name: fluxes dim: 5 sample_rearrange: b ht wt d -> b ht wt 1 d nll_rearrange: b ht wt 1 d -> b ht wt d nll_gating: - _target_: bliss.encoder.variational_dist.StarGating - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: galaxy_fluxes - dim: 5 - sample_rearrange: b ht wt d -> b ht wt 1 d - nll_rearrange: b ht wt 1 d -> b ht wt d - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating + _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.LogitNormalFactor name: galaxy_disk_frac sample_rearrange: b ht wt d -> b ht wt 1 d diff --git a/bliss/encoder/metrics.py b/bliss/encoder/metrics.py index 89b10bb0c..f02b7e821 100644 --- a/bliss/encoder/metrics.py +++ b/bliss/encoder/metrics.py @@ -76,8 +76,8 @@ def get_cur_postfix(self): class NullFilter(CatFilter): def get_cur_filter_bools(self, true_cat, est_cat): - true_filter_bools = torch.ones_like(true_cat.star_bools.squeeze(2)).bool() - est_filter_bools = torch.ones_like(est_cat.star_bools.squeeze(2)).bool() + true_filter_bools = torch.ones_like(true_cat.is_on_mask) + est_filter_bools = torch.ones_like(est_cat.is_on_mask) return true_filter_bools, est_filter_bools diff --git a/bliss/simulator/decoder.py b/bliss/simulator/decoder.py index 4b7c37145..d9a5f2cbb 100644 --- a/bliss/simulator/decoder.py +++ b/bliss/simulator/decoder.py @@ -48,7 +48,7 @@ def render_star(self, psf, band, source_params): Returns: GSObject: a galsim representation of the rendered star convolved with the PSF """ - return psf[band].withFlux(source_params["star_fluxes"][band].item()) + return psf[band].withFlux(source_params["fluxes"][band].item()) def render_galaxy(self, psf, band, source_params): """Render a galaxy with given params and PSF. @@ -62,9 +62,9 @@ def render_galaxy(self, psf, band, source_params): Returns: GSObject: a galsim representation of the rendered galaxy convolved with the PSF """ - disk_flux = source_params["galaxy_fluxes"][band] * source_params["galaxy_disk_frac"] + disk_flux = source_params["fluxes"][band] * source_params["galaxy_disk_frac"] bulge_frac = 1 - source_params["galaxy_disk_frac"] - bulge_flux = source_params["galaxy_fluxes"][band] * bulge_frac + bulge_flux = source_params["fluxes"][band] * bulge_frac beta = source_params["galaxy_beta_radians"] * galsim.radians components = [] @@ -153,9 +153,7 @@ def render_image(self, tile_cat): # use the specified flux_calibration ratios indexed by image_id avg_nelec_conv = np.mean(frame["flux_calibration"], axis=-1) if n_sources > 0: - full_cat["star_fluxes"] *= rearrange(avg_nelec_conv, "bands -> 1 1 bands") - if "galaxy_fluxes" in tile_cat: - full_cat["galaxy_fluxes"] *= avg_nelec_conv + full_cat["fluxes"] *= rearrange(avg_nelec_conv, "bands -> 1 1 bands") # generate random WCS shifts as manual image dithering via unaligning WCS if self.with_dither: diff --git a/bliss/simulator/prior.py b/bliss/simulator/prior.py index dead8d409..cc8f6f05a 100644 --- a/bliss/simulator/prior.py +++ b/bliss/simulator/prior.py @@ -102,8 +102,10 @@ def sample(self) -> TileCatalog: d["locs"] = self._sample_locs() d["n_sources"] = self._sample_n_sources() d["source_type"] = self._sample_source_type() - d["star_fluxes"] = self._sample_fluxes(self.gmm_star, self.star_flux) - d["galaxy_fluxes"] = self._sample_fluxes(self.gmm_gal, self.galaxy_flux) + + star_fluxes = self._sample_fluxes(self.gmm_star, self.star_flux) + galaxy_fluxes = self._sample_fluxes(self.gmm_gal, self.galaxy_flux) + d["fluxes"] = torch.where(d["source_type"], galaxy_fluxes, star_fluxes) return TileCatalog(d) diff --git a/bliss/surveys/dc2.py b/bliss/surveys/dc2.py index 3bca21614..bc4b31ab5 100644 --- a/bliss/surveys/dc2.py +++ b/bliss/surveys/dc2.py @@ -259,8 +259,7 @@ def generate_cached_data(self, image_index): "locs", "n_sources", "source_type", - "galaxy_fluxes", - "star_fluxes", + "fluxes", "redshifts", "blendedness", "shear", @@ -333,9 +332,7 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs): source_type = torch.from_numpy(catalog["truth_type"].values) # we ignore the supernova source_type = torch.where(source_type == 2, SourceType.STAR, SourceType.GALAXY) - flux, psf_params = cls.get_bands_flux_and_psf(kwargs["bands"], catalog) - star_fluxes = flux - galaxy_fluxes = flux + fluxes, psf_params = cls.get_bands_flux_and_psf(kwargs["bands"], catalog) blendedness = torch.from_numpy(catalog["blendedness"].values) shear1 = torch.from_numpy(catalog["shear_1"].values) shear2 = torch.from_numpy(catalog["shear_2"].values) @@ -352,8 +349,7 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs): "source_type": source_type.view(1, ori_len, 1), "plocs": plocs.view(1, ori_len, 2), "redshifts": redshifts.view(1, ori_len, 1), - "galaxy_fluxes": galaxy_fluxes.view(1, ori_len, kwargs["n_bands"]), - "star_fluxes": star_fluxes.view(1, ori_len, kwargs["n_bands"]), + "fluxes": fluxes.view(1, ori_len, kwargs["n_bands"]), "blendedness": blendedness.view(1, ori_len, 1), "shear": shear.view(1, ori_len, 2), "ellipticity": ellipticity.view(1, ori_len, 2), diff --git a/bliss/surveys/sdss.py b/bliss/surveys/sdss.py index d7f22d17a..91e472059 100644 --- a/bliss/surveys/sdss.py +++ b/bliss/surveys/sdss.py @@ -366,14 +366,11 @@ def from_file(cls, cat_path, wcs: WCS, height, width): decs = column_to_tensor(table, "dec") galaxy_bools = (objc_type == 3) & (thing_id != -1) star_bools = (objc_type == 6) & (thing_id != -1) - star_fluxes = column_to_tensor(table, "psfflux") * star_bools.reshape(-1, 1) - star_mags = column_to_tensor(table, "psfmag") * star_bools.reshape(-1, 1) - galaxy_fluxes = column_to_tensor(table, "cmodelflux") * galaxy_bools.reshape(-1, 1) - galaxy_mags = column_to_tensor(table, "cmodelmag") * galaxy_bools.reshape(-1, 1) # Combine light source parameters to one tensor + star_fluxes = column_to_tensor(table, "psfflux") * star_bools.reshape(-1, 1) + galaxy_fluxes = column_to_tensor(table, "cmodelflux") * galaxy_bools.reshape(-1, 1) fluxes = star_fluxes + galaxy_fluxes - mags = star_mags + galaxy_mags # true light source mask keep = galaxy_bools | star_bools @@ -383,7 +380,6 @@ def from_file(cls, cat_path, wcs: WCS, height, width): ras = ras[keep] decs = decs[keep] fluxes = fluxes[keep] - mags = mags[keep] nobj = ras.shape[0] # We require all 5 bands for computing loss on predictions. @@ -401,7 +397,6 @@ def from_file(cls, cat_path, wcs: WCS, height, width): "n_sources": torch.tensor((nobj,)), "source_type": source_type.reshape(1, nobj, 1), "fluxes": fluxes.reshape(1, nobj, n_bands), - "mags": mags.reshape(1, nobj, n_bands), "ra": ras.reshape(1, nobj, 1), "dec": decs.reshape(1, nobj, 1), } diff --git a/case_studies/dc2_cataloging/cataloging_exp.ipynb b/case_studies/dc2_cataloging/cataloging_exp.ipynb index b087609a1..fe60edf2d 100644 --- a/case_studies/dc2_cataloging/cataloging_exp.ipynb +++ b/case_studies/dc2_cataloging/cataloging_exp.ipynb @@ -3458,10 +3458,8 @@ " self.ellipticity_sources = 0\n", " self.ellipticity1_within_ci_num = 0\n", " self.ellipticity2_within_ci_num = 0\n", - " self.star_flux_sources = 0\n", - " self.star_flux_within_ci_num = {band: 0 for band in bands}\n", - " self.galaxy_flux_sources = 0\n", - " self.galaxy_flux_within_ci_num = {band: 0 for band in bands}\n", + " self.flux_sources = 0\n", + " self.flux_within_ci_num = {band: 0 for band in bands}\n", "\n", " bliss_locs_list = []\n", " bliss_locs_ci_lower_list = []\n", @@ -3498,23 +3496,14 @@ " self.ellipticity1_within_ci_num += (ellipticity_within_ci[..., 0:1] & ellipticity_mask).sum().item()\n", " self.ellipticity2_within_ci_num += (ellipticity_within_ci[..., 1:2] & ellipticity_mask).sum().item()\n", "\n", - " star_mask = bliss_tile_cat.star_bools & target_tile_cat.star_bools[..., :bliss_m, :]\n", - " self.star_flux_sources += star_mask.sum().item()\n", - " bliss_star_flux_ci_lower = bliss_tile_cat[\"star_fluxes_ci_lower\"]\n", - " bliss_star_flux_ci_upper = bliss_tile_cat[\"star_fluxes_ci_upper\"]\n", - " target_star_flux = target_tile_cat[\"star_fluxes\"][..., :bliss_m, :]\n", - " star_flux_within_ci = (target_star_flux > bliss_star_flux_ci_lower) & (target_star_flux < bliss_star_flux_ci_upper)\n", + " mask = bliss_tile_cat.is_on_mask & target_tile_cat.is_on_mask[..., :bliss_m, :]\n", + " self.flux_sources += mask.sum().item()\n", + " bliss_flux_ci_lower = bliss_tile_cat[\"fluxes_ci_lower\"]\n", + " bliss_flux_ci_upper = bliss_tile_cat[\"fluxes_ci_upper\"]\n", + " target_flux = target_tile_cat[\"fluxes\"][..., :bliss_m, :]\n", + " flux_within_ci = (target_flux > bliss_flux_ci_lower) & (target_flux < bliss_flux_ci_upper)\n", " for i, band in enumerate(bands):\n", - " self.star_flux_within_ci_num[band] += (star_flux_within_ci[..., i:(i + 1)] & star_mask).sum().item()\n", - "\n", - " galaxy_mask = bliss_tile_cat.galaxy_bools & target_tile_cat.galaxy_bools[..., :bliss_m, :]\n", - " self.galaxy_flux_sources += galaxy_mask.sum().item()\n", - " bliss_galaxy_flux_ci_lower = bliss_tile_cat[\"galaxy_fluxes_ci_lower\"]\n", - " bliss_galaxy_flux_ci_upper = bliss_tile_cat[\"galaxy_fluxes_ci_upper\"]\n", - " target_galaxy_flux = target_tile_cat[\"galaxy_fluxes\"][..., :bliss_m, :]\n", - " galaxy_flux_within_ci = (target_galaxy_flux > bliss_galaxy_flux_ci_lower) & (target_galaxy_flux < bliss_galaxy_flux_ci_upper)\n", - " for i, band in enumerate(bands):\n", - " self.galaxy_flux_within_ci_num[band] += (galaxy_flux_within_ci[..., i:(i + 1)] & galaxy_mask).sum().item()\n", + " self.flux_within_ci_num[band] += (flux_within_ci[..., i:(i + 1)] & mask).sum().item()\n", "\n", " rand_int = random.randint(0, len(bliss_locs_list) - 1)\n", " bliss_locs = bliss_locs_list[rand_int]\n", @@ -3546,15 +3535,11 @@ " print(f\"# ellipticity2 within ci: {self.ellipticity2_within_ci_num}\")\n", " print(f\"ellipticity2 within ci: {self.ellipticity2_within_ci_num / self.ellipticity_sources: .4f}\")\n", " print()\n", - " print(f\"# star sources: {self.star_flux_sources}\")\n", + " print(f\"# star sources: {self.flux_sources}\")\n", " for band in bands:\n", - " print(f\"# {band} flux within ci: {self.star_flux_within_ci_num[band]}\")\n", - " print(f\"{band} flux within ci: {self.star_flux_within_ci_num[band] / self.star_flux_sources: .4f}\")\n", + " print(f\"# {band} flux within ci: {self.flux_within_ci_num[band]}\")\n", + " print(f\"{band} flux within ci: {self.flux_within_ci_num[band] / self.flux_sources: .4f}\")\n", " print()\n", - " print(f\"# galaxy sources: {self.galaxy_flux_sources}\")\n", - " for band in bands:\n", - " print(f\"# {band} flux within ci: {self.galaxy_flux_within_ci_num[band]}\")\n", - " print(f\"{band} flux within ci: {self.galaxy_flux_within_ci_num[band] / self.galaxy_flux_sources: .4f}\")\n", "\n", " def plot(self):\n", " fig1, ax1 = plt.subplots(1, 1, figsize=NoteBookPlottingParams.figsize)\n", @@ -3736,11 +3721,9 @@ " locs1_vsbc_list.append(locs_vsbc[..., 0])\n", " locs2_vsbc_list.append(locs_vsbc[..., 1])\n", " \n", - " star_flux_vsbc = bliss_tile_cat[\"star_fluxes_vsbc\"]\n", - " galaxy_flux_vsbc = bliss_tile_cat[\"galaxy_fluxes_vsbc\"]\n", + " flux_vsbc = bliss_tile_cat[\"fluxes_vsbc\"]\n", " for i, band in enumerate(bands):\n", - " self.star_flux_vsbc_dict[band].append(star_flux_vsbc[..., i])\n", - " self.galaxy_flux_vsbc_dict[band].append(galaxy_flux_vsbc[..., i])\n", + " self.flux_vsbc_dict[band].append(flux_vsbc[..., i])\n", "\n", " self.ellipticity1_vsbc = torch.cat(ellipticity1_vsbc_list, dim=0).flatten()\n", " self.ellipticity2_vsbc = torch.cat(ellipticity2_vsbc_list, dim=0).flatten()\n", @@ -3748,8 +3731,8 @@ " self.locs2_vsbc = torch.cat(locs2_vsbc_list, dim=0).flatten()\n", " self.is_on_mask = torch.cat(is_on_mask_list, dim=0).flatten()\n", "\n", - " for k, v in self.star_flux_vsbc_dict.items():\n", - " self.star_flux_vsbc_dict[k] = torch.cat(v, dim=0).flatten()\n", + " for k, v in self.flux_vsbc_dict.items():\n", + " self.flux_vsbc_dict[k] = torch.cat(v, dim=0).flatten()\n", "\n", " for k, v in self.galaxy_flux_vsbc_dict.items():\n", " self.galaxy_flux_vsbc_dict[k] = torch.cat(v, dim=0).flatten()\n", @@ -3808,7 +3791,7 @@ "\n", " return fig1, fig2\n", " \n", - " def _star_flux_plot(self):\n", + " def _flux_plot(self):\n", " fig, axes = plt.subplots(3, 2, \n", " figsize=(NoteBookPlottingParams.figsize[0] * 1.1, NoteBookPlottingParams.figsize[1] * 1.7), \n", " sharex=\"col\", \n", @@ -3842,50 +3825,14 @@ "\n", " return fig\n", " \n", - " def _galaxy_flux_plot(self):\n", - " fig, axes = plt.subplots(3, 2, \n", - " figsize=(NoteBookPlottingParams.figsize[0] * 1.1, NoteBookPlottingParams.figsize[1] * 1.7), \n", - " sharex=\"col\", \n", - " sharey=\"row\",\n", - " layout=\"constrained\")\n", - "\n", - " for i, band in enumerate(bands):\n", - " col_index = i % 2\n", - " row_index = i // 2\n", - " ax = axes[row_index, col_index]\n", - " galaxy_flux_vsbc = self.galaxy_flux_vsbc_dict[band]\n", - " galaxy_flux_vsbc = galaxy_flux_vsbc[self.is_on_mask & ~galaxy_flux_vsbc.isnan()]\n", - " ax.hist(galaxy_flux_vsbc, density=True, bins=50)\n", - " ax.axvline(0.5, color=\"red\")\n", - " ax.set_title(f\"{band} band\", fontsize=NoteBookPlottingParams.fontsize)\n", - " ax.grid(visible=False)\n", - "\n", - " axes[1, 0].set_ylabel(\"Density\", fontsize=NoteBookPlottingParams.fontsize)\n", - " axes[2, 0].set_xticks(np.linspace(0.0, 1.0, num=6))\n", - " axes[2, 0].set_xticklabels([f\"{i: .1f}\" for i in np.linspace(0.0, 1.0, num=6)])\n", - " axes[2, 0].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n", - " axes[2, 0].set_xlabel(\"$P_{galaxy\\ flux}$\", fontsize=NoteBookPlottingParams.fontsize)\n", - " axes[2, 1].set_xticks(np.linspace(0.0, 1.0, num=6))\n", - " axes[2, 1].set_xticklabels([f\"{i: .1f}\" for i in np.linspace(0.0, 1.0, num=6)])\n", - " axes[2, 1].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n", - " axes[2, 1].set_xlabel(\"$P_{galaxy\\ flux}$\", fontsize=NoteBookPlottingParams.fontsize)\n", - "\n", - " axes[0, 0].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n", - " axes[1, 0].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n", - " axes[2, 0].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n", - "\n", - " return fig\n", - "\n", " def plot(self, plot_type: str):\n", " match plot_type:\n", " case \"ellipticity\":\n", " return self._ellipticity_plot()\n", " case \"locs\":\n", " return self._locs_plot()\n", - " case \"star_flux\":\n", - " return self._star_flux_plot()\n", - " case \"galaxy_flux\":\n", - " return self._galaxy_flux_plot()\n", + " case \"flux\":\n", + " return self._flux_plot()\n", " case _:\n", " raise NotImplementedError()" ] diff --git a/case_studies/dc2_cataloging/train_config.yaml b/case_studies/dc2_cataloging/train_config.yaml index fb896492b..df0465967 100644 --- a/case_studies/dc2_cataloging/train_config.yaml +++ b/case_studies/dc2_cataloging/train_config.yaml @@ -25,19 +25,12 @@ my_variational_factors: nll_gating: _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: star_fluxes + name: fluxes dim: 6 sample_rearrange: b ht wt d -> b ht wt 1 d nll_rearrange: b ht wt 1 d -> b ht wt d nll_gating: - _target_: bliss.encoder.variational_dist.StarGating - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: galaxy_fluxes - dim: 6 - sample_rearrange: b ht wt d -> b ht wt 1 d - nll_rearrange: b ht wt 1 d -> b ht wt d - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating + _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.BivariateNormalFactor name: ellipticity sample_rearrange: b ht wt d -> b ht wt 1 d diff --git a/case_studies/dc2_cataloging/utils/load_lsst.py b/case_studies/dc2_cataloging/utils/load_lsst.py index 878fb9821..3d902948e 100644 --- a/case_studies/dc2_cataloging/utils/load_lsst.py +++ b/case_studies/dc2_cataloging/utils/load_lsst.py @@ -93,7 +93,6 @@ def get_lsst_full_cat(lsst_root_dir: str, cur_image_wcs, image_lim, r_band_min_f "plocs": lsst_plocs.unsqueeze(0), "n_sources": lsst_n_sources, "source_type": lsst_source_type.unsqueeze(0), - "galaxy_fluxes": lsst_flux.unsqueeze(0), - "star_fluxes": lsst_flux.unsqueeze(0).clone(), + "fluxes": lsst_flux.unsqueeze(0), }, ) diff --git a/case_studies/dc2_cataloging/utils/lsst_predictor.py b/case_studies/dc2_cataloging/utils/lsst_predictor.py index 8f065f641..3bc41b79f 100644 --- a/case_studies/dc2_cataloging/utils/lsst_predictor.py +++ b/case_studies/dc2_cataloging/utils/lsst_predictor.py @@ -61,8 +61,7 @@ def _predict_one_image(self, wcs_header_str, image_lim, height_index, width_inde "plocs": lsst_plocs.unsqueeze(0), "n_sources": lsst_n_sources, "source_type": lsst_source_type.unsqueeze(0), - "galaxy_fluxes": lsst_flux.unsqueeze(0), - "star_fluxes": lsst_flux.unsqueeze(0).clone(), + "fluxes": lsst_flux.unsqueeze(0), }, ).to_tile_catalog(self.tile_slen, self.max_sources_per_tile) diff --git a/case_studies/dc2_cataloging/utils/variational_dist.py b/case_studies/dc2_cataloging/utils/variational_dist.py index 61d61b7c2..50f12f237 100644 --- a/case_studies/dc2_cataloging/utils/variational_dist.py +++ b/case_studies/dc2_cataloging/utils/variational_dist.py @@ -103,14 +103,12 @@ def sample( vsbc_variables = { "locs", "ellipticity", - "star_fluxes", - "galaxy_fluxes", + "fluxes", } credible_interval_variables = { "locs", "ellipticity", - "star_fluxes", - "galaxy_fluxes", + "fluxes", } for qk, params in fp_pairs: if qk.name in vsbc_variables: diff --git a/case_studies/dc2_multidetection/train_config.yaml b/case_studies/dc2_multidetection/train_config.yaml index 9925b3448..9beee5bdf 100644 --- a/case_studies/dc2_multidetection/train_config.yaml +++ b/case_studies/dc2_multidetection/train_config.yaml @@ -25,19 +25,12 @@ my_variational_factors: nll_gating: _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: star_fluxes + name: fluxes dim: 6 sample_rearrange: b ht wt d -> b ht wt 1 d nll_rearrange: b ht wt 1 d -> b ht wt d nll_gating: - _target_: bliss.encoder.variational_dist.StarGating - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: galaxy_fluxes - dim: 6 - sample_rearrange: b ht wt d -> b ht wt 1 d - nll_rearrange: b ht wt 1 d -> b ht wt d - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating + _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.BivariateNormalFactor name: ellipticity sample_rearrange: b ht wt d -> b ht wt 1 d diff --git a/case_studies/galaxy_clustering/data_generation/data_gen.py b/case_studies/galaxy_clustering/data_generation/data_gen.py index 19f9e7502..f9c90bca4 100644 --- a/case_studies/galaxy_clustering/data_generation/data_gen.py +++ b/case_studies/galaxy_clustering/data_generation/data_gen.py @@ -116,10 +116,9 @@ def file_data_gen(cfg): catalog_dict = {} catalog_dict["plocs"] = torch.tensor([catalog[["X", "Y"]].to_numpy()]) catalog_dict["n_sources"] = torch.sum(catalog_dict["plocs"][:, :, 0] != 0, axis=1) - catalog_dict["galaxy_fluxes"] = torch.tensor( + catalog_dict["fluxes"] = torch.tensor( [catalog[["FLUX_G", "FLUX_R", "FLUX_I", "FLUX_Z"]].to_numpy()] ) - catalog_dict["star_fluxes"] = torch.zeros_like(catalog_dict["galaxy_fluxes"]) catalog_dict["membership"] = torch.tensor([catalog[["MEM"]].to_numpy()]) catalog_dict["redshift"] = torch.tensor([catalog[["Z"]].to_numpy()]) catalog_dict["galaxy_params"] = torch.tensor( diff --git a/case_studies/psf_variation/conf/config.yaml b/case_studies/psf_variation/conf/config.yaml index 33e721a0a..c9b47d9fa 100644 --- a/case_studies/psf_variation/conf/config.yaml +++ b/case_studies/psf_variation/conf/config.yaml @@ -111,19 +111,12 @@ variational_factors: nll_gating: _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: star_fluxes + name: fluxes dim: 5 sample_rearrange: "b ht wt d -> b ht wt 1 d" nll_rearrange: "b ht wt 1 d -> b ht wt d" nll_gating: - _target_: bliss.encoder.variational_dist.StarGating - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: galaxy_fluxes - dim: 5 - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 d -> b ht wt d" - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating + _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.LogitNormalFactor name: galaxy_disk_frac sample_rearrange: "b ht wt d -> b ht wt 1 d" diff --git a/case_studies/psf_variation/evaluate_models.py b/case_studies/psf_variation/evaluate_models.py index 950d99a68..e613e0f83 100644 --- a/case_studies/psf_variation/evaluate_models.py +++ b/case_studies/psf_variation/evaluate_models.py @@ -440,19 +440,13 @@ def compute_prob_flux_within_one_mag(pred_dists, bins, cached_path): for name, model in models.items(): # Get probabilities of flux within 1 magnitude of true - q_star_flux = pred_dists[name][i]["star_fluxes"].base_dist - q_gal_flux = pred_dists[name][i]["galaxy_fluxes"].base_dist + q_flux = pred_dists[name][i]["fluxes"].base_dist - star_flux_probs = q_star_flux.cdf(ub) - q_star_flux.cdf(lb) - gal_flux_probs = q_gal_flux.cdf(ub) - q_gal_flux.cdf(lb) + flux_probs = q_flux.cdf(ub) - q_flux.cdf(lb) + flux_probs = flux_probs.unsqueeze(-2)[target_cat.is_on_mask][:, 2] - pred_probs = torch.where( - target_cat.star_bools, star_flux_probs.unsqueeze(-2), gal_flux_probs.unsqueeze(-2) - ) - pred_probs = pred_probs[target_cat.is_on_mask][:, 2] - - probs_per_bin = torch.zeros(n_bins, dtype=pred_probs.dtype) - sum_probs[name] += probs_per_bin.scatter_add(0, binned_target_on_mags, pred_probs) + probs_per_bin = torch.zeros(n_bins, dtype=flux_probs.dtype) + sum_probs[name] += probs_per_bin.scatter_add(0, binned_target_on_mags, flux_probs) bin_count[name] += binned_target_on_mags.bincount(minlength=n_bins) binned_avg_flux_probs = {} @@ -535,27 +529,19 @@ def compute_prop_flux_in_interval(pred_dists, intervals, cached_path): dim_count += dim_mask.sum() for name, model in models.items(): - q_star_flux = pred_dists[name][i]["star_fluxes"].base_dist - q_gal_flux = pred_dists[name][i]["galaxy_fluxes"].base_dist + q_flux = pred_dists[name][i]["fluxes"].base_dist for j, interval in enumerate(intervals): # construct equal tail intervals and determine if true flux is within ETI tail_prob = (1 - interval) / 2 - star_lb = q_star_flux.icdf(tail_prob)[..., 2] - star_ub = q_star_flux.icdf(1 - tail_prob)[..., 2] - gal_lb = q_gal_flux.icdf(tail_prob)[..., 2] - gal_ub = q_gal_flux.icdf(1 - tail_prob)[..., 2] - - star_flux_in_eti = (true_fluxes >= star_lb) & (true_fluxes <= star_ub) - gal_flux_in_eti = (true_fluxes >= gal_lb) & (true_fluxes <= gal_ub) + lb = q_flux.icdf(tail_prob)[..., 2] + ub = q_flux.icdf(1 - tail_prob)[..., 2] - source_in_eti = torch.where( - target_cat.star_bools.squeeze(), star_flux_in_eti, gal_flux_in_eti - ) + flux_in_eti = (true_fluxes >= lb) & (true_fluxes <= ub) - sum_all_in_eti[name][j] += (source_in_eti * target_cat.is_on_mask.squeeze()).sum() - sum_bright_in_eti[name][j] += (source_in_eti * bright_mask).sum() - sum_dim_in_eti[name][j] += (source_in_eti * dim_mask).sum() + sum_all_in_eti[name][j] += (flux_in_eti * target_cat.is_on_mask.squeeze()).sum() + sum_bright_in_eti[name][j] += (flux_in_eti * bright_mask).sum() + sum_dim_in_eti[name][j] += (flux_in_eti * dim_mask).sum() # Compute proportions and save data prop_all_in_eti = {} @@ -985,17 +971,13 @@ def compute_ci_width(pred_dists, bins, cached_path): bin_count += binned_target_on_mags.bincount(minlength=n_bins) for name, model in models.items(): - q_star_flux = pred_dists[name][i]["star_fluxes"].base_dist - q_gal_flux = pred_dists[name][i]["galaxy_fluxes"].base_dist + q_flux = pred_dists[name][i]["fluxes"].base_dist # construct equal tail intervals - star_intervals = q_star_flux.icdf(1 - tail_prob) - q_star_flux.icdf(tail_prob) - gal_intervals = q_gal_flux.icdf(1 - tail_prob) - q_gal_flux.icdf(tail_prob) + source_intervals = q_flux.icdf(1 - tail_prob) - q_flux.icdf(tail_prob) # Compute CI width for true sources based on source type - width = torch.where( - target_cat.star_bools, star_intervals.unsqueeze(-2), gal_intervals.unsqueeze(-2) - ) + width = source_intervals.unsqueeze(-2) width = width[target_cat.is_on_mask][:, 2] width[width == torch.inf] = 0 # temp hack to not get inf @@ -1009,14 +991,7 @@ def compute_ci_width(pred_dists, bins, cached_path): width / target_cat.on_fluxes[target_cat.is_on_mask][:, 2], ) - # Get flux scale for true sources based on source type - scale = torch.where( - target_cat.star_bools, - q_star_flux.scale.unsqueeze(-2), - q_gal_flux.scale.unsqueeze(-2), - ) - scale = scale[target_cat.is_on_mask][:, 2] - + scale = q_flux.scale.unsqueeze(-2)[target_cat.is_on_mask][:, 2] scales_per_bin = torch.zeros(n_bins, dtype=scale.dtype) flux_scale[name] += scales_per_bin.scatter_add(0, binned_target_on_mags, scale) diff --git a/case_studies/redshift/evaluation/dc2_plot.ipynb b/case_studies/redshift/evaluation/dc2_plot.ipynb index 65c0de665..074ffd311 100644 --- a/case_studies/redshift/evaluation/dc2_plot.ipynb +++ b/case_studies/redshift/evaluation/dc2_plot.ipynb @@ -774,8 +774,7 @@ " \"plocs\": lsst_plocs.unsqueeze(0),\n", " \"n_sources\": lsst_n_sources,\n", " \"source_type\": lsst_source_type.unsqueeze(0),\n", - " \"galaxy_fluxes\": lsst_flux.unsqueeze(0),\n", - " \"star_fluxes\": lsst_flux.unsqueeze(0).clone(),\n", + " \"fluxes\": lsst_flux.unsqueeze(0),\n", " },\n", " ).to_tile_catalog(self.tile_slen, self.max_sources_per_tile, ignore_extra_sources=True)\n", "\n", diff --git a/case_studies/redshift/redshift_from_img/full_train_config.yaml b/case_studies/redshift/redshift_from_img/full_train_config.yaml index 936942942..56eb7752f 100644 --- a/case_studies/redshift/redshift_from_img/full_train_config.yaml +++ b/case_studies/redshift/redshift_from_img/full_train_config.yaml @@ -26,17 +26,11 @@ variational_factors: nll_rearrange: "b ht wt 1 d -> b ht wt d" nll_gating: n_sources - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: star_fluxes + name: fluxes dim: 6 sample_rearrange: "b ht wt d -> b ht wt 1 d" nll_rearrange: "b ht wt 1 d -> b ht wt d" - nll_gating: is_star - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: galaxy_fluxes - dim: 6 - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 d -> b ht wt d" - nll_gating: is_galaxy + nll_gating: n_sources - _target_: bliss.encoder.variational_dist.NormalFactor name: redshifts sample_rearrange: "b ht wt -> b ht wt 1 1" diff --git a/case_studies/spatial_tiling/m2/config.yaml b/case_studies/spatial_tiling/m2/config.yaml index 6625fd8b5..6ff28fd21 100644 --- a/case_studies/spatial_tiling/m2/config.yaml +++ b/case_studies/spatial_tiling/m2/config.yaml @@ -8,19 +8,13 @@ paths: cached_data: /data/scratch/regier/m2_aug30 variational_factors: - - _target_: bliss.encoder.variational_dist.BernoulliFactor - name: source_type - sample_rearrange: "b ht wt -> b ht wt 1 1" - nll_rearrange: "b ht wt 1 1 -> b ht wt" - nll_gating: - _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: star_fluxes + name: fluxes dim: 1 sample_rearrange: "b ht wt d -> b ht wt 1 d" nll_rearrange: "b ht wt 1 d -> b ht wt d" nll_gating: - _target_: bliss.encoder.variational_dist.StarGating + _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.DiscretizedUnitBoxFactor name: locs sample_rearrange: "b ht wt d -> b ht wt 1 d" diff --git a/case_studies/spatial_tiling/m2/m2.ipynb b/case_studies/spatial_tiling/m2/m2.ipynb index 87a0d0182..15c1a972e 100644 --- a/case_studies/spatial_tiling/m2/m2.ipynb +++ b/case_studies/spatial_tiling/m2/m2.ipynb @@ -230,8 +230,7 @@ "source": [ "d = {\n", " \"plocs\": plocs_square.unsqueeze(0),\n", - " \"star_fluxes\": sdss_r_nmgy.unsqueeze(0).unsqueeze(2),\n", - " \"galaxy_fluxes\": sdss_r_nmgy.unsqueeze(0).unsqueeze(2) * 0.0,\n", + " \"fluxes\": sdss_r_nmgy.unsqueeze(0).unsqueeze(2),\n", " \"n_sources\": torch.tensor(plocs.shape[0]).unsqueeze(0),\n", " \"source_type\": torch.zeros(plocs.shape[0]).unsqueeze(0).unsqueeze(2).long(),\n", "}" @@ -287,8 +286,7 @@ "source": [ "d = {\n", " \"plocs\": plocs_square[is_bright].unsqueeze(0),\n", - " \"star_fluxes\": sdss_r_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),\n", - " \"galaxy_fluxes\": sdss_r_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]) * 0.0,\n", + " \"fluxes\": sdss_r_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),\n", " \"n_sources\": torch.tensor(plocs[is_bright].shape[0]).unsqueeze(0),\n", " \"source_type\": torch.zeros(plocs[is_bright].shape[0]).unsqueeze(0).unsqueeze(2).long(),\n", "}\n", diff --git a/case_studies/spatial_tiling/sdss_demo/config.yaml b/case_studies/spatial_tiling/sdss_demo/config.yaml index 9a6759d75..fe03fbda3 100644 --- a/case_studies/spatial_tiling/sdss_demo/config.yaml +++ b/case_studies/spatial_tiling/sdss_demo/config.yaml @@ -28,19 +28,12 @@ variational_factors: nll_gating: _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: star_fluxes + name: fluxes dim: 1 sample_rearrange: b ht wt d -> b ht wt 1 d nll_rearrange: b ht wt 1 d -> b ht wt d nll_gating: - _target_: bliss.encoder.variational_dist.StarGating - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: galaxy_fluxes - dim: 1 - sample_rearrange: b ht wt d -> b ht wt 1 d - nll_rearrange: b ht wt 1 d -> b ht wt d - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating + _target_: bliss.encoder.variational_dist.SourcesGating - _target_: bliss.encoder.variational_dist.LogitNormalFactor name: galaxy_disk_frac sample_rearrange: b ht wt d -> b ht wt 1 d @@ -106,7 +99,16 @@ nopsf_image_normalizers: my_metrics: detection_performance: _target_: bliss.encoder.metrics.DetectionPerformance - base_flux_bin_cutoffs: ${sdss_flux_cutoffs} + base_flux_bin_cutoffs: + - 1 + - 1.9055 + - 2.7542 + - 3.9811 + - 5.7544 + - 8.3176 + - 12.0227 + - 17.3780 + - 25.1189 mag_zero_point: ${sdss_mag_zero_point} report_bin_unit: mag exclude_last_bin: true @@ -118,7 +120,7 @@ encoder: reference_band: 0 matcher: _target_: bliss.encoder.metrics.CatalogMatcher - dist_slack: 1 + dist_slack: 2 mag_band: 0 # SDSS r-band mode_metrics: _target_: torchmetrics.MetricCollection diff --git a/case_studies/spatial_tiling/sdss_demo/sdss_field.py b/case_studies/spatial_tiling/sdss_demo/sdss_field.py index 0b7e25177..25eb0148c 100644 --- a/case_studies/spatial_tiling/sdss_demo/sdss_field.py +++ b/case_studies/spatial_tiling/sdss_demo/sdss_field.py @@ -2,8 +2,6 @@ # ## Imports import gc - -# %% from os import environ from pathlib import Path @@ -23,7 +21,7 @@ from bliss.surveys.des import TractorFullCatalog from bliss.surveys.sdss import PhotoFullCatalog -environ["CUDA_VISIBLE_DEVICES"] = "7" +environ["CUDA_VISIBLE_DEVICES"] = "4" torch.set_grad_enabled(False) @@ -99,7 +97,7 @@ decals_cat_base = TractorFullCatalog.from_file(decals_path, sdss_wcs, 1488, 2048) # a bit less than 22.5 magnitude, our target -to_keep = decals_cat_base["fluxes"][..., 0] > 1 +to_keep = decals_cat_base["fluxes"][..., 0] > 0.8 d = {k: v[to_keep].unsqueeze(0) for k, v in decals_cat_base.items() if k != "n_sources"} d["n_sources"] = to_keep.sum(1) decals_cat = FullCatalog(decals_cat_base.height, decals_cat_base.width, d) @@ -113,7 +111,7 @@ # %% # change the cfg here to try different checkerboard schemes -encoder = instantiate(cfg_c1.train.encoder).cuda() +encoder = instantiate(cfg_c4.train.encoder).cuda() enc_state_dict = torch.load(cfg0.train.pretrained_weights) if cfg0.train.pretrained_weights.endswith(".ckpt"): enc_state_dict = enc_state_dict["state_dict"] @@ -150,7 +148,7 @@ d[k] = rearrange(v, pattern, hp=patches.shape[1], wp=patches.shape[2]) bliss_tile_cat = TileCatalog(d) -bliss_flux_filter_cat = bliss_tile_cat.filter_by_flux(convert_mag_to_nmgy(22.7), band=0) +bliss_flux_filter_cat = bliss_tile_cat.filter_by_flux(convert_mag_to_nmgy(22.5), band=0) bliss_cat = bliss_flux_filter_cat.to_full_catalog(4).to("cpu") # %% [markdown] @@ -175,13 +173,13 @@ # Create a CatalogMatcher object matcher = CatalogMatcher( - dist_slack=5.0, + dist_slack=2.0, mag_band=2, ) # Match the catalogs based on their positions -match_gt_bliss = matcher.match_catalogs(decals_cat_box, bliss_cat_box)[0] -match_gt_photo = matcher.match_catalogs(decals_cat_box, photo_cat_box)[0] +match_gt_bliss = matcher.match_catalogs(decals_cat_box, bliss_cat_box) +match_gt_photo = matcher.match_catalogs(decals_cat_box, photo_cat_box) fig, ax = plt.subplots(figsize=(14, 14)) bw = np.array(rgb, dtype=np.float32).sum(2) @@ -225,15 +223,19 @@ matches = { # in decals and (bliss or sdss) - "gt_all": set(match_gt_bliss[0].numpy()).union(match_gt_photo[0].numpy()), + "gt_all": set(match_gt_bliss[0][0].numpy()).union(match_gt_photo[0][0].numpy()), # in bliss and decals, not in sdss - "bliss_tp_only": set(match_gt_bliss[0].numpy()).difference(match_gt_photo[0].numpy()), + "bliss_tp_only": set(match_gt_bliss[0][0].numpy()).difference(match_gt_photo[0][0].numpy()), # in sdss and decals, not in bliss - "sdss_tp_only": set(match_gt_photo[0].numpy()).difference(match_gt_bliss[0].numpy()), + "sdss_tp_only": set(match_gt_photo[0][0].numpy()).difference(match_gt_bliss[0][0].numpy()), # in bliss, not in decals - "bliss_fp": set(range(bliss_cat_box["n_sources"].item())).difference(match_gt_bliss[1].numpy()), + "bliss_fp": set(range(bliss_cat_box["n_sources"].item())).difference( + match_gt_bliss[0][1].numpy() + ), # in sdss, not in decals - "sdss_fp": set(range(photo_cat_box["n_sources"].item())).difference(match_gt_photo[1].numpy()), + "sdss_fp": set(range(photo_cat_box["n_sources"].item())).difference( + match_gt_photo[0][1].numpy() + ), } params = { @@ -295,6 +297,33 @@ fontsize=10, ) +# flake8: noqa: WPS421 for k, v in matches.items(): - print(k, len(v)) # noqa: WPS421 + print(k, len(v)) +# %% + +dp = instantiate(cfg0.my_metrics.detection_performance) +dp.update(decals_cat_box, bliss_cat_box, match_gt_bliss) +bliss_scores = dp.compute() +dp.reset() +print( + bliss_scores["detection_precision"], + bliss_scores["detection_recall"], + bliss_scores["detection_f1"], +) +# %% +dp.update(decals_cat_box, photo_cat_box, match_gt_photo) +sdss_scores = dp.compute() +dp.reset() +print( + sdss_scores["detection_precision"], + sdss_scores["detection_recall"], + sdss_scores["detection_f1"], +) +# %% +decoder = instantiate(cfg0.decoder) +decals_tile_cat = decals_cat_box.to_tile_catalog(4, 2) +decals_tile_cat["source_type"] = torch.zeros_like(decals_tile_cat["source_type"]) +semisynth, _ = decoder.render_image(decals_tile_cat) +plt.imshow(semisynth[2].numpy(), origin="lower") # %% diff --git a/case_studies/spatial_tiling/toy_example/toy_example.ipynb b/case_studies/spatial_tiling/sdss_demo/toy_example.ipynb similarity index 96% rename from case_studies/spatial_tiling/toy_example/toy_example.ipynb rename to case_studies/spatial_tiling/sdss_demo/toy_example.ipynb index 7640e2075..de88f5a22 100644 --- a/case_studies/spatial_tiling/toy_example/toy_example.ipynb +++ b/case_studies/spatial_tiling/sdss_demo/toy_example.ipynb @@ -38,7 +38,7 @@ "outputs": [], "source": [ "from os import environ\n", - "environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + "environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", "\n", "import torch\n", "from bliss.catalog import TileCatalog\n", @@ -63,9 +63,11 @@ "from hydra import initialize, compose\n", "from hydra.utils import instantiate\n", "\n", - "with initialize(config_path=\".\", version_base=None):\n", + "ckpt = \"/home/regier/bliss_output/sep3_sdss_demo/version_3/checkpoints/best_encoder.ckpt\"\n", + "\n", + "with initialize(config_path=\"../sdss_demo\", version_base=None):\n", " overrides = {\n", - " \"predict.weight_save_path=/home/regier/bliss_output/jul25_toy_example_10_percent/version_0/checkpoints/best_encoder.ckpt\",\n", + " \"predict.weight_save_path=\" + ckpt,\n", " \"decoder.with_noise=true\",\n", " \"decoder.with_dither=false\",\n", " \"encoder.predict_mode_not_samples=false\",\n", @@ -144,18 +146,15 @@ " galaxy_params = torch.ones(n, 20, 20, 1, 6) * 0.5\n", " galaxy_params[:, 10, 9, 0, [3,5]] = 10.0\n", "\n", - " star_fluxes = torch.zeros(n, 20, 20, 1, 5)\n", - " star_fluxes[:, ht, 10] = flux\n", - "\n", - " galaxy_fluxes = torch.zeros(n, 20, 20, 1, 5)\n", - " galaxy_fluxes[:, 10, 9] = 400.0\n", + " fluxes = torch.zeros(n, 20, 20, 1, 5)\n", + " fluxes[:, ht, 10] = flux\n", + " fluxes[:, 10, 9] = 400.0\n", "\n", " true_catalog_dict = {\n", " \"n_sources\": n_sources,\n", " \"source_type\": source_type,\n", " \"locs\": locs,\n", - " \"star_fluxes\": star_fluxes, \n", - " \"galaxy_fluxes\": galaxy_fluxes,\n", + " \"fluxes\": fluxes,\n", " \"galaxy_params\": galaxy_params,\n", " }\n", " true_catalog = TileCatalog(true_catalog_dict)\n", @@ -163,8 +162,7 @@ " images, psf_params = decoder.render_images(true_catalog)\n", "\n", " # one band (without using CachedDataset + OneBandTransform for simplicity)\n", - " true_catalog[\"star_fluxes\"] = true_catalog[\"star_fluxes\"][..., 2:3]\n", - " true_catalog[\"galaxy_fluxes\"] = true_catalog[\"galaxy_fluxes\"][..., 2:3]\n", + " true_catalog[\"fluxes\"] = true_catalog[\"fluxes\"][..., 2:3]\n", "\n", " batch = {\n", " \"images\": images[:, 2:3].cuda(),\n", @@ -631,6 +629,13 @@ "trainer.predict(encoder, dataloaders=[data_source.test_dataloader()], return_predictions=False)\n", "nll_callback.report()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/case_studies/spatial_tiling/toy_example/config.yaml b/case_studies/spatial_tiling/toy_example/config.yaml deleted file mode 100644 index 00f62367d..000000000 --- a/case_studies/spatial_tiling/toy_example/config.yaml +++ /dev/null @@ -1,132 +0,0 @@ ---- -defaults: - - ../../../bliss/conf@_here_: base_config - - _self_ - - override hydra/job_logging: stdout - -paths: - cached_data: /data/scratch/regier/toy_example_10_percent - -# this prior is sdss-like, except for the flux distribution, which is easier -# (i.e, flatter, always greater than 3 nmgy, and always less than 100 nmgy). -# the source density is higher too. -# greatly increasing the source density too here -prior: - # not exactly the mean because the max sources per tile is - # clipped at 1, but close - mean_sources: 0.1 # expect 36 sources per 80x80 image - # truncpareto support is [scale + loc, truncation * scale + loc] - star_flux: - exponent: 0.1 - runcation: 100 - loc: 0 - scale: 1.3 - galaxy_flux: - exponent: 0.01 - truncation: 100 - loc: 0 - scale: 1.3 - -variational_factors: - - _target_: bliss.encoder.variational_dist.BernoulliFactor - name: n_sources - sample_rearrange: null - nll_rearrange: null - nll_gating: null - - _target_: bliss.encoder.variational_dist.TDBNFactor - name: locs - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 d -> b ht wt d" - nll_gating: - _target_: bliss.encoder.variational_dist.SourcesGating - - _target_: bliss.encoder.variational_dist.BernoulliFactor - name: source_type - sample_rearrange: "b ht wt -> b ht wt 1 1" - nll_rearrange: "b ht wt 1 1 -> b ht wt" - nll_gating: - _target_: bliss.encoder.variational_dist.SourcesGating - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: star_fluxes - dim: 1 - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 d -> b ht wt d" - nll_gating: - _target_: bliss.encoder.variational_dist.StarGating - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: galaxy_fluxes - dim: 1 - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 d -> b ht wt d" - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating - - _target_: bliss.encoder.variational_dist.LogitNormalFactor - name: galaxy_disk_frac - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 1 -> b ht wt 1" - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating - - _target_: bliss.encoder.variational_dist.LogitNormalFactor - name: galaxy_beta_radians - high: 3.1415926 - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 1 -> b ht wt 1" - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating - - _target_: bliss.encoder.variational_dist.LogitNormalFactor - name: galaxy_disk_q - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 1 -> b ht wt 1" - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: galaxy_a_d - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 1 -> b ht wt 1" - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating - - _target_: bliss.encoder.variational_dist.LogitNormalFactor - name: galaxy_bulge_q - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 1 -> b ht wt 1" - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating - - _target_: bliss.encoder.variational_dist.LogNormalFactor - name: galaxy_a_b - sample_rearrange: "b ht wt d -> b ht wt 1 d" - nll_rearrange: "b ht wt 1 1 -> b ht wt 1" - nll_gating: - _target_: bliss.encoder.variational_dist.GalaxyGating - -my_metrics: - detection_performance: - _target_: bliss.encoder.metrics.DetectionPerformance - bin_cutoffs: [19, 19.4, 19.8, 20.2, 20.6, 21, 21.4, 21.8] - bin_type: "mag" - ref_band: 0 - -encoder: - survey_bands: ['r'] - reference_band: 0 - use_checkerboard: true - matcher: - _target_: bliss.encoder.metrics.CatalogMatcher - dist_slack: 1.0 - mag_slack: null - mag_band: 2 # SDSS r-band - mode_metrics: - _target_: torchmetrics.MetricCollection - _convert_: partial - metrics: ${my_metrics} - sample_metrics: - _target_: torchmetrics.MetricCollection - _convert_: partial - metrics: ${my_metrics} - -cached_simulator: - train_transforms: - - _target_: bliss.cached_dataset.OneBandTransform - band_idx: 2 - - _target_: bliss.data_augmentation.RotateFlipTransform - nontrain_transforms: - - _target_: bliss.cached_dataset.OneBandTransform - band_idx: 2 diff --git a/tests/data/base_config_trained_encoder.pt b/tests/data/base_config_trained_encoder.pt index 3669f64e4..cf841b993 100644 --- a/tests/data/base_config_trained_encoder.pt +++ b/tests/data/base_config_trained_encoder.pt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7d777b4a5d04b17cea412c2a68c29c0937f879a0bcdd8cd216dda997a18cf7ef -size 28984282 +oid sha256:ede1fb9c0d8d47eb6c2415870a58dbd5e9acf5f7f1e9c2fd8eafb30792577e9e +size 28973978 diff --git a/tests/data/multiband_data/dataset_0.pt b/tests/data/multiband_data/dataset_0.pt index ac8bcb506..2828b9e67 100644 --- a/tests/data/multiband_data/dataset_0.pt +++ b/tests/data/multiband_data/dataset_0.pt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ff3262e196cb9653bac9596719bc3d4c0be5e61342b75d2fe7f2c45d776e11a8 -size 142599 +oid sha256:b532e6fcbede153c873fa3b0e16e6621ab942ac51997d8c9d6db65b46d06347f +size 64910 diff --git a/tests/data/multiband_data/dataset_1.pt b/tests/data/multiband_data/dataset_1.pt deleted file mode 100644 index 552660f23..000000000 --- a/tests/data/multiband_data/dataset_1.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:82dfd12399d5980cba6891cf3c82df13e859ad689d00678cbe2b867a2a0c730d -size 142599 diff --git a/tests/data/sdss_preds.pt b/tests/data/sdss_preds.pt index 4d2e19b4a..a3b3d363e 100644 --- a/tests/data/sdss_preds.pt +++ b/tests/data/sdss_preds.pt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f9c91766419456d428034c309da8582c26296d0aac3faba21316d2abe45ee528 -size 129804 +oid sha256:c3ecdd205f58768fedb6d516c895f1043eb7dd40dce4271919eb7024fd1ea7fe +size 100839 diff --git a/tests/data/test_image/dataset_0.pt b/tests/data/test_image/dataset_0.pt index 44a83a660..00a87e889 100644 --- a/tests/data/test_image/dataset_0.pt +++ b/tests/data/test_image/dataset_0.pt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f01e3bb3d4bc39621e2c12a7d3120d0d3b518afaa8e9357a49eeffad7f649ad8 -size 53415 +oid sha256:10518a0d1f4717114d7ac04885d726d01cd27a79211cded729e2757507bb0720 +size 37534 diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index 1558b401a..e0a00a577 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -15,8 +15,7 @@ def basic_tilecat(): "locs": torch.zeros(1, 2, 2, 1, 2), "source_type": torch.ones((1, 2, 2, 1, 1)).bool(), "galaxy_params": torch.zeros((1, 2, 2, 1, 6)), - "star_fluxes": torch.zeros((1, 2, 2, 1, 5)), - "galaxy_fluxes": torch.zeros(1, 2, 2, 1, 5), + "fluxes": torch.zeros((1, 2, 2, 1, 5)), } d["locs"][0, 0, 0, 0] = torch.tensor([0.5, 0.5]) d["locs"][0, 0, 1, 0] = torch.tensor([0.5, 0.5]) @@ -32,13 +31,12 @@ def multi_source_tilecat(): "locs": torch.zeros(1, 2, 2, 2, 2), "source_type": torch.ones((1, 2, 2, 2, 1)).bool(), "galaxy_params": torch.zeros((1, 2, 2, 2, 6)), - "star_fluxes": torch.zeros((1, 2, 2, 2, 5)), - "galaxy_fluxes": torch.zeros(1, 2, 2, 2, 5), + "fluxes": torch.zeros(1, 2, 2, 2, 5), } - d["galaxy_fluxes"][0, 0, 0, :, 2] = torch.tensor([1000, 500]) - d["galaxy_fluxes"][0, 0, 1, :, 2] = torch.tensor([10000, 200]) - d["galaxy_fluxes"][0, 1, 0, :, 2] = torch.tensor([0, 800]) - d["galaxy_fluxes"][0, 1, 1, :, 2] = torch.tensor([300, 600]) + d["fluxes"][0, 0, 0, :, 2] = torch.tensor([1000, 500]) + d["fluxes"][0, 0, 1, :, 2] = torch.tensor([10000, 200]) + d["fluxes"][0, 1, 0, :, 2] = torch.tensor([0, 800]) + d["fluxes"][0, 1, 1, :, 2] = torch.tensor([300, 600]) return TileCatalog(d) @@ -49,8 +47,7 @@ def multi_source_fullcat(): "n_sources": torch.tensor([2, 3, 1]), "plocs": torch.zeros((3, 3, 2)), "source_type": torch.ones((3, 3, 1)).bool(), - "star_fluxes": torch.zeros((3, 3, 6)), - "galaxy_fluxes": torch.zeros(3, 3, 6), + "fluxes": torch.zeros(3, 3, 6), } d["plocs"][0, 0, :] = torch.tensor([300, 600]) @@ -60,9 +57,9 @@ def multi_source_fullcat(): d["plocs"][1, 2, :] = torch.tensor([999, 998]) d["plocs"][2, 0, :] = torch.tensor([1999, 1977]) - d["galaxy_fluxes"][0, :, 2] = torch.tensor([1000, 500, 0]) - d["galaxy_fluxes"][1, :, 2] = torch.tensor([10000, 545, 123]) - d["galaxy_fluxes"][2, :, 2] = torch.tensor([124, 0, 0]) + d["fluxes"][0, :, 2] = torch.tensor([1000, 500, 0]) + d["fluxes"][1, :, 2] = torch.tensor([10000, 545, 123]) + d["fluxes"][2, :, 2] = torch.tensor([124, 0, 0]) return FullCatalog(2000, 2000, d) @@ -90,8 +87,8 @@ def test_restrict_tile_cat_to_brightest(self, multi_source_tilecat): assert cat.max_sources == 1 assert cat["n_sources"].max() == 1 assert cat["n_sources"].sum() == 3 - assert cat["galaxy_fluxes"].sum() == 11600.0 - assert cat["galaxy_fluxes"].max() == 10000.0 + assert cat["fluxes"].sum() == 11600.0 + assert cat["fluxes"].max() == 10000.0 # do it again to make sure nothing changes assert cat.get_brightest_sources_per_tile(band=2).max_sources == 1 @@ -100,7 +97,7 @@ def test_filter_tile_cat_by_flux(self, multi_source_tilecat): cat = multi_source_tilecat.filter_by_flux(300) assert cat.max_sources == 2 assert cat["n_sources"].sum() == 4 - r_band_flux = cat["galaxy_fluxes"][..., 2:3] + r_band_flux = cat["fluxes"][..., 2:3] r_band_flux = torch.where(cat.galaxy_bools, r_band_flux, torch.inf) assert r_band_flux.min().item() == 500 @@ -154,10 +151,10 @@ def test_filter_full_catalog_by_ploc_box(self, multi_source_fullcat): assert torch.allclose(cat["plocs"][0, 0, :], torch.tensor([300.0, 600.0])) assert torch.allclose(cat["plocs"][1, 0, :], torch.tensor([730.0, 73.0])) assert torch.allclose(cat["plocs"][1, 1, :], torch.tensor([999.0, 998.0])) - assert cat["galaxy_fluxes"].shape[1] == 2 - assert torch.allclose(cat["galaxy_fluxes"][0, :, 2], torch.tensor([1000.0, 500.0])) - assert torch.allclose(cat["galaxy_fluxes"][1, :, 2], torch.tensor([10000.0, 123.0])) - assert torch.allclose(cat["galaxy_fluxes"][2, :, 2], torch.tensor([124.0, 0.0])) + assert cat["fluxes"].shape[1] == 2 + assert torch.allclose(cat["fluxes"][0, :, 2], torch.tensor([1000.0, 500.0])) + assert torch.allclose(cat["fluxes"][1, :, 2], torch.tensor([10000.0, 123.0])) + assert torch.allclose(cat["fluxes"][2, :, 2], torch.tensor([124.0, 0.0])) def test_tile_full_round_trip(self, cfg): with open(Path(cfg.paths.test_data) / "sdss_preds.pt", "rb") as f: diff --git a/tests/test_dc2.py b/tests/test_dc2.py index 29839e816..a477af8e7 100644 --- a/tests/test_dc2.py +++ b/tests/test_dc2.py @@ -156,8 +156,7 @@ def test_dc2_size_and_type(self, cfg): "locs", "n_sources", "source_type", - "galaxy_fluxes", - "star_fluxes", + "fluxes", ) for k in params: @@ -200,24 +199,16 @@ def test_train_on_dc2(self, cfg): }, { "_target_": "bliss.encoder.variational_dist.LogNormalFactor", - "name": "star_fluxes", + "name": "fluxes", "dim": 6, "sample_rearrange": "b ht wt d -> b ht wt 1 d", "nll_rearrange": "b ht wt 1 d -> b ht wt d", - "nll_gating": "is_star", - }, - { - "_target_": "bliss.encoder.variational_dist.LogNormalFactor", - "name": "galaxy_fluxes", - "dim": 6, - "sample_rearrange": "b ht wt d -> b ht wt 1 d", - "nll_rearrange": "b ht wt 1 d -> b ht wt d", - "nll_gating": "is_galaxy", + "nll_gating": "n_sources", }, ] for f in cfg.variational_factors: - if f.name in {"star_fluxes", "galaxy_fluxes"}: + if f.name == "fluxes": f.dim = 6 train(cfg.train) diff --git a/tests/test_main.py b/tests/test_main.py index f55c467b4..59114b9b0 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -65,7 +65,7 @@ def test_train_des(self, cfg, tmp_path): cfg.prior.survey_bands = DES.BANDS for f in cfg.variational_factors: - if f.name in {"star_fluxes", "galaxy_fluxes"}: + if f.name == "fluxes": f.dim = 4 cfg.encoder.survey_bands = DES.BANDS diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1742561b5..ea9149d3e 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -29,8 +29,7 @@ def test_metrics(self): "n_sources": torch.tensor([1, 2]), "plocs": true_locs * slen, "source_type": true_source_type, - "star_fluxes": torch.ones(2, 2, 5), - "galaxy_fluxes": torch.ones(2, 2, 5), + "fluxes": torch.ones(2, 2, 5), "galaxy_params": torch.ones(2, 2, 6), } true_params = FullCatalog(slen, slen, d_true) @@ -39,8 +38,7 @@ def test_metrics(self): "n_sources": torch.tensor([2, 2]), "plocs": est_locs * slen, "source_type": est_source_type, - "star_fluxes": torch.ones(2, 2, 5), - "galaxy_fluxes": torch.ones(2, 2, 5), + "fluxes": torch.ones(2, 2, 5), "galaxy_disk_frac": torch.ones(2, 2, 1), "galaxy_beta_radians": torch.ones(2, 2, 1), "galaxy_disk_q": torch.ones(2, 2, 1), @@ -85,8 +83,7 @@ def test_no_sources(self): "n_sources": true_sources, "plocs": true_locs, "source_type": true_source_type, - "star_fluxes": torch.ones(4, 1, 5), - "galaxy_fluxes": torch.ones(4, 1, 5), + "fluxes": torch.ones(4, 1, 5), "galaxy_params": torch.ones(4, 1, 6), } true_params = FullCatalog(50, 50, d_true) @@ -95,8 +92,7 @@ def test_no_sources(self): "n_sources": est_sources, "plocs": est_locs, "source_type": est_source_type, - "star_fluxes": torch.ones(4, 1, 5), - "galaxy_fluxes": torch.ones(4, 1, 5), + "fluxes": torch.ones(4, 1, 5), "galaxy_params": torch.ones(4, 1, 6), } est_params = FullCatalog(50, 50, d_est) diff --git a/tests/test_simulator.py b/tests/test_simulator.py index d9c4c476a..de312f22c 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -25,7 +25,7 @@ def test_simulate_and_predict(self, cfg): """Test simulating an image from a fixed catalog and making predictions on that catalog.""" # load cached simulated catalog true_catalog = torch.load(cfg.paths.test_data + "/test_image/dataset_0.pt") - true_catalog["star_fluxes"][0, 10, 10] = 10.0 + true_catalog["fluxes"][0, 10, 10] = 10.0 true_catalog = TileCatalog(true_catalog) # simulate image from catalog @@ -57,17 +57,10 @@ def test_simulate_and_predict(self, cfg): assert torch.equal(true_catalog.star_bools, mode_cat.star_bools) # Compare predicted and true fluxes - true_star_fluxes = true_catalog["star_fluxes"] * true_catalog.star_bools - true_galaxy_fluxes = true_catalog["galaxy_fluxes"] * true_catalog.galaxy_bools - true_fluxes = true_star_fluxes + true_galaxy_fluxes - true_fluxes_crop = true_fluxes[0, :, :, 0, 2] + true_fluxes = true_catalog.on_fluxes[0, :, :, 0, 2] + est_fluxes = mode_cat.on_fluxes[0, :, :, 0, 2] - est_star_fluxes = mode_cat["star_fluxes"] * mode_cat.star_bools - est_galaxy_fluxes = mode_cat["galaxy_fluxes"] * mode_cat.galaxy_bools - est_fluxes = est_star_fluxes + est_galaxy_fluxes - est_fluxes = est_fluxes[0, :, :, 0, 2] - - assert (est_fluxes - true_fluxes_crop).abs().sum() / (true_fluxes_crop.abs().sum()) < 1.0 + assert (est_fluxes - true_fluxes).abs().sum() / (true_fluxes.abs().sum()) < 1.0 def test_render_images(self, cfg): with open(Path(cfg.paths.test_data) / "sdss_preds.pt", "rb") as f: