Skip to content

Commit

Permalink
Weak lensing DC2 updates (#1061)
Browse files Browse the repository at this point in the history
* WL: average ellipticity baseline estimator, lensing catalog generation script (#1058)

* Batched to_tile_catalog (#1057)

batched to_tile_catalog

Co-authored-by: Yicun Duan <[email protected]>

* Add lensed ellipticities to tile catalog; ellipticity vs shear notebook

* Organize notebooks

* Script to generate DC2 pkl file with shear, conv, ellipticity, redshift

* Use avg ellipticity as baseline shear estimator; fix MSE bugs; remove baseline convergence estimator

* Try-except block in cached_dataset

* Update catalog and splits paths in config, update plots dir name

---------

Co-authored-by: Yicun Duan <[email protected]>
Co-authored-by: Yicun Duan <[email protected]>

* merged new workflow changes

* Make lensed ellipticity compatible with new to_tile_catalog

* Remove unnecessary subclass of SimulatedDataset

* Add SDSS-like simulator capabilities to lensing_config

* Update image generation notebook under SDSS-like simulator

* updated vardist and lensing to filter ellip

* Preliminary version of redshift notebook

* fixed mag mask error in lensingdc2

* fixed ellip key set

* Small tweaks to lensing config and ellip/redshift notebooks

* New catalog merge strategy: notebook

* Update variable name for r-band magnitude after new merge strategy

* New catalog merge strategy: script

* Run dc2_generate_lensing_catalog in order

* Update splits path in config

* Squeeze tile dict in lensing_dc2 (like we did before)

* Don't use tract_filter when loading in object_with_truth_match

* Rerun ellipticity and redshift notebooks with new catalog

* fixed vardist nullgating

* Remove old notebooks

* Remove try-catch block from cached_dataset

* Remove print statements from lensing_encoder

* Pass flake8 check in catalog.py

* attempted fixed error with nullgating

---------

Co-authored-by: Yicun Duan <[email protected]>
Co-authored-by: Yicun Duan <[email protected]>
Co-authored-by: shreyasc <[email protected]>
  • Loading branch information
4 people authored Aug 19, 2024
1 parent bc88cfe commit 41d09e1
Show file tree
Hide file tree
Showing 26 changed files with 3,205 additions and 1,965 deletions.
5 changes: 3 additions & 2 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,9 @@ def __init__(self, height: int, width: int, d: Dict[str, Tensor]) -> None:
self.device = d["plocs"].device
self.batch_size, self.max_sources, hw = d["plocs"].shape
assert hw == 2
assert d["n_sources"].max().int().item() <= self.max_sources
assert d["n_sources"].shape == (self.batch_size,)
if "n_sources" in d:
assert d.get("n_sources").max().int().item() <= self.max_sources
assert d.get("n_sources").shape == (self.batch_size,)

super().__init__(**d)

Expand Down
10 changes: 7 additions & 3 deletions bliss/encoder/variational_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def __call__(cls, true_tile_cat: TileCatalog):
class NullGating(NllGating):
@classmethod
def __call__(cls, true_tile_cat: TileCatalog):
return torch.ones_like(true_tile_cat["n_sources"]).bool()
tc_keys = true_tile_cat.keys()
if "n_sources" in tc_keys:
return torch.ones_like(true_tile_cat["n_sources"]).bool()
first = true_tile_cat[list(tc_keys)[0]]
return torch.ones(first.shape[:-1]).bool().to(first.device)


class SourcesGating(NllGating):
Expand Down Expand Up @@ -151,8 +155,8 @@ def __init__(self, *args, low_clamp=-20, high_clamp=20, **kwargs):
self.high_clamp = high_clamp

def get_dist(self, params):
mean = params[:, :, :, 0]
sd = params[:, :, :, 1].clamp(self.low_clamp, self.high_clamp).exp().sqrt()
mean = params[:, :, :, 0:1]
sd = params[:, :, :, 1:2].clamp(self.low_clamp, self.high_clamp).exp().sqrt()
return Normal(mean, sd)


Expand Down
151 changes: 151 additions & 0 deletions case_studies/weak_lensing/generate_dc2_lensing_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
import pickle as pkl

import GCRCatalogs
import healpy as hp
import numpy as np
import pandas as pd
from GCRCatalogs import GCRQuery

GCRCatalogs.set_root_dir("/data/scratch/dc2_nfs/")

file_name = "dc2_lensing_catalog.pkl"
file_path = os.path.join("/data", "scratch", "dc2local", file_name)
file_already_populated = os.path.isfile(file_path)

if file_already_populated:
raise FileExistsError(f"{file_path} already exists.")


print("Loading object-with-truth-match...\n") # noqa: WPS421

object_truth_cat = GCRCatalogs.load_catalog("desc_dc2_run2.2i_dr6_object_with_truth_match")

object_truth_df = object_truth_cat.get_quantities(
quantities=[
"cosmodc2_id_truth",
"id_truth",
"objectId",
"match_objectId",
"truth_type",
"ra_truth",
"dec_truth",
"redshift_truth",
"flux_u_truth",
"flux_g_truth",
"flux_r_truth",
"flux_i_truth",
"flux_z_truth",
"flux_y_truth",
"mag_u_truth",
"mag_g_truth",
"mag_r_truth",
"mag_i_truth",
"mag_z_truth",
"mag_y_truth",
"Ixx_pixel",
"Iyy_pixel",
"Ixy_pixel",
"IxxPSF_pixel_u",
"IxxPSF_pixel_g",
"IxxPSF_pixel_r",
"IxxPSF_pixel_i",
"IxxPSF_pixel_z",
"IxxPSF_pixel_y",
"IyyPSF_pixel_u",
"IyyPSF_pixel_g",
"IyyPSF_pixel_r",
"IyyPSF_pixel_i",
"IyyPSF_pixel_z",
"IyyPSF_pixel_y",
"IxyPSF_pixel_u",
"IxyPSF_pixel_g",
"IxyPSF_pixel_r",
"IxyPSF_pixel_i",
"IxyPSF_pixel_z",
"IxyPSF_pixel_y",
"psf_fwhm_u",
"psf_fwhm_g",
"psf_fwhm_r",
"psf_fwhm_i",
"psf_fwhm_z",
"psf_fwhm_y",
],
)
object_truth_df = pd.DataFrame(object_truth_df)

max_ra = np.nanmax(object_truth_df["ra_truth"])
min_ra = np.nanmin(object_truth_df["ra_truth"])
max_dec = np.nanmax(object_truth_df["dec_truth"])
min_dec = np.nanmin(object_truth_df["dec_truth"])
ra_dec_filters = [f"ra >= {min_ra}", f"ra <= {max_ra}", f"dec >= {min_dec}", f"dec <= {max_dec}"]

vertices = hp.ang2vec(
np.array([min_ra, max_ra, max_ra, min_ra]),
np.array([min_dec, min_dec, max_dec, max_dec]),
lonlat=True,
)
ipix = hp.query_polygon(32, vertices, inclusive=True)
healpix_filter = GCRQuery((lambda h: np.isin(h, ipix, assume_unique=True), "healpix_pixel"))

object_truth_df = object_truth_df[object_truth_df["truth_type"] == 1]


print("Loading CosmoDC2...\n") # noqa: WPS421

config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2"}
cosmo_cat = GCRCatalogs.load_catalog("desc_cosmodc2", config_overwrite)

cosmo_df = cosmo_cat.get_quantities(
quantities=[
"galaxy_id",
"ra",
"dec",
"ellipticity_1_true",
"ellipticity_2_true",
"shear_1",
"shear_2",
"convergence",
],
filters=ra_dec_filters,
native_filters=healpix_filter,
)
cosmo_df = pd.DataFrame(cosmo_df)


print("Merging...\n") # noqa: WPS421

merge_df = object_truth_df.merge(
cosmo_df, left_on="cosmodc2_id_truth", right_on="galaxy_id", how="left"
)

merge_df = merge_df[~merge_df["galaxy_id"].isna()]

merge_df.drop(columns=["ra_truth", "dec_truth"], inplace=True)

merge_df.rename(
columns={
"redshift_truth": "redshift",
"flux_u_truth": "flux_u",
"flux_g_truth": "flux_g",
"flux_r_truth": "flux_r",
"flux_i_truth": "flux_i",
"flux_z_truth": "flux_z",
"flux_y_truth": "flux_y",
"mag_u_truth": "mag_u",
"mag_g_truth": "mag_g",
"mag_r_truth": "mag_r",
"mag_i_truth": "mag_i",
"mag_z_truth": "mag_z",
"mag_y_truth": "mag_y",
},
inplace=True,
)


print("Saving...\n") # noqa: WPS421

with open(file_path, "wb") as f:
pkl.dump(merge_df, f)

print(f"Catalog has been saved at {file_path}") # noqa: WPS421
51 changes: 39 additions & 12 deletions case_studies/weak_lensing/lensing_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,36 @@ defaults:
mode: train

paths:
dc2: /data/scratch/dc2local # change for gl
output: /data/scratch/shreyasc/bliss_output # change for gl
dc2: /data/scratch/dc2local # change for gl
output: /data/scratch/twhit/bliss_output # change for gl

prior:
_target_: case_studies.weak_lensing.lensing_prior.LensingPrior
star_color_model_path: /data/scratch/sdss/color_models/star_gmm_nmgy.pkl
gal_color_model_path: /data/scratch/sdss/color_models/gal_gmm_nmgy.pkl
n_tiles_h: 8
n_tiles_w: 8
batch_size: 1
prob_galaxy: 1.0
mean_sources: 162
arcsec_per_pixel: 0.055
sample_method: cosmology
shear_mean: 0
shear_std: 0.0175
convergence_mean: 0
convergence_std: 0.025
num_knots: 4

decoder:
_target_: case_studies.weak_lensing.lensing_decoder.LensingDecoder
tile_slen: 128
survey: ${surveys.sdss}

simulator:
n_batches: 128
num_workers: 32
valid_n_batches: 10
fix_validation_set: true

variational_factors:
- _target_: bliss.encoder.variational_dist.BivariateNormalFactor
Expand All @@ -34,12 +62,11 @@ my_metrics:

my_render:
lensing_shear_conv:
_target_: case_studies.weak_lensing.lensing_plots.PlotWeakLensingShearConvergence
frequency: 1
restrict_batch: 0
tile_slen: 256
save_local: "convergence_only_maps"

_target_: case_studies.weak_lensing.lensing_plots.PlotWeakLensingShearConvergence
frequency: 1
restrict_batch: 0
tile_slen: 256
save_local: "lensing_maps"

encoder:
_target_: case_studies.weak_lensing.lensing_encoder.WeakLensingEncoder
Expand Down Expand Up @@ -77,20 +104,20 @@ surveys:
dc2:
_target_: case_studies.weak_lensing.lensing_dc2.LensingDC2DataModule
dc2_image_dir: ${paths.dc2}/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
dc2_cat_path: ${paths.dc2}/lensing_catalog.pkl
dc2_cat_path: ${paths.dc2}/dc2_lensing_catalog.pkl
image_slen: 2048
tile_slen: 256
splits: 0:80/80:90/90:100
batch_size: 1
num_workers: 1
cached_data_path: ${paths.output}/dc2_2048_galid_full_scaled_up
cached_data_path: ${paths.dc2}/dc2_lensing_catalog_splits

train:
trainer:
logger:
name: dc2_weak_lensing_exp
version: exp_08_05
devices: [6] # cuda:0 for gl
version: exp_08_12
devices: 1 # cuda:0 for gl
use_distributed_sampler: false
precision: 32-true
data_source: ${surveys.dc2}
Expand Down
Loading

0 comments on commit 41d09e1

Please sign in to comment.