Skip to content

Commit

Permalink
Added shortest_path parameter to `get_transformation_between_coordi…
Browse files Browse the repository at this point in the history
…nate_systems` (#714)

* Added `shortest_path` parameter to `get_transformation_between_coordinate_systems`

* docstring minor edit

---------

Co-authored-by: LucaMarconato <[email protected]>
  • Loading branch information
quentinblampey and LucaMarconato authored Sep 26, 2024
1 parent 8239455 commit b9eb240
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 79 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [0.2.4] - xxxx-xx-xx

### Minor

- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems`

## [0.2.3] - 2024-09-25

### Minor
Expand Down
168 changes: 92 additions & 76 deletions src/spatialdata/transformations/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
from skimage.transform import estimate_transform

from spatialdata._logging import logger
from spatialdata.transformations._utils import (
_get_transformations,
_set_transformations,
)
from spatialdata.transformations._utils import _get_transformations, _set_transformations

if TYPE_CHECKING:
from spatialdata._core.spatialdata import SpatialData
Expand Down Expand Up @@ -216,6 +213,7 @@ def get_transformation_between_coordinate_systems(
source_coordinate_system: Union[SpatialElement, str],
target_coordinate_system: Union[SpatialElement, str],
intermediate_coordinate_systems: Optional[Union[SpatialElement, str]] = None,
shortest_path: bool = True,
) -> BaseTransformation:
"""
Get the transformation to map a coordinate system (intrinsic or extrinsic) to another one.
Expand All @@ -228,6 +226,10 @@ def get_transformation_between_coordinate_systems(
target_coordinate_system
The target coordinate system. Can be a SpatialElement (intrinsic coordinate system) or a string (extrinsic
coordinate system).
shortest_path
Whether to return the shortest paths when multiple paths are found between the coordinate systems
and a single shortest path is found. If `False`, an error is raised when multiple paths exist.
The same error is raised if `True`, but multiple paths of the same shortest lenghts are found.
Returns
-------
Expand All @@ -236,88 +238,105 @@ def get_transformation_between_coordinate_systems(
from spatialdata.models._utils import has_type_spatial_element
from spatialdata.transformations import Identity, Sequence

if (
isinstance(source_coordinate_system, str)
and isinstance(target_coordinate_system, str)
and source_coordinate_system == target_coordinate_system
or id(source_coordinate_system) == id(target_coordinate_system)
):
return Identity()

def _describe_paths(paths: list[list[Union[int, str]]]) -> str:
paths_str = ""
for p in paths:
components = []
for c in p:
if isinstance(c, str):
components.append(f"{c!r}")
else:
ss = [
f"<sdata>.{element_type}[{element_name!r}]"
for element_type, element_name, e in sdata._gen_elements()
if id(e) == c
]
assert len(ss) == 1
components.append(ss[0])
continue
ss = [
f"<sdata>.{element_type}[{element_name!r}]"
for element_type, element_name, e in sdata._gen_elements()
if id(e) == c
]
assert len(ss) == 1
components.append(ss[0])
paths_str += "\n " + " -> ".join(components)
return paths_str

if (
isinstance(source_coordinate_system, str)
and isinstance(target_coordinate_system, str)
and source_coordinate_system == target_coordinate_system
or id(source_coordinate_system) == id(target_coordinate_system)
):
return Identity()
g = _build_transformations_graph(sdata)
src_node: Union[int, str]
if has_type_spatial_element(source_coordinate_system):
src_node = id(source_coordinate_system)
else:
g = _build_transformations_graph(sdata)
src_node: Union[int, str]
if has_type_spatial_element(source_coordinate_system):
src_node = id(source_coordinate_system)
else:
assert isinstance(source_coordinate_system, str)
src_node = source_coordinate_system
tgt_node: Union[int, str]
if has_type_spatial_element(target_coordinate_system):
tgt_node = id(target_coordinate_system)
assert isinstance(source_coordinate_system, str)
src_node = source_coordinate_system
tgt_node: Union[int, str]
if has_type_spatial_element(target_coordinate_system):
tgt_node = id(target_coordinate_system)
else:
assert isinstance(target_coordinate_system, str)
tgt_node = target_coordinate_system
paths = list(nx.all_simple_paths(g, source=src_node, target=tgt_node))
if len(paths) == 0:
# error 0 (we refer to this in the tests)
raise RuntimeError("No path found between the two coordinate systems")
if len(paths) == 1:
path = paths[0]
elif intermediate_coordinate_systems is None:
# if one and only one of the paths has lenght 1, we choose it straight away, otherwise we raise
# an expection and ask the user to be more specific
paths_with_length_1 = [p for p in paths if len(p) == 2]
if len(paths_with_length_1) == 1:
path = paths_with_length_1[0]
elif shortest_path:
shortest_paths = [p for p in paths if len(p) == min(map(len, paths))]

if len(shortest_paths) > 1:
# error 1
s = _describe_paths(shortest_paths)
raise RuntimeError(
"Multiple equal paths found between the two coordinate systems passing through the intermediate. "
f"Available shortest paths are:{s}"
)
path = shortest_paths[0]
else:
assert isinstance(target_coordinate_system, str)
tgt_node = target_coordinate_system
paths = list(nx.all_simple_paths(g, source=src_node, target=tgt_node))
# error 2
s = _describe_paths(paths)
raise RuntimeError(
"Multiple paths found between the two coordinate systems. Please specify an intermediate "
f"coordinate system. Available paths are:{s}"
)
else:
if has_type_spatial_element(intermediate_coordinate_systems):
intermediate_coordinate_systems = id(intermediate_coordinate_systems)
paths = [p for p in paths if intermediate_coordinate_systems in p]
if len(paths) == 0:
# error 0 (we refer to this in the tests)
raise RuntimeError("No path found between the two coordinate systems")
elif len(paths) > 1:
if intermediate_coordinate_systems is None:
# if one and only one of the paths has lenght 1, we choose it straight away, otherwise we raise
# an expection and ask the user to be more specific
paths_with_length_1 = [p for p in paths if len(p) == 2]
if len(paths_with_length_1) == 1:
path = paths_with_length_1[0]
else:
# error 1
s = _describe_paths(paths)
raise RuntimeError(
"Multiple paths found between the two coordinate systems. Please specify an intermediate "
f"coordinate system. Available paths are:{s}"
)
else:
if has_type_spatial_element(intermediate_coordinate_systems):
intermediate_coordinate_systems = id(intermediate_coordinate_systems)
paths = [p for p in paths if intermediate_coordinate_systems in p]
if len(paths) == 0:
# error 2
raise RuntimeError(
"No path found between the two coordinate systems passing through the intermediate"
)
elif len(paths) > 1:
# error 3
s = _describe_paths(paths)
raise RuntimeError(
"Multiple paths found between the two coordinate systems passing through the intermediate. "
f"Avaliable paths are:{s}"
)
else:
path = paths[0]
else:
# error 3
raise RuntimeError("No path found between the two coordinate systems passing through the intermediate")
if len(paths) == 1:
path = paths[0]
transformations = []
for i in range(len(path) - 1):
transformations.append(g[path[i]][path[i + 1]]["transformation"])
sequence = Sequence(transformations)
return sequence
elif shortest_path:
shortest_paths = [p for p in paths if len(p) == min(map(len, paths))]
if len(shortest_paths) > 1:
# error 4
s = _describe_paths(shortest_paths)
raise RuntimeError(
"Multiple equal paths found between the two coordinate systems passing through the intermediate. "
f"Available paths are:{s}"
)
path = shortest_paths[0]
else:
# error 5
s = _describe_paths(paths)
raise RuntimeError(
"Multiple paths found between the two coordinate systems passing through the intermediate. "
f"Available paths are:{s}"
)

transformations = [g[path[i]][path[i + 1]]["transformation"] for i in range(len(path) - 1)]
sequence = Sequence(transformations)
return sequence


def get_transformation_between_landmarks(
Expand Down Expand Up @@ -355,10 +374,7 @@ def get_transformation_between_landmarks(
"""
from spatialdata import transform
from spatialdata.models import get_axes_names
from spatialdata.transformations.transformations import (
Affine,
Sequence,
)
from spatialdata.transformations.transformations import Affine, Sequence

assert get_axes_names(references_coords) == ("x", "y")
assert get_axes_names(moving_coords) == ("x", "y")
Expand Down
43 changes: 40 additions & 3 deletions tests/core/operations/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,43 @@ def test_map_coordinate_systems_single_path(full_sdata: SpatialData):
)


def test_coordinate_systems_with_shortest_paths(full_sdata: SpatialData):
scale = Scale([2], axes=("x",))
translation = Translation([100], axes=("x",))
cs1_to_cs2 = Sequence([scale.inverse(), translation])

im = full_sdata.images["image2d_multiscale"]
la = full_sdata.labels["labels2d"]
po = full_sdata.shapes["multipoly"]
po2 = full_sdata.shapes["circles"]

set_transformation(im, {"cs1": Identity()}, set_all=True)
set_transformation(la, {"cs2": Identity()}, set_all=True)

with pytest.raises(RuntimeError): # error 0
get_transformation_between_coordinate_systems(full_sdata, im, la)

set_transformation(po, {"cs1": scale, "cs2": translation}, set_all=True)

t = get_transformation_between_coordinate_systems(full_sdata, im, la, shortest_path=True)
assert len(t.transformations) == 4
t = get_transformation_between_coordinate_systems(full_sdata, im, la, shortest_path=False)
assert len(t.transformations) == 4

set_transformation(im, cs1_to_cs2, "cs2")

with pytest.raises(RuntimeError): # error 4
get_transformation_between_coordinate_systems(full_sdata, im, la, shortest_path=False)

t = get_transformation_between_coordinate_systems(full_sdata, im, la, shortest_path=True)

assert len(t.transformations) == 2

set_transformation(po2, {"cs1": scale, "cs2": translation}, set_all=True)

get_transformation_between_coordinate_systems(full_sdata, im, la, shortest_path=True)


def test_map_coordinate_systems_zero_or_multiple_paths(full_sdata):
scale = Scale([2], axes=("x",))

Expand All @@ -356,7 +393,7 @@ def test_map_coordinate_systems_zero_or_multiple_paths(full_sdata):
full_sdata, source_coordinate_system="my_space0", target_coordinate_system="globalE"
)

# error 1
# error 2
with pytest.raises(RuntimeError):
t = get_transformation_between_coordinate_systems(
full_sdata, source_coordinate_system="my_space0", target_coordinate_system="global"
Expand All @@ -378,15 +415,15 @@ def test_map_coordinate_systems_zero_or_multiple_paths(full_sdata):
]
),
)
# error 2
# error 3
with pytest.raises(RuntimeError):
get_transformation_between_coordinate_systems(
full_sdata,
source_coordinate_system="my_space0",
target_coordinate_system="global",
intermediate_coordinate_systems="globalE",
)
# error 3
# error 5
with pytest.raises(RuntimeError):
get_transformation_between_coordinate_systems(
full_sdata,
Expand Down

0 comments on commit b9eb240

Please sign in to comment.