diff --git a/pyproject.toml b/pyproject.toml index df93866..3dbf2bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "runtime-keypath" -version = "0.1.5" +version = "0.1.6" authors = [{ name = "Chris Fu", email = "17433201@qq.com" }] description = "Supports runtime key-path recording/accessing for Python." classifiers = [ diff --git a/runtime_keypath/_core.py b/runtime_keypath/_core.py index 6fe52f4..cb34e8a 100644 --- a/runtime_keypath/_core.py +++ b/runtime_keypath/_core.py @@ -3,6 +3,7 @@ __all__ = [ "KeyPath", "KeyPathSupporting", + "key_path_supporting", ] import threading @@ -21,69 +22,50 @@ override, ) +_T = TypeVar("_T") + _Value_co = TypeVar("_Value_co", covariant=True) _Value_0 = TypeVar("_Value_0") _MISSING = cast("Any", object()) -class KeyPathSupporting: - """ - A base class that indicates an object can be used as a chain in - `KeyPath.of(...)` call. - """ - - # ! This method is intentially not named as `__getattribute__`. See below for - # ! reason. - def _(self, key: str, /) -> Any: - try: - recorder = _thread_local.recorder - except AttributeError: - # There is no recorder, which means that `KeyPath.of` is not being called. - # So we don't need to record this key. - return super().__getattribute__(key) - - if recorder.busy: - # The recorder is busy, which means that another member is being accessed, - # typically because the computation of that member is dependent on this one. - # So we don't need to record this key. - return super().__getattribute__(key) - - recorder.busy = True +@final +class _KeyPathRecorder: + __slots__ = ("busy", "start", "end", "key_list") - if recorder.start is not _MISSING and recorder.end is not self: - raise RuntimeError( - " ".join( - [ - "Key-path is broken. Check if there is something that does NOT", - "support key-paths in the member chain.", - ] - ) - ) + busy: bool + start: Any + end: Any + key_list: list[str] - value = super().__getattribute__(key) + def __init__(self, /) -> None: + self.busy = False + self.start = _MISSING + self.end = _MISSING + self.key_list = [] - recorder.busy = False - if recorder.start is _MISSING: - recorder.start = self - recorder.end = value - recorder.key_list.append(key) - return value +class _ThreadLocalProtocol(Protocol): + recorder: _KeyPathRecorder + """ + The active key-path recorder for this thread. May not exist. + """ - # ! `__getattribute__(...)` is declared against `TYPE_CHECKING`, so that unknown - # ! attributes on conforming classes won't be regarded as known by type-checkers. - if not TYPE_CHECKING: - __getattribute__ = _ - del _ +_thread_local = cast("_ThreadLocalProtocol", threading.local()) -# ! A metaclass is made for class `KeyPath`, and `KeyPath.of` is provided as a property -# ! on class `KeyPath`, so that whenever `KeyPath.of` gets accessed, we can do something -# ! before it actually gets called. @final class _KeyPathMeta(type): + """ + The metaclass for class `KeyPath`. + + It exists mainly to provide `KeyPath.of` as a property. + """ + + # ! `of` is provided as a property here, so that whenever `KeyPath.of` gets + # ! accessed, we can do something before it actually gets called. @property def of(self, /) -> _KeyPathOfFunction: # ! Docstring here is for Pylance hint. @@ -116,19 +98,16 @@ def of(self, /) -> _KeyPathOfFunction: >>> class A(KeyPathSupporting): ... def __init__(self) -> None: ... self.b = B() - ... def __repr__(self) -> str: - ... return "a" - - >>> class B(KeyPathSupporting): + >>> @key_path_supporting + ... class B: ... def __init__(self) -> None: ... self.c = C() - >>> class C: ... pass - >>> a = A() - >>> KeyPath.of(a.b.c) - KeyPath(target=a, keys=('b', 'c')) + >>> key_path = KeyPath.of(a.b.c) + >>> assert key_path.base is a + >>> assert key_path.keys == ("b", "c") """ try: @@ -152,7 +131,74 @@ def of(self, /) -> _KeyPathOfFunction: return func -# ! We implement the result of `KeyPath.of` as a stand-alone class, so that when an +@final +class KeyPath(Generic[_Value_co], metaclass=_KeyPathMeta): + """ + An object that stands for a member chain from a base object. + """ + + __base: Final[Any] + __keys: Final[Sequence[str]] + + def __init__(self, /, target: Any, keys: str | Sequence[str]) -> None: + self.__base = target + + if isinstance(keys, str): + keys = tuple(keys.split(".")) + else: + keys = tuple(keys) + self.__keys = keys + + @property + def base(self, /) -> Any: + return self.__base + + @property + def keys(self, /) -> Sequence[str]: + return self.__keys + + def get(self, /) -> _Value_co: + value = self.__base + for key in self.__keys: + value = getattr(value, key) + return value + + def unsafe_set(self: KeyPath[_Value_0], value: _Value_0, /) -> None: + target = self.__base + keys = self.__keys + i_last_key = len(keys) - 1 + for i in range(i_last_key): + target = getattr(target, keys[i]) + setattr(target, keys[i_last_key], value) + + @deprecated("`KeyPath.set` is deprecated. Use `KeyPath.unsafe_set` instead.") + def set(self: KeyPath[_Value_0], value: _Value_0, /) -> None: + return self.unsafe_set(value) + + @override + def __hash__(self, /) -> int: + return hash((self.base, self.keys)) + + @override + def __eq__(self, other: object, /) -> bool: + return ( + isinstance(other, KeyPath) + and self.base is other.base + and self.keys == other.keys + ) + + @override + def __repr__(self, /) -> str: + type_name = type(self).__name__ + base = self.base + keys = self.keys + return f"{type_name}({base=!r}, {keys=!r})" + + def __call__(self, /) -> _Value_co: + return self.get() + + +# ! We implement the result of `KeyPath.of` as a callable object, so that when an # ! exception occurred during the key-path access, there would still be a chance to # ! perform some finalization. class _KeyPathOfFunction: @@ -186,19 +232,15 @@ class _KeyPathOfFunction: >>> class A(KeyPathSupporting): ... def __init__(self) -> None: ... self.b = B() - ... def __repr__(self) -> str: - ... return "a" - >>> class B(KeyPathSupporting): ... def __init__(self) -> None: ... self.c = C() - >>> class C: ... pass - >>> a = A() - >>> KeyPath.of(a.b.c) - KeyPath(target=a, keys=('b', 'c')) + >>> key_path = KeyPath.of(a.b.c) + >>> assert key_path.base is a + >>> assert key_path.keys == ("b", "c") """ __invoked: bool = False @@ -252,91 +294,119 @@ def __del__(self, /) -> None: del _thread_local.recorder -@final -class KeyPath(Generic[_Value_co], metaclass=_KeyPathMeta): +class KeyPathSupporting: """ - An object that stands for a member chain from a base object. + A base class that supporting key-paths. + + Examples + -------- + >>> class C(KeyPathSupporting): + ... v = 0 + >>> c = C() + >>> key_path = KeyPath.of(c.v) + >>> assert key_path.base is c + >>> assert key_path.keys == ("v",) """ - __target: Final[Any] - __keys: Final[Sequence[str]] + # ! This method is intentially not named as `__getattribute__`. See below for + # ! reason. + def _(self, key: str, /) -> Any: + try: + recorder = _thread_local.recorder + except AttributeError: + # There is no recorder, which means that `KeyPath.of` is not being called. + # So we don't need to record this key. + return super().__getattribute__(key) - def __init__(self, /, target: Any, keys: str | Sequence[str]) -> None: - self.__target = target + if recorder.busy: + # The recorder is busy, which means that another member is being accessed, + # typically because the computation of that member is dependent on this one. + # So we don't need to record this key. + return super().__getattribute__(key) - if isinstance(keys, str): - keys = tuple(keys.split(".")) - else: - keys = tuple(keys) - self.__keys = keys + recorder.busy = True - @property - def target(self, /) -> Any: - return self.__target + if recorder.start is not _MISSING and recorder.end is not self: + raise RuntimeError( + " ".join( + [ + "Key-path is broken. Check if there is something that does NOT", + "support key-paths in the member chain.", + ] + ) + ) - @property - def keys(self, /) -> Sequence[str]: - return self.__keys + value = super().__getattribute__(key) + + recorder.busy = False + if recorder.start is _MISSING: + recorder.start = self + recorder.end = value + recorder.key_list.append(key) - def get(self, /) -> _Value_co: - value = self.__target - for key in self.__keys: - value = getattr(value, key) return value - def unsafe_set(self: KeyPath[_Value_0], value: _Value_0, /) -> None: - target = self.__target - keys = self.__keys - i_last_key = len(keys) - 1 - for i in range(i_last_key): - target = getattr(target, keys[i]) - setattr(target, keys[i_last_key], value) + # ! `__getattribute__(...)` is declared against `TYPE_CHECKING`, so that unknown + # ! attributes on conforming classes won't be treated as known by type-checkers. + if not TYPE_CHECKING: + __getattribute__ = _ - @deprecated("`KeyPath.set` is deprecated. Use `KeyPath.unsafe_set` instead.") - def set(self: KeyPath[_Value_0], value: _Value_0, /) -> None: - return self.unsafe_set(value) + del _ - @override - def __hash__(self, /) -> int: - return hash((self.target, self.keys)) - @override - def __eq__(self, other: object, /) -> bool: - return ( - isinstance(other, KeyPath) - and self.target is other.target - and self.keys == other.keys - ) +def key_path_supporting(clazz: type[_T], /) -> type[_T]: + """ + Patch on a class so that it can support key-paths. + + Examples + -------- + >>> @key_path_supporting + ... class C: + ... v = 0 + >>> c = C() + >>> key_path = KeyPath.of(c.v) + >>> assert key_path.base is c + >>> assert key_path.keys == ("v",) + """ - @override - def __repr__(self, /) -> str: - return f"{KeyPath.__name__}(target={self.target!r}, keys={self.keys!r})" + old_getattribute = clazz.__getattribute__ - def __call__(self, /) -> _Value_co: - return self.get() + def __getattribute__(self: _T, key: str) -> Any: + try: + recorder = _thread_local.recorder + except AttributeError: + # There is no recorder, which means that `KeyPath.of` is not being called. + # So we don't need to record this key. + return old_getattribute(self, key) + if recorder.busy: + # The recorder is busy, which means that another member is being accessed, + # typically because the computation of that member is dependent on this one. + # So we don't need to record this key. + return old_getattribute(self, key) -class _ThreadLocalProtocol(Protocol): - recorder: _KeyPathRecorder - """ - The active key-path recorder for this thread. May not exist. - """ + recorder.busy = True + if recorder.start is not _MISSING and recorder.end is not self: + raise RuntimeError( + " ".join( + [ + "Key-path is broken. Check if there is something that does NOT", + "support key-paths in the member chain.", + ] + ) + ) -_thread_local = cast("_ThreadLocalProtocol", threading.local()) + value = old_getattribute(self, key) + recorder.busy = False + if recorder.start is _MISSING: + recorder.start = self + recorder.end = value + recorder.key_list.append(key) -@final -class _KeyPathRecorder: - __slots__ = ("busy", "start", "end", "key_list") + return value - busy: bool - start: Any - end: Any - key_list: list[str] + clazz.__getattribute__ = __getattribute__ - def __init__(self, /) -> None: - self.busy = False - self.start = _MISSING - self.end = _MISSING - self.key_list = [] + return clazz diff --git a/runtime_keypath/_core_test.py b/runtime_keypath/_core_test.py index fb65b01..66ab214 100644 --- a/runtime_keypath/_core_test.py +++ b/runtime_keypath/_core_test.py @@ -9,7 +9,7 @@ from ._core import * -class Tests: +class Tests__KeyPathSupporting: @staticmethod def test__normal() -> None: class A(KeyPathSupporting): @@ -236,3 +236,249 @@ class C(KeyPathSupporting): key_path_2 = KeyPath.of(a.b.c.v) key_path_2.unsafe_set(12345) assert a.b.c.v == 12345 + + +class Tests__key_path_supporting: + @staticmethod + def test__normal() -> None: + @key_path_supporting + class A: + b: B + + def __init__(self) -> None: + self.b = B() + + @key_path_supporting + class B: + c: int + + def __init__(self) -> None: + self.c = 0 + + a = A() + key_path = KeyPath.of(a.b.c) + assert key_path == KeyPath(target=a, keys=("b", "c")) + assert key_path() == 0 + + a.b.c = 1 + assert key_path() == 1 + + @staticmethod + def test__cycle_reference() -> None: + @key_path_supporting + class A: + a: A + b: B + + def __init__(self) -> None: + self.a = self + self.b = B() + + @key_path_supporting + class B: + b: B + c: C + + def __init__(self) -> None: + self.b = self + self.c = C() + + class C: + pass + + a = A() + assert KeyPath.of(a.a.b.b.c) == KeyPath(target=a, keys=("a", "b", "b", "c")) + + @staticmethod + def test__common_mistakes() -> None: + @key_path_supporting + class A: + b: B + + def __init__(self) -> None: + self.b = B() + + @key_path_supporting + class B(KeyPathSupporting): + c: C + + def __init__(self) -> None: + self.c = C() + + class C: + pass + + a = A() + + with pytest.raises(Exception): + # Not even accessed a single member. + _ = KeyPath.of(a) + + with pytest.raises(Exception): + # Using something that is not a member chain. + _ = KeyPath.of(id(a.b.c)) + + with pytest.raises(Exception): + # Calling the same `KeyPath.of` more than once. + of = KeyPath.of + _ = of(a.b.c) + _ = of(a.b.c) + + @staticmethod + def test__error_handling() -> None: + @key_path_supporting + class A: + b: B + + def __init__(self) -> None: + self.b = B() + + @key_path_supporting + class B(KeyPathSupporting): + c: C + + def __init__(self) -> None: + self.c = C() + + class C: + pass + + a = A() + + with pytest.raises(AttributeError): + # Accessing something that doesn't exist. + _ = KeyPath.of(a.b.c.d) # type: ignore + + # With above exception caught, normal code should run correctly. + key_path = KeyPath.of(a.b.c) + assert key_path == KeyPath(target=a, keys=("b", "c")) + + @staticmethod + def test__threading() -> None: + @key_path_supporting + class A: + b: B + + def __init__(self) -> None: + self.b = B() + + @key_path_supporting + class B: + c: C + + def __init__(self) -> None: + self.c = C() + + class C: + pass + + a = A() + key_path_list: list[KeyPath] = [] + + def f() -> None: + # Sleeping for a short while so that the influence of starting a thread + # could be minimal. + time.sleep(1) + + key_path = KeyPath.of(a.b.c) + key_path_list.append(key_path) + + threads = [Thread(target=f) for _ in range(1000)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert len(key_path_list) == 1000 + assert all( + key_path == KeyPath(target=a, keys=("b", "c")) for key_path in key_path_list + ) + + @staticmethod + def test__internal_reference() -> None: + @key_path_supporting + class C: + @property + def v0(self) -> int: + return self.v1.v2 + + @property + def v1(self) -> C: + return self + + @property + def v2(self) -> int: + return 0 + + c = C() + assert KeyPath.of(c.v0) == KeyPath(target=c, keys=("v0",)) + + @staticmethod + def test__get() -> None: + MISSING = cast(Any, object()) + + @key_path_supporting + class A: + b: B = MISSING + + @key_path_supporting + class B: + c: C = MISSING + + @key_path_supporting + class C: + v: int = MISSING + + a = A() + b = B() + c = C() + + key_path_0 = KeyPath.of(a.b) + assert key_path_0.get() is MISSING + a.b = b + assert key_path_0.get() is b + + key_path_1 = KeyPath.of(a.b.c) + assert key_path_1.get() is MISSING + a.b.c = c + assert key_path_1.get() is c + + key_path_2 = KeyPath.of(a.b.c.v) + assert key_path_2.get() is MISSING + a.b.c.v = 12345 + assert key_path_2.get() == 12345 + + @staticmethod + def test__unsafe_set() -> None: + MISSING = cast(Any, object()) + + @key_path_supporting + class A: + b: B = MISSING + + @key_path_supporting + class B: + c: C = MISSING + + @key_path_supporting + class C: + v: int = MISSING + + a = A() + b = B() + c = C() + + assert a.b is MISSING + key_path_0 = KeyPath.of(a.b) + key_path_0.unsafe_set(b) + assert a.b is b + + assert a.b.c is MISSING + key_path_1 = KeyPath.of(a.b.c) + key_path_1.unsafe_set(c) + assert a.b.c is c + + assert a.b.c.v is MISSING + key_path_2 = KeyPath.of(a.b.c.v) + key_path_2.unsafe_set(12345) + assert a.b.c.v == 12345