diff --git a/.gitignore b/.gitignore index 02eefcbe6f..6e9fec5000 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,8 @@ data/sdss/rcf_list data/sdss_all data/dc2/lsstdesc-public data/dc2/calibrated_exposures +data/dc2/coadd_deconv_image +data/dc2/merged_catalog case_studies/sdss_galaxies/models/simulated_blended_galaxies.pt case_studies/*/data/ venv diff --git a/bliss/catalog.py b/bliss/catalog.py index 73a21083c7..da482a8d7a 100644 --- a/bliss/catalog.py +++ b/bliss/catalog.py @@ -488,6 +488,10 @@ def to_tile_params( for ii in range(self.batch_size): n_sources = int(self.n_sources[ii].item()) for idx, coords in enumerate(tile_coords[ii][:n_sources]): + if coords[0] >= tile_n_sources.shape[1] or coords[1] >= tile_n_sources.shape[2]: + continue + # ignore sources outside of the image (usually caused by data augmentation - shift) + source_idx = tile_n_sources[ii, coords[0], coords[1]].item() if source_idx >= max_sources_per_tile: if not ignore_extra_sources: diff --git a/bliss/conf/base_config.yaml b/bliss/conf/base_config.yaml index 6f23796b8a..4d7c3bc709 100644 --- a/bliss/conf/base_config.yaml +++ b/bliss/conf/base_config.yaml @@ -84,6 +84,9 @@ encoder: concat_psf_params: false log_transform: false rolling_z_score: true + data_augmentation: + do_data_augmentation: true + epoch_start: 1 architecture: # this architecture is based on yolov5l.yaml, see # https://github.com/ultralytics/yolov5/blob/master/models/yolov5l.yaml @@ -227,7 +230,9 @@ surveys: dc2: _target_: bliss.surveys.dc2.DC2 data_dir: /nfs/turbo/lsa-regier/lsstdesc-public/dc2/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/ - cat_path: ${paths.dc2}/merged_catalog/merged_catalog_622.pkl - batch_size: 128 + cat_path: ${paths.dc2}/merged_catalog/merged_catalog_psf_100.pkl + batch_size: 64 n_split: 50 image_lim: [4000, 4000] + use_deconv_channel: ${encoder.input_transform_params.use_deconv_channel} + deconv_path: ${paths.dc2}/coadd_deconv_image diff --git a/bliss/data_augmentation.py b/bliss/data_augmentation.py new file mode 100644 index 0000000000..838be87402 --- /dev/null +++ b/bliss/data_augmentation.py @@ -0,0 +1,117 @@ +import random + +import numpy as np +import torchvision as TF + +from bliss.catalog import TileCatalog + + +def augment_data(tile_catalog, image, deconv_image=None): + origin_tile = TileCatalog(4, tile_catalog) + origin_full = origin_tile.to_full_params() + + num_transform_list = [1, 2, 3] + num_transform = random.choices(num_transform_list, weights=[0.7, 0.2, 0.1], k=1)[0] + transform_list = ["vflip", "hflip", "rotate90", "rotate180", "rotate270", "shift"] + + aug_method = np.random.choice( + transform_list, num_transform, p=(0.1, 0.1, 0.1, 0.1, 0.1, 0.5), replace=False + ) + + aug_image, aug_tile, aug_deconv = origin_full, image, deconv_image + for i in aug_method: + if i == "vflip": + aug_image, aug_tile, aug_deconv = aug_vflip(aug_image, aug_tile, aug_deconv) + if i == "hflip": + aug_image, aug_tile, aug_deconv = aug_hflip(aug_image, aug_tile, aug_deconv) + if i == "rotate90": + aug_image, aug_tile, aug_deconv = aug_rotate90(aug_image, aug_tile, aug_deconv) + if i == "rotate180": + aug_image, aug_tile, aug_deconv = aug_rotate180(aug_image, aug_tile, aug_deconv) + if i == "rotate270": + aug_image, aug_tile, aug_deconv = aug_rotate270(aug_image, aug_tile, aug_deconv) + if i == "shift": + aug_image, aug_tile, aug_deconv = aug_shift(aug_image, aug_tile, aug_deconv) + + return aug_image, aug_tile, aug_deconv + + +def aug_vflip(origin_full, image, deconv_image): + aug_image = TF.transforms.functional.vflip(image) + aug_deconv = None + if deconv_image is not None: + aug_deconv = TF.transforms.functional.vflip(deconv_image) + image_size = image.size(2) + origin_full["plocs"][:, :, 0] = image_size - origin_full["plocs"][:, :, 0] - 1 + aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() + return aug_image, aug_tile, aug_deconv + + +def aug_hflip(origin_full, image, deconv_image): + aug_image = TF.transforms.functional.hflip(image) + aug_deconv = None + if deconv_image is not None: + aug_deconv = TF.transforms.functional.hflip(deconv_image) + image_size = image.size(2) + origin_full["plocs"][:, :, 1] = image_size - origin_full["plocs"][:, :, 1] - 1 + aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() + return aug_image, aug_tile, aug_deconv + + +def aug_rotate90(origin_full, image, deconv_image): + aug_image = TF.transforms.functional.rotate(image, 90) + aug_deconv = None + if deconv_image is not None: + aug_deconv = TF.transforms.functional.rotate(deconv_image, 90) + image_size = image.size(2) + plocs = origin_full["plocs"].clone() + origin_full["plocs"][:, :, 1] = plocs[:, :, 0] + origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 1] - 1 + aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() + return aug_image, aug_tile, aug_deconv + + +def aug_rotate180(origin_full, image, deconv_image): + aug_image = TF.transforms.functional.rotate(image, 180) + aug_deconv = None + if deconv_image is not None: + aug_deconv = TF.transforms.functional.rotate(deconv_image, 180) + image_size = image.size(2) + plocs = origin_full["plocs"].clone() + origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 1] - 1 + origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 0] - 1 + aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() + return aug_image, aug_tile, aug_deconv + + +def aug_rotate270(origin_full, image, deconv_image): + aug_image = TF.transforms.functional.rotate(image, 270) + aug_deconv = None + if deconv_image is not None: + aug_deconv = TF.transforms.functional.rotate(deconv_image, 270) + image_size = image.size(2) + plocs = origin_full["plocs"].clone() + origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 0] - 1 + origin_full["plocs"][:, :, 0] = plocs[:, :, 1] + aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() + return aug_image, aug_tile, aug_deconv + + +def aug_shift(origin_full, image, deconv_image): + shift_x = random.randint(0, 3) + shift_y = random.randint(0, 3) + shift_xy = (shift_x, shift_y) + image_size = image.size(2) + aug_deconv = None + pad_image = TF.transforms.functional.pad(image, shift_xy, padding_mode="reflect") + if deconv_image is not None: + pad_deconv = TF.transforms.functional.pad(deconv_image, shift_xy, padding_mode="reflect") + aug_deconv = pad_deconv[:, :, :image_size, :image_size] + + aug_image = pad_image[:, :, :image_size, :image_size] + plocs = origin_full["plocs"].clone() + origin_full["plocs"][:, :, 1] = plocs[:, :, 1] + shift_x + origin_full["plocs"][:, :, 0] = plocs[:, :, 0] + shift_y + + aug_tile = origin_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict() + return aug_image, aug_tile, aug_deconv diff --git a/bliss/encoder.py b/bliss/encoder.py index 88a6d789a3..14fe63f3d1 100644 --- a/bliss/encoder.py +++ b/bliss/encoder.py @@ -15,6 +15,7 @@ from bliss.backbone import Backbone from bliss.catalog import FullCatalog, SourceType, TileCatalog +from bliss.data_augmentation import augment_data from bliss.metrics import BlissMetrics, MetricsMode from bliss.plotting import plot_detections from bliss.transforms import log_transform, rolling_z_score @@ -48,6 +49,7 @@ def __init__( optimizer_params: Optional[dict] = None, scheduler_params: Optional[dict] = None, input_transform_params: Optional[dict] = None, + data_augmentation: Optional[dict] = None, ): """Initializes DetectionEncoder. @@ -63,6 +65,8 @@ def __init__( scheduler_params: arguments passed to the learning rate scheduler input_transform_params: used for determining what channels to use as input (e.g. deconvolution, concatenate PSF parameters, z-score inputs, etc.) + data_augmentation: used for determining whether or not do data augmentation and the + data augmentation start point (a.g. after 1 epoch) """ super().__init__() self.save_hyperparameters() @@ -77,6 +81,7 @@ def __init__( self.optimizer_params = optimizer_params self.scheduler_params = scheduler_params if scheduler_params else {"milestones": []} self.input_transform_params = input_transform_params + self.data_augmentation = data_augmentation transform_enabled = ( "log_transform" in self.input_transform_params @@ -330,6 +335,23 @@ def _get_loss(self, pred: Dict[str, Distribution], true_tile_cat: TileCatalog): return loss_with_components def _generic_step(self, batch, logging_name, log_metrics=False, plot_images=False): + do_data_augmentation = self.data_augmentation.get("do_data_augmentation") + data_augmentation_start = self.current_epoch >= self.data_augmentation.get("epoch_start") + + if data_augmentation_start and (logging_name == "train") and do_data_augmentation: + deconv_image = None + if self.input_transform_params.get("use_deconv_channel"): + assert ( + "deconvolution" in batch + ), "use_deconv_channel specified but deconvolution not present in data" + deconv_image = batch["deconvolution"] + + image, tile, deconv = augment_data(batch["tile_catalog"], batch["images"], deconv_image) + batch["images"] = image + batch["tile_catalog"] = tile + if self.input_transform_params.get("use_deconv_channel"): + batch["deconvolution"] = deconv + batch_size = batch["images"].size(0) pred = self.encode_batch(batch) true_tile_cat = TileCatalog(self.tile_slen, batch["tile_catalog"]) diff --git a/bliss/surveys/dc2.py b/bliss/surveys/dc2.py index 13f9e13d0f..4e9cc47f2c 100644 --- a/bliss/surveys/dc2.py +++ b/bliss/surveys/dc2.py @@ -17,7 +17,9 @@ class DC2(Survey): BANDS = ("g", "i", "r", "u", "y", "z") - def __init__(self, data_dir, cat_path, batch_size, n_split, image_lim): + def __init__( + self, data_dir, cat_path, batch_size, n_split, image_lim, use_deconv_channel, deconv_path + ): super().__init__() self.data_dir = data_dir self.cat_path = cat_path @@ -30,6 +32,8 @@ def __init__(self, data_dir, cat_path, batch_size, n_split, image_lim): self.test = [] self.n_split = n_split self.image_lim = image_lim + self.use_deconv_channel = use_deconv_channel + self.deconv_path = deconv_path self._predict_batch = None @@ -49,8 +53,8 @@ def image_ids(self): return [self.dc2_data[i]["images"] for i in range(len(self))] def prepare_data(self): - img_pattern = "3828/*/calexp*.fits" - bg_pattern = "3828/*/bkgd*.fits" + img_pattern = "**/*/calexp*.fits" + bg_pattern = "**/*/bkgd*.fits" image_files = [] bg_files = [] @@ -75,8 +79,10 @@ def prepare_data(self): plocs_lim = image[0].shape height = plocs_lim[0] width = plocs_lim[1] - full_cat = Dc2FullCatalog.from_file(self.cat_path, wcs, height, width, self.bands) - tile_cat = full_cat.to_tile_params(4, 1, 1) + full_cat, psf_params = Dc2FullCatalog.from_file( + self.cat_path, wcs, height, width, self.bands + ) + tile_cat = full_cat.to_tile_params(4, 5).get_brightest_source_per_tile() tile_dict = tile_cat.to_dict() tile_dict["locs"] = rearrange(tile_cat.to_dict()["locs"], "1 h w nh nw -> h w nh nw") @@ -105,14 +111,9 @@ def prepare_data(self): # split image split_lim = self.image_lim[0] // self.n_split image = torch.from_numpy(image) - split1_image = torch.stack(torch.split(image, split_lim, dim=1)) - split2_image = torch.stack(torch.split(split1_image, split_lim, dim=3)) - split_image = list(torch.split(split2_image.flatten(0, 2), 6)) - + split_image = split_full_image(image, split_lim) bg = torch.from_numpy(bg) - split1_bg = torch.stack(torch.split(bg, split_lim, dim=1)) - split2_bg = torch.stack(torch.split(split1_bg, split_lim, dim=3)) - split_bg = list(torch.split(split2_bg.flatten(0, 2), 6)) + split_bg = split_full_image(bg, split_lim) tile_split = {} param_list = [ @@ -133,8 +134,14 @@ def prepare_data(self): "tile_catalog": [dict(zip(tile_split, i)) for i in zip(*tile_split.values())], "images": split_image, "background": split_bg, + "psf_params": [psf_params for _ in range(self.n_split**2)], } + if self.use_deconv_channel: + file_path = self.deconv_path + "/" + str(image_files[0][n])[-15:-5] + ".pt" + deconv_images = torch.load(file_path) + data_split["deconvolution"] = split_full_image(deconv_images, split_lim) + data.extend([dict(zip(data_split, i)) for i in zip(*data_split.values())]) random.shuffle(data) @@ -188,9 +195,7 @@ def from_file(cls, cat_path, wcs, height, width, band): dec = torch.tensor(catalog["dec"].values) galaxy_bools = torch.tensor((catalog["truth_type"] == 1).values) star_bools = torch.tensor((catalog["truth_type"] == 2).values) - flux_list = [] - for b in band: - flux_list.append(torch.tensor((catalog["flux_" + b]).values)) + flux_list, psf_params = get_band(band, catalog) flux = torch.stack(flux_list).t() @@ -265,7 +270,7 @@ def from_file(cls, cat_path, wcs, height, width, band): "star_log_fluxes": star_log_fluxes.reshape(1, nobj, 6), } - return cls(height, width, d) + return cls(height, width, d), torch.stack(psf_params) def read_frame_for_band(image_files, bg_files, n, n_bands, image_lim): @@ -292,3 +297,24 @@ def read_frame_for_band(image_files, bg_files, n, n_bands, image_lim): bg_list.append(bg) return image_list, bg_list, wcs + + +def split_full_image(image, split_lim): + split1_image = torch.stack(torch.split(image, split_lim, dim=1)) + split2_image = torch.stack(torch.split(split1_image, split_lim, dim=3)) + return list(torch.split(split2_image.flatten(0, 2), 6)) + + +def get_band(band, catalog): + flux_list = [] + psf_params = [] + for b in band: + flux_list.append(torch.tensor((catalog["flux_" + b]).values)) + psf_params_name = ["IxxPSF_pixel_", "IyyPSF_pixel_", "IxyPSF_pixel_", "psf_fwhm_"] + psf_params_band = [] + for i in psf_params_name: + median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32) + psf_params_band.append(torch.tensor(median_psf)) + psf_params.append(torch.stack(psf_params_band).t()) + + return flux_list, psf_params diff --git a/data/tests/dc2/bkgd.fits b/data/tests/dc2/bkgd.fits deleted file mode 100644 index 50c153e8a4..0000000000 --- a/data/tests/dc2/bkgd.fits +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d972241bbf7da0bf2bb06edee4735e72d82ab13a479f636a90c86b074e3022c7 -size 5760 diff --git a/data/tests/dc2/calexp.fits b/data/tests/dc2/calexp.fits deleted file mode 100644 index 0861b21b05..0000000000 --- a/data/tests/dc2/calexp.fits +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ad9b281c0300b1d82cd852066894d4ee364989ceff56655d4550b0f33a548fd3 -size 645120 diff --git a/data/tests/dc2/coadd_deconv_image/g-3828-0,0.pt b/data/tests/dc2/coadd_deconv_image/g-3828-0,0.pt new file mode 100644 index 0000000000..a9f5afc803 --- /dev/null +++ b/data/tests/dc2/coadd_deconv_image/g-3828-0,0.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:119fa5997c16cc67332fbd134d7e74973d4c105821358f290f7cf28f306d62fc +size 384000756 diff --git a/tests/test_dc2.py b/tests/test_dc2.py index ef29e504b0..674a718d08 100644 --- a/tests/test_dc2.py +++ b/tests/test_dc2.py @@ -1,6 +1,9 @@ import torch +from einops import rearrange from hydra.utils import instantiate +from bliss import data_augmentation +from bliss.catalog import TileCatalog from bliss.train import train @@ -38,7 +41,7 @@ def test_dc2(self, cfg): for k in params: assert isinstance(dc2_tile[k], torch.Tensor) - for i in ("images", "background"): + for i in ("images", "background", "psf_params"): assert isinstance(dc2_obj[i], torch.Tensor) def test_train_on_dc2(self, cfg): @@ -46,5 +49,51 @@ def test_train_on_dc2(self, cfg): train_dc2_cfg.encoder.bands = [0, 1, 2, 3, 4, 5] train_dc2_cfg.encoder.survey_bands = ["g", "i", "r", "u", "y", "z"] train_dc2_cfg.training.data_source = train_dc2_cfg.surveys.dc2 + train_dc2_cfg.encoder.input_transform_params.use_deconv_channel = True + train_dc2_cfg.encoder.data_augmentation.epoch_start = 0 train_dc2_cfg.training.pretrained_weights = None train(train_dc2_cfg) + + def test_dc2_augmentation(self, cfg): + train_dc2_cfg = cfg.copy() + train_dc2_cfg.encoder.input_transform_params.use_deconv_channel = True + + dataset = instantiate(train_dc2_cfg.surveys.dc2) + dataset.prepare_data() + dc2_obj = dataset.dc2_data[0] + + tile_dict = {} + dc2_tile = dc2_obj["tile_catalog"] + tile_dict["locs"] = rearrange(dc2_tile["locs"], "h w nh nw -> 1 h w nh nw") + tile_dict["n_sources"] = rearrange(dc2_tile["n_sources"], "h w -> 1 h w") + tile_dict["source_type"] = rearrange(dc2_tile["source_type"], "h w nh nw -> 1 h w nh nw") + tile_dict["galaxy_fluxes"] = rearrange( + dc2_tile["galaxy_fluxes"], "h w nh nw -> 1 h w nh nw" + ) + tile_dict["galaxy_params"] = rearrange( + dc2_tile["galaxy_params"], "h w nh nw -> 1 h w nh nw" + ) + tile_dict["star_fluxes"] = rearrange(dc2_tile["star_fluxes"], "h w nh nw -> 1 h w nh nw") + tile_dict["star_log_fluxes"] = rearrange( + dc2_tile["star_log_fluxes"], "h w nh nw -> 1 h w nh nw" + ) + origin_tile = TileCatalog(4, tile_dict) + origin_full = origin_tile.to_full_params() + + image = rearrange(dc2_obj["images"], "b h w -> 1 b h w") + deconv_image = rearrange(dc2_obj["deconvolution"], "b h w -> 1 b h w") + + aug_list = [ + data_augmentation.aug_vflip, + data_augmentation.aug_hflip, + data_augmentation.aug_rotate90, + data_augmentation.aug_rotate180, + data_augmentation.aug_rotate270, + data_augmentation.aug_shift, + ] + + for i in aug_list: + aug_image, aug_tile, aug_deconv = i(origin_full, image, deconv_image) + assert aug_image.shape == image.shape + assert aug_deconv.shape == deconv_image.shape + assert aug_tile["n_sources"].sum() <= origin_full.n_sources diff --git a/tests/testing_config.yaml b/tests/testing_config.yaml index 754b201874..a1afe7fcff 100644 --- a/tests/testing_config.yaml +++ b/tests/testing_config.yaml @@ -58,6 +58,7 @@ surveys: fields: [12] dc2: data_dir: ${paths.data}/tests/dc2/dc2_multiband/ - cat_path: ${paths.dc2}/merged_catalog/merged_catalog_100000.pkl + cat_path: ${paths.dc2}/merged_catalog/merged_catalog_psf_100000.pkl n_split: 5 image_lim: [800, 800] + deconv_path: ${paths.data}/tests/dc2/coadd_deconv_image