Skip to content

Commit

Permalink
Only run custom checkers for decorated functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714977033
  • Loading branch information
The kauldron Authors committed Jan 13, 2025
1 parent d504e22 commit b007732
Showing 1 changed file with 23 additions and 30 deletions.
53 changes: 23 additions & 30 deletions kauldron/typing/type_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from __future__ import annotations

import contextlib
import dataclasses
import functools
import inspect
Expand All @@ -26,7 +25,6 @@
import typing
from typing import Any, Type, TypedDict, Union

from etils import edc
from etils import enp
from etils import epy
import jaxtyping
Expand Down Expand Up @@ -133,30 +131,29 @@ def _reraise_with_shape_info(*args, _typecheck: bool = True, **kwargs):
# manually reproduce the functionality of typeguard.typechecked, so that
# we get access to the returnvalue of the function
localns = sys._getframe(1).f_locals # pylint: disable=protected-access
with _checker.activate():
memo = typeguard.CallMemo(python_func, localns, args=args, kwargs=kwargs)
retval = _undef
try:
typeguard.check_argument_types(memo)
retval = fn(*args, **kwargs)
typeguard.check_return_type(retval, memo)
return retval
except typeguard.TypeCheckError as e:
# Use function signature to construct a complete list of named arguments
sig = inspect.signature(fn)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()

annotations = {k: p.annotation for k, p in sig.parameters.items()}
# TODO(klausg): filter the stacktrace to exclude all the typechecking
raise TypeCheckError(
str(e),
arguments=bound_args.arguments,
return_value=retval,
annotations=annotations,
return_annotation=sig.return_annotation,
memo=shape_spec.Memo.from_current_context(),
) from e
memo = typeguard.CallMemo(python_func, localns, args=args, kwargs=kwargs)
retval = _undef
try:
typeguard.check_argument_types(memo)
retval = fn(*args, **kwargs)
typeguard.check_return_type(retval, memo)
return retval
except typeguard.TypeCheckError as e:
# Use function signature to construct a complete list of named arguments
sig = inspect.signature(fn)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()

annotations = {k: p.annotation for k, p in sig.parameters.items()}
# TODO(klausg): filter the stacktrace to exclude all the typechecking
raise TypeCheckError(
str(e),
arguments=bound_args.arguments,
return_value=retval,
annotations=annotations,
return_annotation=sig.return_annotation,
memo=shape_spec.Memo.from_current_context(),
) from e

return _reraise_with_shape_info

Expand Down Expand Up @@ -364,8 +361,6 @@ def _array_spec_checker_lookup(
) -> typeguard.TypeCheckerCallable | None:
"""Lookup function to register custom array type checkers in typeguard."""
del extras
if not _checker.is_active:
return None
if origin_type in [Union, types.UnionType]:
# TODO(klausg): handle Union of ArrayType with other types
if all(_is_array_type(arg) for arg in args):
Expand Down Expand Up @@ -415,8 +410,6 @@ def _dataclass_checker_lookup(
) -> typeguard.TypeCheckerCallable | None:
"""Lookup function to register custom dataclass checkers in typeguard."""
del args, extras
if not _checker.is_active:
return None
if dataclasses.is_dataclass(origin_type):
return _custom_dataclass_checker
return None
Expand Down

0 comments on commit b007732

Please sign in to comment.