diff --git a/coreax/approximation.py b/coreax/approximation.py index 350906440..d996ba1ad 100644 --- a/coreax/approximation.py +++ b/coreax/approximation.py @@ -75,6 +75,7 @@ import coreax.kernel as ck from coreax.util import ClassFactory, KernelFunction +from coreax.validation import cast_as_type, validate_in_range, validate_is_instance class KernelMeanApproximator(ABC): @@ -100,6 +101,26 @@ def __init__( """ Define an approximator to the mean of the row sum of a kernel distance matrix. """ + # Validate inputs of coreax defined classes + validate_is_instance(kernel, "kernel", ck.Kernel) + + # Validate inputs of non-coreax defined classes + random_key = cast_as_type( + x=random_key, object_name="random_key", type_caster=jnp.asarray + ) + num_kernel_points = cast_as_type( + x=num_kernel_points, object_name="num_kernel_points", type_caster=int + ) + + # Validate inputs lie within accepted ranges + validate_in_range( + x=num_kernel_points, + object_name="num_kernel_points", + strict_inequalities=True, + lower_bound=0, + ) + + # Assign inputs self.kernel = kernel self.random_key = random_key self.num_kernel_points = num_kernel_points @@ -141,6 +162,20 @@ def __init__( """ Approximate kernel row mean by regression on points selected randomly. """ + # Validate inputs of non-coreax defined classes + num_train_points = cast_as_type( + x=num_train_points, object_name="num_train_points", type_caster=int + ) + + # Validate inputs lie within accepted ranges + validate_in_range( + x=num_train_points, + object_name="num_train_points", + strict_inequalities=True, + lower_bound=0, + ) + + # Assign inputs self.num_train_points = num_train_points # Initialise parent @@ -161,8 +196,10 @@ def approximate( :return: Approximation of the kernel matrix row sum divided by the number of data points in the dataset """ - # Ensure data is the expected type - data = jnp.asarray(data) + # Validate inputs + data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d) + + # Record dataset size num_data_points = len(data) # Randomly select points for kernel regression @@ -213,6 +250,20 @@ def __init__( """ Approximate kernel row mean by regression on ANNchor selected points. """ + # Validate inputs of non-coreax defined classes + num_train_points = cast_as_type( + x=num_train_points, object_name="num_train_points", type_caster=int + ) + + # Validate inputs lie within accepted ranges + validate_in_range( + x=num_train_points, + object_name="num_train_points", + strict_inequalities=True, + lower_bound=0, + ) + + # Assign inputs self.num_train_points = num_train_points # Initialise parent @@ -233,15 +284,17 @@ def approximate( :return: Approximation of the kernel matrix row sum divided by the number of data points in the dataset """ - # Ensure data is the expected type - data = jnp.asarray(data) + # Validate inputs + data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d) + + # Record dataset size num_data_points = len(data) # Select point for kernel regression using ANNchor construction features = jnp.zeros((num_data_points, self.num_kernel_points)) features = features.at[:, 0].set(self.kernel.compute(data, data[0])[:, 0]) - body = partial(anchor_body, data=data, kernel_function=self.kernel.compute) + body = partial(_anchor_body, data=data, kernel_function=self.kernel.compute) features = lax.fori_loop(1, self.num_kernel_points, body, features) train_idx = random.choice( @@ -300,8 +353,10 @@ def approximate( :return: Approximation of the kernel matrix row sum divided by the number of data points in the dataset """ - # Ensure data is the expected type - data = jnp.asarray(data) + # Validate inputs + data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d) + + # Record dataset size num_data_points = len(data) # Randomly select points for kernel regression @@ -318,7 +373,7 @@ def approximate( @partial(jit, static_argnames=["kernel_function"]) -def anchor_body( +def _anchor_body( idx: int, features: ArrayLike, data: ArrayLike, @@ -330,12 +385,15 @@ def anchor_body( :param idx: Loop counter :param features: Loop updateables :param data: Original :math:`n \times d` dataset - :param kernel_function: Vectorised kernel function on pairs `(X,x)`: + :param kernel_function: Vectorised kernel function on pairs ``(X,x)``: :math:`k: \mathbb{R}^{n \times d} \times \mathbb{R}^d \rightarrow \mathbb{R}^n` :return: Updated loop variables `features` """ - features = jnp.asarray(features) - data = jnp.asarray(data) + # Validate inputs + features = cast_as_type( + x=features, object_name="features", type_caster=jnp.atleast_2d + ) + data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d) max_entry = features.max(axis=1).argmin() features = features.at[:, idx].set(kernel_function(data, data[max_entry])[:, 0]) diff --git a/coreax/validation.py b/coreax/validation.py new file mode 100644 index 000000000..3eaaff9bd --- /dev/null +++ b/coreax/validation.py @@ -0,0 +1,106 @@ +# © Crown Copyright GCHQ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functionality to validate data passed throughout coreax. + +The functions within this module are intended to be used as a means to validate inputs +passed to classes, functions and methods throughout the coreax codebase. +""" + +# Support annotations with | in Python < 3.10 +# TODO: Remove once no longer supporting old code +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, TypeVar + +T = TypeVar("T") + + +def validate_in_range( + x: T, + object_name: str, + strict_inequalities: bool, + lower_bound: T | None = None, + upper_bound: T | None = None, +) -> None: + """ + Verify that a given input is in a specified range. + + :param x: Variable we wish to verify lies in the specified range + :param object_name: Name of ``x`` to display if limits are broken + :param strict_inequalities: If :data:`True`, checks are applied using strict + inequalities, otherwise they are not + :param lower_bound: Lower limit placed on ``x``, or :data:`None` + :param upper_bound: Upper limit placed on ``x``, or :data:`None` + :raises ValueError: Raised if ``x`` does not fall between ``lower_limit`` and + ``upper_limit`` + :raises TypeError: Raised if x cannot be compared to a value using ``>``, ``>=``, + ``<`` or ``<=`` + """ + try: + if strict_inequalities: + if lower_bound is not None and not x > lower_bound: + raise ValueError(f"{object_name} must be strictly above {lower_bound}.") + if upper_bound is not None and not x < upper_bound: + raise ValueError(f"{object_name} must be strictly below {upper_bound}.") + else: + if lower_bound is not None and not x >= lower_bound: + raise ValueError(f"{object_name} must be {lower_bound} or above.") + if upper_bound is not None and not x <= upper_bound: + raise ValueError(f"{object_name} must be {upper_bound} or lower.") + except TypeError: + if strict_inequalities: + raise TypeError( + f"{object_name} must have a valid comparison < and > implemented." + ) + else: + raise TypeError( + f"{object_name} must have a valid comparison <= and >= implemented." + ) + + +def validate_is_instance(x: T, object_name: str, expected_type: type[T]) -> None: + """ + Verify that a given object is of a given type. + + :param x: Variable we wish to verify lies in the specified range + :param object_name: Name of ``x`` to display if it is not of type ``expected_type`` + :param expected_type: The expected type of ``x`` + :raises TypeError: Raised if ``x`` is not of type ``expected_type`` + """ + if not isinstance(x, expected_type): + raise TypeError(f"{object_name} must be of type {expected_type}.") + + +def cast_as_type(x: Any, object_name: str, type_caster: Callable) -> Any: + """ + Cast an object as a specified type. + + :param x: Variable to cast as specified type + :param object_name: Name of the object being considered + :param type_caster: Callable that ``x`` will be passed + :return: ``x``, but cast as the type specified by ``type_caster`` + :raises TypeError: Raised if ``x`` cannot be cast using ``type_caster`` + """ + try: + return type_caster(x) + except (TypeError, ValueError) as e: + error_text = f"{object_name} cannot be cast using {type_caster}. \n" + if hasattr(e, "message"): + error_text += e.message + else: + error_text += str(e) + raise TypeError(error_text) diff --git a/tests/unit/test_approximation.py b/tests/unit/test_approximation.py index 79acac19a..08bdd1223 100644 --- a/tests/unit/test_approximation.py +++ b/tests/unit/test_approximation.py @@ -107,6 +107,71 @@ def test_kernel_mean_approximator_creation(self) -> None: self.assertEqual(approximator.random_key[1], self.random_key[1]) self.assertEqual(approximator.num_kernel_points, self.num_kernel_points) + def test_kernel_mean_approximator_creation_invalid_types(self) -> None: + """ + Test the class KernelMeanApproximator rejects invalid input types. + """ + # Patch the abstract method (approximate) of the KernelMeanApproximator, so it + # can be created + p = patch.multiple(ca.KernelMeanApproximator, __abstractmethods__=set()) + p.start() + + # Define the approximator with an incorrect kernel type + self.assertRaises( + TypeError, + ca.KernelMeanApproximator, + kernel="not_a_kernel", + random_key=self.random_key, + num_kernel_points=self.num_kernel_points, + ) + + # Define the approximator with an incorrect random_key type, but that can be + # converted into an array + approximator = ca.KernelMeanApproximator( + kernel=self.kernel, + random_key=123, + num_kernel_points=self.num_kernel_points, + ) + np.testing.assert_array_equal(approximator.random_key, np.array([123])) + + # Define the approximator with an incorrect random_key type, that cannot be + # cast as an array + self.assertRaises( + TypeError, + ca.KernelMeanApproximator, + kernel=self.kernel, + random_key=int, + num_kernel_points=self.num_kernel_points, + ) + + # Define the approximator with an incorrect num_kernel_points type (float) but + # that can be cast into an int + approximator = ca.KernelMeanApproximator( + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=1.0 * self.num_kernel_points, + ) + self.assertEqual(approximator.num_kernel_points, self.num_kernel_points) + + # Define the approximator with an incorrect num_kernel_points type (float) that + # cannot be cast into an int + self.assertRaises( + TypeError, + ca.KernelMeanApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=[1], + ) + + # Define the approximator with a negative value of num_kernel_points + self.assertRaises( + ValueError, + ca.KernelMeanApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=-self.num_kernel_points, + ) + def test_random_approximator(self) -> None: """ Verify random approximator performance on toy problem. @@ -155,6 +220,112 @@ def test_random_approximator(self) -> None: ) self.assertTrue(approx_error_full <= approx_error_partial) + def test_random_approximator_creation_invalid_types(self) -> None: + """ + Test the class RandomApproximator rejects invalid input types. + """ + # Define the approximator with an incorrect kernel type + self.assertRaises( + TypeError, + ca.RandomApproximator, + kernel="not_a_kernel", + random_key=self.random_key, + num_kernel_points=self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + + # Define the approximator with an incorrect random_key type, but that can be + # converted into an array. + approximator = ca.RandomApproximator( + kernel=self.kernel, + random_key=123, + num_kernel_points=self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + np.testing.assert_array_equal(approximator.random_key, np.array([123])) + + # Define the approximator with an incorrect random_key type, that cannot be + # cast as an array + self.assertRaises( + TypeError, + ca.RandomApproximator, + kernel=self.kernel, + random_key=int, + num_kernel_points=self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + + # Define the approximator with an incorrect num_kernel_points type (float) but + # that can be cast into an int + approximator = ca.RandomApproximator( + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=1.0 * self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + self.assertEqual(approximator.num_kernel_points, self.num_kernel_points) + + # Define the approximator with an incorrect num_kernel_points type (float) that + # cannot be cast into an int + self.assertRaises( + TypeError, + ca.RandomApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=[1.0], + num_train_points=self.data.shape[0], + ) + + # Define the approximator with a negative value of num_kernel_points + self.assertRaises( + ValueError, + ca.RandomApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=-self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + + # Define the approximator with an incorrect num_train_points type (float) but + # that can be cast to an int + approximator = ca.RandomApproximator( + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=self.num_kernel_points, + num_train_points=10.0, + ) + self.assertEqual(approximator.num_train_points, 10) + + # Define the approximator with an incorrect num_train_points type (float) and + # that cannot be cast to an int + self.assertRaises( + TypeError, + ca.RandomApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=self.num_kernel_points, + num_train_points=[10.0], + ) + + # Define the approximator with a negative value of num_train_points + self.assertRaises( + ValueError, + ca.RandomApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=-self.num_kernel_points, + num_train_points=-10, + ) + + # Define a valid approximator, but call approximate with an invalid input + approximator = ca.RandomApproximator( + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=self.data.shape[0], + num_train_points=self.data.shape[0], + ) + self.assertRaises(TypeError, approximator.approximate, "not_data") + def test_annchor_approximator(self) -> None: """ Verify Annchor approximator performance on toy problem. @@ -203,6 +374,112 @@ def test_annchor_approximator(self) -> None: ) self.assertTrue(approx_error_full <= approx_error_partial) + def test_annchor_approximator_creation_invalid_types(self) -> None: + """ + Test the class ANNchorApproximator rejects invalid input types. + """ + # Define the approximator with an incorrect kernel type + self.assertRaises( + TypeError, + ca.ANNchorApproximator, + kernel="not_a_kernel", + random_key=self.random_key, + num_kernel_points=self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + + # Define the approximator with an incorrect random_key type, but that can be + # converted into an array. + approximator = ca.ANNchorApproximator( + kernel=self.kernel, + random_key=123, + num_kernel_points=self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + np.testing.assert_array_equal(approximator.random_key, np.array([123])) + + # Define the approximator with an incorrect random_key type, that cannot be + # cast as an array + self.assertRaises( + TypeError, + ca.ANNchorApproximator, + kernel=self.kernel, + random_key=int, + num_kernel_points=self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + + # Define the approximator with an incorrect num_kernel_points type (float) but + # that can be cast into an int + approximator = ca.ANNchorApproximator( + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=1.0 * self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + self.assertEqual(approximator.num_kernel_points, self.num_kernel_points) + + # Define the approximator with an incorrect num_kernel_points type (float) that + # cannot be cast into an int + self.assertRaises( + TypeError, + ca.ANNchorApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=[1.0], + num_train_points=self.data.shape[0], + ) + + # Define the approximator with a negative value of num_kernel_points + self.assertRaises( + ValueError, + ca.ANNchorApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=-self.num_kernel_points, + num_train_points=self.data.shape[0], + ) + + # Define the approximator with an incorrect num_train_points type (float) but + # that can be cast to an int + approximator = ca.ANNchorApproximator( + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=self.num_kernel_points, + num_train_points=10.0, + ) + self.assertEqual(approximator.num_train_points, 10) + + # Define the approximator with an incorrect num_train_points type (float) and + # that cannot be cast to an int + self.assertRaises( + TypeError, + ca.ANNchorApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=self.num_kernel_points, + num_train_points=[10.0], + ) + + # Define the approximator with a negative value of num_train_points + self.assertRaises( + ValueError, + ca.ANNchorApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=-self.num_kernel_points, + num_train_points=-10, + ) + + # Define a valid approximator, but call approximate with an invalid input + approximator = ca.ANNchorApproximator( + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=self.data.shape[0], + num_train_points=self.data.shape[0], + ) + self.assertRaises(TypeError, approximator.approximate, "not_data") + def test_nystrom_approximator(self) -> None: """ Verify Nystrom approximator performance on toy problem. @@ -250,6 +527,74 @@ def test_nystrom_approximator(self) -> None: ) self.assertTrue(approx_error_full <= approx_error_partial) + def test_nystrom_approximator_creation_invalid_types(self) -> None: + """ + Test the class NystromApproximator rejects invalid input types. + """ + # Define the approximator with an incorrect kernel type + self.assertRaises( + TypeError, + ca.NystromApproximator, + kernel="not_a_kernel", + random_key=self.random_key, + num_kernel_points=self.num_kernel_points, + ) + + # Define the approximator with an incorrect random_key type, but that can be + # converted into an array + approximator = ca.NystromApproximator( + kernel=self.kernel, + random_key=123, + num_kernel_points=self.num_kernel_points, + ) + np.testing.assert_array_equal(approximator.random_key, np.array([123])) + + # Define the approximator with an incorrect random_key type, that cannot be + # cast as an array + self.assertRaises( + TypeError, + ca.NystromApproximator, + kernel=self.kernel, + random_key=int, + num_kernel_points=self.num_kernel_points, + ) + + # Define the approximator with an incorrect num_kernel_points type (float) but + # that can be cast into an int + approximator = ca.NystromApproximator( + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=1.0 * self.num_kernel_points, + ) + self.assertEqual(approximator.num_kernel_points, self.num_kernel_points) + + # Define the approximator with an incorrect num_kernel_points type (float) that + # cannot be cast into an int + self.assertRaises( + TypeError, + ca.NystromApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=[1], + ) + + # Define the approximator with a negative value of num_kernel_points + self.assertRaises( + ValueError, + ca.NystromApproximator, + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=-self.num_kernel_points, + ) + + # Define a valid approximator, but call approximate with an invalid input + approximator = ca.NystromApproximator( + kernel=self.kernel, + random_key=self.random_key, + num_kernel_points=self.data.shape[0], + ) + self.assertRaises(TypeError, approximator.approximate, "not_data") + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_validation.py b/tests/unit/test_validation.py new file mode 100644 index 000000000..2a95649b8 --- /dev/null +++ b/tests/unit/test_validation.py @@ -0,0 +1,337 @@ +# © Crown Copyright GCHQ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this +# file except in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import unittest + +from coreax.validation import cast_as_type, validate_in_range, validate_is_instance + + +class TestInputValidationRange(unittest.TestCase): + """ + Tests relating to validation of inputs provided by the user lying in a given range. + """ + + def test_validate_in_range_equal_lower_strict(self) -> None: + """ + Test the function validate_in_range with the input matching the lower bound. + + The inequality is strict here, so this should be flagged as invalid. + """ + self.assertRaises( + ValueError, + validate_in_range, + x=0.0, + object_name="var", + strict_inequalities=True, + lower_bound=0.0, + upper_bound=100.0, + ) + + def test_validate_in_range_equal_lower_not_strict(self) -> None: + """ + Test the function validate_in_range with the input matching the lower bound. + + The inequality is not strict here, so this should not be flagged as invalid. + """ + self.assertIsNone( + validate_in_range( + x=0.0, + object_name="var", + strict_inequalities=False, + lower_bound=0.0, + upper_bound=100.0, + ) + ) + + def test_validate_in_range_below_lower(self) -> None: + """ + Test the function validate_in_range with the input below the lower bound. + + The input is below the lower bound, so this should be flagged as invalid. + """ + self.assertRaises( + ValueError, + validate_in_range, + x=-1.0, + object_name="var", + strict_inequalities=True, + lower_bound=0.0, + upper_bound=100.0, + ) + + def test_validate_in_range_equal_upper_strict(self) -> None: + """ + Test the function validate_in_range with the input matching the lower bound. + + The inequality is strict here, so this should be flagged as invalid. + """ + self.assertRaises( + ValueError, + validate_in_range, + x=100.0, + object_name="var", + strict_inequalities=True, + lower_bound=0.0, + upper_bound=100.0, + ) + + def test_validate_in_range_equal_upper_not_strict(self) -> None: + """ + Test the function validate_in_range with the input matching the lower bound. + + The inequality is not strict here, so this should not be flagged as invalid. + """ + self.assertIsNone( + validate_in_range( + x=100.0, + object_name="var", + strict_inequalities=False, + lower_bound=0.0, + upper_bound=100.0, + ) + ) + + def test_validate_in_range_above_upper(self) -> None: + """ + Test the function validate_in_range with the input above the upper bound. + + The input is above the upper bound, so this should be flagged as invalid. + """ + self.assertRaises( + ValueError, + validate_in_range, + x=120.0, + object_name="var", + strict_inequalities=True, + lower_bound=0.0, + upper_bound=100.0, + ) + + def test_validate_in_range_input_inside_range(self) -> None: + """ + Test the function validate_in_range with the input between the two bounds. + + The input is within the upper and lower bounds, so this should not be flagged as + invalid. + """ + self.assertIsNone( + validate_in_range( + x=50.0, + object_name="var", + strict_inequalities=True, + lower_bound=0.0, + upper_bound=100.0, + ) + ) + + def test_validate_in_range_input_inside_range_negative(self) -> None: + """ + Test the function validate_in_range with the input between the two bounds. + + The input is within the upper and lower bounds, so this should not be flagged as + invalid. The lower bound and input are both negative here. + """ + self.assertIsNone( + validate_in_range( + x=-50.0, + object_name="var", + strict_inequalities=True, + lower_bound=-100.0, + upper_bound=100.0, + ) + ) + + def test_validate_in_range_invalid_input(self) -> None: + """ + Test the function validate_in_range with an invalid input type. + + The input is a string, which cannot be compared to the numerical bounds. + """ + self.assertRaises( + TypeError, + validate_in_range, + x="1.0", + object_name="var", + strict_inequalities=True, + lower_bound=0.0, + upper_bound=100.0, + ) + + def test_validate_in_range_input_no_lower_bound(self) -> None: + """ + Test the function validate_in_range with the input between the two bounds. + + The input is below the upper bound, so this should not be flagged as invalid. + """ + self.assertIsNone( + validate_in_range( + x=50.0, + object_name="var", + strict_inequalities=True, + upper_bound=100.0, + ) + ) + + def test_validate_in_range_input_no_upper_bound(self) -> None: + """ + Test the function validate_in_range with the input between the two bounds. + + The input is above the lower bound, so this should not be flagged as invalid. + """ + self.assertIsNone( + validate_in_range( + x=50.0, + object_name="var", + strict_inequalities=True, + lower_bound=0.0, + ) + ) + + def test_validate_in_range_input_no_lower_or_upper_bound(self) -> None: + """ + Test the function validate_in_range with the input between the two bounds. + + The input is below the upper bound, so this should not be flagged as invalid. + """ + self.assertIsNone( + validate_in_range( + x=50.0, + object_name="var", + strict_inequalities=True, + ) + ) + + +class TestInputValidationInstance(unittest.TestCase): + """ + Tests relating to validation of inputs provided by the user are a given type. + """ + + def test_validate_is_instance_float_to_int(self) -> None: + """ + Test the function validate_is_instance comparing a float to an int. + """ + self.assertRaises( + TypeError, + validate_is_instance, + x=120.0, + object_name="var", + expected_type=int, + ) + + def test_validate_is_instance_int_to_float(self) -> None: + """ + Test the function validate_is_instance comparing an int to a float. + """ + self.assertRaises( + TypeError, + validate_is_instance, + x=120, + object_name="var", + expected_type=float, + ) + + def test_validate_is_instance_float_to_str(self) -> None: + """ + Test the function validate_is_instance comparing a float to a str. + """ + self.assertRaises( + TypeError, + validate_is_instance, + x=120.0, + object_name="var", + expected_type=str, + ) + + def test_validate_is_instance_float_to_float(self) -> None: + """ + Test the function validate_is_instance comparing a float to a float. + """ + self.assertIsNone( + validate_is_instance(x=50.0, object_name="var", expected_type=float) + ) + + def test_validate_is_instance_int_to_int(self) -> None: + """ + Test the function validate_is_instance comparing an int to an int. + """ + self.assertIsNone( + validate_is_instance(x=-500, object_name="var", expected_type=int) + ) + + def test_validate_is_instance_str_to_str(self) -> None: + """ + Test the function validate_is_instance comparing a str to a str. + """ + self.assertIsNone( + validate_is_instance(x="500", object_name="var", expected_type=str) + ) + + +class TestInputValidationConversion(unittest.TestCase): + """ + Tests relating to validation of inputs provided by the user convert to a given type. + """ + + def test_cast_as_type_int_to_float(self) -> None: + """ + Test the function cast_as_type converting an int to a float. + """ + self.assertEqual( + cast_as_type(x=123, object_name="var", type_caster=float), 123.0 + ) + + def test_cast_as_type_float_to_int(self) -> None: + """ + Test the function cast_as_type converting a float to an int. + """ + self.assertEqual(cast_as_type(x=123.4, object_name="var", type_caster=int), 123) + + def test_cast_as_type_float_to_str(self) -> None: + """ + Test the function cast_as_type converting a float to a str. + """ + self.assertEqual( + cast_as_type(x=123.4, object_name="var", type_caster=str), + "123.4", + ) + + def test_cast_as_type_list_to_int(self) -> None: + """ + Test the function cast_as_type converting a list to an int. + """ + self.assertRaises( + TypeError, + cast_as_type, + x=[120.0], + object_name="var", + type_caster=int, + ) + + def test_cast_as_type_str_to_float_invalid(self) -> None: + """ + Test the function cast_as_type converting a str to a float. + + In this case, there are characters in the string beyond numbers, which we expect + to cause the conversion to fail. + """ + self.assertRaises( + TypeError, + cast_as_type, + x="120.0ABC", + object_name="var", + type_caster=float, + ) + + +if __name__ == "__main__": + unittest.main()