From 46ac87400629f1db645a157ba9b5da3d85324fa5 Mon Sep 17 00:00:00 2001 From: rg936672 <162452529+rg936672@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:11:42 +0000 Subject: [PATCH 1/3] fix: correct shape of unweighted_indices Refs: #952 --- coreax/coreset.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/coreax/coreset.py b/coreax/coreset.py index 9a5c2a86..6592eace 100644 --- a/coreax/coreset.py +++ b/coreax/coreset.py @@ -388,14 +388,8 @@ def points(self) -> _TOriginalData_co: @property def unweighted_indices(self) -> Shaped[Array, " n"]: """Unweighted Coresubset indices - attribute access helper.""" - return jnp.squeeze(self._indices.data) - - @override - def __len__(self) -> int: - # TODO: this feels like a hacky temporary fix - the underlying issue is that - # Coresubset doesn't seem to handle 2d data properly if len(indices)==1. - # https://github.com/gchq/coreax/issues/952 - return len(self.indices) + # Ensure at least 1d to avoid shape errors. + return jnp.atleast_1d(jnp.squeeze(self._indices.data)) @property @override From 1effd660c61bcde94582889a08036748edc00b40 Mon Sep 17 00:00:00 2001 From: rg936672 <162452529+rg936672@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:43:46 +0000 Subject: [PATCH 2/3] tests: add extra test to prevent regressions on #952 Refs: #952 --- tests/unit/test_coreset.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/unit/test_coreset.py b/tests/unit/test_coreset.py index 3e130271..419b0fe9 100644 --- a/tests/unit/test_coreset.py +++ b/tests/unit/test_coreset.py @@ -291,3 +291,16 @@ def test_unweighted_indices(self): coresubset = Coresubset(CORESUBSET_INDICES, PRE_CORESET_DATA) expected_indices = CORESUBSET_INDICES.data.squeeze() assert eqx.tree_equal(expected_indices, coresubset.unweighted_indices) + + def test_materialisation_2d_size_1(self): + """ + Test that the length of a coreset of size 1 from a 2d dataset is correct. + + This test is here to prevent regressions on + https://github.com/gchq/coreax/issues/952 + """ + pre_coreset_data = jnp.ones((5, 2)) + indices = jnp.array([0]) + coresubset = Coresubset.build(indices, pre_coreset_data) + assert len(coresubset) == 1 + assert len(coresubset.points) == 1 From 316ac0fa521fe2ab6fe0fcd5fe0608b61001d91b Mon Sep 17 00:00:00 2001 From: rg936672 <162452529+rg936672@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:45:32 +0000 Subject: [PATCH 3/3] chore: update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f412ed9c..588a9074 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Incorrectly-implemented tests for the gradients of `PeriodicKernel`. (https://github.com/gchq/coreax/pull/936) - `MapReduce`'s warning about a solver not being padding-invariant is now raised at the correct stack level. (https://github.com/gchq/coreax/pull/951) +- `len(coresubset.points)` is no longer incorrect for a coresubset of size 1 from a 2d + dataset. (https://github.com/gchq/coreax/pull/957) ### Changed