Skip to content

Commit

Permalink
add data augmentation and dc2 psf/deconv
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyue Li committed Aug 30, 2023
1 parent c85b92a commit b258fb6
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 117 deletions.
4 changes: 1 addition & 3 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ encoder:
concat_psf_params: false
log_transform: false
rolling_z_score: true
data_augmentation:
do_data_augmentation: true
epoch_start: 1
do_data_augmentation: false
architecture:
# this architecture is based on yolov5l.yaml, see
# https://github.com/ultralytics/yolov5/blob/master/models/yolov5l.yaml
Expand Down
161 changes: 71 additions & 90 deletions bliss/data_augmentation.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,98 @@
import random

import numpy as np
import torchvision as TF
import torch
from torchvision.transforms import functional as TF

from bliss.catalog import TileCatalog


def augment_data(tile_catalog, image, deconv_image=None):
def augment_data(tile_catalog, image):
origin_tile = TileCatalog(4, tile_catalog)
origin_full = origin_tile.to_full_params()
aug_full, aug_image = origin_full, image

num_transform_list = [1, 2, 3]
num_transform = random.choices(num_transform_list, weights=[0.7, 0.2, 0.1], k=1)[0]
transform_list = ["vflip", "hflip", "rotate90", "rotate180", "rotate270", "shift"]

aug_method = np.random.choice(
transform_list, num_transform, p=(0.1, 0.1, 0.1, 0.1, 0.1, 0.5), replace=False
)

aug_image, aug_tile, aug_deconv = origin_full, image, deconv_image
for i in aug_method:
if i == "vflip":
aug_image, aug_tile, aug_deconv = aug_vflip(aug_image, aug_tile, aug_deconv)
if i == "hflip":
aug_image, aug_tile, aug_deconv = aug_hflip(aug_image, aug_tile, aug_deconv)
if i == "rotate90":
aug_image, aug_tile, aug_deconv = aug_rotate90(aug_image, aug_tile, aug_deconv)
if i == "rotate180":
aug_image, aug_tile, aug_deconv = aug_rotate180(aug_image, aug_tile, aug_deconv)
if i == "rotate270":
aug_image, aug_tile, aug_deconv = aug_rotate270(aug_image, aug_tile, aug_deconv)
if i == "shift":
aug_image, aug_tile, aug_deconv = aug_shift(aug_image, aug_tile, aug_deconv)

return aug_image, aug_tile, aug_deconv


def aug_vflip(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.vflip(image)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.vflip(deconv_image)
image_size = image.size(2)
rotate_list = [None, aug_rotate90, aug_rotate180, aug_rotate270]
flip_list = [None, aug_vflip]
rotate_choice = random.choice(rotate_list)
flip_choice = random.choice(flip_list)

if rotate_choice is not None:
aug_image, aug_full = rotate_choice(aug_full, aug_image)
if flip_choice is not None:
aug_image, aug_full = flip_choice(aug_full, aug_image)

Check warning on line 22 in bliss/data_augmentation.py

View check run for this annotation

Codecov / codecov/patch

bliss/data_augmentation.py#L22

Added line #L22 was not covered by tests

aug_image, aug_full = aug_shift(aug_full, aug_image)
aug_tile = aug_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile


def aug_vflip(origin_full, image):
aug_image = TF.vflip(image)
image_size = image.size(3)
origin_full["plocs"][:, :, 0] = image_size - origin_full["plocs"][:, :, 0] - 1
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv


def aug_hflip(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.hflip(image)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.hflip(deconv_image)
image_size = image.size(2)
origin_full["plocs"][:, :, 1] = image_size - origin_full["plocs"][:, :, 1] - 1
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv


def aug_rotate90(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.rotate(image, 90)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.rotate(deconv_image, 90)
image_size = image.size(2)
return aug_image, origin_full


def aug_rotate90(origin_full, image):
num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
rotated_slice = TF.rotate(image[:, :, i, :, :], 90)
aug_images.append(rotated_slice)
aug_image = torch.stack(aug_images, dim=2)

image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 0]
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 1] - 1
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv
return aug_image, origin_full


def aug_rotate180(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.rotate(image, 180)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.rotate(deconv_image, 180)
image_size = image.size(2)
def aug_rotate180(origin_full, image):
num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
rotated_slice = TF.rotate(image[:, :, i, :, :], 180)
aug_images.append(rotated_slice)
aug_image = torch.stack(aug_images, dim=2)

image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 1] - 1
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 0] - 1
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv
return aug_image, origin_full


def aug_rotate270(origin_full, image):
num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
rotated_slice = TF.rotate(image[:, :, i, :, :], 270)
aug_images.append(rotated_slice)
aug_image = torch.stack(aug_images, dim=2)

def aug_rotate270(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.rotate(image, 270)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.rotate(deconv_image, 270)
image_size = image.size(2)
image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 0] - 1
origin_full["plocs"][:, :, 0] = plocs[:, :, 1]
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv


def aug_shift(origin_full, image, deconv_image):
shift_x = random.randint(0, 3)
shift_y = random.randint(0, 3)
shift_xy = (shift_x, shift_y)
image_size = image.size(2)
aug_deconv = None
pad_image = TF.transforms.functional.pad(image, shift_xy, padding_mode="reflect")
if deconv_image is not None:
pad_deconv = TF.transforms.functional.pad(deconv_image, shift_xy, padding_mode="reflect")
aug_deconv = pad_deconv[:, :, :image_size, :image_size]

aug_image = pad_image[:, :, :image_size, :image_size]
return aug_image, origin_full


def aug_shift(origin_full, image):
shift_x = random.randint(-1, 2)
shift_y = random.randint(-1, 2)
image_size = image.size(3)
image_lim = [2 - shift_x, 2 - shift_x + image_size, 2 - shift_y, 2 - shift_y + image_size]

num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
pad_slice = TF.pad(image[:, :, i, :, :], (2, 2), padding_mode="reflect")
aug_images.append(pad_slice[:, :, image_lim[0] : image_lim[1], image_lim[2] : image_lim[3]])
aug_image = torch.stack(aug_images, dim=2)

plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 1] + shift_x
origin_full["plocs"][:, :, 0] = plocs[:, :, 0] + shift_y

aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv
return aug_image, origin_full
33 changes: 18 additions & 15 deletions bliss/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
optimizer_params: Optional[dict] = None,
scheduler_params: Optional[dict] = None,
input_transform_params: Optional[dict] = None,
data_augmentation: Optional[dict] = None,
do_data_augmentation: bool = False,
):
"""Initializes DetectionEncoder.
Expand All @@ -65,8 +65,7 @@ def __init__(
scheduler_params: arguments passed to the learning rate scheduler
input_transform_params: used for determining what channels to use as input (e.g.
deconvolution, concatenate PSF parameters, z-score inputs, etc.)
data_augmentation: used for determining whether or not do data augmentation and the
data augmentation start point (a.g. after 1 epoch)
do_data_augmentation: used for determining whether or not do data augmentation
"""
super().__init__()
self.save_hyperparameters()
Expand All @@ -81,7 +80,7 @@ def __init__(
self.optimizer_params = optimizer_params
self.scheduler_params = scheduler_params if scheduler_params else {"milestones": []}
self.input_transform_params = input_transform_params
self.data_augmentation = data_augmentation
self.do_data_augmentation = do_data_augmentation

transform_enabled = (
"log_transform" in self.input_transform_params
Expand Down Expand Up @@ -334,23 +333,26 @@ def _get_loss(self, pred: Dict[str, Distribution], true_tile_cat: TileCatalog):

return loss_with_components

def _generic_step(self, batch, logging_name, log_metrics=False, plot_images=False):
do_data_augmentation = self.data_augmentation.get("do_data_augmentation")
data_augmentation_start = self.current_epoch >= self.data_augmentation.get("epoch_start")

if data_augmentation_start and (logging_name == "train") and do_data_augmentation:
deconv_image = None
def _generic_step(
self, batch, logging_name, do_data_augmentation=False, log_metrics=False, plot_images=False
):
if do_data_augmentation:
imgs = batch["images"][:, self.bands].unsqueeze(2) # add extra dim for 5d input
bgs = batch["background"][:, self.bands].unsqueeze(2)
aug_input_images = [imgs, bgs]
if self.input_transform_params.get("use_deconv_channel"):
assert (
"deconvolution" in batch
), "use_deconv_channel specified but deconvolution not present in data"
deconv_image = batch["deconvolution"]
aug_input_images.append(batch["background"][:, self.bands].unsqueeze(2))
aug_input_image = torch.cat(aug_input_images, dim=2)

image, tile, deconv = augment_data(batch["tile_catalog"], batch["images"], deconv_image)
batch["images"] = image
aug_output_image, tile = augment_data(batch["tile_catalog"], aug_input_image)
batch["images"] = aug_output_image[:, :, 0, :, :]
batch["background"] = aug_output_image[:, :, 1, :, :]
batch["tile_catalog"] = tile
if self.input_transform_params.get("use_deconv_channel"):
batch["deconvolution"] = deconv
batch["deconvolution"] = aug_output_image[:, :, 2, :, :]

batch_size = batch["images"].size(0)
pred = self.encode_batch(batch)
Expand Down Expand Up @@ -407,7 +409,8 @@ def _generic_step(self, batch, logging_name, log_metrics=False, plot_images=Fals

def training_step(self, batch, batch_idx, optimizer_idx=0):
"""Training step (pytorch lightning)."""
return self._generic_step(batch, "train")
do_data_augmentation = self.do_data_augmentation
return self._generic_step(batch, "train", do_data_augmentation=do_data_augmentation)

def validation_step(self, batch, batch_idx):
"""Pytorch lightning method."""
Expand Down
Loading

0 comments on commit b258fb6

Please sign in to comment.