Skip to content

Commit

Permalink
Merge pull request #957 from gchq/fix/coresubset-size-1
Browse files Browse the repository at this point in the history
fix: correct shape of unweighted_indices
  • Loading branch information
tp832944 authored Feb 10, 2025
2 parents bb20c65 + 316ac0f commit e601fb4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Incorrectly-implemented tests for the gradients of `PeriodicKernel`. (https://github.com/gchq/coreax/pull/936)
- `MapReduce`'s warning about a solver not being padding-invariant is now raised at the
correct stack level. (https://github.com/gchq/coreax/pull/951)
- `len(coresubset.points)` is no longer incorrect for a coresubset of size 1 from a 2d
dataset. (https://github.com/gchq/coreax/pull/957)

### Changed

Expand Down
10 changes: 2 additions & 8 deletions coreax/coreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,8 @@ def points(self) -> _TOriginalData_co:
@property
def unweighted_indices(self) -> Shaped[Array, " n"]:
"""Unweighted Coresubset indices - attribute access helper."""
return jnp.squeeze(self._indices.data)

@override
def __len__(self) -> int:
# TODO: this feels like a hacky temporary fix - the underlying issue is that
# Coresubset doesn't seem to handle 2d data properly if len(indices)==1.
# https://github.com/gchq/coreax/issues/952
return len(self.indices)
# Ensure at least 1d to avoid shape errors.
return jnp.atleast_1d(jnp.squeeze(self._indices.data))

@property
@override
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/test_coreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,16 @@ def test_unweighted_indices(self):
coresubset = Coresubset(CORESUBSET_INDICES, PRE_CORESET_DATA)
expected_indices = CORESUBSET_INDICES.data.squeeze()
assert eqx.tree_equal(expected_indices, coresubset.unweighted_indices)

def test_materialisation_2d_size_1(self):
"""
Test that the length of a coreset of size 1 from a 2d dataset is correct.
This test is here to prevent regressions on
https://github.com/gchq/coreax/issues/952
"""
pre_coreset_data = jnp.ones((5, 2))
indices = jnp.array([0])
coresubset = Coresubset.build(indices, pre_coreset_data)
assert len(coresubset) == 1
assert len(coresubset.points) == 1

0 comments on commit e601fb4

Please sign in to comment.