From 511b3faf0a592b5128c686f0590faf40b5f9cc82 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 30 Nov 2023 11:12:51 +0100 Subject: [PATCH] Helper functions for new segmentation domain (#2105) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 + requirements/_devel.txt | 1 + requirements/segmentation_test.txt | 5 + .../functional/segmentation/__init__.py | 13 + .../functional/segmentation/utils.py | 781 ++++++++++++++++++ tests/unittests/segmentation/__init__.py | 13 + tests/unittests/segmentation/test_utils.py | 245 ++++++ 7 files changed, 1061 insertions(+) create mode 100644 requirements/segmentation_test.txt create mode 100644 src/torchmetrics/functional/segmentation/__init__.py create mode 100644 src/torchmetrics/functional/segmentation/utils.py create mode 100644 tests/unittests/segmentation/__init__.py create mode 100644 tests/unittests/segmentation/test_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 468b459608b..2e8b005f380 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added utility functions in `segmentation.utils` for future segmentation metrics ([#2105](https://github.com/Lightning-AI/torchmetrics/pull/2105)) + + - Added more tokenizers for `SacreBLEU` metric ([#2068](https://github.com/Lightning-AI/torchmetrics/pull/2068)) diff --git a/requirements/_devel.txt b/requirements/_devel.txt index 6a80916918a..596cc138133 100644 --- a/requirements/_devel.txt +++ b/requirements/_devel.txt @@ -19,3 +19,4 @@ -r detection_test.txt -r classification_test.txt -r nominal_test.txt +-r segmentation_test.txt diff --git a/requirements/segmentation_test.txt b/requirements/segmentation_test.txt new file mode 100644 index 00000000000..c2db71fcff6 --- /dev/null +++ b/requirements/segmentation_test.txt @@ -0,0 +1,5 @@ +# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package +# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment + +scipy >1.0.0, <1.11.0 +monai ==1.3.0 diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py new file mode 100644 index 00000000000..94f1dec4a9f --- /dev/null +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py new file mode 100644 index 00000000000..bbf5c48ded3 --- /dev/null +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -0,0 +1,781 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn.functional import conv2d, conv3d, pad, unfold +from typing_extensions import Literal + +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.imports import _SCIPY_AVAILABLE + + +def check_if_binarized(x: Tensor) -> None: + """Check if the input is binarized. + + Example: + >>> from torchmetrics.functional.segmentation.utils import check_if_binarized + >>> import torch + >>> check_if_binarized(torch.tensor([0, 1, 1, 0])) + + """ + if not torch.all(x.bool() == x): + raise ValueError("Input x should be binarized") + + +def _unfold(x: Tensor, kernel_size: Tuple[int, ...]) -> Tensor: + """Unfold the input tensor to a matrix. Function supports 3d images e.g. (B, C, D, H, W). + + Inspired by: + https://github.com/f-dangel/unfoldNd/blob/main/unfoldNd/unfold.py + + Args: + x: Input tensor to be unfolded. + kernel_size: The size of the sliding blocks in each dimension. + + """ + batch_size, channels = x.shape[:2] + n = x.ndim - 2 + if n == 2: + return unfold(x, kernel_size) + + kernel_size_numel = kernel_size[0] * kernel_size[1] * kernel_size[2] + repeat = [channels, 1] + [1 for _ in kernel_size] + weight = torch.eye(kernel_size_numel, device=x.device, dtype=x.dtype) + weight = weight.reshape(kernel_size_numel, 1, *kernel_size).repeat(*repeat) + unfold_x = conv3d(x, weight=weight, bias=None) + return unfold_x.reshape(batch_size, channels * kernel_size_numel, -1) + + +def generate_binary_structure(rank: int, connectivity: int) -> Tensor: + """Translated version of the function from scipy.ndimage.morphology. + + Args: + rank: The rank of the structuring element. + connectivity: The number of neighbors connected to a given pixel. + + Returns: + The structuring element. + + Examples:: + >>> from torchmetrics.functional.segmentation.utils import generate_binary_structure + >>> import torch + >>> generate_binary_structure(2, 1) + tensor([[False, True, False], + [ True, True, True], + [False, True, False]]) + >>> generate_binary_structure(2, 2) + tensor([[True, True, True], + [True, True, True], + [True, True, True]]) + >>> generate_binary_structure(3, 2) # doctest: +NORMALIZE_WHITESPACE + tensor([[[False, True, False], + [ True, True, True], + [False, True, False]], + [[ True, True, True], + [ True, True, True], + [ True, True, True]], + [[False, True, False], + [ True, True, True], + [False, True, False]]]) + + """ + if connectivity < 1: + connectivity = 1 + if rank < 1: + return torch.tensor([1], dtype=torch.uint8) + grids = torch.meshgrid([torch.arange(3) for _ in range(rank)], indexing="ij") + output = torch.abs(torch.stack(grids, dim=0) - 1) + output = torch.sum(output, dim=0) + return output <= connectivity + + +def binary_erosion( + image: Tensor, structure: Optional[Tensor] = None, origin: Optional[Tuple[int, ...]] = None, border_value: int = 0 +) -> Tensor: + """Binary erosion of a tensor image. + + Implementation inspired by answer to this question: https://stackoverflow.com/questions/56235733/ + + Args: + image: The image to be eroded, must be a binary tensor with shape ``(batch_size, channels, height, width)``. + structure: The structuring element used for the erosion. If no structuring element is provided, an element + is generated with a square connectivity equal to one. + origin: The origin of the structuring element. + border_value: The value to be used for the border. + + Examples:: + >>> from torchmetrics.functional.segmentation.utils import binary_erosion + >>> import torch + >>> image = torch.tensor([[[[0, 0, 0, 0, 0], + ... [0, 1, 1, 1, 0], + ... [0, 1, 1, 1, 0], + ... [0, 1, 1, 1, 0], + ... [0, 0, 0, 0, 0]]]]) + >>> binary_erosion(image) + tensor([[[[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]]], dtype=torch.uint8) + >>> binary_erosion(image, structure=torch.ones(4, 4)) + tensor([[[[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]]], dtype=torch.uint8) + + """ + if not isinstance(image, Tensor): + raise TypeError(f"Expected argument `image` to be of type Tensor but found {type(image)}") + if image.ndim not in [4, 5]: + raise ValueError(f"Expected argument `image` to be of rank 4 or 5 but found rank {image.ndim}") + check_if_binarized(image) + + # construct the structuring element if not provided + if structure is None: + structure = generate_binary_structure(image.ndim - 2, 1).int().to(image.device) + check_if_binarized(structure) + + if origin is None: + origin = structure.ndim * (1,) + + # first pad the image to have correct unfolding; here is where the origins is used + image_pad = pad( + image, + [x for i in range(len(origin)) for x in [origin[i], structure.shape[i] - origin[i] - 1]], + mode="constant", + value=border_value, + ) + # Unfold the image to be able to perform operation on neighborhoods + image_unfold = _unfold(image_pad.float(), kernel_size=structure.shape) + + strel_flatten = torch.flatten(structure).unsqueeze(0).unsqueeze(-1) + sums = image_unfold - strel_flatten.int() + + # Take minimum over the neighborhood + result, _ = sums.min(dim=1) + + # Reshape the image to recover initial shape + return (torch.reshape(result, image.shape) + 1).byte() + + +def distance_transform( + x: Tensor, + sampling: Optional[Union[Tensor, List[float]]] = None, + metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + engine: Literal["pytorch", "scipy"] = "pytorch", +) -> Tensor: + """Calculate distance transform of a binary tensor. + + This function calculates the distance transform of a binary tensor, replacing each foreground pixel with the + distance to the closest background pixel. The distance is calculated using the euclidean, chessboard or taxicab + distance. + + The memory consumption of this function is in the worst cast N/2**2 where N is the number of pixel. Since we need + to compare all foreground pixels to all background pixels, the memory consumption is quadratic in the number of + pixels. The memory consumption can be reduced by using the ``scipy`` engine, which is more memory efficient but + should also be slower for larger images. + + Args: + x: The binary tensor to calculate the distance transform of. + sampling: Only relevant when distance is calculated using the euclidean distance. The sampling refers to the + pixel spacing in the image, i.e. the distance between two adjacent pixels. If not provided, the pixel + spacing is assumed to be 1. + metric: The distance to use for the distance transform. Can be one of ``"euclidean"``, ``"chessboard"`` + or ``"taxicab"``. + engine: The engine to use for the distance transform. Can be one of ``["pytorch", "scipy"]``. In general, + the ``pytorch`` engine is faster, but the ``scipy`` engine is more memory efficient. + + Returns: + The distance transform of the input tensor. + + Examples:: + >>> from torchmetrics.functional.segmentation.utils import distance_transform + >>> import torch + >>> x = torch.tensor([[0, 0, 0, 0, 0], + ... [0, 1, 1, 1, 0], + ... [0, 1, 1, 1, 0], + ... [0, 1, 1, 1, 0], + ... [0, 0, 0, 0, 0]]) + >>> distance_transform(x) + tensor([[0., 0., 0., 0., 0.], + [0., 1., 1., 1., 0.], + [0., 1., 2., 1., 0.], + [0., 1., 1., 1., 0.], + [0., 0., 0., 0., 0.]]) + + """ + if not isinstance(x, Tensor): + raise ValueError(f"Expected argument `x` to be of type `torch.Tensor` but got `{type(x)}`.") + if x.ndim != 2: + raise ValueError(f"Expected argument `x` to be of rank 2 but got rank `{x.ndim}`.") + if sampling is not None and not isinstance(sampling, list): + raise ValueError( + f"Expected argument `sampling` to either be `None` or of type `list` but got `{type(sampling)}`." + ) + if metric not in ["euclidean", "chessboard", "taxicab"]: + raise ValueError( + f"Expected argument `metric` to be one of `['euclidean', 'chessboard', 'taxicab']` but got `{metric}`." + ) + if engine not in ["pytorch", "scipy"]: + raise ValueError(f"Expected argument `engine` to be one of `['pytorch', 'scipy']` but got `{engine}`.") + + if sampling is None: + sampling = [1, 1] + else: + if len(sampling) != 2: + raise ValueError(f"Expected argument `sampling` to have length 2 but got length `{len(sampling)}`.") + + if engine == "pytorch": + # calculate distance from every foreground pixel to every background pixel + i0, j0 = torch.where(x == 0) + i1, j1 = torch.where(x == 1) + dis_row = (i1.unsqueeze(1) - i0.unsqueeze(0)).abs_().mul_(sampling[0]) + dis_col = (j1.unsqueeze(1) - j0.unsqueeze(0)).abs_().mul_(sampling[1]) + + # # calculate distance + h, _ = x.shape + if metric == "euclidean": + dis_row = dis_row.float() + dis_row.pow_(2).add_(dis_col.pow_(2)).sqrt_() + if metric == "chessboard": + dis_row = dis_row.max(dis_col) + if metric == "taxicab": + dis_row.add_(dis_col) + + # select only the closest distance + mindis, _ = torch.min(dis_row, dim=1) + z = torch.zeros_like(x, dtype=mindis.dtype).view(-1) + z[i1 * h + j1] = mindis + return z.view(x.shape) + + if not _SCIPY_AVAILABLE: + raise ValueError( + "The `scipy` engine requires `scipy` to be installed. Either install `scipy` or use the `pytorch` engine." + ) + from scipy import ndimage + + if metric == "euclidean": + return ndimage.distance_transform_edt(x.cpu().numpy(), sampling) + return ndimage.distance_transform_cdt(x.cpu().numpy(), metric=metric) + + +def mask_edges( + preds: Tensor, + target: Tensor, + crop: bool = True, + spacing: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, +) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]]: + """Get the edges of binary segmentation masks. + + Args: + preds: The predicted binary segmentation mask + target: The ground truth binary segmentation mask + crop: Whether to crop the edges to the region of interest. If ``True``, the edges are cropped to the bounding + spacing: The pixel spacing of the input images. If provided, the edges are calculated using the euclidean + + Returns: + If spacing is not provided, a 2-tuple containing the edges of the predicted and target mask respectively is + returned. If spacing is provided, a 4-tuple containing the edges and areas of the predicted and target mask + respectively is returned. + + """ + _check_same_shape(preds, target) + if preds.ndim not in [2, 3]: + raise ValueError(f"Expected argument `preds` to be of rank 2 or 3 but got rank `{preds.ndim}`.") + check_if_binarized(preds) + check_if_binarized(target) + + if crop: + or_val = preds | target + if not or_val.any(): + p, t = torch.zeros_like(preds), torch.zeros_like(target) + return p, t, p, t + # this seems to be working but does not seem to be right + preds, target = pad(preds, preds.ndim * [1, 1]), pad(target, target.ndim * [1, 1]) + + if spacing is None: + # no spacing, use binary erosion + be_pred = binary_erosion(preds.unsqueeze(0).unsqueeze(0)).squeeze() ^ preds + be_target = binary_erosion(target.unsqueeze(0).unsqueeze(0)).squeeze() ^ target + return be_pred, be_target + + # use neighborhood to get edges + table, kernel = get_neighbour_tables(spacing, device=preds.device) + spatial_dims = len(spacing) + conv_operator = conv2d if spatial_dims == 2 else conv3d + volume = torch.stack([preds.unsqueeze(0), target.unsqueeze(0)], dim=0).float() + code_preds, code_target = conv_operator(volume, kernel.to(volume)) + + # edges + all_ones = len(table) - 1 + edges_preds = (code_preds != 0) & (code_preds != all_ones) + edges_target = (code_target != 0) & (code_target != all_ones) + + # # areas of edges + areas_preds = torch.index_select(table, 0, code_preds.view(-1).int()).view_as(code_preds) + areas_target = torch.index_select(table, 0, code_target.view(-1).int()).view_as(code_target) + return edges_preds[0], edges_target[0], areas_preds[0], areas_target[0] + + +def surface_distance( + preds: Tensor, + target: Tensor, + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, List[float]]] = None, +) -> Tensor: + """Calculate the surface distance between two binary edge masks. + + May return infinity if the predicted mask is empty and the target mask is not, or vice versa. + + Args: + preds: The predicted binary edge mask. + target: The target binary edge mask. + distance_metric: The distance metric to use. One of `["euclidean", "chessboard", "taxicab"]`. + spacing: The spacing between pixels along each spatial dimension. + + Returns: + A tensor with length equal to the number of edges in predictions e.g. `preds.sum()`. Each element is the + distance from the corresponding edge in `preds` to the closest edge in `target`. + + Example:: + >>> import torch + >>> from torchmetrics.functional.segmentation.utils import surface_distance + >>> preds = torch.tensor([[1, 1, 1, 1, 1], + ... [1, 0, 0, 0, 1], + ... [1, 0, 0, 0, 1], + ... [1, 0, 0, 0, 1], + ... [1, 1, 1, 1, 1]], dtype=torch.bool) + >>> target = torch.tensor([[1, 1, 1, 1, 0], + ... [1, 0, 0, 1, 0], + ... [1, 0, 0, 1, 0], + ... [1, 0, 0, 1, 0], + ... [1, 1, 1, 1, 0]], dtype=torch.bool) + >>> surface_distance(preds, target, distance_metric="euclidean", spacing=[1, 1]) + tensor([0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1.]) + + """ + if not (preds.dtype == torch.bool and target.dtype == torch.bool): + raise ValueError(f"Expected both inputs to be of type `torch.bool`, but got {preds.dtype} and {target.dtype}.") + + if not torch.any(target): + dis = torch.inf * torch.ones_like(target) + else: + if not torch.any(preds): + dis = torch.inf * torch.ones_like(preds) + return dis[target] + dis = distance_transform(~target, sampling=spacing, metric=distance_metric) + return dis[preds] + + +@functools.lru_cache +def get_neighbour_tables( + spacing: Union[Tuple[int, int], Tuple[int, int, int]], device: Optional[torch.device] = None +) -> Tuple[Tensor, Tensor]: + """Create a table that maps neighbour codes to the contour length or surface area of the corresponding contour. + + Args: + spacing: The spacing between pixels along each spatial dimension. + device: The device on which the table should be created. + + Returns: + A tuple containing as its first element the table that maps neighbour codes to the contour length or surface + area of the corresponding contour and as its second element the kernel used to compute the neighbour codes. + + """ + if isinstance(spacing, tuple) and len(spacing) == 2: + return table_contour_length(spacing, device) + if isinstance(spacing, tuple) and len(spacing) == 3: + return table_surface_area(spacing, device) + raise ValueError("The spacing must be a tuple of length 2 or 3.") + + +def table_contour_length(spacing: Tuple[int, int], device: Optional[torch.device] = None) -> Tuple[Tensor, Tensor]: + """Create a table that maps neighbour codes to the contour length of the corresponding contour. + + Adopted from: + https://github.com/deepmind/surface-distance/blob/master/surface_distance/lookup_tables.py + + Args: + spacing: The spacing between pixels along each spatial dimension. Should be a tuple of length 2. + device: The device on which the table should be created. + + Returns: + A tuple containing as its first element the table that maps neighbour codes to the contour length of the + corresponding contour and as its second element the kernel used to compute the neighbour codes. + + Example:: + >>> from torchmetrics.functional.segmentation.utils import table_contour_length + >>> table, kernel = table_contour_length((2,2)) + >>> table + tensor([0.0000, 1.4142, 1.4142, 2.0000, 1.4142, 2.0000, 2.8284, 1.4142, 1.4142, + 2.8284, 2.0000, 1.4142, 2.0000, 1.4142, 1.4142, 0.0000]) + >>> kernel + tensor([[[[8, 4], + [2, 1]]]]) + + """ + if not isinstance(spacing, tuple) and len(spacing) != 2: + raise ValueError("The spacing must be a tuple of length 2.") + + first, second = spacing # spacing along the first and second spatial dimension respectively + diag = 0.5 * math.sqrt(first**2 + second**2) + table = torch.zeros(16, dtype=torch.float32, device=device) + for i in [1, 2, 4, 7, 8, 11, 13, 14]: + table[i] = diag + for i in [3, 12]: + table[i] = second + for i in [5, 10]: + table[i] = first + for i in [6, 9]: + table[i] = 2 * diag + kernel = torch.as_tensor([[[[8, 4], [2, 1]]]], device=device) + return table, kernel + + +@functools.lru_cache +def table_surface_area(spacing: Tuple[int, int, int], device: Optional[torch.device] = None) -> Tuple[Tensor, Tensor]: + """Create a table that maps neighbour codes to the surface area of the corresponding surface. + + Adopted from: + https://github.com/deepmind/surface-distance/blob/master/surface_distance/lookup_tables.py + + Args: + spacing: The spacing between pixels along each spatial dimension. Should be a tuple of length 3. + device: The device on which the table should be created. + + Returns: + A tuple containing as its first element the table that maps neighbour codes to the surface area of the + corresponding surface and as its second element the kernel used to compute the neighbour codes. + + Example:: + >>> from torchmetrics.functional.segmentation.utils import table_surface_area + >>> table, kernel = table_surface_area((2,2,2)) + >>> table + tensor([0.0000, 0.8660, 0.8660, 2.8284, 0.8660, 2.8284, 1.7321, 4.5981, 0.8660, + 1.7321, 2.8284, 4.5981, 2.8284, 4.5981, 4.5981, 4.0000, 0.8660, 2.8284, + 1.7321, 4.5981, 1.7321, 4.5981, 2.5981, 5.1962, 1.7321, 3.6945, 3.6945, + 6.2925, 3.6945, 6.2925, 5.4641, 4.5981, 0.8660, 1.7321, 2.8284, 4.5981, + 1.7321, 3.6945, 3.6945, 6.2925, 1.7321, 2.5981, 4.5981, 5.1962, 3.6945, + 5.4641, 6.2925, 4.5981, 2.8284, 4.5981, 4.5981, 4.0000, 3.6945, 6.2925, + 5.4641, 4.5981, 3.6945, 5.4641, 6.2925, 4.5981, 5.6569, 3.6945, 3.6945, + 2.8284, 0.8660, 1.7321, 1.7321, 3.6945, 2.8284, 4.5981, 3.6945, 6.2925, + 1.7321, 2.5981, 3.6945, 5.4641, 4.5981, 5.1962, 6.2925, 4.5981, 2.8284, + 4.5981, 3.6945, 6.2925, 4.5981, 4.0000, 5.4641, 4.5981, 3.6945, 5.4641, + 5.6569, 3.6945, 6.2925, 4.5981, 3.6945, 2.8284, 1.7321, 2.5981, 3.6945, + 5.4641, 3.6945, 5.4641, 5.6569, 3.6945, 2.5981, 3.4641, 5.4641, 2.5981, + 5.4641, 2.5981, 3.6945, 1.7321, 4.5981, 5.1962, 6.2925, 4.5981, 6.2925, + 4.5981, 3.6945, 2.8284, 5.4641, 2.5981, 3.6945, 1.7321, 3.6945, 1.7321, + 1.7321, 0.8660, 0.8660, 1.7321, 1.7321, 3.6945, 1.7321, 3.6945, 2.5981, + 5.4641, 2.8284, 3.6945, 4.5981, 6.2925, 4.5981, 6.2925, 5.1962, 4.5981, + 1.7321, 3.6945, 2.5981, 5.4641, 2.5981, 5.4641, 3.4641, 2.5981, 3.6945, + 5.6569, 5.4641, 3.6945, 5.4641, 3.6945, 2.5981, 1.7321, 2.8284, 3.6945, + 4.5981, 6.2925, 3.6945, 5.6569, 5.4641, 3.6945, 4.5981, 5.4641, 4.0000, + 4.5981, 6.2925, 3.6945, 4.5981, 2.8284, 4.5981, 6.2925, 5.1962, 4.5981, + 5.4641, 3.6945, 2.5981, 1.7321, 6.2925, 3.6945, 4.5981, 2.8284, 3.6945, + 1.7321, 1.7321, 0.8660, 2.8284, 3.6945, 3.6945, 5.6569, 4.5981, 6.2925, + 5.4641, 3.6945, 4.5981, 5.4641, 6.2925, 3.6945, 4.0000, 4.5981, 4.5981, + 2.8284, 4.5981, 6.2925, 5.4641, 3.6945, 5.1962, 4.5981, 2.5981, 1.7321, + 6.2925, 3.6945, 3.6945, 1.7321, 4.5981, 2.8284, 1.7321, 0.8660, 4.5981, + 5.4641, 6.2925, 3.6945, 6.2925, 3.6945, 3.6945, 1.7321, 5.1962, 2.5981, + 4.5981, 1.7321, 4.5981, 1.7321, 2.8284, 0.8660, 4.0000, 4.5981, 4.5981, + 2.8284, 4.5981, 2.8284, 1.7321, 0.8660, 4.5981, 1.7321, 2.8284, 0.8660, + 2.8284, 0.8660, 0.8660, 0.0000]) + >>> kernel + tensor([[[[[128, 64], + [ 32, 16]], + [[ 8, 4], + [ 2, 1]]]]]) + + """ + if not isinstance(spacing, tuple) and len(spacing) != 3: + raise ValueError("The spacing must be a tuple of length 3.") + + zeros = [0.0, 0.0, 0.0] + table = torch.tensor( + [ + [zeros, zeros, zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[-0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros], + [[0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros], + [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros], + [[0.125, -0.125, -0.125], zeros, zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]], + [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros], + [[0.125, -0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]], + [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]], + [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]], + [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros], + [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros], + [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros], + [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], + [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros], + [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros], + [[-0.125, -0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]], + [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]], + [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]], + [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros], + [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros], + [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]], + [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], + [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]], + [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros], + [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], + [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]], + [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros], + [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]], + [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], + [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], + [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]], + [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], + [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros], + [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros], + [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros], + [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]], + [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros], + [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros], + [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros], + [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros], + [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], + [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros], + [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros], + [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros], + [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros], + [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]], + [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]], + [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[0.125, -0.125, -0.125], zeros, zeros, zeros], + [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros], + [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros], + [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[-0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [zeros, zeros, zeros, zeros], + ], + dtype=torch.float32, + device=device, + ) + + space = torch.as_tensor( + [[[spacing[1] * spacing[2], spacing[0] * spacing[2], spacing[0] * spacing[1]]]], + device=device, + dtype=table.dtype, + ) + norm = torch.linalg.norm(table * space, dim=-1) + table = norm.sum(-1) + kernel = torch.as_tensor([[[[[128, 64], [32, 16]], [[8, 4], [2, 1]]]]], device=device) + return table, kernel diff --git a/tests/unittests/segmentation/__init__.py b/tests/unittests/segmentation/__init__.py new file mode 100644 index 00000000000..94f1dec4a9f --- /dev/null +++ b/tests/unittests/segmentation/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/segmentation/test_utils.py b/tests/unittests/segmentation/test_utils.py new file mode 100644 index 00000000000..66a66f20a25 --- /dev/null +++ b/tests/unittests/segmentation/test_utils.py @@ -0,0 +1,245 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from monai.metrics.utils import get_code_to_measure_table +from monai.metrics.utils import get_mask_edges as monai_get_mask_edges +from monai.metrics.utils import get_surface_distance as monai_get_surface_distance +from scipy.ndimage import binary_erosion as scibinary_erosion +from scipy.ndimage import distance_transform_cdt as scidistance_transform_cdt +from scipy.ndimage import distance_transform_edt as scidistance_transform_edt +from scipy.ndimage import generate_binary_structure as scigenerate_binary_structure +from torchmetrics.functional.segmentation.utils import ( + binary_erosion, + distance_transform, + generate_binary_structure, + get_neighbour_tables, + mask_edges, + surface_distance, +) + + +@pytest.mark.parametrize("rank", [2, 3, 4]) +@pytest.mark.parametrize("connectivity", [1, 2, 3]) +def test_generate_binary_structure(rank, connectivity): + """Test the generate binary structure function.""" + structure = generate_binary_structure(rank, connectivity) + scistucture = scigenerate_binary_structure(rank, connectivity) + assert torch.allclose(structure, torch.from_numpy(scistucture)) + + +@pytest.mark.parametrize( + "case", + [ + torch.ones(3, 1), + torch.ones(5, 1), + torch.ones(3, 3), + torch.tensor( + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 1, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 0], + [0, 1, 1, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + torch.tensor( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + torch.tensor( + [ + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + ] + ), + torch.tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 1, 0, 1], [0, 1, 1, 1, 0], [0, 1, 1, 1, 1]]), + torch.randint(2, (5, 5)), + torch.randint(2, (20, 20)), + torch.ones(5, 5, 5), + torch.randint(2, (5, 5, 5)), + torch.randint(2, (20, 20, 20)), + ], +) +@pytest.mark.parametrize("border_value", [0, 1]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_binary_erosion(case, border_value, device): + """Test the binary erosion function. + + Cases taken from: + https://github.com/scipy/scipy/blob/v1.11.1/scipy/ndimage/tests/test_morphology.py + + """ + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA device not available.") + scierosion = scibinary_erosion(case, border_value=border_value) + erosion = binary_erosion(case.unsqueeze(0).unsqueeze(0).to(device), border_value=border_value) + assert torch.allclose(erosion.cpu(), torch.from_numpy(scierosion).byte()) + + +@pytest.mark.parametrize( + ("arguments", "error", "match"), + [ + (([0, 1, 2, 3],), TypeError, "Expected argument `image` to be of type Tensor.*"), + ((torch.ones(3, 3),), ValueError, "Expected argument `image` to be of rank 4 or 5 but.*"), + ((torch.randint(3, (1, 1, 5, 5)),), ValueError, "Input x should be binarized"), + ], +) +def test_binary_erosion_error(arguments, error, match): + """Test that binary erosion raises an error when the input is not binary.""" + with pytest.raises(error, match=match): + binary_erosion(*arguments) + + +@pytest.mark.parametrize( + "case", + [ + torch.tensor( + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 1, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 0], + [0, 1, 1, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + torch.tensor( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + torch.tensor( + [ + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + ] + ), + ], +) +@pytest.mark.parametrize("metric", ["euclidean", "chessboard", "taxicab"]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_distance_transform(case, metric, device): + """Test the distance transform function. + + Cases taken from: + https://github.com/scipy/scipy/blob/v1.11.1/scipy/ndimage/tests/test_morphology.py + + """ + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA device not available.") + distance = distance_transform(case.to(device), metric=metric) + if metric == "euclidean": + scidistance = scidistance_transform_edt(case) + else: + scidistance = scidistance_transform_cdt(case, metric=metric) + assert torch.allclose(distance.cpu(), torch.from_numpy(scidistance).to(distance.dtype)) + + +@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize("spacing", [1, 2]) +def test_neighbour_table(dim, spacing): + """Test the table for surface score function.""" + spacing = dim * (spacing,) + ref_table, ref_kernel = get_code_to_measure_table(spacing) + table, kernel = get_neighbour_tables(spacing) + + assert torch.allclose(ref_table.float(), table) + assert torch.allclose(ref_kernel, kernel) + + +@pytest.mark.parametrize( + "cases", + [ + ( + torch.tensor( + [[1, 1, 1, 1, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 1, 1, 1, 1]], dtype=torch.bool + ), + torch.tensor( + [[1, 1, 1, 1, 0], [1, 0, 0, 1, 0], [1, 0, 0, 1, 0], [1, 0, 0, 1, 0], [1, 1, 1, 1, 0]], dtype=torch.bool + ), + ), + (torch.randint(0, 2, (5, 5), dtype=torch.bool), torch.randint(0, 2, (5, 5), dtype=torch.bool)), + (torch.randint(0, 2, (50, 50), dtype=torch.bool), torch.randint(0, 2, (50, 50), dtype=torch.bool)), + ], +) +@pytest.mark.parametrize("distance_metric", ["euclidean", "chessboard", "taxicab"]) +@pytest.mark.parametrize("spacing", [1, 2]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_surface_distance(cases, distance_metric, spacing, device): + """Test the surface distance function.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA device not available.") + if spacing != 1 and distance_metric != "euclidean": + pytest.skip("Only euclidean distance is supported for spacing != 1 in reference") + preds, target = cases + spacing = 2 * [spacing] + res = surface_distance(preds.to(device), target.to(device), distance_metric=distance_metric, spacing=spacing) + reference_res = monai_get_surface_distance( + preds.numpy(), target.numpy(), distance_metric=distance_metric, spacing=spacing + ) + assert torch.allclose(res.cpu(), torch.from_numpy(reference_res).to(res.dtype)) + + +@pytest.mark.parametrize( + "cases", + [ + (torch.randint(0, 2, (5, 5), dtype=torch.bool), torch.randint(0, 2, (5, 5), dtype=torch.bool)), + (torch.randint(0, 2, (50, 50), dtype=torch.bool), torch.randint(0, 2, (50, 50), dtype=torch.bool)), + (torch.randint(0, 2, (50, 50, 50), dtype=torch.bool), torch.randint(0, 2, (50, 50, 50), dtype=torch.bool)), + ], +) +@pytest.mark.parametrize("spacing", [None, 1, 2]) +@pytest.mark.parametrize("crop", [False, True]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_mask_edges(cases, spacing, crop, device): + """Test the mask edges function.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA device not available.") + preds, target = cases + if spacing is not None: + spacing = preds.ndim * (spacing,) + res = mask_edges(preds.to(device), target.to(device), spacing=spacing, crop=crop) + reference_res = monai_get_mask_edges(preds, target, spacing=spacing, crop=crop) + + for r1, r2 in zip(res, reference_res): + assert torch.allclose(r1.cpu().float(), torch.from_numpy(r2).float())