From 74bb3d7fe2277ed9a3a390a5940cb7faea19f505 Mon Sep 17 00:00:00 2001 From: Chris Lo <46541035+topher-lo@users.noreply.github.com> Date: Mon, 23 Dec 2024 07:43:02 -0800 Subject: [PATCH] feat+fix(engine)!: Support enums and defaults in Action Templates (#644) Signed-off-by: Chris Lo <46541035+topher-lo@users.noreply.github.com> --- registry/tracecat_registry/__init__.py | 8 +- .../tracecat_registry/_internal/exceptions.py | 11 - tests/unit/test_expectation.py | 86 +++++- tests/unit/test_registry.py | 2 +- tests/unit/test_template_actions.py | 247 ++++++++++++++---- tracecat/executor/service.py | 62 +++-- tracecat/expressions/expectations.py | 84 ++++-- tracecat/registry/actions/models.py | 13 +- tracecat/registry/repository.py | 13 +- tracecat/types/exceptions.py | 7 +- 10 files changed, 401 insertions(+), 132 deletions(-) diff --git a/registry/tracecat_registry/__init__.py b/registry/tracecat_registry/__init__.py index 036164b8f..a43fc2baf 100644 --- a/registry/tracecat_registry/__init__.py +++ b/registry/tracecat_registry/__init__.py @@ -10,11 +10,8 @@ "Could not import tracecat. Please install `tracecat` to use the registry." ) from None -from tracecat_registry._internal import registry, secrets -from tracecat_registry._internal.exceptions import ( # noqa: E402 - RegistryActionError, - RegistryValidationError, -) +from tracecat_registry._internal import exceptions, registry, secrets +from tracecat_registry._internal.exceptions import RegistryActionError from tracecat_registry._internal.logger import logger from tracecat_registry._internal.models import RegistrySecret @@ -24,6 +21,5 @@ "logger", "secrets", "exceptions", - "RegistryValidationError", "RegistryActionError", ] diff --git a/registry/tracecat_registry/_internal/exceptions.py b/registry/tracecat_registry/_internal/exceptions.py index ed17a8d78..bf3a55bbe 100644 --- a/registry/tracecat_registry/_internal/exceptions.py +++ b/registry/tracecat_registry/_internal/exceptions.py @@ -1,7 +1,5 @@ from typing import Any -from pydantic_core import ValidationError - class TracecatException(Exception): """Tracecat generic user-facing exception""" @@ -13,12 +11,3 @@ def __init__(self, *args, detail: Any | None = None, **kwargs): class RegistryActionError(TracecatException): """Exception raised when a registry UDF error occurs.""" - - -class RegistryValidationError(TracecatException): - """Exception raised when a registry validation error occurs.""" - - def __init__(self, *args, key: str, err: ValidationError | str | None = None): - super().__init__(*args) - self.key = key - self.err = err diff --git a/tests/unit/test_expectation.py b/tests/unit/test_expectation.py index ba0f9796b..98e3fc638 100644 --- a/tests/unit/test_expectation.py +++ b/tests/unit/test_expectation.py @@ -1,6 +1,7 @@ import json from datetime import datetime +import lark import pytest from pydantic import ValidationError @@ -247,4 +248,87 @@ def test_validate_schema_failure(): assert "end_time" in str(e.value) -# ... existing tests ... +@pytest.mark.parametrize( + "status,priority", + [ + ("PENDING", "low"), + ("running", "low"), + ("Completed", "low"), + ], +) +def test_validate_schema_with_enum(status, priority): + schema = { + "status": { + "type": 'enum["PENDING", "running", "Completed"]', + "description": "The status of the job", + }, + "priority": { + "type": 'enum["low", "medium", "high"]', + "description": "The priority level", + "default": "low", + }, + } + + mapped = {k: ExpectedField(**v) for k, v in schema.items()} + model = create_expectation_model(mapped) + + # Test with provided priority + model_instance = model(status=status, priority=priority) + assert model_instance.status.__class__.__name__ == "EnumStatus" + assert model_instance.priority.__class__.__name__ == "EnumPriority" + + # Test default priority + model_instance_default = model(status=status) + assert str(model_instance_default.priority) == "low" + + +@pytest.mark.parametrize( + "schema_def,error_type,error_message", + [ + ( + {"status": {"type": "enum[]", "description": "Empty enum"}}, + lark.exceptions.UnexpectedCharacters, + "No terminal matches ']'", + ), + ( + { + "status": { + "type": 'enum["Pending", "PENDING"]', + "description": "Duplicate values", + } + }, + lark.exceptions.VisitError, + "Duplicate enum value", + ), + ], +) +def test_validate_schema_with_invalid_enum_definition( + schema_def, error_type, error_message +): + with pytest.raises(error_type, match=error_message): + mapped = {k: ExpectedField(**v) for k, v in schema_def.items()} + create_expectation_model(mapped) + + +@pytest.mark.parametrize( + "invalid_value", + [ + "invalid_status", + "INVALID", + "pending!", + "", + ], +) +def test_validate_schema_with_invalid_enum_values(invalid_value): + schema = { + "status": { + "type": 'enum["PENDING", "running", "Completed"]', + "description": "The status of the job", + } + } + + mapped = {k: ExpectedField(**v) for k, v in schema.items()} + DynamicModel = create_expectation_model(mapped) + + with pytest.raises(ValidationError): + DynamicModel(status=invalid_value) diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index a143f20cb..6ba9d968c 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -2,12 +2,12 @@ import textwrap import pytest -from tracecat_registry import RegistryValidationError from tracecat.concurrency import GatheringTaskGroup from tracecat.registry.actions.models import RegistryActionRead from tracecat.registry.actions.service import RegistryActionsService from tracecat.registry.repository import Repository +from tracecat.types.exceptions import RegistryValidationError @pytest.fixture diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index cab879f12..f8f44f0e2 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -5,6 +5,7 @@ from tracecat.expressions.expectations import ExpectedField from tracecat.registry.actions.models import ActionStep, TemplateAction from tracecat.registry.repository import Repository +from tracecat.types.exceptions import RegistryValidationError def test_construct_template_action(): @@ -51,56 +52,76 @@ def test_construct_template_action(): }, } - try: - action = TemplateAction.model_validate(data) - except Exception as e: - pytest.fail(f"Failed to construct template action: {e}") - else: - assert action.definition.title == "Test Action" - assert action.definition.description == "This is just a test" - assert action.definition.action == "integrations.test.wrapper" - assert action.definition.namespace == "integrations.test" - assert action.definition.display_group == "Testing" - assert action.definition.secrets == [ - RegistrySecret(name="test_secret", keys=["KEY"]) - ] - assert action.definition.expects == { - "service_source": ExpectedField( - type="str", - description="The service source", - default="elastic", - ), - "limit": ExpectedField( - type="int | None", - description="The limit", - ), - } - assert action.definition.steps == [ - ActionStep( - ref="base", - action="core.transform.reshape", - args={ - "value": { - "service_source": "${{ inputs.service_source }}", - "data": 100, - } - }, - ), - ActionStep( - ref="final", - action="core.transform.reshape", - args={ - "value": [ - "${{ steps.base.result.data + 100 }}", - "${{ steps.base.result.service_source }}", - ] - }, - ), - ] + # Parse and validate the action + action = TemplateAction.model_validate(data) + + # Check the action definition + assert action.definition.title == "Test Action" + assert action.definition.description == "This is just a test" + assert action.definition.action == "integrations.test.wrapper" + assert action.definition.namespace == "integrations.test" + assert action.definition.display_group == "Testing" + assert action.definition.secrets == [ + RegistrySecret(name="test_secret", keys=["KEY"]) + ] + assert action.definition.expects == { + "service_source": ExpectedField( + type="str", + description="The service source", + default="elastic", + ), + "limit": ExpectedField( + type="int | None", + description="The limit", + ), + } + assert action.definition.steps == [ + ActionStep( + ref="base", + action="core.transform.reshape", + args={ + "value": { + "service_source": "${{ inputs.service_source }}", + "data": 100, + } + }, + ), + ActionStep( + ref="final", + action="core.transform.reshape", + args={ + "value": [ + "${{ steps.base.result.data + 100 }}", + "${{ steps.base.result.service_source }}", + ] + }, + ), + ] @pytest.mark.anyio -async def test_template_action_run(): +@pytest.mark.parametrize( + "test_args,expected_result,should_raise", + [ + ( + {"user_id": "john@tracecat.com", "service_source": "custom", "limit": 99}, + ["john@tracecat.com", "custom", 99], + False, + ), + ( + {"user_id": "john@tracecat.com"}, + ["john@tracecat.com", "elastic", 100], + False, + ), + ( + {}, + None, + True, + ), + ], + ids=["valid", "with_defaults", "missing_required"], +) +async def test_template_action_run(test_args, expected_result, should_raise): action = TemplateAction( **{ "type": "action", @@ -112,14 +133,22 @@ async def test_template_action_run(): "display_group": "Testing", "secrets": [{"name": "test_secret", "keys": ["KEY"]}], "expects": { + # Required field + "user_id": { + "type": "str", + "description": "The user ID", + }, + # Optional field with string defaultß "service_source": { "type": "str", "description": "The service source", "default": "elastic", }, + # Optional field with None as default "limit": { "type": "int | None", "description": "The limit", + "default": None, }, }, "steps": [ @@ -129,7 +158,7 @@ async def test_template_action_run(): "args": { "value": { "service_source": "${{ inputs.service_source }}", - "data": 100, + "data": "${{ inputs.limit || 100 }}", } }, }, @@ -138,8 +167,9 @@ async def test_template_action_run(): "action": "core.transform.reshape", "args": { "value": [ - "${{ steps.base.result.data + 100 }}", + "${{ inputs.user_id }}", "${{ steps.base.result.service_source }}", + "${{ steps.base.result.data }}", ] }, }, @@ -149,17 +179,124 @@ async def test_template_action_run(): } ) + # Register the action registry = Repository() registry.init(include_base=True, include_templates=False) registry.register_template_action(action) + + # Check that the action is registered assert action.definition.action == "integrations.test.wrapper" assert "core.transform.reshape" in registry assert action.definition.action in registry + # Get the registered action bound_action = registry.get(action.definition.action) - result = await service.run_template_action( - action=bound_action, - args={"service_source": "elastic"}, - context={}, - ) - assert result == [200, "elastic"] + + # Run the action + if should_raise: + with pytest.raises(RegistryValidationError): + await service.run_template_action( + action=bound_action, + args=test_args, + context={}, + ) + else: + result = await service.run_template_action( + action=bound_action, + args=test_args, + context={}, + ) + assert result == expected_result + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "test_args,expected_result,should_raise", + [ + ( + {"status": "critical", "message": "System CPU usage above 90%"}, + {"alert_status": "critical", "alert_message": "System CPU usage above 90%"}, + False, + ), + ( + {"message": "Informational message"}, + {"alert_status": "info", "alert_message": "Informational message"}, + False, + ), + ( + { + "status": "emergency", + "message": "This should fail", + }, + None, + True, + ), + ], + ids=["valid_status", "default_status", "invalid_status"], +) +async def test_enum_template_action(test_args, expected_result, should_raise): + """Test template action with enum status. + This test verifies that: + 1. The action can be constructed with an enum status + 2. The action can be run with a valid enum status + 3. The action can be run with a default enum status + 4. Invalid enum values are properly rejected + """ + data = { + "type": "action", + "definition": { + "title": "Test Alert Action", + "description": "Test action with enum status", + "name": "alert", + "namespace": "integrations.test", + "display_group": "Testing", + "expects": { + "status": { + "type": 'enum["critical", "warning", "info"]', + "description": "Alert severity level", + "default": "info", + }, + "message": {"type": "str", "description": "Alert message"}, + }, + "steps": [ + { + "ref": "format", + "action": "core.transform.reshape", + "args": { + "value": { + "alert_status": "${{ inputs.status }}", + "alert_message": "${{ inputs.message }}", + } + }, + } + ], + "returns": "${{ steps.format.result }}", + }, + } + + # Parse and validate the action + action = TemplateAction.model_validate(data) + + # Register the action + registry = Repository() + registry.init(include_base=True, include_templates=False) + registry.register_template_action(action) + + # Get the registered action + bound_action = registry.get(action.definition.action) + + # Run the action + if should_raise: + with pytest.raises(RegistryValidationError): + await service.run_template_action( + action=bound_action, + args=test_args, + context={}, + ) + else: + result = await service.run_template_action( + action=bound_action, + args=test_args, + context={}, + ) + assert result == expected_result diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index 4975ee651..97affcd5a 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -79,24 +79,23 @@ def sync_executor_entrypoint(input: RunActionInput[ArgsT], role: Role) -> Any: async def _run_action_direct( - *, action: BoundRegistryAction[ArgsClsT], args: ArgsT, validate: bool = False + *, action: BoundRegistryAction[ArgsClsT], args: ArgsT ) -> Any: """Execute the UDF directly. At this point, the UDF cannot be a template. """ - if validate: - # Optional, as we already validate in the caller - args = action.validate_args(**args) if action.is_template: - # This should not be reached + # This should not be reachable raise ValueError("Templates cannot be executed directly") + + validated_args = action.validate_args(**args) try: if action.is_async: logger.trace("Running UDF async") - return await action.fn(**args) + return await action.fn(**validated_args) logger.trace("Running UDF sync") - return await asyncio.to_thread(action.fn, **args) + return await asyncio.to_thread(action.fn, **validated_args) except Exception as e: logger.error( f"Error running UDF {action.action!r}", error=e, type=type(e).__name__ @@ -138,12 +137,15 @@ async def run_single_action( context[ExprContext.SECRETS] = context.get(ExprContext.SECRETS, {}) | secrets if action.is_template: - logger.info("Running template UDF async", action=action.name) - return await run_template_action(action=action, args=args, context=context) - flat_secrets = flatten_secrets(secrets) - with env_sandbox(flat_secrets): - # Run the UDF in the caller process (usually the worker) - return await _run_action_direct(action=action, args=args) + logger.info("Running template action async", action=action.name) + result = await run_template_action(action=action, args=args, context=context) + else: + logger.trace("Running UDF async", action=action.name) + flat_secrets = flatten_secrets(secrets) + with env_sandbox(flat_secrets): + result = await _run_action_direct(action=action, args=args) + + return result async def run_template_action( @@ -152,26 +154,30 @@ async def run_template_action( args: ArgsT, context: ExecutionContext | None = None, ) -> Any: - """Handle template execution. - - You should use `run_async` instead of calling this directly. - - Move the template action execution here, so we can - override run_async's implementation - """ + """Handle template execution.""" if not action.template_action: raise ValueError( "Attempted to run a non-template UDF as a template. " "Please use `run_single_action` instead." ) defn = action.template_action.definition + + # Validate arguments and apply defaults + logger.trace( + "Validating template action arguments", expects=defn.expects, args=args + ) + if defn.expects: + validated_args = action.validate_args(**args) + + secrets_context = {} + if context is not None: + secrets_context = context.get(ExprContext.SECRETS, {}) + template_context = cast( ExecutionContext, { - ExprContext.SECRETS: {} - if context is None - else context.get(ExprContext.SECRETS, {}), - ExprContext.TEMPLATE_ACTION_INPUTS: args, + ExprContext.SECRETS: secrets_context, + ExprContext.TEMPLATE_ACTION_INPUTS: validated_args, ExprContext.TEMPLATE_ACTION_STEPS: {}, }, ) @@ -194,9 +200,11 @@ async def run_template_action( ) # Store the result of the step logger.trace("Storing step result", step=step.ref, result=result) - template_context[ExprContext.TEMPLATE_ACTION_STEPS][step.ref] = DSLNodeResult( - result=result, - result_typename=type(result).__name__, + template_context[str(ExprContext.TEMPLATE_ACTION_STEPS)][step.ref] = ( + DSLNodeResult( + result=result, + result_typename=type(result).__name__, + ) ) # Handle returns diff --git a/tracecat/expressions/expectations.py b/tracecat/expressions/expectations.py index 85edb9952..1ebf53e07 100644 --- a/tracecat/expressions/expectations.py +++ b/tracecat/expressions/expectations.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from enum import Enum from typing import Any, Union from lark import Lark, Transformer, v_args @@ -6,12 +7,12 @@ from tracecat.logger import logger -# Define the Lark grammar for parsing types type_grammar = r""" ?type: primitive_type | list_type | dict_type | union_type + | enum_type | reference_type primitive_type: INTEGER @@ -25,6 +26,7 @@ INTEGER: "int" STRING: "str" +STRING_LITERAL: "\"" /[^"]*/ "\"" | "'" /[^']*/ "'" BOOLEAN: "bool" FLOAT: "float" DATETIME: "datetime" @@ -35,6 +37,7 @@ list_type: "list" "[" type "]" dict_type: "dict" "[" type "," type "]" union_type: type ("|" type)+ +enum_type: "enum" "[" STRING_LITERAL ("," STRING_LITERAL)* "]" reference_type: "$" CNAME CNAME: /[a-zA-Z_]\w*/ @@ -59,6 +62,12 @@ class TypeTransformer(Transformer): + MAX_ENUM_VALUES = 20 + + def __init__(self, field_name: str): + super().__init__() + self.field_name = field_name + @v_args(inline=True) def primitive_type(self, item) -> type | None: logger.trace("Primitive type:", item=item) @@ -87,10 +96,41 @@ def reference_type(self, name) -> str: logger.trace("Reference type:", name=name) return f"${name.value}" + @v_args(inline=True) + def enum_type(self, *values) -> Enum: + if len(values) > self.MAX_ENUM_VALUES: + raise ValueError(f"Too many enum values (maximum {self.MAX_ENUM_VALUES})") + + enum_values = {} + seen_values = set() + + for value in values: + if not value: + raise ValueError("Enum value cannot be empty") + + # Case-insensitive duplicate check + value_lower = value.lower() + if value_lower in seen_values: + raise ValueError(f"Duplicate enum value: {value}") + + seen_values.add(value_lower) + enum_values[value] = value + + # Convert to upper camel case (e.g., "user_status" -> "UserStatus") + enum_name = "".join(word.title() for word in self.field_name.split("_")) + logger.trace("Enum type:", name=enum_name, values=enum_values) + return Enum(f"Enum{enum_name}", enum_values) -def parse_type(type_string: str) -> Any: + @v_args(inline=True) + def STRING_LITERAL(self, value) -> str: + # Remove quotes from the value + value = str(value).strip('"').strip("'") + return value + + +def parse_type(type_string: str, field_name: str) -> Any: tree = type_parser.parse(type_string) - return TypeTransformer().transform(tree) + return TypeTransformer(field_name).transform(tree) class ExpectedField(BaseModel): @@ -103,26 +143,32 @@ def create_expectation_model( schema: dict[str, ExpectedField], model_name: str = "ExpectedSchemaModel" ) -> type[BaseModel]: fields = {} - field_info_kwargs = {} for field_name, field_info in schema.items(): - validated_field_info = ExpectedField.model_validate(field_info) # Defensive - field_type = parse_type(validated_field_info.type) - if validated_field_info.description: - field_info_kwargs["description"] = validated_field_info.description - field_info_kwargs["default"] = ( - validated_field_info.default - if "default" in validated_field_info.model_fields_set - else ... - ) - - fields[field_name] = ( - field_type, - Field(**field_info_kwargs), - ) + field_info_kwargs = {} + # Defensive validation + validated_field_info = ExpectedField.model_validate(field_info) + + # Extract metadata + field_type = parse_type(validated_field_info.type, field_name) + description = validated_field_info.description + + if description: + field_info_kwargs["description"] = description + + if "default" in validated_field_info.model_fields_set: + # If the field has a default value, use it + field_info_kwargs["default"] = validated_field_info.default + else: + # Use ... (ellipsis) to indicate a required field in Pydantic + field_info_kwargs["default"] = ... + + field = Field(**field_info_kwargs) + fields[field_name] = (field_type, field) logger.trace("Creating expectation model", model_name=model_name, fields=fields) - return create_model( + model = create_model( model_name, __config__=ConfigDict(extra="forbid", arbitrary_types_allowed=True), **fields, ) + return model diff --git a/tracecat/registry/actions/models.py b/tracecat/registry/actions/models.py index 47ec14594..bd4f492a7 100644 --- a/tracecat/registry/actions/models.py +++ b/tracecat/registry/actions/models.py @@ -17,12 +17,16 @@ computed_field, model_validator, ) -from tracecat_registry import RegistrySecret, RegistryValidationError +from tracecat_registry import RegistrySecret from tracecat.db.schemas import RegistryAction from tracecat.expressions.expectations import ExpectedField, create_expectation_model from tracecat.logger import logger -from tracecat.types.exceptions import RegistryActionError, TracecatValidationError +from tracecat.types.exceptions import ( + RegistryActionError, + RegistryValidationError, + TracecatValidationError, +) from tracecat.validation.models import ValidationResult ArgsClsT = TypeVar("ArgsClsT", bound=type[BaseModel]) @@ -134,7 +138,8 @@ def validate_args[T](self, *args, **kwargs) -> T: # Use cases would be transforming a UTC string to a datetime object # We return the validated input arguments as a dictionary validated: BaseModel = self.args_cls.model_validate(kwargs) - return cast(T, validated.model_dump()) + validated_args = cast(T, validated.model_dump(mode="json")) + return validated_args except ValidationError as e: logger.error( f"Validation error for bound registry action {self.action!r}. {e.errors()!r}" @@ -146,7 +151,7 @@ def validate_args[T](self, *args, **kwargs) -> T: ) from e except Exception as e: raise RegistryValidationError( - f"Unexpected error when validating input arguments for bound registry action {self.action!r}. {e}", + f"Unexpected error when validating input arguments for bound registry action {self.action!r}. {e!r}", key=self.action, ) from e diff --git a/tracecat/registry/repository.py b/tracecat/registry/repository.py index 06efe6e28..80a3dc244 100644 --- a/tracecat/registry/repository.py +++ b/tracecat/registry/repository.py @@ -323,7 +323,7 @@ def _register_udf_from_function( # Get function metadata key = getattr(fn, "__tracecat_udf_key") kwargs = getattr(fn, "__tracecat_udf_kwargs") - logger.info(f"Registering UDF: {key}", key=key, name=name) + logger.debug("Registering UDF", key=key, name=name) # Add validators to the function validated_kwargs = RegisterKwargs.model_validate(kwargs) attach_validators(fn, TemplateValidator()) @@ -432,13 +432,11 @@ def load_template_actions_from_package( def load_template_actions_from_path(self, *, path: Path, origin: str) -> int: """Load template actions from a package.""" - # Load the default templates - logger.info(f"Loading template actions from {path!s}") # Load all .yml files using rglob n_loaded = 0 all_paths = chain(path.rglob("*.yml"), path.rglob("*.yaml")) for file_path in all_paths: - logger.info(f"Loading template {file_path!s}") + logger.debug("Loading template action from path", path=file_path) # Load TemplateActionDefinition try: template_action = TemplateAction.from_yaml(file_path) @@ -450,14 +448,15 @@ def load_template_actions_from_path(self, *, path: Path, origin: str) -> int: continue except Exception as e: logger.error( - f"Unexpected error loading template action {file_path!s}", error=e + "Unexpected error loading template action", + error=e, + path=file_path, ) continue key = template_action.definition.action if key in self._store: - # Already registered, skip - logger.info(f"Template {key!r} already registered, skipping") + logger.debug("Template action already registered, skipping", key=key) continue self.register_template_action(template_action, origin=origin) diff --git a/tracecat/types/exceptions.py b/tracecat/types/exceptions.py index 81073b9fc..ea09153b5 100644 --- a/tracecat/types/exceptions.py +++ b/tracecat/types/exceptions.py @@ -63,11 +63,16 @@ class RegistryActionError(RegistryError): class RegistryValidationError(RegistryError): """Exception raised when a registry validation error occurs.""" - def __init__(self, *args, key: str, err: ValidationError | str | None = None): + def __init__( + self, *args, key: str | None = None, err: ValidationError | str | None = None + ): super().__init__(*args) self.key = key self.err = err + def __reduce__(self): + return (self.__class__, (self.key, self.err)) + class RegistryNotFound(RegistryError): """Exception raised when a registry is not found."""