Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add data augmentation and dc2 psf/deconv #932

Merged
merged 4 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ data/sdss/rcf_list
data/sdss_all
data/dc2/lsstdesc-public
data/dc2/calibrated_exposures
data/dc2/coadd_deconv_image
data/dc2/merged_catalog
case_studies/sdss_galaxies/models/simulated_blended_galaxies.pt
case_studies/*/data/
venv
Expand Down
4 changes: 4 additions & 0 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,10 @@
for ii in range(self.batch_size):
n_sources = int(self.n_sources[ii].item())
for idx, coords in enumerate(tile_coords[ii][:n_sources]):
if coords[0] >= tile_n_sources.shape[1] or coords[1] >= tile_n_sources.shape[2]:
continue

Check warning on line 492 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L492

Added line #L492 was not covered by tests
# ignore sources outside of the image (usually caused by data augmentation - shift)
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved

source_idx = tile_n_sources[ii, coords[0], coords[1]].item()
if source_idx >= max_sources_per_tile:
if not ignore_extra_sources:
Expand Down
7 changes: 5 additions & 2 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ encoder:
concat_psf_params: false
log_transform: false
rolling_z_score: true
do_data_augmentation: false
architecture:
# this architecture is based on yolov5l.yaml, see
# https://github.com/ultralytics/yolov5/blob/master/models/yolov5l.yaml
Expand Down Expand Up @@ -227,7 +228,9 @@ surveys:
dc2:
_target_: bliss.surveys.dc2.DC2
data_dir: /nfs/turbo/lsa-regier/lsstdesc-public/dc2/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
cat_path: ${paths.dc2}/merged_catalog/merged_catalog_622.pkl
batch_size: 128
cat_path: ${paths.dc2}/merged_catalog/merged_catalog_psf_100.pkl
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved
batch_size: 64
n_split: 50
image_lim: [4000, 4000]
use_deconv_channel: ${encoder.input_transform_params.use_deconv_channel}
deconv_path: ${paths.dc2}/coadd_deconv_image
98 changes: 98 additions & 0 deletions bliss/data_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import random

import torch
from torchvision.transforms import functional as TF

from bliss.catalog import TileCatalog


def augment_data(tile_catalog, image):
origin_tile = TileCatalog(4, tile_catalog)
origin_full = origin_tile.to_full_params()
aug_full, aug_image = origin_full, image

rotate_list = [None, aug_rotate90, aug_rotate180, aug_rotate270]
flip_list = [None, aug_vflip]
rotate_choice = random.choice(rotate_list)
flip_choice = random.choice(flip_list)
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved

if rotate_choice is not None:
aug_image, aug_full = rotate_choice(aug_full, aug_image)
if flip_choice is not None:
aug_image, aug_full = flip_choice(aug_full, aug_image)

Check warning on line 22 in bliss/data_augmentation.py

View check run for this annotation

Codecov / codecov/patch

bliss/data_augmentation.py#L22

Added line #L22 was not covered by tests

aug_image, aug_full = aug_shift(aug_full, aug_image)
aug_tile = aug_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile


def aug_vflip(origin_full, image):
aug_image = TF.vflip(image)
image_size = image.size(3)
origin_full["plocs"][:, :, 0] = image_size - origin_full["plocs"][:, :, 0] - 1
return aug_image, origin_full


def aug_rotate90(origin_full, image):
num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
rotated_slice = TF.rotate(image[:, :, i, :, :], 90)
aug_images.append(rotated_slice)
aug_image = torch.stack(aug_images, dim=2)
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved

image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 0]
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 1] - 1
return aug_image, origin_full


def aug_rotate180(origin_full, image):
num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved
rotated_slice = TF.rotate(image[:, :, i, :, :], 180)
aug_images.append(rotated_slice)
aug_image = torch.stack(aug_images, dim=2)

image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 1] - 1
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 0] - 1
return aug_image, origin_full


def aug_rotate270(origin_full, image):
num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
rotated_slice = TF.rotate(image[:, :, i, :, :], 270)
aug_images.append(rotated_slice)
aug_image = torch.stack(aug_images, dim=2)

image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 0] - 1
origin_full["plocs"][:, :, 0] = plocs[:, :, 1]
return aug_image, origin_full


def aug_shift(origin_full, image):
shift_x = random.randint(-1, 2)
shift_y = random.randint(-1, 2)
image_size = image.size(3)
image_lim = [2 - shift_x, 2 - shift_x + image_size, 2 - shift_y, 2 - shift_y + image_size]

num_channel = image.size(2)
aug_images = []
for i in range(num_channel):
Copy link
Contributor

@jeff-regier jeff-regier Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, it'd be nice not to loop over the channels. We should be able to pad and slice all the channels at once.

pad_slice = TF.pad(image[:, :, i, :, :], (2, 2), padding_mode="reflect")
aug_images.append(pad_slice[:, :, image_lim[0] : image_lim[1], image_lim[2] : image_lim[3]])
aug_image = torch.stack(aug_images, dim=2)

plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 1] + shift_x
origin_full["plocs"][:, :, 0] = plocs[:, :, 0] + shift_y

return aug_image, origin_full
29 changes: 27 additions & 2 deletions bliss/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from bliss.backbone import Backbone
from bliss.catalog import FullCatalog, SourceType, TileCatalog
from bliss.data_augmentation import augment_data
from bliss.metrics import BlissMetrics, MetricsMode
from bliss.plotting import plot_detections
from bliss.transforms import log_transform, rolling_z_score
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
optimizer_params: Optional[dict] = None,
scheduler_params: Optional[dict] = None,
input_transform_params: Optional[dict] = None,
do_data_augmentation: bool = False,
):
"""Initializes DetectionEncoder.

Expand All @@ -63,6 +65,7 @@ def __init__(
scheduler_params: arguments passed to the learning rate scheduler
input_transform_params: used for determining what channels to use as input (e.g.
deconvolution, concatenate PSF parameters, z-score inputs, etc.)
do_data_augmentation: used for determining whether or not do data augmentation
"""
super().__init__()
self.save_hyperparameters()
Expand All @@ -77,6 +80,7 @@ def __init__(
self.optimizer_params = optimizer_params
self.scheduler_params = scheduler_params if scheduler_params else {"milestones": []}
self.input_transform_params = input_transform_params
self.do_data_augmentation = do_data_augmentation

transform_enabled = (
"log_transform" in self.input_transform_params
Expand Down Expand Up @@ -329,7 +333,27 @@ def _get_loss(self, pred: Dict[str, Distribution], true_tile_cat: TileCatalog):

return loss_with_components

def _generic_step(self, batch, logging_name, log_metrics=False, plot_images=False):
def _generic_step(
self, batch, logging_name, do_data_augmentation=False, log_metrics=False, plot_images=False
):
if do_data_augmentation:
imgs = batch["images"][:, self.bands].unsqueeze(2) # add extra dim for 5d input
bgs = batch["background"][:, self.bands].unsqueeze(2)
aug_input_images = [imgs, bgs]
if self.input_transform_params.get("use_deconv_channel"):
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved
assert (
"deconvolution" in batch
), "use_deconv_channel specified but deconvolution not present in data"
aug_input_images.append(batch["background"][:, self.bands].unsqueeze(2))
aug_input_image = torch.cat(aug_input_images, dim=2)

aug_output_image, tile = augment_data(batch["tile_catalog"], aug_input_image)
batch["images"] = aug_output_image[:, :, 0, :, :]
batch["background"] = aug_output_image[:, :, 1, :, :]
batch["tile_catalog"] = tile
if self.input_transform_params.get("use_deconv_channel"):
batch["deconvolution"] = aug_output_image[:, :, 2, :, :]

batch_size = batch["images"].size(0)
pred = self.encode_batch(batch)
true_tile_cat = TileCatalog(self.tile_slen, batch["tile_catalog"])
Expand Down Expand Up @@ -385,7 +409,8 @@ def _generic_step(self, batch, logging_name, log_metrics=False, plot_images=Fals

def training_step(self, batch, batch_idx, optimizer_idx=0):
"""Training step (pytorch lightning)."""
return self._generic_step(batch, "train")
do_data_augmentation = self.do_data_augmentation
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved
return self._generic_step(batch, "train", do_data_augmentation=do_data_augmentation)

def validation_step(self, batch, batch_idx):
"""Pytorch lightning method."""
Expand Down
58 changes: 42 additions & 16 deletions bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
class DC2(Survey):
BANDS = ("g", "i", "r", "u", "y", "z")

def __init__(self, data_dir, cat_path, batch_size, n_split, image_lim):
def __init__(
self, data_dir, cat_path, batch_size, n_split, image_lim, use_deconv_channel, deconv_path
):
super().__init__()
self.data_dir = data_dir
self.cat_path = cat_path
Expand All @@ -30,6 +32,8 @@ def __init__(self, data_dir, cat_path, batch_size, n_split, image_lim):
self.test = []
self.n_split = n_split
self.image_lim = image_lim
self.use_deconv_channel = use_deconv_channel
self.deconv_path = deconv_path

self._predict_batch = None

Expand All @@ -49,8 +53,8 @@ def image_ids(self):
return [self.dc2_data[i]["images"] for i in range(len(self))]

def prepare_data(self):
img_pattern = "3828/*/calexp*.fits"
bg_pattern = "3828/*/bkgd*.fits"
img_pattern = "**/*/calexp*.fits"
bg_pattern = "**/*/bkgd*.fits"
image_files = []
bg_files = []

Expand All @@ -75,8 +79,10 @@ def prepare_data(self):
plocs_lim = image[0].shape
height = plocs_lim[0]
width = plocs_lim[1]
full_cat = Dc2FullCatalog.from_file(self.cat_path, wcs, height, width, self.bands)
tile_cat = full_cat.to_tile_params(4, 1, 1)
full_cat, psf_params = Dc2FullCatalog.from_file(
self.cat_path, wcs, height, width, self.bands
)
tile_cat = full_cat.to_tile_params(4, 5).get_brightest_source_per_tile()
tile_dict = tile_cat.to_dict()

tile_dict["locs"] = rearrange(tile_cat.to_dict()["locs"], "1 h w nh nw -> h w nh nw")
Expand Down Expand Up @@ -105,14 +111,9 @@ def prepare_data(self):
# split image
split_lim = self.image_lim[0] // self.n_split
image = torch.from_numpy(image)
split1_image = torch.stack(torch.split(image, split_lim, dim=1))
split2_image = torch.stack(torch.split(split1_image, split_lim, dim=3))
split_image = list(torch.split(split2_image.flatten(0, 2), 6))

split_image = split_full_image(image, split_lim)
bg = torch.from_numpy(bg)
split1_bg = torch.stack(torch.split(bg, split_lim, dim=1))
split2_bg = torch.stack(torch.split(split1_bg, split_lim, dim=3))
split_bg = list(torch.split(split2_bg.flatten(0, 2), 6))
split_bg = split_full_image(bg, split_lim)

tile_split = {}
param_list = [
Expand All @@ -133,8 +134,14 @@ def prepare_data(self):
"tile_catalog": [dict(zip(tile_split, i)) for i in zip(*tile_split.values())],
"images": split_image,
"background": split_bg,
"psf_params": [psf_params for _ in range(self.n_split**2)],
}

if self.use_deconv_channel:
file_path = self.deconv_path + "/" + str(image_files[0][n])[-15:-5] + ".pt"
deconv_images = torch.load(file_path)
data_split["deconvolution"] = split_full_image(deconv_images, split_lim)

data.extend([dict(zip(data_split, i)) for i in zip(*data_split.values())])

random.shuffle(data)
Expand Down Expand Up @@ -188,9 +195,7 @@ def from_file(cls, cat_path, wcs, height, width, band):
dec = torch.tensor(catalog["dec"].values)
galaxy_bools = torch.tensor((catalog["truth_type"] == 1).values)
star_bools = torch.tensor((catalog["truth_type"] == 2).values)
flux_list = []
for b in band:
flux_list.append(torch.tensor((catalog["flux_" + b]).values))
flux_list, psf_params = get_band(band, catalog)

flux = torch.stack(flux_list).t()

Expand Down Expand Up @@ -265,7 +270,7 @@ def from_file(cls, cat_path, wcs, height, width, band):
"star_log_fluxes": star_log_fluxes.reshape(1, nobj, 6),
}

return cls(height, width, d)
return cls(height, width, d), torch.stack(psf_params)


def read_frame_for_band(image_files, bg_files, n, n_bands, image_lim):
Expand All @@ -292,3 +297,24 @@ def read_frame_for_band(image_files, bg_files, n, n_bands, image_lim):
bg_list.append(bg)

return image_list, bg_list, wcs


def split_full_image(image, split_lim):
split1_image = torch.stack(torch.split(image, split_lim, dim=1))
split2_image = torch.stack(torch.split(split1_image, split_lim, dim=3))
return list(torch.split(split2_image.flatten(0, 2), 6))


def get_band(band, catalog):
flux_list = []
psf_params = []
for b in band:
flux_list.append(torch.tensor((catalog["flux_" + b]).values))
psf_params_name = ["IxxPSF_pixel_", "IyyPSF_pixel_", "IxyPSF_pixel_", "psf_fwhm_"]
psf_params_band = []
for i in psf_params_name:
median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32)
psf_params_band.append(torch.tensor(median_psf))
psf_params.append(torch.stack(psf_params_band).t())

return flux_list, psf_params
Loading