Skip to content

Commit

Permalink
add data augmentation and dc2 psf/deconv
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyue Li committed Aug 23, 2023
1 parent 84afd3a commit 106a07a
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ data/sdss/rcf_list
data/sdss_all
data/dc2/lsstdesc-public
data/dc2/calibrated_exposures
data/dc2/coadd_deconv_image
data/dc2/merged_catalog
case_studies/sdss_galaxies/models/simulated_blended_galaxies.pt
case_studies/*/data/
venv
Expand Down
4 changes: 4 additions & 0 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,10 @@ def to_tile_params(
for ii in range(self.batch_size):
n_sources = int(self.n_sources[ii].item())
for idx, coords in enumerate(tile_coords[ii][:n_sources]):
if coords[0] >= tile_n_sources.shape[1] or coords[1] >= tile_n_sources.shape[2]:
continue
# ignore sources outside of the image (usually caused by data augmentation - shift)

source_idx = tile_n_sources[ii, coords[0], coords[1]].item()
if source_idx >= max_sources_per_tile:
if not ignore_extra_sources:
Expand Down
9 changes: 7 additions & 2 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ encoder:
concat_psf_params: false
log_transform: false
rolling_z_score: true
data_augmentation:
do_data_augmentation: true
epoch_start: 1
architecture:
# this architecture is based on yolov5l.yaml, see
# https://github.com/ultralytics/yolov5/blob/master/models/yolov5l.yaml
Expand Down Expand Up @@ -227,7 +230,9 @@ surveys:
dc2:
_target_: bliss.surveys.dc2.DC2
data_dir: /nfs/turbo/lsa-regier/lsstdesc-public/dc2/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
cat_path: ${paths.dc2}/merged_catalog/merged_catalog_622.pkl
batch_size: 128
cat_path: ${paths.dc2}/merged_catalog/merged_catalog_psf_100.pkl
batch_size: 64
n_split: 50
image_lim: [4000, 4000]
use_deconv_channel: ${encoder.input_transform_params.use_deconv_channel}
deconv_path: ${paths.dc2}/coadd_deconv_image
117 changes: 117 additions & 0 deletions bliss/data_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import random

import numpy as np
import torchvision as TF

from bliss.catalog import TileCatalog


def augment_data(tile_catalog, image, deconv_image=None):
origin_tile = TileCatalog(4, tile_catalog)
origin_full = origin_tile.to_full_params()

num_transform_list = [1, 2, 3]
num_transform = random.choices(num_transform_list, weights=[0.7, 0.2, 0.1], k=1)[0]
transform_list = ["vflip", "hflip", "rotate90", "rotate180", "rotate270", "shift"]

aug_method = np.random.choice(
transform_list, num_transform, p=(0.1, 0.1, 0.1, 0.1, 0.1, 0.5), replace=False
)

aug_image, aug_tile, aug_deconv = origin_full, image, deconv_image
for i in aug_method:
if i == "vflip":
aug_image, aug_tile, aug_deconv = aug_vflip(aug_image, aug_tile, aug_deconv)
if i == "hflip":
aug_image, aug_tile, aug_deconv = aug_hflip(aug_image, aug_tile, aug_deconv)
if i == "rotate90":
aug_image, aug_tile, aug_deconv = aug_rotate90(aug_image, aug_tile, aug_deconv)
if i == "rotate180":
aug_image, aug_tile, aug_deconv = aug_rotate180(aug_image, aug_tile, aug_deconv)
if i == "rotate270":
aug_image, aug_tile, aug_deconv = aug_rotate270(aug_image, aug_tile, aug_deconv)
if i == "shift":
aug_image, aug_tile, aug_deconv = aug_shift(aug_image, aug_tile, aug_deconv)

return aug_image, aug_tile, aug_deconv


def aug_vflip(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.vflip(image)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.vflip(deconv_image)
image_size = image.size(2)
origin_full["plocs"][:, :, 0] = image_size - origin_full["plocs"][:, :, 0] - 1
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv


def aug_hflip(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.hflip(image)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.hflip(deconv_image)
image_size = image.size(2)
origin_full["plocs"][:, :, 1] = image_size - origin_full["plocs"][:, :, 1] - 1
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv


def aug_rotate90(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.rotate(image, 90)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.rotate(deconv_image, 90)
image_size = image.size(2)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 0]
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 1] - 1
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv


def aug_rotate180(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.rotate(image, 180)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.rotate(deconv_image, 180)
image_size = image.size(2)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 1] - 1
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 0] - 1
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv


def aug_rotate270(origin_full, image, deconv_image):
aug_image = TF.transforms.functional.rotate(image, 270)
aug_deconv = None
if deconv_image is not None:
aug_deconv = TF.transforms.functional.rotate(deconv_image, 270)
image_size = image.size(2)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 0] - 1
origin_full["plocs"][:, :, 0] = plocs[:, :, 1]
aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv


def aug_shift(origin_full, image, deconv_image):
shift_x = random.randint(0, 3)
shift_y = random.randint(0, 3)
shift_xy = (shift_x, shift_y)
image_size = image.size(2)
aug_deconv = None
pad_image = TF.transforms.functional.pad(image, shift_xy, padding_mode="reflect")
if deconv_image is not None:
pad_deconv = TF.transforms.functional.pad(deconv_image, shift_xy, padding_mode="reflect")
aug_deconv = pad_deconv[:, :, :image_size, :image_size]

aug_image = pad_image[:, :, :image_size, :image_size]
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 1] + shift_x
origin_full["plocs"][:, :, 0] = plocs[:, :, 0] + shift_y

aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile, aug_deconv
22 changes: 22 additions & 0 deletions bliss/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from bliss.backbone import Backbone
from bliss.catalog import FullCatalog, SourceType, TileCatalog
from bliss.data_augmentation import augment_data
from bliss.metrics import BlissMetrics, MetricsMode
from bliss.plotting import plot_detections
from bliss.transforms import log_transform, rolling_z_score
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
optimizer_params: Optional[dict] = None,
scheduler_params: Optional[dict] = None,
input_transform_params: Optional[dict] = None,
data_augmentation: Optional[dict] = None,
):
"""Initializes DetectionEncoder.
Expand All @@ -63,6 +65,8 @@ def __init__(
scheduler_params: arguments passed to the learning rate scheduler
input_transform_params: used for determining what channels to use as input (e.g.
deconvolution, concatenate PSF parameters, z-score inputs, etc.)
data_augmentation: used for determining whether or not do data augmentation and the
data augmentation start point (a.g. after 1 epoch)
"""
super().__init__()
self.save_hyperparameters()
Expand All @@ -77,6 +81,7 @@ def __init__(
self.optimizer_params = optimizer_params
self.scheduler_params = scheduler_params if scheduler_params else {"milestones": []}
self.input_transform_params = input_transform_params
self.data_augmentation = data_augmentation

transform_enabled = (
"log_transform" in self.input_transform_params
Expand Down Expand Up @@ -330,6 +335,23 @@ def _get_loss(self, pred: Dict[str, Distribution], true_tile_cat: TileCatalog):
return loss_with_components

def _generic_step(self, batch, logging_name, log_metrics=False, plot_images=False):
do_data_augmentation = self.data_augmentation.get("do_data_augmentation")
data_augmentation_start = self.current_epoch >= self.data_augmentation.get("epoch_start")

if data_augmentation_start and (logging_name == "train") and do_data_augmentation:
deconv_image = None
if self.input_transform_params.get("use_deconv_channel"):
assert (
"deconvolution" in batch
), "use_deconv_channel specified but deconvolution not present in data"
deconv_image = batch["deconvolution"]

image, tile, deconv = augment_data(batch["tile_catalog"], batch["images"], deconv_image)
batch["images"] = image
batch["tile_catalog"] = tile
if self.input_transform_params.get("use_deconv_channel"):
batch["deconvolution"] = deconv

batch_size = batch["images"].size(0)
pred = self.encode_batch(batch)
true_tile_cat = TileCatalog(self.tile_slen, batch["tile_catalog"])
Expand Down
58 changes: 42 additions & 16 deletions bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
class DC2(Survey):
BANDS = ("g", "i", "r", "u", "y", "z")

def __init__(self, data_dir, cat_path, batch_size, n_split, image_lim):
def __init__(
self, data_dir, cat_path, batch_size, n_split, image_lim, use_deconv_channel, deconv_path
):
super().__init__()
self.data_dir = data_dir
self.cat_path = cat_path
Expand All @@ -30,6 +32,8 @@ def __init__(self, data_dir, cat_path, batch_size, n_split, image_lim):
self.test = []
self.n_split = n_split
self.image_lim = image_lim
self.use_deconv_channel = use_deconv_channel
self.deconv_path = deconv_path

self._predict_batch = None

Expand All @@ -49,8 +53,8 @@ def image_ids(self):
return [self.dc2_data[i]["images"] for i in range(len(self))]

def prepare_data(self):
img_pattern = "3828/*/calexp*.fits"
bg_pattern = "3828/*/bkgd*.fits"
img_pattern = "**/*/calexp*.fits"
bg_pattern = "**/*/bkgd*.fits"
image_files = []
bg_files = []

Expand All @@ -75,8 +79,10 @@ def prepare_data(self):
plocs_lim = image[0].shape
height = plocs_lim[0]
width = plocs_lim[1]
full_cat = Dc2FullCatalog.from_file(self.cat_path, wcs, height, width, self.bands)
tile_cat = full_cat.to_tile_params(4, 1, 1)
full_cat, psf_params = Dc2FullCatalog.from_file(
self.cat_path, wcs, height, width, self.bands
)
tile_cat = full_cat.to_tile_params(4, 5).get_brightest_source_per_tile()
tile_dict = tile_cat.to_dict()

tile_dict["locs"] = rearrange(tile_cat.to_dict()["locs"], "1 h w nh nw -> h w nh nw")
Expand Down Expand Up @@ -105,14 +111,9 @@ def prepare_data(self):
# split image
split_lim = self.image_lim[0] // self.n_split
image = torch.from_numpy(image)
split1_image = torch.stack(torch.split(image, split_lim, dim=1))
split2_image = torch.stack(torch.split(split1_image, split_lim, dim=3))
split_image = list(torch.split(split2_image.flatten(0, 2), 6))

split_image = split_full_image(image, split_lim)
bg = torch.from_numpy(bg)
split1_bg = torch.stack(torch.split(bg, split_lim, dim=1))
split2_bg = torch.stack(torch.split(split1_bg, split_lim, dim=3))
split_bg = list(torch.split(split2_bg.flatten(0, 2), 6))
split_bg = split_full_image(bg, split_lim)

tile_split = {}
param_list = [
Expand All @@ -133,8 +134,14 @@ def prepare_data(self):
"tile_catalog": [dict(zip(tile_split, i)) for i in zip(*tile_split.values())],
"images": split_image,
"background": split_bg,
"psf_params": [psf_params for _ in range(self.n_split**2)],
}

if self.use_deconv_channel:
file_path = self.deconv_path + "/" + str(image_files[0][n])[-15:-5] + ".pt"
deconv_images = torch.load(file_path)
data_split["deconvolution"] = split_full_image(deconv_images, split_lim)

data.extend([dict(zip(data_split, i)) for i in zip(*data_split.values())])

random.shuffle(data)
Expand Down Expand Up @@ -188,9 +195,7 @@ def from_file(cls, cat_path, wcs, height, width, band):
dec = torch.tensor(catalog["dec"].values)
galaxy_bools = torch.tensor((catalog["truth_type"] == 1).values)
star_bools = torch.tensor((catalog["truth_type"] == 2).values)
flux_list = []
for b in band:
flux_list.append(torch.tensor((catalog["flux_" + b]).values))
flux_list, psf_params = get_band(band, catalog)

flux = torch.stack(flux_list).t()

Expand Down Expand Up @@ -265,7 +270,7 @@ def from_file(cls, cat_path, wcs, height, width, band):
"star_log_fluxes": star_log_fluxes.reshape(1, nobj, 6),
}

return cls(height, width, d)
return cls(height, width, d), torch.stack(psf_params)


def read_frame_for_band(image_files, bg_files, n, n_bands, image_lim):
Expand All @@ -292,3 +297,24 @@ def read_frame_for_band(image_files, bg_files, n, n_bands, image_lim):
bg_list.append(bg)

return image_list, bg_list, wcs


def split_full_image(image, split_lim):
split1_image = torch.stack(torch.split(image, split_lim, dim=1))
split2_image = torch.stack(torch.split(split1_image, split_lim, dim=3))
return list(torch.split(split2_image.flatten(0, 2), 6))


def get_band(band, catalog):
flux_list = []
psf_params = []
for b in band:
flux_list.append(torch.tensor((catalog["flux_" + b]).values))
psf_params_name = ["IxxPSF_pixel_", "IyyPSF_pixel_", "IxyPSF_pixel_", "psf_fwhm_"]
psf_params_band = []
for i in psf_params_name:
median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32)
psf_params_band.append(torch.tensor(median_psf))
psf_params.append(torch.stack(psf_params_band).t())

return flux_list, psf_params
3 changes: 0 additions & 3 deletions data/tests/dc2/bkgd.fits

This file was deleted.

3 changes: 0 additions & 3 deletions data/tests/dc2/calexp.fits

This file was deleted.

3 changes: 3 additions & 0 deletions data/tests/dc2/coadd_deconv_image/g-3828-0,0.pt
Git LFS file not shown
Loading

0 comments on commit 106a07a

Please sign in to comment.