Skip to content

Commit

Permalink
Improve multidetection and add some calibration plots (#1052)
Browse files Browse the repository at this point in the history
* modify multidetect

* fix bugs

* refactor directory

* add a new train script

* add vsbc test

* update notebook

* change the output path of notebooks

* add postprocess to multidetection

* add vsbc plot for flux

* add credible interval plots

* update notebooks and address PR comments

* update according to the changes in merge

* fix test fail

---------

Co-authored-by: Yicun Duan <[email protected]>
  • Loading branch information
YicunDuanUMich and Yicun Duan authored Aug 6, 2024
1 parent e857c6f commit da6f2b5
Show file tree
Hide file tree
Showing 40 changed files with 6,010 additions and 1,032 deletions.
6 changes: 5 additions & 1 deletion bliss/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def __call__(self, datum_in):

class ChunkingSampler(Sampler):
def __init__(self, dataset: Dataset) -> None:
super().__init__(dataset)
# please don't pass dataset to the following __init__()
# according to https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler
# the parameter `data_source` has been deprecated
# make sure your pytorch version is greater than 2.2.0
super().__init__()
assert isinstance(dataset, ChunkingDataset), "dataset should be ChunkingDataset"
self.dataset = dataset

Expand Down
62 changes: 19 additions & 43 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import math
from collections import UserDict
from enum import IntEnum
Expand Down Expand Up @@ -81,38 +82,6 @@ def symmetric_crop(self, tiles_to_crop):
[tiles_to_crop, self.n_tiles_w - tiles_to_crop],
)

def filter_base_tile_catalog_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
assert box_origin[0] + box_len < self.height, "invalid box"
assert box_origin[1] + box_len < self.width, "invalid box"

box_origin_tensor = box_origin.view(1, 1, 2).to(device=self.device)
box_end_tensor = (box_origin + box_len).view(1, 1, 2).to(device=self.device)

plocs_mask = torch.all(
(self["plocs"] < box_end_tensor) & (self["plocs"] > box_origin_tensor), dim=2
)

plocs_mask_indexes = plocs_mask.nonzero()
plocs_inverse_mask_indexes = (~plocs_mask).nonzero()
plocs_full_mask_indexes = torch.cat((plocs_mask_indexes, plocs_inverse_mask_indexes), dim=0)
_, index_order = plocs_full_mask_indexes[:, 0].sort(stable=True)
plocs_full_mask_sorted_indexes = plocs_full_mask_indexes[index_order.tolist(), :]

d = {}
new_max_sources = plocs_mask.sum(dim=1).max()
for k, v in self.items():
if k == "n_sources":
d[k] = plocs_mask.sum(dim=1)
else:
d[k] = v[
plocs_full_mask_sorted_indexes[:, 0].tolist(),
plocs_full_mask_sorted_indexes[:, 1].tolist(),
].view(-1, self.max_sources, v.shape[-1])[:, :new_max_sources, :]

d["plocs"] -= box_origin_tensor

return FullCatalog(box_len, box_len, d)


class TileCatalog(BaseTileCatalog):
galaxy_params = [
Expand All @@ -126,7 +95,8 @@ class TileCatalog(BaseTileCatalog):
galaxy_params_index = {k: i for i, k in enumerate(galaxy_params)}

def __init__(self, d: Dict[str, Tensor]):
self.max_sources = d["locs"].shape[3]
assert "locs" in d
assert len(d["locs"].shape) == 5
super().__init__(d)

def __getitem__(self, name: str):
Expand All @@ -137,6 +107,10 @@ def __getitem__(self, name: str):
return self.data["galaxy_params"][..., idx : (idx + 1)]
return super().__getitem__(name)

@property
def max_sources(self):
return self["locs"].shape[3]

@property
def is_on_mask(self) -> Tensor:
"""Provides tensor which indicates how many sources are present for each batch.
Expand Down Expand Up @@ -324,16 +298,22 @@ def get_brightest_sources_per_tile(self, top_k=1, exclude_num=0, band=2):
if key == "n_sources":
d[key] = (sorted_self["n_sources"] - exclude_num).clamp(min=0, max=top_k)
else:
d[key] = val[:, :, :, exclude_num : (exclude_num + top_k)]
slicing_start = exclude_num
slicing_end = exclude_num + top_k
if slicing_end > val.shape[-2]:
pad = torch.zeros_like(val)[:, :, :, 0:1, :].expand(
-1, -1, -1, slicing_end - val.shape[-2], -1
)
val = torch.cat((val, pad), dim=-2)
d[key] = val[:, :, :, slicing_start:slicing_end, :]

return TileCatalog(d)

def filter_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):
def filter_by_flux(self, min_flux=0, band=2):
"""Restricts TileCatalog to sources that have a flux between min_flux and max_flux.
Args:
min_flux (float): Minimum flux value to keep. Defaults to 0.
max_flux (float): Maximum flux value to keep. Defaults to infinity.
band (int): The band to compare fluxes in. Defaults to 2 (r-band).
Returns:
Expand All @@ -344,14 +324,10 @@ def filter_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):

# get fluxes of "on" sources to mask by
on_nmgy = sorted_self.on_nmgy[..., band]
flux_mask = (on_nmgy > min_flux) & (on_nmgy < max_flux)
flux_mask = on_nmgy > min_flux

d = {}
for key, val in sorted_self.items():
if key == "n_sources":
d[key] = flux_mask.sum(dim=3) # number of sources within range in tile
else:
d[key] = torch.where(flux_mask.unsqueeze(-1), val, torch.zeros_like(val))
d = copy.copy(sorted_self.data)
d["n_sources"] = flux_mask.sum(dim=3) # number of sources within range in tile

return TileCatalog(d)

Expand Down
43 changes: 33 additions & 10 deletions bliss/encoder/variational_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _get_nll_gating_instance(self, nll_gating: str):
raise ValueError("invalide nll_gating string")

def sample(self, params, use_mode=False):
qk = self._get_dist(params)
qk = self.get_dist(params)
sample_cat = qk.mode if use_mode else qk.sample()
if self.sample_rearrange is not None:
sample_cat = rearrange(sample_cat, self.sample_rearrange)
Expand All @@ -123,7 +123,7 @@ def compute_nll(self, params, true_tile_cat):

gating = self.nll_gating(true_tile_cat)

qk = self._get_dist(params)
qk = self.get_dist(params)
if gating.shape != target.shape:
assert gating.shape == target.shape[:-1]
target = torch.where(gating.unsqueeze(-1), target, 0)
Expand All @@ -136,7 +136,7 @@ class BernoulliFactor(VariationalFactor):
def __init__(self, *args, **kwargs):
super().__init__(1, *args, **kwargs)

def _get_dist(self, params):
def get_dist(self, params):
yes_prob = params.sigmoid().clamp(1e-4, 1 - 1e-4)
no_yes_prob = torch.cat([1 - yes_prob, yes_prob], dim=3)
# this next line may be helpful with nans encountered during training with fp16s
Expand All @@ -150,7 +150,7 @@ def __init__(self, *args, low_clamp=-20, high_clamp=20, **kwargs):
self.low_clamp = low_clamp
self.high_clamp = high_clamp

def _get_dist(self, params):
def get_dist(self, params):
mean = params[:, :, :, 0]
sd = params[:, :, :, 1].clamp(self.low_clamp, self.high_clamp).exp().sqrt()
return Normal(mean, sd)
Expand All @@ -162,7 +162,7 @@ def __init__(self, *args, low_clamp=-20, high_clamp=20, **kwargs):
self.low_clamp = low_clamp
self.high_clamp = high_clamp

def _get_dist(self, params):
def get_dist(self, params):
mean = params[:, :, :, :2]
sd = params[:, :, :, 2:].clamp(self.low_clamp, self.high_clamp).exp().sqrt()

Expand All @@ -177,7 +177,7 @@ def __init__(self, *args, low_clamp=-6, high_clamp=3, **kwargs):
self.low_clamp = low_clamp
self.high_clamp = high_clamp

def _get_dist(self, params):
def get_dist(self, params):
mu = params[:, :, :, :2].sigmoid()
sigma = params[:, :, :, 2:].clamp(self.low_clamp, self.high_clamp).exp().sqrt()
assert not mu.isnan().any() and not mu.isinf().any(), "mu contains invalid values"
Expand All @@ -191,7 +191,7 @@ def __init__(self, *args, dim=1, **kwargs):
n_params = 2 * dim # mean and std for each dimension (diagonal covariance)
super().__init__(n_params, *args, **kwargs)

def _get_dist(self, params):
def get_dist(self, params):
mu = params[:, :, :, 0 : self.dim].clamp(-40, 40)
sigma = params[:, :, :, self.dim : self.n_params].clamp(-6, 5).exp().sqrt()
iid_dist = LogNormalEpsilon(
Expand All @@ -208,7 +208,7 @@ def __init__(self, *args, low=0, high=1, dim=1, **kwargs):
self.high = high
super().__init__(n_params, *args, **kwargs)

def _get_dist(self, params):
def get_dist(self, params):
mu = params[:, :, :, 0 : self.dim]
sigma = params[:, :, :, self.dim : self.n_params].clamp(-10, 10).exp().sqrt()
return RescaledLogitNormal(mu, sigma, low=self.low, high=self.high)
Expand Down Expand Up @@ -237,7 +237,8 @@ def __init__(self, mu, sigma):

# we'll need these calculations later for log_prob
prob_in_unit_box_hw = multiple_normals.cdf(self.b) - multiple_normals.cdf(self.a)
self.log_prob_in_unit_box = prob_in_unit_box_hw.log().sum(dim=-1)
self.log_event_prob_in_unit_box = prob_in_unit_box_hw.log()
self.log_prob_in_unit_box = self.log_event_prob_in_unit_box.sum(dim=-1)

def __repr__(self):
return f"{self.__class__.__name__}({self.base_dist.base_dist})"
Expand All @@ -253,7 +254,7 @@ def sample(self, sample_shape=()):
"""

shape = sample_shape + self.base_dist.batch_shape + self.base_dist.event_shape
shape = sample_shape + self.batch_shape + self.event_shape

# draw using inverse cdf method
# if Fi is the cdf of the relavant gaussian, then
Expand Down Expand Up @@ -303,6 +304,14 @@ def mode(self):
assert (self.base_dist.mean >= 0).all() and (self.base_dist.mean <= 1).all()
return self.base_dist.mode

@property
def batch_shape(self):
return self.base_dist.batch_shape

@property
def event_shape(self):
return self.base_dist.event_shape

def log_prob(self, value):
assert (value >= 0).all() and (value <= 1).all()
# subtracting log probability that the base RV is in the unit box
Expand All @@ -315,6 +324,20 @@ def cdf(self, value):
log_cdf = (cdf_at_val - cdf_at_lb + 1e-9).log().sum(dim=-1) - self.log_prob_in_unit_box
return log_cdf.exp()

def event_cdf(self, value):
cdf_at_val = self.base_dist.base_dist.cdf(value)
cdf_at_lb = self.lower_cdf
log_cdf = (cdf_at_val - cdf_at_lb + 1e-9).log() - self.log_event_prob_in_unit_box
return log_cdf.exp()

def event_icdf(self, value):
assert isinstance(value, torch.Tensor)
assert value.shape == self.lower_cdf.shape
assert (value > 0).all() and (value < 1).all()
converted_cdf = value * (self.upper_cdf - self.lower_cdf) + self.lower_cdf
converted_icdf = self.base_dist.base_dist.icdf(converted_cdf)
return converted_icdf.clamp(self.a, self.b)


class RescaledLogitNormal(Distribution):
def __init__(self, mu, sigma, low=0, high=1):
Expand Down
Loading

0 comments on commit da6f2b5

Please sign in to comment.