Skip to content

Commit

Permalink
chore : Increase coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
qh681248 committed Feb 10, 2025
1 parent f3e9c4c commit b4e6d98
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
1 change: 1 addition & 0 deletions .cspell/people.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Abhishek
Benard
Caratheodory
Chatalic
Expand Down
4 changes: 0 additions & 4 deletions coreax/solvers/coresubset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,10 +1426,6 @@ def _compress(indices: jax.Array) -> jax.Array:
"""
m = len(indices)
# Base case: If m = 4^g, return the dataset
if m < 4**self.g:
raise ValueError(
f"Dataset size {m} is smaller than the required size {4**self.g}."
)
if m == 4**self.g:
return indices

Expand Down
9 changes: 9 additions & 0 deletions documentation/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,12 @@ @misc{dwivedi2024kernelthinning
primaryClass={stat.ML},
url={https://arxiv.org/abs/2105.05842},
}

@misc{shetty2022compress,
title={Distribution compression in near-linear time},
author={Abhishek Shetty and Raaz Dwivedi and Lester Mackey},
year={2022},
archivePrefix={arXiv},
primaryClass={stat.ML},
url={https://arxiv.org/pdf/2111.07941},
}
38 changes: 38 additions & 0 deletions tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,3 +2484,41 @@ def solver_factory(self) -> Union[type[Solver], jtu.Partial]:
delta=0.01,
sqrt_kernel=kernel,
)

def test_invalid_g_too_high(self):
"""Test that ValueError is raised when g too high."""
dataset = Data(jnp.arange(64)) # Create a dataset with 64 elements

with pytest.raises(
ValueError,
match="The over-sampling factor g should be between 0 and 3, inclusive.",
):
solver = CompressPlusPlus(
g=4, # Set g to 4, which is outside the valid range (0 to 3)
coreset_size=3, # Set coreset_size to 8
random_key=self.random_key,
kernel=SquaredExponentialKernel(),
delta=0.01,
sqrt_kernel=SquaredExponentialKernel(),
)
solver.reduce(dataset) # Attempt to reduce the dataset

def test_invalid_coreset_size_incompatible(self):
"""Test that ValueError is raised when coreset_size and g are incompatible."""
dataset = Data(jnp.arange(64)) # Create a dataset with 64 elements
g = 0 # Set g to 0
coreset_size = 17 # Set an incompatible coreset size

with pytest.raises(
ValueError,
match="Coreset size and g are not compatible with the dataset size.",
):
solver = CompressPlusPlus(
g=g,
coreset_size=coreset_size,
random_key=self.random_key,
kernel=SquaredExponentialKernel(),
delta=0.01,
sqrt_kernel=SquaredExponentialKernel(),
)
solver.reduce(dataset)

0 comments on commit b4e6d98

Please sign in to comment.