Skip to content

Commit

Permalink
shared fluxes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeff-regier committed Sep 11, 2024
1 parent e4f7ec5 commit c922258
Show file tree
Hide file tree
Showing 34 changed files with 184 additions and 455 deletions.
18 changes: 2 additions & 16 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,7 @@ def galaxy_bools(self) -> Tensor:

@property
def on_fluxes(self):
# TODO: a tile catalog should store fluxes rather than star_fluxes and galaxy_fluxes
# because that's all that's needed to render the source
if "galaxy_fluxes" not in self:
fluxes = self["star_fluxes"]
else:
fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"])
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))
return self.is_on_mask.unsqueeze(-1) * self["fluxes"]

def on_magnitudes(self, zero_point) -> Tensor:
return convert_flux_to_magnitude(self.on_fluxes, zero_point)
Expand Down Expand Up @@ -470,12 +464,7 @@ def galaxy_bools(self) -> Tensor:

@property
def on_fluxes(self) -> Tensor:
# ideally we'd always store fluxes rather than star_fluxes and galaxy_fluxes
if "galaxy_fluxes" not in self:
fluxes = self["star_fluxes"]
else:
fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"])
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))
return self.is_on_mask.unsqueeze(-1) * self["fluxes"]

def on_magnitudes(self, zero_point) -> Tensor:
return convert_flux_to_magnitude(self.on_fluxes, zero_point)
Expand Down Expand Up @@ -571,9 +560,6 @@ def to_tile_catalog(
"""
assert max_sources_per_tile <= torch.iinfo(inter_int_type).max

# TODO: a FullCatalog only needs to "know" its height and width to convert itself to a
# TileCatalog. So those parameters should be passed on conversion, not initialization.

# initialization #
n_tiles_h = math.ceil(self.height / tile_slen)
n_tiles_w = math.ceil(self.width / tile_slen)
Expand Down
11 changes: 2 additions & 9 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,12 @@ variational_factors:
nll_gating:
_target_: bliss.encoder.variational_dist.SourcesGating
- _target_: bliss.encoder.variational_dist.LogNormalFactor
name: star_fluxes
name: fluxes
dim: 5
sample_rearrange: b ht wt d -> b ht wt 1 d
nll_rearrange: b ht wt 1 d -> b ht wt d
nll_gating:
_target_: bliss.encoder.variational_dist.StarGating
- _target_: bliss.encoder.variational_dist.LogNormalFactor
name: galaxy_fluxes
dim: 5
sample_rearrange: b ht wt d -> b ht wt 1 d
nll_rearrange: b ht wt 1 d -> b ht wt d
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating
_target_: bliss.encoder.variational_dist.SourcesGating
- _target_: bliss.encoder.variational_dist.LogitNormalFactor
name: galaxy_disk_frac
sample_rearrange: b ht wt d -> b ht wt 1 d
Expand Down
4 changes: 2 additions & 2 deletions bliss/encoder/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def get_cur_postfix(self):

class NullFilter(CatFilter):
def get_cur_filter_bools(self, true_cat, est_cat):
true_filter_bools = torch.ones_like(true_cat.star_bools.squeeze(2)).bool()
est_filter_bools = torch.ones_like(est_cat.star_bools.squeeze(2)).bool()
true_filter_bools = torch.ones_like(true_cat.is_on_mask)
est_filter_bools = torch.ones_like(est_cat.is_on_mask)

return true_filter_bools, est_filter_bools

Expand Down
10 changes: 4 additions & 6 deletions bliss/simulator/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def render_star(self, psf, band, source_params):
Returns:
GSObject: a galsim representation of the rendered star convolved with the PSF
"""
return psf[band].withFlux(source_params["star_fluxes"][band].item())
return psf[band].withFlux(source_params["fluxes"][band].item())

def render_galaxy(self, psf, band, source_params):
"""Render a galaxy with given params and PSF.
Expand All @@ -62,9 +62,9 @@ def render_galaxy(self, psf, band, source_params):
Returns:
GSObject: a galsim representation of the rendered galaxy convolved with the PSF
"""
disk_flux = source_params["galaxy_fluxes"][band] * source_params["galaxy_disk_frac"]
disk_flux = source_params["fluxes"][band] * source_params["galaxy_disk_frac"]
bulge_frac = 1 - source_params["galaxy_disk_frac"]
bulge_flux = source_params["galaxy_fluxes"][band] * bulge_frac
bulge_flux = source_params["fluxes"][band] * bulge_frac
beta = source_params["galaxy_beta_radians"] * galsim.radians

components = []
Expand Down Expand Up @@ -153,9 +153,7 @@ def render_image(self, tile_cat):
# use the specified flux_calibration ratios indexed by image_id
avg_nelec_conv = np.mean(frame["flux_calibration"], axis=-1)
if n_sources > 0:
full_cat["star_fluxes"] *= rearrange(avg_nelec_conv, "bands -> 1 1 bands")
if "galaxy_fluxes" in tile_cat:
full_cat["galaxy_fluxes"] *= avg_nelec_conv
full_cat["fluxes"] *= rearrange(avg_nelec_conv, "bands -> 1 1 bands")

# generate random WCS shifts as manual image dithering via unaligning WCS
if self.with_dither:
Expand Down
6 changes: 4 additions & 2 deletions bliss/simulator/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ def sample(self) -> TileCatalog:
d["locs"] = self._sample_locs()
d["n_sources"] = self._sample_n_sources()
d["source_type"] = self._sample_source_type()
d["star_fluxes"] = self._sample_fluxes(self.gmm_star, self.star_flux)
d["galaxy_fluxes"] = self._sample_fluxes(self.gmm_gal, self.galaxy_flux)

star_fluxes = self._sample_fluxes(self.gmm_star, self.star_flux)
galaxy_fluxes = self._sample_fluxes(self.gmm_gal, self.galaxy_flux)
d["fluxes"] = torch.where(d["source_type"], galaxy_fluxes, star_fluxes)

Check warning on line 108 in bliss/simulator/prior.py

View check run for this annotation

Codecov / codecov/patch

bliss/simulator/prior.py#L106-L108

Added lines #L106 - L108 were not covered by tests

return TileCatalog(d)

Expand Down
10 changes: 3 additions & 7 deletions bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,7 @@ def generate_cached_data(self, image_index):
"locs",
"n_sources",
"source_type",
"galaxy_fluxes",
"star_fluxes",
"fluxes",
"redshifts",
"blendedness",
"shear",
Expand Down Expand Up @@ -333,9 +332,7 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
source_type = torch.from_numpy(catalog["truth_type"].values)
# we ignore the supernova
source_type = torch.where(source_type == 2, SourceType.STAR, SourceType.GALAXY)
flux, psf_params = cls.get_bands_flux_and_psf(kwargs["bands"], catalog)
star_fluxes = flux
galaxy_fluxes = flux
fluxes, psf_params = cls.get_bands_flux_and_psf(kwargs["bands"], catalog)
blendedness = torch.from_numpy(catalog["blendedness"].values)
shear1 = torch.from_numpy(catalog["shear_1"].values)
shear2 = torch.from_numpy(catalog["shear_2"].values)
Expand All @@ -352,8 +349,7 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
"source_type": source_type.view(1, ori_len, 1),
"plocs": plocs.view(1, ori_len, 2),
"redshifts": redshifts.view(1, ori_len, 1),
"galaxy_fluxes": galaxy_fluxes.view(1, ori_len, kwargs["n_bands"]),
"star_fluxes": star_fluxes.view(1, ori_len, kwargs["n_bands"]),
"fluxes": fluxes.view(1, ori_len, kwargs["n_bands"]),
"blendedness": blendedness.view(1, ori_len, 1),
"shear": shear.view(1, ori_len, 2),
"ellipticity": ellipticity.view(1, ori_len, 2),
Expand Down
9 changes: 2 additions & 7 deletions bliss/surveys/sdss.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,11 @@ def from_file(cls, cat_path, wcs: WCS, height, width):
decs = column_to_tensor(table, "dec")
galaxy_bools = (objc_type == 3) & (thing_id != -1)
star_bools = (objc_type == 6) & (thing_id != -1)
star_fluxes = column_to_tensor(table, "psfflux") * star_bools.reshape(-1, 1)
star_mags = column_to_tensor(table, "psfmag") * star_bools.reshape(-1, 1)
galaxy_fluxes = column_to_tensor(table, "cmodelflux") * galaxy_bools.reshape(-1, 1)
galaxy_mags = column_to_tensor(table, "cmodelmag") * galaxy_bools.reshape(-1, 1)

# Combine light source parameters to one tensor
star_fluxes = column_to_tensor(table, "psfflux") * star_bools.reshape(-1, 1)
galaxy_fluxes = column_to_tensor(table, "cmodelflux") * galaxy_bools.reshape(-1, 1)
fluxes = star_fluxes + galaxy_fluxes
mags = star_mags + galaxy_mags

# true light source mask
keep = galaxy_bools | star_bools
Expand All @@ -383,7 +380,6 @@ def from_file(cls, cat_path, wcs: WCS, height, width):
ras = ras[keep]
decs = decs[keep]
fluxes = fluxes[keep]
mags = mags[keep]
nobj = ras.shape[0]

# We require all 5 bands for computing loss on predictions.
Expand All @@ -401,7 +397,6 @@ def from_file(cls, cat_path, wcs: WCS, height, width):
"n_sources": torch.tensor((nobj,)),
"source_type": source_type.reshape(1, nobj, 1),
"fluxes": fluxes.reshape(1, nobj, n_bands),
"mags": mags.reshape(1, nobj, n_bands),
"ra": ras.reshape(1, nobj, 1),
"dec": decs.reshape(1, nobj, 1),
}
Expand Down
91 changes: 19 additions & 72 deletions case_studies/dc2_cataloging/cataloging_exp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3458,10 +3458,8 @@
" self.ellipticity_sources = 0\n",
" self.ellipticity1_within_ci_num = 0\n",
" self.ellipticity2_within_ci_num = 0\n",
" self.star_flux_sources = 0\n",
" self.star_flux_within_ci_num = {band: 0 for band in bands}\n",
" self.galaxy_flux_sources = 0\n",
" self.galaxy_flux_within_ci_num = {band: 0 for band in bands}\n",
" self.flux_sources = 0\n",
" self.flux_within_ci_num = {band: 0 for band in bands}\n",
"\n",
" bliss_locs_list = []\n",
" bliss_locs_ci_lower_list = []\n",
Expand Down Expand Up @@ -3498,23 +3496,14 @@
" self.ellipticity1_within_ci_num += (ellipticity_within_ci[..., 0:1] & ellipticity_mask).sum().item()\n",
" self.ellipticity2_within_ci_num += (ellipticity_within_ci[..., 1:2] & ellipticity_mask).sum().item()\n",
"\n",
" star_mask = bliss_tile_cat.star_bools & target_tile_cat.star_bools[..., :bliss_m, :]\n",
" self.star_flux_sources += star_mask.sum().item()\n",
" bliss_star_flux_ci_lower = bliss_tile_cat[\"star_fluxes_ci_lower\"]\n",
" bliss_star_flux_ci_upper = bliss_tile_cat[\"star_fluxes_ci_upper\"]\n",
" target_star_flux = target_tile_cat[\"star_fluxes\"][..., :bliss_m, :]\n",
" star_flux_within_ci = (target_star_flux > bliss_star_flux_ci_lower) & (target_star_flux < bliss_star_flux_ci_upper)\n",
" mask = bliss_tile_cat.is_on_mask & target_tile_cat.is_on_mask[..., :bliss_m, :]\n",
" self.flux_sources += mask.sum().item()\n",
" bliss_flux_ci_lower = bliss_tile_cat[\"fluxes_ci_lower\"]\n",
" bliss_flux_ci_upper = bliss_tile_cat[\"fluxes_ci_upper\"]\n",
" target_flux = target_tile_cat[\"fluxes\"][..., :bliss_m, :]\n",
" flux_within_ci = (target_flux > bliss_flux_ci_lower) & (target_flux < bliss_flux_ci_upper)\n",
" for i, band in enumerate(bands):\n",
" self.star_flux_within_ci_num[band] += (star_flux_within_ci[..., i:(i + 1)] & star_mask).sum().item()\n",
"\n",
" galaxy_mask = bliss_tile_cat.galaxy_bools & target_tile_cat.galaxy_bools[..., :bliss_m, :]\n",
" self.galaxy_flux_sources += galaxy_mask.sum().item()\n",
" bliss_galaxy_flux_ci_lower = bliss_tile_cat[\"galaxy_fluxes_ci_lower\"]\n",
" bliss_galaxy_flux_ci_upper = bliss_tile_cat[\"galaxy_fluxes_ci_upper\"]\n",
" target_galaxy_flux = target_tile_cat[\"galaxy_fluxes\"][..., :bliss_m, :]\n",
" galaxy_flux_within_ci = (target_galaxy_flux > bliss_galaxy_flux_ci_lower) & (target_galaxy_flux < bliss_galaxy_flux_ci_upper)\n",
" for i, band in enumerate(bands):\n",
" self.galaxy_flux_within_ci_num[band] += (galaxy_flux_within_ci[..., i:(i + 1)] & galaxy_mask).sum().item()\n",
" self.flux_within_ci_num[band] += (flux_within_ci[..., i:(i + 1)] & mask).sum().item()\n",
"\n",
" rand_int = random.randint(0, len(bliss_locs_list) - 1)\n",
" bliss_locs = bliss_locs_list[rand_int]\n",
Expand Down Expand Up @@ -3546,15 +3535,11 @@
" print(f\"# ellipticity2 within ci: {self.ellipticity2_within_ci_num}\")\n",
" print(f\"ellipticity2 within ci: {self.ellipticity2_within_ci_num / self.ellipticity_sources: .4f}\")\n",
" print()\n",
" print(f\"# star sources: {self.star_flux_sources}\")\n",
" print(f\"# star sources: {self.flux_sources}\")\n",
" for band in bands:\n",
" print(f\"# {band} flux within ci: {self.star_flux_within_ci_num[band]}\")\n",
" print(f\"{band} flux within ci: {self.star_flux_within_ci_num[band] / self.star_flux_sources: .4f}\")\n",
" print(f\"# {band} flux within ci: {self.flux_within_ci_num[band]}\")\n",
" print(f\"{band} flux within ci: {self.flux_within_ci_num[band] / self.flux_sources: .4f}\")\n",
" print()\n",
" print(f\"# galaxy sources: {self.galaxy_flux_sources}\")\n",
" for band in bands:\n",
" print(f\"# {band} flux within ci: {self.galaxy_flux_within_ci_num[band]}\")\n",
" print(f\"{band} flux within ci: {self.galaxy_flux_within_ci_num[band] / self.galaxy_flux_sources: .4f}\")\n",
"\n",
" def plot(self):\n",
" fig1, ax1 = plt.subplots(1, 1, figsize=NoteBookPlottingParams.figsize)\n",
Expand Down Expand Up @@ -3736,20 +3721,18 @@
" locs1_vsbc_list.append(locs_vsbc[..., 0])\n",
" locs2_vsbc_list.append(locs_vsbc[..., 1])\n",
" \n",
" star_flux_vsbc = bliss_tile_cat[\"star_fluxes_vsbc\"]\n",
" galaxy_flux_vsbc = bliss_tile_cat[\"galaxy_fluxes_vsbc\"]\n",
" flux_vsbc = bliss_tile_cat[\"fluxes_vsbc\"]\n",
" for i, band in enumerate(bands):\n",
" self.star_flux_vsbc_dict[band].append(star_flux_vsbc[..., i])\n",
" self.galaxy_flux_vsbc_dict[band].append(galaxy_flux_vsbc[..., i])\n",
" self.flux_vsbc_dict[band].append(flux_vsbc[..., i])\n",
"\n",
" self.ellipticity1_vsbc = torch.cat(ellipticity1_vsbc_list, dim=0).flatten()\n",
" self.ellipticity2_vsbc = torch.cat(ellipticity2_vsbc_list, dim=0).flatten()\n",
" self.locs1_vsbc = torch.cat(locs1_vsbc_list, dim=0).flatten()\n",
" self.locs2_vsbc = torch.cat(locs2_vsbc_list, dim=0).flatten()\n",
" self.is_on_mask = torch.cat(is_on_mask_list, dim=0).flatten()\n",
"\n",
" for k, v in self.star_flux_vsbc_dict.items():\n",
" self.star_flux_vsbc_dict[k] = torch.cat(v, dim=0).flatten()\n",
" for k, v in self.flux_vsbc_dict.items():\n",
" self.flux_vsbc_dict[k] = torch.cat(v, dim=0).flatten()\n",
"\n",
" for k, v in self.galaxy_flux_vsbc_dict.items():\n",
" self.galaxy_flux_vsbc_dict[k] = torch.cat(v, dim=0).flatten()\n",
Expand Down Expand Up @@ -3808,7 +3791,7 @@
"\n",
" return fig1, fig2\n",
" \n",
" def _star_flux_plot(self):\n",
" def _flux_plot(self):\n",
" fig, axes = plt.subplots(3, 2, \n",
" figsize=(NoteBookPlottingParams.figsize[0] * 1.1, NoteBookPlottingParams.figsize[1] * 1.7), \n",
" sharex=\"col\", \n",
Expand Down Expand Up @@ -3842,50 +3825,14 @@
"\n",
" return fig\n",
" \n",
" def _galaxy_flux_plot(self):\n",
" fig, axes = plt.subplots(3, 2, \n",
" figsize=(NoteBookPlottingParams.figsize[0] * 1.1, NoteBookPlottingParams.figsize[1] * 1.7), \n",
" sharex=\"col\", \n",
" sharey=\"row\",\n",
" layout=\"constrained\")\n",
"\n",
" for i, band in enumerate(bands):\n",
" col_index = i % 2\n",
" row_index = i // 2\n",
" ax = axes[row_index, col_index]\n",
" galaxy_flux_vsbc = self.galaxy_flux_vsbc_dict[band]\n",
" galaxy_flux_vsbc = galaxy_flux_vsbc[self.is_on_mask & ~galaxy_flux_vsbc.isnan()]\n",
" ax.hist(galaxy_flux_vsbc, density=True, bins=50)\n",
" ax.axvline(0.5, color=\"red\")\n",
" ax.set_title(f\"{band} band\", fontsize=NoteBookPlottingParams.fontsize)\n",
" ax.grid(visible=False)\n",
"\n",
" axes[1, 0].set_ylabel(\"Density\", fontsize=NoteBookPlottingParams.fontsize)\n",
" axes[2, 0].set_xticks(np.linspace(0.0, 1.0, num=6))\n",
" axes[2, 0].set_xticklabels([f\"{i: .1f}\" for i in np.linspace(0.0, 1.0, num=6)])\n",
" axes[2, 0].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n",
" axes[2, 0].set_xlabel(\"$P_{galaxy\\ flux}$\", fontsize=NoteBookPlottingParams.fontsize)\n",
" axes[2, 1].set_xticks(np.linspace(0.0, 1.0, num=6))\n",
" axes[2, 1].set_xticklabels([f\"{i: .1f}\" for i in np.linspace(0.0, 1.0, num=6)])\n",
" axes[2, 1].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n",
" axes[2, 1].set_xlabel(\"$P_{galaxy\\ flux}$\", fontsize=NoteBookPlottingParams.fontsize)\n",
"\n",
" axes[0, 0].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n",
" axes[1, 0].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n",
" axes[2, 0].tick_params(labelsize=NoteBookPlottingParams.fontsize * 0.7)\n",
"\n",
" return fig\n",
"\n",
" def plot(self, plot_type: str):\n",
" match plot_type:\n",
" case \"ellipticity\":\n",
" return self._ellipticity_plot()\n",
" case \"locs\":\n",
" return self._locs_plot()\n",
" case \"star_flux\":\n",
" return self._star_flux_plot()\n",
" case \"galaxy_flux\":\n",
" return self._galaxy_flux_plot()\n",
" case \"flux\":\n",
" return self._flux_plot()\n",
" case _:\n",
" raise NotImplementedError()"
]
Expand Down
11 changes: 2 additions & 9 deletions case_studies/dc2_cataloging/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,12 @@ my_variational_factors:
nll_gating:
_target_: bliss.encoder.variational_dist.SourcesGating
- _target_: bliss.encoder.variational_dist.LogNormalFactor
name: star_fluxes
name: fluxes
dim: 6
sample_rearrange: b ht wt d -> b ht wt 1 d
nll_rearrange: b ht wt 1 d -> b ht wt d
nll_gating:
_target_: bliss.encoder.variational_dist.StarGating
- _target_: bliss.encoder.variational_dist.LogNormalFactor
name: galaxy_fluxes
dim: 6
sample_rearrange: b ht wt d -> b ht wt 1 d
nll_rearrange: b ht wt 1 d -> b ht wt d
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating
_target_: bliss.encoder.variational_dist.SourcesGating
- _target_: bliss.encoder.variational_dist.BivariateNormalFactor
name: ellipticity
sample_rearrange: b ht wt d -> b ht wt 1 d
Expand Down
3 changes: 1 addition & 2 deletions case_studies/dc2_cataloging/utils/load_lsst.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def get_lsst_full_cat(lsst_root_dir: str, cur_image_wcs, image_lim, r_band_min_f
"plocs": lsst_plocs.unsqueeze(0),
"n_sources": lsst_n_sources,
"source_type": lsst_source_type.unsqueeze(0),
"galaxy_fluxes": lsst_flux.unsqueeze(0),
"star_fluxes": lsst_flux.unsqueeze(0).clone(),
"fluxes": lsst_flux.unsqueeze(0),
},
)
3 changes: 1 addition & 2 deletions case_studies/dc2_cataloging/utils/lsst_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def _predict_one_image(self, wcs_header_str, image_lim, height_index, width_inde
"plocs": lsst_plocs.unsqueeze(0),
"n_sources": lsst_n_sources,
"source_type": lsst_source_type.unsqueeze(0),
"galaxy_fluxes": lsst_flux.unsqueeze(0),
"star_fluxes": lsst_flux.unsqueeze(0).clone(),
"fluxes": lsst_flux.unsqueeze(0),
},
).to_tile_catalog(self.tile_slen, self.max_sources_per_tile)

Expand Down
Loading

0 comments on commit c922258

Please sign in to comment.