|
1 | 1 | import json
|
2 | 2 | import logging
|
3 | 3 | import textwrap
|
| 4 | +from copy import deepcopy |
4 | 5 | from functools import partial
|
5 | 6 | from typing import Any, Awaitable, Callable, Dict, List, Optional
|
6 | 7 | from unittest.mock import MagicMock
|
|
18 | 19 | from prefect.utilities.dockerutils import get_prefect_image_name
|
19 | 20 | from pydantic import VERSION as PYDANTIC_VERSION
|
20 | 21 |
|
| 22 | +from prefect_aws.workers.ecs_worker import ECSWorker |
| 23 | + |
21 | 24 | if PYDANTIC_VERSION.startswith("2."):
|
22 | 25 | from pydantic.v1 import ValidationError
|
23 | 26 | else:
|
@@ -2047,3 +2050,189 @@ async def test_kill_with_grace_period(aws_credentials, caplog):
|
2047 | 2050 |
|
2048 | 2051 | # Logs warning
|
2049 | 2052 | assert "grace period of 60s requested, but AWS does not support" in caplog.text
|
| 2053 | + |
| 2054 | + |
| 2055 | +@pytest.fixture |
| 2056 | +def default_base_job_template(): |
| 2057 | + return deepcopy(ECSWorker.get_default_base_job_template()) |
| 2058 | + |
| 2059 | + |
| 2060 | +@pytest.fixture |
| 2061 | +def base_job_template_with_defaults(default_base_job_template, aws_credentials): |
| 2062 | + base_job_template_with_defaults = deepcopy(default_base_job_template) |
| 2063 | + base_job_template_with_defaults["variables"]["properties"]["command"][ |
| 2064 | + "default" |
| 2065 | + ] = "python my_script.py" |
| 2066 | + base_job_template_with_defaults["variables"]["properties"]["env"]["default"] = { |
| 2067 | + "VAR1": "value1", |
| 2068 | + "VAR2": "value2", |
| 2069 | + } |
| 2070 | + base_job_template_with_defaults["variables"]["properties"]["labels"]["default"] = { |
| 2071 | + "label1": "value1", |
| 2072 | + "label2": "value2", |
| 2073 | + } |
| 2074 | + base_job_template_with_defaults["variables"]["properties"]["name"][ |
| 2075 | + "default" |
| 2076 | + ] = "prefect-job" |
| 2077 | + base_job_template_with_defaults["variables"]["properties"]["image"][ |
| 2078 | + "default" |
| 2079 | + ] = "docker.io/my_image:latest" |
| 2080 | + base_job_template_with_defaults["variables"]["properties"]["aws_credentials"][ |
| 2081 | + "default" |
| 2082 | + ] = {"$ref": {"block_document_id": str(aws_credentials._block_document_id)}} |
| 2083 | + base_job_template_with_defaults["variables"]["properties"]["launch_type"][ |
| 2084 | + "default" |
| 2085 | + ] = "FARGATE_SPOT" |
| 2086 | + base_job_template_with_defaults["variables"]["properties"]["vpc_id"][ |
| 2087 | + "default" |
| 2088 | + ] = "vpc-123456" |
| 2089 | + base_job_template_with_defaults["variables"]["properties"]["task_role_arn"][ |
| 2090 | + "default" |
| 2091 | + ] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" |
| 2092 | + base_job_template_with_defaults["variables"]["properties"]["execution_role_arn"][ |
| 2093 | + "default" |
| 2094 | + ] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" |
| 2095 | + base_job_template_with_defaults["variables"]["properties"]["cluster"][ |
| 2096 | + "default" |
| 2097 | + ] = "test-cluster" |
| 2098 | + base_job_template_with_defaults["variables"]["properties"]["cpu"]["default"] = 2048 |
| 2099 | + base_job_template_with_defaults["variables"]["properties"]["memory"][ |
| 2100 | + "default" |
| 2101 | + ] = 4096 |
| 2102 | + |
| 2103 | + base_job_template_with_defaults["variables"]["properties"]["family"][ |
| 2104 | + "default" |
| 2105 | + ] = "test-family" |
| 2106 | + base_job_template_with_defaults["variables"]["properties"]["task_definition_arn"][ |
| 2107 | + "default" |
| 2108 | + ] = "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1" |
| 2109 | + base_job_template_with_defaults["variables"]["properties"][ |
| 2110 | + "cloudwatch_logs_options" |
| 2111 | + ]["default"] = { |
| 2112 | + "awslogs-group": "prefect", |
| 2113 | + "awslogs-region": "us-east-1", |
| 2114 | + "awslogs-stream-prefix": "prefect", |
| 2115 | + } |
| 2116 | + base_job_template_with_defaults["variables"]["properties"][ |
| 2117 | + "configure_cloudwatch_logs" |
| 2118 | + ]["default"] = True |
| 2119 | + base_job_template_with_defaults["variables"]["properties"]["stream_output"][ |
| 2120 | + "default" |
| 2121 | + ] = True |
| 2122 | + base_job_template_with_defaults["variables"]["properties"][ |
| 2123 | + "task_watch_poll_interval" |
| 2124 | + ]["default"] = 5.1 |
| 2125 | + base_job_template_with_defaults["variables"]["properties"][ |
| 2126 | + "task_start_timeout_seconds" |
| 2127 | + ]["default"] = 60 |
| 2128 | + base_job_template_with_defaults["variables"]["properties"][ |
| 2129 | + "auto_deregister_task_definition" |
| 2130 | + ]["default"] = False |
| 2131 | + return base_job_template_with_defaults |
| 2132 | + |
| 2133 | + |
| 2134 | +@pytest.fixture |
| 2135 | +def base_job_template_with_task_arn(default_base_job_template, aws_credentials): |
| 2136 | + base_job_template_with_task_arn = deepcopy(default_base_job_template) |
| 2137 | + base_job_template_with_task_arn["variables"]["properties"]["image"][ |
| 2138 | + "default" |
| 2139 | + ] = "docker.io/my_image:latest" |
| 2140 | + |
| 2141 | + base_job_template_with_task_arn["job_configuration"]["task_definition"] = { |
| 2142 | + "containerDefinitions": [ |
| 2143 | + {"image": "docker.io/my_image:latest", "name": "prefect-job"} |
| 2144 | + ], |
| 2145 | + "cpu": "2048", |
| 2146 | + "family": "test-family", |
| 2147 | + "memory": "2024", |
| 2148 | + "executionRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole", |
| 2149 | + } |
| 2150 | + return base_job_template_with_task_arn |
| 2151 | + |
| 2152 | + |
| 2153 | +@pytest.mark.parametrize( |
| 2154 | + "job_config", |
| 2155 | + [ |
| 2156 | + "default", |
| 2157 | + "custom", |
| 2158 | + "task_definition_arn", |
| 2159 | + ], |
| 2160 | +) |
| 2161 | +async def test_generate_work_pool_base_job_template( |
| 2162 | + job_config, |
| 2163 | + base_job_template_with_defaults, |
| 2164 | + aws_credentials, |
| 2165 | + default_base_job_template, |
| 2166 | + base_job_template_with_task_arn, |
| 2167 | + caplog, |
| 2168 | +): |
| 2169 | + job = ECSTask() |
| 2170 | + expected_template = default_base_job_template |
| 2171 | + expected_template["variables"]["properties"]["image"][ |
| 2172 | + "default" |
| 2173 | + ] = get_prefect_image_name() |
| 2174 | + if job_config == "custom": |
| 2175 | + expected_template = base_job_template_with_defaults |
| 2176 | + job = ECSTask( |
| 2177 | + command=["python", "my_script.py"], |
| 2178 | + env={"VAR1": "value1", "VAR2": "value2"}, |
| 2179 | + labels={"label1": "value1", "label2": "value2"}, |
| 2180 | + name="prefect-job", |
| 2181 | + image="docker.io/my_image:latest", |
| 2182 | + aws_credentials=aws_credentials, |
| 2183 | + launch_type="FARGATE_SPOT", |
| 2184 | + vpc_id="vpc-123456", |
| 2185 | + task_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole", |
| 2186 | + execution_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole", |
| 2187 | + cluster="test-cluster", |
| 2188 | + cpu=2048, |
| 2189 | + memory=4096, |
| 2190 | + task_customizations=[ |
| 2191 | + { |
| 2192 | + "op": "add", |
| 2193 | + "path": "/networkConfiguration/awsvpcConfiguration/securityGroups", |
| 2194 | + "value": ["sg-d72e9599956a084f5"], |
| 2195 | + }, |
| 2196 | + ], |
| 2197 | + family="test-family", |
| 2198 | + task_definition_arn=( |
| 2199 | + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1" |
| 2200 | + ), |
| 2201 | + cloudwatch_logs_options={ |
| 2202 | + "awslogs-group": "prefect", |
| 2203 | + "awslogs-region": "us-east-1", |
| 2204 | + "awslogs-stream-prefix": "prefect", |
| 2205 | + }, |
| 2206 | + configure_cloudwatch_logs=True, |
| 2207 | + stream_output=True, |
| 2208 | + task_watch_poll_interval=5.1, |
| 2209 | + task_start_timeout_seconds=60, |
| 2210 | + auto_deregister_task_definition=False, |
| 2211 | + ) |
| 2212 | + elif job_config == "task_definition_arn": |
| 2213 | + expected_template = base_job_template_with_task_arn |
| 2214 | + job = ECSTask( |
| 2215 | + image="docker.io/my_image:latest", |
| 2216 | + task_definition={ |
| 2217 | + "containerDefinitions": [ |
| 2218 | + {"image": "docker.io/my_image:latest", "name": "prefect-job"} |
| 2219 | + ], |
| 2220 | + "cpu": "2048", |
| 2221 | + "family": "test-family", |
| 2222 | + "memory": "2024", |
| 2223 | + "executionRoleArn": ( |
| 2224 | + "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" |
| 2225 | + ), |
| 2226 | + }, |
| 2227 | + ) |
| 2228 | + |
| 2229 | + template = await job.generate_work_pool_base_job_template() |
| 2230 | + |
| 2231 | + assert template == expected_template |
| 2232 | + |
| 2233 | + if job_config == "custom": |
| 2234 | + assert ( |
| 2235 | + "Unable to apply task customizations to the base job template." |
| 2236 | + "You may need to update the template manually." |
| 2237 | + in caplog.text |
| 2238 | + ) |
0 commit comments