Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SatlasPretrain: ResNet50/152 and Swin_V2_T Weights #2038

Merged
merged 12 commits into from
Aug 27, 2024
3 changes: 2 additions & 1 deletion docs/api/landsat_pretrained_weights.csv
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ ResNet50_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/
ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html>`__,"CC0-1.0",63.65,46.68,60.01,43.17
ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html>`__,"CC0-1.0",66.81,50.16,64.17,47.24
ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html>`__,"CC0-1.0",65.04,48.20,62.61,45.46
Swin_V2_B_Weights.LANDSAT_MS_SI_SATLAS,11,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
Swin_V2_B_Weights.LANDSAT_SI_SATLAS,8--9,11,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
Swin_V2_B_Weights.LANDSAT_MI_SATLAS,8--9,11,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 4 additions & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ ResNet

.. autofunction:: resnet18
.. autofunction:: resnet50
.. autofunction:: resnet152
.. autoclass:: ResNet18_Weights
.. autoclass:: ResNet50_Weights
.. autoclass:: ResNet152_Weights

Scale-MAE
^^^^^^^^^
Expand All @@ -59,7 +61,9 @@ Scale-MAE
Swin Transformer
^^^^^^^^^^^^^^^^^^

.. autofunction:: swin_v2_t
.. autofunction:: swin_v2_b
.. autoclass:: Swin_V2_T_Weights
.. autoclass:: Swin_V2_B_Weights

Vision Transformer
Expand Down
3 changes: 2 additions & 1 deletion docs/api/naip_pretrained_weights.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Weight,Channels,Source,Citation,License
Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
Swin_V2_B_Weights.NAIP_RGB_MI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0
Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0
3 changes: 2 additions & 1 deletion docs/api/sentinel1_pretrained_weights.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Weight,Channels,Source,Citation,License
ResNet50_Weights.SENTINEL1_ALL_DECUR, 2,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0"
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0"
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
Swin_V2_B_Weights.SENTINEL1_MI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0
18 changes: 16 additions & 2 deletions docs/api/sentinel2_pretrained_weights.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,23 @@ ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seas
ResNet50_Weights.SENTINEL2_ALL_DECUR,13,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0",,,,
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",90.7,99.1,63.6,
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",91.8,99.1,60.9,
ResNet50_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
ResNet50_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
ResNet50_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
ResNet50_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",,,
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,"Apache-2.0",87.81,,,
ResNet152_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
ResNet152_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
ResNet152_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
ResNet152_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",90.5,99.0,62.2,
ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",89.9,98.6,61.6,
Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
Swin_V2_B_Weights.SENTINEL2_MS_SI_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
Swin_V2_T_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
Swin_V2_T_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
Swin_V2_T_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
Swin_V2_T_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
Swin_V2_B_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
Swin_V2_B_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
Swin_V2_B_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,Apache-2.0,,,,
6 changes: 5 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
dofa_large_patch16_224,
resnet18,
resnet50,
resnet152,
scalemae_large_patch16,
swin_v2_b,
swin_v2_t,
vit_small_patch16_224,
)

Expand All @@ -22,9 +24,11 @@
'dofa_large_patch16_224',
'resnet18',
'resnet50',
'resnet152',
'scalemae_large_patch16',
'swin_v2_t',
'swin_v2_b',
'vit_small_patch16_224',
)

dependencies = ['timm']
dependencies = ['timm', 'torchvision']
8 changes: 8 additions & 0 deletions tests/models/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
DOFALarge16_Weights,
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
ScaleMAELarge16_Weights,
Swin_V2_B_Weights,
Swin_V2_T_Weights,
ViTSmall16_Weights,
dofa_base_patch16_224,
dofa_large_patch16_224,
Expand All @@ -24,8 +26,10 @@
list_models,
resnet18,
resnet50,
resnet152,
scalemae_large_patch16,
swin_v2_b,
swin_v2_t,
vit_small_patch16_224,
)

Expand All @@ -34,7 +38,9 @@
dofa_large_patch16_224,
resnet18,
resnet50,
resnet152,
scalemae_large_patch16,
swin_v2_t,
swin_v2_b,
vit_small_patch16_224,
]
Expand All @@ -43,7 +49,9 @@
DOFALarge16_Weights,
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
ScaleMAELarge16_Weights,
Swin_V2_T_Weights,
Swin_V2_B_Weights,
ViTSmall16_Weights,
]
Expand Down
53 changes: 50 additions & 3 deletions tests/models/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from torchgeo.models import (
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
resnet18,
resnet50,
resnet152,
)


class TestResNet18:
Expand Down Expand Up @@ -44,7 +51,7 @@ def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
}
mocked_weights.transforms(sample)

Expand Down Expand Up @@ -84,10 +91,50 @@ def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
}
mocked_weights.transforms(sample)

@pytest.mark.slow
def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet50(weights=weights)


class TestResNet152:
@pytest.fixture(params=[*ResNet152_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet152', in_chans=weights.meta['in_chans'])
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights

def test_resnet(self) -> None:
resnet152()

def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet152(weights=mocked_weights)

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
}
mocked_weights.transforms(sample)

@pytest.mark.slow
def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet152(weights=weights)
52 changes: 51 additions & 1 deletion tests/models/test_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,52 @@
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import Swin_V2_B_Weights, swin_v2_b
from torchgeo.models import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t


class TestSwin_V2_T:
@pytest.fixture(params=[*Swin_V2_T_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = torchvision.models.swin_v2_t()
num_channels = weights.meta['in_chans']
out_channels = model.features[0][0].out_channels
model.features[0][0] = torch.nn.Conv2d(
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights

def test_swin_v2_t(self) -> None:
swin_v2_t()

def test_swin_v2_t_weights(self, mocked_weights: WeightsEnum) -> None:
swin_v2_t(weights=mocked_weights)

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
}
mocked_weights.transforms(sample)

@pytest.mark.slow
def test_swin_v2_t_download(self, weights: WeightsEnum) -> None:
swin_v2_t(weights=weights)


class TestSwin_V2_B:
Expand All @@ -28,6 +73,11 @@ def mocked_weights(
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = torchvision.models.swin_v2_b()
num_channels = weights.meta['in_chans']
out_channels = model.features[0][0].out_channels
model.features[0][0] = torch.nn.Conv2d(
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
Expand Down
17 changes: 14 additions & 3 deletions torchgeo/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@
from .fcn import FCN
from .fcsiam import FCSiamConc, FCSiamDiff
from .rcf import RCF
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .resnet import (
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
resnet18,
resnet50,
resnet152,
)
from .scale_mae import ScaleMAE, ScaleMAELarge16_Weights, scalemae_large_patch16
from .swin import Swin_V2_B_Weights, swin_v2_b
from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
from .vit import ViTSmall16_Weights, vit_small_patch16_224

__all__ = (
Expand All @@ -40,16 +47,20 @@
'RCF',
'resnet18',
'resnet50',
'resnet152',
'ScaleMAE',
'scalemae_large_patch16',
'swin_v2_t',
'swin_v2_b',
'vit_small_patch16_224',
# weights
'DOFABase16_Weights',
'DOFALarge16_Weights',
'ResNet50_Weights',
'ResNet18_Weights',
'ResNet50_Weights',
'ResNet152_Weights',
'ScaleMAELarge16_Weights',
'Swin_V2_T_Weights',
'Swin_V2_B_Weights',
'ViTSmall16_Weights',
# utilities
Expand Down
17 changes: 15 additions & 2 deletions torchgeo/models/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,26 @@
dofa_base_patch16_224,
dofa_large_patch16_224,
)
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .resnet import (
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
resnet18,
resnet50,
resnet152,
)
from .scale_mae import ScaleMAELarge16_Weights, scalemae_large_patch16
from .swin import Swin_V2_B_Weights, swin_v2_b
from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
from .vit import ViTSmall16_Weights, vit_small_patch16_224

_model = {
'dofa_base_patch16_224': dofa_base_patch16_224,
'dofa_large_patch16_224': dofa_large_patch16_224,
'resnet18': resnet18,
'resnet50': resnet50,
'resnet152': resnet152,
'scalemae_large_patch16': scalemae_large_patch16,
'swin_v2_t': swin_v2_t,
'swin_v2_b': swin_v2_b,
'vit_small_patch16_224': vit_small_patch16_224,
}
Expand All @@ -42,14 +51,18 @@
dofa_large_patch16_224: DOFALarge16_Weights,
resnet18: ResNet18_Weights,
resnet50: ResNet50_Weights,
resnet152: ResNet152_Weights,
scalemae_large_patch16: ScaleMAELarge16_Weights,
swin_v2_t: Swin_V2_T_Weights,
swin_v2_b: Swin_V2_B_Weights,
vit_small_patch16_224: ViTSmall16_Weights,
'dofa_base_patch16_224': DOFABase16_Weights,
'dofa_large_patch16_224': DOFALarge16_Weights,
'resnet18': ResNet18_Weights,
'resnet50': ResNet50_Weights,
'resnet152': ResNet152_Weights,
'scalemae_large_patch16': ScaleMAELarge16_Weights,
'swin_v2_t': Swin_V2_T_Weights,
'swin_v2_b': Swin_V2_B_Weights,
'vit_small_patch16_224': ViTSmall16_Weights,
}
Expand Down
Loading