Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve type support for .get_state #4623

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
LockExpiredError,
ReflexRuntimeError,
SetUndefinedStateVarError,
StateMismatchError,
StateSchemaMismatchError,
StateSerializationError,
StateTooLargeError,
Expand Down Expand Up @@ -1543,19 +1544,27 @@ async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
# Return the direct parent of target_state_cls for subsequent linking.
return parent_state

def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from the cache.

Args:
state_cls: The class of the state.

Returns:
The instance of state_cls associated with this state's client_token.

Raises:
StateMismatchError: If the state instance is not of the expected type.
"""
root_state = self._get_root_state()
return root_state.get_substate(state_cls.get_full_name().split("."))
substate = root_state.get_substate(state_cls.get_full_name().split("."))
if not isinstance(substate, state_cls):
raise StateMismatchError(
f"Searched for state {state_cls.get_full_name()} but found {substate}."
)
return substate

async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from redis.

Args:
Expand All @@ -1566,6 +1575,7 @@ async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:

Raises:
RuntimeError: If redis is not used in this backend process.
StateMismatchError: If the state instance is not of the expected type.
"""
# Fetch all missing parent states from redis.
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
Expand All @@ -1577,14 +1587,22 @@ async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
return await state_manager.get_state(

state_in_redis = await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state_of_state_cls,
)

async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
if not isinstance(state_in_redis, state_cls):
raise StateMismatchError(
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
)

return state_in_redis

async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get an instance of the state associated with this token.

Allows for arbitrary access to sibling states from within an event handler.
Expand Down Expand Up @@ -2316,6 +2334,9 @@ def _deserialize(
return state


T_STATE = TypeVar("T_STATE", bound=BaseState)


class State(BaseState):
"""The app Base State."""

Expand Down
4 changes: 4 additions & 0 deletions reflex/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ class StateSerializationError(ReflexError):
"""Raised when the state cannot be serialized."""


class StateMismatchError(ReflexError, ValueError):
"""Raised when the state retrieved does not match the expected state."""


class SystemPackageMissingError(ReflexError):
"""Raised when a system package is missing."""

Expand Down
Loading