Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weak lensing: data aug + recent DC2 results with baselines #1077

Merged
merged 22 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f0bbbb4
Tune avg ellip estimator in ellipticity notebook
timwhite0 Oct 2, 2024
15cf7d5
Add plot to ellipticity notebook
timwhite0 Oct 2, 2024
2038d99
Fix bug when computing MSEs in ellipticity notebook
timwhite0 Oct 2, 2024
79b8203
Minor update to README
timwhite0 Oct 3, 2024
fc0a758
Add weighed avg ellipticity estimator as an additional baseline
timwhite0 Oct 3, 2024
950d3c4
Notebook to demonstrate shear transformation under rotation
timwhite0 Oct 3, 2024
f2ab77d
Add ra/dec to catalog
timwhite0 Oct 7, 2024
3c19dc8
Merge branch 'master' into tw/weak_lensing
timwhite0 Oct 7, 2024
6c8f784
Close figures in lensing_plots.py
timwhite0 Oct 7, 2024
7b7b7e6
Rerun encoder eval notebook after recent training run
timwhite0 Oct 7, 2024
1c877cc
Add 2PCFs to encoder eval notebook (work in progress)
timwhite0 Oct 7, 2024
52509dc
Set seed in encoder eval notebook
timwhite0 Oct 7, 2024
fa885c3
Merge branch 'master' into tw/weak_lensing
timwhite0 Oct 8, 2024
cbcbbbc
Add flips to data augmentation notebook
timwhite0 Oct 8, 2024
c2a3b00
lensing rotate and flip
shreyasc30 Oct 11, 2024
0c47074
Fix pred_convergence shape in lensing metrics
timwhite0 Oct 11, 2024
9c39e1b
Allow train_transforms to be an argument in LensingDC2DataModule
timwhite0 Oct 11, 2024
42616f6
Update encoder eval notebook with most recent encoder weights
timwhite0 Oct 18, 2024
eff27e4
Fix split files directory in dc2 config
timwhite0 Oct 18, 2024
6a55760
updated augmentation
Oct 18, 2024
df378c0
Rename data aug and metrics, tweak data aug
timwhite0 Oct 19, 2024
720973c
Fix output path in dc2 config
timwhite0 Oct 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion case_studies/weak_lensing/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
### Neural posterior estimation of weak lensing shear and convergence for the LSST DESC DC2 simulated sky survey
### Neural posterior estimation of weak lensing shear and convergence
#### Shreyas Chandrashekaran, Tim White, Camille Avestruz, and Jeffrey Regier, with assistance from Steve Fan and Tahseen Younus

This case study aims to estimate weak lensing shear and convergence for the DC2 simulated sky survey. See `notebooks/dc2/evaluate_encoder.ipynb` for our most recent results.
Expand Down
12 changes: 7 additions & 5 deletions case_studies/weak_lensing/lensing_config_dc2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ my_normalizers:

my_metrics:
lensing_map:
_target_: case_studies.weak_lensing.lensing_metrics.LensingMapMSE
_target_: case_studies.weak_lensing.lensing_metrics.LensingMSE

my_render:
lensing_shear_conv:
Expand Down Expand Up @@ -77,17 +77,19 @@ surveys:
n_image_split: 2 # split into n_image_split**2 subimages
tile_slen: 256
splits: 0:80/80:90/90:100
avg_ellip_kernel_size: 7 # needs to be odd
avg_ellip_kernel_sigma: 3
avg_ellip_kernel_size: 15 # needs to be odd
avg_ellip_kernel_sigma: 4
batch_size: 1
num_workers: 1
cached_data_path: ${paths.dc2}/dc2_corrected_shear_only_cd_fix
cached_data_path: ${paths.dc2}/dc2_lensing_splits_radec
train_transforms:
- _target_: case_studies.weak_lensing.lensing_data_augmentation.LensingRotateFlipTransform

train:
trainer:
logger:
name: weak_lensing_experiments_dc2
version: october1
version: october18
max_epochs: 250
devices: 1
use_distributed_sampler: false
Expand Down
2 changes: 1 addition & 1 deletion case_studies/weak_lensing/lensing_config_simulator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ my_normalizers:

my_metrics:
lensing_map:
_target_: case_studies.weak_lensing.lensing_metrics.LensingMapMSE
_target_: case_studies.weak_lensing.lensing_metrics.LensingMSE

my_render:
lensing_shear_conv:
Expand Down
62 changes: 62 additions & 0 deletions case_studies/weak_lensing/lensing_data_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import random

import torch


class LensingRotateFlipTransform(torch.nn.Module):
def __init__(self, without_replacement=False):
super().__init__()
self.rotate_id = -1
self.flip_id = -1
self.seen_states = set()
self.without_replacement = without_replacement

def __call__(self, datum):
self.rotate_id = random.randint(0, 3)
self.flip_id = random.randint(0, 1)

if self.without_replacement:
if len(self.seen_states) == 8:
self.seen_states = set()
while ((self.rotate_id, self.flip_id)) in self.seen_states:
self.rotate_id = random.randint(0, 3)
self.flip_id = random.randint(0, 1)
self.seen_states.add((self.rotate_id, self.flip_id))

# problematic if the psf isn't rotationally invariant
datum_out = {"psf_params": datum["psf_params"]}

# apply rotation
datum_out["images"] = datum["images"].rot90(self.rotate_id, [1, 2])
d = datum["tile_catalog"]
datum_out["tile_catalog"] = {k: v.rot90(self.rotate_id, [0, 1]) for k, v in d.items()}

# apply flip
if self.flip_id == 1:
datum_out["images"] = datum_out["images"].flip([1])
d = datum_out["tile_catalog"]
datum_out["tile_catalog"] = {k: v.flip([0]) for k, v in d.items()}

# shear requires special logic
if all(k in datum["tile_catalog"] for k in ("shear_1", "shear_2")):
shear1 = datum_out["tile_catalog"]["shear_1"]
shear2 = datum_out["tile_catalog"]["shear_2"]
for _ in range(self.rotate_id):
shear1 = -shear1
shear2 = -shear2
if self.flip_id == 1:
shear2 = -shear2
datum_out["tile_catalog"]["shear_1"] = shear1
datum_out["tile_catalog"]["shear_2"] = shear2

# locations require special logic
if "locs" in datum["tile_catalog"]:
locs = datum_out["tile_catalog"]["locs"]
for _ in range(self.rotate_id):
# Rotate 90 degrees clockwise (in pixel coordinates)
locs = torch.stack((1 - locs[..., 1], locs[..., 0]), dim=3)
if self.flip_id == 1:
locs = torch.stack((1 - locs[..., 0], locs[..., 1]), dim=3)
datum_out["tile_catalog"]["locs"] = locs

return datum_out
12 changes: 11 additions & 1 deletion case_studies/weak_lensing/lensing_dc2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import math
import sys
from typing import List

import pandas as pd
import torch
Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(
batch_size: int,
num_workers: int,
cached_data_path: str,
train_transforms: List,
**kwargs,
):
super().__init__(
Expand All @@ -47,7 +49,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
cached_data_path=cached_data_path,
train_transforms=[],
train_transforms=train_transforms,
nontrain_transforms=[],
subset_fraction=None,
)
Expand Down Expand Up @@ -166,6 +168,8 @@ def generate_cached_data(self, image_index):
ellip2_lensed = tile_dict["ellip2_lensed_sum"] / tile_dict["ellip2_lensed_count"]
ellip_lensed = torch.stack((ellip1_lensed.squeeze(-1), ellip2_lensed.squeeze(-1)), dim=-1)
redshift = tile_dict["redshift_sum"] / tile_dict["redshift_count"]
ra = tile_dict["ra_sum"] / tile_dict["ra_count"]
dec = tile_dict["dec_sum"] / tile_dict["dec_count"]

tile_dict["shear_1"] = shear1
tile_dict["shear_2"] = shear2
Expand All @@ -175,6 +179,8 @@ def generate_cached_data(self, image_index):
tile_dict, self.avg_ellip_kernel_size, self.avg_ellip_kernel_sigma
)
tile_dict["redshift"] = redshift
tile_dict["ra"] = ra
tile_dict["dec"] = dec

data_splits = self.split_image_and_tile_cat(image, tile_dict, tile_dict.keys(), psf_params)

Expand Down Expand Up @@ -225,6 +231,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
plocs_mask = x0_mask * x1_mask

galid = galid[plocs_mask]
ra = ra[plocs_mask]
dec = dec[plocs_mask]
plocs = plocs[plocs_mask]

shear1 = shear1[plocs_mask]
Expand All @@ -241,6 +249,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
nobj = galid.shape[0]

d = {
"ra": ra.reshape(1, nobj, 1),
"dec": dec.reshape(1, nobj, 1),
"plocs": plocs.reshape(1, nobj, 2),
"shear1": shear1.reshape(1, nobj, 1),
"shear2": shear2.reshape(1, nobj, 1),
Expand Down
65 changes: 41 additions & 24 deletions case_studies/weak_lensing/lensing_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,80 @@
from torchmetrics import Metric


class LensingMapMSE(Metric):
class LensingMSE(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("shear1_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state(
"baseline_shear1_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
"zero_baseline_shear1_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
)
self.add_state(
"ellip_baseline_shear1_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
)
self.add_state("shear2_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state(
"baseline_shear2_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
"zero_baseline_shear2_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
)
self.add_state(
"ellip_baseline_shear2_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum"
)
self.add_state("convergence_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum")
# potentially throws a division by zero error if true_idx is empty and uncaught
self.add_state("total", default=torch.zeros(1), dist_reduce_fx="sum")

def update(self, true_cat, est_cat, matching) -> None:
true_shear1 = true_cat["shear_1"]
true_shear2 = true_cat["shear_2"]
pred_shear1 = est_cat["shear_1"]
pred_shear2 = est_cat["shear_2"]
true_shear = torch.cat((true_shear1, true_shear2), dim=-1)
pred_shear = torch.cat((pred_shear1, pred_shear2), dim=-1)
true_shear = true_shear.flatten(1, 2)
pred_shear = pred_shear.flatten(1, 2)
baseline_pred_shear = torch.zeros_like(true_shear)
true_shear1 = true_cat["shear_1"].flatten(1, 2)
true_shear2 = true_cat["shear_2"].flatten(1, 2)
pred_shear1 = est_cat["shear_1"].flatten(1, 2)
pred_shear2 = est_cat["shear_2"].flatten(1, 2)
zero_baseline_pred_shear1 = torch.zeros_like(true_shear1)
zero_baseline_pred_shear2 = torch.zeros_like(true_shear2)
ellip_baseline_pred_shear1 = (
true_cat["ellip_lensed_wavg"][..., 0].unsqueeze(-1).flatten(1, 2)
)
ellip_baseline_pred_shear2 = (
true_cat["ellip_lensed_wavg"][..., 1].unsqueeze(-1).flatten(1, 2)
)

if "convergence" not in est_cat:
true_convergence = torch.zeros_like(true_shear1).flatten(1, 2)
pred_convergence = torch.zeros_like(true_convergence).flatten(1, 2)
pred_convergence = torch.zeros_like(true_convergence)
else:
true_convergence = true_cat["convergence"].flatten(1, 2)
pred_convergence = est_cat["convergence"].flatten(1, 2)

shear1_sq_err = ((true_shear[:, :, 0] - pred_shear[:, :, 0]) ** 2).sum()
baseline_shear1_sq_err = ((true_shear[:, :, 0] - baseline_pred_shear[:, :, 0]) ** 2).sum()
shear2_sq_err = ((true_shear[:, :, 1] - pred_shear[:, :, 1]) ** 2).sum()
baseline_shear2_sq_err = ((true_shear[:, :, 1] - baseline_pred_shear[:, :, 1]) ** 2).sum()
shear1_sq_err = ((true_shear1 - pred_shear1) ** 2).sum()
zero_baseline_shear1_sq_err = ((true_shear1 - zero_baseline_pred_shear1) ** 2).sum()
ellip_baseline_shear1_sq_err = ((true_shear1 - ellip_baseline_pred_shear1) ** 2).sum()
shear2_sq_err = ((true_shear2 - pred_shear2) ** 2).sum()
zero_baseline_shear2_sq_err = ((true_shear2 - zero_baseline_pred_shear2) ** 2).sum()
ellip_baseline_shear2_sq_err = ((true_shear2 - ellip_baseline_pred_shear2) ** 2).sum()
convergence_sq_err = ((true_convergence - pred_convergence) ** 2).sum()

self.shear1_sum_squared_err += shear1_sq_err
self.baseline_shear1_sum_squared_err += baseline_shear1_sq_err
self.zero_baseline_shear1_sum_squared_err += zero_baseline_shear1_sq_err
self.ellip_baseline_shear1_sum_squared_err += ellip_baseline_shear1_sq_err
self.shear2_sum_squared_err += shear2_sq_err
self.baseline_shear2_sum_squared_err += baseline_shear2_sq_err
self.zero_baseline_shear2_sum_squared_err += zero_baseline_shear2_sq_err
self.ellip_baseline_shear2_sum_squared_err += ellip_baseline_shear2_sq_err
self.convergence_sum_squared_err += convergence_sq_err

self.total += torch.tensor(true_convergence.shape[1])

def compute(self):
shear1_mse = self.shear1_sum_squared_err / self.total
baseline_shear1_mse = self.baseline_shear1_sum_squared_err / self.total
zero_baseline_shear1_mse = self.zero_baseline_shear1_sum_squared_err / self.total
ellip_baseline_shear1_mse = self.ellip_baseline_shear1_sum_squared_err / self.total
shear2_mse = self.shear2_sum_squared_err / self.total
baseline_shear2_mse = self.baseline_shear2_sum_squared_err / self.total
zero_baseline_shear2_mse = self.zero_baseline_shear2_sum_squared_err / self.total
ellip_baseline_shear2_mse = self.ellip_baseline_shear2_sum_squared_err / self.total
convergence_mse = self.convergence_sum_squared_err / self.total

return {
"Shear 1 MSE": shear1_mse,
"Baseline shear 1 MSE": baseline_shear1_mse,
"Zero baseline shear 1 MSE": zero_baseline_shear1_mse,
"Ellip baseline shear 1 MSE": ellip_baseline_shear1_mse,
"Shear 2 MSE": shear2_mse,
"Baseline shear 2 MSE": baseline_shear2_mse,
"Zero baseline shear 2 MSE": zero_baseline_shear2_mse,
"Ellip baseline shear 2 MSE": ellip_baseline_shear2_mse,
"Convergence MSE": convergence_mse,
}
6 changes: 4 additions & 2 deletions case_studies/weak_lensing/lensing_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def plot_maps(
if not Path(save_local).exists():
Path(save_local).mkdir(parents=True)
fig.savefig(f"{save_local}/lensing_maps_{current_epoch}.png")
return fig, axes

plt.close(fig)


def plot_lensing_scatterplots(
Expand Down Expand Up @@ -186,4 +187,5 @@ def plot_lensing_scatterplots(
if not Path(save_local).exists():
Path(save_local).mkdir(parents=True)
fig.savefig(f"{save_local}/lensing_scatterplots_{current_epoch}.png")
return fig, axes

plt.close(fig)
516 changes: 295 additions & 221 deletions case_studies/weak_lensing/notebooks/dc2/ellipticity.ipynb

Large diffs are not rendered by default.

64 changes: 41 additions & 23 deletions case_studies/weak_lensing/notebooks/dc2/evaluate_encoder.ipynb

Large diffs are not rendered by default.

Loading
Loading