diff --git a/pyproject.toml b/pyproject.toml index b729862..c816707 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "runtime-keypath" -version = "0.1.2" +version = "0.1.3" authors = [{ name = "Chris Fu", email = "17433201@qq.com" }] description = "Supports runtime key-path recording/accessing for Python." classifiers = [ diff --git a/runtime_keypath/_core.py b/runtime_keypath/_core.py index 2e182f1..609f91b 100644 --- a/runtime_keypath/_core.py +++ b/runtime_keypath/_core.py @@ -26,7 +26,6 @@ from typing import NamedTuple _Value_t = TypeVar("_Value_t") -_Value_co = TypeVar("_Value_co", covariant=True) class _ThreadLocalProtocol(Protocol): @@ -225,11 +224,11 @@ def __del__(self, /) -> None: @final -class KeyPath(Generic[_Value_co], metaclass=_KeyPathMeta): - __target: Final[object] +class KeyPath(Generic[_Value_t], metaclass=_KeyPathMeta): + __target: Final[Any] __keys: Final[Sequence[str]] - def __init__(self, /, target: object, keys: str | Sequence[str]) -> None: + def __init__(self, /, target: Any, keys: str | Sequence[str]) -> None: self.__target = target if isinstance(keys, str): @@ -239,13 +238,27 @@ def __init__(self, /, target: object, keys: str | Sequence[str]) -> None: self.__keys = keys @property - def target(self, /) -> object: + def target(self, /) -> Any: return self.__target @property def keys(self, /) -> Sequence[str]: return self.__keys + def get(self, /) -> _Value_t: + value = self.__target + for key in self.__keys: + value = getattr(value, key) + return value + + def set(self, value: _Value_t, /) -> None: + target = self.__target + keys = self.__keys + n = len(keys) - 1 + for i in range(n): + target = getattr(target, keys[i]) + setattr(target, keys[n], value) + def __hash__(self, /) -> int: return hash((self.target, self.keys)) @@ -259,11 +272,8 @@ def __eq__(self, other: object, /) -> bool: def __repr__(self, /) -> str: return f"{KeyPath.__name__}(target={self.target!r}, keys={self.keys!r})" - def __call__(self, /) -> _Value_co: - value = self.__target - for key in self.__keys: - value = getattr(value, key) - return cast("_Value_co", value) + def __call__(self, /) -> _Value_t: + return self.get() class KeyPathSupporting: diff --git a/runtime_keypath/_core_test.py b/runtime_keypath/_core_test.py index a4c3041..13bee5b 100644 --- a/runtime_keypath/_core_test.py +++ b/runtime_keypath/_core_test.py @@ -169,3 +169,36 @@ def v2(self) -> int: c = C() assert KeyPath.of(c.v0) == KeyPath(target=c, keys=("v0",)) + + +def test_get_set() -> None: + class A(KeyPathSupporting): + b: B | None = None + + class B(KeyPathSupporting): + c: C | None = None + + class C(KeyPathSupporting): + v: int | None = None + + a = A() + b = B() + c = C() + + key_path_0 = KeyPath.of(a.b) + assert key_path_0.get() is None + key_path_0.set(b) + assert a.b is b + assert key_path_0.get() is b + + key_path_1 = KeyPath.of(a.b.c) # type: ignore + assert key_path_1.get() is None + key_path_1.set(c) + assert a.b.c is c # type: ignore + assert key_path_1.get() is c + + key_path_2 = KeyPath.of(a.b.c.v) # type: ignore + assert key_path_2.get() is None + key_path_2.set(12345) + assert a.b.c.v == 12345 # type: ignore + assert key_path_2.get() == 12345