Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add data augmentation and dc2 psf/deconv #932

Merged
merged 4 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ data/sdss/rcf_list
data/sdss_all
data/dc2/lsstdesc-public
data/dc2/calibrated_exposures
data/dc2/coadd_deconv_image
data/dc2/merged_catalog
case_studies/sdss_galaxies/models/simulated_blended_galaxies.pt
case_studies/*/data/
venv
Expand Down
4 changes: 4 additions & 0 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,10 @@
for ii in range(self.batch_size):
n_sources = int(self.n_sources[ii].item())
for idx, coords in enumerate(tile_coords[ii][:n_sources]):
if coords[0] >= tile_n_sources.shape[1] or coords[1] >= tile_n_sources.shape[2]:
continue

Check warning on line 492 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L492

Added line #L492 was not covered by tests
# ignore sources outside of the image (usually caused by data augmentation - shift)
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved

source_idx = tile_n_sources[ii, coords[0], coords[1]].item()
if source_idx >= max_sources_per_tile:
if not ignore_extra_sources:
Expand Down
7 changes: 5 additions & 2 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ encoder:
concat_psf_params: false
log_transform: false
rolling_z_score: true
do_data_augmentation: false
architecture:
# this architecture is based on yolov5l.yaml, see
# https://github.com/ultralytics/yolov5/blob/master/models/yolov5l.yaml
Expand Down Expand Up @@ -227,7 +228,9 @@ surveys:
dc2:
_target_: bliss.surveys.dc2.DC2
data_dir: /nfs/turbo/lsa-regier/lsstdesc-public/dc2/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
cat_path: ${paths.dc2}/merged_catalog/merged_catalog_622.pkl
batch_size: 128
cat_path: ${paths.dc2}/merged_catalog/merged_catalog_psf_100.pkl
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved
batch_size: 64
n_split: 50
image_lim: [4000, 4000]
use_deconv_channel: ${encoder.input_transform_params.use_deconv_channel}
deconv_path: ${paths.dc2}/coadd_deconv_image
86 changes: 86 additions & 0 deletions bliss/data_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import random

from einops import rearrange
from torchvision.transforms import functional as TF

from bliss.catalog import TileCatalog


def augment_data(tile_catalog, image):
origin_tile = TileCatalog(4, tile_catalog)
origin_full = origin_tile.to_full_params()
aug_full, aug_image = origin_full, image

rotate_list = [None, aug_rotate90, aug_rotate180, aug_rotate270]
flip_list = [None, aug_vflip]
rotate_choice = random.choice(rotate_list)
flip_choice = random.choice(flip_list)
jeff-regier marked this conversation as resolved.
Show resolved Hide resolved

if rotate_choice is not None:
aug_image, aug_full = rotate_choice(aug_full, aug_image)
if flip_choice is not None:
aug_image, aug_full = flip_choice(aug_full, aug_image)

Check warning on line 22 in bliss/data_augmentation.py

View check run for this annotation

Codecov / codecov/patch

bliss/data_augmentation.py#L22

Added line #L22 was not covered by tests

aug_image, aug_full = aug_shift(aug_full, aug_image)
aug_tile = aug_full.to_tile_params(4, 4).get_brightest_source_per_tile().to_dict()
return aug_image, aug_tile


def aug_vflip(origin_full, image):
aug_image = TF.vflip(image)
image_size = image.size(3)
origin_full["plocs"][:, :, 0] = image_size - origin_full["plocs"][:, :, 0] - 1
return aug_image, origin_full


def aug_rotate90(origin_full, image):
aug_image = rotate_images(image, 90)
image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 0]
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 1] - 1
return aug_image, origin_full


def aug_rotate180(origin_full, image):
aug_image = rotate_images(image, 180)
image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 1] - 1
origin_full["plocs"][:, :, 0] = image_size - plocs[:, :, 0] - 1
return aug_image, origin_full


def aug_rotate270(origin_full, image):
aug_image = rotate_images(image, 270)
image_size = image.size(3)
plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = image_size - plocs[:, :, 0] - 1
origin_full["plocs"][:, :, 0] = plocs[:, :, 1]
return aug_image, origin_full


def rotate_images(image, degree):
num_batch = image.size(0)
combined_image = rearrange(image, "bt bd ch h w -> (bt bd) ch h w")
rotated_image = TF.rotate(combined_image, degree)
return rearrange(rotated_image, "(bt bd) ch h w -> bt bd ch h w", bt=num_batch)


def aug_shift(origin_full, image):
shift_x = random.randint(-1, 2)
shift_y = random.randint(-1, 2)
image_size = image.size(3)
image_lim = [2 - shift_x, 2 - shift_x + image_size, 2 - shift_y, 2 - shift_y + image_size]

num_batch = image.size(0)
combined_image = rearrange(image, "bt bd ch h w -> (bt bd) ch h w")
pad_combined_image = TF.pad(combined_image, (2, 2), padding_mode="reflect")
pad_image = rearrange(pad_combined_image, "(bt bd) ch h w -> bt bd ch h w", bt=num_batch)
aug_image = pad_image[:, :, :, image_lim[0] : image_lim[1], image_lim[2] : image_lim[3]]

plocs = origin_full["plocs"].clone()
origin_full["plocs"][:, :, 1] = plocs[:, :, 1] + shift_y
origin_full["plocs"][:, :, 0] = plocs[:, :, 0] + shift_x

return aug_image, origin_full
22 changes: 22 additions & 0 deletions bliss/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from bliss.backbone import Backbone
from bliss.catalog import FullCatalog, SourceType, TileCatalog
from bliss.data_augmentation import augment_data
from bliss.metrics import BlissMetrics, MetricsMode
from bliss.plotting import plot_detections
from bliss.transforms import log_transform, rolling_z_score
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
optimizer_params: Optional[dict] = None,
scheduler_params: Optional[dict] = None,
input_transform_params: Optional[dict] = None,
do_data_augmentation: bool = False,
):
"""Initializes DetectionEncoder.

Expand All @@ -63,6 +65,7 @@ def __init__(
scheduler_params: arguments passed to the learning rate scheduler
input_transform_params: used for determining what channels to use as input (e.g.
deconvolution, concatenate PSF parameters, z-score inputs, etc.)
do_data_augmentation: used for determining whether or not do data augmentation
"""
super().__init__()
self.save_hyperparameters()
Expand All @@ -77,6 +80,7 @@ def __init__(
self.optimizer_params = optimizer_params
self.scheduler_params = scheduler_params if scheduler_params else {"milestones": []}
self.input_transform_params = input_transform_params
self.do_data_augmentation = do_data_augmentation

transform_enabled = (
"log_transform" in self.input_transform_params
Expand Down Expand Up @@ -385,6 +389,24 @@ def _generic_step(self, batch, logging_name, log_metrics=False, plot_images=Fals

def training_step(self, batch, batch_idx, optimizer_idx=0):
"""Training step (pytorch lightning)."""
if self.do_data_augmentation:
imgs = batch["images"][:, self.bands].unsqueeze(2) # add extra dim for 5d input
bgs = batch["background"][:, self.bands].unsqueeze(2)
aug_input_images = [imgs, bgs]
if self.input_transform_params.get("use_deconv_channel"):
assert (
"deconvolution" in batch
), "use_deconv_channel specified but deconvolution not present in data"
aug_input_images.append(batch["background"][:, self.bands].unsqueeze(2))
aug_input_image = torch.cat(aug_input_images, dim=2)

aug_output_image, tile = augment_data(batch["tile_catalog"], aug_input_image)
batch["images"] = aug_output_image[:, :, 0, :, :]
batch["background"] = aug_output_image[:, :, 1, :, :]
batch["tile_catalog"] = tile
if self.input_transform_params.get("use_deconv_channel"):
batch["deconvolution"] = aug_output_image[:, :, 2, :, :]

return self._generic_step(batch, "train")

def validation_step(self, batch, batch_idx):
Expand Down
58 changes: 42 additions & 16 deletions bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
class DC2(Survey):
BANDS = ("g", "i", "r", "u", "y", "z")

def __init__(self, data_dir, cat_path, batch_size, n_split, image_lim):
def __init__(
self, data_dir, cat_path, batch_size, n_split, image_lim, use_deconv_channel, deconv_path
):
super().__init__()
self.data_dir = data_dir
self.cat_path = cat_path
Expand All @@ -30,6 +32,8 @@ def __init__(self, data_dir, cat_path, batch_size, n_split, image_lim):
self.test = []
self.n_split = n_split
self.image_lim = image_lim
self.use_deconv_channel = use_deconv_channel
self.deconv_path = deconv_path

self._predict_batch = None

Expand All @@ -49,8 +53,8 @@ def image_ids(self):
return [self.dc2_data[i]["images"] for i in range(len(self))]

def prepare_data(self):
img_pattern = "3828/*/calexp*.fits"
bg_pattern = "3828/*/bkgd*.fits"
img_pattern = "**/*/calexp*.fits"
bg_pattern = "**/*/bkgd*.fits"
image_files = []
bg_files = []

Expand All @@ -75,8 +79,10 @@ def prepare_data(self):
plocs_lim = image[0].shape
height = plocs_lim[0]
width = plocs_lim[1]
full_cat = Dc2FullCatalog.from_file(self.cat_path, wcs, height, width, self.bands)
tile_cat = full_cat.to_tile_params(4, 1, 1)
full_cat, psf_params = Dc2FullCatalog.from_file(
self.cat_path, wcs, height, width, self.bands
)
tile_cat = full_cat.to_tile_params(4, 5).get_brightest_source_per_tile()
tile_dict = tile_cat.to_dict()

tile_dict["locs"] = rearrange(tile_cat.to_dict()["locs"], "1 h w nh nw -> h w nh nw")
Expand Down Expand Up @@ -105,14 +111,9 @@ def prepare_data(self):
# split image
split_lim = self.image_lim[0] // self.n_split
image = torch.from_numpy(image)
split1_image = torch.stack(torch.split(image, split_lim, dim=1))
split2_image = torch.stack(torch.split(split1_image, split_lim, dim=3))
split_image = list(torch.split(split2_image.flatten(0, 2), 6))

split_image = split_full_image(image, split_lim)
bg = torch.from_numpy(bg)
split1_bg = torch.stack(torch.split(bg, split_lim, dim=1))
split2_bg = torch.stack(torch.split(split1_bg, split_lim, dim=3))
split_bg = list(torch.split(split2_bg.flatten(0, 2), 6))
split_bg = split_full_image(bg, split_lim)

tile_split = {}
param_list = [
Expand All @@ -133,8 +134,14 @@ def prepare_data(self):
"tile_catalog": [dict(zip(tile_split, i)) for i in zip(*tile_split.values())],
"images": split_image,
"background": split_bg,
"psf_params": [psf_params for _ in range(self.n_split**2)],
}

if self.use_deconv_channel:
file_path = self.deconv_path + "/" + str(image_files[0][n])[-15:-5] + ".pt"
deconv_images = torch.load(file_path)
data_split["deconvolution"] = split_full_image(deconv_images, split_lim)

data.extend([dict(zip(data_split, i)) for i in zip(*data_split.values())])

random.shuffle(data)
Expand Down Expand Up @@ -188,9 +195,7 @@ def from_file(cls, cat_path, wcs, height, width, band):
dec = torch.tensor(catalog["dec"].values)
galaxy_bools = torch.tensor((catalog["truth_type"] == 1).values)
star_bools = torch.tensor((catalog["truth_type"] == 2).values)
flux_list = []
for b in band:
flux_list.append(torch.tensor((catalog["flux_" + b]).values))
flux_list, psf_params = get_band(band, catalog)

flux = torch.stack(flux_list).t()

Expand Down Expand Up @@ -265,7 +270,7 @@ def from_file(cls, cat_path, wcs, height, width, band):
"star_log_fluxes": star_log_fluxes.reshape(1, nobj, 6),
}

return cls(height, width, d)
return cls(height, width, d), torch.stack(psf_params)


def read_frame_for_band(image_files, bg_files, n, n_bands, image_lim):
Expand All @@ -292,3 +297,24 @@ def read_frame_for_band(image_files, bg_files, n, n_bands, image_lim):
bg_list.append(bg)

return image_list, bg_list, wcs


def split_full_image(image, split_lim):
split1_image = torch.stack(torch.split(image, split_lim, dim=1))
split2_image = torch.stack(torch.split(split1_image, split_lim, dim=3))
return list(torch.split(split2_image.flatten(0, 2), 6))


def get_band(band, catalog):
flux_list = []
psf_params = []
for b in band:
flux_list.append(torch.tensor((catalog["flux_" + b]).values))
psf_params_name = ["IxxPSF_pixel_", "IyyPSF_pixel_", "IxyPSF_pixel_", "psf_fwhm_"]
psf_params_band = []
for i in psf_params_name:
median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32)
psf_params_band.append(torch.tensor(median_psf))
psf_params.append(torch.stack(psf_params_band).t())

return flux_list, psf_params
Loading