Skip to content

Commit

Permalink
Merge pull request #260 from gchq/feature/input_validation_251
Browse files Browse the repository at this point in the history
feat: add input validation for non-private methods of approximation.py
  • Loading branch information
tp832944 authored Dec 6, 2023
2 parents 12cad63 + daf49c2 commit 6d4d34f
Show file tree
Hide file tree
Showing 4 changed files with 857 additions and 11 deletions.
80 changes: 69 additions & 11 deletions coreax/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@

import coreax.kernel as ck
from coreax.util import ClassFactory, KernelFunction
from coreax.validation import cast_as_type, validate_in_range, validate_is_instance


class KernelMeanApproximator(ABC):
Expand All @@ -100,6 +101,26 @@ def __init__(
"""
Define an approximator to the mean of the row sum of a kernel distance matrix.
"""
# Validate inputs of coreax defined classes
validate_is_instance(kernel, "kernel", ck.Kernel)

# Validate inputs of non-coreax defined classes
random_key = cast_as_type(
x=random_key, object_name="random_key", type_caster=jnp.asarray
)
num_kernel_points = cast_as_type(
x=num_kernel_points, object_name="num_kernel_points", type_caster=int
)

# Validate inputs lie within accepted ranges
validate_in_range(
x=num_kernel_points,
object_name="num_kernel_points",
strict_inequalities=True,
lower_bound=0,
)

# Assign inputs
self.kernel = kernel
self.random_key = random_key
self.num_kernel_points = num_kernel_points
Expand Down Expand Up @@ -141,6 +162,20 @@ def __init__(
"""
Approximate kernel row mean by regression on points selected randomly.
"""
# Validate inputs of non-coreax defined classes
num_train_points = cast_as_type(
x=num_train_points, object_name="num_train_points", type_caster=int
)

# Validate inputs lie within accepted ranges
validate_in_range(
x=num_train_points,
object_name="num_train_points",
strict_inequalities=True,
lower_bound=0,
)

# Assign inputs
self.num_train_points = num_train_points

# Initialise parent
Expand All @@ -161,8 +196,10 @@ def approximate(
:return: Approximation of the kernel matrix row sum divided by the number of
data points in the dataset
"""
# Ensure data is the expected type
data = jnp.asarray(data)
# Validate inputs
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

# Record dataset size
num_data_points = len(data)

# Randomly select points for kernel regression
Expand Down Expand Up @@ -213,6 +250,20 @@ def __init__(
"""
Approximate kernel row mean by regression on ANNchor selected points.
"""
# Validate inputs of non-coreax defined classes
num_train_points = cast_as_type(
x=num_train_points, object_name="num_train_points", type_caster=int
)

# Validate inputs lie within accepted ranges
validate_in_range(
x=num_train_points,
object_name="num_train_points",
strict_inequalities=True,
lower_bound=0,
)

# Assign inputs
self.num_train_points = num_train_points

# Initialise parent
Expand All @@ -233,15 +284,17 @@ def approximate(
:return: Approximation of the kernel matrix row sum divided by the number of
data points in the dataset
"""
# Ensure data is the expected type
data = jnp.asarray(data)
# Validate inputs
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

# Record dataset size
num_data_points = len(data)

# Select point for kernel regression using ANNchor construction
features = jnp.zeros((num_data_points, self.num_kernel_points))

features = features.at[:, 0].set(self.kernel.compute(data, data[0])[:, 0])
body = partial(anchor_body, data=data, kernel_function=self.kernel.compute)
body = partial(_anchor_body, data=data, kernel_function=self.kernel.compute)
features = lax.fori_loop(1, self.num_kernel_points, body, features)

train_idx = random.choice(
Expand Down Expand Up @@ -304,8 +357,10 @@ def approximate(
:return: Approximation of the kernel matrix row sum divided by the number of
data points in the dataset
"""
# Ensure data is the expected type
data = jnp.asarray(data)
# Validate inputs
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

# Record dataset size
num_data_points = len(data)

# Randomly select points for kernel regression
Expand All @@ -322,7 +377,7 @@ def approximate(


@partial(jit, static_argnames=["kernel_function"])
def anchor_body(
def _anchor_body(
idx: int,
features: ArrayLike,
data: ArrayLike,
Expand All @@ -334,12 +389,15 @@ def anchor_body(
:param idx: Loop counter
:param features: Loop updateables
:param data: Original :math:`n \times d` dataset
:param kernel_function: Vectorised kernel function on pairs `(X,x)`:
:param kernel_function: Vectorised kernel function on pairs ``(X,x)``:
:math:`k: \mathbb{R}^{n \times d} \times \mathbb{R}^d \rightarrow \mathbb{R}^n`
:return: Updated loop variables `features`
"""
features = jnp.asarray(features)
data = jnp.asarray(data)
# Validate inputs
features = cast_as_type(
x=features, object_name="features", type_caster=jnp.atleast_2d
)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.atleast_2d)

max_entry = features.max(axis=1).argmin()
features = features.at[:, idx].set(kernel_function(data, data[max_entry])[:, 0])
Expand Down
106 changes: 106 additions & 0 deletions coreax/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Functionality to validate data passed throughout coreax.
The functions within this module are intended to be used as a means to validate inputs
passed to classes, functions and methods throughout the coreax codebase.
"""

# Support annotations with | in Python < 3.10
# TODO: Remove once no longer supporting old code
from __future__ import annotations

from collections.abc import Callable
from typing import Any, TypeVar

T = TypeVar("T")


def validate_in_range(
x: T,
object_name: str,
strict_inequalities: bool,
lower_bound: T | None = None,
upper_bound: T | None = None,
) -> None:
"""
Verify that a given input is in a specified range.
:param x: Variable we wish to verify lies in the specified range
:param object_name: Name of ``x`` to display if limits are broken
:param strict_inequalities: If :data:`True`, checks are applied using strict
inequalities, otherwise they are not
:param lower_bound: Lower limit placed on ``x``, or :data:`None`
:param upper_bound: Upper limit placed on ``x``, or :data:`None`
:raises ValueError: Raised if ``x`` does not fall between ``lower_limit`` and
``upper_limit``
:raises TypeError: Raised if x cannot be compared to a value using ``>``, ``>=``,
``<`` or ``<=``
"""
try:
if strict_inequalities:
if lower_bound is not None and not x > lower_bound:
raise ValueError(f"{object_name} must be strictly above {lower_bound}.")
if upper_bound is not None and not x < upper_bound:
raise ValueError(f"{object_name} must be strictly below {upper_bound}.")
else:
if lower_bound is not None and not x >= lower_bound:
raise ValueError(f"{object_name} must be {lower_bound} or above.")
if upper_bound is not None and not x <= upper_bound:
raise ValueError(f"{object_name} must be {upper_bound} or lower.")
except TypeError:
if strict_inequalities:
raise TypeError(
f"{object_name} must have a valid comparison < and > implemented."
)
else:
raise TypeError(
f"{object_name} must have a valid comparison <= and >= implemented."
)


def validate_is_instance(x: T, object_name: str, expected_type: type[T]) -> None:
"""
Verify that a given object is of a given type.
:param x: Variable we wish to verify lies in the specified range
:param object_name: Name of ``x`` to display if it is not of type ``expected_type``
:param expected_type: The expected type of ``x``
:raises TypeError: Raised if ``x`` is not of type ``expected_type``
"""
if not isinstance(x, expected_type):
raise TypeError(f"{object_name} must be of type {expected_type}.")


def cast_as_type(x: Any, object_name: str, type_caster: Callable) -> Any:
"""
Cast an object as a specified type.
:param x: Variable to cast as specified type
:param object_name: Name of the object being considered
:param type_caster: Callable that ``x`` will be passed
:return: ``x``, but cast as the type specified by ``type_caster``
:raises TypeError: Raised if ``x`` cannot be cast using ``type_caster``
"""
try:
return type_caster(x)
except (TypeError, ValueError) as e:
error_text = f"{object_name} cannot be cast using {type_caster}. \n"
if hasattr(e, "message"):
error_text += e.message
else:
error_text += str(e)
raise TypeError(error_text)
Loading

0 comments on commit 6d4d34f

Please sign in to comment.