Skip to content

Commit

Permalink
Torchgeo Notebook Examples (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pale-Blue-Dot-97 authored May 12, 2023
1 parent e48d1ca commit 5fee2f5
Show file tree
Hide file tree
Showing 12 changed files with 2,356 additions and 1,385 deletions.
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

## About 🔎

Minerva is a package to aid in the building, fitting and testing of neural network models on geo-spatial
rasterised land cover data.
Minerva is a package to aid in the building, fitting and testing of neural network models on multi-spectral geo-spatial data.

## Getting Started ▶

Expand Down Expand Up @@ -122,11 +121,19 @@ Contributions also provided by:
- [Isabel Sargent](https://github.com/PenguinJunk)
- [Steve Coupland](https://github.com/scoupland-os)
- [Joe Guyatt](https://github.com/joeguyatt97)
- [Ben Dickens](https://github.com/BenDickens)
- [Kitty Varghese](https://github.com/kittyvarghese)

## Acknowledgments 📢

I'd like to acknowledge the invaluable supervision and contributions of Prof Jonathon Hare and
Dr Isabel Sargent towards this work.
I'd like to acknowledge the invaluable supervision and contributions of [Prof Jonathon Hare](https://github.com/jonhare) and
[Dr Isabel Sargent](https://github.com/PenguinJunk) towards this work.

The following modules are adapted from open source third-parites:
| Module | Original Author | License | Link |
|:-------|:----------------|:--------|:-----|
| `pytorchtools` | [Noah Golmant](https://github.com/noahgolmant) | MIT | https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py |
| `optimisers` | [Bjarte Mehus Sunde](https://github.com/Bjarten) | MIT | https://github.com/Bjarten/early-stopping-pytorch |

This repositry also contains two small ``.tiff`` exercpts from the [ChesapeakeCVPR](https://lila.science/datasets/chesapeakelandcover) dataset used for unit testing purposes. Credit for this data goes to:

Expand Down
28 changes: 14 additions & 14 deletions minerva/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def __init__(

self.model = model_cls(*args, **kwargs)

def __call__(self, *input) -> Any:
return self.forward(*input)
def __call__(self, *inputs) -> Any:
return self.forward(*inputs)

def __getattr__(self, name):
try:
Expand All @@ -246,8 +246,8 @@ def __getattr__(self, name):
def __repr__(self) -> Any:
return self.model.__repr__()

def forward(self, *input) -> Any:
return self.model.forward(*input)
def forward(self, *inputs) -> Any:
return self.model.forward(*inputs)


class MinervaBackbone(MinervaModel):
Expand Down Expand Up @@ -297,22 +297,22 @@ def __init__(
super(MinervaDataParallel, self).__init__()
self.model = paralleliser(model, *args, **kwargs).cuda()

def forward(self, *input: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]:
def forward(self, *inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]:
"""Ensures a forward call to the model goes to the actual wrapped model.
Args:
input (tuple[~torch.Tensor, ...]): Input of tensors to be parsed to the
inputs (tuple[~torch.Tensor, ...]): Input of tensors to be parsed to the
:attr:`~MinervaDataParallel.model` forward.
Returns:
tuple[~torch.Tensor, ...]: Output of :attr:`~MinervaDataParallel.model`.
"""
z = self.model(*input)
z = self.model(*inputs)
assert isinstance(z, tuple) and list(map(type, z)) == [Tensor] * len(z)
return z

def __call__(self, *input) -> Tuple[Tensor, ...]:
return self.forward(*input)
def __call__(self, *inputs) -> Tuple[Tensor, ...]:
return self.forward(*inputs)

def __getattr__(self, name):
try:
Expand All @@ -339,8 +339,8 @@ def __init__(self, model: Module, *args, **kwargs) -> None:

self.model = model

def __call__(self, *input) -> Any:
return self.model.forward(*input)
def __call__(self, *inputs) -> Any:
return self.model.forward(*inputs)

def __getattr__(self, name) -> Any:
try:
Expand All @@ -351,16 +351,16 @@ def __getattr__(self, name) -> Any:
def __repr__(self) -> Any:
return self.model.__repr__()

def forward(self, *input: Any) -> Any:
def forward(self, *inputs: Any) -> Any:
"""Performs a forward pass of the :attr:`~MinervaOnnxModel.model` within.
Args:
input (~typing.Any): Input to be parsed to the ``.forward`` method of :attr:`~MinervaOnnxModel.model`.
inputs (~typing.Any): Input to be parsed to the ``.forward`` method of :attr:`~MinervaOnnxModel.model`.
Returns:
~typing.Any: Output of :attr:`~MinervaOnnxModel.model`.
"""
return self.model.forward(*input)
return self.model.forward(*inputs)


# =====================================================================================================================
Expand Down
10 changes: 5 additions & 5 deletions minerva/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,13 +1267,13 @@ def save_model_weights(self, fn: Optional[str] = None) -> None:
torch.save(model.state_dict(), f"{fn}.pt")

def save_model(
self, fn: Optional[Union[Path, str]] = None, format: str = "pt"
self, fn: Optional[Union[Path, str]] = None, fmt: str = "pt"
) -> None:
"""Saves the model object itself to :mod:`torch` file.
Args:
fn (~pathlib.Path | str): Optional; Filename and path (excluding extension) to save model to.
format (str): Optional; Format to save model to. ``pt`` for :mod:`torch`, or :mod:`onnx` for ONNX.
fmt (str): Optional; Format to save model to. ``pt`` for :mod:`torch`, or :mod:`onnx` for ONNX.
Raises:
ValueError: If format is not recognised.
Expand All @@ -1283,13 +1283,13 @@ def save_model(
if fn is None:
fn = str(self.exp_fn)

if format == "pt":
if fmt == "pt":
torch.save(model, f"{fn}.pt")
elif format == "onnx":
elif fmt == "onnx":
x = torch.rand(*self.get_input_size(), device=self.device)
torch.onnx.export(model, (x,), f"{fn}.onnx")
else:
raise ValueError(f"format {format} unrecognised!")
raise ValueError(f"format {fmt} unrecognised!")

def save_backbone(self) -> None:
"""Readies the model for use in downstream tasks and saves to file."""
Expand Down
194 changes: 194 additions & 0 deletions notebooks/Torchgeo_FCN_Demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import tempfile\n",
"from pathlib import Path\n",
"\n",
"from torch.utils.data import DataLoader\n",
"from torchvision.models.segmentation import fcn_resnet50\n",
"import torch.nn as nn\n",
"from torchgeo.datasets import NAIP, ChesapeakeDE, stack_samples\n",
"from torchgeo.datasets.utils import download_url\n",
"from torchgeo.samplers import RandomGeoSampler\n",
"from torch.nn import CrossEntropyLoss\n",
"from torch.optim import Adam\n",
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from minerva.models import FCN8ResNet18\n",
"from minerva.utils.utils import get_cuda_device\n",
"\n",
"device = get_cuda_device(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_root = tempfile.gettempdir()\n",
"train_root = Path(data_root, \"naip\", \"train\")\n",
"test_root = Path(data_root, \"naip\", \"test\")\n",
"naip_url = \"https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/\"\n",
"tiles = [\n",
" \"m_3807511_ne_18_060_20181104.tif\",\n",
" \"m_3807511_se_18_060_20181104.tif\",\n",
" \"m_3807512_nw_18_060_20180815.tif\",\n",
"]\n",
"\n",
"for tile in tiles:\n",
" download_url(naip_url + tile, train_root)\n",
"\n",
"download_url(naip_url + \"m_3807512_sw_18_060_20180815.tif\", test_root)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_naip = NAIP(train_root)\n",
"test_naip = NAIP(test_root)\n",
"\n",
"chesapeake_root = os.path.join(data_root, \"chesapeake\")\n",
"\n",
"chesapeake = ChesapeakeDE(chesapeake_root, crs=train_naip.crs, res=train_naip.res, download=True)\n",
"\n",
"train_dataset = train_naip & chesapeake\n",
"test_dataset = test_naip & chesapeake\n",
"\n",
"sampler = RandomGeoSampler(train_naip, size=256, length=200)\n",
"dataloader = DataLoader(train_dataset, sampler=sampler, collate_fn=stack_samples, batch_size=32)\n",
"\n",
"testsampler = RandomGeoSampler(test_naip, size=256, length=8)\n",
"testdataloader = DataLoader(test_dataset, sampler=testsampler, collate_fn=stack_samples, batch_size=8, num_workers=4)\n",
"testdata = list(testdataloader)[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"crit = CrossEntropyLoss()\n",
"\n",
"# Criterions are normally parsed to models at init in minerva.\n",
"fcn = FCN8ResNet18(crit, input_size=(4, 256, 256), n_classes=13).to(device)\n",
"opt = Adam(fcn.parameters(), lr=1e-3)\n",
"\n",
"# Optimisers need to be set to a model in minerva before training.\n",
"fcn.set_optimiser(opt)\n",
"\n",
"for epoch in range(101):\n",
" losses = []\n",
" for i, sample in enumerate(dataloader):\n",
" image = sample[\"image\"].to(device).float() / 255.0\n",
" target = sample[\"mask\"].to(device).long().squeeze(1)\n",
" \n",
" # Uses MinervaModel.step.\n",
" loss, pred = fcn.step(image, target, train=True)\n",
" losses.append(loss.item())\n",
"\n",
" print(epoch, np.mean(losses))\n",
" if epoch % 10 == 0:\n",
" with torch.no_grad():\n",
" image = testdata[\"image\"].to(device).float() / 255.0\n",
" target = testdata[\"mask\"].to(device).long().squeeze(1)\n",
" pred = fcn(image)\n",
"\n",
" fig, axs = plt.subplots(3, pred.shape[0], figsize=(10,4))\n",
" for i in range(pred.shape[0]):\n",
" axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))\n",
" axs[1,i].imshow(target[i].cpu().numpy(), cmap=\"Set3\", vmin=0, vmax=12)\n",
" axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap=\"Set3\", vmin=0, vmax=12)\n",
" plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fcn = fcn_resnet50(num_classes=13).to(device)\n",
"fcn.backbone.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).to(device)\n",
"\n",
"crit = CrossEntropyLoss()\n",
"opt = Adam(fcn.parameters(), lr=1e-3)\n",
"\n",
"for epoch in range(101):\n",
" losses = []\n",
" for i, sample in enumerate(dataloader):\n",
" image = sample[\"image\"].to(device).float() / 255.0\n",
" target = sample[\"mask\"].to(device).long().squeeze(1)\n",
"\n",
" opt.zero_grad()\n",
" pred = fcn(image)[\"out\"]\n",
" loss = crit(pred, target)\n",
" loss.backward()\n",
" opt.step()\n",
" losses.append(loss.item())\n",
"\n",
" print(epoch, np.mean(losses))\n",
" if epoch % 10 == 0:\n",
" with torch.no_grad():\n",
" image = testdata[\"image\"].to(device).float() / 255.0\n",
" target = testdata[\"mask\"].to(device).long().squeeze(1)\n",
" pred = fcn(image)[\"out\"]\n",
"\n",
" fig, axs = plt.subplots(3, pred.shape[0], figsize=(10,4))\n",
" for i in range(pred.shape[0]):\n",
" axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))\n",
" axs[1,i].imshow(target[i].cpu().numpy(), cmap=\"Set3\", vmin=0, vmax=12)\n",
" axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap=\"Set3\", vmin=0, vmax=12)\n",
" plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])\n",
" plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "minerva-310",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0]"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "3564bae54b830248e5fcf548a4e349b732e585ece6f047dc1ae97c29756580ff"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 5fee2f5

Please sign in to comment.