-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plotting Wrappers: Occupancy Histogram #403
Open
willGraham01
wants to merge
14
commits into
main
Choose a base branch
from
wgraham-388-occupancy-histogram
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
71b88b1
Basic histogram plot created
willGraham01 0d08661
Allow kwargs to go to underlying function
willGraham01 fb61a44
Remove manual debugging from package module
willGraham01 448e3af
Write test, but it fails. But can't figure out why it fails...
willGraham01 a4e5651
Additional return values to help extract histogram information
willGraham01 03cb1db
Test missing dims and entirely NAN values
willGraham01 c3c77ae
Check that new / existing axes are respected
willGraham01 c8cf1b4
Default units to pixels
willGraham01 e042f7c
SonarQube recommendations
willGraham01 be0d22b
Comply with new plot wrapper standards
willGraham01 1614ab4
Add test for default selection case
willGraham01 76a1973
Add check for incorrect dims after squeezing
willGraham01 e3e690d
Remove tests to start afresh
willGraham01 8ecd914
Merge branch 'main' into wgraham-388-occupancy-histogram
willGraham01 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from movement.plots.occupancy import plot_occupancy | ||
from movement.plots.trajectory import plot_trajectory | ||
|
||
__all__ = ["plot_trajectory"] | ||
__all__ = ["plot_occupancy", "plot_trajectory"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
"""Wrappers for plotting occupancy data of select individuals.""" | ||
|
||
from collections.abc import Hashable | ||
from typing import Any, Literal, TypeAlias | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import xarray as xr | ||
|
||
HistInfoKeys: TypeAlias = Literal["counts", "xedges", "yedges"] | ||
|
||
DEFAULT_HIST_ARGS = {"alpha": 1.0, "bins": 30, "cmap": "viridis"} | ||
|
||
|
||
def plot_occupancy( | ||
da: xr.DataArray, | ||
individuals: Hashable | None = None, | ||
keypoints: Hashable | list[Hashable] | None = None, | ||
ax: plt.Axes | None = None, | ||
**kwargs: Any, | ||
) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]: | ||
"""Create a 2D histogram of the occupancy data given. | ||
|
||
By default; | ||
|
||
- If there are multiple keypoints selected, the occupancy of the centroid | ||
of these keypoints is computed. | ||
- If there are multiple individuals selected, the occupancies of their | ||
centroids are aggregated. | ||
|
||
Points whose corresponding spatial coordinates have NaN values | ||
are ignored. | ||
|
||
Histogram information is returned as the third output value (see Notes). | ||
|
||
Parameters | ||
---------- | ||
da : xarray.DataArray | ||
Spatial data to create histogram for. NaN values are dropped. | ||
individuals : Hashable, optional | ||
The name of the individual(s) to be aggregated and plotted. By default, | ||
all individuals are aggregated. | ||
keypoints : Hashable | list[Hashable], optional | ||
Name of a keypoint or list of such names. The centroid of all provided | ||
keypoints is computed, then plotted in the histogram. | ||
ax : matplotlib.axes.Axes, optional | ||
Axes object on which to draw the histogram. If not provided, a new | ||
figure and axes are created and returned. | ||
kwargs : Any | ||
Keyword arguments passed to ``matplotlib.pyplot.hist2d`` | ||
|
||
Returns | ||
------- | ||
matplotlib.pyplot.Figure | ||
Plot handle containing the rendered 2D histogram. If ``ax`` is | ||
supplied, this will be the figure that ``ax`` belongs to. | ||
matplotlib.axes.Axes | ||
Axes on which the histogram was drawn. If ``ax`` was supplied, | ||
the input will be directly modified and returned in this value. | ||
dict[str, numpy.ndarray] | ||
Information about the created histogram (see Notes). | ||
|
||
Notes | ||
----- | ||
In instances where the counts or information about the histogram bins is | ||
desired, the ``return_hist_info`` argument should be provided as ``True``. | ||
This will force the function to return a second output value, which is a | ||
dictionary containing the bin edges and bin counts that were used to create | ||
the histogram. | ||
|
||
For data with ``Nx`` bins in the 1st spatial dimension, and ``Ny`` bins in | ||
the 2nd spatial dimension, the dictionary output has key-value pairs; | ||
- ``xedges``, an ``(Nx+1,)`` ``numpy`` array specifying the bin edges in | ||
the 1st spatial dimension. | ||
- ``yedges``, an ``(Ny+1,)`` ``numpy`` array specifying the bin edges in | ||
the 2nd spatial dimension. | ||
- ``counts``, an ``(Nx, Ny)`` ``numpy`` array with the count for each bin. | ||
|
||
``counts[x, y]`` is the number of datapoints in the | ||
``(xedges[x], xedges[x+1]), (yedges[y], yedges[y+1])`` bin. These values | ||
are those returned from ``matplotlib.pyplot.Axes.hist2d``. | ||
|
||
See Also | ||
-------- | ||
matplotlib.pyplot.Axes.hist2d : The underlying plotting function. | ||
|
||
""" | ||
# Collapse dimensions if necessary | ||
data = da.copy(deep=True) | ||
if "keypoints" in da.dims: | ||
if keypoints is not None: | ||
data = data.sel(keypoints=keypoints) | ||
data = data.mean(dim="keypoints") | ||
if "individuals" in da.dims and individuals is not None: | ||
data = data.sel(individuals=individuals) | ||
|
||
# We need to remove NaN values from each individual, but we can't do this | ||
# right now because we still potentially have a (time, space, individuals) | ||
# array and so dropping NaNs along any axis may remove valid points for | ||
# other times / individuals. | ||
# Since we only care about a count, we can just unravel the individuals | ||
# dimension and create a "long" array of points. For example, a (10, 2, 5) | ||
# time-space-individuals DataArray becomes (50, 2). | ||
if "individuals" in data.dims: | ||
data.stack({"space": ("space", "individuals")}, create_index=False) | ||
# We should now have just the relevant time-space data, | ||
# so we can remove time-points with NaN values. | ||
data = data.dropna(dim="time", how="any") | ||
|
||
# This makes us agnostic to the planar coordinate system. | ||
x_coord = data["space"].values[0] | ||
y_coord = data["space"].values[1] | ||
|
||
# Inherit our defaults if not otherwise provided | ||
for key, value in DEFAULT_HIST_ARGS.items(): | ||
if key not in kwargs: | ||
kwargs[key] = value | ||
# Now it should just be a case of creating the histogram | ||
if ax is not None: | ||
fig = ax.get_figure() | ||
else: | ||
fig, ax = plt.subplots() | ||
counts, xedges, yedges, hist_image = ax.hist2d( | ||
data.sel(space=x_coord).stack, data.sel(space=y_coord), **kwargs | ||
) | ||
colourbar = fig.colorbar(hist_image, ax=ax) | ||
colourbar.solids.set(alpha=1.0) | ||
|
||
space_unit = data.attrs.get("space_unit", "pixels") | ||
ax.set_xlabel(f"{x_coord} ({space_unit})") | ||
ax.set_ylabel(f"{y_coord} ({space_unit})") | ||
|
||
return fig, ax, {"counts": counts, "xedges": xedges, "yedges": yedges} | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
import pytest | ||
import xarray as xr | ||
from numpy.random import Generator, default_rng | ||
|
||
|
||
@pytest.fixture | ||
def seed() -> int: | ||
return 0 | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def rng(seed: int) -> Generator: | ||
"""Create a RandomState to use in testing. | ||
|
||
This ensures the repeatability of histogram tests, that require large | ||
datasets that would be tedious to create manually. | ||
""" | ||
return default_rng(seed) | ||
|
||
|
||
@pytest.fixture | ||
def normal_dist_2d(rng: Generator) -> np.ndarray: | ||
"""Points distributed by the standard multivariate normal. | ||
|
||
The standard multivariate normal is just two independent N(0, 1) | ||
distributions, one in each dimension. | ||
""" | ||
samples = rng.multivariate_normal( | ||
(0.0, 0.0), [[1.0, 0.0], [0.0, 1.0]], (250, 3, 4) | ||
) | ||
return np.moveaxis( | ||
samples, 3, 1 | ||
) # Move generated space coords to correct axis position | ||
|
||
|
||
@pytest.fixture | ||
def histogram_data(normal_dist_2d: np.ndarray) -> xr.DataArray: | ||
"""DataArray whose data is the ``normal_dist_2d`` points. | ||
|
||
Axes 2 and 3 are the individuals and keypoints axes, respectively. | ||
These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for | ||
the purposes of indexing. | ||
""" | ||
return xr.DataArray( | ||
data=normal_dist_2d, | ||
dims=["time", "space", "individuals", "keypoints"], | ||
coords={ | ||
"space": ["x", "y"], | ||
"individuals": [f"i{i}" for i in range(normal_dist_2d.shape[2])], | ||
"keypoints": [f"k{i}" for i in range(normal_dist_2d.shape[3])], | ||
}, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def histogram_data_with_nans(histogram_data: xr.DataArray) -> xr.DataArray: | ||
"""DataArray whose data is the ``normal_dist_2d`` points. | ||
|
||
Axes 2 and 3 are the individuals and keypoints axes, respectively. | ||
These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for | ||
the purposes of indexing. | ||
|
||
For individual i0, keypoint k0, the following (time, space) values are | ||
converted into NaNs: | ||
- (100, "x") | ||
- (200, "y") | ||
- (150, "x") | ||
- (150, "y") | ||
|
||
""" | ||
individual_0 = "i0" | ||
keypoint_0 = "k0" | ||
data_with_nans = histogram_data.copy(deep=True) | ||
for time_index, space_coord in [ | ||
(100, "x"), | ||
(200, "y"), | ||
(150, "x"), | ||
(150, "y"), | ||
]: | ||
data_with_nans.loc[ | ||
time_index, space_coord, individual_0, keypoint_0 | ||
] = float("nan") | ||
return data_with_nans | ||
|
||
|
||
@pytest.fixture | ||
def entirely_nan_data(histogram_data: xr.DataArray) -> xr.DataArray: | ||
return histogram_data.copy( | ||
deep=True, data=histogram_data.values * float("nan") | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm completely on board with forwarding all
kwargs
tohist2d
. However, I think it would be helpful to illustrate some of the most commonly usedkwargs
in one or two examples in this docstring. While experimenting with this function, I found the following particularly useful:bins
(since users will want full control over the bin sizes)cmin
(especially useful when overlaying the trajectory on an image, to mask areas with low occupancy counts)norm
(particularlynorm="log"
)I don't believe we need to show all of these in the docstring example, as we have more space to explore them in a proper Sphinx Gallery example (see issue #410). However, we should at least demonstrate a typical usage of
bins
, for examplebins=(30, 30)
.