Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multidetection and add some calibration plots #1052

Merged
merged 15 commits into from
Aug 6, 2024
6 changes: 5 additions & 1 deletion bliss/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def __call__(self, datum_in):

class ChunkingSampler(Sampler):
def __init__(self, dataset: Dataset) -> None:
super().__init__(dataset)
# please don't pass dataset to the following __init__()
# according to https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler
# the parameter `data_source` has been deprecated
# make sure your pytorch version is greater than 2.2.0
super().__init__()
assert isinstance(dataset, ChunkingDataset), "dataset should be ChunkingDataset"
self.dataset = dataset

Expand Down
62 changes: 19 additions & 43 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import math
from collections import UserDict
from enum import IntEnum
Expand Down Expand Up @@ -81,38 +82,6 @@
[tiles_to_crop, self.n_tiles_w - tiles_to_crop],
)

def filter_base_tile_catalog_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
assert box_origin[0] + box_len < self.height, "invalid box"
assert box_origin[1] + box_len < self.width, "invalid box"

box_origin_tensor = box_origin.view(1, 1, 2).to(device=self.device)
box_end_tensor = (box_origin + box_len).view(1, 1, 2).to(device=self.device)

plocs_mask = torch.all(
(self["plocs"] < box_end_tensor) & (self["plocs"] > box_origin_tensor), dim=2
)

plocs_mask_indexes = plocs_mask.nonzero()
plocs_inverse_mask_indexes = (~plocs_mask).nonzero()
plocs_full_mask_indexes = torch.cat((plocs_mask_indexes, plocs_inverse_mask_indexes), dim=0)
_, index_order = plocs_full_mask_indexes[:, 0].sort(stable=True)
plocs_full_mask_sorted_indexes = plocs_full_mask_indexes[index_order.tolist(), :]

d = {}
new_max_sources = plocs_mask.sum(dim=1).max()
for k, v in self.items():
if k == "n_sources":
d[k] = plocs_mask.sum(dim=1)
else:
d[k] = v[
plocs_full_mask_sorted_indexes[:, 0].tolist(),
plocs_full_mask_sorted_indexes[:, 1].tolist(),
].view(-1, self.max_sources, v.shape[-1])[:, :new_max_sources, :]

d["plocs"] -= box_origin_tensor

return FullCatalog(box_len, box_len, d)


class TileCatalog(BaseTileCatalog):
galaxy_params = [
Expand All @@ -126,7 +95,8 @@
galaxy_params_index = {k: i for i, k in enumerate(galaxy_params)}

def __init__(self, d: Dict[str, Tensor]):
self.max_sources = d["locs"].shape[3]
assert "locs" in d
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved
assert len(d["locs"].shape) == 5
super().__init__(d)

def __getitem__(self, name: str):
Expand All @@ -137,6 +107,10 @@
return self.data["galaxy_params"][..., idx : (idx + 1)]
return super().__getitem__(name)

@property
def max_sources(self):
return self["locs"].shape[3]

@property
def is_on_mask(self) -> Tensor:
"""Provides tensor which indicates how many sources are present for each batch.
Expand Down Expand Up @@ -324,16 +298,22 @@
if key == "n_sources":
d[key] = (sorted_self["n_sources"] - exclude_num).clamp(min=0, max=top_k)
else:
d[key] = val[:, :, :, exclude_num : (exclude_num + top_k)]
slicing_start = exclude_num
slicing_end = exclude_num + top_k
if slicing_end > val.shape[-2]:
pad = torch.zeros_like(val)[:, :, :, 0:1, :].expand(

Check warning on line 304 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L304

Added line #L304 was not covered by tests
-1, -1, -1, slicing_end - val.shape[-2], -1
)
val = torch.cat((val, pad), dim=-2)

Check warning on line 307 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L307

Added line #L307 was not covered by tests
d[key] = val[:, :, :, slicing_start:slicing_end, :]

return TileCatalog(d)

def filter_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):
def filter_by_flux(self, min_flux=0, band=2):
"""Restricts TileCatalog to sources that have a flux between min_flux and max_flux.

Args:
min_flux (float): Minimum flux value to keep. Defaults to 0.
max_flux (float): Maximum flux value to keep. Defaults to infinity.
band (int): The band to compare fluxes in. Defaults to 2 (r-band).

Returns:
Expand All @@ -344,14 +324,10 @@

# get fluxes of "on" sources to mask by
on_nmgy = sorted_self.on_nmgy[..., band]
flux_mask = (on_nmgy > min_flux) & (on_nmgy < max_flux)
flux_mask = on_nmgy > min_flux

d = {}
for key, val in sorted_self.items():
if key == "n_sources":
d[key] = flux_mask.sum(dim=3) # number of sources within range in tile
else:
d[key] = torch.where(flux_mask.unsqueeze(-1), val, torch.zeros_like(val))
d = copy.copy(sorted_self.data)
d["n_sources"] = flux_mask.sum(dim=3) # number of sources within range in tile

return TileCatalog(d)

Expand Down
43 changes: 33 additions & 10 deletions bliss/encoder/variational_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
raise ValueError("invalide nll_gating string")

def sample(self, params, use_mode=False):
qk = self._get_dist(params)
qk = self.get_dist(params)
sample_cat = qk.mode if use_mode else qk.sample()
if self.sample_rearrange is not None:
sample_cat = rearrange(sample_cat, self.sample_rearrange)
Expand All @@ -123,7 +123,7 @@

gating = self.nll_gating(true_tile_cat)

qk = self._get_dist(params)
qk = self.get_dist(params)
if gating.shape != target.shape:
assert gating.shape == target.shape[:-1]
target = torch.where(gating.unsqueeze(-1), target, 0)
Expand All @@ -136,7 +136,7 @@
def __init__(self, *args, **kwargs):
super().__init__(1, *args, **kwargs)

def _get_dist(self, params):
def get_dist(self, params):
yes_prob = params.sigmoid().clamp(1e-4, 1 - 1e-4)
no_yes_prob = torch.cat([1 - yes_prob, yes_prob], dim=3)
# this next line may be helpful with nans encountered during training with fp16s
Expand All @@ -150,7 +150,7 @@
self.low_clamp = low_clamp
self.high_clamp = high_clamp

def _get_dist(self, params):
def get_dist(self, params):
mean = params[:, :, :, 0]
sd = params[:, :, :, 1].clamp(self.low_clamp, self.high_clamp).exp().sqrt()
return Normal(mean, sd)
Expand All @@ -162,7 +162,7 @@
self.low_clamp = low_clamp
self.high_clamp = high_clamp

def _get_dist(self, params):
def get_dist(self, params):
mean = params[:, :, :, :2]
sd = params[:, :, :, 2:].clamp(self.low_clamp, self.high_clamp).exp().sqrt()

Expand All @@ -177,7 +177,7 @@
self.low_clamp = low_clamp
self.high_clamp = high_clamp

def _get_dist(self, params):
def get_dist(self, params):
mu = params[:, :, :, :2].sigmoid()
sigma = params[:, :, :, 2:].clamp(self.low_clamp, self.high_clamp).exp().sqrt()
assert not mu.isnan().any() and not mu.isinf().any(), "mu contains invalid values"
Expand All @@ -191,7 +191,7 @@
n_params = 2 * dim # mean and std for each dimension (diagonal covariance)
super().__init__(n_params, *args, **kwargs)

def _get_dist(self, params):
def get_dist(self, params):
mu = params[:, :, :, 0 : self.dim].clamp(-40, 40)
sigma = params[:, :, :, self.dim : self.n_params].clamp(-6, 5).exp().sqrt()
iid_dist = LogNormalEpsilon(
Expand All @@ -208,7 +208,7 @@
self.high = high
super().__init__(n_params, *args, **kwargs)

def _get_dist(self, params):
def get_dist(self, params):
mu = params[:, :, :, 0 : self.dim]
sigma = params[:, :, :, self.dim : self.n_params].clamp(-10, 10).exp().sqrt()
return RescaledLogitNormal(mu, sigma, low=self.low, high=self.high)
Expand Down Expand Up @@ -237,7 +237,8 @@

# we'll need these calculations later for log_prob
prob_in_unit_box_hw = multiple_normals.cdf(self.b) - multiple_normals.cdf(self.a)
self.log_prob_in_unit_box = prob_in_unit_box_hw.log().sum(dim=-1)
self.log_event_prob_in_unit_box = prob_in_unit_box_hw.log()
self.log_prob_in_unit_box = self.log_event_prob_in_unit_box.sum(dim=-1)

def __repr__(self):
return f"{self.__class__.__name__}({self.base_dist.base_dist})"
Expand All @@ -253,7 +254,7 @@

"""

shape = sample_shape + self.base_dist.batch_shape + self.base_dist.event_shape
shape = sample_shape + self.batch_shape + self.event_shape

# draw using inverse cdf method
# if Fi is the cdf of the relavant gaussian, then
Expand Down Expand Up @@ -303,6 +304,14 @@
assert (self.base_dist.mean >= 0).all() and (self.base_dist.mean <= 1).all()
return self.base_dist.mode

@property
def batch_shape(self):
return self.base_dist.batch_shape

@property
def event_shape(self):
return self.base_dist.event_shape
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved

def log_prob(self, value):
assert (value >= 0).all() and (value <= 1).all()
# subtracting log probability that the base RV is in the unit box
Expand All @@ -315,6 +324,20 @@
log_cdf = (cdf_at_val - cdf_at_lb + 1e-9).log().sum(dim=-1) - self.log_prob_in_unit_box
return log_cdf.exp()

def event_cdf(self, value):
cdf_at_val = self.base_dist.base_dist.cdf(value)
cdf_at_lb = self.lower_cdf
log_cdf = (cdf_at_val - cdf_at_lb + 1e-9).log() - self.log_event_prob_in_unit_box
return log_cdf.exp()

Check warning on line 331 in bliss/encoder/variational_dist.py

View check run for this annotation

Codecov / codecov/patch

bliss/encoder/variational_dist.py#L328-L331

Added lines #L328 - L331 were not covered by tests

def event_icdf(self, value):
assert isinstance(value, torch.Tensor)
assert value.shape == self.lower_cdf.shape
assert (value > 0).all() and (value < 1).all()
converted_cdf = value * (self.upper_cdf - self.lower_cdf) + self.lower_cdf
converted_icdf = self.base_dist.base_dist.icdf(converted_cdf)
return converted_icdf.clamp(self.a, self.b)

Check warning on line 339 in bliss/encoder/variational_dist.py

View check run for this annotation

Codecov / codecov/patch

bliss/encoder/variational_dist.py#L334-L339

Added lines #L334 - L339 were not covered by tests


class RescaledLogitNormal(Distribution):
def __init__(self, mu, sigma, low=0, high=1):
Expand Down
Loading
Loading