Skip to content

Commit

Permalink
Make base trainer class work properly
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Feb 23, 2024
1 parent 176d22c commit eb65c25
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 146 deletions.
191 changes: 53 additions & 138 deletions docs/tutorials/custom_segmentation_trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -63,14 +63,13 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "0b01bf43",
"metadata": {},
"outputs": [],
"source": [
"import lightning\n",
"import lightning.pytorch as pl\n",
"import torch\n",
"from lightning.pytorch.callbacks import ModelCheckpoint\n",
"from torch.optim import AdamW\n",
"from torch.optim.lr_scheduler import CosineAnnealingLR\n",
Expand Down Expand Up @@ -128,16 +127,14 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class CustomSemanticSegmentationTask(SemanticSegmentationTask):\n",
"\n",
" # any keywords we add here between *args and **kwargs will be found in self.hparams\n",
" def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None:\n",
" if \"ignore\" in kwargs:\n",
" del kwargs[\"ignore\"] # this is a hack\n",
" super().__init__(*args, **kwargs) # pass args and kwargs to the parent class\n",
"\n",
" def configure_optimizers(\n",
Expand Down Expand Up @@ -195,7 +192,7 @@
" List of callbacks to apply.\n",
" \"\"\"\n",
" return [\n",
" ModelCheckpoint(every_n_epochs=50, save_top_k=-1),\n",
" ModelCheckpoint(every_n_epochs=50, save_top_k=-1, save_last=True),\n",
" ModelCheckpoint(monitor=self.monitor, mode=self.mode, save_top_k=5),\n",
" ]\n",
"\n",
Expand Down Expand Up @@ -225,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -243,7 +240,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -254,7 +251,6 @@
"\"eta_min\": 1e-06\n",
"\"freeze_backbone\": False\n",
"\"freeze_decoder\": False\n",
"\"ignore\": weights\n",
"\"ignore_index\": None\n",
"\"in_channels\": 3\n",
"\"loss\": ce\n",
Expand All @@ -266,7 +262,7 @@
"\"tmax\": 50"
]
},
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -278,33 +274,41 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n"
"GPU available: True (cuda), used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"/home/calebrobinson/.conda/envs/geo/lib/python3.10/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n",
"`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n"
"`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.\n"
]
}
],
"source": [
"trainer = pl.Trainer(min_epochs=150, max_epochs=250, log_every_n_steps=50)"
"# The following Trainer config is useful just for testing the code in this notebook.\n",
"trainer = pl.Trainer(\n",
" limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, max_epochs=1\n",
")\n",
"# You can use the following for actual training runs.\n",
"# trainer = pl.Trainer(min_epochs=150, max_epochs=250, log_every_n_steps=50)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -314,85 +318,10 @@
"The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip to data/landcover.ai.v1.zip\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1538212277/1538212277 [01:25<00:00, 17913845.14it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed M-33-20-D-c-4-2 1/41\n",
"Processed M-33-20-D-d-3-3 2/41\n",
"Processed M-33-32-B-b-4-4 3/41\n",
"Processed M-33-48-A-c-4-4 4/41\n",
"Processed M-33-7-A-d-2-3 5/41\n",
"Processed M-33-7-A-d-3-2 6/41\n",
"Processed M-34-32-B-a-4-3 7/41\n",
"Processed M-34-32-B-b-1-3 8/41\n",
"Processed M-34-5-D-d-4-2 9/41\n",
"Processed M-34-51-C-b-2-1 10/41\n",
"Processed M-34-51-C-d-4-1 11/41\n",
"Processed M-34-55-B-b-4-1 12/41\n",
"Processed M-34-56-A-b-1-4 13/41\n",
"Processed M-34-6-A-d-2-2 14/41\n",
"Processed M-34-65-D-a-4-4 15/41\n",
"Processed M-34-65-D-c-4-2 16/41\n",
"Processed M-34-65-D-d-4-1 17/41\n",
"Processed M-34-68-B-a-1-3 18/41\n",
"Processed M-34-77-B-c-2-3 19/41\n",
"Processed N-33-104-A-c-1-1 20/41\n",
"Processed N-33-119-C-c-3-3 21/41\n",
"Processed N-33-130-A-d-3-3 22/41\n",
"Processed N-33-130-A-d-4-4 23/41\n",
"Processed N-33-139-C-d-2-2 24/41\n",
"Processed N-33-139-C-d-2-4 25/41\n",
"Processed N-33-139-D-c-1-3 26/41\n",
"Processed N-33-60-D-c-4-2 27/41\n",
"Processed N-33-60-D-d-1-2 28/41\n",
"Processed N-33-96-D-d-1-1 29/41\n",
"Processed N-34-106-A-b-3-4 30/41\n",
"Processed N-34-106-A-c-1-3 31/41\n",
"Processed N-34-140-A-b-3-2 32/41\n",
"Processed N-34-140-A-b-4-2 33/41\n",
"Processed N-34-140-A-d-3-4 34/41\n",
"Processed N-34-140-A-d-4-2 35/41\n",
"Processed N-34-61-B-a-1-1 36/41\n",
"Processed N-34-66-C-c-4-3 37/41\n",
"Processed N-34-77-A-b-1-4 38/41\n",
"Processed N-34-94-A-b-2-4 39/41\n",
"Processed N-34-97-C-b-1-2 40/41\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed N-34-97-D-c-2-4 41/41\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
"\n",
" | Name | Type | Params\n",
"---------------------------------------------------\n",
Expand All @@ -405,27 +334,14 @@
"32.5 M Trainable params\n",
"0 Non-trainable params\n",
"32.5 M Total params\n",
"130.087 Total estimated model params size (MB)\n"
"130.087 Total estimated model params size (MB)\n",
"/home/calebrobinson/.conda/envs/geo/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "00af14780e004ce69b89e436b7b95606",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d1439daa182f4d27b5194609bd194d56",
"model_id": "0834b6894deb4a7a9edc9d90a8d9ed31",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -439,7 +355,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9b547aad511247b1b0e88f79fbfa7e38",
"model_id": "24564c905f3d432eb9ba6171f03c2273",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -454,7 +370,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/home/calebrobinson/.conda/envs/geo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...\n"
"`Trainer.fit` stopped: `max_epochs=1` reached.\n"
]
}
],
Expand All @@ -468,38 +384,37 @@
"source": [
"## Test model\n",
"\n",
"Finally, we test the model on the test set and visualize the results."
"Finally, we test the model (optionally loading from a previously saved checkpoint)."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# If you are starting from a checkpoint, run this cell\n",
"# You can load directly from a saved checkpoint with `.load_from_checkpoint(...)`\n",
"task = CustomSemanticSegmentationTask.load_from_checkpoint(\n",
" \"lightning_logs/version_3/checkpoints/epoch=0-step=117.ckpt\"\n",
" \"lightning_logs/version_0/checkpoints/epoch=0-step=1.ckpt\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n"
"The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b4567eeff83f488891b6911480ea381f",
"model_id": "5dc29f3a1d4a449ab1ef94c58c0937ba",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -516,25 +431,25 @@
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_MeanIoU </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.3726549744606018 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_OverallAccuracy </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.8094286322593689 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_OverallF1Score </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.8094285726547241 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_OverallPrecision </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.8094286322593689 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_OverallRecall </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.8094286322593689 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.4797952175140381 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_MeanIoU </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.012266275472939014 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_OverallAccuracy </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.038088466972112656 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_OverallF1Score </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.038088466972112656 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_OverallPrecision </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.038088466972112656 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_OverallRecall </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.038088466972112656 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 1.8426358699798584 </span>│\n",
"└───────────────────────────┴───────────────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"\u001b[36m \u001b[0m\u001b[36m test_MeanIoU \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.3726549744606018 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_OverallAccuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8094286322593689 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_OverallF1Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8094285726547241 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_OverallPrecision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8094286322593689 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_OverallRecall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8094286322593689 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.4797952175140381 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_MeanIoU \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.012266275472939014 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_OverallAccuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.038088466972112656 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_OverallF1Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.038088466972112656 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_OverallPrecision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.038088466972112656 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_OverallRecall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.038088466972112656 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.8426358699798584 \u001b[0m\u001b[35m \u001b[0m│\n",
"└───────────────────────────┴───────────────────────────┘\n"
]
},
Expand All @@ -544,15 +459,15 @@
{
"data": {
"text/plain": [
"[{'test_loss': 0.4797952175140381,\n",
" 'test_MeanIoU': 0.3726549744606018,\n",
" 'test_OverallAccuracy': 0.8094286322593689,\n",
" 'test_OverallF1Score': 0.8094285726547241,\n",
" 'test_OverallPrecision': 0.8094286322593689,\n",
" 'test_OverallRecall': 0.8094286322593689}]"
"[{'test_loss': 1.8426358699798584,\n",
" 'test_MeanIoU': 0.012266275472939014,\n",
" 'test_OverallAccuracy': 0.038088466972112656,\n",
" 'test_OverallF1Score': 0.038088466972112656,\n",
" 'test_OverallPrecision': 0.038088466972112656,\n",
" 'test_OverallRecall': 0.038088466972112656}]"
]
},
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
Loading

0 comments on commit eb65c25

Please sign in to comment.