From eb65c25476239c72c81bda635addf3c64a247c8e Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 23 Feb 2024 20:46:13 +0000 Subject: [PATCH] Make base trainer class work properly --- .../custom_segmentation_trainer.ipynb | 191 +++++------------- docs/tutorials/custom_segmentation_trainer.py | 18 +- torchgeo/trainers/base.py | 4 + 3 files changed, 67 insertions(+), 146 deletions(-) diff --git a/docs/tutorials/custom_segmentation_trainer.ipynb b/docs/tutorials/custom_segmentation_trainer.ipynb index 9732a50ac8e..9226ed92eec 100644 --- a/docs/tutorials/custom_segmentation_trainer.ipynb +++ b/docs/tutorials/custom_segmentation_trainer.ipynb @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -63,14 +63,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "0b01bf43", "metadata": {}, "outputs": [], "source": [ "import lightning\n", "import lightning.pytorch as pl\n", - "import torch\n", "from lightning.pytorch.callbacks import ModelCheckpoint\n", "from torch.optim import AdamW\n", "from torch.optim.lr_scheduler import CosineAnnealingLR\n", @@ -128,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -136,8 +135,6 @@ "\n", " # any keywords we add here between *args and **kwargs will be found in self.hparams\n", " def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None:\n", - " if \"ignore\" in kwargs:\n", - " del kwargs[\"ignore\"] # this is a hack\n", " super().__init__(*args, **kwargs) # pass args and kwargs to the parent class\n", "\n", " def configure_optimizers(\n", @@ -195,7 +192,7 @@ " List of callbacks to apply.\n", " \"\"\"\n", " return [\n", - " ModelCheckpoint(every_n_epochs=50, save_top_k=-1),\n", + " ModelCheckpoint(every_n_epochs=50, save_top_k=-1, save_last=True),\n", " ModelCheckpoint(monitor=self.monitor, mode=self.mode, save_top_k=5),\n", " ]\n", "\n", @@ -225,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -243,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -254,7 +251,6 @@ "\"eta_min\": 1e-06\n", "\"freeze_backbone\": False\n", "\"freeze_decoder\": False\n", - "\"ignore\": weights\n", "\"ignore_index\": None\n", "\"in_channels\": 3\n", "\"loss\": ce\n", @@ -266,7 +262,7 @@ "\"tmax\": 50" ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -278,33 +274,41 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "GPU available: True (cuda), used: True\n" + "GPU available: True (cuda), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/home/calebrobinson/.conda/envs/geo/lib/python3.10/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n", + "`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n" + "`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.\n" ] } ], "source": [ - "trainer = pl.Trainer(min_epochs=150, max_epochs=250, log_every_n_steps=50)" + "# The following Trainer config is useful just for testing the code in this notebook.\n", + "trainer = pl.Trainer(\n", + " limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, max_epochs=1\n", + ")\n", + "# You can use the following for actual training runs.\n", + "# trainer = pl.Trainer(min_epochs=150, max_epochs=250, log_every_n_steps=50)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -314,85 +318,10 @@ "The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint\n" ] }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip to data/landcover.ai.v1.zip\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 1538212277/1538212277 [01:25<00:00, 17913845.14it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processed M-33-20-D-c-4-2 1/41\n", - "Processed M-33-20-D-d-3-3 2/41\n", - "Processed M-33-32-B-b-4-4 3/41\n", - "Processed M-33-48-A-c-4-4 4/41\n", - "Processed M-33-7-A-d-2-3 5/41\n", - "Processed M-33-7-A-d-3-2 6/41\n", - "Processed M-34-32-B-a-4-3 7/41\n", - "Processed M-34-32-B-b-1-3 8/41\n", - "Processed M-34-5-D-d-4-2 9/41\n", - "Processed M-34-51-C-b-2-1 10/41\n", - "Processed M-34-51-C-d-4-1 11/41\n", - "Processed M-34-55-B-b-4-1 12/41\n", - "Processed M-34-56-A-b-1-4 13/41\n", - "Processed M-34-6-A-d-2-2 14/41\n", - "Processed M-34-65-D-a-4-4 15/41\n", - "Processed M-34-65-D-c-4-2 16/41\n", - "Processed M-34-65-D-d-4-1 17/41\n", - "Processed M-34-68-B-a-1-3 18/41\n", - "Processed M-34-77-B-c-2-3 19/41\n", - "Processed N-33-104-A-c-1-1 20/41\n", - "Processed N-33-119-C-c-3-3 21/41\n", - "Processed N-33-130-A-d-3-3 22/41\n", - "Processed N-33-130-A-d-4-4 23/41\n", - "Processed N-33-139-C-d-2-2 24/41\n", - "Processed N-33-139-C-d-2-4 25/41\n", - "Processed N-33-139-D-c-1-3 26/41\n", - "Processed N-33-60-D-c-4-2 27/41\n", - "Processed N-33-60-D-d-1-2 28/41\n", - "Processed N-33-96-D-d-1-1 29/41\n", - "Processed N-34-106-A-b-3-4 30/41\n", - "Processed N-34-106-A-c-1-3 31/41\n", - "Processed N-34-140-A-b-3-2 32/41\n", - "Processed N-34-140-A-b-4-2 33/41\n", - "Processed N-34-140-A-d-3-4 34/41\n", - "Processed N-34-140-A-d-4-2 35/41\n", - "Processed N-34-61-B-a-1-1 36/41\n", - "Processed N-34-66-C-c-4-3 37/41\n", - "Processed N-34-77-A-b-1-4 38/41\n", - "Processed N-34-94-A-b-2-4 39/41\n", - "Processed N-34-97-C-b-1-2 40/41\n" - ] - }, { "name": "stderr", "output_type": "stream", "text": [ - "You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processed N-34-97-D-c-2-4 41/41\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", "\n", " | Name | Type | Params\n", "---------------------------------------------------\n", @@ -405,27 +334,14 @@ "32.5 M Trainable params\n", "0 Non-trainable params\n", "32.5 M Total params\n", - "130.087 Total estimated model params size (MB)\n" + "130.087 Total estimated model params size (MB)\n", + "/home/calebrobinson/.conda/envs/geo/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "00af14780e004ce69b89e436b7b95606", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃ Test metric DataLoader 0 ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ test_MeanIoU 0.3726549744606018 │\n", - "│ test_OverallAccuracy 0.8094286322593689 │\n", - "│ test_OverallF1Score 0.8094285726547241 │\n", - "│ test_OverallPrecision 0.8094286322593689 │\n", - "│ test_OverallRecall 0.8094286322593689 │\n", - "│ test_loss 0.4797952175140381 │\n", + "│ test_MeanIoU 0.012266275472939014 │\n", + "│ test_OverallAccuracy 0.038088466972112656 │\n", + "│ test_OverallF1Score 0.038088466972112656 │\n", + "│ test_OverallPrecision 0.038088466972112656 │\n", + "│ test_OverallRecall 0.038088466972112656 │\n", + "│ test_loss 1.8426358699798584 │\n", "└───────────────────────────┴───────────────────────────┘\n", "\n" ], @@ -529,12 +444,12 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_MeanIoU \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.3726549744606018 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_OverallAccuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8094286322593689 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_OverallF1Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8094285726547241 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_OverallPrecision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8094286322593689 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_OverallRecall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8094286322593689 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4797952175140381 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MeanIoU \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.012266275472939014 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_OverallAccuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.038088466972112656 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_OverallF1Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.038088466972112656 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_OverallPrecision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.038088466972112656 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_OverallRecall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.038088466972112656 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.8426358699798584 \u001b[0m\u001b[35m \u001b[0m│\n", "└───────────────────────────┴───────────────────────────┘\n" ] }, @@ -544,15 +459,15 @@ { "data": { "text/plain": [ - "[{'test_loss': 0.4797952175140381,\n", - " 'test_MeanIoU': 0.3726549744606018,\n", - " 'test_OverallAccuracy': 0.8094286322593689,\n", - " 'test_OverallF1Score': 0.8094285726547241,\n", - " 'test_OverallPrecision': 0.8094286322593689,\n", - " 'test_OverallRecall': 0.8094286322593689}]" + "[{'test_loss': 1.8426358699798584,\n", + " 'test_MeanIoU': 0.012266275472939014,\n", + " 'test_OverallAccuracy': 0.038088466972112656,\n", + " 'test_OverallF1Score': 0.038088466972112656,\n", + " 'test_OverallPrecision': 0.038088466972112656,\n", + " 'test_OverallRecall': 0.038088466972112656}]" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } diff --git a/docs/tutorials/custom_segmentation_trainer.py b/docs/tutorials/custom_segmentation_trainer.py index 48ed6c7b56a..050f07a1fb0 100644 --- a/docs/tutorials/custom_segmentation_trainer.py +++ b/docs/tutorials/custom_segmentation_trainer.py @@ -42,7 +42,6 @@ import lightning import lightning.pytorch as pl -import torch from lightning.pytorch.callbacks import ModelCheckpoint from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR @@ -81,8 +80,6 @@ class CustomSemanticSegmentationTask(SemanticSegmentationTask): # any keywords we add here between *args and **kwargs will be found in self.hparams def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None: - if "ignore" in kwargs: - del kwargs["ignore"] # this is a hack super().__init__(*args, **kwargs) # pass args and kwargs to the parent class def configure_optimizers( @@ -140,7 +137,7 @@ def configure_callbacks(self): List of callbacks to apply. """ return [ - ModelCheckpoint(every_n_epochs=50, save_top_k=-1), + ModelCheckpoint(every_n_epochs=50, save_top_k=-1, save_last=True), ModelCheckpoint(monitor=self.monitor, mode=self.mode, save_top_k=5), ] @@ -170,17 +167,22 @@ def on_train_epoch_start(self) -> None: # validate that the task's hyperparameters are as expected task.hparams -trainer = pl.Trainer(min_epochs=150, max_epochs=250, log_every_n_steps=50) +# The following Trainer config is useful just for testing the code in this notebook. +trainer = pl.Trainer( + limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, max_epochs=1 +) +# You can use the following for actual training runs. +# trainer = pl.Trainer(min_epochs=150, max_epochs=250, log_every_n_steps=50) trainer.fit(task, dm) # ## Test model # -# Finally, we test the model on the test set and visualize the results. +# Finally, we test the model (optionally loading from a previously saved checkpoint). -# If you are starting from a checkpoint, run this cell +# You can load directly from a saved checkpoint with `.load_from_checkpoint(...)` task = CustomSemanticSegmentationTask.load_from_checkpoint( - "lightning_logs/version_3/checkpoints/epoch=0-step=117.ckpt" + "lightning_logs/version_0/checkpoints/epoch=0-step=1.ckpt" ) trainer.test(task, dm) diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py index 3a44c047a31..d9677e390a6 100644 --- a/torchgeo/trainers/base.py +++ b/torchgeo/trainers/base.py @@ -35,6 +35,10 @@ def __init__(self, ignore: Optional[Union[Sequence[str], str]] = None) -> None: ignore: Arguments to skip when saving hyperparameters. """ super().__init__() + if isinstance(ignore, str): + ignore = [ignore, "ignore"] + else: + ignore = list(ignore) + ["ignore"] self.save_hyperparameters(ignore=ignore) self.configure_losses() self.configure_metrics()