Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting Wrappers: Occupancy Histogram #403

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion movement/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from movement.plots.occupancy import plot_occupancy
from movement.plots.trajectory import plot_trajectory

__all__ = ["plot_trajectory"]
__all__ = ["plot_occupancy", "plot_trajectory"]
133 changes: 133 additions & 0 deletions movement/plots/occupancy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Wrappers for plotting occupancy data of select individuals."""

from collections.abc import Hashable
from typing import Any, Literal, TypeAlias

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

HistInfoKeys: TypeAlias = Literal["counts", "xedges", "yedges"]

DEFAULT_HIST_ARGS = {"alpha": 1.0, "bins": 30, "cmap": "viridis"}


def plot_occupancy(
da: xr.DataArray,
individuals: Hashable | None = None,
keypoints: Hashable | list[Hashable] | None = None,
ax: plt.Axes | None = None,
**kwargs: Any,
) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]:
"""Create a 2D histogram of the occupancy data given.

By default;

- If there are multiple keypoints selected, the occupancy of the centroid
of these keypoints is computed.
- If there are multiple individuals selected, the occupancies of their
centroids are aggregated.

Points whose corresponding spatial coordinates have NaN values
are ignored.

Histogram information is returned as the third output value (see Notes).

Parameters
----------
da : xarray.DataArray
Spatial data to create histogram for. NaN values are dropped.
individuals : Hashable, optional
The name of the individual(s) to be aggregated and plotted. By default,
all individuals are aggregated.
keypoints : Hashable | list[Hashable], optional
Name of a keypoint or list of such names. The centroid of all provided
keypoints is computed, then plotted in the histogram.
ax : matplotlib.axes.Axes, optional
Axes object on which to draw the histogram. If not provided, a new
figure and axes are created and returned.
kwargs : Any
Keyword arguments passed to ``matplotlib.pyplot.hist2d``
Comment on lines +49 to +50
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm completely on board with forwarding all kwargs to hist2d. However, I think it would be helpful to illustrate some of the most commonly used kwargs in one or two examples in this docstring. While experimenting with this function, I found the following particularly useful:

  • bins (since users will want full control over the bin sizes)
  • cmin (especially useful when overlaying the trajectory on an image, to mask areas with low occupancy counts)
  • norm (particularly norm="log")

I don't believe we need to show all of these in the docstring example, as we have more space to explore them in a proper Sphinx Gallery example (see issue #410). However, we should at least demonstrate a typical usage of bins, for example bins=(30, 30).


Returns
-------
matplotlib.pyplot.Figure
Plot handle containing the rendered 2D histogram. If ``ax`` is
supplied, this will be the figure that ``ax`` belongs to.
matplotlib.axes.Axes
Axes on which the histogram was drawn. If ``ax`` was supplied,
the input will be directly modified and returned in this value.
dict[str, numpy.ndarray]
Information about the created histogram (see Notes).

Notes
-----
In instances where the counts or information about the histogram bins is
desired, the ``return_hist_info`` argument should be provided as ``True``.
This will force the function to return a second output value, which is a
dictionary containing the bin edges and bin counts that were used to create
the histogram.

For data with ``Nx`` bins in the 1st spatial dimension, and ``Ny`` bins in
the 2nd spatial dimension, the dictionary output has key-value pairs;
- ``xedges``, an ``(Nx+1,)`` ``numpy`` array specifying the bin edges in
the 1st spatial dimension.
- ``yedges``, an ``(Ny+1,)`` ``numpy`` array specifying the bin edges in
the 2nd spatial dimension.
- ``counts``, an ``(Nx, Ny)`` ``numpy`` array with the count for each bin.

``counts[x, y]`` is the number of datapoints in the
``(xedges[x], xedges[x+1]), (yedges[y], yedges[y+1])`` bin. These values
are those returned from ``matplotlib.pyplot.Axes.hist2d``.

See Also
--------
matplotlib.pyplot.Axes.hist2d : The underlying plotting function.

"""
# Collapse dimensions if necessary
data = da.copy(deep=True)
if "keypoints" in da.dims:
if keypoints is not None:
data = data.sel(keypoints=keypoints)
data = data.mean(dim="keypoints")
if "individuals" in da.dims and individuals is not None:
data = data.sel(individuals=individuals)

Check warning on line 95 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L89-L95

Added lines #L89 - L95 were not covered by tests

# We need to remove NaN values from each individual, but we can't do this
# right now because we still potentially have a (time, space, individuals)
# array and so dropping NaNs along any axis may remove valid points for
# other times / individuals.
# Since we only care about a count, we can just unravel the individuals
# dimension and create a "long" array of points. For example, a (10, 2, 5)
# time-space-individuals DataArray becomes (50, 2).
if "individuals" in data.dims:
data.stack({"space": ("space", "individuals")}, create_index=False)

Check warning on line 105 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L104-L105

Added lines #L104 - L105 were not covered by tests
# We should now have just the relevant time-space data,
# so we can remove time-points with NaN values.
data = data.dropna(dim="time", how="any")

Check warning on line 108 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L108

Added line #L108 was not covered by tests

# This makes us agnostic to the planar coordinate system.
x_coord = data["space"].values[0]
y_coord = data["space"].values[1]

Check warning on line 112 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L111-L112

Added lines #L111 - L112 were not covered by tests

# Inherit our defaults if not otherwise provided
for key, value in DEFAULT_HIST_ARGS.items():
if key not in kwargs:
kwargs[key] = value

Check warning on line 117 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L115-L117

Added lines #L115 - L117 were not covered by tests
# Now it should just be a case of creating the histogram
if ax is not None:
fig = ax.get_figure()

Check warning on line 120 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L119-L120

Added lines #L119 - L120 were not covered by tests
else:
fig, ax = plt.subplots()
counts, xedges, yedges, hist_image = ax.hist2d(

Check warning on line 123 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L122-L123

Added lines #L122 - L123 were not covered by tests
data.sel(space=x_coord).stack, data.sel(space=y_coord), **kwargs
)
colourbar = fig.colorbar(hist_image, ax=ax)
colourbar.solids.set(alpha=1.0)

Check warning on line 127 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L126-L127

Added lines #L126 - L127 were not covered by tests

space_unit = data.attrs.get("space_unit", "pixels")
ax.set_xlabel(f"{x_coord} ({space_unit})")
ax.set_ylabel(f"{y_coord} ({space_unit})")

Check warning on line 131 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L129-L131

Added lines #L129 - L131 were not covered by tests

return fig, ax, {"counts": counts, "xedges": xedges, "yedges": yedges}

Check warning on line 133 in movement/plots/occupancy.py

View check run for this annotation

Codecov / codecov/patch

movement/plots/occupancy.py#L133

Added line #L133 was not covered by tests
91 changes: 91 additions & 0 deletions tests/test_unit/test_plots/test_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
import pytest
import xarray as xr
from numpy.random import Generator, default_rng


@pytest.fixture
def seed() -> int:
return 0


@pytest.fixture(scope="function")
def rng(seed: int) -> Generator:
"""Create a RandomState to use in testing.

This ensures the repeatability of histogram tests, that require large
datasets that would be tedious to create manually.
"""
return default_rng(seed)


@pytest.fixture
def normal_dist_2d(rng: Generator) -> np.ndarray:
"""Points distributed by the standard multivariate normal.

The standard multivariate normal is just two independent N(0, 1)
distributions, one in each dimension.
"""
samples = rng.multivariate_normal(
(0.0, 0.0), [[1.0, 0.0], [0.0, 1.0]], (250, 3, 4)
)
return np.moveaxis(
samples, 3, 1
) # Move generated space coords to correct axis position


@pytest.fixture
def histogram_data(normal_dist_2d: np.ndarray) -> xr.DataArray:
"""DataArray whose data is the ``normal_dist_2d`` points.

Axes 2 and 3 are the individuals and keypoints axes, respectively.
These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for
the purposes of indexing.
"""
return xr.DataArray(
data=normal_dist_2d,
dims=["time", "space", "individuals", "keypoints"],
coords={
"space": ["x", "y"],
"individuals": [f"i{i}" for i in range(normal_dist_2d.shape[2])],
"keypoints": [f"k{i}" for i in range(normal_dist_2d.shape[3])],
},
)


@pytest.fixture
def histogram_data_with_nans(histogram_data: xr.DataArray) -> xr.DataArray:
"""DataArray whose data is the ``normal_dist_2d`` points.

Axes 2 and 3 are the individuals and keypoints axes, respectively.
These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for
the purposes of indexing.

For individual i0, keypoint k0, the following (time, space) values are
converted into NaNs:
- (100, "x")
- (200, "y")
- (150, "x")
- (150, "y")

"""
individual_0 = "i0"
keypoint_0 = "k0"
data_with_nans = histogram_data.copy(deep=True)
for time_index, space_coord in [
(100, "x"),
(200, "y"),
(150, "x"),
(150, "y"),
]:
data_with_nans.loc[
time_index, space_coord, individual_0, keypoint_0
] = float("nan")
return data_with_nans


@pytest.fixture
def entirely_nan_data(histogram_data: xr.DataArray) -> xr.DataArray:
return histogram_data.copy(
deep=True, data=histogram_data.values * float("nan")
)
Loading