diff --git a/burr/core/application.py b/burr/core/application.py index 9cbbcb1b..bfdd4999 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -543,8 +543,6 @@ class ApplicationContext(AbstractContextManager, ApplicationIdentifiers): """Application context. This is anything your node might need to know about the application. Often used for recursive tracking. - Note this is also a context manager (allowing you to pass context to sub-applications). - To access this object in a running application, you can use the `__context` variable in the action signature: @@ -556,6 +554,7 @@ class ApplicationContext(AbstractContextManager, ApplicationIdentifiers): def my_action(state: State, __context: ApplicationContext) -> State: app_id = __context.app_id partition_key = __context.partition_key + current_action_name = __context.action_name # Access the current action name ... """ @@ -564,6 +563,7 @@ def my_action(state: State, __context: ApplicationContext) -> State: parallel_executor_factory: Callable[[], Executor] state_initializer: Optional[BaseStateLoader] state_persister: Optional[BaseStateSaver] + action_name: Optional[str] # Store just the action name @staticmethod def get() -> Optional["ApplicationContext"]: @@ -865,6 +865,7 @@ def _context_factory(self, action: Action, sequence_id: int) -> ApplicationConte parallel_executor_factory=self._parallel_executor_factory, state_initializer=self._state_initializer, state_persister=self._state_persister, + action_name=action.name if action else None, # Pass just the action name ) def _step( diff --git a/docs/concepts/actions.rst b/docs/concepts/actions.rst index d5aaff0d..f8657e9e 100644 --- a/docs/concepts/actions.rst +++ b/docs/concepts/actions.rst @@ -110,7 +110,7 @@ Will require the inputs to be passed in at runtime. See below for how to do that This means that the application does not *need* the inputs to be set. -Note: to access ``app_id`` and ``partition_key`` in your running application, you can have the :py:class:`ApplicationContext ` +Note: to access application-level metadata such as ``app_id``, ``partition_key``, ``sequence_id``, and ``action_name`` in your running application, you can have the :py:class:`ApplicationContext ` injected into your Burr Actions. This is done by adding ``__context`` to the action signature: .. code-block:: python @@ -121,6 +121,8 @@ injected into your Burr Actions. This is done by adding ``__context`` to the act def my_action(state: State, __context: ApplicationContext) -> State: app_id = __context.app_id partition_key = __context.partition_key + action_name = __context.action_name + sequence_id = __context.sequence_id ... diff --git a/docs/concepts/state-persistence.rst b/docs/concepts/state-persistence.rst index 8ae5f884..cd24befe 100644 --- a/docs/concepts/state-persistence.rst +++ b/docs/concepts/state-persistence.rst @@ -43,7 +43,7 @@ Note that ``partition_key`` can be `None` if this is not relevant. A UUID is alw You set these values using the :py:meth:`with_identifiers() ` method. -Note: to access ``app_id`` and ``partition_key`` in your running application, you can have the :py:class:`ApplicationContext ` +Note: to access application-level metadata such as ``app_id``, ``partition_key``, ``sequence_id``, and ``action_name`` in your running application, you can have the :py:class:`ApplicationContext ` injected into your Burr Actions. This is done by adding ``__context`` to the action signature: .. code-block:: python @@ -54,6 +54,8 @@ injected into your Burr Actions. This is done by adding ``__context`` to the act def my_action(state: State, __context: ApplicationContext) -> State: app_id = __context.app_id partition_key = __context.partition_key + action_name = __context.action_name + sequence_id = __context.sequence_id ... diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 26191462..bbc9df46 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -1353,6 +1353,7 @@ def test_action(state: State, __context: ApplicationContext) -> State: assert __context.sequence_id == 0 assert __context.partition_key == PARTITION_KEY assert __context.app_id == APP_ID + assert __context.action_name == "test_action" return state app = ( @@ -1379,6 +1380,7 @@ def test_action(state: State, __context: ApplicationContext) -> State: assert __context.sequence_id == 0 assert __context.partition_key == PARTITION_KEY assert __context.app_id == APP_ID + assert __context.action_name == "test_action" return state app = ( diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py index 4c1f15f3..820924a5 100644 --- a/tests/core/test_parallelism.py +++ b/tests/core/test_parallelism.py @@ -1201,6 +1201,7 @@ def reads(self) -> list[str]: state_persister=persister, state_initializer=persister, parallel_executor_factory=lambda: concurrent.futures.ThreadPoolExecutor(), + action_name=action.name, ), inputs={}, )