From 14c5840f5286786047b722aa6722d0f47f5bb489 Mon Sep 17 00:00:00 2001 From: Aakash Patel Date: Fri, 15 Sep 2023 03:13:54 -0400 Subject: [PATCH] Filter region catalog by flux per tile when converting from tile catalog --- .../adaptive_tiling/region_catalog.py | 58 ++++++++++++++++++- tests/test_catalogs.py | 43 +++++++++++++- 2 files changed, 97 insertions(+), 4 deletions(-) diff --git a/case_studies/adaptive_tiling/region_catalog.py b/case_studies/adaptive_tiling/region_catalog.py index 3da838ab8..7b833ed98 100644 --- a/case_studies/adaptive_tiling/region_catalog.py +++ b/case_studies/adaptive_tiling/region_catalog.py @@ -8,7 +8,7 @@ from einops import rearrange, repeat from torch import Tensor -from bliss.catalog import TileCatalog +from bliss.catalog import SourceType, TileCatalog class RegionType(IntEnum): @@ -251,7 +251,7 @@ def crop(self, hlims_tile, wlims_tile): # endregion -def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float): +def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float, discard_extra_sources=True): """Convert a TileCatalog to RegionCatalog. We do this by checking if a location is within the interior or boundary, and copying the @@ -264,6 +264,8 @@ def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float): Args: tile_cat: the tile catalog to convert overlap_slen: the overlap in pixels between tiles + discard_extra_sources: if True, only keep the brightest source in each padded tile. If this + is False and there are multiple sources, an warning will occur. Returns: RegionCatalog: the region-based representation of this TileCatalog @@ -311,6 +313,14 @@ def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float): ) d[key][b, new_i, new_j, 0] = val[b, i, j, 0] + # If there are multiple sources in a tile, only keep the brightest one + if discard_extra_sources: + fluxes = torch.where( + d["source_type"] == SourceType.GALAXY, d["galaxy_fluxes"], d["star_fluxes"] + ) + fluxes = torch.where(d["n_sources"][..., None, None] > 0, fluxes, 0) + d["n_sources"] = filter_regions_by_flux(d["n_sources"], fluxes) + region_cat = RegionCatalog(height=tile_cat.height, overlap_slen=overlap_slen, d=d) offset = repeat( @@ -324,6 +334,50 @@ def tile_cat_to_region_cat(tile_cat: TileCatalog, overlap_slen: float): return region_cat +def filter_regions_by_flux(n_sources, fluxes, band=2): + """Mask out extra sources in each region. + + Args: + n_sources: tensor containing number of sources in each region + fluxes: tensor containing fluxes + band (int, optional): Flux band to filter by. Defaults to 2 (r band). + + Returns: + New tensor of number of sources in each region with extra sources masked out. + """ + # Construct a tensor of integer indices to each region in n_sources. We increment by 1 before + # unfolding and then subtract 1 to differentiate index 0 from the 0 padding. + int_idx = torch.arange(n_sources.numel()).reshape(n_sources.shape) + unfolded_idx = torch.nn.functional.unfold( + (int_idx + 1).unsqueeze(1).float(), kernel_size=(3, 3), padding=1, stride=2 + ) + unfolded_idx = (unfolded_idx - 1).long() + + # Find indices of regions where tile contains more than one sources + sources_per_tile = torch.nn.functional.unfold( + n_sources.unsqueeze(1), kernel_size=(3, 3), padding=1, stride=2 + ).sum(axis=1) + b, c = (sources_per_tile > 1).nonzero(as_tuple=True) + check_idx = unfolded_idx[b, :, c] # int idx of 9 regions for each problematic tile + + # Flatten fluxes to index by + fluxes = fluxes[..., band].reshape(-1, 1) + fluxes = torch.vstack( + (fluxes, torch.ones(fluxes.shape[-1]) * -torch.inf) # add -inf in last row for -1 indices + ) + + # Get flux in each region of problematic tiles + check_flux = fluxes[check_idx.flatten()].reshape(check_idx.shape) + argmax_flux = torch.argmax(check_flux, dim=1) + max_flux_idx = check_idx[torch.arange(len(argmax_flux)), argmax_flux] # idx of max flux regions + + # Turn off all regions in problematic tiles, then turn max flux regions back on + new_n_sources = torch.hstack((n_sources.flatten(), torch.zeros(1))) # extra val for -1 idx + new_n_sources[check_idx.flatten()] = 0 # all off + new_n_sources[max_flux_idx] = 1 # turn max back on + return new_n_sources[:-1].reshape(n_sources.shape) # remove extra val and unflatten + + def region_for_tile_source(loc, pos, n_rows, n_cols, threshold): """Determine which region index a tile-based location should be placed in in a RegionCatalog. diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index 729550af1..f9c2343c5 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path import numpy as np @@ -238,7 +239,45 @@ def test_convert_to_full(self, region_cat): ) assert full_cat.plocs.equal(true_locs) - def test_tile_cat_to_region(self, basic_tilecat): - region_cat = tile_cat_to_region_cat(basic_tilecat, 0.5) + def test_tile_cat_to_region_basic(self, basic_tilecat): + region_cat = tile_cat_to_region_cat(basic_tilecat, 0.5, discard_extra_sources=False) full_cat = basic_tilecat.to_full_params() assert region_cat.to_full_params().plocs.equal(full_cat.plocs) + + def test_tile_cat_to_region_filtering(self): + d = { + "n_sources": torch.zeros(3, 2, 2), + "locs": torch.zeros(3, 2, 2, 1, 2), + "source_type": torch.ones((3, 2, 2, 1, 1)).bool(), + "galaxy_params": torch.zeros((3, 2, 2, 1, 6)), + "star_fluxes": torch.ones((3, 2, 2, 1, 5)) * 1000, + "galaxy_fluxes": torch.ones(3, 2, 2, 1, 5) * 1000, + } + # BATCH 0: top right interior, center right boundary + d["n_sources"][0, 0, 1] = 1 + d["n_sources"][0, 1, 1] = 1 + d["locs"][0, 0, 1, 0] = torch.tensor([0.5, 0.5]) + d["locs"][0, 1, 1, 0] = torch.tensor([0.02, 0.5]) + d["galaxy_fluxes"][0, 0, 1, 0, 2] = 5000 # keep top right + + # BATCH 1: top left interior, top center boundary + d["n_sources"][1, 0, 0] = 1 + d["n_sources"][1, 0, 1] = 1 + d["locs"][1, 0, 0, 0] = torch.tensor([0.5, 0.5]) + d["locs"][1, 0, 1, 0] = torch.tensor([0.5, 0.02]) + d["galaxy_fluxes"][1, 0, 1, 0, 2] = 5000 # keep top center + + # BATCH 2: only one source in top left + d["n_sources"][2, 0, 0] = 1 + d["locs"][2, 0, 0, 0] = torch.tensor([0.5, 0.5]) + + tilecat = TileCatalog(4, d) + + # make sure no warning when converting (since extra sources have been discarded) + with warnings.catch_warnings(): + warnings.simplefilter("error") + region_cat = tile_cat_to_region_cat(tilecat, 0.5, discard_extra_sources=True) + + n_sources = region_cat.n_sources + assert n_sources[0, 0, 2] == n_sources[1, 0, 1] == n_sources[2, 0, 0] == 1 + assert n_sources[0].sum() == n_sources[1].sum() == n_sources[2].sum()