Skip to content

Commit

Permalink
refactoring join (#519)
Browse files Browse the repository at this point in the history
refactoring join after #512
  • Loading branch information
LucaMarconato authored Mar 27, 2024
1 parent 460024f commit 1f74e77
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 106 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ and this project adheres to [Semantic Versioning][].

### Added

####

- Added method `update_annotated_regions_metadata() which updates the `region`value automatically from the `region_key` columns

### Changed

- Renamed `join_sdata_spatialelement_table` to `join_spatialelement_table`, and made it work also without `SpatialData` objects.

## [0.1.0] - 2024-03-24

### Added
Expand All @@ -24,7 +26,7 @@ and this project adheres to [Semantic Versioning][].

- Implemented support in SpatialData for storing multiple tables. These tables can annotate a SpatialElement but not
necessarily so.
- Added SQL like joins that can be executed by calling one public function `join_spatialelement_table`. The
- Added SQL like joins that can be executed by calling one public function `join_sdata_spatialelement_table`. The
following joins are supported: `left`, `left_exclusive`, `right`, `right_exclusive` and `inner`. The function has
an option to match rows. For `left` only matching `left` is supported and for `right` join only `right` matching of
rows is supported. Not all joins are supported for `Labels` elements. The elements and table can either exist within
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Note: if you are using a Mac with an M1/M2 chip, please follow the installation

## Limitations

- Windows support. Currently the framework is tested on Linux and macOS machines, not Windows machines. Users have reported bugs in read/write operations (27 March 2024).
- Windows support. Currently the framework is tested on Linux and macOS machines, not Windows machines. Users have reported bugs in read/write operations (27 March 2024).

## Contact

Expand Down
161 changes: 78 additions & 83 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,80 +465,49 @@ class MatchTypes(Enum):


def _create_sdata_elements_dict_for_join(
sdata: SpatialData, spatial_element_name: str | list[str], table_name: str
) -> tuple[dict[str, dict[str, Any]], AnnData]:
assert sdata.tables.get(table_name), f"No table with `{table_name}` exists in the SpatialData object."
table = sdata.tables[table_name]

sdata: SpatialData, spatial_element_name: str | list[str]
) -> dict[str, dict[str, Any]]:
elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict))
for name in spatial_element_name:
if name in sdata.tables:
warnings.warn(
f"Table: `{name}` given in spatial_element_names cannot be "
f"joined with a table using this function.",
UserWarning,
stacklevel=2,
)
elif name in sdata.images:
warnings.warn(
f"Image: `{name}` cannot be joined with a table",
UserWarning,
stacklevel=2,
)
else:
element_type, _, element = sdata._find_element(name)
elements_dict[element_type][name] = element
return elements_dict, table
element_type, _, element = sdata._find_element(name)
elements_dict[element_type][name] = element
return elements_dict


def _create_elements_dict_for_join(
spatial_element_name: str | list[str], elements: SpatialElement | list[SpatialElement], table: AnnData
) -> dict[str, dict[str, Any]]:
elements = elements if isinstance(elements, list) else [elements]
def _validate_element_types_for_join(
sdata: SpatialData | None,
spatial_element_names: list[str],
spatial_elements: list[SpatialElement] | None,
table: AnnData | None,
) -> None:
if sdata is not None:
elements_to_check = []
for name in spatial_element_names:
elements_to_check.append(sdata[name])
else:
assert spatial_elements is not None
elements_to_check = spatial_elements

elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict))
for name, element in zip(spatial_element_name, elements):
for element in elements_to_check:
model = get_model(element)

if model == TableModel:
warnings.warn(
f"Table: `{name}` given in spatial_element_name cannot be " f"joined with a table using this function.",
UserWarning,
stacklevel=2,
)
continue
if model in [Image2DModel, Image3DModel]:
warnings.warn(
f"Image: `{name}` cannot be joined with a table",
UserWarning,
stacklevel=2,
)
continue

if model in [Labels2DModel, Labels3DModel]:
element_type = "labels"
elif model == PointsModel:
element_type = "points"
elif model == ShapesModel:
element_type = "shapes"
elements_dict[element_type][name] = element
return elements_dict
if model in [Image2DModel, Image3DModel, TableModel]:
raise ValueError(f"Element type `{model}` not supported for join operation.")


def join_spatialelement_table(
spatial_element_names: str | list[str],
elements: SpatialElement | list[SpatialElement] | None = None,
table: AnnData | None = None,
table_name: str | None = None,
sdata: SpatialData | None = None,
spatial_element_names: str | list[str] | None = None,
spatial_elements: SpatialElement | list[SpatialElement] | None = None,
table_name: str | None = None,
table: AnnData | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "left",
match_rows: Literal["no", "left", "right"] = "no",
) -> tuple[dict[str, Any], AnnData]:
"""
Join SpatialElement(s) and table together in SQL like manner.
The function allows the user to perform SQL like joins of SpatialElements and a table. The elements are not
returned together in one dataframe like structure, but instead filtered elements are returned. To determine matches,
returned together in one dataframe-like structure, but instead filtered elements are returned. To determine matches,
for the SpatialElement the index is used and for the table the region key column and instance key column. The
elements are not overwritten in the `SpatialData` object.
Expand All @@ -555,14 +524,20 @@ def join_spatialelement_table(
Parameters
----------
sdata
SpatialData object containing all the elements and tables. This parameter can be `None`; in such case the both
the names and values for the elements and the table must be provided.
spatial_element_names
The name(s) of the spatial elements to be joined with the table. If a list of names the indices must match
with the list of SpatialElements passed on by the argument elements.
elements
The SpatialElement(s) to be joined with the table. In case of a list of SpatialElements the indices
must match exactly with the indices in the list of spatial_element_name.
Required. The name(s) of the spatial elements to be joined with the table. If a list of names, and if sdata is
`None`, the indices must match with the list of SpatialElements passed on by the argument elements.
spatial_elements
This parameter should be speficied exactly when `sdata` is `None`. The SpatialElement(s) to be joined with the
table. In case of a list of SpatialElements the indices must match exactly with the indices in the list of
`spatial_element_name`.
table_name
The name of the table to join with the spatial elements.
The name of the table to join with the spatial elements. Optional, `table` can be provided instead.
table
The table to join with the spatial elements. When `sdata` is not `None`, `table_name` can be used instead.
how
The type of SQL like join to perform, default is ``'left'``. Options are ``'left'``, ``'left_exclusive'``,
``'inner'``, ``'right'`` and ``'right_exclusive'``.
Expand All @@ -577,35 +552,55 @@ def join_spatialelement_table(
Raises
------
ValueError
If table_name is provided but not present in the SpatialData object.
If `spatial_element_names` is not provided.
ValueError
If sdata is `None` but `spatial_elements` is not `None`; if `sdata` is not `None`, but `spatial_elements` is
`None`.
ValueError
If `table_name` is provided but not present in the `SpatialData` object, or if `table_name` is provided but
`sdata` is `None`.
ValueError
If not exactly one of `table_name` and `table` is provided.
ValueError
If no valid elements are provided for the join operation.
ValueError
If the provided join type is not supported.
ValueError
If an incorrect value is given for match_rows.
If an incorrect value is given for `match_rows`.
"""
if spatial_element_names is None:
raise ValueError("`spatial_element_names` must be provided.")
if sdata is None and (spatial_elements is None or table is None):
raise ValueError("If `sdata` is not provided, both `spatial_elements` and `table` must be provided.")
if sdata is not None and (spatial_elements is not None):
raise ValueError(
"If `sdata` is provided, `spatial_elements` must not be provided; use `spatial_elements_name` instead."
)
if table is None and table_name is None or table is not None and table_name is not None:
raise ValueError("Exactly one of `table_name` and `table` must be provided.")
if sdata is not None and table_name is not None:
if table_name not in sdata.tables:
raise ValueError(f"No table with name `{table_name}` found in the SpatialData object.")
table = sdata[table_name]
spatial_element_names = (
spatial_element_names if isinstance(spatial_element_names, list) else [spatial_element_names]
)
sdata_args = [sdata, table_name]
non_sdata_args = [elements, table]
if any(arg is not None for arg in sdata_args):
assert all(
arg is None for arg in non_sdata_args
), "If `sdata` and `table_name` are specified, `elements` and `table` should not be specified."
if sdata is not None and table_name is not None:
elements_dict, table = _create_sdata_elements_dict_for_join(sdata, spatial_element_names, table_name)
else:
raise ValueError("If either `sdata` or `table_name` is specified, both should be specified.")
if any(arg is not None for arg in non_sdata_args):
assert all(
arg is not None for arg in non_sdata_args
), "both `elements` and `table` must be given if either is specified."
elements_dict = _create_elements_dict_for_join(spatial_element_names, elements, table_name)

elements_dict, table = _call_join(elements_dict, table, how, match_rows)
return elements_dict, table
spatial_elements = spatial_elements if isinstance(spatial_elements, list) else [spatial_elements]
_validate_element_types_for_join(sdata, spatial_element_names, spatial_elements, table)

elements_dict: dict[str, dict[str, Any]]
if sdata is not None:
elements_dict = _create_sdata_elements_dict_for_join(sdata, spatial_element_names)
else:
derived_sdata = SpatialData.from_elements_dict(dict(zip(spatial_element_names, spatial_elements)))
element_types = ["labels", "shapes", "points"]
elements_dict = defaultdict(lambda: defaultdict(dict))
for element_type in element_types:
for name, element in getattr(derived_sdata, element_type).items():
elements_dict[element_type][name] = element

elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows)
return elements_dict_joined, table


def _call_join(
Expand Down
Loading

0 comments on commit 1f74e77

Please sign in to comment.