-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Weak lensing: data aug + recent DC2 results with baselines (#1077)
* Tune avg ellip estimator in ellipticity notebook * Add plot to ellipticity notebook * Fix bug when computing MSEs in ellipticity notebook * Minor update to README * Add weighed avg ellipticity estimator as an additional baseline * Notebook to demonstrate shear transformation under rotation * Add ra/dec to catalog * Close figures in lensing_plots.py * Rerun encoder eval notebook after recent training run * Add 2PCFs to encoder eval notebook (work in progress) * Set seed in encoder eval notebook * Add flips to data augmentation notebook * lensing rotate and flip * Fix pred_convergence shape in lensing metrics * Allow train_transforms to be an argument in LensingDC2DataModule * Update encoder eval notebook with most recent encoder weights * Fix split files directory in dc2 config * updated augmentation * Rename data aug and metrics, tweak data aug * Fix output path in dc2 config --------- Co-authored-by: shreyasc <[email protected]> Co-authored-by: Shreyas Chandrashekaran <[email protected]>
- Loading branch information
1 parent
4694457
commit 0eb412b
Showing
10 changed files
with
806 additions
and
278 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import random | ||
|
||
import torch | ||
|
||
|
||
class LensingRotateFlipTransform(torch.nn.Module): | ||
def __init__(self, without_replacement=False): | ||
super().__init__() | ||
self.rotate_id = -1 | ||
self.flip_id = -1 | ||
self.seen_states = set() | ||
self.without_replacement = without_replacement | ||
|
||
def __call__(self, datum): | ||
self.rotate_id = random.randint(0, 3) | ||
self.flip_id = random.randint(0, 1) | ||
|
||
if self.without_replacement: | ||
if len(self.seen_states) == 8: | ||
self.seen_states = set() | ||
while ((self.rotate_id, self.flip_id)) in self.seen_states: | ||
self.rotate_id = random.randint(0, 3) | ||
self.flip_id = random.randint(0, 1) | ||
self.seen_states.add((self.rotate_id, self.flip_id)) | ||
|
||
# problematic if the psf isn't rotationally invariant | ||
datum_out = {"psf_params": datum["psf_params"]} | ||
|
||
# apply rotation | ||
datum_out["images"] = datum["images"].rot90(self.rotate_id, [1, 2]) | ||
d = datum["tile_catalog"] | ||
datum_out["tile_catalog"] = {k: v.rot90(self.rotate_id, [0, 1]) for k, v in d.items()} | ||
|
||
# apply flip | ||
if self.flip_id == 1: | ||
datum_out["images"] = datum_out["images"].flip([1]) | ||
d = datum_out["tile_catalog"] | ||
datum_out["tile_catalog"] = {k: v.flip([0]) for k, v in d.items()} | ||
|
||
# shear requires special logic | ||
if all(k in datum["tile_catalog"] for k in ("shear_1", "shear_2")): | ||
shear1 = datum_out["tile_catalog"]["shear_1"] | ||
shear2 = datum_out["tile_catalog"]["shear_2"] | ||
for _ in range(self.rotate_id): | ||
shear1 = -shear1 | ||
shear2 = -shear2 | ||
if self.flip_id == 1: | ||
shear2 = -shear2 | ||
datum_out["tile_catalog"]["shear_1"] = shear1 | ||
datum_out["tile_catalog"]["shear_2"] = shear2 | ||
|
||
# locations require special logic | ||
if "locs" in datum["tile_catalog"]: | ||
locs = datum_out["tile_catalog"]["locs"] | ||
for _ in range(self.rotate_id): | ||
# Rotate 90 degrees clockwise (in pixel coordinates) | ||
locs = torch.stack((1 - locs[..., 1], locs[..., 0]), dim=3) | ||
if self.flip_id == 1: | ||
locs = torch.stack((1 - locs[..., 0], locs[..., 1]), dim=3) | ||
datum_out["tile_catalog"]["locs"] = locs | ||
|
||
return datum_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
516 changes: 295 additions & 221 deletions
516
case_studies/weak_lensing/notebooks/dc2/ellipticity.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
64 changes: 41 additions & 23 deletions
64
case_studies/weak_lensing/notebooks/dc2/evaluate_encoder.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.