Skip to content

Commit

Permalink
Weak lensing DC2 updates (#1067)
Browse files Browse the repository at this point in the history
* Refactor generate_cached_data in lensing_dc2

* Decrease learning rate, remove clamp on convergence stdev

* Remove some print statements in lensing_encoder

* Update ellipticity and redshift notebooks after new split files

* Rename notebooks

* Notebook to examine flux/mag detection limit in DC2 images

* Filter flux_r >= 200 instead of >= 50

* New average ellipticity estimator (lensing_dc2 + ellip ipynb)

* Move avg ellipticity to utils script

* Move avg ellipticity to utils script (cont.)

* Rerun redshift and two-point notebooks with updated catalog (flux_r >= 200)

* Rename catalogs and split directories

* Change avg ellipticity kernel params

* Rerun notebooks with new catalog

* Allow duplicate code in both DC2 lensing catalog generation scripts

* Refactor image and tile cat splitting to avoid duplicate code in lensing DC2 subclass
  • Loading branch information
timwhite0 authored Aug 26, 2024
1 parent e536606 commit bd96328
Show file tree
Hide file tree
Showing 16 changed files with 3,393 additions and 3,373 deletions.
53 changes: 29 additions & 24 deletions bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,7 @@ def load_image_and_catalog(self, image_index):
},
}

def generate_cached_data(self, image_index):
result_dict = self.load_image_and_catalog(image_index)

image = result_dict["inputs"]["image"]
tile_dict = result_dict["tile_dict"]
wcs_header_str = result_dict["other_info"]["wcs_header_str"]
psf_params = result_dict["inputs"]["psf_params"]

def split_image_and_tile_cat(self, image, tile_cat, tile_cat_keys_to_split, psf_params):
# split image
split_lim = self.image_lim[0] // self.n_image_split
image_splits = split_tensor(image, split_lim, 1, 2)
Expand All @@ -237,6 +230,31 @@ def generate_cached_data(self, image_index):

# split tile cat
tile_cat_splits = {}
for param_name in tile_cat_keys_to_split:
tile_cat_splits[param_name] = split_tensor(
tile_cat[param_name], split_lim // self.tile_slen, 0, 1
)

return {
"tile_catalog": unpack_dict(tile_cat_splits),
"images": image_splits,
"image_height_index": (
torch.arange(0, len(image_splits)) // split_image_num_on_width
).tolist(),
"image_width_index": (
torch.arange(0, len(image_splits)) % split_image_num_on_width
).tolist(),
"psf_params": [psf_params for _ in range(self.n_image_split**2)],
}

def generate_cached_data(self, image_index):
result_dict = self.load_image_and_catalog(image_index)

image = result_dict["inputs"]["image"]
tile_dict = result_dict["tile_dict"]
wcs_header_str = result_dict["other_info"]["wcs_header_str"]
psf_params = result_dict["inputs"]["psf_params"]

param_list = [
"locs",
"n_sources",
Expand All @@ -252,24 +270,11 @@ def generate_cached_data(self, image_index):
"two_sources_mask",
"more_than_two_sources_mask",
]
for param_name in param_list:
tile_cat_splits[param_name] = split_tensor(
tile_dict[param_name], split_lim // self.tile_slen, 0, 1
)

data_splits = {
"tile_catalog": unpack_dict(tile_cat_splits),
"images": image_splits,
"image_height_index": (
torch.arange(0, len(image_splits)) // split_image_num_on_width
).tolist(),
"image_width_index": (
torch.arange(0, len(image_splits)) % split_image_num_on_width
).tolist(),
"psf_params": [psf_params for _ in range(self.n_image_split**2)],
}
splits = self.split_image_and_tile_cat(image, tile_dict, param_list, psf_params)

data_splits = split_list(
unpack_dict(data_splits),
unpack_dict(splits),
sub_list_len=self.data_in_one_cached_file,
)

Expand Down
133 changes: 69 additions & 64 deletions case_studies/weak_lensing/generate_dc2_lensing_catalog.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=R0801
import os
import pickle as pkl

Expand All @@ -17,32 +18,61 @@
raise FileExistsError(f"{file_path} already exists.")


print("Loading truth...\n") # noqa: WPS421

truth_cat = GCRCatalogs.load_catalog("desc_dc2_run2.2i_dr6_truth")

truth_df = truth_cat.get_quantities(
quantities=[
"cosmodc2_id",
"id",
"match_objectId",
"truth_type",
"ra",
"dec",
"redshift",
"flux_u",
"flux_g",
"flux_r",
"flux_i",
"flux_z",
"flux_y",
"mag_u",
"mag_g",
"mag_r",
"mag_i",
"mag_z",
"mag_y",
]
)
truth_df = pd.DataFrame(truth_df)

truth_df = truth_df[truth_df["truth_type"] == 1]

truth_df = truth_df[truth_df["flux_r"] >= 200]

max_ra = np.nanmax(truth_df["ra"])
min_ra = np.nanmin(truth_df["ra"])
max_dec = np.nanmax(truth_df["dec"])
min_dec = np.nanmin(truth_df["dec"])
ra_dec_filters = [f"ra >= {min_ra}", f"ra <= {max_ra}", f"dec >= {min_dec}", f"dec <= {max_dec}"]

vertices = hp.ang2vec(
np.array([min_ra, max_ra, max_ra, min_ra]),
np.array([min_dec, min_dec, max_dec, max_dec]),
lonlat=True,
)
ipix = hp.query_polygon(32, vertices, inclusive=True)
healpix_filter = GCRQuery((lambda h: np.isin(h, ipix, assume_unique=True), "healpix_pixel"))


print("Loading object-with-truth-match...\n") # noqa: WPS421

object_truth_cat = GCRCatalogs.load_catalog("desc_dc2_run2.2i_dr6_object_with_truth_match")

object_truth_df = object_truth_cat.get_quantities(
quantities=[
"cosmodc2_id_truth",
"id_truth",
"objectId",
"match_objectId",
"truth_type",
"ra_truth",
"dec_truth",
"redshift_truth",
"flux_u_truth",
"flux_g_truth",
"flux_r_truth",
"flux_i_truth",
"flux_z_truth",
"flux_y_truth",
"mag_u_truth",
"mag_g_truth",
"mag_r_truth",
"mag_i_truth",
"mag_z_truth",
"mag_y_truth",
"Ixx_pixel",
"Iyy_pixel",
"Ixy_pixel",
Expand Down Expand Up @@ -70,32 +100,15 @@
"psf_fwhm_i",
"psf_fwhm_z",
"psf_fwhm_y",
],
]
)
object_truth_df = pd.DataFrame(object_truth_df)

max_ra = np.nanmax(object_truth_df["ra_truth"])
min_ra = np.nanmin(object_truth_df["ra_truth"])
max_dec = np.nanmax(object_truth_df["dec_truth"])
min_dec = np.nanmin(object_truth_df["dec_truth"])
ra_dec_filters = [f"ra >= {min_ra}", f"ra <= {max_ra}", f"dec >= {min_dec}", f"dec <= {max_dec}"]

vertices = hp.ang2vec(
np.array([min_ra, max_ra, max_ra, min_ra]),
np.array([min_dec, min_dec, max_dec, max_dec]),
lonlat=True,
)
ipix = hp.query_polygon(32, vertices, inclusive=True)
healpix_filter = GCRQuery((lambda h: np.isin(h, ipix, assume_unique=True), "healpix_pixel"))

object_truth_df = object_truth_df[object_truth_df["truth_type"] == 1]

object_truth_df.drop_duplicates(subset=["cosmodc2_id_truth"], inplace=True)


print("Loading CosmoDC2...\n") # noqa: WPS421

config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2"}

cosmo_cat = GCRCatalogs.load_catalog("desc_cosmodc2", config_overwrite)

cosmo_df = cosmo_cat.get_quantities(
Expand All @@ -115,39 +128,31 @@
cosmo_df = pd.DataFrame(cosmo_df)


print("Merging...\n") # noqa: WPS421
print("Merging truth with object-with-truth-match...\n") # noqa: WPS421

merge_df = object_truth_df.merge(
cosmo_df, left_on="cosmodc2_id_truth", right_on="galaxy_id", how="left"
merge_df1 = truth_df.merge(
object_truth_df, left_on="cosmodc2_id", right_on="cosmodc2_id_truth", how="left"
)

merge_df = merge_df[~merge_df["galaxy_id"].isna()]

merge_df.drop(columns=["ra_truth", "dec_truth"], inplace=True)

merge_df.rename(
columns={
"redshift_truth": "redshift",
"flux_u_truth": "flux_u",
"flux_g_truth": "flux_g",
"flux_r_truth": "flux_r",
"flux_i_truth": "flux_i",
"flux_z_truth": "flux_z",
"flux_y_truth": "flux_y",
"mag_u_truth": "mag_u",
"mag_g_truth": "mag_g",
"mag_r_truth": "mag_r",
"mag_i_truth": "mag_i",
"mag_z_truth": "mag_z",
"mag_y_truth": "mag_y",
},
inplace=True,
)
merge_df1.drop_duplicates(subset=["cosmodc2_id"], inplace=True)

merge_df1.drop(columns=["cosmodc2_id_truth"], inplace=True)


print("Merging with CosmoDC2...\n") # noqa: WPS421

merge_df2 = merge_df1.merge(cosmo_df, left_on="cosmodc2_id", right_on="galaxy_id", how="left")

merge_df2 = merge_df2[~merge_df2["galaxy_id"].isna()]

merge_df2.drop(columns=["ra_y", "dec_y"], inplace=True)

merge_df2.rename(columns={"ra_x": "ra", "dec_x": "dec"}, inplace=True)


print("Saving...\n") # noqa: WPS421

with open(file_path, "wb") as f:
pkl.dump(merge_df, f)
pkl.dump(merge_df2, f)

print(f"Catalog has been saved at {file_path}") # noqa: WPS421
Loading

0 comments on commit bd96328

Please sign in to comment.