Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor predict.py #943

Merged
merged 4 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,8 @@ ignore =
# there are a number of places where a float zero value makes sense
WPS358

# what's wrong a multiline try block?
WPS229,

# we should consider fixing some of these issues
W503,WPS202,WPS213,WPS214,WPS231,WPS236,WPS336,WPS338,WPS440,WPS441,WPS442,WPS602
9 changes: 7 additions & 2 deletions bliss/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from bliss.catalog import FullCatalog, SourceType
from bliss.conf.igs import base_config
from bliss.generate import generate as _generate
from bliss.predict import predict as _predict
from bliss.predict import predict_and_compare as _predict_and_compare
from bliss.surveys.sdss import SDSSDownloader
from bliss.train import train as _train
from bliss.utils.download_utils import download_git_lfs_file
Expand Down Expand Up @@ -159,7 +159,12 @@ def predict(
cfg.predict.dataset = "${surveys." + survey + "}"
for k, v in kwargs.items():
OmegaConf.update(cfg, k, v)
est_cat, _, _, _, pred_for_image_id = _predict(cfg)

# `predict.predict_and_compare` isn't really the right function to call here;
# it doesn't simply make predictions using bliss, it also loads survey catalogs.
# instead, we should implement and call `predict.bulk_predict`, which would use
# `trainer.predict(encoder, datamodule=dataset)` to make predictions
est_cat, _, _, _, pred_for_image_id = _predict_and_compare(cfg)
est_cat_table = fullcat_to_astropy_table(est_cat, cfg.encoder.survey_bands)
pred_tables = {} # indexed by image_id
for image_id, pred in pred_for_image_id.items():
Expand Down
14 changes: 7 additions & 7 deletions bliss/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,13 @@ def test_step(self, batch, batch_idx):

def predict_step(self, batch, batch_idx, dataloader_idx=0):
"""Pytorch lightning method."""

return {
"est_cat": self.sample(batch, use_mode=True),
# a marginal catalog isn't really what we want here, perhaps
# we should return samples from the variation distribution instead
"pred": None,
}
with torch.no_grad():
return {
"est_cat": self.sample(batch, use_mode=True),
# a marginal catalog isn't really what we want here, perhaps
# we should return samples from the variation distribution instead
"pred": None,
}

def configure_optimizers(self):
"""Configure optimizers for training (pytorch lightning)."""
Expand Down
222 changes: 114 additions & 108 deletions bliss/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,21 @@
from bliss.surveys.sdss import SloanDigitalSkySurvey as SDSS


def prepare_image(x, device):
x = torch.from_numpy(x).unsqueeze(0)
x = x.to(device=device)
def crop_image(x, cfg):
"""Crop the image to a subregion for prediction."""
# image dimensions must be a multiple of 16
height = x.size(2) - (x.size(2) % 16)
width = x.size(3) - (x.size(3) % 16)
return x[:, :, :height, :width]
x = x[:, :, :height, :width].float()

cp = cfg.predict.crop
if cp.do_crop:
top_left_y, top_left_x = cp.left_upper_corner
w, h = cp.width, cp.height
if ((top_left_y + h) <= x.shape[2]) and ((top_left_x + w) <= x.shape[3]):
x = x[:, :, top_left_y : top_left_y + h, top_left_x : top_left_x + w]

def crop_image(image, background, crop_params):
"""Crop the image (and background) to a subregion for prediction."""
if not crop_params.do_crop:
return image, background

top_left_y = crop_params.left_upper_corner[0]
top_left_x = crop_params.left_upper_corner[1]
width = crop_params.width
height = crop_params.height
if ((top_left_y + height) <= image.shape[2]) and ((top_left_x + width) <= image.shape[3]):
image = image[:, :, top_left_y : top_left_y + height, top_left_x : top_left_x + width]
background = background[
:, :, top_left_y : top_left_y + height, top_left_x : top_left_x + width
]
return image, background


def prepare_batch(images, backgrounds):
batch = {"images": images, "background": backgrounds}
batch["images"] = batch["images"].squeeze(0)
batch["background"] = batch["background"].squeeze(0)
return batch
return x


def align(img, wcs_list, ref_band, ref_depth=0):
Expand All @@ -63,6 +47,7 @@
for d in range(coadd_depth):
for bnd in range(n_bands):
inputs = (img[d, bnd], wcs_list[d][bnd])
# the next line is the computational bottleneck
reproj, footprint = reproject_interp(
inputs, target_wcs, order="bicubic", shape_out=(h, w)
)
Expand Down Expand Up @@ -100,97 +85,114 @@
return est_cat


def predict(cfg):
survey = instantiate(cfg.predict.dataset, load_image_data=True)
def predict_frame(cfg, frame, encoder):
aligned_images = align(
frame["image"],
wcs_list=frame["wcs"],
ref_band=cfg.simulator.prior.reference_band,
)
aligned_backgrounds = align(
frame["background"],
wcs_list=frame["wcs"],
ref_band=cfg.simulator.prior.reference_band,
)

device = encoder.device
images = torch.from_numpy(aligned_images).unsqueeze(0).to(device=device)
backgrounds = torch.from_numpy(aligned_backgrounds).unsqueeze(0).to(device=device)

# cropping the images should be done by the dataset
images = crop_image(images, cfg)
backgrounds = crop_image(backgrounds, cfg)

batch = {"images": images, "background": backgrounds}
est_cat, pred = encoder.predict_step(batch, None).values()

# mean of the nelec_per_mgy per band
nelec_per_nmgy_per_band = np.mean(frame["flux_calibration_list"], axis=1)
est_cat = nelec_to_nmgy_for_catalog(est_cat, nelec_per_nmgy_per_band)
est_full = est_cat.to_full_params()

return est_full, pred, images.squeeze(0), backgrounds.squeeze(0)


def predict_and_compare(cfg):
dataset = instantiate(cfg.predict.dataset, load_image_data=True)

# below collections indexed by image_id
images_for_frame = {}
radecs_for_frame = {}
backgrounds_for_frame = {}
preds_for_frame = {}

plocs_all = None
est_full_all = None # collated catalog for all images
survey_objs = [survey[i] for i in range(len(survey))]
for i, survey_obj in enumerate(survey_objs):
survey_obj["image"] = align(
survey_obj["image"],
wcs_list=survey_obj["wcs"],
ref_band=cfg.simulator.prior.reference_band,
)
survey_obj["background"] = align(
survey_obj["background"],
wcs_list=survey_obj["wcs"],
ref_band=cfg.simulator.prior.reference_band,
)
# collated catalog for all images
survey_plocs_all = None
bliss_cat_all = None

cat_path = survey.downloader.download_catalog(survey.image_id(i))
plocs = survey.catalog_cls.from_file(
cat_path=cat_path,
wcs=survey_obj["wcs"][cfg.simulator.prior.reference_band],
height=survey_obj["image"].shape[1],
width=survey_obj["image"].shape[2],
).plocs[0]
encoder = instantiate(cfg.encoder).to(cfg.predict.device)
enc_state_dict = torch.load(cfg.predict.weight_save_path)
encoder.load_state_dict(enc_state_dict)
encoder.eval()

# get RA, Dec of the center of the image
ra, dec = survey_obj["wcs"][cfg.simulator.prior.reference_band].all_pix2world(
survey_obj["image"].shape[2] / 2, survey_obj["image"].shape[1] / 2, 0
)
radecs_for_frame[survey.image_id(i)] = (ra.item(), dec.item())

encoder = instantiate(cfg.encoder).to(cfg.predict.device)
enc_state_dict = torch.load(cfg.predict.weight_save_path)
encoder.load_state_dict(enc_state_dict)
encoder.eval()
trainer = instantiate(cfg.predict.trainer)
images = prepare_image(survey_obj["image"], cfg.predict.device).float()
backgrounds = prepare_image(survey_obj["background"], cfg.predict.device).float()
images, backgrounds = crop_image(images, backgrounds, cfg.predict.crop)
survey.predict_batch = prepare_batch(images, backgrounds)
est_cat, pred = trainer.predict(encoder, datamodule=survey)[0].values()

# mean of the nelec_per_mgy per band
nelec_per_nmgy_per_band = np.mean(survey_obj["flux_calibration_list"], axis=1)
est_cat = nelec_to_nmgy_for_catalog(est_cat, nelec_per_nmgy_per_band)
est_full = est_cat.to_full_params()

images_for_frame[survey.image_id(i)] = images
backgrounds_for_frame[survey.image_id(i)] = backgrounds
preds_for_frame[survey.image_id(i)] = pred

if plocs_all is None:
plocs_all = plocs
else:
plocs_all = torch.cat((plocs_all, plocs), dim=0)
plocs_all = torch.unique(plocs_all, dim=0)
for i, frame in enumerate(dataset):
images_id = dataset.image_id(i) # should be called frame_id, not image_id

if not est_full_all:
est_full_all = est_full
bliss_cat, bliss_pred, images, backgrounds = predict_frame(cfg, frame, encoder)
images_for_frame[images_id] = images
backgrounds_for_frame[images_id] = backgrounds
preds_for_frame[images_id] = bliss_pred

# we are merging bliss catalogs as we go here, but we shouldn't be!
# we should keep each bliss_cat associated with a particular frame_id
if bliss_cat_all is None:
bliss_cat_all = bliss_cat
else:
d = {}
d["plocs"] = torch.cat((est_full_all.plocs, est_full.plocs), dim=1)
d["n_sources"] = Tensor([est_full_all.n_sources + est_full.n_sources])
est_full_all_dict = est_full_all.to_dict()
for k, v in est_full.items():
d["plocs"] = torch.cat((bliss_cat_all.plocs, bliss_cat.plocs), dim=1)
d["n_sources"] = Tensor([bliss_cat_all.n_sources + bliss_cat.n_sources])
est_full_all_dict = bliss_cat_all.to_dict()
for k, v in bliss_cat.items():
d[k] = torch.cat((est_full_all_dict[k], v), dim=1)
est_full_all = FullCatalog(est_full_all.height, est_full_all.width, d)
bliss_cat_all = FullCatalog(bliss_cat_all.height, bliss_cat_all.width, d)

# store the RA and Dec for the center of the image
ra, dec = frame["wcs"][cfg.simulator.prior.reference_band].all_pix2world(
frame["image"].shape[2] / 2, frame["image"].shape[1] / 2, 0
)
radecs_for_frame[images_id] = (ra.item(), dec.item())

# now we load the survey's catalog from this frame
survey_cat_path = dataset.downloader.download_catalog(images_id)
try:
survey_cat = dataset.catalog_cls.from_file(
cat_path=survey_cat_path,
wcs=frame["wcs"][cfg.simulator.prior.reference_band],
height=frame["image"].shape[1],
width=frame["image"].shape[2],
)
survey_plocs = survey_cat.plocs[0]
except IndexError:
survey_plocs = None

Check warning on line 174 in bliss/predict.py

View check run for this annotation

Codecov / codecov/patch

bliss/predict.py#L173-L174

Added lines #L173 - L174 were not covered by tests

# we are merging survey catalogs as we go here, but we shouldn't be!
# we should keep each catalog associated with a particular frame_id
if survey_plocs_all is None:
survey_plocs_all = survey_plocs
else:
survey_plocs_all = torch.cat((survey_plocs_all, survey_plocs), dim=0)
survey_plocs_all = torch.unique(survey_plocs_all, dim=0)

assert est_full_all is not None and isinstance(
est_full_all, FullCatalog
), "Should have estimated catalog for at least one image"
if cfg.predict.plot.show_plot and (plocs_all is not None):
# better not to have this here...the caller should call plot_predict directly
if cfg.predict.plot.show_plot and (survey_plocs_all is not None):
plot_predict(
cfg,
images_for_frame,
backgrounds_for_frame,
radecs_for_frame,
plocs_all,
est_full_all,
survey_plocs_all,
bliss_cat_all,
)

images_for_frame = {k: v[0] for k, v in images_for_frame.items()}
backgrounds_for_frame = {k: v[0] for k, v in backgrounds_for_frame.items()}
return est_full_all, images_for_frame, backgrounds_for_frame, plocs_all, preds_for_frame
return bliss_cat_all, images_for_frame, backgrounds_for_frame, survey_plocs_all, preds_for_frame


def crop_plocs(cfg, w, h, plocs, do_crop=False):
Expand Down Expand Up @@ -224,15 +226,16 @@
size=10,
fill_color=None,
)
p.scatter(
survey_true_plocs[:, 1],
survey_true_plocs[:, 0],
marker="circle",
color="hotpink",
legend_label="consolidated survey catalog",
size=20,
fill_color=None,
)
if survey_true_plocs is not None:
p.scatter(
survey_true_plocs[:, 1],
survey_true_plocs[:, 0],
marker="circle",
color="hotpink",
legend_label="consolidated survey catalog",
size=20,
fill_color=None,
)
if sdss_plocs is not None:
p.scatter(
sdss_plocs[:, 1],
Expand Down Expand Up @@ -329,8 +332,8 @@
background = backgrounds_for_frame[image_id]

ptc = cfg.encoder.tiles_to_crop * cfg.encoder.tile_slen
image = image[0, 0, ptc:-ptc, ptc:-ptc]
background = background[0, 0, ptc:-ptc, ptc:-ptc]
image = image[0, ptc:-ptc, ptc:-ptc] # uh, are we always plotting the u band image here?
background = background[0, ptc:-ptc, ptc:-ptc]

w, h = image.shape

Expand Down Expand Up @@ -364,6 +367,9 @@
tab1_title = f"{title_prefix}true image{title_suffix}"
tab2_title = f"{title_prefix}reconstructed image{title_suffix}"
tab3_title = f"residual{title_suffix}"

# we are plotting detections from all frames on each frame here, but we shouldn't be!
# each frame has it's own bliss catalog and it's own survey catalog
tab1 = plot_image(cfg, ra, dec, np_image, w, h, est_plocs, survey_true_plocs, tab1_title)
tab2 = plot_image(cfg, ra, dec, np_recon, w, h, est_plocs, survey_true_plocs, tab2_title)
tab3 = plot_image(cfg, ra, dec, np_res, w, h, est_plocs, survey_true_plocs, tab3_title)
Expand Down
12 changes: 6 additions & 6 deletions case_studies/dependent_tiling/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ predict:
_target_: bliss.surveys.sdss.SloanDigitalSkySurvey
dir_path: ${paths.sdss}
fields:
- run: 94 # 2583
camcol: 1 # 2
fields: [12, 13] # [136]
- run: 2583
camcol: 2
fields: [136]
psf_config:
pixel_scale: 0.396
psf_slen: 25
Expand All @@ -55,9 +55,9 @@ predict:
device: "cuda:0"
crop:
do_crop: true
left_upper_corner: [160, 160]
width: 640
height: 640
left_upper_corner: [630, 310]
width: 112
height: 112
plot:
show_plot: true
width: 1000
Expand Down
Loading