Skip to content

Commit

Permalink
revised color model (#941)
Browse files Browse the repository at this point in the history
* added 2 new dependent tiling notebooks

* working on some notebooks

* clip log-flux ratios to we don't generate crazy fluxes

* changed misleading variable name

* improving m2 data generation

* made flux gmm directly relative to adjacent band rathe r than the reference band

* use stdevs as log transform thresholds

* support 2x2-pixel tiles as well as 4x4

* fix tests

* added 2percent model
  • Loading branch information
jeff-regier authored Oct 26, 2023
1 parent 7d93f51 commit 893d664
Show file tree
Hide file tree
Showing 24 changed files with 11,608 additions and 412 deletions.
6 changes: 3 additions & 3 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ encoder:
include_original: false
use_deconv_channel: false
concat_psf_params: false
log_transform_thresholds: [-100, 0, 100, 500, 1000]
log_transform_stdevs: [-3, 0, 1, 3]
use_clahe: true
do_data_augmentation: false
compile_model: false # if true, compile model for potential performance
Expand All @@ -102,7 +102,7 @@ region_encoder:
include_original: false
use_deconv_channel: false
concat_psf_params: false
log_transform_thresholds: [-100, 0, 100, 500, 1000]
log_transform_stdevs: [-3, 0, 1, 3]
use_clahe: true
overlap_slen: 0.4
slack: ${encoder.slack}
Expand Down Expand Up @@ -147,7 +147,7 @@ predict:
dataset: ${surveys.sdss}
trainer: ${training.trainer}
encoder: ${encoder}
weight_save_path: ${paths.pretrained_models}/clahed_logged_2percent.pt
weight_save_path: ${paths.pretrained_models}/clahed_logged_20percent.pt
device: "cuda:0"
crop:
do_crop: true
Expand Down
4 changes: 2 additions & 2 deletions bliss/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def forward(self, x):


class MarginalNet(nn.Module):
def __init__(self, n_bands, ch_per_band, out_channels):
def __init__(self, n_bands, ch_per_band, out_channels, double_downsample=True):
super().__init__()

nch_hidden = 64
Expand All @@ -68,7 +68,7 @@ def __init__(self, n_bands, ch_per_band, out_channels):
nn.Sequential(*[ConvBlock(64, 64, kernel_size=5, padding=2) for _ in range(4)]),
ConvBlock(64, 128, stride=2),
nn.Sequential(*[ConvBlock(128, 128) for _ in range(5)]),
ConvBlock(128, NUM_FEATURES, stride=2), # 4
ConvBlock(128, NUM_FEATURES, stride=(2 if double_downsample else 1)), # 4
C3(256, 256, n=6), # 5
ConvBlock(256, 512, stride=2),
C3(512, 512, n=3, shortcut=False),
Expand Down
8 changes: 7 additions & 1 deletion bliss/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ def __init__(

ch_per_band = self.image_normalizer.num_channels_per_band()
n_params_per_source = sum(param.dim for param in self.dist_param_groups.values())
self.marginal_net = MarginalNet(len(bands), ch_per_band, n_params_per_source)
assert tile_slen in {2, 4}, "tile_slen must be 2 or 4"
self.marginal_net = MarginalNet(
len(bands),
ch_per_band,
n_params_per_source,
double_downsample=(tile_slen == 4),
)
self.conditional_net = ConditionalNet(n_params_per_source)

if compile_model:
Expand Down
31 changes: 17 additions & 14 deletions bliss/image_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(
include_original: bool,
use_deconv_channel: bool,
concat_psf_params: bool,
log_transform_thresholds: list,
log_transform_stdevs: list,
use_clahe: bool,
):
"""Initializes DetectionEncoder.
Expand All @@ -20,7 +20,7 @@ def __init__(
include_original: whether to include the original image as an input channel
use_deconv_channel: whether to include the deconvolved image as an input channel
concat_psf_params: whether to include the PSF parameters as input channels
log_transform_thresholds: list of thresholds to apply log transform to (can be empty)
log_transform_stdevs: list of thresholds to apply log transform to (can be empty)
use_clahe: whether to apply Contrast Limited Adaptive Histogram Equalization to images
"""
super().__init__()
Expand All @@ -29,11 +29,11 @@ def __init__(
self.include_original = include_original
self.use_deconv_channel = use_deconv_channel
self.concat_psf_params = concat_psf_params
self.log_transform_thresholds = log_transform_thresholds
self.log_transform_stdevs = log_transform_stdevs
self.use_clahe = use_clahe

if not (log_transform_thresholds or use_clahe):
warnings.warn("Either log transform or rolling z-score should be enabled.")
if not (log_transform_stdevs or use_clahe):
warnings.warn("Either log transform or clahe should be enabled.")

def num_channels_per_band(self):
"""Determine number of input channels for model based on desired input transforms."""
Expand All @@ -44,8 +44,8 @@ def num_channels_per_band(self):
nch += 1
if self.concat_psf_params:
nch += 6 # number of PSF parameters for SDSS, may vary for other surveys
if self.log_transform_thresholds:
nch += len(self.log_transform_thresholds)
if self.log_transform_stdevs:
nch += len(self.log_transform_stdevs)
if self.use_clahe:
nch += 1
return nch
Expand All @@ -63,6 +63,9 @@ def get_input_tensor(self, batch):
assert batch["images"].size(2) % 16 == 0, "image dims must be multiples of 16"
assert batch["images"].size(3) % 16 == 0, "image dims must be multiples of 16"

if self.log_transform_stdevs:
assert batch["background"].min() > 1e-6, "background must be positive"

input_bands = batch["images"].shape[1]
if input_bands < len(self.bands):
msg = f"Expected >= {len(self.bands)} bands in the input but found only {input_bands}"
Expand All @@ -87,21 +90,21 @@ def get_input_tensor(self, batch):
psf_params = batch["psf_params"][:, self.bands]
inputs.append(psf_params.view(n, c, 6 * i, 1, 1).expand(n, c, 6 * i, h, w))

for threshold in self.log_transform_thresholds:
image_offsets = raw_images - backgrounds - threshold
transformed_img = torch.log(torch.clamp(image_offsets, min=1.0))
for threshold in self.log_transform_stdevs:
image_offsets = (raw_images - backgrounds) / backgrounds.sqrt() - threshold
transformed_img = torch.log(torch.clamp(image_offsets + 1.0, min=1.0))
inputs.append(transformed_img)

if self.use_clahe:
renormalized_img = self.rolling_z_score(raw_images, 9, 200, 4)
renormalized_img = self.clahe(raw_images, 9, 200, 4)
inputs.append(renormalized_img)
inputs[0] = self.rolling_z_score(backgrounds, 9, 200, 4)
inputs[0] = self.clahe(backgrounds, 9, 200, 4)

return torch.cat(inputs, dim=2)

@classmethod
def rolling_z_score(cls, imgs, s, c, p):
"""Perform a rolling z_score transform on input images."""
def clahe(cls, imgs, s, c, p):
"""Perform Contrast Limited Adaptive Histogram Equalization (CLAHE) on input images."""
imgs4d = torch.squeeze(imgs, dim=2)
padding = (p, p, p, p)
orig_shape = imgs4d.shape
Expand Down
18 changes: 9 additions & 9 deletions bliss/simulator/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,31 +184,31 @@ def render_images(self, tile_cat: TileCatalog, image_ids, coadd_depth=1):
self.flux_calibration_dict[image_ids[b]] for b in range(batch_size)
]

for b in range(batch_size):
# Convert to electron counts
tile_cat["star_fluxes"][b] *= flux_calibration_rats[b]
tile_cat["galaxy_fluxes"][b] *= flux_calibration_rats[b]
for i in range(batch_size):
# Convert from (linear) physical units to electron counts
tile_cat["star_fluxes"][i] *= flux_calibration_rats[i]
tile_cat["galaxy_fluxes"][i] *= flux_calibration_rats[i]

full_cat = tile_cat.to_full_params()

# generate random WCS shifts as manual image dithering via unaligning WCS
wcs_batch = []

for b in range(batch_size):
n_sources = int(full_cat.n_sources[b].item())
psf = psfs[b]
for i in range(batch_size):
n_sources = int(full_cat.n_sources[i].item())
psf = psfs[i]
for d in range(coadd_depth):
depth_band_shifts, depth_band_wcs_list = self.pixel_shifts(
coadd_depth, self.n_bands, self.ref_band
)
wcs_batch.append(depth_band_wcs_list)
for band in range(self.n_bands):
band_img = galsim.Image(array=images[b, d, band], scale=self.pixel_scale)
band_img = galsim.Image(array=images[i, d, band], scale=self.pixel_scale)
self.draw_sources_on_band_image(
band_img,
n_sources,
full_cat,
b,
i,
psf,
band,
image_dims=(slen_h, slen_w),
Expand Down
69 changes: 36 additions & 33 deletions bliss/simulator/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
gal_color_model_path: str,
reference_band: int,
):
"""Initializes ImagePrior.
"""Initializes CatalogPrior.
Args:
survey_bands: all band-pass filters available for this survey
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(

self.star_color_model_path = star_color_model_path
self.gal_color_model_path = gal_color_model_path
self.b_band = reference_band
self.reference_band = reference_band
self.gmm_star, self.gmm_gal = self._load_color_models()

def sample(self) -> TileCatalog:
Expand All @@ -113,9 +113,8 @@ def sample(self) -> TileCatalog:
The remaining dimensions are variable-specific.
"""
locs = self._sample_locs()
stars_fluxes, gals_fluxes = self._flux_ratios()
galaxy_fluxes, galaxy_params = self._sample_galaxy_prior(gals_fluxes)
star_fluxes = self._sample_star_fluxes(stars_fluxes)
galaxy_fluxes, galaxy_params = self._sample_galaxy_prior()
star_fluxes = self._sample_star_fluxes()

n_sources = self._sample_n_sources()
source_type = self._sample_source_type()
Expand Down Expand Up @@ -155,12 +154,14 @@ def _draw_truncated_pareto(self, alpha, min_x, max_x, n_samples) -> Tensor:
uniform_samples = torch.rand(n_samples) * u_max
return min_x / (1.0 - uniform_samples) ** (1 / alpha)

def _sample_star_fluxes(self, star_ratios):
def _sample_star_fluxes(self):
flux_prop = self._sample_flux_ratios(self.gmm_star)

latent_dims = (self.batch_size, self.n_tiles_h, self.n_tiles_w, self.max_sources, 1)
b_flux = self._draw_truncated_pareto(
ref_band_flux = self._draw_truncated_pareto(
self.star_flux_alpha, self.star_flux_min, self.star_flux_max, latent_dims
)
total_flux = b_flux * star_ratios
total_flux = ref_band_flux * flux_prop

# select specified bands
bands = np.array(range(self.n_bands))
Expand All @@ -176,7 +177,7 @@ def _load_color_models(self):
gmm_gal = pickle.load(f)
return gmm_star, gmm_gal

def _flux_ratios(self) -> Tuple[Tensor, Tensor]:
def _sample_flux_ratios(self, gmm) -> Tuple[Tensor, Tensor]:
"""Sample and compute all star, galaxy fluxes based on real image data.
Instead of pareto-sampling fluxes for each band, we pareto-sample `b`-band flux values,
Expand All @@ -186,38 +187,38 @@ def _flux_ratios(self) -> Tuple[Tensor, Tensor]:
flux_r ~ Pareto, flux_g ~ Pareto => flux_r | flux_g ~ Pareto.
This function is survey-specific as the 'b'-band index depends on the survey.
Args:
gmm: Gaussian mixture model of flux ratios (either star or galaxy)
Returns:
stars_fluxes (Tensor): (b x th x tw x ms x nbands) Tensor containing all star flux
ratios for current batch
gals_fluxes (Tensor): (b x th x tw x ms x nbands) Tensor containing all gal fluxes
ratios for current batch
"""

samples = (self.batch_size, self.n_tiles_h, self.n_tiles_w, self.max_sources)
star_ratios_flat = self.gmm_star.sample(np.prod(samples))[0]
gal_ratios_flat = self.gmm_gal.sample(np.prod(samples))[0]
sample_dims = (self.batch_size, self.n_tiles_h, self.n_tiles_w, self.max_sources)
flux_logdiff, _ = gmm.sample(np.prod(sample_dims))

# A log difference of +/- 2.76 correpsonds to a 3 magnitude difference (e.g. 18 vs 21),
# or equivalently an 15.8x flux ratio.
# It's unlikely that objects will have a ratio larger than this.
flux_logdiff = np.clip(flux_logdiff, -2.76, 2.76)
flux_ratio = np.exp(flux_logdiff)

# Computes the flux in each band as a proportion of the reference band flux
flux_prop = torch.ones(flux_logdiff.shape[0], self.n_bands)
for band in range(self.reference_band - 1, -1, -1):
flux_prop[:, band] = flux_prop[:, band + 1] / flux_ratio[:, band]
for band in range(self.reference_band + 1, self.n_bands):
flux_prop[:, band] = flux_prop[:, band - 1] * flux_ratio[:, band - 1]

# Reshape drawn values into appropriate form
samples = samples + (self.n_bands - 1,)
# TODO: remove band-dimension coercing after building DES/DECaLS-specific GMMs
bands = range(star_ratios_flat.shape[-1] - self.n_bands + 1, star_ratios_flat.shape[-1])
star_ratios = np.exp(np.reshape(star_ratios_flat[..., bands], samples))
gal_ratios = np.exp(np.reshape(gal_ratios_flat[..., bands], samples))

# Append r-band 'ratio' of 1's to sampled ratios
base = np.ones((self.batch_size, self.n_tiles_h, self.n_tiles_w, self.max_sources, 1))
arrs = (star_ratios[..., : self.b_band], base, star_ratios[..., self.b_band :])
star_ratios_b = np.concatenate(arrs, axis=4)
arrs = (gal_ratios[..., : self.b_band], base, gal_ratios[..., self.b_band :])
gal_ratios_b = np.concatenate(arrs, axis=4)

return torch.from_numpy(star_ratios_b), torch.from_numpy(gal_ratios_b)

def _sample_galaxy_prior(self, gal_ratios) -> Tuple[Tensor, Tensor]:
"""Sample the latent galaxy params.
sample_dims = sample_dims + (self.n_bands,)
return flux_prop.view(sample_dims)

Args:
gal_ratios: flux ratios for multiband galaxies (why is this an argument?)
def _sample_galaxy_prior(self) -> Tuple[Tensor, Tensor]:
"""Sample the latent galaxy params.
Returns:
Tuple[Tensor]: A tuple of galaxy fluxes (per band) and galsim parameters, including.
Expand All @@ -228,13 +229,15 @@ def _sample_galaxy_prior(self, gal_ratios) -> Tuple[Tensor, Tensor]:
- bulge_q: minor-to-major axis ratio of the bulge
- a_b: semi-major axis of bulge
"""
flux_prop = self._sample_flux_ratios(self.gmm_gal)

latent_dims = (self.batch_size, self.n_tiles_h, self.n_tiles_w, self.max_sources, 1)

b_flux = self._draw_truncated_pareto(
ref_band_flux = self._draw_truncated_pareto(
self.galaxy_alpha, self.galaxy_flux_min, self.galaxy_flux_max, latent_dims
)

total_flux = gal_ratios * b_flux
total_flux = flux_prop * ref_band_flux

# select fluxes from specified bands
bands = np.array(self.bands)
Expand Down
4 changes: 2 additions & 2 deletions bliss/simulator/simulated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
bands=survey.BANDS,
pixel_shift=survey.pixel_shift,
flux_calibration_dict=survey.flux_calibration_dict,
ref_band=prior.b_band,
ref_band=prior.reference_band,
)

self.n_batches = n_batches
Expand Down Expand Up @@ -105,7 +105,7 @@ def align_images(self, images, wcs_batch):
images[b].numpy(),
wcs_list=wcs_batch[b],
ref_depth=0,
ref_band=self.catalog_prior.b_band,
ref_band=self.catalog_prior.reference_band,
)
)
return images
Expand Down
1 change: 1 addition & 0 deletions bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


class DC2(Survey):
# why are these bands out of order? why does a test break if they are ordered correctly?
BANDS = ("g", "i", "r", "u", "y", "z")

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion bliss/surveys/sdss.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
self,
psf_config: PSFConfig,
fields,
pixel_shift,
pixel_shift=0.0,
dir_path="data/sdss",
load_image_data: bool = False,
):
Expand Down
Loading

0 comments on commit 893d664

Please sign in to comment.