diff --git a/README.md b/README.md
index 7d7c898a6..f8eeed013 100644
--- a/README.md
+++ b/README.md
@@ -106,7 +106,7 @@ Don't forget to give the project a star! Thanks again!
## License 🔏
-Minerva is distributed under a [GNU LGPLv3 License](https://choosealicense.com/licenses/lgpl-3.0/).
+Minerva is distributed under a [MIT License](https://choosealicense.com/licenses/mit/).
(back to top)
diff --git a/minerva/__init__.py b/minerva/__init__.py
index 32472ec66..6ee793904 100644
--- a/minerva/__init__.py
+++ b/minerva/__init__.py
@@ -45,5 +45,5 @@
__version__ = "0.23.4"
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
diff --git a/minerva/datasets.py b/minerva/datasets.py
index 89a0f71dc..8645bfbd3 100644
--- a/minerva/datasets.py
+++ b/minerva/datasets.py
@@ -34,7 +34,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"PairedDataset",
@@ -384,7 +384,9 @@ def get_subdataset(
)
# Construct the root to the sub-dataset's files.
- sub_dataset_root: Path = universal_path(data_directory) / sub_dataset_params["root"]
+ sub_dataset_root: Path = (
+ universal_path(data_directory) / sub_dataset_params["root"]
+ )
sub_dataset_root = sub_dataset_root.absolute()
return _sub_dataset, str(sub_dataset_root)
diff --git a/minerva/logger.py b/minerva/logger.py
index e778b4b2d..210239522 100644
--- a/minerva/logger.py
+++ b/minerva/logger.py
@@ -31,7 +31,7 @@
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"MinervaLogger",
diff --git a/minerva/metrics.py b/minerva/metrics.py
index 957b09e16..86342e202 100644
--- a/minerva/metrics.py
+++ b/minerva/metrics.py
@@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"MinervaMetrics",
diff --git a/minerva/modelio.py b/minerva/modelio.py
index afe80864b..48bb6e512 100644
--- a/minerva/modelio.py
+++ b/minerva/modelio.py
@@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"sup_tg",
diff --git a/minerva/models/__depreciated.py b/minerva/models/__depreciated.py
index 6716cee40..3e4afcd04 100644
--- a/minerva/models/__depreciated.py
+++ b/minerva/models/__depreciated.py
@@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
diff --git a/minerva/models/__init__.py b/minerva/models/__init__.py
index 757990d60..8aa2595a1 100644
--- a/minerva/models/__init__.py
+++ b/minerva/models/__init__.py
@@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
# =====================================================================================================================
diff --git a/minerva/models/core.py b/minerva/models/core.py
index 2474c39ac..259ea15ca 100644
--- a/minerva/models/core.py
+++ b/minerva/models/core.py
@@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
@@ -449,7 +449,7 @@ def get_output_shape(
assert isinstance(_image_dim, Iterable)
random_input = torch.rand([4, *_image_dim])
- output: Tensor = model(random_input)
+ output: Tensor = model(random_input.to(next(model.parameters()).device))
if len(output[0].data.shape) == 1:
return output[0].data.shape[0]
diff --git a/minerva/models/fcn.py b/minerva/models/fcn.py
index 55cf52721..77db13dd6 100644
--- a/minerva/models/fcn.py
+++ b/minerva/models/fcn.py
@@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
diff --git a/minerva/models/resnet.py b/minerva/models/resnet.py
index d0af75a2c..84c48d42d 100644
--- a/minerva/models/resnet.py
+++ b/minerva/models/resnet.py
@@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"ResNet",
diff --git a/minerva/models/siamese.py b/minerva/models/siamese.py
index 9040a10e8..bd9b20bf4 100644
--- a/minerva/models/siamese.py
+++ b/minerva/models/siamese.py
@@ -31,7 +31,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"MinervaSiamese",
diff --git a/minerva/models/unet.py b/minerva/models/unet.py
index c6563f94e..327979855 100644
--- a/minerva/models/unet.py
+++ b/minerva/models/unet.py
@@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
diff --git a/minerva/samplers.py b/minerva/samplers.py
index a160208aa..e95e6b403 100644
--- a/minerva/samplers.py
+++ b/minerva/samplers.py
@@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"RandomPairGeoSampler",
diff --git a/minerva/trainer.py b/minerva/trainer.py
index baeff599a..18843ed5f 100644
--- a/minerva/trainer.py
+++ b/minerva/trainer.py
@@ -31,7 +31,7 @@
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = ["Trainer"]
@@ -39,6 +39,7 @@
# IMPORTS
# =====================================================================================================================
import os
+import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import (
@@ -61,6 +62,7 @@
from alive_progress import alive_bar, alive_it
from inputimeout import TimeoutOccurred, inputimeout
from nptyping import Int, NDArray
+from packaging.version import Version
from torch import Tensor
from torch.nn.modules import Module
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -414,6 +416,16 @@ def __init__(
)
self.model = MinervaDataParallel(self.model, DDP, device_ids=[gpu])
+ # Wraps the model in `torch.compile` to speed up computation time.
+ # Python 3.11+ is not yet supported though, hence the exception clause.
+ if Version(torch.__version__) > Version("2.0.0"): # pragma: no cover
+ try:
+ _compiled_model = torch.compile(self.model)
+ assert isinstance(_compiled_model, (MinervaModel, MinervaDataParallel))
+ self.model = _compiled_model
+ except RuntimeError as err:
+ warnings.warn(str(err))
+
def init_wandb_metrics(self) -> None:
"""Setups up separate step counters for :mod:`wandb` logging of train, val, etc."""
if isinstance(self.writer, Run):
@@ -1063,7 +1075,9 @@ def weighted_knn_validation(
for batch in test_bar:
test_data: Tensor = batch["image"].to(self.device, non_blocking=True)
test_target: Tensor = torch.mode(
- torch.flatten(batch["mask"], start_dim=1)
+ torch.flatten(
+ batch["mask"].to(self.device, non_blocking=True), start_dim=1
+ )
).values
# Get features from passing the input data through the model.
diff --git a/minerva/transforms.py b/minerva/transforms.py
index 1f05256c9..9e6d90ca8 100644
--- a/minerva/transforms.py
+++ b/minerva/transforms.py
@@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"ClassTransform",
diff --git a/minerva/utils/__init__.py b/minerva/utils/__init__.py
index c66942d25..68b5c8b41 100644
--- a/minerva/utils/__init__.py
+++ b/minerva/utils/__init__.py
@@ -37,7 +37,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"universal_path",
diff --git a/minerva/utils/config_load.py b/minerva/utils/config_load.py
index d3e5ae8d9..ce2a18ad3 100644
--- a/minerva/utils/config_load.py
+++ b/minerva/utils/config_load.py
@@ -34,7 +34,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"DEFAULT_CONF_DIR_PATH",
diff --git a/minerva/utils/runner.py b/minerva/utils/runner.py
index e4512e802..f1f5355fd 100644
--- a/minerva/utils/runner.py
+++ b/minerva/utils/runner.py
@@ -35,7 +35,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"GENERIC_PARSER",
@@ -51,8 +51,8 @@
# =====================================================================================================================
import argparse
import os
-import signal
import shlex
+import signal
import subprocess
from argparse import Namespace
from typing import Any, Callable, Optional, Union
diff --git a/minerva/utils/utils.py b/minerva/utils/utils.py
index 812f1d0e8..94d89e3bc 100644
--- a/minerva/utils/utils.py
+++ b/minerva/utils/utils.py
@@ -47,7 +47,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"IMAGERY_CONFIG_PATH",
@@ -121,12 +121,12 @@
import random
import re as regex
import shlex
-from subprocess import Popen
import sys
import webbrowser
from collections import Counter, OrderedDict
from datetime import datetime
from pathlib import Path
+from subprocess import Popen
from types import ModuleType
from typing import Any, Callable
from typing import Counter as CounterType
@@ -1669,7 +1669,9 @@ def run_tensorboard(
os.chdir(_path)
# Activates the correct Conda environment.
- Popen(shlex.split(f"conda activate {env_name}"), shell=True).wait() # nosec B607
+ Popen( # nosec B607, B602
+ shlex.split(f"conda activate {env_name}"), shell=True
+ ).wait()
if _testing:
os.chdir(cwd)
@@ -1677,7 +1679,9 @@ def run_tensorboard(
else: # pragma: no cover
# Runs TensorBoard log.
- Popen(shlex.split(f"tensorboard --logdir {exp_name}"), shell=True) # nosec B607
+ Popen( # nosec B607, B602
+ shlex.split(f"tensorboard --logdir {exp_name}"), shell=True
+ )
# Opens the TensorBoard log in a locally hosted webpage of the default system browser.
webbrowser.open(f"localhost:{host_num}")
diff --git a/minerva/utils/visutils.py b/minerva/utils/visutils.py
index db552c625..b85d01e7d 100644
--- a/minerva/utils/visutils.py
+++ b/minerva/utils/visutils.py
@@ -42,7 +42,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU LGPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"DATA_CONFIG",
diff --git a/scripts/ManifestMake.py b/scripts/ManifestMake.py
index 05e82816a..714f44b84 100644
--- a/scripts/ManifestMake.py
+++ b/scripts/ManifestMake.py
@@ -26,7 +26,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU GPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
diff --git a/scripts/MinervaClusterVis.py b/scripts/MinervaClusterVis.py
index c59ca0ec9..1df3d407a 100644
--- a/scripts/MinervaClusterVis.py
+++ b/scripts/MinervaClusterVis.py
@@ -30,7 +30,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU GPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
diff --git a/scripts/MinervaExp.py b/scripts/MinervaExp.py
index bddffa75b..cc95ac29e 100644
--- a/scripts/MinervaExp.py
+++ b/scripts/MinervaExp.py
@@ -33,7 +33,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU GPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
diff --git a/scripts/MinervaPipe.py b/scripts/MinervaPipe.py
index e7207d652..2595fd847 100644
--- a/scripts/MinervaPipe.py
+++ b/scripts/MinervaPipe.py
@@ -25,7 +25,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU GPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
@@ -54,7 +54,7 @@ def main(config_path: str):
)
try:
- exit_code = subprocess.Popen( # nosec B607
+ exit_code = subprocess.Popen( # nosec B607, B602
shlex.split(f"python MinervaExp.py -c {config[key]}"),
shell=True,
).wait()
diff --git a/scripts/RunTensorBoard.py b/scripts/RunTensorBoard.py
index d75f81e47..a0f9e7403 100644
--- a/scripts/RunTensorBoard.py
+++ b/scripts/RunTensorBoard.py
@@ -25,7 +25,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU GPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
diff --git a/scripts/TorchWeightDownloader.py b/scripts/TorchWeightDownloader.py
index 8a1760c6a..a22698883 100644
--- a/scripts/TorchWeightDownloader.py
+++ b/scripts/TorchWeightDownloader.py
@@ -29,7 +29,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU GPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
diff --git a/scripts/Torch_to_ONNX.py b/scripts/Torch_to_ONNX.py
index 2b4c35e0d..287382b10 100644
--- a/scripts/Torch_to_ONNX.py
+++ b/scripts/Torch_to_ONNX.py
@@ -25,7 +25,7 @@
# =====================================================================================================================
__author__ = "Harry Baker"
__contact__ = "hjb1d20@soton.ac.uk"
-__license__ = "GNU GPLv3"
+__license__ = "MIT License"
__copyright__ = "Copyright (C) 2023 Harry Baker"
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29bb..1aa720329 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r""":mod:`pytest` suite for :mod:`minerva` CI/CD.
+"""
+
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
diff --git a/tests/conftest.py b/tests/conftest.py
index 95ad75299..594062cbe 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,4 +1,36 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r""":mod:`pytest` fixtures for :mod:`minerva` CI/CD.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import os
import shutil
from pathlib import Path
@@ -18,6 +50,9 @@
from minerva.utils import CONFIG, utils
+# =====================================================================================================================
+# FIXTURES
+# =====================================================================================================================
@pytest.fixture(scope="session", autouse=True)
def set_seeds():
utils.set_seeds(42)
@@ -35,15 +70,20 @@ def results_dir():
@pytest.fixture
-def data_root():
+def data_root() -> Path:
return Path(__file__).parent / "tmp" / "results"
@pytest.fixture
-def img_root(data_root: Path):
+def img_root(data_root: Path) -> Path:
return data_root.parent / "data" / "test_images"
+@pytest.fixture
+def lc_root(data_root: Path) -> Path:
+ return data_root.parent / "data" / "test_lc"
+
+
@pytest.fixture
def config_root(data_root: Path):
config_path = data_root.parent / "config"
@@ -71,11 +111,36 @@ def config_here():
os.unlink(here / "exp_mf_config.yml")
+@pytest.fixture
+def default_device() -> torch.device:
+ return utils.get_cuda_device()
+
+
+@pytest.fixture
+def std_batch_size() -> int:
+ return 3
+
+
+@pytest.fixture
+def std_n_classes() -> int:
+ return 8
+
+
+@pytest.fixture
+def std_n_batches() -> int:
+ return 2
+
+
@pytest.fixture
def x_entropy_loss():
return nn.CrossEntropyLoss()
+@pytest.fixture
+def small_patch_size() -> Tuple[int, int]:
+ return (32, 32)
+
+
@pytest.fixture
def rgbi_input_size() -> Tuple[int, int, int]:
return (4, 64, 64)
@@ -92,32 +157,58 @@ def exp_cnn(x_entropy_loss, rgbi_input_size) -> MinervaModel:
@pytest.fixture
-def random_mask() -> NDArray[Shape["32, 32"], Int]:
- return np.random.randint(0, 7, size=(32, 32))
+def random_mask(small_patch_size, std_n_classes) -> NDArray[Shape["32, 32"], Int]:
+ return np.random.randint(0, std_n_classes - 1, size=small_patch_size)
@pytest.fixture
-def random_image() -> NDArray[Shape["32, 32, 3"], Float]:
- return np.random.rand(32, 32, 3)
+def random_image(small_patch_size) -> NDArray[Shape["32, 32, 3"], Float]:
+ return np.random.rand(*small_patch_size, 3)
@pytest.fixture
-def random_rgbi_image() -> NDArray[Shape["32, 32, 4"], Float]:
- return np.random.rand(32, 32, 4)
+def random_rgbi_image(small_patch_size) -> NDArray[Shape["32, 32, 4"], Float]:
+ return np.random.rand(*small_patch_size, 4)
@pytest.fixture
-def random_rgbi_tensor(rgbi_input_size) -> Tensor:
+def random_rgbi_tensor(rgbi_input_size: Tuple[int, int, int]) -> Tensor:
return torch.rand(rgbi_input_size)
@pytest.fixture
-def random_tensor_mask() -> LongTensor:
- mask = torch.randint(0, 7, size=(32, 32), dtype=torch.long)
+def random_rgbi_batch(
+ rgbi_input_size: Tuple[int, int, int], std_batch_size: int
+) -> Tensor:
+ return torch.rand((std_batch_size, *rgbi_input_size))
+
+
+@pytest.fixture
+def random_tensor_mask(std_n_classes: int, small_patch_size) -> LongTensor:
+ mask = torch.randint(0, std_n_classes - 1, size=small_patch_size, dtype=torch.long)
assert isinstance(mask, LongTensor)
return mask
+@pytest.fixture
+def random_mask_batch(
+ std_batch_size: int, std_n_classes: int, rgbi_input_size: Tuple[int, int, int]
+) -> LongTensor:
+ mask = torch.randint(
+ 0,
+ std_n_classes - 1,
+ size=(std_batch_size, *rgbi_input_size[1:]),
+ dtype=torch.long,
+ )
+ assert isinstance(mask, LongTensor)
+ return mask
+
+
+@pytest.fixture
+def random_scene_classification_batch(std_batch_size, std_n_classes) -> LongTensor:
+ return torch.randint(0, std_n_classes - 1, size=(std_batch_size,))
+
+
@pytest.fixture
def bounds_for_test_img() -> BoundingBox:
return BoundingBox(
diff --git a/tests/test_config.py b/tests/test_config.py
index f068a8e45..314b53a09 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.utils.config_load`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import os
from pathlib import Path
@@ -11,6 +42,9 @@
)
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_universal_path():
path1 = "one/two/three/file.txt"
path2 = ["one", "two", "three", "file.txt"]
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 47a58874b..510b84d47 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.datasets`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Union
@@ -19,13 +50,10 @@
from minerva.datasets import PairedDataset, TstImgDataset, TstMaskDataset
from minerva.utils.utils import CONFIG
-data_root = Path("tests", "tmp")
-img_root = str(data_root / "data" / "test_images")
-lc_root = str(data_root / "data" / "test_lc")
-
-bounds = BoundingBox(411248.0, 412484.0, 4058102.0, 4059399.0, 0, 1e12)
-
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_make_bounding_box() -> None:
assert mdt.make_bounding_box() is None
assert mdt.make_bounding_box(False) is None
@@ -40,7 +68,7 @@ def test_make_bounding_box() -> None:
_ = mdt.make_bounding_box(True)
-def test_tinydataset() -> None:
+def test_tinydataset(img_root: Path, lc_root: Path) -> None:
"""Source of TIFF: https://github.com/mommermi/geotiff_sample"""
imagery = TstImgDataset(img_root)
@@ -50,9 +78,10 @@ def test_tinydataset() -> None:
assert isinstance(dataset, IntersectionDataset)
-def test_paired_datasets() -> None:
+def test_paired_datasets(img_root: Path) -> None:
dataset = PairedDataset(TstImgDataset, img_root)
+ bounds = BoundingBox(411248.0, 412484.0, 4058102.0, 4059399.0, 0, 1e12)
query_1 = get_random_bounding_box(bounds, (32, 32), 10.0)
query_2 = get_random_bounding_box(bounds, (32, 32), 10.0)
@@ -121,7 +150,7 @@ def test_stack_sample_pairs() -> None:
assert_array_equal(stacked_samples_2[key][i], sample_2[key])
-def test_intersect_datasets() -> None:
+def test_intersect_datasets(img_root: Path, lc_root: Path) -> None:
imagery = PairedDataset(TstImgDataset, img_root)
labels = PairedDataset(TstMaskDataset, lc_root)
diff --git a/tests/test_fcn.py b/tests/test_fcn.py
index 6ad67f7a3..bd76fc68e 100644
--- a/tests/test_fcn.py
+++ b/tests/test_fcn.py
@@ -1,4 +1,37 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.models.fcn`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
+from typing import Tuple
+
import pytest
import torch
from torch import Tensor
@@ -20,104 +53,51 @@
)
from minerva.models.fcn import DCN
-input_size = (4, 64, 64)
-batch_size = 2
-n_classes = 8
-
-x = torch.rand((batch_size, *input_size))
-y = torch.randint(0, n_classes, (batch_size, *input_size[1:])) # type: ignore[attr-defined]
-
-def fcn_test(test_model: MinervaModel, x: Tensor, y: Tensor) -> None:
- optimiser = torch.optim.SGD(test_model.parameters(), lr=1.0e-3)
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
+@pytest.mark.parametrize(
+ "model_cls",
+ (
+ FCN8ResNet18,
+ FCN8ResNet34,
+ FCN8ResNet50,
+ FCN8ResNet101,
+ FCN8ResNet152,
+ FCN16ResNet18,
+ FCN16ResNet34,
+ FCN16ResNet50,
+ FCN32ResNet18,
+ FCN32ResNet34,
+ FCN32ResNet50,
+ ),
+)
+def test_fcn(
+ model_cls: MinervaModel,
+ x_entropy_loss,
+ random_rgbi_batch: Tensor,
+ random_mask_batch: Tensor,
+ std_batch_size: int,
+ std_n_classes: int,
+ rgbi_input_size: Tuple[int, int, int],
+) -> None:
+ model: MinervaModel = model_cls(x_entropy_loss, input_size=rgbi_input_size)
+ optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
- test_model.set_optimiser(optimiser)
+ model.set_optimiser(optimiser)
- test_model.determine_output_dim()
- assert test_model.output_shape == input_size[1:]
+ model.determine_output_dim()
+ assert model.output_shape == rgbi_input_size[1:]
- loss, z = test_model.step(x, y, True)
+ loss, z = model.step(random_rgbi_batch, random_mask_batch, True)
assert type(loss.item()) is float
assert isinstance(z, Tensor)
- assert z.size() == (batch_size, n_classes, *input_size[1:])
-
-
-def test_fcn32resnet18(x_entropy_loss) -> None:
- model = FCN32ResNet18(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn32resnet34(x_entropy_loss) -> None:
- model = FCN32ResNet34(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn32resnet50(x_entropy_loss) -> None:
- model = FCN32ResNet50(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn16resnet18(x_entropy_loss) -> None:
- model = FCN16ResNet18(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn16resnet34(x_entropy_loss) -> None:
- model = FCN16ResNet34(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn16resnet50(x_entropy_loss) -> None:
- model = FCN16ResNet50(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn8resnet18(x_entropy_loss) -> None:
- model = FCN8ResNet18(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn8resnet34(x_entropy_loss) -> None:
- model = FCN8ResNet34(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn8resnet50(x_entropy_loss) -> None:
- model = FCN8ResNet50(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn8resnet101(x_entropy_loss) -> None:
- model = FCN8ResNet101(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcn8resnet152(x_entropy_loss) -> None:
- model = FCN8ResNet152(x_entropy_loss, input_size=input_size)
- fcn_test(model, x, y)
-
-
-def test_fcnresnet_torch_weights(x_entropy_loss) -> None:
- for _model in (
- FCN8ResNet18,
- FCN16ResNet34,
- FCN32ResNet50,
- FCN8ResNet101,
- FCN8ResNet152,
- ):
- try:
- model = _model(
- x_entropy_loss,
- input_size=input_size,
- backbone_kwargs={"torch_weights": True},
- )
- fcn_test(model, x, y)
- except ImportError as err:
- print(err)
+ assert z.size() == (std_batch_size, std_n_classes, *rgbi_input_size[1:])
-def test_dcn() -> None:
+def test_dcn(random_rgbi_batch: Tensor) -> None:
with pytest.raises(
NotImplementedError, match="Variant 42 does not match known types"
):
@@ -129,4 +109,4 @@ def test_dcn() -> None:
NotImplementedError, match="Variant 42 does not match known types"
):
dcn.variant = "42" # type: ignore[assignment]
- _ = dcn.forward(resnet(torch.rand((batch_size, *input_size))))
+ _ = dcn.forward(resnet(random_rgbi_batch))
diff --git a/tests/test_logger.py b/tests/test_logger.py
index 7bfd59241..06ed53bcc 100644
--- a/tests/test_logger.py
+++ b/tests/test_logger.py
@@ -1,9 +1,41 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.logger`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
import importlib
import shutil
import tempfile
from pathlib import Path
-from typing import Any, Dict, List, Union
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
+from typing import Any, Dict, List, Tuple, Union
import numpy as np
import torch
@@ -18,22 +50,26 @@
from nptyping import NDArray, Shape
from numpy.testing import assert_array_equal
from torch import Tensor
+from torchgeo.datasets.utils import BoundingBox
from minerva.logger import SSLLogger, STGLogger
from minerva.modelio import ssl_pair_tg, sup_tg
from minerva.models import FCN16ResNet18, SimCLR18
from minerva.utils import utils
-device = torch.device("cpu") # type: ignore[attr-defined]
-n_batches = 2
-batch_size = 3
-patch_size = (32, 32)
-n_classes = 8
-
-
-def test_STGLogger(simple_bbox):
- criterion = nn.CrossEntropyLoss()
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
+def test_STGLogger(
+ simple_bbox: BoundingBox,
+ x_entropy_loss,
+ std_n_batches: int,
+ std_n_classes: int,
+ std_batch_size: int,
+ small_patch_size: Tuple[int, int],
+ default_device: torch.device,
+) -> None:
path = Path(tempfile.gettempdir(), "exp1")
if not path.exists():
@@ -49,7 +85,9 @@ def test_STGLogger(simple_bbox):
else:
writer = tensorboard_writer(log_dir=path)
- model = FCN16ResNet18(criterion, input_size=(4, *patch_size))
+ model = FCN16ResNet18(x_entropy_loss, input_size=(4, *small_patch_size)).to(
+ default_device
+ )
optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
model.set_optimiser(optimiser)
model.determine_output_dim()
@@ -60,21 +98,24 @@ def test_STGLogger(simple_bbox):
for mode in ("train", "val", "test"):
for model_type in ("scene_classifier", "segmentation"):
logger = STGLogger(
- n_batches=n_batches,
- batch_size=batch_size,
- n_samples=n_batches * batch_size * patch_size[0] * patch_size[1],
+ n_batches=std_n_batches,
+ batch_size=std_batch_size,
+ n_samples=std_n_batches
+ * std_batch_size
+ * small_patch_size[0]
+ * small_patch_size[1],
out_shape=output_shape,
- n_classes=n_classes,
+ n_classes=std_n_classes,
record_int=True,
record_float=True,
model_type=model_type,
writer=writer,
)
data: List[Dict[str, Union[Tensor, List[Any]]]] = []
- for i in range(n_batches):
- images = torch.rand(size=(batch_size, 4, *patch_size))
- masks = torch.randint(0, n_classes, (batch_size, *patch_size)) # type: ignore[attr-defined]
- bboxes = [simple_bbox] * batch_size
+ for i in range(std_n_batches):
+ images = torch.rand(size=(std_batch_size, 4, *small_patch_size))
+ masks = torch.randint(0, std_n_classes, (std_batch_size, *small_patch_size)) # type: ignore[attr-defined]
+ bboxes = [simple_bbox] * std_batch_size
batch: Dict[str, Union[Tensor, List[Any]]] = {
"image": images,
"mask": masks,
@@ -82,10 +123,10 @@ def test_STGLogger(simple_bbox):
}
data.append(batch)
- logger(mode, i, *sup_tg(batch, model, device=device, mode=mode))
+ logger(mode, i, *sup_tg(batch, model, device=default_device, mode=mode))
logs = logger.get_logs
- assert logs["batch_num"] == n_batches
+ assert logs["batch_num"] == std_n_batches
assert type(logs["total_loss"]) is float
assert type(logs["total_correct"]) is float
@@ -93,15 +134,23 @@ def test_STGLogger(simple_bbox):
assert type(logs["total_miou"]) is float
results = logger.get_results
- assert results["z"].shape == (n_batches, batch_size, *patch_size)
- assert results["y"].shape == (n_batches, batch_size, *patch_size)
- assert np.array(results["ids"]).shape == (n_batches, batch_size)
+ assert results["z"].shape == (
+ std_n_batches,
+ std_batch_size,
+ *small_patch_size,
+ )
+ assert results["y"].shape == (
+ std_n_batches,
+ std_batch_size,
+ *small_patch_size,
+ )
+ assert np.array(results["ids"]).shape == (std_n_batches, std_batch_size)
- shape = f"{n_batches}, {batch_size}, {patch_size[0]}, {patch_size[1]}"
+ shape = f"{std_n_batches}, {std_batch_size}, {small_patch_size[0]}, {small_patch_size[1]}"
y: NDArray[Shape[shape], Any] = np.empty(
- (n_batches, batch_size, *output_shape), dtype=np.uint8
+ (std_n_batches, std_batch_size, *output_shape), dtype=np.uint8
)
- for i in range(n_batches):
+ for i in range(std_n_batches):
mask: Union[Tensor, List[Any]] = data[i]["mask"]
assert isinstance(mask, Tensor)
y[i] = mask.cpu().numpy()
@@ -111,7 +160,13 @@ def test_STGLogger(simple_bbox):
shutil.rmtree(path, ignore_errors=True)
-def test_SSLLogger(simple_bbox):
+def test_SSLLogger(
+ simple_bbox: BoundingBox,
+ std_n_batches: int,
+ std_batch_size: int,
+ small_patch_size: Tuple[int, int],
+ default_device: torch.device,
+) -> None:
criterion = NTXentLoss(0.5)
path = Path(tempfile.gettempdir(), "exp2")
@@ -129,16 +184,16 @@ def test_SSLLogger(simple_bbox):
else:
writer = tensorboard_writer(log_dir=path)
- model = SimCLR18(criterion, input_size=(4, *patch_size))
+ model = SimCLR18(criterion, input_size=(4, *small_patch_size)).to(default_device)
optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
model.set_optimiser(optimiser)
for mode in ("train", "val", "test"):
for extra_metrics in (True, False):
logger = SSLLogger(
- n_batches=n_batches,
- batch_size=batch_size,
- n_samples=n_batches * batch_size,
+ n_batches=std_n_batches,
+ batch_size=std_batch_size,
+ n_samples=std_n_batches * std_batch_size,
record_int=True,
record_float=True,
collapse_level=extra_metrics,
@@ -146,9 +201,9 @@ def test_SSLLogger(simple_bbox):
writer=writer,
)
data = []
- for i in range(n_batches):
- images = torch.rand(size=(batch_size, 4, *patch_size))
- bboxes = [simple_bbox] * batch_size
+ for i in range(std_n_batches):
+ images = torch.rand(size=(std_batch_size, 4, *small_patch_size))
+ bboxes = [simple_bbox] * std_batch_size
batch = {
"image": images,
"bbox": bboxes,
@@ -158,11 +213,13 @@ def test_SSLLogger(simple_bbox):
logger(
mode,
i,
- *ssl_pair_tg((batch, batch), model, device=device, mode=mode),
+ *ssl_pair_tg(
+ (batch, batch), model, device=default_device, mode=mode
+ ),
)
logs = logger.get_logs
- assert logs["batch_num"] == n_batches
+ assert logs["batch_num"] == std_n_batches
assert type(logs["total_loss"]) is float
assert type(logs["total_correct"]) is float
assert type(logs["total_top5"]) is float
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 5553fb17b..08a9745fb 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.metrics`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import random
from typing import Dict, List
@@ -7,6 +38,9 @@
from minerva.metrics import MinervaMetrics, SPMetrics, SSLMetrics
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_minervametrics() -> None:
assert issubclass(SPMetrics, MinervaMetrics)
diff --git a/tests/test_modelio.py b/tests/test_modelio.py
index 80b7be3ba..d1ab8d2c4 100644
--- a/tests/test_modelio.py
+++ b/tests/test_modelio.py
@@ -1,6 +1,37 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.modelio`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import importlib
-from typing import Any, Dict, List, Union
+from typing import Any, Dict, List, Tuple, Union
import torch
import torch.nn.modules as nn
@@ -14,85 +45,107 @@
import pytest
from numpy.testing import assert_array_equal
from torch import Tensor
+from torchgeo.datasets.utils import BoundingBox
from minerva.modelio import autoencoder_io, ssl_pair_tg, sup_tg
from minerva.models import FCN32ResNet18, SimCLR34
-input_size = (4, 64, 64)
-batch_size = 3
-n_classes = 8
-device = torch.device("cpu") # type: ignore[attr-defined]
-
-def test_sup_tg(simple_bbox) -> None:
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
+def test_sup_tg(
+ simple_bbox: BoundingBox,
+ random_rgbi_batch: Tensor,
+ random_mask_batch: Tensor,
+ std_batch_size: int,
+ std_n_classes: int,
+ rgbi_input_size: Tuple[int, int, int],
+ default_device: torch.device,
+) -> None:
criterion = nn.CrossEntropyLoss()
- model = FCN32ResNet18(criterion, input_size=input_size)
+ model = FCN32ResNet18(criterion, input_size=rgbi_input_size).to(default_device)
optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
model.set_optimiser(optimiser)
for mode in ("train", "val", "test"):
- images = torch.rand(size=(batch_size, *input_size))
- masks = torch.randint(0, n_classes, (batch_size, *input_size[1:])) # type: ignore[attr-defined]
- bboxes = [simple_bbox] * batch_size
+ bboxes = [simple_bbox] * std_batch_size
batch: Dict[str, Union[Tensor, List[Any]]] = {
- "image": images,
- "mask": masks,
+ "image": random_rgbi_batch,
+ "mask": random_mask_batch,
"bbox": bboxes,
}
- results = sup_tg(batch, model, device, mode)
+ results = sup_tg(batch, model, default_device, mode)
assert isinstance(results[0], Tensor)
assert isinstance(results[1], Tensor)
- assert results[1].size() == (batch_size, n_classes, *input_size[1:])
- assert_array_equal(results[2], batch["mask"])
+ assert results[1].size() == (
+ std_batch_size,
+ std_n_classes,
+ *rgbi_input_size[1:],
+ )
+ assert_array_equal(results[2].detach().cpu(), batch["mask"].detach().cpu())
assert results[3] == batch["bbox"]
-def test_ssl_pair_tg(simple_bbox) -> None:
+def test_ssl_pair_tg(
+ simple_bbox: BoundingBox,
+ std_batch_size: int,
+ rgbi_input_size: Tuple[int, int, int],
+ default_device: torch.device,
+) -> None:
criterion = NTXentLoss(0.5)
- model = SimCLR34(criterion, input_size=input_size)
+ model = SimCLR34(criterion, input_size=rgbi_input_size).to(default_device)
optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
model.set_optimiser(optimiser)
for mode in ("train", "val"):
- images_1 = torch.rand(size=(batch_size, *input_size))
- bboxes_1 = [simple_bbox] * batch_size
+ images_1 = torch.rand(size=(std_batch_size, *rgbi_input_size))
+ bboxes_1 = [simple_bbox] * std_batch_size
batch_1 = {
"image": images_1,
"bbox": bboxes_1,
}
- images_2 = torch.rand(size=(batch_size, *input_size))
- bboxes_2 = [simple_bbox] * batch_size
+ images_2 = torch.rand(size=(std_batch_size, *rgbi_input_size))
+ bboxes_2 = [simple_bbox] * std_batch_size
batch_2 = {
"image": images_2,
"bbox": bboxes_2,
}
- results = ssl_pair_tg((batch_1, batch_2), model, device, mode)
+ results = ssl_pair_tg((batch_1, batch_2), model, default_device, mode)
assert isinstance(results[0], Tensor)
assert isinstance(results[1], Tensor)
- assert results[1].size() == (2 * batch_size, 128)
+ assert results[1].size() == (2 * std_batch_size, 128)
assert results[2] is None
assert isinstance(batch_1["bbox"], list)
assert isinstance(batch_2["bbox"], list)
assert results[3] == batch_1["bbox"] + batch_2["bbox"]
-def test_mask_autoencoder_io(simple_bbox) -> None:
+def test_mask_autoencoder_io(
+ simple_bbox: BoundingBox,
+ std_batch_size: int,
+ std_n_classes: int,
+ rgbi_input_size: Tuple[int, int, int],
+ default_device: torch.device,
+) -> None:
criterion = nn.CrossEntropyLoss()
- model = FCN32ResNet18(criterion, input_size=(8, *input_size[1:]))
+ model = FCN32ResNet18(criterion, input_size=(8, *rgbi_input_size[1:])).to(
+ default_device
+ )
optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
model.set_optimiser(optimiser)
for mode in ("train", "val", "test"):
- images = torch.rand(size=(batch_size, *input_size))
- masks = torch.randint(0, 8, (batch_size, *input_size[1:])) # type: ignore[attr-defined]
- bboxes = [simple_bbox] * batch_size
+ images = torch.rand(size=(std_batch_size, *rgbi_input_size))
+ masks = torch.randint(0, 8, (std_batch_size, *rgbi_input_size[1:])) # type: ignore[attr-defined]
+ bboxes = [simple_bbox] * std_batch_size
batch: Dict[str, Union[Tensor, List[Any]]] = {
"image": images,
"mask": masks,
@@ -103,41 +156,54 @@ def test_mask_autoencoder_io(simple_bbox) -> None:
ValueError,
match="The value of key='wrong' is not understood. Must be either 'mask' or 'image'",
):
- autoencoder_io(batch, model, device, mode, autoencoder_data_key="wrong")
+ autoencoder_io(
+ batch, model, default_device, mode, autoencoder_data_key="wrong"
+ )
results = autoencoder_io(
- batch, model, device, mode, autoencoder_data_key="mask"
+ batch, model, default_device, mode, autoencoder_data_key="mask"
)
assert isinstance(results[0], Tensor)
assert isinstance(results[1], Tensor)
- assert results[1].size() == (batch_size, n_classes, *input_size[1:])
- assert_array_equal(results[2], batch["mask"])
+ assert results[1].size() == (
+ std_batch_size,
+ std_n_classes,
+ *rgbi_input_size[1:],
+ )
+ assert_array_equal(results[2].detach().cpu(), batch["mask"].detach().cpu())
assert results[3] == batch["bbox"]
-def test_image_autoencoder_io(simple_bbox) -> None:
+def test_image_autoencoder_io(
+ simple_bbox: BoundingBox,
+ random_rgbi_batch: Tensor,
+ random_mask_batch: Tensor,
+ std_batch_size: int,
+ rgbi_input_size: Tuple[int, int, int],
+ default_device: torch.device,
+) -> None:
criterion = nn.CrossEntropyLoss()
- model = FCN32ResNet18(criterion, input_size=input_size, n_classes=4)
+ model = FCN32ResNet18(criterion, input_size=rgbi_input_size, n_classes=4).to(
+ default_device
+ )
optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
model.set_optimiser(optimiser)
for mode in ("train", "val", "test"):
- images = torch.rand(size=(batch_size, *input_size))
- masks = torch.randint(0, 8, (batch_size, *input_size[1:])) # type: ignore[attr-defined]
- bboxes = [simple_bbox] * batch_size
+ bboxes = [simple_bbox] * std_batch_size
batch: Dict[str, Union[Tensor, List[Any]]] = {
- "image": images,
- "mask": masks,
+ "image": random_rgbi_batch,
+ "mask": random_mask_batch,
"bbox": bboxes,
}
results = autoencoder_io(
- batch, model, device, mode, autoencoder_data_key="image"
+ batch, model, default_device, mode, autoencoder_data_key="image"
)
assert isinstance(results[0], Tensor)
assert isinstance(results[1], Tensor)
- assert results[1].size() == (batch_size, *input_size)
- assert_array_equal(results[2], batch["image"])
+ assert results[1].size() == (std_batch_size, *rgbi_input_size)
+ assert_array_equal(results[2].detach().cpu(), batch["image"].detach().cpu())
assert results[3] == batch["bbox"]
diff --git a/tests/test_models_core.py b/tests/test_models_core.py
index b0797ab4e..5dc003cc5 100644
--- a/tests/test_models_core.py
+++ b/tests/test_models_core.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.models.core`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import importlib
import internet_sabotage
@@ -30,6 +61,9 @@
from minerva.models.__depreciated import MLP
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_minerva_model(x_entropy_loss) -> None:
x = torch.rand(16, (288))
y = torch.LongTensor(np.random.randint(0, 8, size=16))
diff --git a/tests/test_models_depreciated.py b/tests/test_models_depreciated.py
index 95377c321..fc559305c 100644
--- a/tests/test_models_depreciated.py
+++ b/tests/test_models_depreciated.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.models._depreciated`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import numpy as np
import torch
from torch import Tensor
@@ -6,6 +37,9 @@
from minerva.models.__depreciated import CNN, MLP
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_mlp(x_entropy_loss) -> None:
model = MLP(x_entropy_loss, hidden_sizes=128)
diff --git a/tests/test_optimisers.py b/tests/test_optimisers.py
index 0c874fc1e..1da966446 100644
--- a/tests/test_optimisers.py
+++ b/tests/test_optimisers.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.optimsiers`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import pytest
import torch
import torch.nn.modules as nn
@@ -7,6 +38,9 @@
from minerva.optimisers import LARS
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_lars() -> None:
model = CNN(nn.CrossEntropyLoss(), input_size=(3, 224, 224))
diff --git a/tests/test_pytorchtools.py b/tests/test_pytorchtools.py
index 1596ff61c..711d7df90 100644
--- a/tests/test_pytorchtools.py
+++ b/tests/test_pytorchtools.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.pytorchtools`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import tempfile
from pathlib import Path
@@ -7,6 +38,9 @@
from minerva.pytorchtools import EarlyStopping
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_earlystopping() -> None:
path = Path(tempfile.gettempdir(), "exp1.pt")
diff --git a/tests/test_resnets.py b/tests/test_resnets.py
index 04c60f81a..d1c3e02e1 100644
--- a/tests/test_resnets.py
+++ b/tests/test_resnets.py
@@ -1,5 +1,37 @@
# -*- coding: utf-8 -*-
-import numpy as np
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.models.resnet`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
+from typing import Tuple
+
import pytest
import torch
from torch import LongTensor, Tensor
@@ -15,73 +47,99 @@
)
from minerva.models.resnet import ResNet, _preload_weights
-input_size = (4, 64, 64)
-
-x = torch.rand(6, *input_size)
-y = LongTensor(np.random.randint(0, 8, size=6))
-
-def resnet_test(test_model: MinervaModel, x: Tensor, y: Tensor) -> None:
- optimiser = torch.optim.SGD(test_model.parameters(), lr=1.0e-3)
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
+def resnet_test(
+ model: MinervaModel, x: Tensor, y: LongTensor, batch_size: int, n_classes: int
+) -> None:
+ optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
- test_model.set_optimiser(optimiser)
+ model.set_optimiser(optimiser)
- test_model.determine_output_dim()
- assert test_model.output_shape is test_model.n_classes
+ model.determine_output_dim()
+ assert model.output_shape is model.n_classes
- loss, z = test_model.step(x, y, True)
+ loss, z = model.step(x, y, True)
assert type(loss.item()) is float
assert isinstance(z, Tensor)
- assert z.size() == (6, 8)
-
-
-def test_resnet():
+ assert z.size() == (batch_size, n_classes)
+
+
+@pytest.mark.parametrize(
+ "model_cls",
+ (
+ ResNet18,
+ ResNet34,
+ ResNet50,
+ ResNet101,
+ ResNet152,
+ ),
+)
+@pytest.mark.parametrize("zero_init", (True, False))
+def test_resnets(
+ model_cls: MinervaModel,
+ zero_init: bool,
+ x_entropy_loss,
+ random_rgbi_batch: Tensor,
+ random_scene_classification_batch: LongTensor,
+ std_batch_size: int,
+ std_n_classes: int,
+ rgbi_input_size: Tuple[int, int, int],
+) -> None:
+ model: MinervaModel = model_cls(
+ x_entropy_loss, input_size=rgbi_input_size, zero_init_residual=zero_init
+ )
+ resnet_test(
+ model,
+ random_rgbi_batch,
+ random_scene_classification_batch,
+ std_batch_size,
+ std_n_classes,
+ )
+
+
+def test_resnet() -> None:
assert isinstance(ResNet(BasicBlock, [2, 2, 2, 2], groups=2), ResNet)
-
-def test_resnet18(x_entropy_loss) -> None:
with pytest.raises(ValueError):
_ = ResNet18(replace_stride_with_dilation=(True, False)) # type: ignore[arg-type]
- for zero_init_residual in (True, False):
- resnet18 = ResNet18(
- x_entropy_loss, input_size=input_size, zero_init_residual=zero_init_residual
- )
-
- resnet_test(resnet18, x, y)
-
-def test_resnet34(x_entropy_loss) -> None:
- model = ResNet34(x_entropy_loss, input_size=input_size)
- resnet_test(model, x, y)
-
-
-def test_resnet50(x_entropy_loss) -> None:
+def test_replace_stride(
+ x_entropy_loss,
+ random_rgbi_batch: Tensor,
+ random_scene_classification_batch: LongTensor,
+ std_batch_size: int,
+ std_n_classes: int,
+ rgbi_input_size: Tuple[int, int, int],
+) -> None:
for model in (
- ResNet50(x_entropy_loss, input_size=input_size),
+ ResNet50(x_entropy_loss, input_size=rgbi_input_size),
ResNet50(
x_entropy_loss,
- input_size=input_size,
+ input_size=rgbi_input_size,
replace_stride_with_dilation=(True, True, False),
zero_init_residual=True,
),
):
- resnet_test(model, x, y)
-
-
-def test_resnet101(x_entropy_loss) -> None:
- model = ResNet101(x_entropy_loss, input_size=input_size)
- resnet_test(model, x, y)
-
-
-def test_resnet152(x_entropy_loss) -> None:
- model = ResNet152(x_entropy_loss, input_size=input_size)
- resnet_test(model, x, y)
+ resnet_test(
+ model,
+ random_rgbi_batch,
+ random_scene_classification_batch,
+ std_batch_size,
+ std_n_classes,
+ )
-def test_resnet_encoder(x_entropy_loss) -> None:
- encoder = ResNet18(x_entropy_loss, input_size=input_size, encoder=True)
+def test_resnet_encoder(
+ x_entropy_loss,
+ random_rgbi_batch: Tensor,
+ rgbi_input_size: Tuple[int, int, int],
+) -> None:
+ encoder = ResNet18(x_entropy_loss, input_size=rgbi_input_size, encoder=True)
optimiser = torch.optim.SGD(encoder.parameters(), lr=1.0e-3)
encoder.set_optimiser(optimiser)
@@ -90,11 +148,10 @@ def test_resnet_encoder(x_entropy_loss) -> None:
print(encoder.output_shape)
assert encoder.output_shape == (512, 2, 2)
- x = torch.rand(6, *input_size)
- assert len(encoder(x)) == 5
+ assert len(encoder(random_rgbi_batch)) == 5
-def test_preload_weights():
+def test_preload_weights() -> None:
resnet = ResNet(BasicBlock, [2, 2, 2, 2])
new_resnet = _preload_weights(resnet, None, (4, 32, 32), encoder_on=False)
diff --git a/tests/test_runner.py b/tests/test_runner.py
index 353a057b0..185bc4c5a 100644
--- a/tests/test_runner.py
+++ b/tests/test_runner.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.utils.runner`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import os
import subprocess
import time
@@ -12,6 +43,9 @@
from minerva.utils import CONFIG, runner
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_wandb_connection_manager() -> None:
try:
requests.head("http://www.wandb.ai/", timeout=0.1)
diff --git a/tests/test_samplers.py b/tests/test_samplers.py
index f23013c2e..3eb03fbae 100644
--- a/tests/test_samplers.py
+++ b/tests/test_samplers.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.samplers`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict
@@ -14,12 +45,11 @@
get_greater_bbox,
)
-data_root = Path("tests", "tmp")
-img_root = str(data_root / "data" / "test_images")
-lc_root = str(data_root / "data" / "test_lc")
-
-def test_randompairgeosampler() -> None:
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
+def test_randompairgeosampler(img_root: Path) -> None:
dataset = PairedDataset(TstImgDataset, img_root, res=1.0)
sampler = RandomPairGeoSampler(dataset, size=32, length=32, max_r=52)
@@ -35,7 +65,7 @@ def test_randompairgeosampler() -> None:
assert len(batch[1]["image"]) == 8
-def test_randompairbatchgeosampler() -> None:
+def test_randompairbatchgeosampler(img_root: Path) -> None:
dataset = PairedDataset(TstImgDataset, img_root, res=1.0)
sampler = RandomPairBatchGeoSampler(
@@ -62,6 +92,6 @@ def test_randompairbatchgeosampler() -> None:
)
-def test_get_greater_bbox(simple_bbox) -> None:
+def test_get_greater_bbox(simple_bbox: BoundingBox) -> None:
new_bbox = get_greater_bbox(simple_bbox, 1.0, 1.0)
assert new_bbox == BoundingBox(-1.0, 2.0, -1.0, 2.0, 0.0, 1.0)
diff --git a/tests/test_siamese.py b/tests/test_siamese.py
index 1d87ef9cd..8bfce9254 100644
--- a/tests/test_siamese.py
+++ b/tests/test_siamese.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.models.siamese`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import importlib
import pytest
@@ -17,6 +48,9 @@
from minerva.models import SimCLR18, SimCLR34, SimCLR50, SimSiam18, SimSiam34, SimSiam50
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_simclr() -> None:
loss_func = NTXentLoss(0.3)
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
index b8da948e4..1d5edb0d4 100644
--- a/tests/test_trainer.py
+++ b/tests/test_trainer.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.trainer`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import argparse
import shutil
from pathlib import Path
@@ -11,6 +42,9 @@
from minerva.utils import CONFIG, config_load, runner, utils
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def run_trainer(gpu: int, args: argparse.Namespace):
args.gpu = gpu
params = CONFIG.copy()
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index 9d5568faa..2c46644e2 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.transforms`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import pytest
import torch
from numpy.testing import assert_array_equal
@@ -18,6 +49,9 @@
from minerva.utils import utils
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_class_transform(simple_mask, example_matrix) -> None:
transform = ClassTransform(example_matrix)
diff --git a/tests/test_unet.py b/tests/test_unet.py
index 0404384a1..afe9d7212 100644
--- a/tests/test_unet.py
+++ b/tests/test_unet.py
@@ -1,4 +1,38 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.models.unet`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
+from typing import Tuple
+
+import pytest
import torch
from torch import Tensor
@@ -12,57 +46,89 @@
UNetR152,
)
-input_size = (4, 64, 64)
-batch_size = 2
-n_classes = 8
-x = torch.rand((batch_size, *input_size))
-y = torch.randint(0, n_classes, (batch_size, *input_size[1:])) # type: ignore[attr-defined]
+# =====================================================================================================================
+# METHODS
+# =====================================================================================================================
+def unet_test(
+ model: MinervaModel,
+ x: Tensor,
+ y: Tensor,
+ batch_size: int,
+ n_classes: int,
+ input_size: Tuple[int, int, int],
+) -> None:
+ optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
+ model.set_optimiser(optimiser)
+ model.determine_output_dim()
+ assert model.output_shape == input_size[1:]
-def unet_test(test_model: MinervaModel, x: Tensor, y: Tensor) -> None:
- optimiser = torch.optim.SGD(test_model.parameters(), lr=1.0e-3)
-
- test_model.set_optimiser(optimiser)
-
- test_model.determine_output_dim()
- assert test_model.output_shape == input_size[1:]
-
- loss, z = test_model.step(x, y, True)
+ loss, z = model.step(x, y, True)
assert type(loss.item()) is float
assert isinstance(z, Tensor)
assert z.size() == (batch_size, n_classes, *input_size[1:])
-def test_unet(x_entropy_loss) -> None:
- model = UNet(x_entropy_loss, input_size=input_size)
- unet_test(model, x, y)
-
- bilinear_model = UNet(x_entropy_loss, input_size=input_size, bilinear=True)
- unet_test(bilinear_model, x, y)
-
-
-def test_unetr18(x_entropy_loss) -> None:
- model = UNetR18(x_entropy_loss, input_size=input_size)
- unet_test(model, x, y)
-
-
-def test_unetr34(x_entropy_loss) -> None:
- model = UNetR34(x_entropy_loss, input_size=input_size)
- unet_test(model, x, y)
-
-
-def test_unetr50(x_entropy_loss) -> None:
- model = UNetR50(x_entropy_loss, input_size=input_size)
- unet_test(model, x, y)
-
-
-def test_unetr101(x_entropy_loss) -> None:
- model = UNetR101(x_entropy_loss, input_size=input_size)
- unet_test(model, x, y)
-
-
-def test_unetr152(x_entropy_loss) -> None:
- model = UNetR152(x_entropy_loss, input_size=input_size)
- unet_test(model, x, y)
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
+@pytest.mark.parametrize(
+ "model_cls",
+ (
+ UNetR18,
+ UNetR34,
+ UNetR50,
+ UNetR101,
+ UNetR152,
+ ),
+)
+def test_unetrs(
+ model_cls: MinervaModel,
+ x_entropy_loss,
+ random_rgbi_batch: Tensor,
+ random_mask_batch: Tensor,
+ std_batch_size: int,
+ std_n_classes: int,
+ rgbi_input_size: Tuple[int, int, int],
+) -> None:
+ model: MinervaModel = model_cls(x_entropy_loss, rgbi_input_size)
+
+ unet_test(
+ model,
+ random_rgbi_batch,
+ random_mask_batch,
+ std_batch_size,
+ std_n_classes,
+ rgbi_input_size,
+ )
+
+
+def test_unet(
+ x_entropy_loss,
+ random_rgbi_batch: Tensor,
+ random_mask_batch: Tensor,
+ std_batch_size: int,
+ std_n_classes: int,
+ rgbi_input_size: Tuple[int, int, int],
+) -> None:
+ model = UNet(x_entropy_loss, input_size=rgbi_input_size)
+ unet_test(
+ model,
+ random_rgbi_batch,
+ random_mask_batch,
+ std_batch_size,
+ std_n_classes,
+ rgbi_input_size,
+ )
+
+ bilinear_model = UNet(x_entropy_loss, input_size=rgbi_input_size, bilinear=True)
+ unet_test(
+ bilinear_model,
+ random_rgbi_batch,
+ random_mask_batch,
+ std_batch_size,
+ std_n_classes,
+ rgbi_input_size,
+ )
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d45a47ddb..9e450ee1b 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.utils.utils`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import cmath
import math
import os
@@ -26,6 +57,9 @@
from minerva.utils import AUX_CONFIGS, CONFIG, utils, visutils
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_print_banner() -> None:
utils._print_banner()
diff --git a/tests/test_visutils.py b/tests/test_visutils.py
index 1dfb09d59..ef64608c7 100644
--- a/tests/test_visutils.py
+++ b/tests/test_visutils.py
@@ -1,4 +1,35 @@
# -*- coding: utf-8 -*-
+# Copyright (C) 2023 Harry Baker
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program in LICENSE.txt. If not,
+# see .
+#
+# @org: University of Southampton
+# Created under a project funded by the Ordnance Survey Ltd.
+r"""Tests for :mod:`minerva.utils.visutils`.
+"""
+# =====================================================================================================================
+# METADATA
+# =====================================================================================================================
+__author__ = "Harry Baker"
+__contact__ = "hjb1d20@soton.ac.uk"
+__license__ = "MIT License"
+__copyright__ = "Copyright (C) 2023 Harry Baker"
+
+# =====================================================================================================================
+# IMPORTS
+# =====================================================================================================================
import os
import shutil
import tempfile
@@ -20,6 +51,9 @@
from minerva.utils import utils, visutils
+# =====================================================================================================================
+# TESTS
+# =====================================================================================================================
def test_de_interlace() -> None:
x_1 = [1, 1, 1, 1, 1]
x_2 = [2, 2, 2, 2, 2]