Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
update to new equistore API from commit c022fde
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri authored and agoscinski committed Jul 10, 2023
1 parent 95b0c82 commit 5fdd26b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ keywords = [
requires-python = ">=3.7"

dependencies = [
"equistore @ https://github.com/lab-cosmo/equistore/archive/a9b9a2a.zip",
"equistore @ https://github.com/lab-cosmo/equistore/archive/c022fde.zip",
"numpy",
"scipy",
"skmatter"
Expand Down
4 changes: 2 additions & 2 deletions tests/equisolve_tests/numpy/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def X(self):
def test_fit(self, X, selector_class, skmatter_selector_class):
selector = selector_class(n_to_select=2)
selector.fit(X)
support = selector.support[0].properties["properties"]
support = selector.support[0].properties

skmatter_selector = skmatter_selector_class(n_to_select=2)
skmatter_selector.fit(X[0].values)
Expand All @@ -41,7 +41,7 @@ def test_fit(self, X, selector_class, skmatter_selector_class):
),
)

assert_equal(support, skmatter_support_labels)
assert support == skmatter_support_labels

@pytest.mark.parametrize(
"selector_class, skmatter_selector_class",
Expand Down
9 changes: 5 additions & 4 deletions tests/equisolve_tests/numpy/sample_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,20 @@ def X(self):
def test_fit(self, X, selector_class, skmatter_selector_class):
selector = selector_class(n_to_select=2)
selector.fit(X)
support = selector.support[0].samples["structure"]
support = selector.support[0].samples

skmatter_selector = skmatter_selector_class(n_to_select=2)
skmatter_selector.fit(X[0].values)
skmatter_support = skmatter_selector.get_support(indices=True)
skmatter_support_labels = Labels(
names=["structure"],
names=["sample", "structure"],
values=np.array(
[[support_i] for support_i in skmatter_support], dtype=np.int32
[[support_i, support_i] for support_i in skmatter_support],
dtype=np.int32,
),
)

assert_equal(support, skmatter_support_labels)
assert support == skmatter_support_labels

@pytest.mark.parametrize(
"selector_class, skmatter_selector_class",
Expand Down

0 comments on commit 5fdd26b

Please sign in to comment.