Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor coreset classes #943

Merged
merged 18 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,21 @@ fail-under=10
#from-stdin=

# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS
ignore=
CVS,
.git,
.venv,
.cache,
build,

# Add files or directories matching the regular expressions patterns to the
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=
# For some reason this doesn't work in `ignore` above.
^documentation/source/snippets/.*$,


# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
Expand Down
17 changes: 16 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **[BREAKING CHANGE]** Equinox dependency version is changed from `<0.11.8` to `>=0.
11.5`. (https://github.com/gchq/coreax/pull/898)
- **[BREAKING CHANGE]** The `jaxtyping` version is now lower bounded at `v0.2.31` to enable `coreax.data.Data` jaxtyping compatibility.
- Refactored the `Coreset` types - instead of `Coreset` and `Coresubset(Coreset)`, we
now have `AbstractCoreset`, `PseudoCoreset(AbstractCoreset)`, and
`Coresubset(AbstractCoreset)`. See "Deprecated" below for more details of this change.
(https://github.com/gchq/coreax/pull/943)

### Removed

-

### Deprecated

-
- Uses of `Coreset` should be replaced with `AbstractCoreset` (for a general coreset,
such as in a function argument type hint), or `PseudoCoreset` (for the specific case
of a coreset that is not necessarily a coresubset).
(https://github.com/gchq/coreax/pull/943)
- Uses of `Coreset.coreset` should be replaced with `Coreset.points`.
(https://github.com/gchq/coreax/pull/943)
- Uses of `Coreset.nodes` should be replaced with `Coresubset.indices` or
`PseudoCoreset.points`, depending on whether the coreset is a coresubset or a
pseudo-coreset. (https://github.com/gchq/coreax/pull/943)
- Passing `Array` or `tuple[Array, Array]` into coreset constructors is now deprecated -
either pass in `Data` or `SupervisedData` instances, or use the `build()` class
method which handles the conversion. (https://github.com/gchq/coreax/pull/943)


## [0.3.1]
Expand Down
4 changes: 2 additions & 2 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def compute_solver_metrics(
coresubset, _ = solver.reduce(dataset)

# Unweighted metrics
unweighted_mmd = float(mmd_metric.compute(dataset, coresubset.coreset))
unweighted_ksd = float(ksd_metric.compute(dataset, coresubset.coreset))
unweighted_mmd = float(mmd_metric.compute(dataset, coresubset.points))
unweighted_ksd = float(ksd_metric.compute(dataset, coresubset.points))

# Weighted metrics
weighted_coresubset = coresubset.solve_weights(weights_optimiser)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/david_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def benchmark_coreset_algorithms(
start_time = time.perf_counter()
coreset, _ = eqx.filter_jit(solver.reduce)(data)
duration = time.perf_counter() - start_time
coresets[solver_name] = coreset.coreset.data
coresets[solver_name] = coreset.points.data
solver_times[solver_name] = duration

plt.figure(figsize=(15, 10))
Expand Down
2 changes: 1 addition & 1 deletion benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def main() -> None:
# pylint: enable=duplicate-code
coreset, _ = eqx.filter_jit(solver.reduce)(train_data_umap)

coreset_indices = coreset.nodes.data
coreset_indices = coreset.indices.data

train_data_coreset = train_data_jax[coreset_indices]
train_targets_coreset = train_targets_jax[coreset_indices]
Expand Down
5 changes: 3 additions & 2 deletions coreax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
MonteCarloApproximateKernel,
NystromApproximateKernel,
)
from coreax.coreset import Coreset, Coresubset
from coreax.coreset import AbstractCoreset, Coresubset, PseudoCoreset
from coreax.data import Data, SupervisedData
from coreax.kernels import (
LaplacianKernel,
Expand All @@ -48,8 +48,9 @@
"ApproximateKernel",
"MonteCarloApproximateKernel",
"NystromApproximateKernel",
"Coreset",
"AbstractCoreset",
"Coresubset",
"PseudoCoreset",
"Data",
"SupervisedData",
"LaplacianKernel",
Expand Down
Loading