Skip to content

Commit

Permalink
Weak lensing: data aug + recent DC2 results with baselines (#1077)
Browse files Browse the repository at this point in the history
* Tune avg ellip estimator in ellipticity notebook

* Add plot to ellipticity notebook

* Fix bug when computing MSEs in ellipticity notebook

* Minor update to README

* Add weighed avg ellipticity estimator as an additional baseline

* Notebook to demonstrate shear transformation under rotation

* Add ra/dec to catalog

* Close figures in lensing_plots.py

* Rerun encoder eval notebook after recent training run

* Add 2PCFs to encoder eval notebook (work in progress)

* Set seed in encoder eval notebook

* Add flips to data augmentation notebook

* lensing rotate and flip

* Fix pred_convergence shape in lensing metrics

* Allow train_transforms to be an argument in LensingDC2DataModule

* Update encoder eval notebook with most recent encoder weights

* Fix split files directory in dc2 config

* updated augmentation

* Rename data aug and metrics, tweak data aug

* Fix output path in dc2 config

---------

Co-authored-by: shreyasc <[email protected]>
Co-authored-by: Shreyas Chandrashekaran <[email protected]>
  • Loading branch information
3 people authored Oct 19, 2024
1 parent 4694457 commit 0eb412b
Show file tree
Hide file tree
Showing 10 changed files with 806 additions and 278 deletions.
2 changes: 1 addition & 1 deletion case_studies/weak_lensing/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
### Neural posterior estimation of weak lensing shear and convergence for the LSST DESC DC2 simulated sky survey
### Neural posterior estimation of weak lensing shear and convergence
#### Shreyas Chandrashekaran, Tim White, Camille Avestruz, and Jeffrey Regier, with assistance from Steve Fan and Tahseen Younus

This case study aims to estimate weak lensing shear and convergence for the DC2 simulated sky survey. See `notebooks/dc2/evaluate_encoder.ipynb` for our most recent results.
Expand Down
12 changes: 7 additions & 5 deletions case_studies/weak_lensing/lensing_config_dc2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ my_normalizers:

my_metrics:
lensing_map:
_target_: case_studies.weak_lensing.lensing_metrics.LensingMapMSE
_target_: case_studies.weak_lensing.lensing_metrics.LensingMSE

my_render:
lensing_shear_conv:
Expand Down Expand Up @@ -77,17 +77,19 @@ surveys:
n_image_split: 2 # split into n_image_split**2 subimages
tile_slen: 256
splits: 0:80/80:90/90:100
avg_ellip_kernel_size: 7 # needs to be odd
avg_ellip_kernel_sigma: 3
avg_ellip_kernel_size: 15 # needs to be odd
avg_ellip_kernel_sigma: 4
batch_size: 1
num_workers: 1
cached_data_path: ${paths.dc2}/dc2_corrected_shear_only_cd_fix
cached_data_path: ${paths.dc2}/dc2_lensing_splits_radec
train_transforms:
- _target_: case_studies.weak_lensing.lensing_data_augmentation.LensingRotateFlipTransform

train:
trainer:
logger:
name: weak_lensing_experiments_dc2
version: october1
version: october18
max_epochs: 250
devices: 1
use_distributed_sampler: false
Expand Down
2 changes: 1 addition & 1 deletion case_studies/weak_lensing/lensing_config_simulator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ my_normalizers:

my_metrics:
lensing_map:
_target_: case_studies.weak_lensing.lensing_metrics.LensingMapMSE
_target_: case_studies.weak_lensing.lensing_metrics.LensingMSE

my_render:
lensing_shear_conv:
Expand Down
62 changes: 62 additions & 0 deletions case_studies/weak_lensing/lensing_data_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import random

import torch


class LensingRotateFlipTransform(torch.nn.Module):
def __init__(self, without_replacement=False):
super().__init__()
self.rotate_id = -1
self.flip_id = -1
self.seen_states = set()
self.without_replacement = without_replacement

def __call__(self, datum):
self.rotate_id = random.randint(0, 3)
self.flip_id = random.randint(0, 1)

if self.without_replacement:
if len(self.seen_states) == 8:
self.seen_states = set()
while ((self.rotate_id, self.flip_id)) in self.seen_states:
self.rotate_id = random.randint(0, 3)
self.flip_id = random.randint(0, 1)
self.seen_states.add((self.rotate_id, self.flip_id))

# problematic if the psf isn't rotationally invariant
datum_out = {"psf_params": datum["psf_params"]}

# apply rotation
datum_out["images"] = datum["images"].rot90(self.rotate_id, [1, 2])
d = datum["tile_catalog"]
datum_out["tile_catalog"] = {k: v.rot90(self.rotate_id, [0, 1]) for k, v in d.items()}

# apply flip
if self.flip_id == 1:
datum_out["images"] = datum_out["images"].flip([1])
d = datum_out["tile_catalog"]
datum_out["tile_catalog"] = {k: v.flip([0]) for k, v in d.items()}

# shear requires special logic
if all(k in datum["tile_catalog"] for k in ("shear_1", "shear_2")):
shear1 = datum_out["tile_catalog"]["shear_1"]
shear2 = datum_out["tile_catalog"]["shear_2"]
for _ in range(self.rotate_id):
shear1 = -shear1
shear2 = -shear2
if self.flip_id == 1:
shear2 = -shear2
datum_out["tile_catalog"]["shear_1"] = shear1
datum_out["tile_catalog"]["shear_2"] = shear2

# locations require special logic
if "locs" in datum["tile_catalog"]:
locs = datum_out["tile_catalog"]["locs"]
for _ in range(self.rotate_id):
# Rotate 90 degrees clockwise (in pixel coordinates)
locs = torch.stack((1 - locs[..., 1], locs[..., 0]), dim=3)
if self.flip_id == 1:
locs = torch.stack((1 - locs[..., 0], locs[..., 1]), dim=3)
datum_out["tile_catalog"]["locs"] = locs

return datum_out
12 changes: 11 additions & 1 deletion case_studies/weak_lensing/lensing_dc2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import math
import sys
from typing import List

import pandas as pd
import torch
Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(
batch_size: int,
num_workers: int,
cached_data_path: str,
train_transforms: List,
**kwargs,
):
super().__init__(
Expand All @@ -47,7 +49,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
cached_data_path=cached_data_path,
train_transforms=[],
train_transforms=train_transforms,
nontrain_transforms=[],
subset_fraction=None,
)
Expand Down Expand Up @@ -166,6 +168,8 @@ def generate_cached_data(self, image_index):
ellip2_lensed = tile_dict["ellip2_lensed_sum"] / tile_dict["ellip2_lensed_count"]
ellip_lensed = torch.stack((ellip1_lensed.squeeze(-1), ellip2_lensed.squeeze(-1)), dim=-1)
redshift = tile_dict["redshift_sum"] / tile_dict["redshift_count"]
ra = tile_dict["ra_sum"] / tile_dict["ra_count"]
dec = tile_dict["dec_sum"] / tile_dict["dec_count"]

tile_dict["shear_1"] = shear1
tile_dict["shear_2"] = shear2
Expand All @@ -175,6 +179,8 @@ def generate_cached_data(self, image_index):
tile_dict, self.avg_ellip_kernel_size, self.avg_ellip_kernel_sigma
)
tile_dict["redshift"] = redshift
tile_dict["ra"] = ra
tile_dict["dec"] = dec

data_splits = self.split_image_and_tile_cat(image, tile_dict, tile_dict.keys(), psf_params)

Expand Down Expand Up @@ -225,6 +231,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
plocs_mask = x0_mask * x1_mask

galid = galid[plocs_mask]
ra = ra[plocs_mask]
dec = dec[plocs_mask]
plocs = plocs[plocs_mask]

shear1 = shear1[plocs_mask]
Expand All @@ -241,6 +249,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
nobj = galid.shape[0]

d = {
"ra": ra.reshape(1, nobj, 1),
"dec": dec.reshape(1, nobj, 1),
"plocs": plocs.reshape(1, nobj, 2),
"shear1": shear1.reshape(1, nobj, 1),
"shear2": shear2.reshape(1, nobj, 1),
Expand Down
65 changes: 41 additions & 24 deletions case_studies/weak_lensing/lensing_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,80 @@
from torchmetrics import Metric


class LensingMapMSE(Metric):
class LensingMSE(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("shear1_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state(
"baseline_shear1_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
"zero_baseline_shear1_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
)
self.add_state(
"ellip_baseline_shear1_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
)
self.add_state("shear2_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state(
"baseline_shear2_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
"zero_baseline_shear2_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
)
self.add_state(
"ellip_baseline_shear2_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
)
self.add_state("convergence_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum")
# potentially throws a division by zero error if true_idx is empty and uncaught
self.add_state("total", default=torch.zeros(1), dist_reduce_fx="sum")

def update(self, true_cat, est_cat, matching) -> None:
true_shear1 = true_cat["shear_1"]
true_shear2 = true_cat["shear_2"]
pred_shear1 = est_cat["shear_1"]
pred_shear2 = est_cat["shear_2"]
true_shear = torch.cat((true_shear1, true_shear2), dim=-1)
pred_shear = torch.cat((pred_shear1, pred_shear2), dim=-1)
true_shear = true_shear.flatten(1, 2)
pred_shear = pred_shear.flatten(1, 2)
baseline_pred_shear = torch.zeros_like(true_shear)
true_shear1 = true_cat["shear_1"].flatten(1, 2)
true_shear2 = true_cat["shear_2"].flatten(1, 2)
pred_shear1 = est_cat["shear_1"].flatten(1, 2)
pred_shear2 = est_cat["shear_2"].flatten(1, 2)
zero_baseline_pred_shear1 = torch.zeros_like(true_shear1)
zero_baseline_pred_shear2 = torch.zeros_like(true_shear2)
ellip_baseline_pred_shear1 = (
true_cat["ellip_lensed_wavg"][..., 0].unsqueeze(-1).flatten(1, 2)
)
ellip_baseline_pred_shear2 = (
true_cat["ellip_lensed_wavg"][..., 1].unsqueeze(-1).flatten(1, 2)
)

if "convergence" not in est_cat:
true_convergence = torch.zeros_like(true_shear1).flatten(1, 2)
pred_convergence = torch.zeros_like(true_convergence).flatten(1, 2)
pred_convergence = torch.zeros_like(true_convergence)
else:
true_convergence = true_cat["convergence"].flatten(1, 2)
pred_convergence = est_cat["convergence"].flatten(1, 2)

shear1_sq_err = ((true_shear[:, :, 0] - pred_shear[:, :, 0]) ** 2).sum()
baseline_shear1_sq_err = ((true_shear[:, :, 0] - baseline_pred_shear[:, :, 0]) ** 2).sum()
shear2_sq_err = ((true_shear[:, :, 1] - pred_shear[:, :, 1]) ** 2).sum()
baseline_shear2_sq_err = ((true_shear[:, :, 1] - baseline_pred_shear[:, :, 1]) ** 2).sum()
shear1_sq_err = ((true_shear1 - pred_shear1) ** 2).sum()
zero_baseline_shear1_sq_err = ((true_shear1 - zero_baseline_pred_shear1) ** 2).sum()
ellip_baseline_shear1_sq_err = ((true_shear1 - ellip_baseline_pred_shear1) ** 2).sum()
shear2_sq_err = ((true_shear2 - pred_shear2) ** 2).sum()
zero_baseline_shear2_sq_err = ((true_shear2 - zero_baseline_pred_shear2) ** 2).sum()
ellip_baseline_shear2_sq_err = ((true_shear2 - ellip_baseline_pred_shear2) ** 2).sum()
convergence_sq_err = ((true_convergence - pred_convergence) ** 2).sum()

self.shear1_sum_squared_err += shear1_sq_err
self.baseline_shear1_sum_squared_err += baseline_shear1_sq_err
self.zero_baseline_shear1_sum_squared_err += zero_baseline_shear1_sq_err
self.ellip_baseline_shear1_sum_squared_err += ellip_baseline_shear1_sq_err
self.shear2_sum_squared_err += shear2_sq_err
self.baseline_shear2_sum_squared_err += baseline_shear2_sq_err
self.zero_baseline_shear2_sum_squared_err += zero_baseline_shear2_sq_err
self.ellip_baseline_shear2_sum_squared_err += ellip_baseline_shear2_sq_err
self.convergence_sum_squared_err += convergence_sq_err

self.total += torch.tensor(true_convergence.shape[1])

def compute(self):
shear1_mse = self.shear1_sum_squared_err / self.total
baseline_shear1_mse = self.baseline_shear1_sum_squared_err / self.total
zero_baseline_shear1_mse = self.zero_baseline_shear1_sum_squared_err / self.total
ellip_baseline_shear1_mse = self.ellip_baseline_shear1_sum_squared_err / self.total
shear2_mse = self.shear2_sum_squared_err / self.total
baseline_shear2_mse = self.baseline_shear2_sum_squared_err / self.total
zero_baseline_shear2_mse = self.zero_baseline_shear2_sum_squared_err / self.total
ellip_baseline_shear2_mse = self.ellip_baseline_shear2_sum_squared_err / self.total
convergence_mse = self.convergence_sum_squared_err / self.total

return {
"Shear 1 MSE": shear1_mse,
"Baseline shear 1 MSE": baseline_shear1_mse,
"Zero baseline shear 1 MSE": zero_baseline_shear1_mse,
"Ellip baseline shear 1 MSE": ellip_baseline_shear1_mse,
"Shear 2 MSE": shear2_mse,
"Baseline shear 2 MSE": baseline_shear2_mse,
"Zero baseline shear 2 MSE": zero_baseline_shear2_mse,
"Ellip baseline shear 2 MSE": ellip_baseline_shear2_mse,
"Convergence MSE": convergence_mse,
}
6 changes: 4 additions & 2 deletions case_studies/weak_lensing/lensing_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def plot_maps(
if not Path(save_local).exists():
Path(save_local).mkdir(parents=True)
fig.savefig(f"{save_local}/lensing_maps_{current_epoch}.png")
return fig, axes

plt.close(fig)


def plot_lensing_scatterplots(
Expand Down Expand Up @@ -186,4 +187,5 @@ def plot_lensing_scatterplots(
if not Path(save_local).exists():
Path(save_local).mkdir(parents=True)
fig.savefig(f"{save_local}/lensing_scatterplots_{current_epoch}.png")
return fig, axes

plt.close(fig)
516 changes: 295 additions & 221 deletions case_studies/weak_lensing/notebooks/dc2/ellipticity.ipynb

Large diffs are not rendered by default.

64 changes: 41 additions & 23 deletions case_studies/weak_lensing/notebooks/dc2/evaluate_encoder.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 0eb412b

Please sign in to comment.