diff --git a/liminal/validation/__init__.py b/liminal/validation/__init__.py index 0ff61a7..c0f9a5f 100644 --- a/liminal/validation/__init__.py +++ b/liminal/validation/__init__.py @@ -1,6 +1,6 @@ import inspect from datetime import datetime -from functools import wraps +from functools import partial, wraps from typing import TYPE_CHECKING, Any, Callable from pydantic import BaseModel, ConfigDict @@ -101,57 +101,64 @@ def create_validation_report( ) -def liminal_validator( - validator_level: ValidationSeverity = ValidationSeverity.LOW, - validator_name: str | None = None, +def _liminal_decorator( + func: Callable[[type["BenchlingBaseModel"]], BenchlingValidatorReport | None], + validator_level: ValidationSeverity, + validator_name: str | None, ) -> Callable: - """A decorator that validates a function that takes a Benchling entity as an argument and returns None. - - Parameters: - validator_level: ValidationSeverity - The level of the validator. - validator_name: str | None - The name of the validator. Defaults to the pascalized name of the function. - """ - - def decorator( - func: Callable[[type["BenchlingBaseModel"]], BenchlingValidatorReport | None], - ) -> Callable: - """Decorator that validates a function that takes a Benchling entity as an argument and returns None.""" - sig = inspect.signature(func) - params = list(sig.parameters.values()) - if not params or params[0].name != "self" or len(params) > 1: - raise TypeError( - "Validator must defined in a schema class, where the only argument to this validator must be 'self'." - ) + """Core decorator logic for liminal_validator.""" + sig = inspect.signature(func) + params = list(sig.parameters.values()) + if not params or params[0].name != "self" or len(params) > 1: + raise TypeError( + "Validator must be defined in a schema class, where the only argument to this validator must be 'self'." + ) - nonlocal validator_name - if validator_name is None: - validator_name = pascalize(func.__name__) - - @wraps(func) - def wrapper(self: type["BenchlingBaseModel"]) -> BenchlingValidatorReport: - """Wrapper that runs the validator function and returns a BenchlingValidatorReport.""" - try: - ret_val = func(self) - if type(ret_val) is BenchlingValidatorReport: - return ret_val - except Exception as e: - return BenchlingValidatorReport.create_validation_report( - valid=False, - level=validator_level, - entity=self, - validator_name=validator_name, - message=str(e), - ) + if validator_name is None: + validator_name = pascalize(func.__name__) + + @wraps(func) + def wrapper(self: type["BenchlingBaseModel"]) -> BenchlingValidatorReport: + """Wrapper that runs the validator function and returns a BenchlingValidatorReport.""" + try: + ret_val = func(self) + if type(ret_val) is BenchlingValidatorReport: + return ret_val + except Exception as e: return BenchlingValidatorReport.create_validation_report( - valid=True, + valid=False, level=validator_level, entity=self, validator_name=validator_name, + message=str(e), ) + return BenchlingValidatorReport.create_validation_report( + valid=True, + level=validator_level, + entity=self, + validator_name=validator_name, + ) + + setattr(wrapper, "_is_liminal_validator", True) + return wrapper - setattr(wrapper, "_is_liminal_validator", True) - return wrapper - return decorator +def liminal_validator( + validator_level: ValidationSeverity = ValidationSeverity.LOW, + validator_name: str | None = None, +) -> Callable: + """A decorator for a function that validates a Benchling entity, defined on a schema class. + Wraps around any exceptions raised by the validator function, and returns a BenchlingValidatorReport. + + Parameters + ---------- + validator_level: ValidationSeverity + The severity level of the validation report. Defaults to ValidationSeverity.LOW. + validator_name: str | None = + The name of the validator. Defaults to the PascalCase version of the function name. + """ + return partial( + _liminal_decorator, + validator_level=validator_level, + validator_name=validator_name, + )