diff --git a/bliss/cached_dataset.py b/bliss/cached_dataset.py index d9c58f6b9..87863d92f 100644 --- a/bliss/cached_dataset.py +++ b/bliss/cached_dataset.py @@ -69,7 +69,7 @@ def __call__(self, datum_in): class ChunkingSampler(Sampler): def __init__(self, dataset: Dataset) -> None: - super().__init__() + super().__init__(dataset) assert isinstance(dataset, ChunkingDataset), "dataset should be ChunkingDataset" self.dataset = dataset diff --git a/case_studies/galaxy_clustering/README.md b/case_studies/galaxy_clustering/README.md index 5df880db5..e39975475 100644 --- a/case_studies/galaxy_clustering/README.md +++ b/case_studies/galaxy_clustering/README.md @@ -10,28 +10,17 @@ Built on work done in Winter 2024 by [Li Shihang](https://www.linkedin.com/in/sh ## Generation of Data -Data can be generated by running the bash scipt `data-gen.sh`. The bash script has 3 options: -1. `-n`: the number of files to be generated (defaults to 100) -2. `-s`: the size of the image (defaults to 4800) -3. `-t`: the tile size (defaults to 4) - -As an example, the command - -``` -bash data-gen.sh -n 10 -s 2400 -t 8 -``` - -generates 10 images of size 2400 x 2400 tiled with a tile size of 8 x 8. - -The bash command `data-gen.sh` runs three scripts (located under the data_generation directory) for data generation: -1. `catalog_gen.py` which generates catalogs of images and stores them in the data/catalogs subdirectory. Keyword arguments: `image_size` and `nfiles`. -2. `galsim-des.yaml` then reads in these catalogs and uses GalSim to generate corresponding images, which are stored as .fits files (one for each band) in the data/images subdirectory. Keyword arguments: `image_size` and `nfiles`. -3. `file_datum_generation.py` reads in the catalogs and images and saves them as *FileDatum* objects which contain the tile catalog and images in a dictionary. Keyword arguments: `image_size` and `tile_size`. - -Often, after image data has been generated, we would want to retile it with a different tile size. This can be done by just running the file `file_datum_generation.py` with appropriate arguments (you would have to pass in the image size as well since it defaults to 4800). For example, if we have 80 x 80 images that we want to tile with tile size 8, we may run - -``` -python data_generation/file_datum_generation.py image_size=80 tile_size=8 -``` - -Note that you must run the file from the galaxy_clustering directory (since the script takes in the current working directory for the data paths). +The data generation routine proceeds through phases. The entire routine is conveniently wrapped into a single python script `data_gen.py` that draws its parameters from the Hydra configuration, located under `conf/config.yaml` under the `data_gen` key. These phases proceed as follows. + +1. **Catalog Generation.** First, we sample semi-synthetic source catalogs with their relevant properties, which are stored as `.dat` files in the `data_dir/catalogs` subdirectory. +2. **Image Generation.** Then, we take in the aforementioned source catalogs and use GalSim to render them as images, which are stored as `.fits` files (one for each band) in the `data_dir/images` subdirectory. +3. **File Datum Generation.** Finally, we convert the full source catalogs generated in phase 1 into tile catalogs, stack them up with their corresponding images, and store these objects as `.pt` files (which is what the encoder ultimately uses) in the `data_dir/file_data` subdirectory. + +The following parameters can be set within the configuration file `config.yaml`. +1. `data_dir`: the path of the directory where generated data will be stored. +2. `image_size`: size of the image (pixels). +3. `tile_size`: size of tile to be used (pixels). +4. `nfiles`: number of files to be generated. +5. `n_catalogs_per_file`: number of catalogs to be stored in each file datum object. +6. `bands`: survey bands to be used (`["g", "r", "i", "z"]` for DES). +7. `min_flux_for_loss`: minimum flux for filtering. diff --git a/case_studies/galaxy_clustering/config.yaml b/case_studies/galaxy_clustering/conf/config.yaml similarity index 75% rename from case_studies/galaxy_clustering/config.yaml rename to case_studies/galaxy_clustering/conf/config.yaml index 6b1e26260..83c0f0b6c 100644 --- a/case_studies/galaxy_clustering/config.yaml +++ b/case_studies/galaxy_clustering/conf/config.yaml @@ -1,9 +1,17 @@ --- defaults: - - ../../bliss/conf@_here_: base_config + - ../../../bliss/conf@_here_: base_config - _self_ - override hydra/job_logging: stdout +data_gen: + data_dir: /nfs/turbo/lsa-regier/scratch/kapnadak/new_data + image_size: 1280 + tile_size: 128 + nfiles: 5000 + n_catalogs_per_file: 500 + bands: ["g", "r", "i", "z"] + min_flux_for_loss: 0 prior: _target_: case_studies.galaxy_clustering.prior.GalaxyClusterPrior @@ -48,19 +56,15 @@ my_metrics: cluster_membership_acc: _target_: case_studies.galaxy_clustering.encoder.metrics.ClusterMembershipAccuracy +my_image_normalizers: + asinh: + _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer + q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999] + encoder: _target_: case_studies.galaxy_clustering.encoder.encoder.GalaxyClusterEncoder survey_bands: ["g", "r", "i", "z"] - image_normalizer: - _target_: bliss.encoder.image_normalizer.ImageNormalizer - bands: [0, 1, 2, 3] - include_original: true - include_background: false - concat_psf_params: false - num_psf_params: 6 # for SDSS, 4 for DC2 - log_transform_stdevs: null - use_clahe: false - clahe_min_stdev: null + image_normalizers: ${my_image_normalizers} mode_metrics: _target_: torchmetrics.MetricCollection _convert_: "partial" @@ -80,14 +84,21 @@ predict: _target_: case_studies.galaxy_clustering.cached_dataset.CachedDESModule cached_data_path: /nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles tiles_per_img: 64 - batch_size: 1 + batch_size: 2 num_workers: 4 trainer: _target_: pytorch_lightning.Trainer accelerator: "gpu" - devices: "6,7" + devices: [6,5] strategy: "ddp" precision: ${train.trainer.precision} + callbacks: + - ${predict.callbacks.writer} + callbacks: + writer: + _target_: case_studies.galaxy_clustering.inference.inference_callbacks.DESPredictionsWriter + output_dir: "/data/scratch/des/dr2_detection_output/run_1" + write_interval: "batch" encoder: ${encoder} weight_save_path: /nfs/turbo/lsa-regier/scratch/gapatron/best_encoder.ckpt device: "cuda:0" diff --git a/case_studies/galaxy_clustering/data_generation/DES_data_extraction.py b/case_studies/galaxy_clustering/data_generation/DES_data_extraction.py deleted file mode 100644 index a70415a0b..000000000 --- a/case_studies/galaxy_clustering/data_generation/DES_data_extraction.py +++ /dev/null @@ -1,48 +0,0 @@ -# flake8: noqa -import os - -import numpy as np -import requests -from astropy import units -from numpy.core.defchararray import startswith -from pyvo.dal import sia - -from case_studies.galaxy_clustering.utils import cluster_utils as utils - -DES_DATAPATH = os.environ["HOME"] + "/bliss/case_studies/galaxy_clustering/data/DES_images" -if not os.path.exists(DES_DATAPATH): - os.makedirs(DES_DATAPATH) - -DATASET = "des_dr1" -DEF_ACCESS_URL = "https://datalab.noirlab.edu/sia/" + DATASET -svc = sia.SIAService(DEF_ACCESS_URL) - - -def compute_fov(m500, z): - m200 = utils.m500_to_m200(m500, z) * units.solMass - r200 = utils.m200_to_r200(m200, z) - da = utils.angular_diameter_distance(z) - fov = (r200 / da) * (360 / (2 * np.pi)) - return fov.value - - -def download_image(m500, z, ra, dec, band, fov_scale=4): - fov = fov_scale * compute_fov(m500, z) - dec_radian = dec * np.pi / 180 - img_table = svc.search((ra, dec), (fov / np.cos(dec_radian), fov), verbosity=2).to_table() - sel = ( - (img_table["proctype"] == "Stack") - & (img_table["prodtype"] == "image") - & (startswith(img_table["obs_bandpass"].astype(str), band)) - ) - row = img_table[sel][0] - url = row["access_url"] # get the download URL - filename = DES_DATAPATH + "/" + str(ra) + "_" + str(dec) + "_" + band + ".fits" - if not url.lower().startswith("http"): - raise ValueError("URL must start with http") - response = requests.get(url, timeout=200) - if response.status_code == 200: - with open(filename, "wb") as file: - file.write(response.content) - else: - raise ValueError("Failed to download file.") diff --git a/case_studies/galaxy_clustering/data_generation/catalog_gen.py b/case_studies/galaxy_clustering/data_generation/catalog_gen.py deleted file mode 100644 index ea5107975..000000000 --- a/case_studies/galaxy_clustering/data_generation/catalog_gen.py +++ /dev/null @@ -1,41 +0,0 @@ -import os -import sys - -import numpy as np -import pandas as pd -from astropy.io import ascii as astro_ascii -from astropy.table import Table - -from case_studies.galaxy_clustering.data_generation.prior import BackgroundPrior, ClusterPrior - -DATA_PATH = "/home/kapnadak/bliss/case_studies/galaxy_clustering/data" -CATALOG_PATH = os.path.join(DATA_PATH, "catalogs") -FILE_PREFIX = "galsim_des" -CLUSTER_PROB = 0.5 - - -def main(**kwargs): - if not os.path.exists(CATALOG_PATH): - os.makedirs(CATALOG_PATH) - - nfiles = int(kwargs.get("nfiles", 100)) - cluster_prior = ClusterPrior(image_size=int(kwargs.get("image_size", 1280))) - background_prior = BackgroundPrior(image_size=int(kwargs.get("image_size", 1280))) - - combined_catalogs = [] - for _ in range(nfiles): - background_catalog = background_prior.sample_background() - if np.random.uniform() < CLUSTER_PROB: - cluster_catalog = cluster_prior.sample_cluster() - combined_catalogs.append(pd.concat([cluster_catalog, background_catalog])) - else: - combined_catalogs.append(background_catalog) - - for i, catalog in enumerate(combined_catalogs): - file_name = f"{CATALOG_PATH}/{FILE_PREFIX}_{i:03}.dat" - catalog_table = Table.from_pandas(catalog) - astro_ascii.write(catalog_table, file_name, format="no_header", overwrite=True) - - -if __name__ == "__main__": - main(**dict(arg.split("=") for arg in sys.argv[1:])) diff --git a/case_studies/galaxy_clustering/data_generation/data-gen.sh b/case_studies/galaxy_clustering/data_generation/data-gen.sh deleted file mode 100644 index e1ca01eba..000000000 --- a/case_studies/galaxy_clustering/data_generation/data-gen.sh +++ /dev/null @@ -1,67 +0,0 @@ -#!/bin/bash - -while getopts ":n:s:t:" opt; do - case $opt in - n) nfiles="$OPTARG" - ;; - s) image_size="$OPTARG" - ;; - t) tile_size="$OPTARG" - ;; - \?) echo "Invalid option -$OPTARG" >&2 - exit 1 - ;; - esac - - case $OPTARG in - -*) echo "Option $opt needs a valid argument" - exit 1 - ;; - esac -done - -echo "Generating Catalogs..." -if [ -z "$nfiles" ]; then - if [ -z "$image_size" ]; then - python3 catalog_gen.py - else - python3 catalog_gen.py image_size="$image_size" - fi -else - if [ -z "$image_size" ]; then - python3 catalog_gen.py nfiles="$nfiles" - else - python3 catalog_gen.py nfiles="$nfiles" image_size="$image_size" - fi -fi -echo "...Done!" -echo "Generating Images..." -if [ -z "$nfiles" ]; then - if [ -z "$image_size" ]; then - galsim galsim-des.yaml - else - galsim galsim-des.yaml variables.image_size="$image_size" - fi -else - if [ -z "$image_size" ]; then - galsim galsim-des.yaml variables.nfiles="$nfiles" - else - galsim galsim-des.yaml variables.nfiles="$nfiles" variables.image_size="$image_size" - fi -fi -echo "...Done!" -echo "Generating File Datums..." -if [ -z "$image_size" ]; then - if [ -z "$tile_size" ]; then - python3 file_datum_generation.py - else - python3 file_datum_generation.py tile_size="$tile_size" - fi -else - if [ -z "$tile_size" ]; then - python3 file_datum_generation.py image_size="$image_size" - else - python3 file_datum_generation.py image_size="$image_size" tile_size="$tile_size" - fi -fi -echo "...Done! Data Generation Complete!" diff --git a/case_studies/galaxy_clustering/data_generation/data_gen.py b/case_studies/galaxy_clustering/data_generation/data_gen.py new file mode 100644 index 000000000..b1c46c26f --- /dev/null +++ b/case_studies/galaxy_clustering/data_generation/data_gen.py @@ -0,0 +1,177 @@ +# flake8: noqa +import os +import subprocess +from pathlib import Path +from typing import Dict, List + +import hydra +import numpy as np +import pandas as pd +import torch +from astropy.io import ascii as astro_ascii +from astropy.io import fits +from astropy.table import Table + +from bliss.catalog import FullCatalog +from case_studies.galaxy_clustering.data_generation.prior import BackgroundPrior, ClusterPrior + +COL_NAMES = ( + "RA", + "DEC", + "X", + "Y", + "MEM", + "FLUX_G", + "FLUX_R", + "FLUX_I", + "FLUX_Z", + "HLR", + "FRACDEV", + "G1", + "G2", + "Z", + "SOURCE_TYPE", +) + +# ============================== Generate Catalogs ============================== + + +def catalog_gen(cfg): + nfiles = int(cfg.nfiles) + image_size = int(cfg.image_size) + data_dir = cfg.data_dir + catalogs_path = f"{data_dir}/catalogs/" + file_prefix = "galsim_des" + if not os.path.exists(catalogs_path): + os.makedirs(catalogs_path) + cluster_prior = ClusterPrior(image_size=image_size) + background_prior = BackgroundPrior(image_size=image_size) + + combined_catalogs = [] + for _ in range(nfiles): + background_catalog = background_prior.sample_background() + if np.random.uniform() < 0.5: + cluster_catalog = cluster_prior.sample_cluster() + combined_catalogs.append(pd.concat([cluster_catalog, background_catalog])) + else: + combined_catalogs.append(background_catalog) + + for i, catalog in enumerate(combined_catalogs): + file_name = f"{catalogs_path}/{file_prefix}_{i:03}.dat" + catalog_table = Table.from_pandas(catalog) + astro_ascii.write(catalog_table, file_name, format="no_header", overwrite=True) + + +# ============================== Generate Images ============================== + + +def image_gen(cfg): + image_size = cfg.image_size + nfiles = cfg.nfiles + data_dir = cfg.data_dir + input_dir = f"{data_dir}/catalogs" + output_dir = f"{data_dir}/images" + args = [] + args.append("galsim") + args.append("galsim-des.yaml") + args.append(f"variables.image_size={image_size}") + args.append(f"variables.nfiles={nfiles}") + args.append(f"variables.input_dir={input_dir}") + args.append(f"variables.output_dir={output_dir}") + subprocess.run(args, shell=False, check=False) + + +# ============================== Generate File Datums ============================== + + +def file_data_gen(cfg): + image_size = int(cfg.image_size) + tile_size = int(cfg.tile_size) + data_dir = cfg.data_dir + n_catalogs_per_file = int(cfg.n_catalogs_per_file) + bands = cfg.bands + min_flux_for_loss = float(cfg.min_flux_for_loss) + catalogs_path = Path(f"{data_dir}/catalogs/") + images_path = f"{data_dir}/images/" + file_path = f"{data_dir}/file_data/" + if not os.path.exists(file_path): + os.makedirs(file_path) + n_tiles = int(image_size / tile_size) + data: List[Dict] = [] + catalog_counter = 0 + file_counter = 0 + + for catalog_path in catalogs_path.glob("*.dat"): + catalog = pd.read_csv(catalog_path, sep=" ", header=None, names=COL_NAMES) + + catalog_dict = {} + catalog_dict["plocs"] = torch.tensor([catalog[["X", "Y"]].to_numpy()]) + catalog_dict["n_sources"] = torch.sum(catalog_dict["plocs"][:, :, 0] != 0, axis=1) + catalog_dict["galaxy_fluxes"] = torch.tensor( + [catalog[["FLUX_G", "FLUX_R", "FLUX_I", "FLUX_Z"]].to_numpy()] + ) + catalog_dict["star_fluxes"] = torch.zeros_like(catalog_dict["galaxy_fluxes"]) + catalog_dict["membership"] = torch.tensor([catalog[["MEM"]].to_numpy()]) + catalog_dict["redshift"] = torch.tensor([catalog[["Z"]].to_numpy()]) + catalog_dict["galaxy_params"] = torch.tensor( + [catalog[["HLR", "G1", "G2", "FRACDEV"]].to_numpy()] + ) + catalog_dict["source_type"] = torch.ones_like(catalog_dict["membership"]) + full_catalog = FullCatalog(height=image_size, width=image_size, d=catalog_dict) + tile_catalog = full_catalog.to_tile_catalog( + tile_slen=tile_size, + max_sources_per_tile=12 * tile_size, + ) + tile_catalog = tile_catalog.filter_by_flux(min_flux=min_flux_for_loss) + tile_catalog = tile_catalog.get_brightest_sources_per_tile(band=2, exclude_num=0) + + membership_array = np.zeros((n_tiles, n_tiles), dtype=bool) + for i, coords in enumerate(full_catalog["plocs"].squeeze()): + if full_catalog["membership"][0, i, 0] > 0: + tile_coord_y, tile_coord_x = ( + torch.div(coords, tile_size, rounding_mode="trunc").to(torch.int).tolist() + ) + membership_array[tile_coord_x, tile_coord_y] = True + + tile_catalog["membership"] = ( + torch.tensor(membership_array).unsqueeze(0).unsqueeze(3).unsqueeze(4) + ) + + tile_catalog_dict = {} + for key, value in tile_catalog.items(): + tile_catalog_dict[key] = torch.squeeze(value, 0) + + image_bands = [] + for band in bands: + fits_filepath = images_path / Path(f"{catalog_path.stem}_{band}.fits") + with fits.open(fits_filepath) as hdul: + image_data = hdul[0].data.astype(np.float32) + image_bands.append(torch.from_numpy(image_data)) + stacked_image = torch.stack(image_bands, dim=0) + + data.append( + { + "tile_catalog": tile_catalog_dict, + "images": stacked_image, + } + ) + catalog_counter += 1 + if catalog_counter == n_catalogs_per_file: + stackname = f"{file_path}/file_data_{file_counter}_size_{n_catalogs_per_file}.pt" + torch.save(data, stackname) + file_counter += 1 + data, catalog_counter = [], 0 + + +# ============================== CLI ============================== + + +@hydra.main(config_path="../conf", config_name="config", version_base=None) +def main(cfg): + catalog_gen(cfg.data_gen) + image_gen(cfg.data_gen) + file_data_gen(cfg.data_gen) + + +if __name__ == "__main__": + main() diff --git a/case_studies/galaxy_clustering/data_generation/dict_creation.py b/case_studies/galaxy_clustering/data_generation/dict_creation.py new file mode 100644 index 000000000..663f333ba --- /dev/null +++ b/case_studies/galaxy_clustering/data_generation/dict_creation.py @@ -0,0 +1,28 @@ +import os +import pickle +from pathlib import Path + +import pandas as pd +from astropy.io import fits + + +def main(): + des_dir = Path( + "/nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles/" + ) + des_subdirs = [d for d in os.listdir(des_dir) if d.startswith("DES")] + obj_tile_mapping = {} + output_filename = "/data/scratch/des/obj_to_tile.pickle" + for des_subdir in des_subdirs: + catalog_path = des_dir / Path(des_subdir) / Path(f"{des_subdir}_dr2_main.fits") + catalog_data = fits.getdata(catalog_path) + source_df = pd.DataFrame(catalog_data) + for obj_id in source_df["COADD_OBJECT_ID"]: + obj_tile_mapping[obj_id] = des_subdir + + with open(output_filename, "wb") as handle: + pickle.dump(obj_tile_mapping, handle, protocol=pickle.HIGHEST_PROTOCOL) + + +if __name__ == "__main__": + main() diff --git a/case_studies/galaxy_clustering/data_generation/file_datum_generation.py b/case_studies/galaxy_clustering/data_generation/file_datum_generation.py deleted file mode 100644 index 50e6e1471..000000000 --- a/case_studies/galaxy_clustering/data_generation/file_datum_generation.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import sys -from pathlib import Path - -import numpy as np -import pandas as pd -import torch -from astropy.io import fits - -from bliss.catalog import FullCatalog - -min_flux_for_loss = 0 -DATA_PATH = "/home/kapnadak/bliss/case_studies/galaxy_clustering/data" -CATALOGS_PATH = DATA_PATH / Path("catalogs") -IMAGES_PATH = DATA_PATH / Path("images") -FILE_DATA_PATH = DATA_PATH / Path("file_data") -if not os.path.exists(FILE_DATA_PATH): - os.makedirs(FILE_DATA_PATH) -COL_NAMES = ( - "RA", - "DEC", - "X", - "Y", - "MEM", - "FLUX_R", - "FLUX_G", - "FLUX_I", - "FLUX_Z", - "HLR", - "FRACDEV", - "G1", - "G2", - "Z", - "SOURCE_TYPE", -) -BANDS = ("g", "r", "i", "z") -N_CATALOGS_PER_FILE = 50 - - -def main(**kwargs): - image_size = int(kwargs.get("image_size", 1280)) - tile_size = int(kwargs.get("tile_size", 128)) - n_tiles = int(image_size / tile_size) - data = [] - - for catalog_path in CATALOGS_PATH.glob("*.dat"): - catalog = pd.read_csv(catalog_path, sep=" ", header=None, names=COL_NAMES) - - catalog_dict = {} - catalog_dict["plocs"] = torch.tensor([catalog[["X", "Y"]].to_numpy()]) - n_sources = torch.sum(catalog_dict["plocs"][:, :, 0] != 0, axis=1) - catalog_dict["n_sources"] = n_sources - catalog_dict["galaxy_fluxes"] = torch.tensor( - [catalog[["FLUX_R", "FLUX_G", "FLUX_I", "FLUX_Z"]].to_numpy()] - ) - catalog_dict["star_fluxes"] = torch.zeros_like(catalog_dict["galaxy_fluxes"]) - catalog_dict["membership"] = torch.tensor([catalog[["MEM"]].to_numpy()]) - catalog_dict["redshift"] = torch.tensor([catalog[["Z"]].to_numpy()]) - catalog_dict["galaxy_params"] = torch.tensor( - [catalog[["HLR", "G1", "G2", "FRACDEV"]].to_numpy()] - ) - catalog_dict["source_type"] = torch.ones_like(catalog_dict["membership"]) - full_catalog = FullCatalog(height=image_size, width=image_size, d=catalog_dict) - tile_catalog = full_catalog.to_tile_catalog( - tile_slen=tile_size, - max_sources_per_tile=12 * tile_size, - ) - tile_catalog = tile_catalog.filter_by_flux(min_flux=min_flux_for_loss) - tile_catalog = tile_catalog.get_brightest_sources_per_tile(band=2, exclude_num=0) - - membership_array = np.zeros((n_tiles, n_tiles), dtype=bool) - for i, coords in enumerate(full_catalog["plocs"].squeeze()): - if full_catalog["membership"][0, i, 0] > 0: - tile_coord_y, tile_coord_x = ( - torch.div(coords, tile_size, rounding_mode="trunc").to(torch.int).tolist() - ) - membership_array[tile_coord_x, tile_coord_y] = True - - tile_catalog["membership"] = ( - torch.tensor(membership_array).unsqueeze(0).unsqueeze(3).unsqueeze(4) - ) - - tile_catalog_dict = {} - for key, value in tile_catalog.items(): - tile_catalog_dict[key] = torch.squeeze(value, 0) - - filename = catalog_path.stem - image_bands = [] - for band in BANDS: - fits_filepath = IMAGES_PATH / Path(f"{filename}_{band}.fits") - # Should the ordering in the bands matter? It does here. - with fits.open(fits_filepath) as hdul: - image_data = hdul[0].data.astype(np.float32) - image_bands.append(torch.from_numpy(image_data)) - stacked_image = torch.stack(image_bands, dim=0) - - file_datum = { - "tile_catalog": tile_catalog_dict, - "images": stacked_image, - "background": stacked_image, - } - data.append(file_datum) - - chunks = [data[i : i + N_CATALOGS_PER_FILE] for i in range(0, len(data), N_CATALOGS_PER_FILE)] - for i, chunk in enumerate(chunks): - torch.save(chunk, f"{FILE_DATA_PATH}/file_data_{i}_size_{N_CATALOGS_PER_FILE}.pt") - - -if __name__ == "__main__": - main(**dict(arg.split("=") for arg in sys.argv[1:])) diff --git a/case_studies/galaxy_clustering/data_generation/galsim-des.yaml b/case_studies/galaxy_clustering/data_generation/galsim-des.yaml index 0d8a1b564..d189a4798 100644 --- a/case_studies/galaxy_clustering/data_generation/galsim-des.yaml +++ b/case_studies/galaxy_clustering/data_generation/galsim-des.yaml @@ -21,9 +21,9 @@ variables : nfiles : 100 - image_size: 4800 - input_dir : /home/kapnadak/bliss/case_studies/galaxy_clustering/data/catalogs - output_dir : /home/kapnadak/bliss/case_studies/galaxy_clustering/data/images + image_size: 1280 + input_dir : /nfs/turbo/lsa-regier/scratch/kapnadak/new_data_2/catalogs + output_dir : /nfs/turbo/lsa-regier/scratch/kapnadak/new_data_2/images psf: type: Convolve diff --git a/case_studies/galaxy_clustering/data_generation/prior.py b/case_studies/galaxy_clustering/data_generation/prior.py index 1f243c52b..6df7ba9bd 100644 --- a/case_studies/galaxy_clustering/data_generation/prior.py +++ b/case_studies/galaxy_clustering/data_generation/prior.py @@ -18,7 +18,7 @@ class Prior: - def __init__(self, image_size=4800): + def __init__(self, image_size=1280): super().__init__() self.width = image_size self.height = image_size @@ -96,15 +96,15 @@ def make_catalog( source_types, membership, ): - """Makes a background catalog from generated samples. + """Makes a catalog from generated samples. Args: flux_samples: flux samples in all bands hlr_samples: samples of HLR g1_size_samples: samples of G1 g2_size_samples: samples of G2 - gal_locs: samples of background locations in galactic coordinates - cartesian_locs: samples of background locations in cartesian coordinates + gal_locs: samples of locations in galactic coordinates + cartesian_locs: samples of locations in cartesian coordinates source_types: source types for each source (0 for star, 1 for galaxy) membership: background (0) or cluster (1) @@ -131,7 +131,7 @@ def make_catalog( class ClusterPrior(Prior): - def __init__(self, image_size=4800): + def __init__(self, image_size=1280): super().__init__(image_size) self.full_cluster_df = Table.read(CLUSTER_CATALOG_PATH).to_pandas() @@ -246,7 +246,7 @@ def sample_cluster(self): class BackgroundPrior(Prior): - def __init__(self, image_size=4800): + def __init__(self, image_size=1280): super().__init__(image_size) self.pixel_scale = 0.263 @@ -266,15 +266,9 @@ def sample_n_sources(self): def sample_des_catalog(self): """Sample a random DES dataframe.""" tile_choice = random.choice(DES_SUBDIRS) - main_path = DES_DIR / Path(tile_choice) / Path(f"{tile_choice}_dr2_main.fits") - flux_path = DES_DIR / Path(tile_choice) / Path(f"{tile_choice}_dr2_flux.fits") - main_data = fits.getdata(main_path) - main_df = pd.DataFrame(main_data) - flux_data = fits.getdata(flux_path) - flux_df = pd.DataFrame(flux_data) - self.source_df = pd.merge( - main_df, flux_df, left_on="COADD_OBJECT_ID", right_on="COADD_OBJECT_ID", how="left" - ) + catalog_path = DES_DIR / Path(tile_choice) / Path(f"{tile_choice}_dr2_main.fits") + catalog_data = fits.getdata(catalog_path) + self.source_df = pd.DataFrame(catalog_data) def sample_sources(self, n_sources): """Samples random sources from the current DES catalog. @@ -326,6 +320,23 @@ def sample_hlr(self, sources): hlr_samples = self.pixel_scale * np.array(sources["FLUX_RADIUS_R"]) return 1e-4 + (hlr_samples * (hlr_samples > 0)) + def sample_shapes(self, sources): + """Samples shapes for each source in the catalog. + Shapes are from DES Table in the form of (a, b) + Converted to (g1, g2) + + Args: + sources: Dataframe of DES sources + + Returns: + samples for g1, g2 for each source + """ + a = np.array(sources["A_IMAGE"]) + b = np.array(sources["B_IMAGE"]) + g = (a - b) / (a + b) + angle = np.arctan(b / a) + return g * np.cos(angle), g * np.sin(angle) + def sample_fluxes(self, sources): """Samples fluxes for all bands for each source. @@ -335,17 +346,18 @@ def sample_fluxes(self, sources): Returns: 5-band array containing fluxes (clamped at 1 from below) """ - fluxes = np.array( + mags = np.array( sources[ [ - "FLUX_AUTO_G_x", - "FLUX_AUTO_R_x", - "FLUX_AUTO_I_x", - "FLUX_AUTO_Z_x", + "MAG_AUTO_G", + "MAG_AUTO_R", + "MAG_AUTO_I", + "MAG_AUTO_Z", ] ] ) + fluxes = 1000 * convert_mag_to_nmgy(mags) return fluxes * (fluxes > 0) def sample_background(self): @@ -364,7 +376,7 @@ def sample_background(self): gal_source_locs = self.cartesian_to_gal(cartesian_source_locs) source_types = self.sample_source_types(des_sources) flux_samples = self.sample_fluxes(des_sources) - g1_size_samples, g2_size_samples = self.sample_shape(n_sources) + g1_size_samples, g2_size_samples = self.sample_shapes(n_sources) hlr_samples = self.sample_hlr(des_sources) return self.make_catalog( flux_samples, diff --git a/case_studies/galaxy_clustering/encoder/encoder.py b/case_studies/galaxy_clustering/encoder/encoder.py index cb755171d..06a39b95d 100644 --- a/case_studies/galaxy_clustering/encoder/encoder.py +++ b/case_studies/galaxy_clustering/encoder/encoder.py @@ -18,10 +18,10 @@ def initialize_networks(self): """ power_of_two = (self.tile_slen != 0) & (self.tile_slen & (self.tile_slen - 1) == 0) assert power_of_two, "tile_slen must be a power of two" - ch_per_band = self.image_normalizer.num_channels_per_band() + ch_per_band = sum(inorm.num_channels_per_band() for inorm in self.image_normalizers) num_features = 256 self.features_net = GalaxyClusterFeaturesNet( - len(self.image_normalizer.bands), + len(self.survey_bands), ch_per_band, num_features, tile_slen=self.tile_slen, @@ -39,7 +39,8 @@ def get_features_and_parameters(self, batch): batch_size, _n_bands, h, w = batch["images"].shape[0:4] ht, wt = h // self.tile_slen, w // self.tile_slen - x = self.image_normalizer.get_input_tensor(batch) + input_lst = [inorm.get_input_tensor(batch) for inorm in self.image_normalizers] + x = torch.cat(input_lst, dim=2) x_features = self.features_net(x) mask = torch.zeros([batch_size, ht, wt]) context = self.make_context(None, mask).to("cuda") @@ -74,8 +75,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): } def _compute_loss(self, batch, logging_name): - batch_size, _n_bands, h, w = batch["images"].shape[0:4] - ht, wt = h // self.tile_slen, w // self.tile_slen + batch_size = batch["images"].shape[0] target_cat = TileCatalog(self.tile_slen, batch["tile_catalog"]) @@ -91,11 +91,8 @@ def _compute_loss(self, batch, logging_name): # make predictions/inferences pred = {} - x = self.image_normalizer.get_input_tensor(batch) - x_features = self.features_net(x) - mask = torch.zeros([batch_size, ht, wt]) - context = self.make_context(None, mask).to("cuda") - pred["x_cat_marginal"] = self.catalog_net(x_features, context) + x_features, x_cat_marginal = self.get_features_and_parameters(batch) + pred["x_cat_marginal"] = x_cat_marginal x_features = x_features.detach() # is this helpful? doing it here to match old code loss = self.var_dist.compute_nll(pred["x_cat_marginal"], target_cat1) diff --git a/case_studies/galaxy_clustering/DES_inference.py b/case_studies/galaxy_clustering/inference/DES_inference.py similarity index 95% rename from case_studies/galaxy_clustering/DES_inference.py rename to case_studies/galaxy_clustering/inference/DES_inference.py index 5d313a167..e68e56fa7 100644 --- a/case_studies/galaxy_clustering/DES_inference.py +++ b/case_studies/galaxy_clustering/inference/DES_inference.py @@ -24,6 +24,7 @@ def inference(predict_cfg): encoder = load_encoder(predict_cfg) trainer = instantiate(predict_cfg.trainer) dataset = instantiate(predict_cfg.cached_dataset) + enc_output = trainer.predict(encoder, datamodule=dataset) gpu_rank = ( distributed.get_rank() if distributed.is_available() and distributed.is_initialized() else 0 @@ -33,7 +34,7 @@ def inference(predict_cfg): def main(): - with initialize(config_path=".", version_base=None): + with initialize(config_path="../conf", version_base=None): cfg = compose("config") predict_cfg = cfg.predict diff --git a/case_studies/galaxy_clustering/cached_dataset.py b/case_studies/galaxy_clustering/inference/cached_dataset.py similarity index 99% rename from case_studies/galaxy_clustering/cached_dataset.py rename to case_studies/galaxy_clustering/inference/cached_dataset.py index 6d37df5bc..cf43758aa 100644 --- a/case_studies/galaxy_clustering/cached_dataset.py +++ b/case_studies/galaxy_clustering/inference/cached_dataset.py @@ -130,7 +130,7 @@ def _get_dataloader(self, dataset): seed=42, ) else: - sampler = DESSampler() + sampler = DESSampler(dataset) return DataLoader( dataset, diff --git a/case_studies/galaxy_clustering/inference/inference_callbacks.py b/case_studies/galaxy_clustering/inference/inference_callbacks.py new file mode 100644 index 000000000..cd885c89e --- /dev/null +++ b/case_studies/galaxy_clustering/inference/inference_callbacks.py @@ -0,0 +1,30 @@ +import os + +import torch +from pytorch_lightning.callbacks import BasePredictionWriter + + +class DESPredictionsWriter(BasePredictionWriter): + def __init__(self, output_dir, write_interval): + super().__init__(write_interval) + self.output_dir = output_dir + + def write_on_batch_end( + self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx + ): + # this will create N (num processes) files in `output_dir` each containing + # the predictions of it's respective rank + name = f"rank_{trainer.global_rank}_batchIdx_{batch_idx}_dataloaderIdx_{dataloader_idx}.pt" + torch.save( + prediction, + os.path.join( + self.output_dir, + name, + ), + ) + + # optionally, you can also save `batch_indices` to get the information about the data index + # from your prediction data + torch.save( + batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt") + ) diff --git a/case_studies/galaxy_clustering/notebooks/DESInference.ipynb b/case_studies/galaxy_clustering/notebooks/DESInference.ipynb new file mode 100644 index 000000000..19049bafe --- /dev/null +++ b/case_studies/galaxy_clustering/notebooks/DESInference.ipynb @@ -0,0 +1,255 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "import os\n", + "import random" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "DES_DIR = \"/nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles\"\n", + "DES_BANDS = (\"g\", \"r\", \"i\", \"z\")\n", + "DES_SUBDIRS = [d for d in os.listdir(DES_DIR) if d.startswith(\"DES\")]\n", + "tiles_per_img = 64" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def convert_to_global_idx(tile_idx, gpu_idx):\n", + " num_gpus = 2\n", + " tiles_per_img = 64\n", + " batch_size = 2\n", + " dir_idx = int(num_gpus * (tile_idx // (tiles_per_img / batch_size)) + gpu_idx)\n", + " subimage_idx = [(batch_size * tile_idx + i) % tiles_per_img for i in range(batch_size)]\n", + " return dir_idx, subimage_idx\n", + "\n", + "def convert_to_tile_idx(dir_idx):\n", + " num_gpus = 2\n", + " tiles_per_img = 64\n", + " batch_size = 2\n", + " gpu_idx = dir_idx % num_gpus\n", + " tile_starting_idx = (tiles_per_img / batch_size) * (dir_idx // num_gpus)\n", + " return int(tile_starting_idx), int(gpu_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Directory: DES0053-2041\n", + "Subimage: [0, 1]\n" + ] + } + ], + "source": [ + "tile_idx = 0\n", + "gpu_idx = 0\n", + "\n", + "dir_idx, subimage_idx = convert_to_global_idx(tile_idx, gpu_idx)\n", + "print(f\"Directory: {DES_SUBDIRS[dir_idx]}\")\n", + "print(f\"Subimage: {subimage_idx}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([64, 1280, 1280])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "memberships = torch.empty((0,10,10))\n", + "output_dir = \"/data/scratch/des/dr2_detection_output/run_0\"\n", + "dir_idx = 10167\n", + "tile_starting_idx, gpu_idx = convert_to_tile_idx(dir_idx)\n", + "for tile in range(tile_starting_idx, tile_starting_idx + 32):\n", + " file = torch.load(f\"{output_dir}/tile_{tile}_gpu_{gpu_idx}.pt\")\n", + " memberships = torch.cat((memberships, file[\"mode_cat\"][\"membership\"].squeeze()), dim=0)\n", + "\n", + "expanded_memberships = torch.repeat_interleave(memberships, repeats=128, dim=1)\n", + "expanded_memberships = torch.repeat_interleave(expanded_memberships, repeats=128, dim=2)\n", + "expanded_memberships.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "500" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# gpu 0 : 0, 2, 4, 6, 8, ...\n", + "# gpu 1 : 1, 3, 5, 7, 9, ...\n", + "\n", + "# tile_0 : 0, 1 -- image 0\n", + "# tile_1 : 2, 3 -- image 0\n", + "# tile_2 : 4, 5 -- image 0\n", + "# ...\n", + "# tile_31 : 62, 63 -- image 0\n", + "# tile_32 : 64, 65 -- image 1 (image 2 overall)\n", + "\n", + "# tile_t_gpu_g --> dir_id = 2 * (t // 32) + g\n", + "# --> sub_id = (2*t % 64), (2*t + 1) % 64" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def count_num_clusters(dir_idx):\n", + " memberships = torch.empty((0,10,10))\n", + " tile_starting_idx, gpu_idx = convert_to_tile_idx(dir_idx)\n", + " for tile in range(tile_starting_idx, tile_starting_idx + 32):\n", + " file = torch.load(f\"{output_dir}/tile_{tile}_gpu_{gpu_idx}.pt\")\n", + " memberships = torch.cat((memberships, file[\"mode_cat\"][\"membership\"].squeeze()), dim=0)\n", + " memberships = torch.repeat_interleave(memberships, repeats=128, dim=1)\n", + " memberships = torch.repeat_interleave(memberships, repeats=128, dim=2)\n", + " return torch.any(memberships.view(memberships.shape[0], -1), dim=1).sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "random_sample = random.sample(list(enumerate(DES_SUBDIRS)), 100)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "num_clusters = {}\n", + "for dir_idx , dir in random_sample:\n", + " num_clusters[dir] = count_num_clusters(dir_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "num_clusters_vals = list(num_clusters.values())" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.figure(figsize=(8,8))\n", + "plt.xlabel(\"Number of image patches containing clusters in each DES tile\")\n", + "plt.ylabel(\"Counts\")\n", + "plt.title(\"Histogram of random sample of 100 DES tiles\")\n", + "plt.hist(num_clusters_vals)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "41.04" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.array(num_clusters_vals).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}