Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter region catalog by flux per tile when converting from tile catalog #938

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions case_studies/adaptive_tiling/region_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from einops import rearrange, repeat
from torch import Tensor

from bliss.catalog import TileCatalog
from bliss.catalog import SourceType, TileCatalog


class RegionType(IntEnum):
Expand Down Expand Up @@ -251,7 +251,7 @@ def crop(self, hlims_tile, wlims_tile):
# endregion


def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float):
def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float, discard_extra_sources=True):
"""Convert a TileCatalog to RegionCatalog.

We do this by checking if a location is within the interior or boundary, and copying the
Expand All @@ -264,6 +264,8 @@ def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float):
Args:
tile_cat: the tile catalog to convert
overlap_slen: the overlap in pixels between tiles
discard_extra_sources: if True, only keep the brightest source in each padded tile. If this
is False and there are multiple sources, an warning will occur.

Returns:
RegionCatalog: the region-based representation of this TileCatalog
Expand Down Expand Up @@ -311,6 +313,14 @@ def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float):
)
d[key][b, new_i, new_j, 0] = val[b, i, j, 0]

# If there are multiple sources in a tile, only keep the brightest one
if discard_extra_sources:
fluxes = torch.where(
d["source_type"] == SourceType.GALAXY, d["galaxy_fluxes"], d["star_fluxes"]
)
fluxes = torch.where(d["n_sources"][..., None, None] > 0, fluxes, 0)
d["n_sources"] = filter_regions_by_flux(d["n_sources"], fluxes)

region_cat = RegionCatalog(height=tile_cat.height, overlap_slen=overlap_slen, d=d)

offset = repeat(
Expand All @@ -324,6 +334,50 @@ def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float):
return region_cat


def filter_regions_by_flux(n_sources, fluxes, band=2):
"""Mask out extra sources in each region.

Args:
n_sources: tensor containing number of sources in each region
fluxes: tensor containing fluxes
band (int, optional): Flux band to filter by. Defaults to 2 (r band).

Returns:
New tensor of number of sources in each region with extra sources masked out.
"""
# Construct a tensor of integer indices to each region in n_sources. We increment by 1 before
# unfolding and then subtract 1 to differentiate index 0 from the 0 padding.
int_idx = torch.arange(n_sources.numel()).reshape(n_sources.shape)
unfolded_idx = torch.nn.functional.unfold(
(int_idx + 1).unsqueeze(1).float(), kernel_size=(3, 3), padding=1, stride=2
)
unfolded_idx = (unfolded_idx - 1).long()

# Find indices of regions where tile contains more than one sources
sources_per_tile = torch.nn.functional.unfold(
n_sources.unsqueeze(1), kernel_size=(3, 3), padding=1, stride=2
).sum(axis=1)
b, c = (sources_per_tile > 1).nonzero(as_tuple=True)
check_idx = unfolded_idx[b, :, c] # int idx of 9 regions for each problematic tile

# Flatten fluxes to index by
fluxes = fluxes[..., band].reshape(-1, 1)
fluxes = torch.vstack(
(fluxes, torch.ones(fluxes.shape[-1]) * -torch.inf) # add -inf in last row for -1 indices
)

# Get flux in each region of problematic tiles
check_flux = fluxes[check_idx.flatten()].reshape(check_idx.shape)
argmax_flux = torch.argmax(check_flux, dim=1)
max_flux_idx = check_idx[torch.arange(len(argmax_flux)), argmax_flux] # idx of max flux regions

# Turn off all regions in problematic tiles, then turn max flux regions back on
new_n_sources = torch.hstack((n_sources.flatten(), torch.zeros(1))) # extra val for -1 idx
new_n_sources[check_idx.flatten()] = 0 # all off
new_n_sources[max_flux_idx] = 1 # turn max back on
return new_n_sources[:-1].reshape(n_sources.shape) # remove extra val and unflatten


def region_for_tile_source(loc, pos, n_rows, n_cols, threshold):
"""Determine which region index a tile-based location should be placed in in a RegionCatalog.

Expand Down
43 changes: 41 additions & 2 deletions tests/test_catalogs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -238,7 +239,45 @@ def test_convert_to_full(self, region_cat):
)
assert full_cat.plocs.equal(true_locs)

def test_tile_cat_to_region(self, basic_tilecat):
region_cat = tile_cat_to_region_cat(basic_tilecat, 0.5)
def test_tile_cat_to_region_basic(self, basic_tilecat):
region_cat = tile_cat_to_region_cat(basic_tilecat, 0.5, discard_extra_sources=False)
full_cat = basic_tilecat.to_full_params()
assert region_cat.to_full_params().plocs.equal(full_cat.plocs)

def test_tile_cat_to_region_filtering(self):
d = {
"n_sources": torch.zeros(3, 2, 2),
"locs": torch.zeros(3, 2, 2, 1, 2),
"source_type": torch.ones((3, 2, 2, 1, 1)).bool(),
"galaxy_params": torch.zeros((3, 2, 2, 1, 6)),
"star_fluxes": torch.ones((3, 2, 2, 1, 5)) * 1000,
"galaxy_fluxes": torch.ones(3, 2, 2, 1, 5) * 1000,
}
# BATCH 0: top right interior, center right boundary
d["n_sources"][0, 0, 1] = 1
d["n_sources"][0, 1, 1] = 1
d["locs"][0, 0, 1, 0] = torch.tensor([0.5, 0.5])
d["locs"][0, 1, 1, 0] = torch.tensor([0.02, 0.5])
d["galaxy_fluxes"][0, 0, 1, 0, 2] = 5000 # keep top right

# BATCH 1: top left interior, top center boundary
d["n_sources"][1, 0, 0] = 1
d["n_sources"][1, 0, 1] = 1
d["locs"][1, 0, 0, 0] = torch.tensor([0.5, 0.5])
d["locs"][1, 0, 1, 0] = torch.tensor([0.5, 0.02])
d["galaxy_fluxes"][1, 0, 1, 0, 2] = 5000 # keep top center

# BATCH 2: only one source in top left
d["n_sources"][2, 0, 0] = 1
d["locs"][2, 0, 0, 0] = torch.tensor([0.5, 0.5])

tilecat = TileCatalog(4, d)

# make sure no warning when converting (since extra sources have been discarded)
with warnings.catch_warnings():
warnings.simplefilter("error")
region_cat = tile_cat_to_region_cat(tilecat, 0.5, discard_extra_sources=True)

n_sources = region_cat.n_sources
assert n_sources[0, 0, 2] == n_sources[1, 0, 1] == n_sources[2, 0, 0] == 1
assert n_sources[0].sum() == n_sources[1].sum() == n_sources[2].sum()