-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add data augmentation and dc2 psf/deconv
- Loading branch information
Xinyue Li
committed
Aug 30, 2023
1 parent
c85b92a
commit b258fb6
Showing
7 changed files
with
263 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,117 +1,98 @@ | ||
import random | ||
|
||
import numpy as np | ||
import torchvision as TF | ||
import torch | ||
from torchvision.transforms import functional as TF | ||
|
||
from bliss.catalog import TileCatalog | ||
|
||
|
||
def augment_data(tile_catalog, image, deconv_image=None): | ||
def augment_data(tile_catalog, image): | ||
origin_tile = TileCatalog(4, tile_catalog) | ||
origin_full = origin_tile.to_full_params() | ||
aug_full, aug_image = origin_full, image | ||
|
||
num_transform_list = [1, 2, 3] | ||
num_transform = random.choices(num_transform_list, weights=[0.7, 0.2, 0.1], k=1)[0] | ||
transform_list = ["vflip", "hflip", "rotate90", "rotate180", "rotate270", "shift"] | ||
|
||
aug_method = np.random.choice( | ||
transform_list, num_transform, p=(0.1, 0.1, 0.1, 0.1, 0.1, 0.5), replace=False | ||
) | ||
|
||
aug_image, aug_tile, aug_deconv = origin_full, image, deconv_image | ||
for i in aug_method: | ||
if i == "vflip": | ||
aug_image, aug_tile, aug_deconv = aug_vflip(aug_image, aug_tile, aug_deconv) | ||
if i == "hflip": | ||
aug_image, aug_tile, aug_deconv = aug_hflip(aug_image, aug_tile, aug_deconv) | ||
if i == "rotate90": | ||
aug_image, aug_tile, aug_deconv = aug_rotate90(aug_image, aug_tile, aug_deconv) | ||
if i == "rotate180": | ||
aug_image, aug_tile, aug_deconv = aug_rotate180(aug_image, aug_tile, aug_deconv) | ||
if i == "rotate270": | ||
aug_image, aug_tile, aug_deconv = aug_rotate270(aug_image, aug_tile, aug_deconv) | ||
if i == "shift": | ||
aug_image, aug_tile, aug_deconv = aug_shift(aug_image, aug_tile, aug_deconv) | ||
|
||
return aug_image, aug_tile, aug_deconv | ||
|
||
|
||
def aug_vflip(origin_full, image, deconv_image): | ||
aug_image = TF.transforms.functional.vflip(image) | ||
aug_deconv = None | ||
if deconv_image is not None: | ||
aug_deconv = TF.transforms.functional.vflip(deconv_image) | ||
image_size = image.size(2) | ||
rotate_list = [None, aug_rotate90, aug_rotate180, aug_rotate270] | ||
flip_list = [None, aug_vflip] | ||
rotate_choice = random.choice(rotate_list) | ||
flip_choice = random.choice(flip_list) | ||
|
||
if rotate_choice is not None: | ||
aug_image, aug_full = rotate_choice(aug_full, aug_image) | ||
if flip_choice is not None: | ||
aug_image, aug_full = flip_choice(aug_full, aug_image) | ||
|
||
aug_image, aug_full = aug_shift(aug_full, aug_image) | ||
aug_tile = aug_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() | ||
return aug_image, aug_tile | ||
|
||
|
||
def aug_vflip(origin_full, image): | ||
aug_image = TF.vflip(image) | ||
image_size = image.size(3) | ||
origin_full["plocs"][:, :, 0] = image_size - origin_full["plocs"][:, :, 0] - 1 | ||
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() | ||
return aug_image, aug_tile, aug_deconv | ||
|
||
|
||
def aug_hflip(origin_full, image, deconv_image): | ||
aug_image = TF.transforms.functional.hflip(image) | ||
aug_deconv = None | ||
if deconv_image is not None: | ||
aug_deconv = TF.transforms.functional.hflip(deconv_image) | ||
image_size = image.size(2) | ||
origin_full["plocs"][:, :, 1] = image_size - origin_full["plocs"][:, :, 1] - 1 | ||
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() | ||
return aug_image, aug_tile, aug_deconv | ||
|
||
|
||
def aug_rotate90(origin_full, image, deconv_image): | ||
aug_image = TF.transforms.functional.rotate(image, 90) | ||
aug_deconv = None | ||
if deconv_image is not None: | ||
aug_deconv = TF.transforms.functional.rotate(deconv_image, 90) | ||
image_size = image.size(2) | ||
return aug_image, origin_full | ||
|
||
|
||
def aug_rotate90(origin_full, image): | ||
num_channel = image.size(2) | ||
aug_images = [] | ||
for i in range(num_channel): | ||
rotated_slice = TF.rotate(image[:, :, i, :, :], 90) | ||
aug_images.append(rotated_slice) | ||
aug_image = torch.stack(aug_images, dim=2) | ||
|
||
image_size = image.size(3) | ||
plocs = origin_full["plocs"].clone() | ||
origin_full["plocs"][:, :, 1] = plocs[:, :, 0] | ||
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 1] - 1 | ||
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() | ||
return aug_image, aug_tile, aug_deconv | ||
return aug_image, origin_full | ||
|
||
|
||
def aug_rotate180(origin_full, image, deconv_image): | ||
aug_image = TF.transforms.functional.rotate(image, 180) | ||
aug_deconv = None | ||
if deconv_image is not None: | ||
aug_deconv = TF.transforms.functional.rotate(deconv_image, 180) | ||
image_size = image.size(2) | ||
def aug_rotate180(origin_full, image): | ||
num_channel = image.size(2) | ||
aug_images = [] | ||
for i in range(num_channel): | ||
rotated_slice = TF.rotate(image[:, :, i, :, :], 180) | ||
aug_images.append(rotated_slice) | ||
aug_image = torch.stack(aug_images, dim=2) | ||
|
||
image_size = image.size(3) | ||
plocs = origin_full["plocs"].clone() | ||
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 1] - 1 | ||
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 0] - 1 | ||
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() | ||
return aug_image, aug_tile, aug_deconv | ||
return aug_image, origin_full | ||
|
||
|
||
def aug_rotate270(origin_full, image): | ||
num_channel = image.size(2) | ||
aug_images = [] | ||
for i in range(num_channel): | ||
rotated_slice = TF.rotate(image[:, :, i, :, :], 270) | ||
aug_images.append(rotated_slice) | ||
aug_image = torch.stack(aug_images, dim=2) | ||
|
||
def aug_rotate270(origin_full, image, deconv_image): | ||
aug_image = TF.transforms.functional.rotate(image, 270) | ||
aug_deconv = None | ||
if deconv_image is not None: | ||
aug_deconv = TF.transforms.functional.rotate(deconv_image, 270) | ||
image_size = image.size(2) | ||
image_size = image.size(3) | ||
plocs = origin_full["plocs"].clone() | ||
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 0] - 1 | ||
origin_full["plocs"][:, :, 0] = plocs[:, :, 1] | ||
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() | ||
return aug_image, aug_tile, aug_deconv | ||
|
||
|
||
def aug_shift(origin_full, image, deconv_image): | ||
shift_x = random.randint(0, 3) | ||
shift_y = random.randint(0, 3) | ||
shift_xy = (shift_x, shift_y) | ||
image_size = image.size(2) | ||
aug_deconv = None | ||
pad_image = TF.transforms.functional.pad(image, shift_xy, padding_mode="reflect") | ||
if deconv_image is not None: | ||
pad_deconv = TF.transforms.functional.pad(deconv_image, shift_xy, padding_mode="reflect") | ||
aug_deconv = pad_deconv[:, :, :image_size, :image_size] | ||
|
||
aug_image = pad_image[:, :, :image_size, :image_size] | ||
return aug_image, origin_full | ||
|
||
|
||
def aug_shift(origin_full, image): | ||
shift_x = random.randint(-1, 2) | ||
shift_y = random.randint(-1, 2) | ||
image_size = image.size(3) | ||
image_lim = [2 - shift_x, 2 - shift_x + image_size, 2 - shift_y, 2 - shift_y + image_size] | ||
|
||
num_channel = image.size(2) | ||
aug_images = [] | ||
for i in range(num_channel): | ||
pad_slice = TF.pad(image[:, :, i, :, :], (2, 2), padding_mode="reflect") | ||
aug_images.append(pad_slice[:, :, image_lim[0] : image_lim[1], image_lim[2] : image_lim[3]]) | ||
aug_image = torch.stack(aug_images, dim=2) | ||
|
||
plocs = origin_full["plocs"].clone() | ||
origin_full["plocs"][:, :, 1] = plocs[:, :, 1] + shift_x | ||
origin_full["plocs"][:, :, 0] = plocs[:, :, 0] + shift_y | ||
|
||
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() | ||
return aug_image, aug_tile, aug_deconv | ||
return aug_image, origin_full |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.