Skip to content

Commit

Permalink
simpler data generation; simpler encoder (#1049)
Browse files Browse the repository at this point in the history
* testing encoder weights trained with an SDSS-like prior

* no tile_slen for TileCatalog; filter by flux in dataloaders, not Encoder

* remove min_flux_for_metrics

* validation image sanity check passing

* working on notebooks

* revamp data generation; remove decals

* toy example shows signs of working; m2 bands filtered

* move alignment and cropping to the PredictSurveyIterator

* minor

* duplicated code
  • Loading branch information
jeff-regier authored Jul 23, 2024
1 parent 83fa80f commit d84e937
Show file tree
Hide file tree
Showing 72 changed files with 745 additions and 3,784 deletions.
11 changes: 1 addition & 10 deletions bliss/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,6 @@
from reproject import reproject_interp


def crop_to_mult16(x):
"""Crop the image dimensions to a multiple of 16."""
# note: by cropping the top-right, we preserve the mapping between pixel coordinates
# and the original WCS coordinates
height = x.shape[1] - (x.shape[1] % 16)
width = x.shape[2] - (x.shape[2] % 16)
return x[:, :height, :width]


def align(img, wcs_list, ref_band, ref_depth=0):
"""Reproject images based on some reference WCS for pixel alignment."""
reproj_d = {}
Expand Down Expand Up @@ -52,4 +43,4 @@ def align(img, wcs_list, ref_band, ref_depth=0):

if reproj_out.shape[0] == 1:
reproj_out = reproj_out.squeeze(axis=0)
return reproj_out
return np.float32(reproj_out)
23 changes: 22 additions & 1 deletion bliss/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.utils.data import DataLoader, Dataset, DistributedSampler, Sampler
from torchvision import transforms

from bliss.catalog import FullCatalog
from bliss.catalog import FullCatalog, TileCatalog
from bliss.global_env import GlobalEnv

# prevent pytorch_lightning warning for num_workers = 2 in dataloaders with IterableDataset
Expand Down Expand Up @@ -67,6 +67,27 @@ def __call__(self, datum_in):
return datum_out


class FluxFilterTransform(torch.nn.Module):
def __init__(self, reference_band, min_flux):
super().__init__()
self.reference_band = reference_band
self.min_flux = min_flux

def __call__(self, datum_in):
datum_out = copy(datum_in)

d1 = {k: v.unsqueeze(0) for k, v in datum_in["tile_catalog"].items()}
target_cat = TileCatalog(d1)
target_cat = target_cat.filter_by_flux(
min_flux=self.min_flux,
band=self.reference_band,
)
d2 = {k: v.squeeze(0) for k, v in target_cat.items()}
datum_out["tile_catalog"] = d2

return datum_out


class ChunkingSampler(Sampler):
def __init__(self, dataset: Dataset) -> None:
super().__init__(dataset)
Expand Down
103 changes: 27 additions & 76 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from typing import Dict, Tuple

import torch
from astropy import units
from astropy.table import Table
from astropy.wcs import WCS
from einops import rearrange, reduce, repeat
from torch import Tensor
Expand All @@ -31,10 +29,7 @@ class SourceType(IntEnum):


class BaseTileCatalog(UserDict):
def __init__(self, tile_slen: int, d: Dict[str, Tensor]):
# TODO: a tile catalog shouldn't know it's side length
self.tile_slen = tile_slen

def __init__(self, d: Dict[str, Tensor]):
v = next(iter(d.values()))
self.batch_size, self.n_tiles_h, self.n_tiles_w = v.shape[:3]
self.device = v.device
Expand All @@ -57,13 +52,13 @@ def to(self, device):
out = {}
for k, v in self.items():
out[k] = v.to(device)
return type(self)(self.tile_slen, out)
return type(self)(out)

def crop(self, hlims_tile, wlims_tile):
out = {}
for k, v in self.items():
out[k] = v[:, hlims_tile[0] : hlims_tile[1], wlims_tile[0] : wlims_tile[1]]
return type(self)(self.tile_slen, out)
return type(self)(out)

def symmetric_crop(self, tiles_to_crop):
return self.crop(
Expand Down Expand Up @@ -115,9 +110,9 @@ class TileCatalog(BaseTileCatalog):
]
galaxy_params_index = {k: i for i, k in enumerate(galaxy_params)}

def __init__(self, tile_slen: int, d: Dict[str, Tensor]):
def __init__(self, d: Dict[str, Tensor]):
self.max_sources = d["locs"].shape[3]
super().__init__(tile_slen, d)
super().__init__(d)

def __getitem__(self, name: str):
# a temporary hack until we stop storing galaxy_params as an array
Expand Down Expand Up @@ -177,20 +172,19 @@ def magnitudes(self):
def magnitudes_njy(self):
return convert_nmgy_to_njymag(self.on_fluxes)

def to_full_catalog(self):
def to_full_catalog(self, tile_slen):
"""Converts image parameters in tiles to parameters of full image.
By parameters, we mean samples from the variational distribution, not the variational
parameters.
Args:
tile_slen: The side length of the square tiles (in pixels).
Returns:
The FullCatalog instance corresponding to the TileCatalog instance.
NOTE: The locations (`"locs"`) are between 0 and 1. The output also contains
pixel locations ("plocs") that are between 0 and `slen`.
"""
# TODO: tile_slen should be an argument to this function, not stored in a tile catalog
plocs = self.get_full_locs_from_tiles()
plocs = self._get_plocs_from_tiles(tile_slen)
param_names_to_mask = {"plocs"}.union(set(self.keys()))
tile_params_to_gather = {"plocs": plocs}
tile_params_to_gather.update(self)
Expand All @@ -212,29 +206,32 @@ def to_full_catalog(self):

params["n_sources"] = reduce(self["n_sources"], "b nth ntw -> b", "sum")

height_px = self.n_tiles_h * self.tile_slen
width_px = self.n_tiles_w * self.tile_slen
height_px = self.n_tiles_h * tile_slen
width_px = self.n_tiles_w * tile_slen

return FullCatalog(height_px, width_px, params)

def get_full_locs_from_tiles(self) -> Tensor:
def _get_plocs_from_tiles(self, tile_slen) -> Tensor:
"""Get the full image locations from tile locations.
Args:
tile_slen: The side length of the square tiles (in pixels).
Returns:
Tensor: pixel coordinates of each source (between 0 and slen).
"""
slen = self.n_tiles_h * self.tile_slen
wlen = self.n_tiles_w * self.tile_slen
slen = self.n_tiles_h * tile_slen
wlen = self.n_tiles_w * tile_slen
# coordinates on tiles.
x_coords = torch.arange(0, slen, self.tile_slen, device=self["locs"].device).long()
y_coords = torch.arange(0, wlen, self.tile_slen, device=self["locs"].device).long()
x_coords = torch.arange(0, slen, tile_slen, device=self["locs"].device).long()
y_coords = torch.arange(0, wlen, tile_slen, device=self["locs"].device).long()
tile_coords = torch.cartesian_prod(x_coords, y_coords)

# recenter and renormalize locations.
locs = rearrange(self["locs"], "b nth ntw d xy -> (b nth ntw) d xy", xy=2)
bias = repeat(tile_coords, "n xy -> (r n) 1 xy", r=self.batch_size).float()

plocs = locs * self.tile_slen + bias
plocs = locs * tile_slen + bias
return rearrange(
plocs,
"(b nth ntw) d xy -> b nth ntw d xy",
Expand Down Expand Up @@ -279,7 +276,7 @@ def _sort_sources_by_flux(self, band=2):
idx_to_gather = repeat(top_indexes, "... -> ... pd", pd=param_dim)
d[key] = torch.take_along_dim(val, idx_to_gather, dim=3)

return TileCatalog(self.tile_slen, d)
return TileCatalog(d)

def get_brightest_sources_per_tile(self, top_k=1, exclude_num=0, band=2):
"""Restrict TileCatalog to only the brightest 'on' source per tile.
Expand All @@ -296,7 +293,7 @@ def get_brightest_sources_per_tile(self, top_k=1, exclude_num=0, band=2):
return self

if exclude_num >= self.max_sources:
tc = TileCatalog(self.tile_slen, self.data)
tc = TileCatalog(self.data)
tc["n_sources"] = torch.zeros_like(tc["n_sources"])
return tc

Expand All @@ -309,7 +306,7 @@ def get_brightest_sources_per_tile(self, top_k=1, exclude_num=0, band=2):
else:
d[key] = val[:, :, :, exclude_num : (exclude_num + top_k)]

return TileCatalog(self.tile_slen, d)
return TileCatalog(d)

def filter_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):
"""Restricts TileCatalog to sources that have a flux between min_flux and max_flux.
Expand Down Expand Up @@ -338,7 +335,7 @@ def filter_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):
else:
d[key] = torch.where(flux_mask.unsqueeze(-1), val, torch.zeros_like(val))

return TileCatalog(self.tile_slen, d)
return TileCatalog(d)

def union(self, other, disjoint=False):
"""Returns a new TileCatalog containing the union of the sources in self and other.
Expand All @@ -352,7 +349,6 @@ def union(self, other, disjoint=False):
Returns:
A new TileCatalog containing the union of the sources in self and other.
"""
assert self.tile_slen == other.tile_slen
assert self.batch_size == other.batch_size
assert self.n_tiles_h == other.n_tiles_h
assert self.n_tiles_w == other.n_tiles_w
Expand All @@ -372,7 +368,7 @@ def union(self, other, disjoint=False):
d1 = torch.cat((v, other[k]), dim=-2)
d2 = torch.cat((other[k], v), dim=-2)
d[k] = torch.where(ns11 > 0, d1, d2)
return TileCatalog(self.tile_slen, d)
return TileCatalog(d)

def __repr__(self):
keys = ", ".join(self.keys())
Expand Down Expand Up @@ -639,54 +635,9 @@ def to_tile_catalog(
# modify tile location
tile_params["locs"][ii] = (tile_params["locs"][ii] % tile_slen) / tile_slen
tile_params.update({"n_sources": tile_n_sources})
return TileCatalog(tile_slen, tile_params)
return TileCatalog(tile_params)

def to_astropy_table(self, encoder_survey_bands: Tuple[str]) -> Table:
# Convert dictionary of tensors to list of dictionaries
on_vals = {}
is_on_mask = self.is_on_mask
for k, v in self.items():
if k == "n_sources":
continue
on_vals[k] = v[is_on_mask].cpu()

# Split to different columns for each band
for b, bl in enumerate(encoder_survey_bands):
on_vals[f"star_flux_{bl}"] = on_vals["star_fluxes"][..., b]
on_vals[f"galaxy_flux_{bl}"] = on_vals["galaxy_fluxes"][..., b]

# Remove combined flux columns
on_vals.pop("star_fluxes")
on_vals.pop("galaxy_fluxes")

# declare our astropy table
est_cat_table = Table(names=on_vals.keys())

# Convert all _fluxes columns to units.Quantity
for bl in encoder_survey_bands:
est_cat_table[f"star_flux_{bl}"].unit = units.nmgy
est_cat_table[f"galaxy_flux_{bl}"].unit = units.nmgy

# add units to some galaxy shape properties
est_cat_table["galaxy_beta_radians"].unit = units.radian
est_cat_table["galaxy_a_d"].unit = units.arcsec
est_cat_table["galaxy_a_b"].unit = units.arcsec

# load data into the astropy table
n = is_on_mask.sum() # number of (predicted) objects
for i in range(n):
row = {}
for k, v in on_vals.items():
row[k] = v[i].cpu().float()
# Convert `source_type` to string "star" or "galaxy" labels
row["source_type"] = "star" if row["source_type"] == SourceType.STAR else "galaxy"
# Force `plocs` to be "({x}, {y})" tuple strings for readability
row["plocs"] = str(tuple(row["plocs"].tolist()))
est_cat_table.add_row(row)

return est_cat_table

def filter_full_catalog_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
assert box_origin[0] + box_len <= self.height, "invalid box"
assert box_origin[1] + box_len <= self.width, "invalid box"

Expand Down
41 changes: 18 additions & 23 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ hydra:

paths:
sdss: /data/scratch/sdss
decals: /data/scratch/decals
des: /data/scratch/des
dc2: /data/scratch/dc2local
cached_data: /data/scratch/regier/sdss_like
Expand All @@ -23,11 +22,10 @@ prior:
_target_: bliss.simulator.prior.CatalogPrior
survey_bands: ["u", "g", "r", "i", "z"] # SDSS available band filters
reference_band: 2 # SDSS r-band
star_color_model_path: ${simulator.survey.dir_path}/color_models/star_gmm_nmgy.pkl
gal_color_model_path: ${simulator.survey.dir_path}/color_models/gal_gmm_nmgy.pkl
star_color_model_path: ${simulator.decoder.survey.dir_path}/color_models/star_gmm_nmgy.pkl
gal_color_model_path: ${simulator.decoder.survey.dir_path}/color_models/gal_gmm_nmgy.pkl
n_tiles_h: 20
n_tiles_w: 20
tile_slen: 4
batch_size: 64
max_sources: 1
mean_sources: 0.01 # 0.0025 is more realistic for SDSS but training takes more iterations
Expand All @@ -46,12 +44,18 @@ prior:
galaxy_a_scale: 4.432725319432478
galaxy_a_bd_ratio: 2.0

decoder:
_target_: bliss.simulator.decoder.Decoder
tile_slen: 4
survey: ${surveys.sdss}
with_dither: true
with_noise: true

simulator:
_target_: bliss.simulator.simulated_dataset.SimulatedDataset
survey: ${surveys.sdss}
prior: ${prior}
decoder: ${decoder}
n_batches: 128
coadd_depth: 1
num_workers: 32
valid_n_batches: 10 # 256
fix_validation_set: true
Expand Down Expand Up @@ -153,9 +157,7 @@ encoder:
_target_: bliss.encoder.encoder.Encoder
survey_bands: ["u", "g", "r", "i", "z"]
reference_band: 2 # SDSS r-band
tile_slen: ${simulator.prior.tile_slen}
min_flux_for_loss: 0 # set to 0 to include all sources
min_flux_for_metrics: 0 # set to 0 to include all sources
tile_slen: ${simulator.decoder.tile_slen}
optimizer_params:
lr: 1e-3
scheduler_params:
Expand Down Expand Up @@ -183,7 +185,7 @@ encoder:
frequency: 1
restrict_batch: 0
tiles_to_crop: 0
tile_slen: ${simulator.prior.tile_slen}
tile_slen: ${simulator.decoder.tile_slen}
use_double_detect: false
use_checkerboard: false

Expand All @@ -198,19 +200,12 @@ surveys:
psf_config:
pixel_scale: 0.396
psf_slen: 25
align_to_band: null # we should set this to 2 (r-band)
load_image_data: false
decals:
_target_: bliss.surveys.decals.DarkEnergyCameraLegacySurvey
dir_path: ${paths.decals}
sky_coords: # in degrees
# brick '3366m010' corresponds to SDSS RCF 94-1-12
- ra: 336.6643042496718
dec: -0.9316385797930247
bands: [0, 1, 3] # grz
psf_config:
pixel_scale: 0.262
psf_slen: 63
# options below only apply to prediction
align_to_band: null # should be 2 but it's slower then
crop_to_bands: null
crop_to_hw: null

des:
_target_: bliss.surveys.des.DarkEnergySurvey
dir_path: ${paths.des}
Expand All @@ -235,7 +230,7 @@ surveys:
n_image_split: 50
tile_slen: 4
max_sources_per_tile: 5
min_flux_for_loss: ${encoder.min_flux_for_loss}
min_flux: 0.0
prepare_data_processes_num: 4
data_in_one_cached_file: 1250
splits: 0:80/80:90/90:100
Expand Down
4 changes: 2 additions & 2 deletions bliss/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __call__(self, datum, vertical_shift=None, horizontal_shift=None):
datum_out["images"] = img

d = {k: v.unsqueeze(0) for k, v in datum["tile_catalog"].items()}
tile_cat = TileCatalog(self.tile_slen, d)
full_cat = tile_cat.to_full_catalog()
tile_cat = TileCatalog(d)
full_cat = tile_cat.to_full_catalog(self.tile_slen)

full_cat["plocs"][:, :, 0] += vertical_shift
full_cat["plocs"][:, :, 1] += horizontal_shift
Expand Down
Loading

0 comments on commit d84e937

Please sign in to comment.