Skip to content

Commit

Permalink
Make mypy happy I think
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Feb 23, 2024
1 parent 3e425db commit aa7a549
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
14 changes: 10 additions & 4 deletions docs/tutorials/custom_segmentation_trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Union, Sequence\n",
"from lightning.pytorch.callbacks.callback import Callback\n",
"import lightning\n",
"import lightning.pytorch as pl\n",
"from lightning.pytorch.callbacks import ModelCheckpoint\n",
Expand Down Expand Up @@ -134,7 +136,7 @@
"class CustomSemanticSegmentationTask(SemanticSegmentationTask):\n",
"\n",
" # any keywords we add here between *args and **kwargs will be found in self.hparams\n",
" def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None:\n",
" def __init__(self, *args: Any, tmax: int=50, eta_min: float=1e-6, **kwargs: Any) -> None:\n",
" super().__init__(*args, **kwargs) # pass args and kwargs to the parent class\n",
"\n",
" def configure_optimizers(\n",
Expand Down Expand Up @@ -185,7 +187,7 @@
" self.val_metrics = self.train_metrics.clone(prefix=\"val_\")\n",
" self.test_metrics = self.train_metrics.clone(prefix=\"test_\")\n",
"\n",
" def configure_callbacks(self):\n",
" def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:\n",
" \"\"\"Initialize callbacks for saving the best and latest models.\n",
"\n",
" Returns:\n",
Expand All @@ -198,8 +200,12 @@
"\n",
" def on_train_epoch_start(self) -> None:\n",
" \"\"\"Log the learning rate at the start of each training epoch.\"\"\"\n",
" lr = self.optimizers().param_groups[0][\"lr\"]\n",
" self.logger.experiment.add_scalar(\"lr\", lr, self.current_epoch)"
" optimizers = self.optimizers()\n",
" if isinstance(optimizers, list):\n",
" lr = optimizers[0].param_groups[0][\"lr\"]\n",
" else:\n",
" lr = optimizers.param_groups[0][\"lr\"]\n",
" self.logger.experiment.add_scalar(\"lr\", lr, self.current_epoch) # type: ignore"
]
},
{
Expand Down
14 changes: 10 additions & 4 deletions docs/tutorials/custom_segmentation_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
# UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.
import warnings

from typing import Any, Union, Sequence
from lightning.pytorch.callbacks.callback import Callback
import lightning
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -79,7 +81,7 @@
class CustomSemanticSegmentationTask(SemanticSegmentationTask):

# any keywords we add here between *args and **kwargs will be found in self.hparams
def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None:
def __init__(self, *args: Any, tmax: int=50, eta_min: float=1e-6, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) # pass args and kwargs to the parent class

def configure_optimizers(
Expand Down Expand Up @@ -130,7 +132,7 @@ def configure_metrics(self) -> None:
self.val_metrics = self.train_metrics.clone(prefix="val_")
self.test_metrics = self.train_metrics.clone(prefix="test_")

def configure_callbacks(self):
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
"""Initialize callbacks for saving the best and latest models.
Returns:
Expand All @@ -143,8 +145,12 @@ def configure_callbacks(self):

def on_train_epoch_start(self) -> None:
"""Log the learning rate at the start of each training epoch."""
lr = self.optimizers().param_groups[0]["lr"]
self.logger.experiment.add_scalar("lr", lr, self.current_epoch)
optimizers = self.optimizers()
if isinstance(optimizers, list):
lr = optimizers[0].param_groups[0]["lr"]
else:
lr = optimizers.param_groups[0]["lr"]
self.logger.experiment.add_scalar("lr", lr, self.current_epoch) # type: ignore


# ## Train model
Expand Down

0 comments on commit aa7a549

Please sign in to comment.