Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rly committed Apr 4, 2024
1 parent 21f822a commit 8ffe6fe
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
11 changes: 6 additions & 5 deletions hdmf_ml/results_table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from hdmf.utils import docval, popargs
from hdmf.backends.hdf5 import H5DataIO
from hdmf.common import get_class, register_class, VectorData
from hdmf.common import get_class, register_class, VectorData, EnumData
import numpy as np
from sklearn.preprocessing import LabelEncoder

Expand Down Expand Up @@ -56,7 +56,7 @@ def n_samples(self):
return self.__n_samples

@docval(
{"name": "cls", "type": (str, type), "doc": "class for this column"},
{"name": "col_cls", "type": type, "doc": "class for this column"},
{"name": "data", "type": data_type, "doc": "data for this column"},
{"name": "name", "type": str, "doc": "the name of this column"},
{"name": "description", "type": str, "doc": "a description for this column"},
Expand All @@ -80,8 +80,8 @@ def n_samples(self):
)
def __add_col(self, **kwargs):
"""A helper function to handle boiler-plate code for adding columns to a ResultsTable"""
cls, data, name, description, dim2_kwarg, dtype = popargs(
"cls", "data", "name", "description", "dim2_kwarg", "dtype", kwargs
col_cls, data, name, description, dim2_kwarg, dtype = popargs(
"col_cls", "data", "name", "description", "dim2_kwarg", "dtype", kwargs
)
# get the size of the other dimension(s) from kwargs
if dim2_kwarg is not None:
Expand Down Expand Up @@ -125,7 +125,7 @@ def __add_col(self, **kwargs):
)

self.add_column(
data=data, name=name, description=description, col_cls=cls, **kwargs
data=data, name=name, description=description, col_cls=col_cls, **kwargs
)

if self.__n_samples is None:
Expand Down Expand Up @@ -211,6 +211,7 @@ def add_true_label(self, **kwargs):
enc = LabelEncoder()
kwargs["data"] = np.uint(enc.fit_transform(kwargs["data"]))
kwargs["enum"] = enc.classes_
return self.__add_col(EnumData, **kwargs)
kwargs["dtype"] = int
return self.__add_col(VectorData, **kwargs)

Expand Down
27 changes: 23 additions & 4 deletions tests/test_results_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,21 @@
import tempfile
from unittest import TestCase

from hdmf.common import HDF5IO, get_manager, EnumData
from hdmf.common import HDF5IO, get_manager, EnumData, VectorData
from hdmf_ml import ResultsTable

from hdmf_ml.results_table import (
ResultsTable,
SupervisedOutput,
TrainValidationTestSplit,
CrossValidationSplit,
ClassProbability,
ClassLabel,
TopKProbabilities,
TopKClasses,
RegressionOutput,
ClusterLabel,
EmbeddedValues,
)
import numpy as np


Expand Down Expand Up @@ -47,12 +59,19 @@ def test_add_col_dupe_name(self):
def test_add_tvt_split(self):
rt = ResultsTable(name="foo", description="a test results table")
rt.add_tvt_split(np.uint([0, 1, 2, 0, 1]))
with self.get_hdf5io() as io:
io.write(rt)
assert isinstance(rt["tvt_split"], TrainValidationTestSplit)
assert all(rt["tvt_split"].data == np.uint([0, 1, 2, 0, 1]))
assert isinstance(rt["tvt_split"].elements, VectorData)
assert rt["tvt_split"].elements.data == ["train", "validate", "test"]
# with self.get_hdf5io() as io: # TODO fix this test
# io.write(rt)

def test_add_cv_split(self):
rt = ResultsTable(name="foo", description="a test results table")
rt.add_cv_split([0, 1, 2, 3, 4])
assert isinstance(rt["cv_split"], CrossValidationSplit)
assert rt["cv_split"].data == [0, 1, 2, 3, 4]
assert rt["cv_split"].n_splits == 5
with self.get_hdf5io() as io:
io.write(rt)

Expand Down

0 comments on commit 8ffe6fe

Please sign in to comment.