Skip to content

Commit

Permalink
Spatially-variant PSF/background reproducability (#1053)
Browse files Browse the repository at this point in the history
* Add config and fix minor errors; remove old data and notebook

* Add binned flux metrics; update source type metrics to bin by magnitude

* Add new config files and flux bin notebook

* Update configs

* Add galaxy shape metrics

* Refactor evaluate_models notebook to python script; update train_models.sh with new configs and eval script

* Fix paths in configs

* Fix style checks

* Fix tests

* Add second field to sdss in base_config

* Add galaxy shape metrics to tests
  • Loading branch information
aakashdp6548 authored Aug 3, 2024
1 parent c0645a3 commit 419e75d
Show file tree
Hide file tree
Showing 20 changed files with 2,096 additions and 683 deletions.
10 changes: 8 additions & 2 deletions bliss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
import omegaconf


def make_range(start, stop, step):
return omegaconf.ListConfig(list(range(start, stop, step)))
def make_range(start, stop, step, *args):
orig_range = list(range(start, stop, step))
for arg in args:
try:
orig_range.remove(arg)
except ValueError:
continue
return omegaconf.ListConfig(orig_range)


# resolve ranges in config files (we want this to execute at any entry point)
Expand Down
2 changes: 1 addition & 1 deletion bliss/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __call__(self, datum_in):

class ChunkingSampler(Sampler):
def __init__(self, dataset: Dataset) -> None:
super().__init__()
super().__init__(dataset)
assert isinstance(dataset, ChunkingDataset), "dataset should be ChunkingDataset"
self.dataset = dataset

Expand Down
7 changes: 6 additions & 1 deletion bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ metrics:
flux_error:
_target_: bliss.encoder.metrics.FluxError
survey_bands: ${encoder.survey_bands}
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
bin_type: "njymag"

image_normalizers:
psf:
Expand Down Expand Up @@ -207,7 +209,10 @@ surveys:
fields: # TODO: better arbitary name for fields/bricks?
- run: 94
camcol: 1
fields: [12]
fields: [12] # can also use ${range:start,stop,step,*exclude}
- run: 3900
camcol: 6
fields: [269]
psf_config:
pixel_scale: 0.396
psf_slen: 25
Expand Down
223 changes: 191 additions & 32 deletions bliss/encoder/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(
self.bin_type = bin_type
self.exclude_last_bin = exclude_last_bin

assert self.bin_type in {"nmgy", "njymag"}, "invalid bin type"
assert self.bin_type in {"mag", "nmgy", "njymag"}, "invalid bin type"

detection_metrics = [
"n_true_sources",
Expand Down Expand Up @@ -343,14 +343,14 @@ def plot(self):
axes[0].step(
range(len(xlabels)),
n_true_sources.tolist(),
label=f"# true sources {fig_tag}",
label=f"Number of true sources {fig_tag}",
where="mid",
color=c4,
)
axes[0].step(
range(len(xlabels)),
n_true_matches.tolist(),
label=f"# BLISS matches {fig_tag}",
label=f"Number of BLISS matches {fig_tag}",
ls="--",
where="mid",
color=c4,
Expand Down Expand Up @@ -381,7 +381,7 @@ def __init__(
self.bin_type = bin_type

assert self.bin_cutoffs, "flux_bin_cutoffs can't be None or empty"
assert self.bin_type in {"nmgy", "njymag"}, "invalid bin type"
assert self.bin_type in {"mag", "nmgy", "njymag"}, "invalid bin type"

n_bins = len(self.bin_cutoffs) + 1

Expand Down Expand Up @@ -501,59 +501,218 @@ def get_internal_states(self):


class FluxError(Metric):
def __init__(self, survey_bands):
def __init__(
self,
survey_bands,
bin_cutoffs: list,
ref_band: int = 2,
bin_type: str = "mag",
exclude_last_bin: bool = False,
):
super().__init__()
self.survey_bands = survey_bands
self.survey_bands = survey_bands # list of band names (e.g. "r")
self.ref_band = ref_band
self.bin_cutoffs = bin_cutoffs
self.bin_type = bin_type
self.exclude_last_bin = exclude_last_bin
self.n_bins = len(self.bin_cutoffs) + 1

fe_init = torch.zeros(len(self.survey_bands))
self.add_state("flux_err", default=fe_init, dist_reduce_fx="sum")
self.add_state("n_matches", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state(
"flux_abs_err",
default=torch.zeros((len(self.survey_bands), self.n_bins)), # n_bins per band
dist_reduce_fx="sum",
)
self.add_state(
"flux_pct_err",
default=torch.zeros((len(self.survey_bands), self.n_bins)), # n_bins per band
dist_reduce_fx="sum",
)
self.add_state(
"flux_abs_pct_err",
default=torch.zeros((len(self.survey_bands), self.n_bins)), # n_bins per band
dist_reduce_fx="sum",
)
self.add_state("n_matches", default=torch.zeros(self.n_bins), dist_reduce_fx="sum")

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
true_bin_measures = true_cat.on_fluxes(self.bin_type)[:, :, self.ref_band].contiguous()

for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]
self.n_matches += tcat_matches.size(0)

true_flux = true_cat.on_nmgy[i, tcat_matches, :]
est_flux = est_cat.on_nmgy[i, ecat_matches, :]
self.flux_err += (true_flux - est_flux).abs().sum(dim=0)
n_true = true_cat["n_sources"][i].int().sum().item()
bin_measure = true_bin_measures[i, 0:n_true][tcat_matches].contiguous()
bins = torch.bucketize(bin_measure, cutoffs)

true_flux = true_cat.on_fluxes("nmgy")[i, tcat_matches]
est_flux = est_cat.on_fluxes("nmgy")[i, ecat_matches]

# Compute and update percent error per band
abs_err = (true_flux - est_flux).abs()
pct_err = (true_flux - est_flux) / true_flux
abs_pct_err = pct_err.abs()
for band in range(len(self.survey_bands)): # noqa: WPS518
tmp = torch.zeros((self.n_bins,), dtype=torch.float, device=self.device)
tmp = tmp.scatter_add(0, bins.reshape(-1), abs_err[..., band].reshape(-1))
self.flux_abs_err[band] += tmp

tmp = torch.zeros((self.n_bins,), dtype=torch.float, device=self.device)
tmp = tmp.scatter_add(0, bins.reshape(-1), pct_err[..., band].reshape(-1))
self.flux_pct_err[band] += tmp

tmp = torch.zeros((self.n_bins,), dtype=torch.float, device=self.device)
tmp = tmp.scatter_add(0, bins.reshape(-1), abs_pct_err[..., band].reshape(-1))
self.flux_abs_pct_err[band] += tmp.abs()
self.n_matches += bins.bincount(minlength=self.n_bins)

def compute(self):
avg_flux_err = self.flux_err / self.n_matches
final_idx = -1 if self.exclude_last_bin else None
flux_abs_err = self.flux_abs_err[:, :final_idx]
flux_pct_err = self.flux_pct_err[:, :final_idx]
flux_abs_pct_err = self.flux_abs_pct_err[:, :final_idx]
n_matches = self.n_matches[:final_idx]

# Compute final metrics
mae = flux_abs_err.sum(dim=1) / n_matches.sum()
binned_mae = flux_abs_err / n_matches
mpe = flux_pct_err.sum(dim=1) / n_matches.sum()
binned_mpe = flux_pct_err / n_matches
mape = flux_abs_pct_err.sum(dim=1) / n_matches.sum()
binned_mape = flux_abs_pct_err / n_matches

results = {}
for i, band in enumerate(self.survey_bands):
results[f"flux_err_{band}_mae"] = avg_flux_err[i].item()
results[f"flux_err_{band}_mae"] = mae[i]
results[f"flux_err_{band}_mpe"] = mpe[i]
results[f"flux_err_{band}_mape"] = mape[i]
for j in range(binned_mpe.shape[1]):
results[f"flux_err_{band}_mae_bin_{j}"] = binned_mae[i, j]
results[f"flux_err_{band}_mpe_bin_{j}"] = binned_mpe[i, j]
results[f"flux_err_{band}_mape_bin_{j}"] = binned_mape[i, j]

return results


class GalaxyShapeError(Metric):
GALSIM_NAMES = ["disk_frac", "beta_radians", "disk_q", "a_d", "bulge_q", "a_b"]
galaxy_params = [
"galaxy_disk_frac",
"galaxy_beta_radians",
"galaxy_disk_q",
"galaxy_a_d",
"galaxy_bulge_q",
"galaxy_a_b",
]
galaxy_param_to_idx = {param: i for i, param in enumerate(galaxy_params)}

def __init__(self):
def __init__(
self,
bin_cutoffs,
ref_band=2,
bin_type="mag",
exclude_last_bin=False,
):
super().__init__()

gpe_init = torch.zeros(len(self.GALSIM_NAMES))
self.add_state("galsim_param_err", default=gpe_init, dist_reduce_fx="sum")
self.add_state("n_true_galaxies", default=torch.zeros(1), dist_reduce_fx="sum")
self.ref_band = ref_band
self.bin_cutoffs = bin_cutoffs
self.n_bins = len(self.bin_cutoffs) + 1
self.bin_type = bin_type
self.exclude_last_bin = exclude_last_bin # used to ignore dim objects

gpe_init = torch.zeros((len(self.galaxy_params), self.n_bins))
self.add_state("galaxy_param_err", default=gpe_init, dist_reduce_fx="sum")
self.add_state("disk_hlr_err", torch.zeros(self.n_bins), dist_reduce_fx="sum")
self.add_state("bulge_hlr_err", torch.zeros(self.n_bins), dist_reduce_fx="sum")
self.add_state("n_true_galaxies", default=torch.zeros(self.n_bins), dist_reduce_fx="sum")

def update(self, true_cat, est_cat, matching):
true_bin_meas = true_cat.on_fluxes(self.bin_type)[:, :, self.ref_band].contiguous()
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)

for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]
n_true = true_cat["n_sources"][i].sum().item()

is_gal = true_cat.galaxy_bools[i][tcat_matches][:, 0]
# Skip if no galaxies in this image
if (~is_gal).all():
continue
true_matched_mags = true_bin_meas[i, 0:n_true][tcat_matches]
true_gal_mags = true_matched_mags[is_gal]

# get magnitude bin for each matched galaxy
mag_bins = torch.bucketize(true_gal_mags, cutoffs)
self.n_true_galaxies += mag_bins.bincount(minlength=self.n_bins)

true_gal_params = true_cat["galaxy_params"][i, tcat_matches][is_gal]

true_gal = true_cat.galaxy_bools[i][tcat_matches]
self.n_true_galaxies += true_gal.sum()
for j, name in enumerate(self.galaxy_params):
true_param = true_gal_params[:, j]
est_param = est_cat[name][i, ecat_matches][is_gal, 0]
abs_res = (true_param - est_param).abs()

# TODO: only compute error for *true* galaxies
true_gp = true_cat["galaxy_params"][i, tcat_matches, :]
est_gp = est_cat["galaxy_params"][i, ecat_matches, :]
gp_err = (true_gp - est_gp).abs().sum(dim=0)
# TODO: angle is a special case, need to wrap around pi (not 2pi due to symmetry)
self.galsim_param_err += gp_err
# Wrap angle around pi
if name == "galaxy_beta_radians":
abs_res = abs_res % torch.pi

# Update bins
tmp = torch.zeros(self.n_bins, dtype=torch.float, device=self.device)
self.galaxy_param_err[j] += tmp.scatter_add(0, mag_bins, abs_res)

# Compute HLRs for disk and bulge
true_a_d = true_gal_params[:, self.galaxy_param_to_idx["galaxy_a_d"]]
true_disk_q = true_gal_params[:, self.galaxy_param_to_idx["galaxy_disk_q"]]
true_b_d = true_a_d * true_disk_q
true_disk_hlr = torch.sqrt(true_a_d * true_b_d)

est_a_d = est_cat["galaxy_a_d"][i, ecat_matches][is_gal, 0]
est_disk_q = est_cat["galaxy_disk_q"][i, ecat_matches][is_gal, 0]
est_b_d = est_a_d * est_disk_q
est_disk_hlr = torch.sqrt(est_a_d * est_b_d)

true_a_b = true_gal_params[:, self.galaxy_param_to_idx["galaxy_a_b"]]
true_bulge_q = true_gal_params[:, self.galaxy_param_to_idx["galaxy_bulge_q"]]
true_b_b = true_a_b * true_bulge_q
true_bulge_hlr = torch.sqrt(true_a_b * true_b_b)

est_a_b = est_cat["galaxy_a_b"][i, ecat_matches][is_gal, 0]
est_bulge_q = est_cat["galaxy_bulge_q"][i, ecat_matches][is_gal, 0]
est_b_b = est_a_b * est_bulge_q
est_bulge_hlr = torch.sqrt(est_a_b * est_b_b)

abs_disk_hlr_res = (true_disk_hlr - est_disk_hlr).abs()
tmp = torch.zeros(self.n_bins, dtype=torch.float, device=self.device)
self.disk_hlr_err += tmp.scatter_add(0, mag_bins, abs_disk_hlr_res)

abs_bulge_hlr_res = (true_bulge_hlr - est_bulge_hlr).abs()
tmp = torch.zeros(self.n_bins, dtype=torch.float, device=self.device)
self.bulge_hlr_err += tmp.scatter_add(0, mag_bins, abs_bulge_hlr_res)

def compute(self):
avg_galsim_param_err = self.galsim_param_err / self.n_true_galaxies
final_idx = -1 if self.exclude_last_bin else None
galaxy_param_err = self.galaxy_param_err[:, :final_idx]
disk_hlr_err = self.disk_hlr_err[:final_idx]
bulge_hlr_err = self.bulge_hlr_err[:final_idx]
n_galaxies = self.n_true_galaxies[:final_idx]

gal_param_mae = galaxy_param_err.sum(dim=1) / n_galaxies.sum()
binned_gal_param_mae = galaxy_param_err / n_galaxies

disk_hlr_mae = disk_hlr_err.sum() / n_galaxies.sum()
binned_disk_hlr_mae = disk_hlr_err / n_galaxies
bulge_hlr_mae = bulge_hlr_err.sum() / n_galaxies.sum()
binned_bulge_hlr_mae = bulge_hlr_err / n_galaxies

results = {}
for i, gs_name in enumerate(self.GALSIM_NAMES):
results[f"{gs_name}_mae"] = avg_galsim_param_err[i]
for i, name in enumerate(self.galaxy_params):
results[f"{name}_mae"] = gal_param_mae[i]
for j in range(self.n_bins):
results[f"{name}_mae_bin_{j}"] = binned_gal_param_mae[i, j]

results["galaxy_disk_hlr_mae"] = disk_hlr_mae
results["galaxy_bulge_hlr_mae"] = bulge_hlr_mae
for j in range(self.n_bins):
results[f"galaxy_disk_hlr_mae_bin_{j}"] = binned_disk_hlr_mae[j]
results[f"galaxy_bulge_hlr_mae_bin_{j}"] = binned_bulge_hlr_mae[j]

return results
18 changes: 18 additions & 0 deletions case_studies/psf_variation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Spatially Varying Backgrounds and PSFs Case Study
============================================

This case study contains code to reproduce the results and figures for the paper "Neural Posterior Estimation for Cataloging Astronomical Images with Spatially Varying Backgrounds and Point Spread Functions".

To run the experiments, run
```
sh run_experiments.sh
```

This will generate two synthetic datasets, train three models on the appropriate dataset, and run evaluation and generate figures similar to those seen in the paper. Note that the results may not be exact due to differences in the generated data.

There are three main config files, one for each model:
- `conf/single_field.yaml`
- `conf/psf_unaware.yaml`
- `conf/psf_aware.yaml`

You should modify the `paths/cached_data` in each file to point to where you would like to save the generated datasets.
Loading

0 comments on commit 419e75d

Please sign in to comment.