Skip to content

Commit

Permalink
To tile coords (#939)
Browse files Browse the repository at this point in the history
* add data augmentation and dc2 psf/deconv

* add data augmentation and dc2 psf/deconv

* add data augmentation and dc2 psf/deconv

* new to_tile_coords

* new to_tile_coords

* add tests and filter

* new to_tile_coords

---------

Co-authored-by: Xinyue Li <[email protected]>
  • Loading branch information
XinyueLi1012 and Xinyue Li authored Sep 25, 2023
1 parent 0a8b6db commit bd59bd1
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 28 deletions.
83 changes: 64 additions & 19 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,11 @@ def apply_param_bin(self, pname: str, p_min: float, p_max: float):
return type(self)(self.height, self.width, d_new)

def to_tile_params(
self, tile_slen: int, max_sources_per_tile: int, ignore_extra_sources=False
self,
tile_slen: int,
max_sources_per_tile: int,
ignore_extra_sources=False,
filter_oob=False,
) -> TileCatalog:
"""Returns the TileCatalog corresponding to this FullCatalog.
Expand All @@ -463,6 +467,9 @@ def to_tile_params(
ignore_extra_sources: If False (default), raises an error if the number of sources
in one tile exceeds the `max_sources_per_tile`. If True, only adds the tile
parameters of the first `max_sources_per_tile` sources to the new TileCatalog.
filter_oob: If filter_oob is true, filter out the sources outside the image. (e.g. In
case of data augmentation, there is a chance of some sources located outside the
image)
Returns:
TileCatalog correspond to the each source in the FullCatalog.
Expand All @@ -485,24 +492,62 @@ def to_tile_params(
size = (self.batch_size, n_tiles_h, n_tiles_w, max_sources_per_tile, v.shape[-1])
tile_params[k] = torch.zeros(size, dtype=dtype, device=self.device)

tile_params["locs"] = tile_locs

for ii in range(self.batch_size):
n_sources = int(self.n_sources[ii].item())
for idx, coords in enumerate(tile_coords[ii][:n_sources]):
if coords[0] >= tile_n_sources.shape[1] or coords[1] >= tile_n_sources.shape[2]:
continue
# ignore sources outside of the image (usually caused by data augmentation - shift)

source_idx = tile_n_sources[ii, coords[0], coords[1]].item()
if source_idx >= max_sources_per_tile:
if not ignore_extra_sources:
raise ValueError( # noqa: WPS220
"# of sources per tile exceeds `max_sources_per_tile`."
)
continue # ignore extra sources in this tile.
tile_loc = (self.plocs[ii, idx] - coords * tile_slen) / tile_slen
tile_locs[ii, coords[0], coords[1], source_idx] = tile_loc
for k, v in tile_params.items():
v[ii, coords[0], coords[1], source_idx] = self[k][ii, idx]
tile_n_sources[ii, coords[0], coords[1]] = source_idx + 1
tile_params.update({"locs": tile_locs, "n_sources": tile_n_sources})
plocs_ii = self.plocs[ii][:n_sources]
filter_sources = n_sources
source_indices = tile_coords[ii][:n_sources]
if filter_oob:
x0_mask = (plocs_ii[:, 0] > 0) & (plocs_ii[:, 0] < self.height)
x1_mask = (plocs_ii[:, 1] > 0) & (plocs_ii[:, 1] < self.width)
x_mask = x0_mask * x1_mask
filter_sources = x_mask.sum()
source_indices = source_indices[x_mask]
source_indices = source_indices[:, 0] * n_tiles_w + source_indices[:, 1].unsqueeze(0)
tile_range = torch.arange(n_tiles_h * n_tiles_w, device=self.device).unsqueeze(1)
# get mask, for tiles
source_mask = (source_indices == tile_range).long() # (nth*ntw) x n_sources
if source_mask.sum(-1).max() > max_sources_per_tile:
if not ignore_extra_sources:
raise ValueError( # noqa: WPS220
"# of sources per tile exceeds `max_sources_per_tile`."
)

mask_sources = torch.zeros_like(source_mask, dtype=torch.bool, device=self.device)

# Find the indices of the first 'max_sources_per_tile' truth values in each row
sorted_indices = torch.argsort(source_mask, dim=1, descending=True)
top_indices = sorted_indices[:, :max_sources_per_tile]

# Set the corresponding positions in the output tensor to True
mask_sources = mask_sources.scatter_(1, top_indices, 1) & source_mask

# get n_sources for each tile
tile_n_sources[ii] = mask_sources.reshape(n_tiles_h, n_tiles_w, filter_sources).sum(-1)

for k, v in tile_params.items():
if k == "locs":
k = "plocs"
source_matrix = self[k][ii][:n_sources]
if filter_oob:
source_matrix = source_matrix[x_mask]
num_tile = n_tiles_h * n_tiles_w
source_matrix_expand = source_matrix.unsqueeze(0).expand(num_tile, -1, -1)

masked_params = source_matrix_expand * mask_sources.unsqueeze(2)
# gather params values
gathered_params = masked_params.reshape(
n_tiles_h, n_tiles_w, filter_sources, v[ii].shape[-1]
).to(v[ii].dtype)

# move nonzero ahead (eg. [0, 1, 0, 2] --> [1, 2, 0, 0]) for later used
fill_indice = min(filter_sources, max_sources_per_tile)
index = torch.sort((gathered_params != 0).long(), dim=2, descending=True)[1]
v[ii][:, :, :fill_indice] = gathered_params.gather(2, index)[:, :, :fill_indice]

# modify tile location
tile_params["locs"][ii] = (tile_params["locs"][ii] % tile_slen) / tile_slen
tile_params.update({"n_sources": tile_n_sources})
return TileCatalog(tile_slen, tile_params)
4 changes: 3 additions & 1 deletion bliss/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def augment_data(tile_catalog, image):
aug_image, aug_full = flip_choice(aug_full, aug_image)

aug_image, aug_full = aug_shift(aug_full, aug_image)
aug_tile = aug_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
aug_tile = (
aug_full.to_tile_params(4, 4, filter_oob=True).get_brightest_source_per_tile().to_dict()
)
return aug_image, aug_tile


Expand Down
4 changes: 4 additions & 0 deletions tests/test_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def test_multiple_sources_one_tile(self):
tile_cat = full_cat.to_tile_params(1, 1, ignore_extra_sources=True)
assert torch.equal(tile_cat.n_sources, torch.tensor([[[1, 0], [0, 0]]]))

# test to_tile_coords and to_full_coords (set max_sources_per_tile to 2)
convert_full_cat = full_cat.to_tile_params(1, 2).to_full_params()
assert torch.allclose(convert_full_cat.plocs, full_cat.plocs)

correct_locs = torch.tensor([[[0.5, 0.5], [0, 0]], [[0, 0], [0, 0]]]).reshape(1, 2, 2, 1, 2)
assert torch.allclose(tile_cat.locs, correct_locs)

Expand Down
20 changes: 12 additions & 8 deletions tests/test_dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from einops import rearrange
from hydra.utils import instantiate

from bliss import data_augmentation
from bliss.catalog import TileCatalog
from bliss.data_augmentation import aug_rotate90, aug_rotate180, aug_rotate270, aug_shift, aug_vflip
from bliss.train import train


Expand Down Expand Up @@ -87,17 +87,21 @@ def test_dc2_augmentation(self, cfg):
aug_input_images = [imgs, bgs, deconv_image]
aug_input_images = torch.cat(aug_input_images, dim=2)

aug_list = [
data_augmentation.aug_vflip,
data_augmentation.aug_rotate90,
data_augmentation.aug_rotate180,
data_augmentation.aug_rotate270,
data_augmentation.aug_shift,
]
aug_list = [aug_vflip, aug_rotate90, aug_rotate180, aug_rotate270, aug_shift]

for aug_method in aug_list:
aug_image, aug_full = aug_method(origin_full, aug_input_images)
assert aug_image[0, :, 0, :, :].shape == dc2_obj["images"].shape
assert aug_image[0, :, 1, :, :].shape == dc2_obj["background"].shape
assert aug_image[0, :, 2, :, :].shape == dc2_obj["deconvolution"].shape
assert aug_full["n_sources"] <= origin_full.n_sources

# test rotatation
aug_image90, aug_full90 = aug_rotate90(origin_full, aug_input_images)
_, aug_full270 = aug_rotate270(origin_full, aug_input_images)

_, aug_full90180 = aug_rotate180(aug_full90, aug_image90)
_, aug_full90270 = aug_rotate270(aug_full90, aug_image90)

assert aug_full90270 == origin_full
assert aug_full90180 == aug_full270

0 comments on commit bd59bd1

Please sign in to comment.