diff --git a/tests/unit/test_workflows.py b/tests/unit/test_workflows.py index c7395f122..dcb6c1998 100644 --- a/tests/unit/test_workflows.py +++ b/tests/unit/test_workflows.py @@ -39,6 +39,7 @@ from tracecat.secrets.models import SecretCreate, SecretKeyValue from tracecat.secrets.service import SecretsService from tracecat.types.auth import Role +from tracecat.types.exceptions import TracecatValidationError from tracecat.workflow.management.definitions import WorkflowDefinitionsService from tracecat.workflow.management.management import WorkflowsManagementService @@ -2065,3 +2066,175 @@ async def test_workflow_runs_template_for_each( ), ) assert result == [101, 102, 103, 104, 105] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "trigger_inputs,expected_result,should_raise", + [ + # Test case 1: All required fields with valid values + ( + {"user_id": "john@example.com", "priority": "high", "count": 5}, + { + "ACTIONS": {}, + "INPUTS": {}, + "TRIGGER": { + "user_id": "john@example.com", + "priority": "high", + "count": 5, + }, + }, + False, + ), + # Test case 2: Using default values + ( + {"user_id": "jane@example.com"}, + { + "ACTIONS": {}, + "INPUTS": {}, + "TRIGGER": { + "user_id": "jane@example.com", + "priority": "low", + "count": 10, + }, + }, + False, + ), + # Test case 3: Invalid enum value + ( + {"user_id": "bob@example.com", "priority": "INVALID"}, + None, + True, + ), + # Test case 4: Missing required field + ( + {"priority": "medium", "count": 3}, + None, + True, + ), + # Test case 5: No expects defined + ( + {"any": "value"}, + { + "ACTIONS": {}, + "INPUTS": {}, + "TRIGGER": {"any": "value"}, + }, + False, + ), + ], + ids=[ + "valid_all_fields", + "with_defaults", + "invalid_enum", + "missing_required", + "no_expects", + ], +) +async def test_workflow_trigger_validation( + trigger_inputs, expected_result, should_raise, test_role, temporal_client +): + """Test workflow trigger input validation. + + This test verifies that: + 1. Required trigger inputs are properly validated + 2. Default values are correctly applied + 3. Enum values are properly validated + 4. Missing required fields are rejected + 5. Workflows with no expects defined accept any trigger inputs + """ + test_name = f"{test_workflow_trigger_validation.__name__}" + wf_exec_id = generate_test_exec_id(test_name) + + # Base DSL with validation + dsl_with_validation = { + "title": "Test Workflow Trigger Validation", + "description": "Test workflow with trigger input validation", + "entrypoint": { + "expects": { + "user_id": { + "type": "str", + "description": "User identifier", + }, + "priority": { + "type": 'enum["low", "medium", "high"]', + "description": "Task priority level", + "default": "low", + }, + "count": { + "type": "int", + "description": "Number of items", + "default": 10, + }, + }, + "ref": "start", + }, + "actions": [ + { + "ref": "start", + "action": "core.transform.reshape", + "args": {"value": "START"}, + }, + ], + "inputs": {}, + "returns": None, + "tests": [], + "triggers": [], + } + + # DSL without expects for the "no_expects" test case + dsl_without_validation = { + **dsl_with_validation, + "entrypoint": {"ref": "start"}, + } + + # Use appropriate DSL based on test case + dsl = DSLInput( + **( + dsl_without_validation + if trigger_inputs.get("any") == "value" + else dsl_with_validation + ) + ) + + run_args = DSLRunArgs( + dsl=dsl, + role=test_role, + wf_id=TEST_WF_ID, + trigger_inputs=trigger_inputs, + ) + + if should_raise: + with pytest.raises(TracecatValidationError) as exc_info: + async with Worker( + temporal_client, + task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"], + activities=DSLActivities.load() + DSL_UTILITIES, + workflows=[DSLWorkflow], + workflow_runner=new_sandbox_runner(), + ): + await temporal_client.execute_workflow( + DSLWorkflow.run, + run_args, + id=wf_exec_id, + task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"], + retry_policy=retry_policies["workflow:fail_fast"], + ) + # Verify that it's a validation error + assert "ValidationError" in str(exc_info.value) + else: + async with Worker( + temporal_client, + task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"], + activities=DSLActivities.load() + DSL_UTILITIES, + workflows=[DSLWorkflow], + workflow_runner=new_sandbox_runner(), + ): + result = await temporal_client.execute_workflow( + DSLWorkflow.run, + run_args, + id=wf_exec_id, + task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"], + retry_policy=retry_policies["workflow:fail_fast"], + ) + assert result == expected_result diff --git a/tracecat/dsl/validation.py b/tracecat/dsl/validation.py index f19c1aff4..8423421bf 100644 --- a/tracecat/dsl/validation.py +++ b/tracecat/dsl/validation.py @@ -33,9 +33,9 @@ def validate_trigger_inputs( } if isinstance(payload, dict): # NOTE: We only validate dict payloads for now - validator = create_expectation_model(expects_schema, model_name=model_name) + model = create_expectation_model(expects_schema, model_name=model_name) try: - validator(**payload) + validated_payload = model(**payload).model_dump(mode="json") except ValidationError as e: if raise_exceptions: raise @@ -44,7 +44,12 @@ def validate_trigger_inputs( msg=f"Validation error in trigger inputs ({e.title}). Please refer to the schema for more details.", detail={"errors": e.errors()}, ) - return ValidationResult(status="success", msg="Trigger inputs are valid.") + result = ValidationResult( + status="success", + msg="Trigger inputs are valid.", + payload=validated_payload, + ) + return result class ValidateTriggerInputsActivityInputs(BaseModel): diff --git a/tracecat/dsl/workflow.py b/tracecat/dsl/workflow.py index 9b13b0cda..1a736eec3 100644 --- a/tracecat/dsl/workflow.py +++ b/tracecat/dsl/workflow.py @@ -225,10 +225,11 @@ async def run(self, args: DSLRunArgs) -> Any: ) from e # Prepare user facing context + validated_payload = validation_result.payload or {} self.context: ExecutionContext = { ExprContext.ACTIONS: {}, ExprContext.INPUTS: self.dsl.inputs, - ExprContext.TRIGGER: trigger_inputs, + ExprContext.TRIGGER: validated_payload, ExprContext.ENV: DSLEnvironment( workflow={ "start_time": wf_info.start_time, diff --git a/tracecat/validation/models.py b/tracecat/validation/models.py index 6decc028b..60462f5ba 100644 --- a/tracecat/validation/models.py +++ b/tracecat/validation/models.py @@ -14,6 +14,7 @@ class ValidationResult(BaseModel): msg: str = "" detail: Any | None = None ref: str | None = None + payload: dict[str, Any] | None = None def __hash__(self) -> int: detail = json.dumps(self.detail, sort_keys=True)