Skip to content

Commit

Permalink
feat: split tests and ensure 2d arrays.
Browse files Browse the repository at this point in the history
Refs: 251
  • Loading branch information
pc532627 committed Dec 5, 2023
1 parent f675165 commit 0c1e914
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 12 deletions.
14 changes: 8 additions & 6 deletions coreax/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def approximate(
data points in the dataset
"""
# Validate inputs
data = cast_as_type(x=data, object_name="data", type_caster=jnp.asarray)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

# Record dataset size
num_data_points = len(data)
Expand Down Expand Up @@ -285,7 +285,7 @@ def approximate(
data points in the dataset
"""
# Validate inputs
data = cast_as_type(x=data, object_name="data", type_caster=jnp.asarray)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

# Record dataset size
num_data_points = len(data)
Expand Down Expand Up @@ -354,7 +354,7 @@ def approximate(
data points in the dataset
"""
# Validate inputs
data = cast_as_type(x=data, object_name="data", type_caster=jnp.asarray)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

# Record dataset size
num_data_points = len(data)
Expand Down Expand Up @@ -385,13 +385,15 @@ def _anchor_body(
:param idx: Loop counter
:param features: Loop updateables
:param data: Original :math:`n \times d` dataset
:param kernel_function: Vectorised kernel function on pairs `(X,x)`:
:param kernel_function: Vectorised kernel function on pairs ``(X,x)``:
:math:`k: \mathbb{R}^{n \times d} \times \mathbb{R}^d \rightarrow \mathbb{R}^n`
:return: Updated loop variables `features`
"""
# Validate inputs
features = cast_as_type(x=features, object_name="features", type_caster=jnp.asarray)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.asarray)
features = cast_as_type(
x=features, object_name="features", type_caster=jnp.atleast_2d
)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

max_entry = features.max(axis=1).argmin()
features = features.at[:, idx].set(kernel_function(data, data[max_entry])[:, 0])
Expand Down
106 changes: 100 additions & 6 deletions tests/unit/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ class TestInputValidation(unittest.TestCase):
Tests relating to validation of inputs provided by the user.
"""

def test_validate_in_range(self) -> None:
def test_validate_in_range_equal_lower_strict(self) -> None:
"""
Test the function validate_in_range across reasonably likely inputs
Test the function validate_in_range with the input matching the lower bound.
The inequality is strict here, so this should be flagged as invalid.
"""
self.assertRaises(
ValueError,
Expand All @@ -33,6 +35,13 @@ def test_validate_in_range(self) -> None:
lower_bound=0.0,
upper_bound=100.0,
)

def test_validate_in_range_equal_lower_not_strict(self) -> None:
"""
Test the function validate_in_range with the input matching the lower bound.
The inequality is not strict here, so this should not be flagged as invalid.
"""
self.assertIsNone(
validate_in_range(
x=0.0,
Expand All @@ -42,6 +51,13 @@ def test_validate_in_range(self) -> None:
upper_bound=100.0,
)
)

def test_validate_in_range_below_lower(self) -> None:
"""
Test the function validate_in_range with the input below the lower bound.
The input is below the lower bound, so this should be flagged as invalid.
"""
self.assertRaises(
ValueError,
validate_in_range,
Expand All @@ -51,6 +67,13 @@ def test_validate_in_range(self) -> None:
lower_bound=0.0,
upper_bound=100.0,
)

def test_validate_in_range_above_upper(self) -> None:
"""
Test the function validate_in_range with the input above the upper bound.
The input is above the upper bound, so this should be flagged as invalid.
"""
self.assertRaises(
ValueError,
validate_in_range,
Expand All @@ -60,6 +83,14 @@ def test_validate_in_range(self) -> None:
lower_bound=0.0,
upper_bound=100.0,
)

def test_validate_in_range_input_inside_range(self) -> None:
"""
Test the function validate_in_range with the input between the two bounds.
The input is within the upper and lower bounds, so this should not be flagged as
invalid.
"""
self.assertIsNone(
validate_in_range(
x=50.0,
Expand All @@ -69,6 +100,14 @@ def test_validate_in_range(self) -> None:
upper_bound=100.0,
)
)

def test_validate_in_range_input_inside_range_negative(self) -> None:
"""
Test the function validate_in_range with the input between the two bounds.
The input is within the upper and lower bounds, so this should not be flagged as
invalid. The lower bound and input are both negative here.
"""
self.assertIsNone(
validate_in_range(
x=-50.0,
Expand All @@ -78,6 +117,13 @@ def test_validate_in_range(self) -> None:
upper_bound=100.0,
)
)

def test_validate_in_range_invalid_input(self) -> None:
"""
Test the function validate_in_range with an invalid input type.
The input is a string, which cannot be compared to the numerical bounds.
"""
self.assertRaises(
TypeError,
validate_in_range,
Expand All @@ -88,9 +134,9 @@ def test_validate_in_range(self) -> None:
upper_bound=100.0,
)

def test_validate_is_instance(self) -> None:
def test_validate_is_instance_float_to_int(self) -> None:
"""
Test the function validate_is_instance across reasonably likely inputs
Test the function validate_is_instance comparing a float to an int.
"""
self.assertRaises(
TypeError,
Expand All @@ -99,49 +145,97 @@ def test_validate_is_instance(self) -> None:
object_name="var",
expected_type=int,
)

def test_validate_is_instance_int_to_float(self) -> None:
"""
Test the function validate_is_instance comparing an int to a float.
"""
self.assertRaises(
TypeError,
validate_is_instance,
x=120,
object_name="var",
expected_type=float,
)

def test_validate_is_instance_float_to_str(self) -> None:
"""
Test the function validate_is_instance comparing a float to a str.
"""
self.assertRaises(
TypeError,
validate_is_instance,
x=120.0,
object_name="var",
expected_type=str,
)

def test_validate_is_instance_float_to_float(self) -> None:
"""
Test the function validate_is_instance comparing a float to a float.
"""
self.assertIsNone(
validate_is_instance(x=50.0, object_name="var", expected_type=float)
)

def test_validate_is_instance_int_to_int(self) -> None:
"""
Test the function validate_is_instance comparing an int to an int.
"""
self.assertIsNone(
validate_is_instance(x=-500, object_name="var", expected_type=int)
)

def test_validate_is_instance_str_to_str(self) -> None:
"""
Test the function validate_is_instance comparing a str to a str.
"""
self.assertIsNone(
validate_is_instance(x="500", object_name="var", expected_type=str)
)

def test_cast_as_type(self) -> None:
def test_cast_as_type_int_to_float(self) -> None:
"""
Test the function cast_as_type across reasonably likely inputs
Test the function cast_as_type converting an int to a float.
"""
self.assertEqual(
cast_as_type(x=123, object_name="var", type_caster=float), 123.0
)

def test_cast_as_type_float_to_int(self) -> None:
"""
Test the function cast_as_type converting a float to an int.
"""
self.assertEqual(cast_as_type(x=123.4, object_name="var", type_caster=int), 123)

def test_cast_as_type_float_to_str(self) -> None:
"""
Test the function cast_as_type converting a float to a str.
"""
self.assertEqual(
cast_as_type(x=123.4, object_name="var", type_caster=str),
"123.4",
)

def test_cast_as_type_list_to_int(self) -> None:
"""
Test the function cast_as_type converting a list to an int.
"""
self.assertRaises(
TypeError,
cast_as_type,
x=[120.0],
object_name="var",
type_caster=int,
)

def test_cast_as_type_str_to_float_invalid(self) -> None:
"""
Test the function cast_as_type converting a str to a float.
In this case, there are characters in the string beyond numbers, which we expect
to cause the conversion to fail.
"""
self.assertRaises(
ValueError,
cast_as_type,
Expand Down

0 comments on commit 0c1e914

Please sign in to comment.