From 8cac9b2d9b97b9bd56305015b8049b0ef2a7e8e3 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Sun, 26 Jan 2025 02:31:21 +0000 Subject: [PATCH] Add RegistryActionsService get_bound, fetch_all_action_secrets, get_actions --- tracecat/registry/actions/service.py | 48 +++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tracecat/registry/actions/service.py b/tracecat/registry/actions/service.py index 074a4118d..572047bab 100644 --- a/tracecat/registry/actions/service.py +++ b/tracecat/registry/actions/service.py @@ -61,7 +61,7 @@ async def list_actions( result = await self.session.exec(statement) return result.all() - async def get_action(self, *, action_name: str) -> RegistryAction: + async def get_action(self, action_name: str) -> RegistryAction: """Get an action by name.""" namespace, name = action_name.rsplit(".", maxsplit=1) statement = select(RegistryAction).where( @@ -75,6 +75,17 @@ async def get_action(self, *, action_name: str) -> RegistryAction: raise RegistryError(f"Action {namespace}.{name} not found in repository") return action + async def get_actions(self, action_names: list[str]) -> Sequence[RegistryAction]: + """Get actions by name.""" + statement = select(RegistryAction).where( + RegistryAction.owner_id == config.TRACECAT__DEFAULT_ORG_ID, + func.concat(RegistryAction.namespace, ".", RegistryAction.name).in_( + action_names + ), + ) + result = await self.session.exec(statement) + return result.all() + async def create_action( self, params: RegistryActionCreate, @@ -256,3 +267,38 @@ async def read_action_with_implicit_secrets( ) -> RegistryActionRead: extra_secrets = await self.get_action_implicit_secrets(action) return RegistryActionRead.from_database(action, extra_secrets) + + async def fetch_all_action_secrets( + self, action: RegistryAction + ) -> set[RegistrySecret]: + """Recursively fetch all secrets from the action and its template steps. + + Args: + action: The registry action to fetch secrets from + + Returns: + set[RegistrySecret]: A set of secret names used by the action and its template steps + """ + secrets = set() + impl = RegistryActionImplValidator.validate_python(action.implementation) + if impl.type == "udf": + if action.secrets: + secrets.update(RegistrySecret(**secret) for secret in action.secrets) + elif impl.type == "template": + ta = impl.template_action + if ta is None: + raise ValueError("Template action is not defined") + # Add secrets from the template action itself + if template_secrets := ta.definition.secrets: + secrets.update(template_secrets) + # Recursively fetch secrets from each step + step_action_names = [step.action for step in ta.definition.steps] + step_ras = await self.get_actions(step_action_names) + for step_ra in step_ras: + step_secrets = await self.fetch_all_action_secrets(step_ra) + secrets.update(step_secrets) + return secrets + + def get_bound(self, action: RegistryAction) -> BoundRegistryAction: + """Get the bound action for a registry action.""" + return get_bound_action_impl(action)