Skip to content

Commit

Permalink
feat: moved away from use of variable in descriptions
Browse files Browse the repository at this point in the history
Refs: 251
  • Loading branch information
pc532627 committed Dec 4, 2023
1 parent 422074e commit 3ddd1f7
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 104 deletions.
57 changes: 27 additions & 30 deletions coreax/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,7 @@

import coreax.kernel as ck
from coreax.util import ClassFactory, KernelFunction
from coreax.validation import (
cast_variable_as_type,
validate_in_range,
validate_variable_is_instance,
)
from coreax.validation import cast_as_type, validate_in_range, validate_is_instance


class KernelMeanApproximator(ABC):
Expand All @@ -106,19 +102,22 @@ def __init__(
Define an approximator to the mean of the row sum of a kernel distance matrix.
"""
# Validate inputs of coreax defined classes
validate_variable_is_instance(kernel, "kernel", ck.Kernel)
validate_is_instance(kernel, "kernel", ck.Kernel)

# Validate inputs of non-coreax defined classes
random_key = cast_variable_as_type(
x=random_key, variable_name="random_key", type_caster=jnp.asarray
random_key = cast_as_type(
x=random_key, object_name="random_key", type_caster=jnp.asarray
)
num_kernel_points = cast_variable_as_type(
x=num_kernel_points, variable_name="num_kernel_points", type_caster=int
num_kernel_points = cast_as_type(
x=num_kernel_points, object_name="num_kernel_points", type_caster=int
)

# Validate inputs lie within accepted ranges
validate_in_range(
x=num_kernel_points, variable_name="num_kernel_points", lower_limit=0
x=num_kernel_points,
object_name="num_kernel_points",
strict_inequalities=True,
lower_bound=0,
)

# Assign inputs
Expand Down Expand Up @@ -164,13 +163,16 @@ def __init__(
Approximate kernel row mean by regression on points selected randomly.
"""
# Validate inputs of non-coreax defined classes
num_train_points = cast_variable_as_type(
x=num_train_points, variable_name="num_train_points", type_caster=int
num_train_points = cast_as_type(
x=num_train_points, object_name="num_train_points", type_caster=int
)

# Validate inputs lie within accepted ranges
validate_in_range(
x=num_train_points, variable_name="num_train_points", lower_limit=0
x=num_train_points,
object_name="num_train_points",
strict_inequalities=True,
lower_bound=0,
)

# Assign inputs
Expand All @@ -195,9 +197,7 @@ def approximate(
data points in the dataset
"""
# Validate inputs
data = cast_variable_as_type(
x=data, variable_name="data", type_caster=jnp.asarray
)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.asarray)

# Record dataset size
num_data_points = len(data)
Expand Down Expand Up @@ -251,13 +251,16 @@ def __init__(
Approximate kernel row mean by regression on ANNchor selected points.
"""
# Validate inputs of non-coreax defined classes
num_train_points = cast_variable_as_type(
x=num_train_points, variable_name="num_train_points", type_caster=int
num_train_points = cast_as_type(
x=num_train_points, object_name="num_train_points", type_caster=int
)

# Validate inputs lie within accepted ranges
validate_in_range(
x=num_train_points, variable_name="num_train_points", lower_limit=0
x=num_train_points,
object_name="num_train_points",
strict_inequalities=True,
lower_bound=0,
)

# Assign inputs
Expand All @@ -282,9 +285,7 @@ def approximate(
data points in the dataset
"""
# Validate inputs
data = cast_variable_as_type(
x=data, variable_name="data", type_caster=jnp.asarray
)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.asarray)

# Record dataset size
num_data_points = len(data)
Expand Down Expand Up @@ -353,9 +354,7 @@ def approximate(
data points in the dataset
"""
# Validate inputs
data = cast_variable_as_type(
x=data, variable_name="data", type_caster=jnp.asarray
)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.asarray)

# Record dataset size
num_data_points = len(data)
Expand Down Expand Up @@ -391,10 +390,8 @@ def _anchor_body(
:return: Updated loop variables `features`
"""
# Validate inputs
features = cast_variable_as_type(
x=features, variable_name="features", type_caster=jnp.asarray
)
data = cast_variable_as_type(x=data, variable_name="data", type_caster=jnp.asarray)
features = cast_as_type(x=features, object_name="features", type_caster=jnp.asarray)
data = cast_as_type(x=data, object_name="data", type_caster=jnp.asarray)

max_entry = features.max(axis=1).argmin()
features = features.at[:, idx].set(kernel_function(data, data[max_entry])[:, 0])
Expand Down
65 changes: 42 additions & 23 deletions coreax/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,55 +26,75 @@
from collections.abc import Callable
from typing import Any, TypeVar

import numpy as np

T = TypeVar("T")


def validate_in_range(
x: T,
variable_name: str,
lower_limit: T = -np.inf,
upper_limit: T = np.inf,
object_name: str,
strict_inequalities: bool,
lower_bound: T | None = None,
upper_bound: T | None = None,
) -> None:
"""
Verify that a given input is in a specified range.
:param x: Variable we wish to verify lies in the specified range
:param variable_name: Name of ``x`` to display if limits are broken
:param lower_limit: Lower limit placed on ``x``
:param upper_limit: Upper limit placed on ``x``
:param object_name: Name of ``x`` to display if limits are broken
:param strict_inequalities: If true, checks are applied using strict inequalities,
otherwise they are not
:param lower_bound: Lower limit placed on ``x``, or :data:`None`
:param upper_bound: Upper limit placed on ``x``, or :data:`None`
:raises ValueError: Raised if ``x`` does not fall between ``lower_limit`` and
``upper_limit``
:raises TypeError: Raised if x cannot be compared to a value using >, >=, < or <=
"""
if not lower_limit < x < upper_limit:
raise ValueError(
f"{variable_name} must be between {lower_limit} and {upper_limit}. "
f"Given value {x}."
try:
if strict_inequalities:
if lower_bound is not None and not x > lower_bound:
raise ValueError(
f"{object_name} must be strictly above {lower_bound}. "
f"Given value {x}."
)
if upper_bound is not None and not x < upper_bound:
raise ValueError(
f"{object_name} must be strictly below {lower_bound}. "
f"Given value {x}."
)
else:
if lower_bound is not None and not x >= lower_bound:
raise ValueError(
f"{object_name} must be {lower_bound} or above. Given value {x}."
)
if upper_bound is not None and not x <= upper_bound:
raise ValueError(
f"{object_name} must be {lower_bound} or lower. Given value {x}."
)
except TypeError:
raise TypeError(
f"{object_name} must have a valid comparison <, <=, > and >= implemented."
)


def validate_variable_is_instance(
x: Any, variable_name: str, expected_type: Any
) -> None:
def validate_is_instance(x: T, object_name: str, expected_type: type[T]) -> None:
"""
Verify that a given variable is of a given type.
Verify that a given object is of a given type.
:param x: Variable we wish to verify lies in the specified range
:param variable_name: Name of ``x`` to display if limits are broken
:param object_name: Name of ``x`` to display if limits are broken
:param expected_type: The expected type of ``x``
:raises TypeError: Raised if ``x`` is not of type ``expected_type``
"""
if not isinstance(x, expected_type):
raise TypeError(f"{variable_name} must be of type {expected_type}.")
raise TypeError(f"{object_name} must be of type {expected_type}.")


def cast_variable_as_type(x: Any, variable_name: str, type_caster: Callable) -> Any:
def cast_as_type(x: Any, object_name: str, type_caster: Callable) -> Any:
"""
Cast a variable as a specified type.
Cast an object as a specified type.
:param x: Variable to cast as specified type
:param variable_name: Name of the variable being considered
:param object_name: Name of the object being considered
:param type_caster: Callable that ``x`` will be passed
:return: ``x``, but cast as the type specified by ``type_caster``
:raises TypeError: Raised if ``x`` cannot be cast using ``type_caster``
Expand All @@ -83,8 +103,7 @@ def cast_variable_as_type(x: Any, variable_name: str, type_caster: Callable) ->
return type_caster(x)
except Exception as e:
error_text = (
f"{variable_name} cannot be cast using {type_caster}. "
f"Given value {x}.\n"
f"{object_name} cannot be cast using {type_caster}. " f"Given value {x}.\n"
)
if hasattr(e, "message"):
error_text += e.message
Expand Down
Loading

0 comments on commit 3ddd1f7

Please sign in to comment.