Skip to content

Commit

Permalink
removed duplicate methods in lensing encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyasc30 committed Jul 23, 2024
1 parent 1b0bcd1 commit bb9a8a7
Showing 1 changed file with 0 additions and 62 deletions.
62 changes: 0 additions & 62 deletions case_studies/weak_lensing/lensing_encoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from typing import Optional

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR
from torchmetrics import MetricCollection

from bliss.catalog import BaseTileCatalog
from bliss.encoder.encoder import Encoder
from bliss.encoder.variational_dist import VariationalDist
from bliss.global_env import GlobalEnv
from case_studies.weak_lensing.lensing_convnet import WeakLensingCatalogNet, WeakLensingFeaturesNet


Expand Down Expand Up @@ -89,16 +86,6 @@ def _compute_loss(self, batch, logging_name):

return loss

def on_fit_start(self):
GlobalEnv.current_encoder_epoch = self.current_epoch

def on_train_epoch_start(self):
GlobalEnv.current_encoder_epoch = self.current_epoch

def training_step(self, batch, batch_idx, optimizer_idx=0):
"""Training step (pytorch lightning)."""
return self._compute_loss(batch, "train")

def update_metrics(self, batch, batch_idx):
target_cat = BaseTileCatalog(batch["tile_catalog"])

Expand All @@ -116,52 +103,3 @@ def update_metrics(self, batch, batch_idx):
self.current_epoch,
batch_idx,
)

def validation_step(self, batch, batch_idx):
"""Pytorch lightning method."""
self._compute_loss(batch, "val")
self.update_metrics(batch, batch_idx)

def report_metrics(self, metrics, logging_name, show_epoch=False):
for k, v in metrics.compute().items():
self.log(f"{logging_name}/{k}", v, sync_dist=True)

for metric_name, metric in metrics.items():
if hasattr(metric, "plot"): # noqa: WPS421
plot_or_none = metric.plot()
name = f"Epoch:{self.current_epoch}" if show_epoch else ""
name += f"/{logging_name} {metric_name}"
if self.logger and plot_or_none:
fig, _axes = plot_or_none
self.logger.experiment.add_figure(name, fig)

metrics.reset()

def on_validation_epoch_end(self):
self.report_metrics(self.mode_metrics, "val/mode", show_epoch=True)
self.report_metrics(self.sample_metrics, "sample/mode", show_epoch=True)
self.report_metrics(self.sample_image_renders, "val/image_renders", show_epoch=True)

def test_step(self, batch, batch_idx):
"""Pytorch lightning method."""
self._compute_loss(batch, "test")
self.update_metrics(batch, batch_idx)

def on_test_epoch_end(self):
self.report_metrics(self.mode_metrics, "test/mode", show_epoch=False)
self.report_metrics(self.sample_metrics, "test/mode", show_epoch=False)

def predict_step(self, batch, batch_idx, dataloader_idx=0):
"""Pytorch lightning method."""
with torch.no_grad():
return {
"mode_cat": self.sample(batch, use_mode=True),
# we may want multiple samples
"sample_cat": self.sample(batch, use_mode=False),
}

def configure_optimizers(self):
"""Configure optimizers for training (pytorch lightning)."""
optimizer = Adam(self.parameters(), **self.optimizer_params)
scheduler = MultiStepLR(optimizer, **self.scheduler_params)
return [optimizer], [scheduler]

0 comments on commit bb9a8a7

Please sign in to comment.