Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

high-resolution images & a 2d discrete variational factor #1068

Merged
merged 63 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
5b4ca23
null sample renders bug
jeff-regier Aug 5, 2024
82ded83
fix annoying data_source warning
jeff-regier Aug 5, 2024
79f7d20
aug5_discretizedbox_quarterpixels
jeff-regier Aug 5, 2024
aa4cadd
not much, some m2 notebook stuff
jeff-regier Aug 9, 2024
16d2083
sdss case study
jeff-regier Aug 16, 2024
bf40a00
revert convnets to master
jeff-regier Aug 16, 2024
c8635eb
merge
jeff-regier Aug 16, 2024
8584488
fixed some merge discrepencies
jeff-regier Aug 16, 2024
8b29fd0
bring back former to_tile_catalog
jeff-regier Aug 18, 2024
0d62d71
self.double_sample_prob and nll hack
jeff-regier Aug 18, 2024
49cd763
sdss demo needs retrain with one band fluxes
jeff-regier Aug 18, 2024
095fcbd
tiny changes to m2 case study
jeff-regier Aug 18, 2024
546f48a
revert problematic changes
jeff-regier Aug 18, 2024
cd91253
fix bin_cutoffs
jeff-regier Aug 18, 2024
2a6ee90
logsumexp for double detections
jeff-regier Aug 19, 2024
77b26d8
back to 32-true
jeff-regier Aug 19, 2024
df502bc
lots of 1x1 kernels
jeff-regier Aug 19, 2024
dab7ff7
updated sample for logsumexp double detect
jeff-regier Aug 19, 2024
4032fa2
wider context net
jeff-regier Aug 19, 2024
d1848d4
removed incorrect masking
jeff-regier Aug 20, 2024
6cd81f1
gate new_est_cat2 during sampling
jeff-regier Aug 20, 2024
2540ad4
simplify color context
jeff-regier Aug 20, 2024
168b77b
count_net
jeff-regier Aug 20, 2024
5027822
countnet refactor
jeff-regier Aug 20, 2024
d341192
richer color history
jeff-regier Aug 20, 2024
d5dcfe8
no groupnorm in heads
jeff-regier Aug 21, 2024
e3245c6
minimalist_history
jeff-regier Aug 21, 2024
c13f441
deeper
jeff-regier Aug 21, 2024
6c594a6
1x1 for real
jeff-regier Aug 21, 2024
16dac53
centered locs for colors
jeff-regier Aug 22, 2024
c1fc02f
restore flux history to local context
jeff-regier Aug 22, 2024
1c037d3
add n_sources to color history
jeff-regier Aug 22, 2024
0e70577
only nsources color
jeff-regier Aug 22, 2024
51940c3
spatial context for color
jeff-regier Aug 22, 2024
78e2798
spatial countnet too
jeff-regier Aug 22, 2024
a6c6db5
locs only local context
jeff-regier Aug 22, 2024
fac16d1
embedding
jeff-regier Aug 22, 2024
e67f78e
restore fluxes to local context
jeff-regier Aug 23, 2024
566e4db
groupnorm4all less color spatial
jeff-regier Aug 23, 2024
0b75b64
extra spatial for countnet and color
jeff-regier Aug 23, 2024
39a7049
shallower
jeff-regier Aug 23, 2024
d44461d
earlier first skip connection
jeff-regier Aug 23, 2024
1db8498
null normalizer
jeff-regier Aug 23, 2024
e6b3e70
no groupnorm for localnet or detectionnet
jeff-regier Aug 23, 2024
80fb47d
restore normalizers
jeff-regier Aug 24, 2024
2be7fa1
flux in color context
jeff-regier Aug 24, 2024
7c94649
some merging; give decoder bounds
jeff-regier Aug 28, 2024
317bba8
recovered from merge?
jeff-regier Aug 28, 2024
3db820e
use multiprocessing for generate
jeff-regier Aug 29, 2024
62f90fb
add mask patterns
jeff-regier Aug 29, 2024
e22be1d
pylint
jeff-regier Aug 29, 2024
fe0cc24
flake8
jeff-regier Aug 29, 2024
b895d33
remove simulator
jeff-regier Aug 29, 2024
b606b0a
forgot to remove line
jeff-regier Aug 29, 2024
201c3d1
wrong path
jeff-regier Aug 29, 2024
a8c1ad5
simplify prior
jeff-regier Aug 30, 2024
e82dfd9
restored double downsample
jeff-regier Sep 2, 2024
d06402b
remove fixtures; fix test paths
jeff-regier Sep 2, 2024
d1f5496
revert log10 division change
jeff-regier Sep 2, 2024
a09de3e
tiles_to_crop
jeff-regier Sep 2, 2024
efed656
smaller images
jeff-regier Sep 2, 2024
4b9dfcf
new base_config encoder
jeff-regier Sep 3, 2024
2656a3d
notebooks running again
jeff-regier Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ disable=too-many-ancestors,
unused-argument,

# flake8 recommends against f-strings
consider-using-f-string,
logging-fstring-interpolation,

# flake8 already checks for lambda expressions, which are OK at times
unnecessary-lambda-assignment,
Expand Down
1 change: 1 addition & 0 deletions bliss/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __call__(self, datum_in):

class ChunkingSampler(Sampler):
def __init__(self, dataset: Dataset) -> None:
super().__init__()
# please don't pass dataset to the following __init__()
# according to https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler
# the parameter `data_source` has been deprecated
Expand Down
46 changes: 40 additions & 6 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,38 @@
[tiles_to_crop, self.n_tiles_w - tiles_to_crop],
)

def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
assert box_origin[0] + box_len < self.height, "invalid box"
assert box_origin[1] + box_len < self.width, "invalid box"

Check warning on line 87 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L86-L87

Added lines #L86 - L87 were not covered by tests

box_origin_tensor = box_origin.view(1, 1, 2).to(device=self.device)
box_end_tensor = (box_origin + box_len).view(1, 1, 2).to(device=self.device)

Check warning on line 90 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L89-L90

Added lines #L89 - L90 were not covered by tests

plocs_mask = torch.all(

Check warning on line 92 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L92

Added line #L92 was not covered by tests
(self["plocs"] < box_end_tensor) & (self["plocs"] > box_origin_tensor), dim=2
)

plocs_mask_indexes = plocs_mask.nonzero()
plocs_inverse_mask_indexes = (~plocs_mask).nonzero()
plocs_full_mask_indexes = torch.cat((plocs_mask_indexes, plocs_inverse_mask_indexes), dim=0)
_, index_order = plocs_full_mask_indexes[:, 0].sort(stable=True)
plocs_full_mask_sorted_indexes = plocs_full_mask_indexes[index_order.tolist(), :]

Check warning on line 100 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L96-L100

Added lines #L96 - L100 were not covered by tests

d = {}
new_max_sources = plocs_mask.sum(dim=1).max()
for k, v in self.items():
if k == "n_sources":
d[k] = plocs_mask.sum(dim=1)

Check warning on line 106 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L102-L106

Added lines #L102 - L106 were not covered by tests
else:
d[k] = v[

Check warning on line 108 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L108

Added line #L108 was not covered by tests
plocs_full_mask_sorted_indexes[:, 0].tolist(),
plocs_full_mask_sorted_indexes[:, 1].tolist(),
].view(-1, self.max_sources, v.shape[-1])[:, :new_max_sources, :]

d["plocs"] -= box_origin_tensor

Check warning on line 113 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L113

Added line #L113 was not covered by tests

return FullCatalog(box_len, box_len, d)

Check warning on line 115 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L115

Added line #L115 was not covered by tests


class TileCatalog(BaseTileCatalog):
galaxy_params = [
Expand Down Expand Up @@ -335,9 +367,8 @@
ns11 = rearrange(self["n_sources"], "b ht wt -> b ht wt 1 1")
for k, v in self.items():
if k == "n_sources":
assert not disjoint or ((v == 0) | (other[k] == 0)).all()

Check warning on line 370 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L370

Added line #L370 was not covered by tests
d[k] = v + other[k]
if disjoint:
assert d[k].max() <= 1
else:
if disjoint:
d1 = v
Expand Down Expand Up @@ -734,9 +765,7 @@

return TileCatalog(tile_params)

# pylint: enable=R0912,R0915

def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float, exclude_box=False):
assert box_origin[0] + box_len <= self.height, "invalid box"
assert box_origin[1] + box_len <= self.width, "invalid box"

Expand All @@ -747,6 +776,9 @@
(self["plocs"] < box_end_tensor) & (self["plocs"] > box_origin_tensor), dim=2
)

if exclude_box:
plocs_mask = ~plocs_mask

Check warning on line 780 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L780

Added line #L780 was not covered by tests

plocs_mask_indexes = plocs_mask.nonzero()
plocs_inverse_mask_indexes = (~plocs_mask).nonzero()
plocs_full_mask_indexes = torch.cat((plocs_mask_indexes, plocs_inverse_mask_indexes), dim=0)
Expand All @@ -764,6 +796,8 @@
plocs_full_mask_sorted_indexes[:, 1].tolist(),
].view(-1, self.max_sources, v.shape[-1])[:, :new_max_sources, :]

d["plocs"] -= box_origin_tensor
if exclude_box:
return FullCatalog(self.height, self.width, d)

Check warning on line 800 in bliss/catalog.py

View check run for this annotation

Codecov / codecov/patch

bliss/catalog.py#L800

Added line #L800 was not covered by tests

d["plocs"] -= box_origin_tensor
return FullCatalog(box_len, box_len, d)
90 changes: 52 additions & 38 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,30 @@ paths:
cached_data: /data/scratch/regier/sdss_like
output: ${oc.env:HOME}/bliss_output

# this prior is sdss-like; the flux parameters were fit using SDSS catalogs
# this prior is sdss-like; the parameters were fit using SDSS catalogs
prior:
_target_: bliss.simulator.prior.CatalogPrior
survey_bands: [u, g, r, i, z] # SDSS available band filters
reference_band: 2 # SDSS r-band
star_color_model_path: ${simulator.decoder.survey.dir_path}/color_models/star_gmm_nmgy.pkl
gal_color_model_path: ${simulator.decoder.survey.dir_path}/color_models/gal_gmm_nmgy.pkl
n_tiles_h: 20
n_tiles_w: 20
batch_size: 64
star_color_model_path: ${paths.sdss}/color_models/star_gmm_nmgy.pkl
gal_color_model_path: ${paths.sdss}/color_models/gal_gmm_nmgy.pkl
n_tiles_h: 68 # cropping 2 tiles from each side
n_tiles_w: 68 # cropping 2 tiles from each side
batch_size: 8
max_sources: 1
mean_sources: 0.01 # 0.0025 is more realistic for SDSS but training takes more iterations
mean_sources: 0.0025
min_sources: 0
prob_galaxy: 0.5144
star_flux_exponent: 0.4689157382430609
star_flux_truncation: 613313.768995269
star_flux_loc: -0.5534648001193676
star_flux_scale: 1.1846035501201129
galaxy_flux_exponent: 1.5609458661807678
galaxy_flux_truncation: 28790.449063519092
galaxy_flux_loc: -3.29383532288203
galaxy_flux_scale: 3.924799999613338
star_flux:
exponent: 0.4689157382430609
truncation: 613313.768995269
loc: -0.5534648001193676
scale: 1.1846035501201129
galaxy_flux:
exponent: 1.5609458661807678
truncation: 28790.449063519092
loc: -3.29383532288203
scale: 3.924799999613338
galaxy_a_concentration: 0.39330758068481686
galaxy_a_loc: 0.8371888967872619
galaxy_a_scale: 4.432725319432478
Expand All @@ -51,20 +53,11 @@ decoder:
with_dither: true
with_noise: true

simulator:
_target_: bliss.simulator.simulated_dataset.SimulatedDataset
prior: ${prior}
decoder: ${decoder}
n_batches: 128
num_workers: 32
valid_n_batches: 10 # 256
fix_validation_set: true

cached_simulator:
_target_: bliss.cached_dataset.CachedSimulatedDataModule
batch_size: 64
batch_size: 16
splits: 0:80/80:90/90:100 # train/val/test splits as percent ranges
num_workers: 8
num_workers: 4
cached_data_path: ${paths.cached_data}
train_transforms:
- _target_: bliss.data_augmentation.RotateFlipTransform
Expand Down Expand Up @@ -140,23 +133,42 @@ variational_factors:
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating

# these are in nanomaggies
sdss_mag_zero_point: 1e9
sdss_flux_cutoffs:
- 1.4928
- 1.9055
- 2.7542
- 3.9811
- 5.7544
- 8.3176
- 12.0227
- 17.3780
- 25.1189

metrics:
detection_performance:
_target_: bliss.encoder.metrics.DetectionPerformance
base_flux_bin_cutoffs: [200, 400, 600, 800, 1000]
mag_zero_point: 3631e9 # for DC2
base_flux_bin_cutoffs: ${sdss_flux_cutoffs}
mag_zero_point: ${sdss_mag_zero_point}
report_bin_unit: mag
exclude_last_bin: true
ref_band: 2
source_type_accuracy:
_target_: bliss.encoder.metrics.SourceTypeAccuracy
base_flux_bin_cutoffs: [200, 400, 600, 800, 1000]
mag_zero_point: 3631e9 # for DC2
base_flux_bin_cutoffs: ${sdss_flux_cutoffs}
mag_zero_point: ${sdss_mag_zero_point}
report_bin_unit: mag
exclude_last_bin: true
ref_band: 2
flux_error:
_target_: bliss.encoder.metrics.FluxError
survey_bands: ${encoder.survey_bands}
base_flux_bin_cutoffs: [200, 400, 600, 800, 1000]
mag_zero_point: 3631e9 # for DC2
base_flux_bin_cutoffs: ${sdss_flux_cutoffs}
mag_zero_point: ${sdss_mag_zero_point}
report_bin_unit: mag
exclude_last_bin: true
ref_band: 2

image_normalizers:
psf:
Expand All @@ -173,7 +185,7 @@ encoder:
_target_: bliss.encoder.encoder.Encoder
survey_bands: [u, g, r, i, z]
reference_band: 2 # SDSS r-band
tile_slen: ${simulator.decoder.tile_slen}
tile_slen: ${decoder.tile_slen}
optimizer_params:
lr: 1e-3
scheduler_params:
Expand Down Expand Up @@ -201,7 +213,7 @@ encoder:
frequency: 1
restrict_batch: 0
tiles_to_crop: 0
tile_slen: ${simulator.decoder.tile_slen}
tile_slen: ${decoder.tile_slen}
use_double_detect: false
use_checkerboard: false
n_sampler_colors: 4
Expand Down Expand Up @@ -278,11 +290,13 @@ surveys:
mode: train

generate:
n_image_files: 64
n_batches_per_file: 16
simulator: ${simulator}
prior: ${prior}
decoder: ${decoder}
tiles_to_crop: 2
n_image_files: 512
n_batches_per_file: 32 # multiply by prior.batch_size to get total number of images
n_processes: 16 # using more isn't necessarily faster
cached_data_path: ${paths.cached_data}
file_prefix: dataset
store_full_catalog: false

train:
Expand Down
27 changes: 17 additions & 10 deletions bliss/encoder/convnet_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@


class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, gn=True):
super().__init__()
assert kernel_size % 2 == 1, "kernel size must be odd"
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
# seems to work about as well as BatchNorm2d
self.norm = nn.GroupNorm(out_channels // 8, out_channels)
n_groups = out_channels // 8
use_gn = gn and n_groups >= 16
self.norm = nn.GroupNorm(n_groups, out_channels) if use_gn else nn.Identity()
self.activation = nn.SiLU(inplace=True)

def forward(self, x):
Expand All @@ -27,11 +31,12 @@ def forward(self, x):


class Bottleneck(nn.Module):
def __init__(self, c1, c2, shortcut=True, e=0.5):
def __init__(self, c1, c2, shortcut=True, e=0.5, gn=True, spatial=True):
super().__init__()
ch = int(c2 * e)
self.cv1 = ConvBlock(c1, ch, kernel_size=1, padding=0)
self.cv2 = ConvBlock(ch, c2, kernel_size=3, padding=1, stride=1)
self.cv1 = ConvBlock(c1, ch, kernel_size=1, gn=gn)
ks = 3 if spatial else 1
self.cv2 = ConvBlock(ch, c2, kernel_size=ks, stride=1, gn=gn)
self.add = shortcut and c1 == c2

def forward(self, x):
Expand All @@ -40,13 +45,15 @@ def forward(self, x):


class C3(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=True, e=0.5):
def __init__(self, c1, c2, n=1, shortcut=True, e=0.5, gn=True, spatial=True):
super().__init__()
ch = int(c2 * e)
self.cv1 = ConvBlock(c1, ch, kernel_size=1, padding=0)
self.cv2 = ConvBlock(c1, ch, kernel_size=1, padding=0)
self.cv3 = ConvBlock(2 * ch, c2, kernel_size=1, padding=0)
self.m = nn.Sequential(*(Bottleneck(ch, ch, shortcut, e=1.0) for _ in range(n)))
self.cv1 = ConvBlock(c1, ch, kernel_size=1, gn=gn)
self.cv2 = ConvBlock(c1, ch, kernel_size=1, gn=gn)
self.cv3 = ConvBlock(2 * ch, c2, kernel_size=1, gn=gn)
self.m = nn.Sequential(
*(Bottleneck(ch, ch, shortcut, e=1.0, spatial=spatial) for _ in range(n)),
)

def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
Loading
Loading