Skip to content

Commit

Permalink
add data augmentation and dc2 psf/deconv
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyue Li committed Aug 31, 2023
1 parent b258fb6 commit 63825dd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 53 deletions.
48 changes: 18 additions & 30 deletions bliss/data_augmentation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random

import torch
from einops import rearrange
from torchvision.transforms import functional as TF

from bliss.catalog import TileCatalog
Expand Down Expand Up @@ -34,13 +34,7 @@ def aug_vflip(origin_full, image):


def aug_rotate90(origin_full, image):
num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
rotated_slice = TF.rotate(image[:, :, i, :, :], 90)
aug_images.append(rotated_slice)
aug_image = torch.stack(aug_images, dim=2)

aug_image = rotate_images(image, 90)
image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 0]
Expand All @@ -49,13 +43,7 @@ def aug_rotate90(origin_full, image):


def aug_rotate180(origin_full, image):
num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
rotated_slice = TF.rotate(image[:, :, i, :, :], 180)
aug_images.append(rotated_slice)
aug_image = torch.stack(aug_images, dim=2)

aug_image = rotate_images(image, 180)
image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 1] - 1
Expand All @@ -64,35 +52,35 @@ def aug_rotate180(origin_full, image):


def aug_rotate270(origin_full, image):
num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
rotated_slice = TF.rotate(image[:, :, i, :, :], 270)
aug_images.append(rotated_slice)
aug_image = torch.stack(aug_images, dim=2)

aug_image = rotate_images(image, 270)
image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 0] - 1
origin_full["plocs"][:, :, 0] = plocs[:, :, 1]
return aug_image, origin_full


def rotate_images(image, degree):
num_batch = image.size(0)
combined_image = rearrange(image, "bt bd ch h w -> (bt bd) ch h w")
rotated_image = TF.rotate(combined_image, degree)
return rearrange(rotated_image, "(bt bd) ch h w -> bt bd ch h w", bt=num_batch)


def aug_shift(origin_full, image):
shift_x = random.randint(-1, 2)
shift_y = random.randint(-1, 2)
image_size = image.size(3)
image_lim = [2 - shift_x, 2 - shift_x + image_size, 2 - shift_y, 2 - shift_y + image_size]

num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
pad_slice = TF.pad(image[:, :, i, :, :], (2, 2), padding_mode="reflect")
aug_images.append(pad_slice[:, :, image_lim[0] : image_lim[1], image_lim[2] : image_lim[3]])
aug_image = torch.stack(aug_images, dim=2)
num_batch = image.size(0)
combined_image = rearrange(image, "bt bd ch h w -> (bt bd) ch h w")
pad_combined_image = TF.pad(combined_image, (2, 2), padding_mode="reflect")
pad_image = rearrange(pad_combined_image, "(bt bd) ch h w -> bt bd ch h w", bt=num_batch)
aug_image = pad_image[:, :, :, image_lim[0] : image_lim[1], image_lim[2] : image_lim[3]]

plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 1] + shift_x
origin_full["plocs"][:, :, 0] = plocs[:, :, 0] + shift_y
origin_full["plocs"][:, :, 1] = plocs[:, :, 1] + shift_y
origin_full["plocs"][:, :, 0] = plocs[:, :, 0] + shift_x

return aug_image, origin_full
43 changes: 20 additions & 23 deletions bliss/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,27 +333,7 @@ def _get_loss(self, pred: Dict[str, Distribution], true_tile_cat: TileCatalog):

return loss_with_components

def _generic_step(
self, batch, logging_name, do_data_augmentation=False, log_metrics=False, plot_images=False
):
if do_data_augmentation:
imgs = batch["images"][:, self.bands].unsqueeze(2) # add extra dim for 5d input
bgs = batch["background"][:, self.bands].unsqueeze(2)
aug_input_images = [imgs, bgs]
if self.input_transform_params.get("use_deconv_channel"):
assert (
"deconvolution" in batch
), "use_deconv_channel specified but deconvolution not present in data"
aug_input_images.append(batch["background"][:, self.bands].unsqueeze(2))
aug_input_image = torch.cat(aug_input_images, dim=2)

aug_output_image, tile = augment_data(batch["tile_catalog"], aug_input_image)
batch["images"] = aug_output_image[:, :, 0, :, :]
batch["background"] = aug_output_image[:, :, 1, :, :]
batch["tile_catalog"] = tile
if self.input_transform_params.get("use_deconv_channel"):
batch["deconvolution"] = aug_output_image[:, :, 2, :, :]

def _generic_step(self, batch, logging_name, log_metrics=False, plot_images=False):
batch_size = batch["images"].size(0)
pred = self.encode_batch(batch)
true_tile_cat = TileCatalog(self.tile_slen, batch["tile_catalog"])
Expand Down Expand Up @@ -409,8 +389,25 @@ def _generic_step(

def training_step(self, batch, batch_idx, optimizer_idx=0):
"""Training step (pytorch lightning)."""
do_data_augmentation = self.do_data_augmentation
return self._generic_step(batch, "train", do_data_augmentation=do_data_augmentation)
if self.do_data_augmentation:
imgs = batch["images"][:, self.bands].unsqueeze(2) # add extra dim for 5d input
bgs = batch["background"][:, self.bands].unsqueeze(2)
aug_input_images = [imgs, bgs]
if self.input_transform_params.get("use_deconv_channel"):
assert (
"deconvolution" in batch
), "use_deconv_channel specified but deconvolution not present in data"
aug_input_images.append(batch["background"][:, self.bands].unsqueeze(2))
aug_input_image = torch.cat(aug_input_images, dim=2)

aug_output_image, tile = augment_data(batch["tile_catalog"], aug_input_image)
batch["images"] = aug_output_image[:, :, 0, :, :]
batch["background"] = aug_output_image[:, :, 1, :, :]
batch["tile_catalog"] = tile
if self.input_transform_params.get("use_deconv_channel"):
batch["deconvolution"] = aug_output_image[:, :, 2, :, :]

return self._generic_step(batch, "train")

def validation_step(self, batch, batch_idx):
"""Pytorch lightning method."""
Expand Down

0 comments on commit 63825dd

Please sign in to comment.