From 3fb7eb9a1e513ed69dd18e75d4a3cf44752605f4 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 19 Dec 2024 15:11:17 -0800 Subject: [PATCH 01/32] Add tests --- tests/unit/test_expectation.py | 87 +++++++++++++++- tests/unit/test_template_actions.py | 150 ++++++++++++++++++++++++++++ 2 files changed, 236 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_expectation.py b/tests/unit/test_expectation.py index ba0f9796b..c0943015d 100644 --- a/tests/unit/test_expectation.py +++ b/tests/unit/test_expectation.py @@ -1,6 +1,7 @@ import json from datetime import datetime +import lark import pytest from pydantic import ValidationError @@ -247,4 +248,88 @@ def test_validate_schema_failure(): assert "end_time" in str(e.value) -# ... existing tests ... +@pytest.mark.parametrize( + "input_value,expected_value,priority_value", + [ + ("PENDING", "Status.PENDING", "low"), + ("running", "Status.running", "low"), + ("Completed", "Status.Completed", "low"), + ], +) +def test_validate_schema_with_enum(input_value, expected_value, priority_value): + schema = { + "status": { + "type": 'enum["PENDING", "running", "Completed"]', + "description": "The status of the job", + }, + "priority": { + "type": 'enum["low", "medium", "high"]', + "description": "The priority level", + "default": "low", + }, + } + + mapped = {k: ExpectedField(**v) for k, v in schema.items()} + DynamicModel = create_expectation_model(mapped) + + # Test with provided priority + model_instance = DynamicModel(status=input_value, priority=priority_value) + assert str(model_instance.status) == expected_value + assert model_instance.status.__class__.__name__ == "Status" + assert model_instance.priority.__class__.__name__ == "Priority" + + # Test default priority + model_instance_default = DynamicModel(status=input_value) + assert str(model_instance_default.priority) == "low" + + +@pytest.mark.parametrize( + "schema_def,error_type,error_message", + [ + ( + {"status": {"type": "enum[]", "description": "Empty enum"}}, + lark.exceptions.UnexpectedCharacters, + "No terminal matches ']'", + ), + ( + { + "status": { + "type": 'enum["Pending", "PENDING"]', + "description": "Duplicate values", + } + }, + lark.exceptions.VisitError, + "Duplicate enum value", + ), + ], +) +def test_validate_schema_with_invalid_enum_definition( + schema_def, error_type, error_message +): + with pytest.raises(error_type, match=error_message): + mapped = {k: ExpectedField(**v) for k, v in schema_def.items()} + create_expectation_model(mapped) + + +@pytest.mark.parametrize( + "invalid_value", + [ + "invalid_status", + "INVALID", + "pending!", + "", + ], +) +def test_validate_schema_with_invalid_enum_values(invalid_value): + schema = { + "status": { + "type": 'enum["PENDING", "running", "Completed"]', + "description": "The status of the job", + } + } + + mapped = {k: ExpectedField(**v) for k, v in schema.items()} + DynamicModel = create_expectation_model(mapped) + + with pytest.raises(ValidationError): + DynamicModel(status=invalid_value) diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index cab879f12..149cbbc50 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -1,4 +1,5 @@ import pytest +from pydantic import ValidationError from tracecat_registry import RegistrySecret from tracecat.executor import service @@ -163,3 +164,152 @@ async def test_template_action_run(): context={}, ) assert result == [200, "elastic"] + + +@pytest.mark.anyio +async def test_template_action_with_enum(): + """Test template action with enum field types. + This test verifies that: + 1. Enum fields can be properly defined in template action expectations + 2. Enum values are correctly validated during model construction + 3. Default enum values are respected when not provided + 4. Enum values are correctly serialized to strings in template expressions + 5. The entire flow works in a realistic alerting system scenario + The test uses a simulated alert action that accepts a severity level (critical/warning/info) + """ + data = { + "type": "action", + "definition": { + "title": "Test Alert Action", + "description": "Test action with enum status", + "name": "alert", + "namespace": "integrations.test", + "display_group": "Testing", + "expects": { + "status": { + "type": 'enum["critical", "warning", "info"]', + "description": "Alert severity level", + "default": "info", + }, + "message": {"type": "str", "description": "Alert message"}, + }, + "steps": [ + { + "ref": "format", + "action": "core.transform.reshape", + "args": { + "value": { + "alert_status": "${{ inputs.status }}", + "alert_message": "${{ inputs.message }}", + } + }, + } + ], + "returns": "${{ steps.format.result }}", + }, + } + + try: + action = TemplateAction.model_validate(data) + except Exception as e: + pytest.fail(f"Failed to construct template action: {e}") + else: + assert action.definition.title == "Test Alert Action" + assert ( + action.definition.expects["status"].type + == 'enum["critical", "warning", "info"]' + ) + assert action.definition.expects["status"].default == "info" + + # Test running the action + registry = Repository() + registry.init(include_base=True, include_templates=False) + registry.register_template_action(action) + + bound_action = registry.get(action.definition.action) + result = await service.run_template_action( + action=bound_action, + args={"status": "critical", "message": "System CPU usage above 90%"}, + context={}, + ) + + assert result == { + "alert_status": "critical", + "alert_message": "System CPU usage above 90%", + } + + # Test with default status + result_default = await service.run_template_action( + action=bound_action, + args={"message": "Informational message"}, + context={}, + ) + + assert result_default == { + "alert_status": "info", + "alert_message": "Informational message", + } + + +@pytest.mark.anyio +async def test_template_action_with_invalid_enum(): + """Test template action with invalid enum value. + This test verifies that: + 1. Invalid enum values are properly rejected + 2. The error message is descriptive and helpful + 3. The validation happens at runtime during template execution + """ + data = { + "type": "action", + "definition": { + "title": "Test Alert Action", + "description": "Test action with enum status", + "name": "alert", + "namespace": "integrations.test", + "display_group": "Testing", + "expects": { + "status": { + "type": 'enum["critical", "warning", "info"]', + "description": "Alert severity level", + }, + "message": {"type": "str", "description": "Alert message"}, + }, + "steps": [ + { + "ref": "format", + "action": "core.transform.reshape", + "args": { + "value": { + "alert_status": "${{ inputs.status }}", + "alert_message": "${{ inputs.message }}", + } + }, + } + ], + "returns": "${{ steps.format.result }}", + }, + } + + action = TemplateAction.model_validate(data) + registry = Repository() + registry.init(include_base=True, include_templates=False) + registry.register_template_action(action) + + bound_action = registry.get(action.definition.action) + with pytest.raises(ValidationError) as exc_info: + await service.run_template_action( + action=bound_action, + args={ + "status": "emergency", # Invalid status - not in enum + "message": "This should fail", + }, + context={}, + ) + + error_msg = str(exc_info.value) + assert "status" in error_msg + assert "emergency" in error_msg + assert any( + "critical" in msg or "warning" in msg or "info" in msg + for msg in error_msg.split("\n") + ) From fceab7ab90f4b47108fe25d364a37d0a4f64ee0a Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 19 Dec 2024 15:25:33 -0800 Subject: [PATCH 02/32] Add more informative inline comments in tests --- tests/unit/test_template_actions.py | 249 +++++++++++----------------- 1 file changed, 101 insertions(+), 148 deletions(-) diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index 149cbbc50..da454639c 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -52,52 +52,51 @@ def test_construct_template_action(): }, } - try: - action = TemplateAction.model_validate(data) - except Exception as e: - pytest.fail(f"Failed to construct template action: {e}") - else: - assert action.definition.title == "Test Action" - assert action.definition.description == "This is just a test" - assert action.definition.action == "integrations.test.wrapper" - assert action.definition.namespace == "integrations.test" - assert action.definition.display_group == "Testing" - assert action.definition.secrets == [ - RegistrySecret(name="test_secret", keys=["KEY"]) - ] - assert action.definition.expects == { - "service_source": ExpectedField( - type="str", - description="The service source", - default="elastic", - ), - "limit": ExpectedField( - type="int | None", - description="The limit", - ), - } - assert action.definition.steps == [ - ActionStep( - ref="base", - action="core.transform.reshape", - args={ - "value": { - "service_source": "${{ inputs.service_source }}", - "data": 100, - } - }, - ), - ActionStep( - ref="final", - action="core.transform.reshape", - args={ - "value": [ - "${{ steps.base.result.data + 100 }}", - "${{ steps.base.result.service_source }}", - ] - }, - ), - ] + # Parse and validate the action + action = TemplateAction.model_validate(data) + + # Check the action definition + assert action.definition.title == "Test Action" + assert action.definition.description == "This is just a test" + assert action.definition.action == "integrations.test.wrapper" + assert action.definition.namespace == "integrations.test" + assert action.definition.display_group == "Testing" + assert action.definition.secrets == [ + RegistrySecret(name="test_secret", keys=["KEY"]) + ] + assert action.definition.expects == { + "service_source": ExpectedField( + type="str", + description="The service source", + default="elastic", + ), + "limit": ExpectedField( + type="int | None", + description="The limit", + ), + } + assert action.definition.steps == [ + ActionStep( + ref="base", + action="core.transform.reshape", + args={ + "value": { + "service_source": "${{ inputs.service_source }}", + "data": 100, + } + }, + ), + ActionStep( + ref="final", + action="core.transform.reshape", + args={ + "value": [ + "${{ steps.base.result.data + 100 }}", + "${{ steps.base.result.service_source }}", + ] + }, + ), + ] @pytest.mark.anyio @@ -150,32 +149,64 @@ async def test_template_action_run(): } ) + # Register the action registry = Repository() registry.init(include_base=True, include_templates=False) registry.register_template_action(action) + + # Check that the action is registered assert action.definition.action == "integrations.test.wrapper" assert "core.transform.reshape" in registry assert action.definition.action in registry + # Get the registered action bound_action = registry.get(action.definition.action) + + # Run the action result = await service.run_template_action( action=bound_action, args={"service_source": "elastic"}, context={}, ) + + # Check the result assert result == [200, "elastic"] @pytest.mark.anyio -async def test_template_action_with_enum(): - """Test template action with enum field types. +@pytest.mark.parametrize( + "test_args,expected_result,should_raise", + [ + # Valid cases + ( + {"status": "critical", "message": "System CPU usage above 90%"}, + {"alert_status": "critical", "alert_message": "System CPU usage above 90%"}, + False, + ), + ( + {"message": "Informational message"}, + {"alert_status": "info", "alert_message": "Informational message"}, + False, + ), + # Invalid case + ( + { + "status": "emergency", + "message": "This should fail", + }, + None, + True, + ), + ], + ids=["valid_status", "default_status", "invalid_status"], +) +async def test_template_action_with_enum(test_args, expected_result, should_raise): + """Test template action with enum status. This test verifies that: - 1. Enum fields can be properly defined in template action expectations - 2. Enum values are correctly validated during model construction - 3. Default enum values are respected when not provided - 4. Enum values are correctly serialized to strings in template expressions - 5. The entire flow works in a realistic alerting system scenario - The test uses a simulated alert action that accepts a severity level (critical/warning/info) + 1. The action can be constructed with an enum status + 2. The action can be run with a valid enum status + 3. The action can be run with a default enum status + 4. Invalid enum values are properly rejected """ data = { "type": "action", @@ -209,107 +240,29 @@ async def test_template_action_with_enum(): }, } - try: - action = TemplateAction.model_validate(data) - except Exception as e: - pytest.fail(f"Failed to construct template action: {e}") - else: - assert action.definition.title == "Test Alert Action" - assert ( - action.definition.expects["status"].type - == 'enum["critical", "warning", "info"]' - ) - assert action.definition.expects["status"].default == "info" + # Parse and validate the action + action = TemplateAction.model_validate(data) - # Test running the action + # Register the action registry = Repository() registry.init(include_base=True, include_templates=False) registry.register_template_action(action) + # Get the registered action bound_action = registry.get(action.definition.action) - result = await service.run_template_action( - action=bound_action, - args={"status": "critical", "message": "System CPU usage above 90%"}, - context={}, - ) - - assert result == { - "alert_status": "critical", - "alert_message": "System CPU usage above 90%", - } - - # Test with default status - result_default = await service.run_template_action( - action=bound_action, - args={"message": "Informational message"}, - context={}, - ) - - assert result_default == { - "alert_status": "info", - "alert_message": "Informational message", - } - - -@pytest.mark.anyio -async def test_template_action_with_invalid_enum(): - """Test template action with invalid enum value. - This test verifies that: - 1. Invalid enum values are properly rejected - 2. The error message is descriptive and helpful - 3. The validation happens at runtime during template execution - """ - data = { - "type": "action", - "definition": { - "title": "Test Alert Action", - "description": "Test action with enum status", - "name": "alert", - "namespace": "integrations.test", - "display_group": "Testing", - "expects": { - "status": { - "type": 'enum["critical", "warning", "info"]', - "description": "Alert severity level", - }, - "message": {"type": "str", "description": "Alert message"}, - }, - "steps": [ - { - "ref": "format", - "action": "core.transform.reshape", - "args": { - "value": { - "alert_status": "${{ inputs.status }}", - "alert_message": "${{ inputs.message }}", - } - }, - } - ], - "returns": "${{ steps.format.result }}", - }, - } - - action = TemplateAction.model_validate(data) - registry = Repository() - registry.init(include_base=True, include_templates=False) - registry.register_template_action(action) - bound_action = registry.get(action.definition.action) - with pytest.raises(ValidationError) as exc_info: - await service.run_template_action( + # Run the action + if should_raise: + with pytest.raises(ValidationError): + await service.run_template_action( + action=bound_action, + args=test_args, + context={}, + ) + else: + result = await service.run_template_action( action=bound_action, - args={ - "status": "emergency", # Invalid status - not in enum - "message": "This should fail", - }, + args=test_args, context={}, ) - - error_msg = str(exc_info.value) - assert "status" in error_msg - assert "emergency" in error_msg - assert any( - "critical" in msg or "warning" in msg or "info" in msg - for msg in error_msg.split("\n") - ) + assert result == expected_result From 3442d7750505e07b3caa98f907ed57425c03ffb4 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:16:14 -0800 Subject: [PATCH 03/32] Test defaults logic in template actions --- tests/unit/test_template_actions.py | 49 ++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index da454639c..82c41be2d 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -100,7 +100,28 @@ def test_construct_template_action(): @pytest.mark.anyio -async def test_template_action_run(): +@pytest.mark.parametrize( + "test_args,expected_result", + [ + ( + {"service_source": "custom", "limit": 100}, + [200, "custom"], + False, + ), + ( + {"limit": 100}, + [200, "elastic"], + False, + ), + ( + {}, + None, + True, + ), + ], + ids=["custom_source", "default_source", "missing_required"], +) +async def test_template_action_run(test_args, expected_result, should_raise): action = TemplateAction( **{ "type": "action", @@ -163,21 +184,26 @@ async def test_template_action_run(): bound_action = registry.get(action.definition.action) # Run the action - result = await service.run_template_action( - action=bound_action, - args={"service_source": "elastic"}, - context={}, - ) - - # Check the result - assert result == [200, "elastic"] + if should_raise: + with pytest.raises(ValidationError): + await service.run_template_action( + action=bound_action, + args=test_args, + context={}, + ) + else: + result = await service.run_template_action( + action=bound_action, + args=test_args, + context={}, + ) + assert result == expected_result @pytest.mark.anyio @pytest.mark.parametrize( "test_args,expected_result,should_raise", [ - # Valid cases ( {"status": "critical", "message": "System CPU usage above 90%"}, {"alert_status": "critical", "alert_message": "System CPU usage above 90%"}, @@ -188,7 +214,6 @@ async def test_template_action_run(): {"alert_status": "info", "alert_message": "Informational message"}, False, ), - # Invalid case ( { "status": "emergency", @@ -200,7 +225,7 @@ async def test_template_action_run(): ], ids=["valid_status", "default_status", "invalid_status"], ) -async def test_template_action_with_enum(test_args, expected_result, should_raise): +async def test_enum_template_action(test_args, expected_result, should_raise): """Test template action with enum status. This test verifies that: 1. The action can be constructed with an enum status From 32460bf1cdee68faab843a9be7f6ff26fd31a805 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:17:17 -0800 Subject: [PATCH 04/32] Add enum type to expectations --- tracecat/expressions/expectations.py | 71 ++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/tracecat/expressions/expectations.py b/tracecat/expressions/expectations.py index 85edb9952..e458aacf1 100644 --- a/tracecat/expressions/expectations.py +++ b/tracecat/expressions/expectations.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from enum import Enum from typing import Any, Union from lark import Lark, Transformer, v_args @@ -6,12 +7,12 @@ from tracecat.logger import logger -# Define the Lark grammar for parsing types type_grammar = r""" ?type: primitive_type | list_type | dict_type | union_type + | enum_type | reference_type primitive_type: INTEGER @@ -25,6 +26,7 @@ INTEGER: "int" STRING: "str" +STRING_LITERAL: "\"" /[^"]*/ "\"" BOOLEAN: "bool" FLOAT: "float" DATETIME: "datetime" @@ -35,6 +37,7 @@ list_type: "list" "[" type "]" dict_type: "dict" "[" type "," type "]" union_type: type ("|" type)+ +enum_type: "enum" "[" STRING_LITERAL ("," STRING_LITERAL)* "]" reference_type: "$" CNAME CNAME: /[a-zA-Z_]\w*/ @@ -59,6 +62,12 @@ class TypeTransformer(Transformer): + MAX_ENUM_VALUES = 20 + + def __init__(self, field_name: str): + super().__init__() + self.field_name = field_name + @v_args(inline=True) def primitive_type(self, item) -> type | None: logger.trace("Primitive type:", item=item) @@ -87,10 +96,42 @@ def reference_type(self, name) -> str: logger.trace("Reference type:", name=name) return f"${name.value}" + @v_args(inline=True) + def enum_type(self, *values) -> type: + if len(values) > self.MAX_ENUM_VALUES: + raise ValueError(f"Too many enum values (maximum {self.MAX_ENUM_VALUES})") + + enum_values = {} + seen_values = set() + + for value in values: + if not value: + raise ValueError("Enum value cannot be empty") + + # Case-insensitive duplicate check + value_lower = value.lower() + if value_lower in seen_values: + raise ValueError(f"Duplicate enum value: {value}") + + seen_values.add(value_lower) + enum_values[value] = value + + # Convert to upper camel case (e.g., "user_status" -> "UserStatus") + enum_name = "".join(word.title() for word in self.field_name.split("_")) + logger.trace("Enum type:", name=enum_name, values=enum_values) + return Enum(enum_name, enum_values) -def parse_type(type_string: str) -> Any: + @v_args(inline=True) + def STRING_LITERAL(self, value): + # Remove quotes from the value + value = value.strip('"').strip("'") + # Coerce to string + return str(value) + + +def parse_type(type_string: str, field_name: str | None = None) -> Any: tree = type_parser.parse(type_string) - return TypeTransformer().transform(tree) + return TypeTransformer(field_name).transform(tree) class ExpectedField(BaseModel): @@ -106,19 +147,21 @@ def create_expectation_model( field_info_kwargs = {} for field_name, field_info in schema.items(): validated_field_info = ExpectedField.model_validate(field_info) # Defensive - field_type = parse_type(validated_field_info.type) + field_type = parse_type(validated_field_info.type, field_name) + + # Add description if provided if validated_field_info.description: field_info_kwargs["description"] = validated_field_info.description - field_info_kwargs["default"] = ( - validated_field_info.default - if "default" in validated_field_info.model_fields_set - else ... - ) - - fields[field_name] = ( - field_type, - Field(**field_info_kwargs), - ) + + if validated_field_info.default: + # Add default if provided + field_info_kwargs["default"] = validated_field_info.default + else: + # Else field is a required field + field_info_kwargs["default"] = ... + + field = Field(**field_info_kwargs) + fields[field_name] = (field_type, field) logger.trace("Creating expectation model", model_name=model_name, fields=fields) return create_model( From ace09b179323930dd7e8ae45f98209527dd11f63 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:21:50 -0800 Subject: [PATCH 05/32] fix missing arg in test --- tests/unit/test_template_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index 82c41be2d..1628e7e26 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -101,7 +101,7 @@ def test_construct_template_action(): @pytest.mark.anyio @pytest.mark.parametrize( - "test_args,expected_result", + "test_args,expected_result,should_raise", [ ( {"service_source": "custom", "limit": 100}, From 35132633fe76e8dea9e860ad6261594a4320af29 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 21:05:55 -0800 Subject: [PATCH 06/32] Update verbosity of register logs --- tracecat/registry/repository.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tracecat/registry/repository.py b/tracecat/registry/repository.py index 06efe6e28..80a3dc244 100644 --- a/tracecat/registry/repository.py +++ b/tracecat/registry/repository.py @@ -323,7 +323,7 @@ def _register_udf_from_function( # Get function metadata key = getattr(fn, "__tracecat_udf_key") kwargs = getattr(fn, "__tracecat_udf_kwargs") - logger.info(f"Registering UDF: {key}", key=key, name=name) + logger.debug("Registering UDF", key=key, name=name) # Add validators to the function validated_kwargs = RegisterKwargs.model_validate(kwargs) attach_validators(fn, TemplateValidator()) @@ -432,13 +432,11 @@ def load_template_actions_from_package( def load_template_actions_from_path(self, *, path: Path, origin: str) -> int: """Load template actions from a package.""" - # Load the default templates - logger.info(f"Loading template actions from {path!s}") # Load all .yml files using rglob n_loaded = 0 all_paths = chain(path.rglob("*.yml"), path.rglob("*.yaml")) for file_path in all_paths: - logger.info(f"Loading template {file_path!s}") + logger.debug("Loading template action from path", path=file_path) # Load TemplateActionDefinition try: template_action = TemplateAction.from_yaml(file_path) @@ -450,14 +448,15 @@ def load_template_actions_from_path(self, *, path: Path, origin: str) -> int: continue except Exception as e: logger.error( - f"Unexpected error loading template action {file_path!s}", error=e + "Unexpected error loading template action", + error=e, + path=file_path, ) continue key = template_action.definition.action if key in self._store: - # Already registered, skip - logger.info(f"Template {key!r} already registered, skipping") + logger.debug("Template action already registered, skipping", key=key) continue self.register_template_action(template_action, origin=origin) From 29436ad710854fc9b1d70fb5a11c0dc050d66c58 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 21:14:50 -0800 Subject: [PATCH 07/32] Update default / non default args tests --- tests/unit/test_template_actions.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index 1628e7e26..e6fdf2cca 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -104,13 +104,13 @@ def test_construct_template_action(): "test_args,expected_result,should_raise", [ ( - {"service_source": "custom", "limit": 100}, - [200, "custom"], + {"user_id": "john@tracecat.com", "service_source": "custom", "limit": 99}, + ["john@tracecat.com", "custom", 99], False, ), ( - {"limit": 100}, - [200, "elastic"], + {"user_id": "john@tracecat.com"}, + ["john@tracecat.com", "elastic", 100], False, ), ( @@ -119,7 +119,7 @@ def test_construct_template_action(): True, ), ], - ids=["custom_source", "default_source", "missing_required"], + ids=["valid", "with_defaults", "missing_required"], ) async def test_template_action_run(test_args, expected_result, should_raise): action = TemplateAction( @@ -133,14 +133,22 @@ async def test_template_action_run(test_args, expected_result, should_raise): "display_group": "Testing", "secrets": [{"name": "test_secret", "keys": ["KEY"]}], "expects": { + # Required field + "user_id": { + "type": "str", + "description": "The user ID", + }, + # Optional field with string defaultß "service_source": { "type": "str", "description": "The service source", "default": "elastic", }, + # Optional field with None as default "limit": { - "type": "int | None", + "type": "int | null", "description": "The limit", + "default": "null", }, }, "steps": [ @@ -150,7 +158,7 @@ async def test_template_action_run(test_args, expected_result, should_raise): "args": { "value": { "service_source": "${{ inputs.service_source }}", - "data": 100, + "data": "${{ inputs.limit || 100 }}", } }, }, @@ -159,8 +167,9 @@ async def test_template_action_run(test_args, expected_result, should_raise): "action": "core.transform.reshape", "args": { "value": [ - "${{ steps.base.result.data + 100 }}", + "${{ inputs.user_id }}", "${{ steps.base.result.service_source }}", + "${{ steps.base.result.data }}", ] }, }, @@ -197,7 +206,7 @@ async def test_template_action_run(test_args, expected_result, should_raise): args=test_args, context={}, ) - assert result == expected_result + assert result == expected_result @pytest.mark.anyio From e17e75eaa64e5e40a437c49b94bd3a5d6e420274 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 21:15:33 -0800 Subject: [PATCH 08/32] feat: Use validated args in `run_template_action` Fixes missing defaults and no input validation on run. --- tracecat/executor/service.py | 51 ++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index 97c3aee1f..9527a9037 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -33,6 +33,7 @@ extract_templated_secrets, get_iterables_from_expression, ) +from tracecat.expressions.expectations import create_expectation_model from tracecat.expressions.shared import ExprContext from tracecat.logger import logger from tracecat.parse import traverse_leaves @@ -145,12 +146,15 @@ async def run_single_action( context["SECRETS"] = context.get("SECRETS", {}) | secrets if action.is_template: - logger.info("Running template UDF async", action=action.name) - return await run_template_action(action=action, args=args, context=context) - flat_secrets = flatten_secrets(secrets) - with env_sandbox(flat_secrets): - # Run the UDF in the caller process (usually the worker) - return await _run_action_direct(action=action, args=args) + logger.info("Running template action async", action=action.name) + result = await run_template_action(action=action, args=args, context=context) + else: + logger.trace("Running UDF async", action=action.name) + flat_secrets = flatten_secrets(secrets) + with env_sandbox(flat_secrets): + result = await _run_action_direct(action=action, args=args) + + return result async def run_template_action( @@ -159,25 +163,32 @@ async def run_template_action( args: ArgsT, context: DSLContext, ) -> Any: - """Handle template execution. - - You should use `run_async` instead of calling this directly. - - Move the template action execution here, so we can - override run_async's implementation - """ + """Handle template execution.""" if not action.template_action: raise ValueError( "Attempted to run a non-template UDF as a template. " "Please use `run_single_action` instead." ) defn = action.template_action.definition + + # Validate arguments and apply defaults + logger.trace( + "Validating template action arguments", expects=defn.expects, args=args + ) + if defn.expects: + model = create_expectation_model(defn.expects) + # In pydantic 2.x, we need to use `model_dump(mode="json")` + # so enums return the string value instead of the enum instance + args = model(**args).model_dump(mode="json") + + secrets_context = {} + if context is not None: + secrets_context = context.get(ExprContext.SECRETS, {}) + template_context = cast( DSLContext, { - ExprContext.SECRETS: {} - if context is None - else context.get(ExprContext.SECRETS, {}), + ExprContext.SECRETS: secrets_context, ExprContext.TEMPLATE_ACTION_INPUTS: args, ExprContext.TEMPLATE_ACTION_STEPS: {}, }, @@ -201,9 +212,11 @@ async def run_template_action( ) # Store the result of the step logger.trace("Storing step result", step=step.ref, result=result) - template_context[ExprContext.TEMPLATE_ACTION_STEPS][step.ref] = DSLNodeResult( - result=result, - result_typename=type(result).__name__, + template_context[str(ExprContext.TEMPLATE_ACTION_STEPS)][step.ref] = ( + DSLNodeResult( + result=result, + result_typename=type(result).__name__, + ) ) # Handle returns From b165b0838fc409ef20112333b266fb3a11c49621 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 21:16:19 -0800 Subject: [PATCH 09/32] Replace `None` with `null` (reverved string) to represent null defaults (optional fields with None as default value) --- tracecat/expressions/expectations.py | 37 +++++++++++++++++----------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/tracecat/expressions/expectations.py b/tracecat/expressions/expectations.py index e458aacf1..1bf493928 100644 --- a/tracecat/expressions/expectations.py +++ b/tracecat/expressions/expectations.py @@ -32,7 +32,7 @@ DATETIME: "datetime" DURATION: "duration" ANY: "any" -NULL: "None" +NULL: "null" list_type: "list" "[" type "]" dict_type: "dict" "[" type "," type "]" @@ -57,7 +57,7 @@ "datetime": datetime, "duration": timedelta, "any": Any, - "None": None, + "null": None, } @@ -97,7 +97,7 @@ def reference_type(self, name) -> str: return f"${name.value}" @v_args(inline=True) - def enum_type(self, *values) -> type: + def enum_type(self, *values) -> Enum: if len(values) > self.MAX_ENUM_VALUES: raise ValueError(f"Too many enum values (maximum {self.MAX_ENUM_VALUES})") @@ -122,14 +122,14 @@ def enum_type(self, *values) -> type: return Enum(enum_name, enum_values) @v_args(inline=True) - def STRING_LITERAL(self, value): + def STRING_LITERAL(self, value) -> str: # Remove quotes from the value value = value.strip('"').strip("'") # Coerce to string return str(value) -def parse_type(type_string: str, field_name: str | None = None) -> Any: +def parse_type(type_string: str, field_name: str) -> Any: tree = type_parser.parse(type_string) return TypeTransformer(field_name).transform(tree) @@ -144,28 +144,35 @@ def create_expectation_model( schema: dict[str, ExpectedField], model_name: str = "ExpectedSchemaModel" ) -> type[BaseModel]: fields = {} - field_info_kwargs = {} for field_name, field_info in schema.items(): - validated_field_info = ExpectedField.model_validate(field_info) # Defensive + field_info_kwargs = {} + # Defensive validation + validated_field_info = ExpectedField.model_validate(field_info) + + # Extract metadata field_type = parse_type(validated_field_info.type, field_name) + default = validated_field_info.default + description = validated_field_info.description - # Add description if provided - if validated_field_info.description: - field_info_kwargs["description"] = validated_field_info.description + if description: + field_info_kwargs["description"] = description - if validated_field_info.default: - # Add default if provided - field_info_kwargs["default"] = validated_field_info.default + if default == "null": + # "null" indicates an explicit optional field, which evaluates to None + field_info_kwargs["default"] = None + elif default is not None: + field_info_kwargs["default"] = default else: - # Else field is a required field + # Use ... (ellipsis) to indicate a required field in Pydantic field_info_kwargs["default"] = ... field = Field(**field_info_kwargs) fields[field_name] = (field_type, field) logger.trace("Creating expectation model", model_name=model_name, fields=fields) - return create_model( + model = create_model( model_name, __config__=ConfigDict(extra="forbid", arbitrary_types_allowed=True), **fields, ) + return model From fc400871d1dd103544d26cead80020fb6581695c Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 22:14:30 -0800 Subject: [PATCH 10/32] Drop unused validate path in _run_action_direct --- tracecat/executor/service.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index 9527a9037..b4b606e7b 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -33,7 +33,6 @@ extract_templated_secrets, get_iterables_from_expression, ) -from tracecat.expressions.expectations import create_expectation_model from tracecat.expressions.shared import ExprContext from tracecat.logger import logger from tracecat.parse import traverse_leaves @@ -87,17 +86,14 @@ async def coro(): async def _run_action_direct( - *, action: BoundRegistryAction[ArgsClsT], args: ArgsT, validate: bool = False + *, action: BoundRegistryAction[ArgsClsT], args: ArgsT ) -> Any: """Execute the UDF directly. At this point, the UDF cannot be a template. """ - if validate: - # Optional, as we already validate in the caller - args = action.validate_args(**args) if action.is_template: - # This should not be reached + # Defensive check raise ValueError("Templates cannot be executed directly") try: if action.is_async: @@ -176,10 +172,7 @@ async def run_template_action( "Validating template action arguments", expects=defn.expects, args=args ) if defn.expects: - model = create_expectation_model(defn.expects) - # In pydantic 2.x, we need to use `model_dump(mode="json")` - # so enums return the string value instead of the enum instance - args = model(**args).model_dump(mode="json") + args = action.validate_args(**args) secrets_context = {} if context is not None: From 9d5495ee64bc7b7296db8f8d1763b4cd261e25d6 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 22:14:46 -0800 Subject: [PATCH 11/32] Check RegistryValidationError raised --- tests/unit/test_template_actions.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index e6fdf2cca..0f5db3fc7 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -1,6 +1,5 @@ import pytest -from pydantic import ValidationError -from tracecat_registry import RegistrySecret +from tracecat_registry import RegistrySecret, RegistryValidationError from tracecat.executor import service from tracecat.expressions.expectations import ExpectedField @@ -194,7 +193,7 @@ async def test_template_action_run(test_args, expected_result, should_raise): # Run the action if should_raise: - with pytest.raises(ValidationError): + with pytest.raises(RegistryValidationError): await service.run_template_action( action=bound_action, args=test_args, @@ -287,7 +286,7 @@ async def test_enum_template_action(test_args, expected_result, should_raise): # Run the action if should_raise: - with pytest.raises(ValidationError): + with pytest.raises(RegistryValidationError): await service.run_template_action( action=bound_action, args=test_args, From e72fbff7f03e4faf2637d7cccb3d74bccb262fd9 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 22:15:06 -0800 Subject: [PATCH 12/32] return validated_args in BoundRegistryAction.validate_args with json mode --- tracecat/registry/actions/models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tracecat/registry/actions/models.py b/tracecat/registry/actions/models.py index 415f7a375..87135c540 100644 --- a/tracecat/registry/actions/models.py +++ b/tracecat/registry/actions/models.py @@ -133,7 +133,7 @@ def validate_args[T](self, *args, **kwargs) -> T: # Use cases would be transforming a UTC string to a datetime object # We return the validated input arguments as a dictionary validated: BaseModel = self.args_cls.model_validate(kwargs) - return cast(T, validated.model_dump()) + validated_args = cast(T, validated.model_dump(mode="json")) except ValidationError as e: logger.error( f"Validation error for bound registry action {self.action!r}. {e.errors()!r}" @@ -148,6 +148,8 @@ def validate_args[T](self, *args, **kwargs) -> T: f"Unexpected error when validating input arguments for bound registry action {self.action!r}. {e}", key=self.action, ) from e + else: + return validated_args # Templates From 134794a0dd4e99096e4f5b13c8d1c0a1aa248413 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 22:28:27 -0800 Subject: [PATCH 13/32] fix replace str | None with str | null in wazuh `run_rootcheck` --- registry/tracecat_registry/templates/wazuh/run_rootcheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml b/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml index 85255a70f..13fe50d95 100644 --- a/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml +++ b/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml @@ -12,7 +12,7 @@ definition: - WAZUH_API_URL expects: agents_list: - type: str | None + type: str | null description: List of agent IDs (separated by comma), all agents selected by default if not specified. default: null verify_ssl: From 4fc933fc4e19adb73263f070b70d206dddc0a97d Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 22:38:14 -0800 Subject: [PATCH 14/32] fix int | null in _example_template_action --- .../templates/internal/_example_template_action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/registry/tracecat_registry/templates/internal/_example_template_action.yml b/registry/tracecat_registry/templates/internal/_example_template_action.yml index 5c6954433..e5e4ae364 100644 --- a/registry/tracecat_registry/templates/internal/_example_template_action.yml +++ b/registry/tracecat_registry/templates/internal/_example_template_action.yml @@ -16,7 +16,7 @@ definition: description: The service source default: elastic limit: - type: int | None + type: int | null description: The limit # Layers are used to define a sequence of operations steps: From 1bb2bc8d795e0327cd7cfaa9dde323720a8281b8 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Fri, 20 Dec 2024 23:35:28 -0800 Subject: [PATCH 15/32] Replace int | None with int | null in test_expectations schemas --- tests/unit/test_expectation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_expectation.py b/tests/unit/test_expectation.py index c0943015d..a035e3d06 100644 --- a/tests/unit/test_expectation.py +++ b/tests/unit/test_expectation.py @@ -20,7 +20,7 @@ def test_validate_schema(): "description": "The end time", }, "nullable": { - "type": "int | None", + "type": "int | null", "description": "An nullable integer", }, "optional_with_default": { @@ -81,7 +81,7 @@ def test_dynamic_model_with_optional_field_omitted(): "description": "The end time", }, "nullable": { - "type": "int | None", + "type": "int | null", "description": "An nullable integer", }, "optional_with_default": { @@ -130,7 +130,7 @@ def test_dynamic_model_with_invalid_data(): "description": "The end time", }, "nullable": { - "type": "int | None", + "type": "int | null", "description": "A nullable integer", }, "optional_with_default": { From b9248570a40953ef6a62f2aea50b80f5a4f65538 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 21 Dec 2024 01:14:51 -0800 Subject: [PATCH 16/32] Revert null type --- .../templates/wazuh/run_rootcheck.yml | 4 ++-- tests/unit/test_expectation.py | 6 +++--- tests/unit/test_template_actions.py | 2 +- tracecat/expressions/expectations.py | 11 ++++------- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml b/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml index 13fe50d95..c1d8bbfd3 100644 --- a/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml +++ b/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml @@ -12,9 +12,9 @@ definition: - WAZUH_API_URL expects: agents_list: - type: str | null + type: str | None description: List of agent IDs (separated by comma), all agents selected by default if not specified. - default: null + default: None verify_ssl: type: bool description: If False, disables SSL verification for internal networks. diff --git a/tests/unit/test_expectation.py b/tests/unit/test_expectation.py index a035e3d06..c0943015d 100644 --- a/tests/unit/test_expectation.py +++ b/tests/unit/test_expectation.py @@ -20,7 +20,7 @@ def test_validate_schema(): "description": "The end time", }, "nullable": { - "type": "int | null", + "type": "int | None", "description": "An nullable integer", }, "optional_with_default": { @@ -81,7 +81,7 @@ def test_dynamic_model_with_optional_field_omitted(): "description": "The end time", }, "nullable": { - "type": "int | null", + "type": "int | None", "description": "An nullable integer", }, "optional_with_default": { @@ -130,7 +130,7 @@ def test_dynamic_model_with_invalid_data(): "description": "The end time", }, "nullable": { - "type": "int | null", + "type": "int | None", "description": "A nullable integer", }, "optional_with_default": { diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index 0f5db3fc7..a79aade1a 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -145,7 +145,7 @@ async def test_template_action_run(test_args, expected_result, should_raise): }, # Optional field with None as default "limit": { - "type": "int | null", + "type": "int | None", "description": "The limit", "default": "null", }, diff --git a/tracecat/expressions/expectations.py b/tracecat/expressions/expectations.py index 1bf493928..66e201ea0 100644 --- a/tracecat/expressions/expectations.py +++ b/tracecat/expressions/expectations.py @@ -32,7 +32,7 @@ DATETIME: "datetime" DURATION: "duration" ANY: "any" -NULL: "null" +NULL: "None" list_type: "list" "[" type "]" dict_type: "dict" "[" type "," type "]" @@ -151,17 +151,14 @@ def create_expectation_model( # Extract metadata field_type = parse_type(validated_field_info.type, field_name) - default = validated_field_info.default description = validated_field_info.description if description: field_info_kwargs["description"] = description - if default == "null": - # "null" indicates an explicit optional field, which evaluates to None - field_info_kwargs["default"] = None - elif default is not None: - field_info_kwargs["default"] = default + if "default" in validated_field_info.model_fields_set: + # If the field has a default value, use it + field_info_kwargs["default"] = validated_field_info.default else: # Use ... (ellipsis) to indicate a required field in Pydantic field_info_kwargs["default"] = ... From 112deb1e218636dfe2e190e5898a1accf9c761e3 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 21 Dec 2024 01:17:52 -0800 Subject: [PATCH 17/32] revert nulls --- .../templates/internal/_example_template_action.yml | 2 +- registry/tracecat_registry/templates/wazuh/run_rootcheck.yml | 2 +- tests/unit/test_template_actions.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/registry/tracecat_registry/templates/internal/_example_template_action.yml b/registry/tracecat_registry/templates/internal/_example_template_action.yml index e5e4ae364..5c6954433 100644 --- a/registry/tracecat_registry/templates/internal/_example_template_action.yml +++ b/registry/tracecat_registry/templates/internal/_example_template_action.yml @@ -16,7 +16,7 @@ definition: description: The service source default: elastic limit: - type: int | null + type: int | None description: The limit # Layers are used to define a sequence of operations steps: diff --git a/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml b/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml index c1d8bbfd3..85255a70f 100644 --- a/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml +++ b/registry/tracecat_registry/templates/wazuh/run_rootcheck.yml @@ -14,7 +14,7 @@ definition: agents_list: type: str | None description: List of agent IDs (separated by comma), all agents selected by default if not specified. - default: None + default: null verify_ssl: type: bool description: If False, disables SSL verification for internal networks. diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index a79aade1a..0d898daf4 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -147,7 +147,7 @@ async def test_template_action_run(test_args, expected_result, should_raise): "limit": { "type": "int | None", "description": "The limit", - "default": "null", + "default": None, }, }, "steps": [ From a5a11942e39ae0703fdce8440b6fd349115f832e Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 21 Dec 2024 01:23:40 -0800 Subject: [PATCH 18/32] revert null : None mapping --- tracecat/expressions/expectations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tracecat/expressions/expectations.py b/tracecat/expressions/expectations.py index 66e201ea0..8c2479db8 100644 --- a/tracecat/expressions/expectations.py +++ b/tracecat/expressions/expectations.py @@ -57,7 +57,7 @@ "datetime": datetime, "duration": timedelta, "any": Any, - "null": None, + "None": None, } From 794fe9862591eb749e99d15e0f1014de7491f6f8 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 21 Dec 2024 01:25:33 -0800 Subject: [PATCH 19/32] validate args in run action direct (udf) --- tracecat/executor/service.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index b4b606e7b..c1f14c7c6 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -95,12 +95,14 @@ async def _run_action_direct( if action.is_template: # Defensive check raise ValueError("Templates cannot be executed directly") + + validated_args = action.validate_args(**args) try: if action.is_async: logger.trace("Running UDF async") - return await action.fn(**args) + return await action.fn(**validated_args) logger.trace("Running UDF sync") - return await asyncio.to_thread(action.fn, **args) + return await asyncio.to_thread(action.fn, **validated_args) except Exception as e: logger.error( f"Error running UDF {action.action!r}", error=e, type=type(e).__name__ From 4e4ba4c50f7c951a5d47d4803eaa1ad8de6ae15b Mon Sep 17 00:00:00 2001 From: Chris Lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 21 Dec 2024 21:43:24 -0800 Subject: [PATCH 20/32] Apply suggestions from code review Signed-off-by: Chris Lo <46541035+topher-lo@users.noreply.github.com> --- tracecat/executor/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index c1f14c7c6..e194c3e77 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -93,7 +93,7 @@ async def _run_action_direct( At this point, the UDF cannot be a template. """ if action.is_template: - # Defensive check + # This should not be reachable raise ValueError("Templates cannot be executed directly") validated_args = action.validate_args(**args) From 3e1d6aa725aecb58ee03bb0900c20dd94f2386df Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 21 Dec 2024 21:46:49 -0800 Subject: [PATCH 21/32] Updates --- tracecat/expressions/expectations.py | 5 ++--- tracecat/registry/actions/models.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tracecat/expressions/expectations.py b/tracecat/expressions/expectations.py index 8c2479db8..6c44d196b 100644 --- a/tracecat/expressions/expectations.py +++ b/tracecat/expressions/expectations.py @@ -124,9 +124,8 @@ def enum_type(self, *values) -> Enum: @v_args(inline=True) def STRING_LITERAL(self, value) -> str: # Remove quotes from the value - value = value.strip('"').strip("'") - # Coerce to string - return str(value) + value = str(value).strip('"').strip("'") + return value def parse_type(type_string: str, field_name: str) -> Any: diff --git a/tracecat/registry/actions/models.py b/tracecat/registry/actions/models.py index 87135c540..af8f617ae 100644 --- a/tracecat/registry/actions/models.py +++ b/tracecat/registry/actions/models.py @@ -134,6 +134,7 @@ def validate_args[T](self, *args, **kwargs) -> T: # We return the validated input arguments as a dictionary validated: BaseModel = self.args_cls.model_validate(kwargs) validated_args = cast(T, validated.model_dump(mode="json")) + return validated_args except ValidationError as e: logger.error( f"Validation error for bound registry action {self.action!r}. {e.errors()!r}" @@ -148,8 +149,6 @@ def validate_args[T](self, *args, **kwargs) -> T: f"Unexpected error when validating input arguments for bound registry action {self.action!r}. {e}", key=self.action, ) from e - else: - return validated_args # Templates From 80cc4b8235a12f1a5b8e91feddde8cbc8cf66fea Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 21 Dec 2024 21:49:20 -0800 Subject: [PATCH 22/32] Support single quotes in enum --- tracecat/expressions/expectations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tracecat/expressions/expectations.py b/tracecat/expressions/expectations.py index 6c44d196b..0955ccfc0 100644 --- a/tracecat/expressions/expectations.py +++ b/tracecat/expressions/expectations.py @@ -26,7 +26,7 @@ INTEGER: "int" STRING: "str" -STRING_LITERAL: "\"" /[^"]*/ "\"" +STRING_LITERAL: "\"" /[^"]*/ "\"" | "'" /[^']*/ "'" BOOLEAN: "bool" FLOAT: "float" DATETIME: "datetime" From d39f06f47d81dc6a328fac812b0f369db3031b4b Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 21 Dec 2024 21:51:50 -0800 Subject: [PATCH 23/32] Add Enum as prefix to enum class name --- tracecat/expressions/expectations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tracecat/expressions/expectations.py b/tracecat/expressions/expectations.py index 0955ccfc0..1ebf53e07 100644 --- a/tracecat/expressions/expectations.py +++ b/tracecat/expressions/expectations.py @@ -119,7 +119,7 @@ def enum_type(self, *values) -> Enum: # Convert to upper camel case (e.g., "user_status" -> "UserStatus") enum_name = "".join(word.title() for word in self.field_name.split("_")) logger.trace("Enum type:", name=enum_name, values=enum_values) - return Enum(enum_name, enum_values) + return Enum(f"Enum{enum_name}", enum_values) @v_args(inline=True) def STRING_LITERAL(self, value) -> str: From 73986420040d55a6f024a25ad12e2339b95ba095 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 21 Dec 2024 23:46:25 -0800 Subject: [PATCH 24/32] refactor: Call TEMPLATE_ACTION_INPUTS var `validated_args` --- tracecat/executor/service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index e194c3e77..f00813c45 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -174,7 +174,7 @@ async def run_template_action( "Validating template action arguments", expects=defn.expects, args=args ) if defn.expects: - args = action.validate_args(**args) + validated_args = action.validate_args(**args) secrets_context = {} if context is not None: @@ -184,7 +184,7 @@ async def run_template_action( DSLContext, { ExprContext.SECRETS: secrets_context, - ExprContext.TEMPLATE_ACTION_INPUTS: args, + ExprContext.TEMPLATE_ACTION_INPUTS: validated_args, ExprContext.TEMPLATE_ACTION_STEPS: {}, }, ) From c9fa57cc78ef70728b356d707f4344ea55f3bf78 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sun, 22 Dec 2024 00:06:52 -0800 Subject: [PATCH 25/32] Capture e representation for unexpected errors in action.validate_args --- tracecat/registry/actions/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tracecat/registry/actions/models.py b/tracecat/registry/actions/models.py index af8f617ae..20ddcb0f9 100644 --- a/tracecat/registry/actions/models.py +++ b/tracecat/registry/actions/models.py @@ -146,7 +146,7 @@ def validate_args[T](self, *args, **kwargs) -> T: ) from e except Exception as e: raise RegistryValidationError( - f"Unexpected error when validating input arguments for bound registry action {self.action!r}. {e}", + f"Unexpected error when validating input arguments for bound registry action {self.action!r}. {e!r}", key=self.action, ) from e From b5d70eef6b6aa0c06a010c570bf05799e11a7e5b Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sun, 22 Dec 2024 16:44:39 -0800 Subject: [PATCH 26/32] Drop RegistryValidationError from registry --- registry/tracecat_registry/__init__.py | 8 ++------ registry/tracecat_registry/_internal/exceptions.py | 11 ----------- tests/unit/test_registry.py | 2 +- tests/unit/test_template_actions.py | 3 ++- tracecat/registry/actions/models.py | 6 +++--- tracecat/types/exceptions.py | 7 ++++++- 6 files changed, 14 insertions(+), 23 deletions(-) diff --git a/registry/tracecat_registry/__init__.py b/registry/tracecat_registry/__init__.py index 036164b8f..a43fc2baf 100644 --- a/registry/tracecat_registry/__init__.py +++ b/registry/tracecat_registry/__init__.py @@ -10,11 +10,8 @@ "Could not import tracecat. Please install `tracecat` to use the registry." ) from None -from tracecat_registry._internal import registry, secrets -from tracecat_registry._internal.exceptions import ( # noqa: E402 - RegistryActionError, - RegistryValidationError, -) +from tracecat_registry._internal import exceptions, registry, secrets +from tracecat_registry._internal.exceptions import RegistryActionError from tracecat_registry._internal.logger import logger from tracecat_registry._internal.models import RegistrySecret @@ -24,6 +21,5 @@ "logger", "secrets", "exceptions", - "RegistryValidationError", "RegistryActionError", ] diff --git a/registry/tracecat_registry/_internal/exceptions.py b/registry/tracecat_registry/_internal/exceptions.py index ed17a8d78..bf3a55bbe 100644 --- a/registry/tracecat_registry/_internal/exceptions.py +++ b/registry/tracecat_registry/_internal/exceptions.py @@ -1,7 +1,5 @@ from typing import Any -from pydantic_core import ValidationError - class TracecatException(Exception): """Tracecat generic user-facing exception""" @@ -13,12 +11,3 @@ def __init__(self, *args, detail: Any | None = None, **kwargs): class RegistryActionError(TracecatException): """Exception raised when a registry UDF error occurs.""" - - -class RegistryValidationError(TracecatException): - """Exception raised when a registry validation error occurs.""" - - def __init__(self, *args, key: str, err: ValidationError | str | None = None): - super().__init__(*args) - self.key = key - self.err = err diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index a143f20cb..6ba9d968c 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -2,12 +2,12 @@ import textwrap import pytest -from tracecat_registry import RegistryValidationError from tracecat.concurrency import GatheringTaskGroup from tracecat.registry.actions.models import RegistryActionRead from tracecat.registry.actions.service import RegistryActionsService from tracecat.registry.repository import Repository +from tracecat.types.exceptions import RegistryValidationError @pytest.fixture diff --git a/tests/unit/test_template_actions.py b/tests/unit/test_template_actions.py index 0d898daf4..f8f44f0e2 100644 --- a/tests/unit/test_template_actions.py +++ b/tests/unit/test_template_actions.py @@ -1,10 +1,11 @@ import pytest -from tracecat_registry import RegistrySecret, RegistryValidationError +from tracecat_registry import RegistrySecret from tracecat.executor import service from tracecat.expressions.expectations import ExpectedField from tracecat.registry.actions.models import ActionStep, TemplateAction from tracecat.registry.repository import Repository +from tracecat.types.exceptions import RegistryValidationError def test_construct_template_action(): diff --git a/tracecat/registry/actions/models.py b/tracecat/registry/actions/models.py index 20ddcb0f9..2cd72b8e4 100644 --- a/tracecat/registry/actions/models.py +++ b/tracecat/registry/actions/models.py @@ -16,12 +16,12 @@ computed_field, model_validator, ) -from tracecat_registry import RegistrySecret, RegistryValidationError +from tracecat_registry import RegistrySecret from tracecat.db.schemas import RegistryAction from tracecat.expressions.expectations import ExpectedField, create_expectation_model from tracecat.logger import logger -from tracecat.types.exceptions import RegistryActionError, TracecatValidationError +from tracecat.types.exceptions import RegistryActionError, RegistryValidationError from tracecat.validation.models import ValidationResult ArgsClsT = TypeVar("ArgsClsT", bound=type[BaseModel]) @@ -390,7 +390,7 @@ def from_validation_result( ) @staticmethod - def from_dsl_validation_error(exc: TracecatValidationError): + def from_dsl_validation_error(exc: RegistryValidationError): return RegistryActionValidateResponse( ok=False, message=str(exc), detail=exc.detail ) diff --git a/tracecat/types/exceptions.py b/tracecat/types/exceptions.py index d365fa0b8..618980bbb 100644 --- a/tracecat/types/exceptions.py +++ b/tracecat/types/exceptions.py @@ -63,11 +63,16 @@ class RegistryActionError(RegistryError): class RegistryValidationError(RegistryError): """Exception raised when a registry validation error occurs.""" - def __init__(self, *args, key: str, err: ValidationError | str | None = None): + def __init__( + self, *args, key: str | None = None, err: ValidationError | str | None = None + ): super().__init__(*args) self.key = key self.err = err + def __reduce__(self): + return (self.__class__, (self.detail, self.key, self.err)) + class RegistryNotFound(RegistryError): """Exception raised when a registry is not found.""" From 8091d72f4dd5fa93920ee6b781239b0b6d6b725e Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sun, 22 Dec 2024 16:52:50 -0800 Subject: [PATCH 27/32] Make RegistryValidationError serializable --- tracecat/types/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tracecat/types/exceptions.py b/tracecat/types/exceptions.py index 618980bbb..e5a914c59 100644 --- a/tracecat/types/exceptions.py +++ b/tracecat/types/exceptions.py @@ -71,7 +71,7 @@ def __init__( self.err = err def __reduce__(self): - return (self.__class__, (self.detail, self.key, self.err)) + return (self.__class__, (self.key, self.err)) class RegistryNotFound(RegistryError): From a8e09fa0a781efc07c2e1676d46c94bf0ab88043 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sun, 22 Dec 2024 16:55:50 -0800 Subject: [PATCH 28/32] Fix expected enum class value with prefix --- tests/unit/test_expectation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_expectation.py b/tests/unit/test_expectation.py index c0943015d..b1a5df9e2 100644 --- a/tests/unit/test_expectation.py +++ b/tests/unit/test_expectation.py @@ -251,9 +251,9 @@ def test_validate_schema_failure(): @pytest.mark.parametrize( "input_value,expected_value,priority_value", [ - ("PENDING", "Status.PENDING", "low"), - ("running", "Status.running", "low"), - ("Completed", "Status.Completed", "low"), + ("PENDING", "EnumStatus.PENDING", "low"), + ("running", "EnumStatus.running", "low"), + ("Completed", "EnumStatus.Completed", "low"), ], ) def test_validate_schema_with_enum(input_value, expected_value, priority_value): From 402fe85275df34c6fd2aa4ccbfa3b6f1828d4a82 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:08:14 -0800 Subject: [PATCH 29/32] more test fixes --- tests/unit/test_expectation.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_expectation.py b/tests/unit/test_expectation.py index b1a5df9e2..7cd6fc8cf 100644 --- a/tests/unit/test_expectation.py +++ b/tests/unit/test_expectation.py @@ -249,14 +249,14 @@ def test_validate_schema_failure(): @pytest.mark.parametrize( - "input_value,expected_value,priority_value", + "input_value,priority_value", [ - ("PENDING", "EnumStatus.PENDING", "low"), - ("running", "EnumStatus.running", "low"), - ("Completed", "EnumStatus.Completed", "low"), + ("PENDING", "low"), + ("running", "low"), + ("Completed", "low"), ], ) -def test_validate_schema_with_enum(input_value, expected_value, priority_value): +def test_validate_schema_with_enum(input_value, priority_value): schema = { "status": { "type": 'enum["PENDING", "running", "Completed"]', @@ -274,9 +274,8 @@ def test_validate_schema_with_enum(input_value, expected_value, priority_value): # Test with provided priority model_instance = DynamicModel(status=input_value, priority=priority_value) - assert str(model_instance.status) == expected_value - assert model_instance.status.__class__.__name__ == "Status" - assert model_instance.priority.__class__.__name__ == "Priority" + assert model_instance.status.__class__.__name__ == "EnumStatus" + assert model_instance.priority.__class__.__name__ == "EnumPriority" # Test default priority model_instance_default = DynamicModel(status=input_value) From 15287bf17de81a63339ca57013ae8f623a8c282d Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:09:42 -0800 Subject: [PATCH 30/32] simplify enum validate schema param names --- tests/unit/test_expectation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_expectation.py b/tests/unit/test_expectation.py index 7cd6fc8cf..98e3fc638 100644 --- a/tests/unit/test_expectation.py +++ b/tests/unit/test_expectation.py @@ -249,14 +249,14 @@ def test_validate_schema_failure(): @pytest.mark.parametrize( - "input_value,priority_value", + "status,priority", [ ("PENDING", "low"), ("running", "low"), ("Completed", "low"), ], ) -def test_validate_schema_with_enum(input_value, priority_value): +def test_validate_schema_with_enum(status, priority): schema = { "status": { "type": 'enum["PENDING", "running", "Completed"]', @@ -270,15 +270,15 @@ def test_validate_schema_with_enum(input_value, priority_value): } mapped = {k: ExpectedField(**v) for k, v in schema.items()} - DynamicModel = create_expectation_model(mapped) + model = create_expectation_model(mapped) # Test with provided priority - model_instance = DynamicModel(status=input_value, priority=priority_value) + model_instance = model(status=status, priority=priority) assert model_instance.status.__class__.__name__ == "EnumStatus" assert model_instance.priority.__class__.__name__ == "EnumPriority" # Test default priority - model_instance_default = DynamicModel(status=input_value) + model_instance_default = model(status=status) assert str(model_instance_default.priority) == "low" From b7373ced4814e67ea4c63bf2322f8056c31874be Mon Sep 17 00:00:00 2001 From: Chris Lo <46541035+topher-lo@users.noreply.github.com> Date: Mon, 23 Dec 2024 02:49:43 -0800 Subject: [PATCH 31/32] Update models.py Signed-off-by: Chris Lo <46541035+topher-lo@users.noreply.github.com> --- tracecat/registry/actions/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tracecat/registry/actions/models.py b/tracecat/registry/actions/models.py index 589488bb5..439652646 100644 --- a/tracecat/registry/actions/models.py +++ b/tracecat/registry/actions/models.py @@ -22,7 +22,7 @@ from tracecat.db.schemas import RegistryAction from tracecat.expressions.expectations import ExpectedField, create_expectation_model from tracecat.logger import logger -from tracecat.types.exceptions import RegistryActionError, RegistryValidationError +from tracecat.types.exceptions import RegistryActionError, TracecatValidationError from tracecat.validation.models import ValidationResult ArgsClsT = TypeVar("ArgsClsT", bound=type[BaseModel]) @@ -391,7 +391,7 @@ def from_validation_result( ) @staticmethod - def from_dsl_validation_error(exc: RegistryValidationError): + def from_dsl_validation_error(exc: TracecatValidationError): return RegistryActionValidateResponse( ok=False, message=str(exc), detail=exc.detail ) From 90f26c40dc1678ff50420d46b418ca140c406930 Mon Sep 17 00:00:00 2001 From: topher-lo <46541035+topher-lo@users.noreply.github.com> Date: Mon, 23 Dec 2024 03:10:43 -0800 Subject: [PATCH 32/32] fix missing raise RegistryValidationError( --- tracecat/registry/actions/models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tracecat/registry/actions/models.py b/tracecat/registry/actions/models.py index 439652646..bd4f492a7 100644 --- a/tracecat/registry/actions/models.py +++ b/tracecat/registry/actions/models.py @@ -22,7 +22,11 @@ from tracecat.db.schemas import RegistryAction from tracecat.expressions.expectations import ExpectedField, create_expectation_model from tracecat.logger import logger -from tracecat.types.exceptions import RegistryActionError, TracecatValidationError +from tracecat.types.exceptions import ( + RegistryActionError, + RegistryValidationError, + TracecatValidationError, +) from tracecat.validation.models import ValidationResult ArgsClsT = TypeVar("ArgsClsT", bound=type[BaseModel])