From 71b88b152580757527309f1cab7883baae38dc98 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Mon, 3 Feb 2025 12:29:25 +0000 Subject: [PATCH 01/22] Basic histogram plot created --- movement/plot.py | 84 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 movement/plot.py diff --git a/movement/plot.py b/movement/plot.py new file mode 100644 index 000000000..bedaa6b2e --- /dev/null +++ b/movement/plot.py @@ -0,0 +1,84 @@ +"""Wrappers to plot movement data.""" + +import matplotlib.pyplot as plt +import xarray as xr + + +def occupancy_histogram( + da: xr.DataArray, + keypoint: int | str = 0, + individual: int | str = 0, + title: str | None = None, +) -> plt.Figure: + """Create a 2D histogram of the occupancy data given. + + Parameters + ---------- + da : xarray.DataArray + Spatial data to create histogram for. NaN values are dropped. + keypoint : int | str + The keypoint to create a histogram for. + individual : int | str + The individual to create a histogram for. + title : str, optional + Title to give to the plot. Default will be generated from the + ``keypoint`` and ``individual`` + + Returns + ------- + matplotlib.pyplot.Figure + Plot handle containing the rendered 2D histogram. + + """ + data = da.position if isinstance(da, xr.Dataset) else da + title_components = [] + + # Remove additional dimensions before dropping NaN values + if "individuals" in data.dims: + if individual not in data["individuals"]: + individual = data["individuals"].values[individual] + data = data.sel(individuals=individual).squeeze() + title_components.append(f"individual {individual}") + if "keypoints" in data.dims: + if keypoint not in data["keypoints"]: + keypoint = data["keypoints"].values[keypoint] + data = data.sel(keypoints=keypoint).squeeze() + title_components.append(f"keypoint {keypoint}") + # We should now have just the relevant time-space data, + # so we can remove time-points with NaN values. + data = data.dropna(dim="time", how="any") + # This makes us agnostic to the planar coordinate system. + x_coord = data["space"].values[0] + y_coord = data["space"].values[1] + + # Now it should just be a case of creating the histogram + fig, ax = plt.subplots() + _, _, _, hist_image = ax.hist2d( + data.sel(space=x_coord), data.sel(space=y_coord) + ) # counts, xedges, yedges, image + colourbar = fig.colorbar(hist_image, ax=ax) + colourbar.solids.set(alpha=1.0) + + # Axis labels and title + if not title and title_components: + title = "Occupancy of " + ", ".join(title_components) + if title: + ax.set_title(title) + ax.set_xlabel(x_coord) + ax.set_ylabel(y_coord) + + return fig + + +if __name__ == "__main__": + from movement import sample_data + from movement.io import load_poses + + ds_path = sample_data.fetch_dataset_paths( + "SLEAP_single-mouse_EPM.analysis.h5" + )["poses"] + position = load_poses.from_sleap_file(ds_path, fps=None).position + + f = occupancy_histogram(position) + plt.show(block=True) + pass From 0d086616922e5e0c866cfbb985902c5a11255fab Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Mon, 3 Feb 2025 13:55:03 +0000 Subject: [PATCH 02/22] Allow kwargs to go to underlying function --- movement/plot.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/movement/plot.py b/movement/plot.py index bedaa6b2e..1e8019983 100644 --- a/movement/plot.py +++ b/movement/plot.py @@ -1,14 +1,19 @@ """Wrappers to plot movement data.""" +from typing import Any + import matplotlib.pyplot as plt import xarray as xr +DEFAULT_HIST_ARGS = {"alpha": 1.0, "bins": 30, "cmap": "viridis"} + def occupancy_histogram( da: xr.DataArray, keypoint: int | str = 0, individual: int | str = 0, title: str | None = None, + **kwargs: Any, ) -> plt.Figure: """Create a 2D histogram of the occupancy data given. @@ -23,6 +28,8 @@ def occupancy_histogram( title : str, optional Title to give to the plot. Default will be generated from the ``keypoint`` and ``individual`` + kwargs : Any + Keyword arguments passed to ``matplotlib.pyplot.hist2d`` Returns ------- @@ -51,10 +58,14 @@ def occupancy_histogram( x_coord = data["space"].values[0] y_coord = data["space"].values[1] + # Inherit our defaults if not otherwise provided + for key, value in DEFAULT_HIST_ARGS.items(): + if key not in kwargs: + kwargs[key] = value # Now it should just be a case of creating the histogram fig, ax = plt.subplots() _, _, _, hist_image = ax.hist2d( - data.sel(space=x_coord), data.sel(space=y_coord) + data.sel(space=x_coord), data.sel(space=y_coord), **kwargs ) # counts, xedges, yedges, image colourbar = fig.colorbar(hist_image, ax=ax) colourbar.solids.set(alpha=1.0) From fb61a449e342e0bb8f602dabebd0303a87bb22f3 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Mon, 3 Feb 2025 13:56:49 +0000 Subject: [PATCH 03/22] Remove manual debugging from package module --- movement/plot.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/movement/plot.py b/movement/plot.py index 1e8019983..62bcc0d4b 100644 --- a/movement/plot.py +++ b/movement/plot.py @@ -79,17 +79,3 @@ def occupancy_histogram( ax.set_ylabel(y_coord) return fig - - -if __name__ == "__main__": - from movement import sample_data - from movement.io import load_poses - - ds_path = sample_data.fetch_dataset_paths( - "SLEAP_single-mouse_EPM.analysis.h5" - )["poses"] - position = load_poses.from_sleap_file(ds_path, fps=None).position - - f = occupancy_histogram(position) - plt.show(block=True) - pass From 448e3af951b699824e7a3c1629eb8fe7c6301563 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Mon, 3 Feb 2025 15:59:25 +0000 Subject: [PATCH 04/22] Write test, but it fails. But can't figure out why it fails... --- tests/test_unit/test_plot.py | 171 +++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 tests/test_unit/test_plot.py diff --git a/tests/test_unit/test_plot.py b/tests/test_unit/test_plot.py new file mode 100644 index 000000000..295a7355c --- /dev/null +++ b/tests/test_unit/test_plot.py @@ -0,0 +1,171 @@ +from itertools import product + +import matplotlib.pyplot as plt +import numpy as np +import pytest +import xarray as xr +from matplotlib.collections import QuadMesh +from numpy.random import RandomState + +from movement.plot import occupancy_histogram + + +def get_histogram_binning_data(fig: plt.Figure) -> list[QuadMesh]: + """Fetch 2D array data from a histogram plot.""" + return [ + qm for qm in fig.axes[0].get_children() if isinstance(qm, QuadMesh) + ] + + +@pytest.fixture +def seed() -> int: + return 0 + + +@pytest.fixture(scope="function") +def rng(seed: int) -> RandomState: + """Create a RandomState to use in testing. + + This ensures the repeatability of histogram tests, that require large + datasets that would be tedious to create manually. + """ + return RandomState(seed) + + +@pytest.fixture +def normal_dist_2d(rng: RandomState) -> np.ndarray: + """Points distributed by the standard multivariate normal. + + The standard multivariate normal is just two independent N(0, 1) + distributions, one in each dimension. + """ + samples = rng.multivariate_normal( + (0.0, 0.0), [[1.0, 0.0], [0.0, 1.0]], (250, 3, 4) + ) + return np.moveaxis( + samples, 3, 1 + ) # Move generated space coords to correct axis position + + +@pytest.fixture +def histogram_data(normal_dist_2d: np.ndarray) -> xr.DataArray: + """DataArray whose data is the ``normal_dist_2d`` points. + + Axes 2 and 3 are the individuals and keypoints axes, respectively. + These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for + the purposes of indexing. + """ + return xr.DataArray( + data=normal_dist_2d, + dims=["time", "space", "individuals", "keypoints"], + coords={ + "space": ["x", "y"], + "individuals": [f"i{i}" for i in range(normal_dist_2d.shape[2])], + "keypoints": [f"k{i}" for i in range(normal_dist_2d.shape[3])], + }, + ) + + +@pytest.fixture +def histogram_data_with_nans( + histogram_data: xr.DataArray, rng: RandomState +) -> xr.DataArray: + """DataArray whose data is the ``normal_dist_2d`` points. + + Each datapoint has a chance of being turned into a NaN value. + + Axes 2 and 3 are the individuals and keypoints axes, respectively. + These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for + the purposes of indexing. + """ + data_with_nans = histogram_data.copy(deep=True) + data_shape = data_with_nans.shape + nan_chance = 1.0 / 25.0 + index_ranges = [range(dim_length) for dim_length in data_shape] + for multiindex in product(*index_ranges): + if rng.uniform() < nan_chance: + data_with_nans[*multiindex] = float("nan") + return data_with_nans + + +# def test_histogram_ignores_missing_dims( +# input_does_not_have_dimensions: list[str], +# ) -> None: +# """Test that ``occupancy_histogram`` ignores non-present dimensions.""" +# input_data = 0 + + +@pytest.mark.parametrize( + ["data", "individual", "keypoint", "n_bins"], + [pytest.param("histogram_data", "i0", "k0", 30, id="30 bins each axis")], +) +def test_occupancy_histogram( + data: xr.DataArray, + individual: int | str, + keypoint: int | str, + n_bins: int | tuple[int, int], + request, +) -> None: + """Test that occupancy histograms correctly plot data.""" + if isinstance(data, str): + data = request.getfixturevalue(data) + + plotted_hist = occupancy_histogram( + data, individual=individual, keypoint=keypoint, bins=n_bins + ) + + # Confirm that a histogram was made + plotted_data = get_histogram_binning_data(plotted_hist) + assert len(plotted_data) == 1 + plotted_data = plotted_data[0] + plotting_coords = plotted_data.get_coordinates() + plotted_values = plotted_data.get_array() + + # Confirm the binned array has the correct size + if not isinstance(n_bins, tuple): + n_bins = (n_bins, n_bins) + assert plotted_data.get_array().shape == n_bins + + # Confirm that each bin has the correct number of assignments + data_time_xy = data.sel(individuals=individual, keypoints=keypoint) + x_values = data_time_xy.sel(space="x").values + y_values = data_time_xy.sel(space="y").values + reconstructed_bins_limits_x = np.linspace( + x_values.min(), + x_values.max(), + num=n_bins[0] + 1, + endpoint=True, + ) + assert all( + np.allclose(reconstructed_bins_limits_x, plotting_coords[i, :, 0]) + for i in range(n_bins[0]) + ) + reconstructed_bins_limits_y = np.linspace( + y_values.min(), + y_values.max(), + num=n_bins[1] + 1, + endpoint=True, + ) + assert all( + np.allclose(reconstructed_bins_limits_y, plotting_coords[:, j, 1]) + for j in range(n_bins[1]) + ) + + reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float) + for i, xi in enumerate(reconstructed_bins_limits_x[:-1]): + xi_p1 = reconstructed_bins_limits_x[i + 1] + + x_pts_in_range = (x_values >= xi) & (x_values <= xi_p1) + for j, yj in enumerate(reconstructed_bins_limits_y[:-1]): + yj_p1 = reconstructed_bins_limits_y[j + 1] + + y_pts_in_range = (y_values >= yj) & (y_values <= yj_p1) + + pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum() + reconstructed_bin_counts[i, j] = pts_in_this_bin + + if pts_in_this_bin != plotted_values[i, j]: + pass + + assert reconstructed_bin_counts.sum() == plotted_values.sum() + assert np.all(reconstructed_bin_counts == plotted_values) From a4e5651540b24ec70123cd133b03d586dc87fca7 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 4 Feb 2025 12:19:12 +0000 Subject: [PATCH 05/22] Additional return values to help extract histogram information --- movement/plot.py | 46 +++++++++++++-- tests/test_unit/test_plot.py | 111 +++++++++++++++++++++-------------- 2 files changed, 107 insertions(+), 50 deletions(-) diff --git a/movement/plot.py b/movement/plot.py index 62bcc0d4b..0120544b8 100644 --- a/movement/plot.py +++ b/movement/plot.py @@ -1,10 +1,13 @@ """Wrappers to plot movement data.""" -from typing import Any +from typing import Any, Literal, TypeAlias import matplotlib.pyplot as plt +import numpy as np import xarray as xr +HistInfoKeys: TypeAlias = Literal["counts", "xedges", "yedges"] + DEFAULT_HIST_ARGS = {"alpha": 1.0, "bins": 30, "cmap": "viridis"} @@ -14,9 +17,13 @@ def occupancy_histogram( individual: int | str = 0, title: str | None = None, **kwargs: Any, -) -> plt.Figure: +) -> tuple[plt.Figure, dict[HistInfoKeys, np.ndarray]]: """Create a 2D histogram of the occupancy data given. + Time-points whose corresponding spatial coordinates have NaN values + are ignored. Histogram information is returned as the second output + value (see Notes). + Parameters ---------- da : xarray.DataArray @@ -35,6 +42,35 @@ def occupancy_histogram( ------- matplotlib.pyplot.Figure Plot handle containing the rendered 2D histogram. + dict[str, numpy.ndarray] + Information about the created histogram (see Notes). + + Notes + ----- + In instances where the counts or information about the histogram bins is + desired, the ``return_hist_info`` argument should be provided as ``True``. + This will force the function to return a second output value, which is a + dictionary containing the bin edges and bin counts that were used to create + the histogram. + + For data with ``N`` time-points, the dictionary output has key-value pairs; + - ``xedges``, an ``(N+1,)`` ``numpy`` array specifying the bin edges in the + first spatial dimension. + - ``yedges``, same as ``xedges`` but for the second spatial dimension. + - ``counts``, an ``(N, N)`` ``numpy`` array with the count for each bin. + + ``counts[x, y]`` is the number of datapoints in the + ``(xedges[x], xedges[x+1]), (yedges[y], yedges[y+1])`` bin. These values + are those returned from ``matplotlib.pyplot.Axes.hist2d``. + + Note that the ``counts`` values do not necessarily match the mappable + values that one gets from extracting the data from the + ``matplotlib.collections.QuadMesh`` object (that represents the rendered + histogram) via its ``get_array()`` attribute. + + See Also + -------- + matplotlib.pyplot.Axes.hist2d : The underlying plotting function. """ data = da.position if isinstance(da, xr.Dataset) else da @@ -64,9 +100,9 @@ def occupancy_histogram( kwargs[key] = value # Now it should just be a case of creating the histogram fig, ax = plt.subplots() - _, _, _, hist_image = ax.hist2d( + counts, xedges, yedges, hist_image = ax.hist2d( data.sel(space=x_coord), data.sel(space=y_coord), **kwargs - ) # counts, xedges, yedges, image + ) colourbar = fig.colorbar(hist_image, ax=ax) colourbar.solids.set(alpha=1.0) @@ -78,4 +114,4 @@ def occupancy_histogram( ax.set_xlabel(x_coord) ax.set_ylabel(y_coord) - return fig + return fig, {"counts": counts, "xedges": xedges, "yedges": yedges} diff --git a/tests/test_unit/test_plot.py b/tests/test_unit/test_plot.py index 295a7355c..d55f70881 100644 --- a/tests/test_unit/test_plot.py +++ b/tests/test_unit/test_plot.py @@ -1,22 +1,11 @@ -from itertools import product - -import matplotlib.pyplot as plt import numpy as np import pytest import xarray as xr -from matplotlib.collections import QuadMesh from numpy.random import RandomState from movement.plot import occupancy_histogram -def get_histogram_binning_data(fig: plt.Figure) -> list[QuadMesh]: - """Fetch 2D array data from a histogram plot.""" - return [ - qm for qm in fig.axes[0].get_children() if isinstance(qm, QuadMesh) - ] - - @pytest.fixture def seed() -> int: return 0 @@ -72,19 +61,30 @@ def histogram_data_with_nans( ) -> xr.DataArray: """DataArray whose data is the ``normal_dist_2d`` points. - Each datapoint has a chance of being turned into a NaN value. - Axes 2 and 3 are the individuals and keypoints axes, respectively. These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for the purposes of indexing. + + For individual i0, keypoint k0, the following (time, space) values are + converted into NaNs: + - (100, "x") + - (200, "y") + - (150, "x") + - (150, "y") + """ + individual_0 = "i0" + keypoint_0 = "k0" data_with_nans = histogram_data.copy(deep=True) - data_shape = data_with_nans.shape - nan_chance = 1.0 / 25.0 - index_ranges = [range(dim_length) for dim_length in data_shape] - for multiindex in product(*index_ranges): - if rng.uniform() < nan_chance: - data_with_nans[*multiindex] = float("nan") + for time_index, space_coord in [ + (100, "x"), + (200, "y"), + (150, "x"), + (150, "y"), + ]: + data_with_nans.loc[ + time_index, space_coord, individual_0, keypoint_0 + ] = float("nan") return data_with_nans @@ -97,7 +97,29 @@ def histogram_data_with_nans( @pytest.mark.parametrize( ["data", "individual", "keypoint", "n_bins"], - [pytest.param("histogram_data", "i0", "k0", 30, id="30 bins each axis")], + [ + pytest.param( + "histogram_data", + "i0", + "k0", + 30, + id="30 bins each axis", + ), + pytest.param( + "histogram_data", + "i1", + "k0", + (20, 30), + id="(20, 30) bins", + ), + pytest.param( + "histogram_data_with_nans", + "i0", + "k0", + 30, + id="NaNs should be removed", + ), + ], ) def test_occupancy_histogram( data: xr.DataArray, @@ -110,56 +132,51 @@ def test_occupancy_histogram( if isinstance(data, str): data = request.getfixturevalue(data) - plotted_hist = occupancy_histogram( + _, histogram_info = occupancy_histogram( data, individual=individual, keypoint=keypoint, bins=n_bins ) - - # Confirm that a histogram was made - plotted_data = get_histogram_binning_data(plotted_hist) - assert len(plotted_data) == 1 - plotted_data = plotted_data[0] - plotting_coords = plotted_data.get_coordinates() - plotted_values = plotted_data.get_array() + plotted_values = histogram_info["counts"] # Confirm the binned array has the correct size if not isinstance(n_bins, tuple): n_bins = (n_bins, n_bins) - assert plotted_data.get_array().shape == n_bins + assert plotted_values.shape == n_bins # Confirm that each bin has the correct number of assignments data_time_xy = data.sel(individuals=individual, keypoints=keypoint) - x_values = data_time_xy.sel(space="x").values - y_values = data_time_xy.sel(space="y").values + data_time_xy = data_time_xy.dropna(dim="time", how="any") + plotted_x_values = data_time_xy.sel(space="x").values + plotted_y_values = data_time_xy.sel(space="y").values + assert plotted_x_values.shape == plotted_y_values.shape + # This many non-NaN values were plotted + n_non_nan_values = plotted_x_values.shape[0] + reconstructed_bins_limits_x = np.linspace( - x_values.min(), - x_values.max(), + plotted_x_values.min(), + plotted_x_values.max(), num=n_bins[0] + 1, endpoint=True, ) - assert all( - np.allclose(reconstructed_bins_limits_x, plotting_coords[i, :, 0]) - for i in range(n_bins[0]) - ) + assert np.allclose(reconstructed_bins_limits_x, histogram_info["xedges"]) reconstructed_bins_limits_y = np.linspace( - y_values.min(), - y_values.max(), + plotted_y_values.min(), + plotted_y_values.max(), num=n_bins[1] + 1, endpoint=True, ) - assert all( - np.allclose(reconstructed_bins_limits_y, plotting_coords[:, j, 1]) - for j in range(n_bins[1]) - ) + assert np.allclose(reconstructed_bins_limits_y, histogram_info["yedges"]) reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float) for i, xi in enumerate(reconstructed_bins_limits_x[:-1]): xi_p1 = reconstructed_bins_limits_x[i + 1] - x_pts_in_range = (x_values >= xi) & (x_values <= xi_p1) + x_pts_in_range = (plotted_x_values >= xi) & (plotted_x_values <= xi_p1) for j, yj in enumerate(reconstructed_bins_limits_y[:-1]): yj_p1 = reconstructed_bins_limits_y[j + 1] - y_pts_in_range = (y_values >= yj) & (y_values <= yj_p1) + y_pts_in_range = (plotted_y_values >= yj) & ( + plotted_y_values <= yj_p1 + ) pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum() reconstructed_bin_counts[i, j] = pts_in_this_bin @@ -167,5 +184,9 @@ def test_occupancy_histogram( if pts_in_this_bin != plotted_values[i, j]: pass + # We agree with a manual count assert reconstructed_bin_counts.sum() == plotted_values.sum() + # All non-NaN values were plotted + assert n_non_nan_values == plotted_values.sum() + # The counts were actually correct assert np.all(reconstructed_bin_counts == plotted_values) From 03cb1dba542c7e228446ab18f4e0d6de9f5b5250 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 4 Feb 2025 12:44:32 +0000 Subject: [PATCH 06/22] Test missing dims and entirely NAN values --- tests/test_unit/test_plot.py | 157 ++++++++++++++++++++++++++--------- 1 file changed, 116 insertions(+), 41 deletions(-) diff --git a/tests/test_unit/test_plot.py b/tests/test_unit/test_plot.py index d55f70881..176348831 100644 --- a/tests/test_unit/test_plot.py +++ b/tests/test_unit/test_plot.py @@ -88,18 +88,25 @@ def histogram_data_with_nans( return data_with_nans -# def test_histogram_ignores_missing_dims( -# input_does_not_have_dimensions: list[str], -# ) -> None: -# """Test that ``occupancy_histogram`` ignores non-present dimensions.""" -# input_data = 0 +@pytest.fixture +def entirely_nan_data(histogram_data: xr.DataArray) -> xr.DataArray: + return histogram_data.copy( + deep=True, data=histogram_data.values * float("nan") + ) @pytest.mark.parametrize( - ["data", "individual", "keypoint", "n_bins"], + [ + "data", + "remove_dims_from_data_before_starting", + "individual", + "keypoint", + "n_bins", + ], [ pytest.param( "histogram_data", + [], "i0", "k0", 30, @@ -107,6 +114,7 @@ def histogram_data_with_nans( ), pytest.param( "histogram_data", + [], "i1", "k0", (20, 30), @@ -114,24 +122,81 @@ def histogram_data_with_nans( ), pytest.param( "histogram_data_with_nans", + [], "i0", "k0", 30, id="NaNs should be removed", ), + pytest.param( + "entirely_nan_data", + [], + "i0", + "k0", + 10, + id="All NaN-data", + ), + pytest.param( + "histogram_data", + ["individuals"], + "i0", + "k0", + 30, + id="Ignores individual if not a dimension", + ), + pytest.param( + "histogram_data", + ["keypoints"], + "i0", + "k1", + 30, + id="Ignores keypoint if not a dimension", + ), + pytest.param( + "histogram_data", + ["individuals", "keypoints"], + "i0", + "k0", + 30, + id="Can handle raw xy data", + ), ], ) def test_occupancy_histogram( data: xr.DataArray, + remove_dims_from_data_before_starting: list[str], individual: int | str, keypoint: int | str, n_bins: int | tuple[int, int], request, ) -> None: - """Test that occupancy histograms correctly plot data.""" + """Test that occupancy histograms correctly plot data. + + Specifically, check that: + - The bin edges are what we expect. + - The bin counts can be manually verified and are in agreement. + - Only non-NaN values are plotted, but NaN values do not throw errors. + """ if isinstance(data, str): data = request.getfixturevalue(data) + # We will need to only select the xy data later in the test, + # but if we are dropping dimensions we might need to call it + # in different ways. + kwargs_to_select_xy_data = { + "individuals": individual, + "keypoints": keypoint, + } + for d in remove_dims_from_data_before_starting: + # Retain the 0th value in the corresponding dimension, + # then drop that dimension. + data = data.sel({d: getattr(data, d)[0]}).squeeze() + assert d not in data.dims + + # We no longer need to filter this dimension out + # when examining the xy data later in the test. + kwargs_to_select_xy_data.pop(d, None) + _, histogram_info = occupancy_histogram( data, individual=individual, keypoint=keypoint, bins=n_bins ) @@ -143,7 +208,7 @@ def test_occupancy_histogram( assert plotted_values.shape == n_bins # Confirm that each bin has the correct number of assignments - data_time_xy = data.sel(individuals=individual, keypoints=keypoint) + data_time_xy = data.sel(**kwargs_to_select_xy_data) data_time_xy = data_time_xy.dropna(dim="time", how="any") plotted_x_values = data_time_xy.sel(space="x").values plotted_y_values = data_time_xy.sel(space="y").values @@ -151,42 +216,52 @@ def test_occupancy_histogram( # This many non-NaN values were plotted n_non_nan_values = plotted_x_values.shape[0] - reconstructed_bins_limits_x = np.linspace( - plotted_x_values.min(), - plotted_x_values.max(), - num=n_bins[0] + 1, - endpoint=True, - ) - assert np.allclose(reconstructed_bins_limits_x, histogram_info["xedges"]) - reconstructed_bins_limits_y = np.linspace( - plotted_y_values.min(), - plotted_y_values.max(), - num=n_bins[1] + 1, - endpoint=True, - ) - assert np.allclose(reconstructed_bins_limits_y, histogram_info["yedges"]) + if n_non_nan_values > 0: + reconstructed_bins_limits_x = np.linspace( + plotted_x_values.min(), + plotted_x_values.max(), + num=n_bins[0] + 1, + endpoint=True, + ) + assert np.allclose( + reconstructed_bins_limits_x, histogram_info["xedges"] + ) + reconstructed_bins_limits_y = np.linspace( + plotted_y_values.min(), + plotted_y_values.max(), + num=n_bins[1] + 1, + endpoint=True, + ) + assert np.allclose( + reconstructed_bins_limits_y, histogram_info["yedges"] + ) - reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float) - for i, xi in enumerate(reconstructed_bins_limits_x[:-1]): - xi_p1 = reconstructed_bins_limits_x[i + 1] + reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float) + for i, xi in enumerate(reconstructed_bins_limits_x[:-1]): + xi_p1 = reconstructed_bins_limits_x[i + 1] - x_pts_in_range = (plotted_x_values >= xi) & (plotted_x_values <= xi_p1) - for j, yj in enumerate(reconstructed_bins_limits_y[:-1]): - yj_p1 = reconstructed_bins_limits_y[j + 1] - - y_pts_in_range = (plotted_y_values >= yj) & ( - plotted_y_values <= yj_p1 + x_pts_in_range = (plotted_x_values >= xi) & ( + plotted_x_values <= xi_p1 ) + for j, yj in enumerate(reconstructed_bins_limits_y[:-1]): + yj_p1 = reconstructed_bins_limits_y[j + 1] + + y_pts_in_range = (plotted_y_values >= yj) & ( + plotted_y_values <= yj_p1 + ) - pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum() - reconstructed_bin_counts[i, j] = pts_in_this_bin + pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum() + reconstructed_bin_counts[i, j] = pts_in_this_bin - if pts_in_this_bin != plotted_values[i, j]: - pass + if pts_in_this_bin != plotted_values[i, j]: + pass - # We agree with a manual count - assert reconstructed_bin_counts.sum() == plotted_values.sum() - # All non-NaN values were plotted - assert n_non_nan_values == plotted_values.sum() - # The counts were actually correct - assert np.all(reconstructed_bin_counts == plotted_values) + # We agree with a manual count + assert reconstructed_bin_counts.sum() == plotted_values.sum() + # All non-NaN values were plotted + assert n_non_nan_values == plotted_values.sum() + # The counts were actually correct + assert np.all(reconstructed_bin_counts == plotted_values) + else: + # No non-nan values were given + assert plotted_values.sum() == 0 From c3c77aeca1ed2ed7be7cc8781cdc78318550df58 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 4 Feb 2025 13:03:00 +0000 Subject: [PATCH 07/22] Check that new / existing axes are respected --- movement/plot.py | 19 +++++++++++++++---- tests/test_unit/test_plot.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/movement/plot.py b/movement/plot.py index 0120544b8..80312a290 100644 --- a/movement/plot.py +++ b/movement/plot.py @@ -16,8 +16,9 @@ def occupancy_histogram( keypoint: int | str = 0, individual: int | str = 0, title: str | None = None, + ax: plt.Axes | None = None, **kwargs: Any, -) -> tuple[plt.Figure, dict[HistInfoKeys, np.ndarray]]: +) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]: """Create a 2D histogram of the occupancy data given. Time-points whose corresponding spatial coordinates have NaN values @@ -35,13 +36,20 @@ def occupancy_histogram( title : str, optional Title to give to the plot. Default will be generated from the ``keypoint`` and ``individual`` + ax : matplotlib.axes.Axes, optional + Axes object on which to draw the histogram. If not provided, a new + figure and axes are created and returned. kwargs : Any Keyword arguments passed to ``matplotlib.pyplot.hist2d`` Returns ------- matplotlib.pyplot.Figure - Plot handle containing the rendered 2D histogram. + Plot handle containing the rendered 2D histogram. If ``ax`` is + supplied, this will be the figure that ``ax`` belongs to. + matplotlib.axes.Axes + Axes on which the histogram was drawn. If ``ax`` was supplied, + the input will be directly modified and returned in this value. dict[str, numpy.ndarray] Information about the created histogram (see Notes). @@ -99,7 +107,10 @@ def occupancy_histogram( if key not in kwargs: kwargs[key] = value # Now it should just be a case of creating the histogram - fig, ax = plt.subplots() + if ax is not None: + fig = ax.get_figure() + else: + fig, ax = plt.subplots() counts, xedges, yedges, hist_image = ax.hist2d( data.sel(space=x_coord), data.sel(space=y_coord), **kwargs ) @@ -114,4 +125,4 @@ def occupancy_histogram( ax.set_xlabel(x_coord) ax.set_ylabel(y_coord) - return fig, {"counts": counts, "xedges": xedges, "yedges": yedges} + return fig, ax, {"counts": counts, "xedges": xedges, "yedges": yedges} diff --git a/tests/test_unit/test_plot.py b/tests/test_unit/test_plot.py index 176348831..1b72c3390 100644 --- a/tests/test_unit/test_plot.py +++ b/tests/test_unit/test_plot.py @@ -1,6 +1,8 @@ +import matplotlib.pyplot as plt import numpy as np import pytest import xarray as xr +from matplotlib.collections import QuadMesh from numpy.random import RandomState from movement.plot import occupancy_histogram @@ -197,7 +199,7 @@ def test_occupancy_histogram( # when examining the xy data later in the test. kwargs_to_select_xy_data.pop(d, None) - _, histogram_info = occupancy_histogram( + _, _, histogram_info = occupancy_histogram( data, individual=individual, keypoint=keypoint, bins=n_bins ) plotted_values = histogram_info["counts"] @@ -265,3 +267,35 @@ def test_occupancy_histogram( else: # No non-nan values were given assert plotted_values.sum() == 0 + + +def test_respects_axes(histogram_data: xr.DataArray) -> None: + """Check that existing axes objects are respected if passed.""" + # Plotting on existing axes + existing_fig, existing_ax = plt.subplots(1, 2) + + existing_ax[0].plot( + np.linspace(0.0, 10.0, num=100), np.linspace(0.0, 10.0, num=100) + ) + + _, _, hist_info_existing = occupancy_histogram( + histogram_data, ax=existing_ax[1] + ) + hist_plots_added = [ + qm for qm in existing_ax[1].get_children() if isinstance(qm, QuadMesh) + ] + assert len(hist_plots_added) == 1 + + # Plot on new axis and create a new figure + new_fig, new_ax, hist_info_new = occupancy_histogram(histogram_data) + hist_plots_created = [ + qm for qm in new_ax.get_children() if isinstance(qm, QuadMesh) + ] + assert len(hist_plots_created) == 1 + + # Check that the same plot was made for each + assert set(hist_info_new.keys()) == set(hist_info_existing.keys()) + for key, new_ax_value in hist_info_new.items(): + existing_ax_value = hist_info_existing[key] + + assert np.allclose(new_ax_value, existing_ax_value) From c8cf1b43979be5939dc4c04ef9d1080e7d7becdd Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 4 Feb 2025 13:03:15 +0000 Subject: [PATCH 08/22] Default units to pixels --- movement/plot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/movement/plot.py b/movement/plot.py index 80312a290..5f94d4833 100644 --- a/movement/plot.py +++ b/movement/plot.py @@ -122,7 +122,8 @@ def occupancy_histogram( title = "Occupancy of " + ", ".join(title_components) if title: ax.set_title(title) - ax.set_xlabel(x_coord) - ax.set_ylabel(y_coord) + space_unit = data.attrs.get("space_unit", "pixels") + ax.set_xlabel(f"{x_coord} ({space_unit})") + ax.set_ylabel(f"{y_coord} ({space_unit})") return fig, ax, {"counts": counts, "xedges": xedges, "yedges": yedges} From e042f7c2b45b191c18a655cda80ee63020b3c560 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 4 Feb 2025 13:57:21 +0000 Subject: [PATCH 09/22] SonarQube recommendations --- tests/test_unit/test_plot.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_unit/test_plot.py b/tests/test_unit/test_plot.py index 1b72c3390..6b7918fa8 100644 --- a/tests/test_unit/test_plot.py +++ b/tests/test_unit/test_plot.py @@ -3,7 +3,7 @@ import pytest import xarray as xr from matplotlib.collections import QuadMesh -from numpy.random import RandomState +from numpy.random import Generator, default_rng from movement.plot import occupancy_histogram @@ -14,17 +14,17 @@ def seed() -> int: @pytest.fixture(scope="function") -def rng(seed: int) -> RandomState: +def rng(seed: int) -> Generator: """Create a RandomState to use in testing. This ensures the repeatability of histogram tests, that require large datasets that would be tedious to create manually. """ - return RandomState(seed) + return default_rng(seed) @pytest.fixture -def normal_dist_2d(rng: RandomState) -> np.ndarray: +def normal_dist_2d(rng: Generator) -> np.ndarray: """Points distributed by the standard multivariate normal. The standard multivariate normal is just two independent N(0, 1) @@ -59,7 +59,7 @@ def histogram_data(normal_dist_2d: np.ndarray) -> xr.DataArray: @pytest.fixture def histogram_data_with_nans( - histogram_data: xr.DataArray, rng: RandomState + histogram_data: xr.DataArray, rng: Generator ) -> xr.DataArray: """DataArray whose data is the ``normal_dist_2d`` points. @@ -255,9 +255,6 @@ def test_occupancy_histogram( pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum() reconstructed_bin_counts[i, j] = pts_in_this_bin - if pts_in_this_bin != plotted_values[i, j]: - pass - # We agree with a manual count assert reconstructed_bin_counts.sum() == plotted_values.sum() # All non-NaN values were plotted @@ -272,7 +269,7 @@ def test_occupancy_histogram( def test_respects_axes(histogram_data: xr.DataArray) -> None: """Check that existing axes objects are respected if passed.""" # Plotting on existing axes - existing_fig, existing_ax = plt.subplots(1, 2) + _, existing_ax = plt.subplots(1, 2) existing_ax[0].plot( np.linspace(0.0, 10.0, num=100), np.linspace(0.0, 10.0, num=100) @@ -288,6 +285,8 @@ def test_respects_axes(histogram_data: xr.DataArray) -> None: # Plot on new axis and create a new figure new_fig, new_ax, hist_info_new = occupancy_histogram(histogram_data) + assert isinstance(new_fig, plt.Figure) + assert isinstance(new_ax, plt.Axes) hist_plots_created = [ qm for qm in new_ax.get_children() if isinstance(qm, QuadMesh) ] From be0d22bab01a157e1337e6717c866bb6f6184f35 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 11 Feb 2025 10:33:39 +0000 Subject: [PATCH 10/22] Comply with new plot wrapper standards --- movement/plots/__init__.py | 3 +- movement/{plot.py => plots/occupancy.py} | 66 ++++++++++++------------ 2 files changed, 36 insertions(+), 33 deletions(-) rename movement/{plot.py => plots/occupancy.py} (72%) diff --git a/movement/plots/__init__.py b/movement/plots/__init__.py index bd8035dc7..ae3982c33 100644 --- a/movement/plots/__init__.py +++ b/movement/plots/__init__.py @@ -1,3 +1,4 @@ +from movement.plots.occupancy import plot_occupancy from movement.plots.trajectory import plot_trajectory -__all__ = ["plot_trajectory"] +__all__ = ["plot_occupancy", "plot_trajectory"] diff --git a/movement/plot.py b/movement/plots/occupancy.py similarity index 72% rename from movement/plot.py rename to movement/plots/occupancy.py index 5f94d4833..2526897c4 100644 --- a/movement/plot.py +++ b/movement/plots/occupancy.py @@ -1,5 +1,6 @@ -"""Wrappers to plot movement data.""" +"""Wrappers for plotting occupancy data of select individuals.""" +from collections.abc import Hashable from typing import Any, Literal, TypeAlias import matplotlib.pyplot as plt @@ -11,31 +12,33 @@ DEFAULT_HIST_ARGS = {"alpha": 1.0, "bins": 30, "cmap": "viridis"} -def occupancy_histogram( +def plot_occupancy( da: xr.DataArray, - keypoint: int | str = 0, - individual: int | str = 0, - title: str | None = None, + selection: dict[str, Hashable], ax: plt.Axes | None = None, **kwargs: Any, ) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]: """Create a 2D histogram of the occupancy data given. + By default, the 0-indexed value along non-"time" and non-"space" dimensions + is plotted. The ``selection`` variable can be used to select different + coordinates along additional dimensions to plot instead. + Time-points whose corresponding spatial coordinates have NaN values - are ignored. Histogram information is returned as the second output - value (see Notes). + are ignored. + + Histogram information is returned as the third output value (see Notes). Parameters ---------- da : xarray.DataArray Spatial data to create histogram for. NaN values are dropped. - keypoint : int | str - The keypoint to create a histogram for. - individual : int | str - The individual to create a histogram for. - title : str, optional - Title to give to the plot. Default will be generated from the - ``keypoint`` and ``individual`` + selection : dict[str, Hashable] + Mapping of additional dimension identifiers to the coordinate along + that dimension to plot. For example, + ``selection = {"individuals": "Bravo"}`` will create the occupancy + histogram for the individual "Bravo", instead of the occupancy + histogram for the 0-indexed entry on the ``"individuals"`` dimension. ax : matplotlib.axes.Axes, optional Axes object on which to draw the histogram. If not provided, a new figure and axes are created and returned. @@ -81,20 +84,24 @@ def occupancy_histogram( matplotlib.pyplot.Axes.hist2d : The underlying plotting function. """ - data = da.position if isinstance(da, xr.Dataset) else da - title_components = [] - # Remove additional dimensions before dropping NaN values - if "individuals" in data.dims: - if individual not in data["individuals"]: - individual = data["individuals"].values[individual] - data = data.sel(individuals=individual).squeeze() - title_components.append(f"individual {individual}") - if "keypoints" in data.dims: - if keypoint not in data["keypoints"]: - keypoint = data["keypoints"].values[keypoint] - data = data.sel(keypoints=keypoint).squeeze() - title_components.append(f"keypoint {keypoint}") + non_spacetime_dims = [ + dim for dim in da.dims if dim not in ("time", "space") + ] + selection = { + dim: selection.get(dim, da[dim].values[0]) + for dim in non_spacetime_dims + } + data = da.sel(**selection).squeeze() + # Selections must be scalar, resulting in 2D data. + # Catch this now + if data.ndim != 2: + raise ValueError( + "Histogram data was not time-space only. " + "Did you accidentally pass multiple coordinates for any of " + f"the following dimensions: {non_spacetime_dims}" + ) + # We should now have just the relevant time-space data, # so we can remove time-points with NaN values. data = data.dropna(dim="time", how="any") @@ -117,11 +124,6 @@ def occupancy_histogram( colourbar = fig.colorbar(hist_image, ax=ax) colourbar.solids.set(alpha=1.0) - # Axis labels and title - if not title and title_components: - title = "Occupancy of " + ", ".join(title_components) - if title: - ax.set_title(title) space_unit = data.attrs.get("space_unit", "pixels") ax.set_xlabel(f"{x_coord} ({space_unit})") ax.set_ylabel(f"{y_coord} ({space_unit})") From 1614ab4138cf9bf8a02600579f03f87e031b6daf Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 11 Feb 2025 10:51:28 +0000 Subject: [PATCH 11/22] Add test for default selection case --- movement/plots/occupancy.py | 13 ++-- tests/test_unit/{ => test_plots}/test_plot.py | 64 +++++++++++-------- 2 files changed, 44 insertions(+), 33 deletions(-) rename tests/test_unit/{ => test_plots}/test_plot.py (85%) diff --git a/movement/plots/occupancy.py b/movement/plots/occupancy.py index 2526897c4..329d98f15 100644 --- a/movement/plots/occupancy.py +++ b/movement/plots/occupancy.py @@ -14,7 +14,7 @@ def plot_occupancy( da: xr.DataArray, - selection: dict[str, Hashable], + selection: dict[str, Hashable] | None = None, ax: plt.Axes | None = None, **kwargs: Any, ) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]: @@ -33,9 +33,9 @@ def plot_occupancy( ---------- da : xarray.DataArray Spatial data to create histogram for. NaN values are dropped. - selection : dict[str, Hashable] - Mapping of additional dimension identifiers to the coordinate along - that dimension to plot. For example, + selection : dict[str, Hashable], optional + Mapping of dimension identifiers to the coordinate along that dimension + to plot. "time" and "space" dimensions are ignored. For example, ``selection = {"individuals": "Bravo"}`` will create the occupancy histogram for the individual "Bravo", instead of the occupancy histogram for the 0-indexed entry on the ``"individuals"`` dimension. @@ -84,6 +84,9 @@ def plot_occupancy( matplotlib.pyplot.Axes.hist2d : The underlying plotting function. """ + if selection is None: + selection = dict() + # Remove additional dimensions before dropping NaN values non_spacetime_dims = [ dim for dim in da.dims if dim not in ("time", "space") @@ -92,7 +95,7 @@ def plot_occupancy( dim: selection.get(dim, da[dim].values[0]) for dim in non_spacetime_dims } - data = da.sel(**selection).squeeze() + data: xr.DataArray = da.sel(**selection).squeeze() # Selections must be scalar, resulting in 2D data. # Catch this now if data.ndim != 2: diff --git a/tests/test_unit/test_plot.py b/tests/test_unit/test_plots/test_plot.py similarity index 85% rename from tests/test_unit/test_plot.py rename to tests/test_unit/test_plots/test_plot.py index 6b7918fa8..4ed7d38cf 100644 --- a/tests/test_unit/test_plot.py +++ b/tests/test_unit/test_plots/test_plot.py @@ -1,3 +1,5 @@ +from collections.abc import Hashable + import matplotlib.pyplot as plt import numpy as np import pytest @@ -5,7 +7,7 @@ from matplotlib.collections import QuadMesh from numpy.random import Generator, default_rng -from movement.plot import occupancy_histogram +from movement.plots import plot_occupancy @pytest.fixture @@ -101,74 +103,72 @@ def entirely_nan_data(histogram_data: xr.DataArray) -> xr.DataArray: [ "data", "remove_dims_from_data_before_starting", - "individual", - "keypoint", + "selection", "n_bins", ], [ pytest.param( "histogram_data", [], - "i0", - "k0", + {"individuals": "i0", "keypoints": "k0"}, 30, id="30 bins each axis", ), pytest.param( "histogram_data", [], - "i1", - "k0", + None, + 30, + id="Default 0-index", + ), + pytest.param( + "histogram_data", + [], + {"individuals": "i1", "keypoints": "k0"}, (20, 30), id="(20, 30) bins", ), pytest.param( "histogram_data_with_nans", [], - "i0", - "k0", + {"individuals": "i0", "keypoints": "k0"}, 30, id="NaNs should be removed", ), pytest.param( "entirely_nan_data", [], - "i0", - "k0", + {"individuals": "i0", "keypoints": "k0"}, 10, id="All NaN-data", ), pytest.param( "histogram_data", ["individuals"], - "i0", - "k0", + {"individuals": "i0", "keypoints": "k0"}, 30, id="Ignores individual if not a dimension", ), pytest.param( "histogram_data", ["keypoints"], - "i0", - "k1", + {"individuals": "i0", "keypoints": "k1"}, 30, id="Ignores keypoint if not a dimension", ), pytest.param( "histogram_data", ["individuals", "keypoints"], - "i0", - "k0", + {"individuals": "i0", "keypoints": "k0"}, 30, id="Can handle raw xy data", ), ], ) -def test_occupancy_histogram( +def test_plot_histogram( # noqa: C901 data: xr.DataArray, remove_dims_from_data_before_starting: list[str], - individual: int | str, - keypoint: int | str, + selection: dict[str, Hashable], n_bins: int | tuple[int, int], request, ) -> None: @@ -185,10 +185,18 @@ def test_occupancy_histogram( # We will need to only select the xy data later in the test, # but if we are dropping dimensions we might need to call it # in different ways. - kwargs_to_select_xy_data = { - "individuals": individual, - "keypoints": keypoint, - } + kwargs_to_select_xy_data = dict() + # By default we should fetch 0-indexes + for d in [dim for dim in data.dims if dim not in ("time", "space")]: + kwargs_to_select_xy_data[d] = data[d].values[0] + # Custom selections should then take priority + if selection: + for d, d_name in selection.items(): + if d in kwargs_to_select_xy_data: + kwargs_to_select_xy_data[d] = d_name + + # Remove any dimensions that we are purposefully axing from the data, + # before attempting to create the histogram for d in remove_dims_from_data_before_starting: # Retain the 0th value in the corresponding dimension, # then drop that dimension. @@ -199,8 +207,8 @@ def test_occupancy_histogram( # when examining the xy data later in the test. kwargs_to_select_xy_data.pop(d, None) - _, _, histogram_info = occupancy_histogram( - data, individual=individual, keypoint=keypoint, bins=n_bins + _, _, histogram_info = plot_occupancy( + data, selection=selection, bins=n_bins ) plotted_values = histogram_info["counts"] @@ -275,7 +283,7 @@ def test_respects_axes(histogram_data: xr.DataArray) -> None: np.linspace(0.0, 10.0, num=100), np.linspace(0.0, 10.0, num=100) ) - _, _, hist_info_existing = occupancy_histogram( + _, _, hist_info_existing = plot_occupancy( histogram_data, ax=existing_ax[1] ) hist_plots_added = [ @@ -284,7 +292,7 @@ def test_respects_axes(histogram_data: xr.DataArray) -> None: assert len(hist_plots_added) == 1 # Plot on new axis and create a new figure - new_fig, new_ax, hist_info_new = occupancy_histogram(histogram_data) + new_fig, new_ax, hist_info_new = plot_occupancy(histogram_data) assert isinstance(new_fig, plt.Figure) assert isinstance(new_ax, plt.Axes) hist_plots_created = [ From 76a197322368e7c04d18c1b24a5ddb902bfa3fa7 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 11 Feb 2025 10:58:24 +0000 Subject: [PATCH 12/22] Add check for incorrect dims after squeezing --- movement/plots/occupancy.py | 2 +- tests/test_unit/test_plots/test_plot.py | 34 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/movement/plots/occupancy.py b/movement/plots/occupancy.py index 329d98f15..af03cfa5f 100644 --- a/movement/plots/occupancy.py +++ b/movement/plots/occupancy.py @@ -99,7 +99,7 @@ def plot_occupancy( # Selections must be scalar, resulting in 2D data. # Catch this now if data.ndim != 2: - raise ValueError( + raise IndexError( "Histogram data was not time-space only. " "Did you accidentally pass multiple coordinates for any of " f"the following dimensions: {non_spacetime_dims}" diff --git a/tests/test_unit/test_plots/test_plot.py b/tests/test_unit/test_plots/test_plot.py index 4ed7d38cf..5b669d3ef 100644 --- a/tests/test_unit/test_plots/test_plot.py +++ b/tests/test_unit/test_plots/test_plot.py @@ -1,4 +1,6 @@ +import re from collections.abc import Hashable +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -121,6 +123,13 @@ def entirely_nan_data(histogram_data: xr.DataArray) -> xr.DataArray: 30, id="Default 0-index", ), + pytest.param( + "histogram_data", + [], + {"elephants": "Nellie"}, + 30, + id="Ignores nonsensical dimensions", + ), pytest.param( "histogram_data", [], @@ -274,6 +283,31 @@ def test_plot_histogram( # noqa: C901 assert plotted_values.sum() == 0 +@pytest.mark.parametrize( + ["other_args_to_fn", "expected_error"], + [ + pytest.param( + {"selection": {"keypoints": ["k0", "k1"]}}, + IndexError( + "Histogram data was not time-space only. " + "Did you accidentally pass multiple coordinates for any of " + "the following dimensions: ['individuals', 'keypoints']" + ), + id="Multiple selection along non-spacetime dimension.", + ) + ], +) +def test_plot_histogram_error_cases( + histogram_data: xr.DataArray, + other_args_to_fn: dict[str, Any], + expected_error: Exception, +) -> None: + with pytest.raises( + type(expected_error), match=re.escape(str(expected_error)) + ): + plot_occupancy(histogram_data, **other_args_to_fn) + + def test_respects_axes(histogram_data: xr.DataArray) -> None: """Check that existing axes objects are respected if passed.""" # Plotting on existing axes From e3e690ddc576bbf7c69c12aa555ec853ba3ecc18 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Fri, 14 Feb 2025 14:38:40 +0000 Subject: [PATCH 13/22] Remove tests to start afresh --- movement/plots/occupancy.py | 85 ++++---- tests/test_unit/test_plots/test_plot.py | 253 +----------------------- 2 files changed, 43 insertions(+), 295 deletions(-) diff --git a/movement/plots/occupancy.py b/movement/plots/occupancy.py index af03cfa5f..b2080cde8 100644 --- a/movement/plots/occupancy.py +++ b/movement/plots/occupancy.py @@ -14,17 +14,21 @@ def plot_occupancy( da: xr.DataArray, - selection: dict[str, Hashable] | None = None, + individuals: Hashable | None = None, + keypoints: Hashable | list[Hashable] | None = None, ax: plt.Axes | None = None, **kwargs: Any, ) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]: """Create a 2D histogram of the occupancy data given. - By default, the 0-indexed value along non-"time" and non-"space" dimensions - is plotted. The ``selection`` variable can be used to select different - coordinates along additional dimensions to plot instead. + By default; - Time-points whose corresponding spatial coordinates have NaN values + - If there are multiple keypoints selected, the occupancy of the centroid + of these keypoints is computed. + - If there are multiple individuals selected, the occupancies of their + centroids are aggregated. + + Points whose corresponding spatial coordinates have NaN values are ignored. Histogram information is returned as the third output value (see Notes). @@ -33,12 +37,12 @@ def plot_occupancy( ---------- da : xarray.DataArray Spatial data to create histogram for. NaN values are dropped. - selection : dict[str, Hashable], optional - Mapping of dimension identifiers to the coordinate along that dimension - to plot. "time" and "space" dimensions are ignored. For example, - ``selection = {"individuals": "Bravo"}`` will create the occupancy - histogram for the individual "Bravo", instead of the occupancy - histogram for the 0-indexed entry on the ``"individuals"`` dimension. + individuals : Hashable, optional + The name of the individual(s) to be aggregated and plotted. By default, + all individuals are aggregated. + keypoints : Hashable | list[Hashable], optional + Name of a keypoint or list of such names. The centroid of all provided + keypoints is computed, then plotted in the histogram. ax : matplotlib.axes.Axes, optional Axes object on which to draw the histogram. If not provided, a new figure and axes are created and returned. @@ -64,50 +68,45 @@ def plot_occupancy( dictionary containing the bin edges and bin counts that were used to create the histogram. - For data with ``N`` time-points, the dictionary output has key-value pairs; - - ``xedges``, an ``(N+1,)`` ``numpy`` array specifying the bin edges in the - first spatial dimension. - - ``yedges``, same as ``xedges`` but for the second spatial dimension. - - ``counts``, an ``(N, N)`` ``numpy`` array with the count for each bin. + For data with ``Nx`` bins in the 1st spatial dimension, and ``Ny`` bins in + the 2nd spatial dimension, the dictionary output has key-value pairs; + - ``xedges``, an ``(Nx+1,)`` ``numpy`` array specifying the bin edges in + the 1st spatial dimension. + - ``yedges``, an ``(Ny+1,)`` ``numpy`` array specifying the bin edges in + the 2nd spatial dimension. + - ``counts``, an ``(Nx, Ny)`` ``numpy`` array with the count for each bin. ``counts[x, y]`` is the number of datapoints in the ``(xedges[x], xedges[x+1]), (yedges[y], yedges[y+1])`` bin. These values are those returned from ``matplotlib.pyplot.Axes.hist2d``. - Note that the ``counts`` values do not necessarily match the mappable - values that one gets from extracting the data from the - ``matplotlib.collections.QuadMesh`` object (that represents the rendered - histogram) via its ``get_array()`` attribute. - See Also -------- matplotlib.pyplot.Axes.hist2d : The underlying plotting function. """ - if selection is None: - selection = dict() - - # Remove additional dimensions before dropping NaN values - non_spacetime_dims = [ - dim for dim in da.dims if dim not in ("time", "space") - ] - selection = { - dim: selection.get(dim, da[dim].values[0]) - for dim in non_spacetime_dims - } - data: xr.DataArray = da.sel(**selection).squeeze() - # Selections must be scalar, resulting in 2D data. - # Catch this now - if data.ndim != 2: - raise IndexError( - "Histogram data was not time-space only. " - "Did you accidentally pass multiple coordinates for any of " - f"the following dimensions: {non_spacetime_dims}" - ) - + # Collapse dimensions if necessary + data = da.copy(deep=True) + if "keypoints" in da.dims: + if keypoints is not None: + data = data.sel(keypoints=keypoints) + data = data.mean(dim="keypoints") + if "individuals" in da.dims and individuals is not None: + data = data.sel(individuals=individuals) + + # We need to remove NaN values from each individual, but we can't do this + # right now because we still potentially have a (time, space, individuals) + # array and so dropping NaNs along any axis may remove valid points for + # other times / individuals. + # Since we only care about a count, we can just unravel the individuals + # dimension and create a "long" array of points. For example, a (10, 2, 5) + # time-space-individuals DataArray becomes (50, 2). + if "individuals" in data.dims: + data.stack({"space": ("space", "individuals")}, create_index=False) # We should now have just the relevant time-space data, # so we can remove time-points with NaN values. data = data.dropna(dim="time", how="any") + # This makes us agnostic to the planar coordinate system. x_coord = data["space"].values[0] y_coord = data["space"].values[1] @@ -122,7 +121,7 @@ def plot_occupancy( else: fig, ax = plt.subplots() counts, xedges, yedges, hist_image = ax.hist2d( - data.sel(space=x_coord), data.sel(space=y_coord), **kwargs + data.sel(space=x_coord).stack, data.sel(space=y_coord), **kwargs ) colourbar = fig.colorbar(hist_image, ax=ax) colourbar.solids.set(alpha=1.0) diff --git a/tests/test_unit/test_plots/test_plot.py b/tests/test_unit/test_plots/test_plot.py index 5b669d3ef..e55dcdaeb 100644 --- a/tests/test_unit/test_plots/test_plot.py +++ b/tests/test_unit/test_plots/test_plot.py @@ -1,16 +1,8 @@ -import re -from collections.abc import Hashable -from typing import Any - -import matplotlib.pyplot as plt import numpy as np import pytest import xarray as xr -from matplotlib.collections import QuadMesh from numpy.random import Generator, default_rng -from movement.plots import plot_occupancy - @pytest.fixture def seed() -> int: @@ -62,9 +54,7 @@ def histogram_data(normal_dist_2d: np.ndarray) -> xr.DataArray: @pytest.fixture -def histogram_data_with_nans( - histogram_data: xr.DataArray, rng: Generator -) -> xr.DataArray: +def histogram_data_with_nans(histogram_data: xr.DataArray) -> xr.DataArray: """DataArray whose data is the ``normal_dist_2d`` points. Axes 2 and 3 are the individuals and keypoints axes, respectively. @@ -99,244 +89,3 @@ def entirely_nan_data(histogram_data: xr.DataArray) -> xr.DataArray: return histogram_data.copy( deep=True, data=histogram_data.values * float("nan") ) - - -@pytest.mark.parametrize( - [ - "data", - "remove_dims_from_data_before_starting", - "selection", - "n_bins", - ], - [ - pytest.param( - "histogram_data", - [], - {"individuals": "i0", "keypoints": "k0"}, - 30, - id="30 bins each axis", - ), - pytest.param( - "histogram_data", - [], - None, - 30, - id="Default 0-index", - ), - pytest.param( - "histogram_data", - [], - {"elephants": "Nellie"}, - 30, - id="Ignores nonsensical dimensions", - ), - pytest.param( - "histogram_data", - [], - {"individuals": "i1", "keypoints": "k0"}, - (20, 30), - id="(20, 30) bins", - ), - pytest.param( - "histogram_data_with_nans", - [], - {"individuals": "i0", "keypoints": "k0"}, - 30, - id="NaNs should be removed", - ), - pytest.param( - "entirely_nan_data", - [], - {"individuals": "i0", "keypoints": "k0"}, - 10, - id="All NaN-data", - ), - pytest.param( - "histogram_data", - ["individuals"], - {"individuals": "i0", "keypoints": "k0"}, - 30, - id="Ignores individual if not a dimension", - ), - pytest.param( - "histogram_data", - ["keypoints"], - {"individuals": "i0", "keypoints": "k1"}, - 30, - id="Ignores keypoint if not a dimension", - ), - pytest.param( - "histogram_data", - ["individuals", "keypoints"], - {"individuals": "i0", "keypoints": "k0"}, - 30, - id="Can handle raw xy data", - ), - ], -) -def test_plot_histogram( # noqa: C901 - data: xr.DataArray, - remove_dims_from_data_before_starting: list[str], - selection: dict[str, Hashable], - n_bins: int | tuple[int, int], - request, -) -> None: - """Test that occupancy histograms correctly plot data. - - Specifically, check that: - - The bin edges are what we expect. - - The bin counts can be manually verified and are in agreement. - - Only non-NaN values are plotted, but NaN values do not throw errors. - """ - if isinstance(data, str): - data = request.getfixturevalue(data) - - # We will need to only select the xy data later in the test, - # but if we are dropping dimensions we might need to call it - # in different ways. - kwargs_to_select_xy_data = dict() - # By default we should fetch 0-indexes - for d in [dim for dim in data.dims if dim not in ("time", "space")]: - kwargs_to_select_xy_data[d] = data[d].values[0] - # Custom selections should then take priority - if selection: - for d, d_name in selection.items(): - if d in kwargs_to_select_xy_data: - kwargs_to_select_xy_data[d] = d_name - - # Remove any dimensions that we are purposefully axing from the data, - # before attempting to create the histogram - for d in remove_dims_from_data_before_starting: - # Retain the 0th value in the corresponding dimension, - # then drop that dimension. - data = data.sel({d: getattr(data, d)[0]}).squeeze() - assert d not in data.dims - - # We no longer need to filter this dimension out - # when examining the xy data later in the test. - kwargs_to_select_xy_data.pop(d, None) - - _, _, histogram_info = plot_occupancy( - data, selection=selection, bins=n_bins - ) - plotted_values = histogram_info["counts"] - - # Confirm the binned array has the correct size - if not isinstance(n_bins, tuple): - n_bins = (n_bins, n_bins) - assert plotted_values.shape == n_bins - - # Confirm that each bin has the correct number of assignments - data_time_xy = data.sel(**kwargs_to_select_xy_data) - data_time_xy = data_time_xy.dropna(dim="time", how="any") - plotted_x_values = data_time_xy.sel(space="x").values - plotted_y_values = data_time_xy.sel(space="y").values - assert plotted_x_values.shape == plotted_y_values.shape - # This many non-NaN values were plotted - n_non_nan_values = plotted_x_values.shape[0] - - if n_non_nan_values > 0: - reconstructed_bins_limits_x = np.linspace( - plotted_x_values.min(), - plotted_x_values.max(), - num=n_bins[0] + 1, - endpoint=True, - ) - assert np.allclose( - reconstructed_bins_limits_x, histogram_info["xedges"] - ) - reconstructed_bins_limits_y = np.linspace( - plotted_y_values.min(), - plotted_y_values.max(), - num=n_bins[1] + 1, - endpoint=True, - ) - assert np.allclose( - reconstructed_bins_limits_y, histogram_info["yedges"] - ) - - reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float) - for i, xi in enumerate(reconstructed_bins_limits_x[:-1]): - xi_p1 = reconstructed_bins_limits_x[i + 1] - - x_pts_in_range = (plotted_x_values >= xi) & ( - plotted_x_values <= xi_p1 - ) - for j, yj in enumerate(reconstructed_bins_limits_y[:-1]): - yj_p1 = reconstructed_bins_limits_y[j + 1] - - y_pts_in_range = (plotted_y_values >= yj) & ( - plotted_y_values <= yj_p1 - ) - - pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum() - reconstructed_bin_counts[i, j] = pts_in_this_bin - - # We agree with a manual count - assert reconstructed_bin_counts.sum() == plotted_values.sum() - # All non-NaN values were plotted - assert n_non_nan_values == plotted_values.sum() - # The counts were actually correct - assert np.all(reconstructed_bin_counts == plotted_values) - else: - # No non-nan values were given - assert plotted_values.sum() == 0 - - -@pytest.mark.parametrize( - ["other_args_to_fn", "expected_error"], - [ - pytest.param( - {"selection": {"keypoints": ["k0", "k1"]}}, - IndexError( - "Histogram data was not time-space only. " - "Did you accidentally pass multiple coordinates for any of " - "the following dimensions: ['individuals', 'keypoints']" - ), - id="Multiple selection along non-spacetime dimension.", - ) - ], -) -def test_plot_histogram_error_cases( - histogram_data: xr.DataArray, - other_args_to_fn: dict[str, Any], - expected_error: Exception, -) -> None: - with pytest.raises( - type(expected_error), match=re.escape(str(expected_error)) - ): - plot_occupancy(histogram_data, **other_args_to_fn) - - -def test_respects_axes(histogram_data: xr.DataArray) -> None: - """Check that existing axes objects are respected if passed.""" - # Plotting on existing axes - _, existing_ax = plt.subplots(1, 2) - - existing_ax[0].plot( - np.linspace(0.0, 10.0, num=100), np.linspace(0.0, 10.0, num=100) - ) - - _, _, hist_info_existing = plot_occupancy( - histogram_data, ax=existing_ax[1] - ) - hist_plots_added = [ - qm for qm in existing_ax[1].get_children() if isinstance(qm, QuadMesh) - ] - assert len(hist_plots_added) == 1 - - # Plot on new axis and create a new figure - new_fig, new_ax, hist_info_new = plot_occupancy(histogram_data) - assert isinstance(new_fig, plt.Figure) - assert isinstance(new_ax, plt.Axes) - hist_plots_created = [ - qm for qm in new_ax.get_children() if isinstance(qm, QuadMesh) - ] - assert len(hist_plots_created) == 1 - - # Check that the same plot was made for each - assert set(hist_info_new.keys()) == set(hist_info_existing.keys()) - for key, new_ax_value in hist_info_new.items(): - existing_ax_value = hist_info_existing[key] - - assert np.allclose(new_ax_value, existing_ax_value) From 7dd0ebeeaedd67ba5e1ce8c018a40dd163575380 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Fri, 14 Feb 2025 14:49:19 +0000 Subject: [PATCH 14/22] Move trajectory tests into plots/ testing folder --- .../{test_plot_trajectory.py => test_plots/test_trajectory.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/test_unit/{test_plot_trajectory.py => test_plots/test_trajectory.py} (100%) diff --git a/tests/test_unit/test_plot_trajectory.py b/tests/test_unit/test_plots/test_trajectory.py similarity index 100% rename from tests/test_unit/test_plot_trajectory.py rename to tests/test_unit/test_plots/test_trajectory.py From 41055b6cb99abe510decfe484bfabf47f0d65fb3 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Mon, 17 Feb 2025 11:02:35 +0000 Subject: [PATCH 15/22] Write tests for plot_occupancy --- movement/plots/occupancy.py | 17 +- tests/test_unit/test_plots/test_occupancy.py | 253 +++++++++++++++++++ tests/test_unit/test_plots/test_plot.py | 91 ------- 3 files changed, 264 insertions(+), 97 deletions(-) create mode 100644 tests/test_unit/test_plots/test_occupancy.py delete mode 100644 tests/test_unit/test_plots/test_plot.py diff --git a/movement/plots/occupancy.py b/movement/plots/occupancy.py index b2080cde8..ab9fbc10e 100644 --- a/movement/plots/occupancy.py +++ b/movement/plots/occupancy.py @@ -1,6 +1,6 @@ """Wrappers for plotting occupancy data of select individuals.""" -from collections.abc import Hashable +from collections.abc import Hashable, Sequence from typing import Any, Literal, TypeAlias import matplotlib.pyplot as plt @@ -14,8 +14,8 @@ def plot_occupancy( da: xr.DataArray, - individuals: Hashable | None = None, - keypoints: Hashable | list[Hashable] | None = None, + individuals: Hashable | Sequence[Hashable] | None = None, + keypoints: Hashable | Sequence[Hashable] | None = None, ax: plt.Axes | None = None, **kwargs: Any, ) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]: @@ -90,7 +90,10 @@ def plot_occupancy( if "keypoints" in da.dims: if keypoints is not None: data = data.sel(keypoints=keypoints) - data = data.mean(dim="keypoints") + # A selection of just one keypoint automatically drops the keypoints + # dimension, hence the need to re-check this here + if "keypoints" in data.dims: + data = data.mean(dim="keypoints", skipna=True) if "individuals" in da.dims and individuals is not None: data = data.sel(individuals=individuals) @@ -102,7 +105,9 @@ def plot_occupancy( # dimension and create a "long" array of points. For example, a (10, 2, 5) # time-space-individuals DataArray becomes (50, 2). if "individuals" in data.dims: - data.stack({"space": ("space", "individuals")}, create_index=False) + data = data.stack( + {"new": ("time", "individuals")}, create_index=False + ).rename({"new": "time"}) # We should now have just the relevant time-space data, # so we can remove time-points with NaN values. data = data.dropna(dim="time", how="any") @@ -121,7 +126,7 @@ def plot_occupancy( else: fig, ax = plt.subplots() counts, xedges, yedges, hist_image = ax.hist2d( - data.sel(space=x_coord).stack, data.sel(space=y_coord), **kwargs + data.sel(space=x_coord), data.sel(space=y_coord), **kwargs ) colourbar = fig.colorbar(hist_image, ax=ax) colourbar.solids.set(alpha=1.0) diff --git a/tests/test_unit/test_plots/test_occupancy.py b/tests/test_unit/test_plots/test_occupancy.py new file mode 100644 index 000000000..deb804d99 --- /dev/null +++ b/tests/test_unit/test_plots/test_occupancy.py @@ -0,0 +1,253 @@ +from collections.abc import Hashable, Sequence +from typing import Any + +import numpy as np +import pytest +import xarray as xr +from numpy.typing import ArrayLike + +from movement.plots import plot_occupancy + + +def antidiagonal_matrix(diag_values: ArrayLike) -> np.ndarray: + """Create an antidiagonal matrix. + + An antidiagonal matrix has the ``diag_values`` along the reverse (TR to BL) + diagonal, with ``diag_values[0]`` appearing in the top-left position. + + Antidiagonal matrices are square. + """ + return np.fliplr(np.diag(diag_values)) + + +@pytest.fixture() +def occupancy_data() -> xr.DataArray: + """DataArray of 3 keypoints and 4 individuals. + + Individuals 0 through 2 (inclusive) are identical. + Individual 4 is a translation by (1,0) of the other individuals. + + The keypoints are left, right, centre. + Right = left + (1., 1.) + Centre = mean(left, right) + + The extent of the data is [0,6] x [0,5]. Using bins=list(range(7)) or + list(range(6)) will force unit-spaced bins. + """ + time_space = np.array( + [[0.0, 4.0], [1.0, 3.0], [2.0, 2.0], [3.0, 1.0], [4.0, 0.0]] + ) + + time_space_keypoints = np.repeat( + time_space[:, :, np.newaxis], repeats=3, axis=2 + ) + # right = left + (1., 1.) + time_space_keypoints[:, :, 1] += (1.0, 1.0) + # centre = mean(left, right) + time_space_keypoints[:, :, 2] = np.mean( + time_space_keypoints[:, :, :2], axis=2 + ) + + # individuals 0-2 (inclusive) are copies + data_vals = np.repeat( + time_space_keypoints[:, :, :, np.newaxis], repeats=4, axis=3 + ) + # individual 3 is (1., 0) offset from the others + for keypoint_index in range(data_vals.shape[2]): + data_vals[:, :, keypoint_index, 3] += (1.0, 0.0) + return xr.DataArray( + data=data_vals, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "space": ["x", "y"], + "keypoints": ["left", "right", "centre"], + "individuals": [0, 1, 2, 3], + }, + ) + + +@pytest.fixture +def occupancy_data_with_nans(occupancy_data: xr.DataArray) -> xr.DataArray: + """Occupancy data with deliberate NaN values. + + The occupancy_data fixture is modified so that: + + - Individual 0 has an NaN value at its left keypoint, "x" coord, 0th index. + - Individual 1 has an NaN coordinate at its centre keypoint, 0th index. + - Individual 2 is entirely NaN values down its right keypoint. + """ + occupancy_data_nans = occupancy_data.copy(deep=True) + + occupancy_data_nans.loc[0, "x", "left", 0] = float("nan") + occupancy_data_nans.loc[0, :, "centre", 1] = float("nan") + occupancy_data_nans.loc[:, :, "right", 2] = float("nan") + + return occupancy_data_nans + + +@pytest.mark.parametrize( + [ + "data", + "kwargs_to_pass", + "expected_output", + "select_before_passing_to_plot", + ], + [ + pytest.param( + "occupancy_data", + {"individuals": 0, "bins": [list(range(6)), list(range(6))]}, + antidiagonal_matrix([1] * 5), + {}, + id="Keypoints: default centroid", + ), + pytest.param( + "occupancy_data", + { + "keypoints": ["left", "right"], + "individuals": 0, + "bins": [list(range(6)), list(range(6))], + }, + antidiagonal_matrix([1] * 5), + {}, + id="Keypoints: selection centroid", + ), + pytest.param( + "occupancy_data", + { + "individuals": [0, 1, 2], + "bins": [list(range(6)), list(range(6))], + # data will have no keypoints dimension, + # so the argument below should be ignored + "keypoints": ["left", "right"], + }, + 3 * antidiagonal_matrix([1] * 5), + {"keypoints": "centre"}, + id="Keypoints: Handle not a dimension", + ), + pytest.param( + "occupancy_data", + { + "keypoints": "centre", + "bins": [list(range(7)), list(range(6))], + }, + 3 + * np.array( + [ + [0, 0, 0, 0, 1], + [0, 0, 0, 1, 0], + [0, 0, 1, 0, 0], + [0, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + + np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 1], + [0, 0, 0, 1, 0], + [0, 0, 1, 0, 0], + [0, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + ] + ), + {}, + id="Individuals: default aggregate", + ), + pytest.param( + "occupancy_data", + { + "individuals": [0, 1, 2], + "bins": [list(range(6)), list(range(6))], + }, + 3 * antidiagonal_matrix([1] * 5), + {}, + id="Individuals: selection aggregate", + ), + pytest.param( + "occupancy_data", + { + "keypoints": ["left", "right"], + "bins": [list(range(6)), list(range(6))], + # data will have no individuals dimension, + # so the argument below should be ignored + "individuals": [0, 2], + }, + antidiagonal_matrix([1] * 5), + {"individuals": 0}, + id="Individuals: Handle not a dimension", + ), + pytest.param( + "occupancy_data", + { + "keypoints": ["left", "right"], + "individuals": [0, 2], + "bins": [list(range(6)), list(range(6))], + }, + 2 * antidiagonal_matrix([1] * 5), + {}, + id="Sub-selection: mean THEN aggregate", + ), + pytest.param( + "occupancy_data_with_nans", + { + "keypoints": "centre", + "individuals": 1, + "bins": [list(range(6)), list(range(6))], + }, + antidiagonal_matrix([0] + ([1] * 4)), + {}, + id="NaNs: coord does not contribute", + ), + pytest.param( + "occupancy_data_with_nans", + { + "keypoints": ["left", "right"], + "individuals": 1, + "bins": [list(range(6)), list(range(6))], + }, + antidiagonal_matrix([1] * 5), + {}, + id="NaNs: average of valid keypoints still works", + ), + pytest.param( + "occupancy_data_with_nans", + { + "keypoints": "right", + "individuals": 2, + "bins": [list(range(6)), list(range(6))], + }, + np.zeros((5, 5)), + {}, + id="NaNs: no valid points", + ), + pytest.param( + "occupancy_data_with_nans", + { + "individuals": [0, 1, 2], + "bins": [list(range(6)), list(range(6))], + }, + 3 * antidiagonal_matrix([1] * 5), + {}, + id="NaNs: aggregate can ignore NaNs", + ), + ], +) +def test_keypoints_and_individuals_behaviour( + data: str | xr.DataArray, + kwargs_to_pass: dict[str, Any], + expected_output: np.ndarray, + select_before_passing_to_plot: dict[Hashable, Sequence[Hashable]], + request, +) -> None: + if isinstance(data, str): + data = request.getfixturevalue(data) + # Remove dimensions from data, if we want to test how the function + # handles data without certain dimension labels but which can still be + # plotted + if select_before_passing_to_plot: + data = data.sel(select_before_passing_to_plot) + + _, _, hist_info = plot_occupancy(data, **kwargs_to_pass) + + assert np.allclose(expected_output, hist_info["counts"]) diff --git a/tests/test_unit/test_plots/test_plot.py b/tests/test_unit/test_plots/test_plot.py deleted file mode 100644 index e55dcdaeb..000000000 --- a/tests/test_unit/test_plots/test_plot.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np -import pytest -import xarray as xr -from numpy.random import Generator, default_rng - - -@pytest.fixture -def seed() -> int: - return 0 - - -@pytest.fixture(scope="function") -def rng(seed: int) -> Generator: - """Create a RandomState to use in testing. - - This ensures the repeatability of histogram tests, that require large - datasets that would be tedious to create manually. - """ - return default_rng(seed) - - -@pytest.fixture -def normal_dist_2d(rng: Generator) -> np.ndarray: - """Points distributed by the standard multivariate normal. - - The standard multivariate normal is just two independent N(0, 1) - distributions, one in each dimension. - """ - samples = rng.multivariate_normal( - (0.0, 0.0), [[1.0, 0.0], [0.0, 1.0]], (250, 3, 4) - ) - return np.moveaxis( - samples, 3, 1 - ) # Move generated space coords to correct axis position - - -@pytest.fixture -def histogram_data(normal_dist_2d: np.ndarray) -> xr.DataArray: - """DataArray whose data is the ``normal_dist_2d`` points. - - Axes 2 and 3 are the individuals and keypoints axes, respectively. - These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for - the purposes of indexing. - """ - return xr.DataArray( - data=normal_dist_2d, - dims=["time", "space", "individuals", "keypoints"], - coords={ - "space": ["x", "y"], - "individuals": [f"i{i}" for i in range(normal_dist_2d.shape[2])], - "keypoints": [f"k{i}" for i in range(normal_dist_2d.shape[3])], - }, - ) - - -@pytest.fixture -def histogram_data_with_nans(histogram_data: xr.DataArray) -> xr.DataArray: - """DataArray whose data is the ``normal_dist_2d`` points. - - Axes 2 and 3 are the individuals and keypoints axes, respectively. - These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for - the purposes of indexing. - - For individual i0, keypoint k0, the following (time, space) values are - converted into NaNs: - - (100, "x") - - (200, "y") - - (150, "x") - - (150, "y") - - """ - individual_0 = "i0" - keypoint_0 = "k0" - data_with_nans = histogram_data.copy(deep=True) - for time_index, space_coord in [ - (100, "x"), - (200, "y"), - (150, "x"), - (150, "y"), - ]: - data_with_nans.loc[ - time_index, space_coord, individual_0, keypoint_0 - ] = float("nan") - return data_with_nans - - -@pytest.fixture -def entirely_nan_data(histogram_data: xr.DataArray) -> xr.DataArray: - return histogram_data.copy( - deep=True, data=histogram_data.values * float("nan") - ) From 7072bf556ec377d9d7b33da64eba61c3770102a6 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Mon, 17 Feb 2025 11:15:08 +0000 Subject: [PATCH 16/22] Add examples in docstring for kwargs --- movement/plots/occupancy.py | 51 +++++++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/movement/plots/occupancy.py b/movement/plots/occupancy.py index ab9fbc10e..34edef36b 100644 --- a/movement/plots/occupancy.py +++ b/movement/plots/occupancy.py @@ -21,8 +21,6 @@ def plot_occupancy( ) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]: """Create a 2D histogram of the occupancy data given. - By default; - - If there are multiple keypoints selected, the occupancy of the centroid of these keypoints is computed. - If there are multiple individuals selected, the occupancies of their @@ -80,6 +78,55 @@ def plot_occupancy( ``(xedges[x], xedges[x+1]), (yedges[y], yedges[y+1])`` bin. These values are those returned from ``matplotlib.pyplot.Axes.hist2d``. + Examples + -------- + Simple use-case is to plot a histogram of the centroid of all + keypoints, aggregated over all individuals. + + >>> from movement import sample_data + >>> from movement.plots import plot_occupancy + >>> positions = sample_data.fetch_dataset( + ... "SLEAP_three-mice_Aeon_proofread.analysis.h5" + ... ).position + >>> plot_occupancy(positions) + + However, one can restrict the histogram to only counting the positions of + (the centroid of) certain keypoints and/or individuals. + + >>> print("Available individuals:", positions["individuals"]) + >>> plot_occupancy( + ... positions, + ... # use only one keypoint in computation of centroid + ... keypoints="centroid", + ... # only aggregate for two individuals + ... individuals=[ + ... "AEON3B_TP1", + ... "AEON3B_TP2", + ... ], + ... ) + + ``kwargs`` are passed to the ``matplotlib`` backend as keyword-arguments to + this function. + + >>> plot_occupancy( + ... positions, + ... # use only one keypoint in computation of centroid + ... keypoints="centroid", + ... # only aggregate for two individuals + ... individuals=[ + ... "AEON3B_TP1", + ... "AEON3B_TP2", + ... ], + ... # fix the number of bins to use + ... bins=[30, 20], + ... # set the minimum count (bins with a lower count show as 0 count). + ... # This effectively only displayed bins that were visited at least + ... # twice. + ... cmin=1, + ... # Normalise the plot, scaling the counts to [0, 1] + ... norm=True, + ... ) + See Also -------- matplotlib.pyplot.Axes.hist2d : The underlying plotting function. From 0e6d5ddad7d92cf7fd6206be02f51644c77c138a Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Mon, 17 Feb 2025 11:22:57 +0000 Subject: [PATCH 17/22] SonarQube is confused, but fine --- tests/test_unit/test_plots/test_occupancy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_unit/test_plots/test_occupancy.py b/tests/test_unit/test_plots/test_occupancy.py index deb804d99..54724d8cd 100644 --- a/tests/test_unit/test_plots/test_occupancy.py +++ b/tests/test_unit/test_plots/test_occupancy.py @@ -41,9 +41,9 @@ def occupancy_data() -> xr.DataArray: time_space_keypoints = np.repeat( time_space[:, :, np.newaxis], repeats=3, axis=2 ) - # right = left + (1., 1.) + # Set right = left + (1., 1.) time_space_keypoints[:, :, 1] += (1.0, 1.0) - # centre = mean(left, right) + # Set centre = mean(left, right) time_space_keypoints[:, :, 2] = np.mean( time_space_keypoints[:, :, :2], axis=2 ) From 6b52cfc2b1fdc9b80f23392e0957eb90591c23b0 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Mon, 17 Feb 2025 11:26:22 +0000 Subject: [PATCH 18/22] Test that ax argument doesn't complain --- tests/test_unit/test_plots/test_occupancy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_unit/test_plots/test_occupancy.py b/tests/test_unit/test_plots/test_occupancy.py index 54724d8cd..75b1ed1bb 100644 --- a/tests/test_unit/test_plots/test_occupancy.py +++ b/tests/test_unit/test_plots/test_occupancy.py @@ -1,6 +1,7 @@ from collections.abc import Hashable, Sequence from typing import Any +import matplotlib.pyplot as plt import numpy as np import pytest import xarray as xr @@ -182,6 +183,8 @@ def occupancy_data_with_nans(occupancy_data: xr.DataArray) -> xr.DataArray: { "keypoints": ["left", "right"], "individuals": [0, 2], + # Also check that ax doesn't complain + "ax": plt.subplots()[1], "bins": [list(range(6)), list(range(6))], }, 2 * antidiagonal_matrix([1] * 5), From 63c11aec81a25ff29169779412fcc0148a7b252b Mon Sep 17 00:00:00 2001 From: Will Graham <32364977+willGraham01@users.noreply.github.com> Date: Tue, 18 Feb 2025 09:03:48 +0000 Subject: [PATCH 19/22] Apply suggestions from code review Co-authored-by: Niko Sirmpilatze --- movement/plots/occupancy.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/movement/plots/occupancy.py b/movement/plots/occupancy.py index 34edef36b..ddb65ec1f 100644 --- a/movement/plots/occupancy.py +++ b/movement/plots/occupancy.py @@ -19,7 +19,7 @@ def plot_occupancy( ax: plt.Axes | None = None, **kwargs: Any, ) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]: - """Create a 2D histogram of the occupancy data given. + """Create a 2D occupancy histogram. - If there are multiple keypoints selected, the occupancy of the centroid of these keypoints is computed. @@ -86,7 +86,7 @@ def plot_occupancy( >>> from movement import sample_data >>> from movement.plots import plot_occupancy >>> positions = sample_data.fetch_dataset( - ... "SLEAP_three-mice_Aeon_proofread.analysis.h5" + ... "DLC_two-mice.predictions.csv" ... ).position >>> plot_occupancy(positions) @@ -96,31 +96,25 @@ def plot_occupancy( >>> print("Available individuals:", positions["individuals"]) >>> plot_occupancy( ... positions, - ... # use only one keypoint in computation of centroid - ... keypoints="centroid", - ... # only aggregate for two individuals - ... individuals=[ - ... "AEON3B_TP1", - ... "AEON3B_TP2", - ... ], + ... # plot the centroid of keypoints located on the head + ... keypoints=["snout", "leftear", "rightear"], + ... # only plot data for 1 individual + ... individuals="individual1" ... ) ``kwargs`` are passed to the ``matplotlib`` backend as keyword-arguments to this function. - >>> plot_occupancy( - ... positions, - ... # use only one keypoint in computation of centroid - ... keypoints="centroid", - ... # only aggregate for two individuals - ... individuals=[ - ... "AEON3B_TP1", - ... "AEON3B_TP2", - ... ], + >>> plot_occupancy( + ... positions, + ... # plot the centroid of keypoints located on the head + ... keypoints=["snout", "leftear", "rightear"], + ... # only plot data for 1 individual + ... individuals="individual1", ... # fix the number of bins to use ... bins=[30, 20], ... # set the minimum count (bins with a lower count show as 0 count). - ... # This effectively only displayed bins that were visited at least + ... # This effectively only displays bins that were visited at least ... # twice. ... cmin=1, ... # Normalise the plot, scaling the counts to [0, 1] From fc2d3f3ffaafeb1ec593d1fd4d9fd7ed8b3ba399 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Feb 2025 09:06:22 +0000 Subject: [PATCH 20/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- movement/plots/occupancy.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/movement/plots/occupancy.py b/movement/plots/occupancy.py index ddb65ec1f..4dd54aef0 100644 --- a/movement/plots/occupancy.py +++ b/movement/plots/occupancy.py @@ -99,18 +99,18 @@ def plot_occupancy( ... # plot the centroid of keypoints located on the head ... keypoints=["snout", "leftear", "rightear"], ... # only plot data for 1 individual - ... individuals="individual1" + ... individuals="individual1", ... ) ``kwargs`` are passed to the ``matplotlib`` backend as keyword-arguments to this function. - >>> plot_occupancy( - ... positions, - ... # plot the centroid of keypoints located on the head - ... keypoints=["snout", "leftear", "rightear"], - ... # only plot data for 1 individual - ... individuals="individual1", + >>> plot_occupancy( + ... positions, + ... # plot the centroid of keypoints located on the head + ... keypoints=["snout", "leftear", "rightear"], + ... # only plot data for 1 individual + ... individuals="individual1", ... # fix the number of bins to use ... bins=[30, 20], ... # set the minimum count (bins with a lower count show as 0 count). From 61fde4165d3daf64e54af514a606412319ccf89b Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 18 Feb 2025 09:09:46 +0000 Subject: [PATCH 21/22] Apply suggestions from code review --- movement/plots/occupancy.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/movement/plots/occupancy.py b/movement/plots/occupancy.py index 4dd54aef0..ee307eb38 100644 --- a/movement/plots/occupancy.py +++ b/movement/plots/occupancy.py @@ -23,8 +23,8 @@ def plot_occupancy( - If there are multiple keypoints selected, the occupancy of the centroid of these keypoints is computed. - - If there are multiple individuals selected, the occupancies of their - centroids are aggregated. + - If there are multiple individuals selected, the their occupancies are + aggregated. Points whose corresponding spatial coordinates have NaN values are ignored. @@ -60,11 +60,10 @@ def plot_occupancy( Notes ----- - In instances where the counts or information about the histogram bins is - desired, the ``return_hist_info`` argument should be provided as ``True``. - This will force the function to return a second output value, which is a - dictionary containing the bin edges and bin counts that were used to create - the histogram. + The third return value of this method exposes the outputs from + ``matplotlib.pyplot.hist2d`` that would otherwise be lost if only the + figure and axes handles were returned. This information is returned as a + dictionary. For data with ``Nx`` bins in the 1st spatial dimension, and ``Ny`` bins in the 2nd spatial dimension, the dictionary output has key-value pairs; @@ -118,7 +117,7 @@ def plot_occupancy( ... # twice. ... cmin=1, ... # Normalise the plot, scaling the counts to [0, 1] - ... norm=True, + ... norm="log", ... ) See Also @@ -148,7 +147,7 @@ def plot_occupancy( if "individuals" in data.dims: data = data.stack( {"new": ("time", "individuals")}, create_index=False - ).rename({"new": "time"}) + ).swap_dims({"new": "time"}) # We should now have just the relevant time-space data, # so we can remove time-points with NaN values. data = data.dropna(dim="time", how="any") @@ -172,8 +171,7 @@ def plot_occupancy( colourbar = fig.colorbar(hist_image, ax=ax) colourbar.solids.set(alpha=1.0) - space_unit = data.attrs.get("space_unit", "pixels") - ax.set_xlabel(f"{x_coord} ({space_unit})") - ax.set_ylabel(f"{y_coord} ({space_unit})") + ax.set_xlabel(str(x_coord)) + ax.set_ylabel(str(y_coord)) return fig, ax, {"counts": counts, "xedges": xedges, "yedges": yedges} From 402610593a7586373f2414597ce71351a8e6d580 Mon Sep 17 00:00:00 2001 From: willGraham01 Date: Tue, 18 Feb 2025 09:12:17 +0000 Subject: [PATCH 22/22] Close hanging matplotlib plots inside occupancy tests --- tests/test_unit/test_plots/test_occupancy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_unit/test_plots/test_occupancy.py b/tests/test_unit/test_plots/test_occupancy.py index 75b1ed1bb..a6c02afdf 100644 --- a/tests/test_unit/test_plots/test_occupancy.py +++ b/tests/test_unit/test_plots/test_occupancy.py @@ -251,6 +251,8 @@ def test_keypoints_and_individuals_behaviour( if select_before_passing_to_plot: data = data.sel(select_before_passing_to_plot) - _, _, hist_info = plot_occupancy(data, **kwargs_to_pass) + fig, _, hist_info = plot_occupancy(data, **kwargs_to_pass) + # This just helps suppress a warning about open plots + plt.close(fig) assert np.allclose(expected_output, hist_info["counts"])