Skip to content

Commit

Permalink
Weak lensing simulator (#1071)
Browse files Browse the repository at this point in the history
* Modify wl config and prior to generate 512x512 images with constant shear and convergence

* 2048x2048 images, slightly greater shear/conv

* Get rid of super().render_galaxy() in lensing decoder to avoid double convolution of galaxy

* Add flag to circumvent using SDSS frame (which is too small) to add survey bg

* Update two-point notebook to demonstrate robustness to noise distribution

* Galaxy shape derivation notebook + rename dict in prior

* Apply shear and convergence in prior, get rid of LensingDecoder

* Revert "Apply shear and convergence in prior, get rid of LensingDecoder"

This reverts commit 98eda83.

* Update lensing config

* Update lensing config (again)

* Use universal fluxes in lensing decoder

* Big tiles (256x256) in lensing_config

* Update galaxy shape derivation notebook

* Split up render_galaxy in base decoder to avoid repeated code in lensing decoder
  • Loading branch information
timwhite0 authored Sep 15, 2024
1 parent a1c981b commit 0907a16
Show file tree
Hide file tree
Showing 7 changed files with 1,524 additions and 166 deletions.
1 change: 1 addition & 0 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ decoder:
_target_: bliss.simulator.decoder.Decoder
tile_slen: 4
survey: ${surveys.sdss}
use_survey_background: true
with_dither: true
with_noise: true

Expand Down
43 changes: 32 additions & 11 deletions bliss/simulator/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
self,
tile_slen: int,
survey: Survey,
use_survey_background: bool = True,
with_dither: bool = True,
with_noise: bool = True,
) -> None:
Expand All @@ -23,6 +24,7 @@ def __init__(
Args:
tile_slen: side length in pixels of a tile
survey: survey to mimic (psf, background, calibration, etc.)
use_survey_background: if True, add randomly sampled survey background to the images
with_dither: if True, apply random pixel shifts to the images and align them
with_noise: if True, add Poisson noise to the image pixels
"""
Expand All @@ -31,6 +33,7 @@ def __init__(

self.tile_slen = tile_slen
self.survey = survey
self.use_survey_background = use_survey_background
self.with_dither = with_dither
self.with_noise = with_noise

Expand All @@ -50,17 +53,16 @@ def render_star(self, psf, band, source_params):
"""
return psf[band].withFlux(source_params["fluxes"][band].item())

def render_galaxy(self, psf, band, source_params):
"""Render a galaxy with given params and PSF.
def render_bulge_plus_disk(self, band, source_params):
"""Render a galaxy with given params.
Args:
psf (List): a list of PSFs for each band
band (int): band
source_params (Tensor): Tensor containing the parameters for a particular source
(see prior.py for details about these parameters)
Returns:
GSObject: a galsim representation of the rendered galaxy convolved with the PSF
GSObject: a galsim representation of the rendered galaxy
"""
disk_flux = source_params["fluxes"][band] * source_params["galaxy_disk_frac"]
bulge_frac = 1 - source_params["galaxy_disk_frac"]
Expand All @@ -80,7 +82,21 @@ def render_galaxy(self, psf, band, source_params):
bulge = galsim.DeVaucouleurs(flux=bulge_flux, half_light_radius=bulge_hlr_arcsecs)
sheared_bulge = bulge.shear(q=source_params["galaxy_bulge_q"].item(), beta=beta)
components.append(sheared_bulge)
galaxy = galsim.Add(components)
return galsim.Add(components)

def render_galaxy(self, psf, band, source_params):
"""Render a galaxy with given params and PSF.
Args:
psf (List): a list of PSFs for each band
band (int): band
source_params (Tensor): Tensor containing the parameters for a particular source
(see prior.py for details about these parameters)
Returns:
GSObject: a galsim representation of the rendered galaxy convolved with the PSF
"""
galaxy = self.render_bulge_plus_disk(band, source_params)
return galsim.Convolution(galaxy, psf[band])

@property
Expand Down Expand Up @@ -125,6 +141,7 @@ def coadd_images(self, images):
coadded_images[b] = self.survey.coadd_images(images[b])
return torch.from_numpy(coadded_images).float()

# pylint: disable=R0915
def render_image(self, tile_cat):
"""Render a single image from a tile catalog."""
batch_size, n_tiles_h, n_tiles_w = tile_cat["n_sources"].shape
Expand All @@ -138,12 +155,16 @@ def render_image(self, tile_cat):
image_idx = np.random.randint(len(self.survey), dtype=int)
frame = self.survey[image_idx]

# sample background from a random position in the frame
height, width = frame["background"].shape[-2:]
h_diff, w_diff = height - slen_h, width - slen_w
h = 0 if h_diff == 0 else np.random.randint(h_diff)
w = 0 if w_diff == 0 else np.random.randint(w_diff)
background = frame["background"][:, h : (h + slen_h), w : (w + slen_w)]
if self.use_survey_background:
# sample background from a random position in the frame
height, width = frame["background"].shape[-2:]
h_diff, w_diff = height - slen_h, width - slen_w
h = 0 if h_diff == 0 else np.random.randint(h_diff)
w = 0 if w_diff == 0 else np.random.randint(w_diff)
background = frame["background"][:, h : (h + slen_h), w : (w + slen_w)]
else:
background = 0

image += background

full_cat = tile_cat.to_full_catalog(self.tile_slen)
Expand Down
53 changes: 20 additions & 33 deletions case_studies/weak_lensing/lensing_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,31 @@ defaults:
mode: train

paths:
dc2: /data/scratch/dc2local # change for gl
output: /data/scratch/twhit/bliss_output # change for gl
dc2: /data/scratch/dc2local
cached_data: /data/scratch/weak_lensing/weak_lensing_img2048_shear02
output: /data/scratch/twhit/bliss_output

prior:
_target_: case_studies.weak_lensing.lensing_prior.LensingPrior
star_color_model_path: /data/scratch/sdss/color_models/star_gmm_nmgy.pkl
gal_color_model_path: /data/scratch/sdss/color_models/gal_gmm_nmgy.pkl
n_tiles_h: 8
n_tiles_w: 8
batch_size: 1
n_tiles_h: 12 # cropping 2 tiles from each side (4 total)
n_tiles_w: 12 # cropping 2 tiles from each side (4 total)
batch_size: 2
max_sources: 200
constant_shear: 0.2
constant_convergence: 0.2
prob_galaxy: 1.0
mean_sources: 162
arcsec_per_pixel: 0.055
sample_method: cosmology
shear_mean: 0
shear_std: 0.0175
convergence_mean: 0
convergence_std: 0.025
num_knots: 4
mean_sources: 82 # 0.02 * (256/4) * (256/4)

decoder:
_target_: case_studies.weak_lensing.lensing_decoder.LensingDecoder
tile_slen: 128
survey: ${surveys.sdss}
tile_slen: 256
use_survey_background: false
with_dither: false
with_noise: false

simulator:
n_batches: 128
num_workers: 32
valid_n_batches: 10
fix_validation_set: true
cached_simulator:
batch_size: 2
train_transforms: []

variational_factors:
- _target_: bliss.encoder.variational_dist.BivariateNormalFactor
Expand Down Expand Up @@ -115,14 +110,6 @@ surveys:
num_workers: 1
cached_data_path: ${paths.dc2}/dc2_lensing_splits_img2048_tile256

train:
trainer:
logger:
name: dc2_weak_lensing_exp
version: exp_08_12
devices: 1 # cuda:0 for gl
use_distributed_sampler: false
precision: 32-true
data_source: ${surveys.dc2}
pretrained_weights: null
seed: 123123
generate:
n_image_files: 50
n_batches_per_file: 4
2 changes: 1 addition & 1 deletion case_studies/weak_lensing/lensing_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def render_galaxy(self, psf, band, source_params):
Returns:
GSObject: a galsim representation of the rendered galaxy convolved with the PSF
"""
galaxy = super().render_galaxy(psf, band, source_params)
galaxy = self.render_bulge_plus_disk(band, source_params)

shear = source_params["shear"]
shear1, shear2 = shear
Expand Down
136 changes: 15 additions & 121 deletions case_studies/weak_lensing/lensing_prior.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,30 @@
import os

import galsim
import numpy as np
import torch
from torch.distributions import Beta, Normal, Uniform

from bliss.catalog import TileCatalog
from bliss.simulator.prior import CatalogPrior
from case_studies.weak_lensing import generate_angular_cl


class LensingPrior(CatalogPrior):
def __init__(
self,
*args,
arcsec_per_pixel: float,
sample_method: str,
shear_mean: float,
shear_std: float,
convergence_mean: float,
convergence_std: float,
num_knots: int,
constant_shear,
constant_convergence,
**kwargs,
):
super().__init__(*args, **kwargs)
self.arcsec_per_pixel = arcsec_per_pixel

self.sample_method = sample_method
self.shear_mean = shear_mean
self.shear_std = shear_std
self.num_knots = [num_knots, num_knots]
self.convergence_mean = convergence_mean
self.convergence_std = convergence_std

if self.sample_method == "cosmology":
self.grid_size = 0.06

if os.path.exists("angular_cl.npy"):
angular_cl = np.load("angular_cl.npy")
else:
generate_angular_cl.main()
angular_cl = np.load("angular_cl.npy")

angular_cl_table = galsim.LookupTable(x=angular_cl[0], f=angular_cl[1])
self.power_spectrum = galsim.PowerSpectrum(angular_cl_table, units=galsim.degrees)
self.constant_shear = constant_shear
self.constant_convergence = constant_convergence

def _sample_shear_and_convergence(self):
shear_map = torch.zeros((self.batch_size, self.n_tiles_h, self.n_tiles_w, 2))
convergence_map = torch.zeros((self.batch_size, self.n_tiles_h, self.n_tiles_w, 1))

for i in range(self.batch_size):
g1, g2, kappa = self.power_spectrum.buildGrid(
grid_spacing=self.grid_size / self.n_tiles_w,
ngrid=self.n_tiles_w,
get_convergence=True,
units=galsim.degrees,
)
gamma1 = g1 * (1 - kappa)
gamma2 = g2 * (1 - kappa)

shear_map[i, :, :, 0] = torch.from_numpy(gamma1)
shear_map[i, :, :, 1] = torch.from_numpy(gamma2)
convergence_map[i, :, :, 0] = torch.from_numpy(kappa)

return (
shear_map.unsqueeze(3).expand(-1, -1, -1, self.max_sources, -1),
convergence_map.unsqueeze(3).expand(-1, -1, -1, self.max_sources, -1),
shear = self.constant_shear * torch.ones(
(self.batch_size, self.n_tiles_h, self.n_tiles_w, self.max_sources, 2)
)
convergence = self.constant_convergence * torch.ones(
(self.batch_size, self.n_tiles_h, self.n_tiles_w, self.max_sources, 1)
)

def _sample_shear(self):
latent_dims = (self.batch_size, self.n_tiles_h, self.n_tiles_w, 2)
if self.sample_method == "interpolate":
# number of knots in each dimension
corners = (self.batch_size, self.num_knots[0], self.num_knots[1], 2)

shear_maps = Normal(self.shear_mean, self.shear_std).sample(corners)
# want to change from 32 x 20 x 20 x 2 to 32 x 2 x 20 x 20
shear_maps = shear_maps.reshape(
(self.batch_size, 2, self.num_knots[0], self.num_knots[1])
)

shear_maps = torch.nn.functional.interpolate(
shear_maps,
scale_factor=(
self.n_tiles_h // self.num_knots[0],
self.n_tiles_w // self.num_knots[1],
),
mode="bilinear",
align_corners=True,
)

# want to change from 32 x 2 x 20 x 20 to 32 x 20 x 20 x 2
shear_maps = torch.swapaxes(shear_maps, 1, 3)
shear_maps = torch.swapaxes(shear_maps, 1, 2)
else:
shear_maps = Uniform(self.shear_min, self.shear_max).sample(latent_dims)

return shear_maps.unsqueeze(3).expand(-1, -1, -1, self.max_sources, -1)

def _sample_convergence(self):
latent_dims = (self.batch_size, self.n_tiles_h, self.n_tiles_w, 1)
if self.sample_method == "interpolate":
# number of knots in each dimension
corners = (self.batch_size, self.num_knots[0], self.num_knots[1], 1)
convergence_map = Normal(self.convergence_mean, self.convergence_std).sample(corners)
# want to change from 32 x 20 x 20 x 2 to 32 x 2 x 20 x 20
convergence_map = convergence_map.reshape(
(self.batch_size, 1, self.num_knots[0], self.num_knots[1])
)

convergence_map = torch.nn.functional.interpolate(
convergence_map,
scale_factor=(
self.n_tiles_h // self.num_knots[0],
self.n_tiles_w // self.num_knots[1],
),
mode="bilinear",
align_corners=True,
)

# want to change from 32 x 1 x 20 x 20 to 32 x 20 x 20 x 1
convergence_map = torch.swapaxes(convergence_map, 1, 3)
convergence_map = torch.swapaxes(convergence_map, 1, 2)
else:
convergence_map = Beta(self.convergence_a, self.convergence_b).sample(latent_dims)

return convergence_map.unsqueeze(3).expand(-1, -1, -1, self.max_sources, -1)
return shear, convergence

def sample(self) -> TileCatalog:
"""Samples latent variables from the prior of an astronomical image.
Expand All @@ -137,15 +36,10 @@ def sample(self) -> TileCatalog:
The remaining dimensions are variable-specific.
"""

catalog_params = super().sample()

if self.sample_method == "interpolate":
shear = self._sample_shear()
convergence = self._sample_convergence()
elif self.sample_method == "cosmology":
shear, convergence = self._sample_shear_and_convergence()
d = super().sample()

catalog_params["shear"] = shear
catalog_params["convergence"] = convergence
shear, convergence = self._sample_shear_and_convergence()
d["shear"] = shear
d["convergence"] = convergence

return TileCatalog(catalog_params)
return TileCatalog(d)
790 changes: 790 additions & 0 deletions case_studies/weak_lensing/notebooks/dc2/twopoint.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 0907a16

Please sign in to comment.