Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add input validation for non-private methods of approximation.py #260

Merged
merged 11 commits into from
Dec 6, 2023
80 changes: 69 additions & 11 deletions coreax/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@

import coreax.kernel as ck
from coreax.util import ClassFactory, KernelFunction
from coreax.validation import cast_as_type, validate_in_range, validate_is_instance


class KernelMeanApproximator(ABC):
Expand All @@ -100,6 +101,26 @@ def __init__(
"""
Define an approximator to the mean of the row sum of a kernel distance matrix.
"""
# Validate inputs of coreax defined classes
validate_is_instance(kernel, "kernel", ck.Kernel)

# Validate inputs of non-coreax defined classes
random_key = cast_as_type(
x=random_key, object_name="random_key", type_caster=jnp.asarray
)
num_kernel_points = cast_as_type(
x=num_kernel_points, object_name="num_kernel_points", type_caster=int
)

# Validate inputs lie within accepted ranges
validate_in_range(
x=num_kernel_points,
object_name="num_kernel_points",
strict_inequalities=True,
lower_bound=0,
)

# Assign inputs
self.kernel = kernel
self.random_key = random_key
self.num_kernel_points = num_kernel_points
Expand Down Expand Up @@ -141,6 +162,20 @@ def __init__(
"""
Approximate kernel row mean by regression on points selected randomly.
"""
# Validate inputs of non-coreax defined classes
num_train_points = cast_as_type(
x=num_train_points, object_name="num_train_points", type_caster=int
)

# Validate inputs lie within accepted ranges
validate_in_range(
x=num_train_points,
object_name="num_train_points",
strict_inequalities=True,
lower_bound=0,
)

# Assign inputs
self.num_train_points = num_train_points

# Initialise parent
Expand All @@ -161,8 +196,10 @@ def approximate(
:return: Approximation of the kernel matrix row sum divided by the number of
data points in the dataset
"""
# Ensure data is the expected type
data = jnp.asarray(data)
# Validate inputs
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)
tp832944 marked this conversation as resolved.
Show resolved Hide resolved

# Record dataset size
num_data_points = len(data)

# Randomly select points for kernel regression
Expand Down Expand Up @@ -213,6 +250,20 @@ def __init__(
"""
Approximate kernel row mean by regression on ANNchor selected points.
"""
# Validate inputs of non-coreax defined classes
num_train_points = cast_as_type(
x=num_train_points, object_name="num_train_points", type_caster=int
)

# Validate inputs lie within accepted ranges
validate_in_range(
x=num_train_points,
object_name="num_train_points",
strict_inequalities=True,
lower_bound=0,
)

# Assign inputs
self.num_train_points = num_train_points

# Initialise parent
Expand All @@ -233,15 +284,17 @@ def approximate(
:return: Approximation of the kernel matrix row sum divided by the number of
data points in the dataset
"""
# Ensure data is the expected type
data = jnp.asarray(data)
# Validate inputs
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

# Record dataset size
num_data_points = len(data)

# Select point for kernel regression using ANNchor construction
features = jnp.zeros((num_data_points, self.num_kernel_points))

features = features.at[:, 0].set(self.kernel.compute(data, data[0])[:, 0])
body = partial(anchor_body, data=data, kernel_function=self.kernel.compute)
body = partial(_anchor_body, data=data, kernel_function=self.kernel.compute)
features = lax.fori_loop(1, self.num_kernel_points, body, features)

train_idx = random.choice(
Expand Down Expand Up @@ -300,8 +353,10 @@ def approximate(
:return: Approximation of the kernel matrix row sum divided by the number of
data points in the dataset
"""
# Ensure data is the expected type
data = jnp.asarray(data)
# Validate inputs
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

# Record dataset size
num_data_points = len(data)

# Randomly select points for kernel regression
Expand All @@ -318,7 +373,7 @@ def approximate(


@partial(jit, static_argnames=["kernel_function"])
def anchor_body(
def _anchor_body(
idx: int,
features: ArrayLike,
data: ArrayLike,
Expand All @@ -330,12 +385,15 @@ def anchor_body(
:param idx: Loop counter
:param features: Loop updateables
:param data: Original :math:`n \times d` dataset
:param kernel_function: Vectorised kernel function on pairs `(X,x)`:
:param kernel_function: Vectorised kernel function on pairs ``(X,x)``:
:math:`k: \mathbb{R}^{n \times d} \times \mathbb{R}^d \rightarrow \mathbb{R}^n`
:return: Updated loop variables `features`
"""
features = jnp.asarray(features)
data = jnp.asarray(data)
# Validate inputs
features = cast_as_type(
x=features, object_name="features", type_caster=jnp.atleast_2d
)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

max_entry = features.max(axis=1).argmin()
features = features.at[:, idx].set(kernel_function(data, data[max_entry])[:, 0])
Expand Down
106 changes: 106 additions & 0 deletions coreax/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Functionality to validate data passed throughout coreax.

The functions within this module are intended to be used as a means to validate inputs
passed to classes, functions and methods throughout the coreax codebase.
"""

# Support annotations with | in Python < 3.10
# TODO: Remove once no longer supporting old code
from __future__ import annotations

from collections.abc import Callable
from typing import Any, TypeVar

T = TypeVar("T")


def validate_in_range(
x: T,
object_name: str,
strict_inequalities: bool,
lower_bound: T | None = None,
upper_bound: T | None = None,
) -> None:
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
"""
Verify that a given input is in a specified range.

:param x: Variable we wish to verify lies in the specified range
:param object_name: Name of ``x`` to display if limits are broken
:param strict_inequalities: If :data:`True`, checks are applied using strict
inequalities, otherwise they are not
:param lower_bound: Lower limit placed on ``x``, or :data:`None`
:param upper_bound: Upper limit placed on ``x``, or :data:`None`
:raises ValueError: Raised if ``x`` does not fall between ``lower_limit`` and
``upper_limit``
:raises TypeError: Raised if x cannot be compared to a value using ``>``, ``>=``,
``<`` or ``<=``
"""
try:
if strict_inequalities:
if lower_bound is not None and not x > lower_bound:
raise ValueError(f"{object_name} must be strictly above {lower_bound}.")
if upper_bound is not None and not x < upper_bound:
raise ValueError(f"{object_name} must be strictly below {upper_bound}.")
else:
if lower_bound is not None and not x >= lower_bound:
raise ValueError(f"{object_name} must be {lower_bound} or above.")
if upper_bound is not None and not x <= upper_bound:
raise ValueError(f"{object_name} must be {upper_bound} or lower.")
except TypeError:
if strict_inequalities:
raise TypeError(
f"{object_name} must have a valid comparison < and > implemented."
)
else:
raise TypeError(
f"{object_name} must have a valid comparison <= and >= implemented."
)


def validate_is_instance(x: T, object_name: str, expected_type: type[T]) -> None:
"""
Verify that a given object is of a given type.

:param x: Variable we wish to verify lies in the specified range
tp832944 marked this conversation as resolved.
Show resolved Hide resolved
:param object_name: Name of ``x`` to display if it is not of type ``expected_type``
:param expected_type: The expected type of ``x``
:raises TypeError: Raised if ``x`` is not of type ``expected_type``
"""
if not isinstance(x, expected_type):
raise TypeError(f"{object_name} must be of type {expected_type}.")


def cast_as_type(x: Any, object_name: str, type_caster: Callable) -> Any:
"""
Cast an object as a specified type.

:param x: Variable to cast as specified type
:param object_name: Name of the object being considered
:param type_caster: Callable that ``x`` will be passed
:return: ``x``, but cast as the type specified by ``type_caster``
:raises TypeError: Raised if ``x`` cannot be cast using ``type_caster``
"""
try:
return type_caster(x)
except (TypeError, ValueError) as e:
error_text = f"{object_name} cannot be cast using {type_caster}. \n"
if hasattr(e, "message"):
error_text += e.message
else:
error_text += str(e)
raise TypeError(error_text)
Loading