Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

To tile coords #939

Merged
merged 7 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 64 additions & 19 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,11 @@ def apply_param_bin(self, pname: str, p_min: float, p_max: float):
return type(self)(self.height, self.width, d_new)

def to_tile_params(
self, tile_slen: int, max_sources_per_tile: int, ignore_extra_sources=False
self,
tile_slen: int,
max_sources_per_tile: int,
ignore_extra_sources=False,
filter_oob=False,
) -> TileCatalog:
"""Returns the TileCatalog corresponding to this FullCatalog.

Expand All @@ -463,6 +467,9 @@ def to_tile_params(
ignore_extra_sources: If False (default), raises an error if the number of sources
in one tile exceeds the `max_sources_per_tile`. If True, only adds the tile
parameters of the first `max_sources_per_tile` sources to the new TileCatalog.
filter_oob: If filter_oob is true, filter out the sources outside the image. (e.g. In
case of data augmentation, there is a chance of some sources located outside the
image)

Returns:
TileCatalog correspond to the each source in the FullCatalog.
Expand All @@ -485,24 +492,62 @@ def to_tile_params(
size = (self.batch_size, n_tiles_h, n_tiles_w, max_sources_per_tile, v.shape[-1])
tile_params[k] = torch.zeros(size, dtype=dtype, device=self.device)

tile_params["locs"] = tile_locs

for ii in range(self.batch_size):
n_sources = int(self.n_sources[ii].item())
for idx, coords in enumerate(tile_coords[ii][:n_sources]):
if coords[0] >= tile_n_sources.shape[1] or coords[1] >= tile_n_sources.shape[2]:
continue
# ignore sources outside of the image (usually caused by data augmentation - shift)

source_idx = tile_n_sources[ii, coords[0], coords[1]].item()
if source_idx >= max_sources_per_tile:
if not ignore_extra_sources:
raise ValueError( # noqa: WPS220
"# of sources per tile exceeds `max_sources_per_tile`."
)
continue # ignore extra sources in this tile.
tile_loc = (self.plocs[ii, idx] - coords * tile_slen) / tile_slen
tile_locs[ii, coords[0], coords[1], source_idx] = tile_loc
for k, v in tile_params.items():
v[ii, coords[0], coords[1], source_idx] = self[k][ii, idx]
tile_n_sources[ii, coords[0], coords[1]] = source_idx + 1
tile_params.update({"locs": tile_locs, "n_sources": tile_n_sources})
plocs_ii = self.plocs[ii][:n_sources]
filter_sources = n_sources
source_indices = tile_coords[ii][:n_sources]
if filter_oob:
x0_mask = (plocs_ii[:, 0] > 0) & (plocs_ii[:, 0] < self.height)
x1_mask = (plocs_ii[:, 1] > 0) & (plocs_ii[:, 1] < self.width)
x_mask = x0_mask * x1_mask
filter_sources = x_mask.sum()
source_indices = source_indices[x_mask]
source_indices = source_indices[:, 0] * n_tiles_w + source_indices[:, 1].unsqueeze(0)
tile_range = torch.arange(n_tiles_h * n_tiles_w, device=self.device).unsqueeze(1)
# get mask, for tiles
source_mask = (source_indices == tile_range).long() # (nth*ntw) x n_sources
if source_mask.sum(-1).max() > max_sources_per_tile:
if not ignore_extra_sources:
raise ValueError( # noqa: WPS220
"# of sources per tile exceeds `max_sources_per_tile`."
)

mask_sources = torch.zeros_like(source_mask, dtype=torch.bool, device=self.device)

# Find the indices of the first 'max_sources_per_tile' truth values in each row
sorted_indices = torch.argsort(source_mask, dim=1, descending=True)
top_indices = sorted_indices[:, :max_sources_per_tile]

# Set the corresponding positions in the output tensor to True
mask_sources = mask_sources.scatter_(1, top_indices, 1) & source_mask

# get n_sources for each tile
tile_n_sources[ii] = mask_sources.reshape(n_tiles_h, n_tiles_w, filter_sources).sum(-1)

for k, v in tile_params.items():
if k == "locs":
k = "plocs"
source_matrix = self[k][ii][:n_sources]
if filter_oob:
source_matrix = source_matrix[x_mask]
num_tile = n_tiles_h * n_tiles_w
source_matrix_expand = source_matrix.unsqueeze(0).expand(num_tile, -1, -1)

masked_params = source_matrix_expand * mask_sources.unsqueeze(2)
# gather params values
gathered_params = masked_params.reshape(
n_tiles_h, n_tiles_w, filter_sources, v[ii].shape[-1]
).to(v[ii].dtype)

# move nonzero ahead (eg. [0, 1, 0, 2] --> [1, 2, 0, 0]) for later used
fill_indice = min(filter_sources, max_sources_per_tile)
index = torch.sort((gathered_params != 0).long(), dim=2, descending=True)[1]
v[ii][:, :, :fill_indice] = gathered_params.gather(2, index)[:, :, :fill_indice]

# modify tile location
tile_params["locs"][ii] = (tile_params["locs"][ii] % tile_slen) / tile_slen
tile_params.update({"n_sources": tile_n_sources})
return TileCatalog(tile_slen, tile_params)
4 changes: 3 additions & 1 deletion bliss/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def augment_data(tile_catalog, image):
aug_image, aug_full = flip_choice(aug_full, aug_image)

aug_image, aug_full = aug_shift(aug_full, aug_image)
aug_tile = aug_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
aug_tile = (
aug_full.to_tile_params(4, 4, filter_oob=True).get_brightest_source_per_tile().to_dict()
)
return aug_image, aug_tile


Expand Down
4 changes: 4 additions & 0 deletions tests/test_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def test_multiple_sources_one_tile(self):
tile_cat = full_cat.to_tile_params(1, 1, ignore_extra_sources=True)
assert torch.equal(tile_cat.n_sources, torch.tensor([[[1, 0], [0, 0]]]))

# test to_tile_coords and to_full_coords (set max_sources_per_tile to 2)
convert_full_cat = full_cat.to_tile_params(1, 2).to_full_params()
assert torch.allclose(convert_full_cat.plocs, full_cat.plocs)

correct_locs = torch.tensor([[[0.5, 0.5], [0, 0]], [[0, 0], [0, 0]]]).reshape(1, 2, 2, 1, 2)
assert torch.allclose(tile_cat.locs, correct_locs)

Expand Down
20 changes: 12 additions & 8 deletions tests/test_dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from einops import rearrange
from hydra.utils import instantiate

from bliss import data_augmentation
from bliss.catalog import TileCatalog
from bliss.data_augmentation import aug_rotate90, aug_rotate180, aug_rotate270, aug_shift, aug_vflip
from bliss.train import train


Expand Down Expand Up @@ -87,17 +87,21 @@ def test_dc2_augmentation(self, cfg):
aug_input_images = [imgs, bgs, deconv_image]
aug_input_images = torch.cat(aug_input_images, dim=2)

aug_list = [
data_augmentation.aug_vflip,
data_augmentation.aug_rotate90,
data_augmentation.aug_rotate180,
data_augmentation.aug_rotate270,
data_augmentation.aug_shift,
]
aug_list = [aug_vflip, aug_rotate90, aug_rotate180, aug_rotate270, aug_shift]

for aug_method in aug_list:
aug_image, aug_full = aug_method(origin_full, aug_input_images)
assert aug_image[0, :, 0, :, :].shape == dc2_obj["images"].shape
assert aug_image[0, :, 1, :, :].shape == dc2_obj["background"].shape
assert aug_image[0, :, 2, :, :].shape == dc2_obj["deconvolution"].shape
assert aug_full["n_sources"] <= origin_full.n_sources

# test rotatation
aug_image90, aug_full90 = aug_rotate90(origin_full, aug_input_images)
_, aug_full270 = aug_rotate270(origin_full, aug_input_images)

_, aug_full90180 = aug_rotate180(aug_full90, aug_image90)
_, aug_full90270 = aug_rotate270(aug_full90, aug_image90)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good test to have, but I realize it doesn't quite test what I thought it would, because there's no conversion btw tile and full catalogs happening.

Would you add an additional test that converts from a tile catalog to a full catalog and then back again to a tile catalog? (And that the first and last tile catalogs are equal?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes make sense


assert aug_full90270 == origin_full
assert aug_full90180 == aug_full270