Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 6aca613

Browse files
authored
Adds ability to publish ECSTask block as a ecs work pool (#353)
1 parent 6865af7 commit 6aca613

File tree

5 files changed

+274
-3
lines changed

5 files changed

+274
-3
lines changed

CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717

1818
### Removed
1919

20+
## 0.4.6
21+
22+
Released December 11th, 2023.
23+
24+
### Added
25+
26+
Ability to publish `ECSTask`` block as an ecs work pool - [#353](https://github.com/PrefectHQ/prefect-aws/pull/353)
27+
2028
## 0.4.5
2129

2230
Released November 30th, 2023.

prefect_aws/ecs.py

+73-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
import json
109109
import logging
110110
import pprint
111+
import shlex
111112
import sys
112113
import time
113114
import warnings
@@ -116,6 +117,8 @@
116117
import boto3
117118
import yaml
118119
from anyio.abc import TaskStatus
120+
from jsonpointer import JsonPointerException
121+
from prefect.blocks.core import BlockNotSavedError
119122
from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound
120123
from prefect.infrastructure.base import Infrastructure, InfrastructureResult
121124
from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible
@@ -132,7 +135,7 @@
132135
from typing_extensions import Literal, Self
133136

134137
from prefect_aws import AwsCredentials
135-
from prefect_aws.workers.ecs_worker import _TAG_REGEX
138+
from prefect_aws.workers.ecs_worker import _TAG_REGEX, ECSWorker
136139

137140
# Internal type alias for ECS clients which are generated dynamically in botocore
138141
_ECSClient = Any
@@ -681,6 +684,75 @@ async def kill(self, identifier: str, grace_seconds: int = 30) -> None:
681684
cluster, task = parse_task_identifier(identifier)
682685
await run_sync_in_worker_thread(self._stop_task, cluster, task)
683686

687+
@staticmethod
688+
def get_corresponding_worker_type() -> str:
689+
"""Return the corresponding worker type for this infrastructure block."""
690+
return ECSWorker.type
691+
692+
async def generate_work_pool_base_job_template(self) -> dict:
693+
"""
694+
Generate a base job template for a cloud-run work pool with the same
695+
configuration as this block.
696+
697+
Returns:
698+
- dict: a base job template for a cloud-run work pool
699+
"""
700+
base_job_template = copy.deepcopy(ECSWorker.get_default_base_job_template())
701+
for key, value in self.dict(exclude_unset=True, exclude_defaults=True).items():
702+
if key == "command":
703+
base_job_template["variables"]["properties"]["command"]["default"] = (
704+
shlex.join(value)
705+
)
706+
elif key in [
707+
"type",
708+
"block_type_slug",
709+
"_block_document_id",
710+
"_block_document_name",
711+
"_is_anonymous",
712+
"task_customizations",
713+
]:
714+
continue
715+
elif key == "aws_credentials":
716+
if not self.aws_credentials._block_document_id:
717+
raise BlockNotSavedError(
718+
"It looks like you are trying to use a block that"
719+
" has not been saved. Please call `.save` on your block"
720+
" before publishing it as a work pool."
721+
)
722+
base_job_template["variables"]["properties"]["aws_credentials"][
723+
"default"
724+
] = {
725+
"$ref": {
726+
"block_document_id": str(
727+
self.aws_credentials._block_document_id
728+
)
729+
}
730+
}
731+
elif key == "task_definition":
732+
base_job_template["job_configuration"]["task_definition"] = value
733+
elif key in base_job_template["variables"]["properties"]:
734+
base_job_template["variables"]["properties"][key]["default"] = value
735+
else:
736+
self.logger.warning(
737+
f"Variable {key!r} is not supported by Cloud Run work pools."
738+
" Skipping."
739+
)
740+
741+
if self.task_customizations:
742+
try:
743+
base_job_template["job_configuration"]["task_run_request"] = (
744+
self.task_customizations.apply(
745+
base_job_template["job_configuration"]["task_run_request"]
746+
)
747+
)
748+
except JsonPointerException:
749+
self.logger.warning(
750+
"Unable to apply task customizations to the base job template."
751+
"You may need to update the template manually."
752+
)
753+
754+
return base_job_template
755+
684756
def _stop_task(self, cluster: str, task: str) -> None:
685757
"""
686758
Stop a running ECS task.

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ boto3>=1.24.53
22
botocore>=1.27.53
33
mypy_boto3_s3>=1.24.94
44
mypy_boto3_secretsmanager>=1.26.49
5-
prefect>=2.13.5
5+
prefect>=2.14.10
66
tenacity>=8.0.0

tests/conftest.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ def prefect_db():
2222

2323
@pytest.fixture
2424
def aws_credentials():
25-
return AwsCredentials(
25+
block = AwsCredentials(
2626
aws_access_key_id="access_key_id",
2727
aws_secret_access_key="secret_access_key",
2828
region_name="us-east-1",
2929
)
30+
block.save("test-creds-block", overwrite=True)
31+
return block
3032

3133

3234
@pytest.fixture

tests/test_ecs.py

+189
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33
import textwrap
4+
from copy import deepcopy
45
from functools import partial
56
from typing import Any, Awaitable, Callable, Dict, List, Optional
67
from unittest.mock import MagicMock
@@ -18,6 +19,8 @@
1819
from prefect.utilities.dockerutils import get_prefect_image_name
1920
from pydantic import VERSION as PYDANTIC_VERSION
2021

22+
from prefect_aws.workers.ecs_worker import ECSWorker
23+
2124
if PYDANTIC_VERSION.startswith("2."):
2225
from pydantic.v1 import ValidationError
2326
else:
@@ -2047,3 +2050,189 @@ async def test_kill_with_grace_period(aws_credentials, caplog):
20472050

20482051
# Logs warning
20492052
assert "grace period of 60s requested, but AWS does not support" in caplog.text
2053+
2054+
2055+
@pytest.fixture
2056+
def default_base_job_template():
2057+
return deepcopy(ECSWorker.get_default_base_job_template())
2058+
2059+
2060+
@pytest.fixture
2061+
def base_job_template_with_defaults(default_base_job_template, aws_credentials):
2062+
base_job_template_with_defaults = deepcopy(default_base_job_template)
2063+
base_job_template_with_defaults["variables"]["properties"]["command"][
2064+
"default"
2065+
] = "python my_script.py"
2066+
base_job_template_with_defaults["variables"]["properties"]["env"]["default"] = {
2067+
"VAR1": "value1",
2068+
"VAR2": "value2",
2069+
}
2070+
base_job_template_with_defaults["variables"]["properties"]["labels"]["default"] = {
2071+
"label1": "value1",
2072+
"label2": "value2",
2073+
}
2074+
base_job_template_with_defaults["variables"]["properties"]["name"][
2075+
"default"
2076+
] = "prefect-job"
2077+
base_job_template_with_defaults["variables"]["properties"]["image"][
2078+
"default"
2079+
] = "docker.io/my_image:latest"
2080+
base_job_template_with_defaults["variables"]["properties"]["aws_credentials"][
2081+
"default"
2082+
] = {"$ref": {"block_document_id": str(aws_credentials._block_document_id)}}
2083+
base_job_template_with_defaults["variables"]["properties"]["launch_type"][
2084+
"default"
2085+
] = "FARGATE_SPOT"
2086+
base_job_template_with_defaults["variables"]["properties"]["vpc_id"][
2087+
"default"
2088+
] = "vpc-123456"
2089+
base_job_template_with_defaults["variables"]["properties"]["task_role_arn"][
2090+
"default"
2091+
] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole"
2092+
base_job_template_with_defaults["variables"]["properties"]["execution_role_arn"][
2093+
"default"
2094+
] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole"
2095+
base_job_template_with_defaults["variables"]["properties"]["cluster"][
2096+
"default"
2097+
] = "test-cluster"
2098+
base_job_template_with_defaults["variables"]["properties"]["cpu"]["default"] = 2048
2099+
base_job_template_with_defaults["variables"]["properties"]["memory"][
2100+
"default"
2101+
] = 4096
2102+
2103+
base_job_template_with_defaults["variables"]["properties"]["family"][
2104+
"default"
2105+
] = "test-family"
2106+
base_job_template_with_defaults["variables"]["properties"]["task_definition_arn"][
2107+
"default"
2108+
] = "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1"
2109+
base_job_template_with_defaults["variables"]["properties"][
2110+
"cloudwatch_logs_options"
2111+
]["default"] = {
2112+
"awslogs-group": "prefect",
2113+
"awslogs-region": "us-east-1",
2114+
"awslogs-stream-prefix": "prefect",
2115+
}
2116+
base_job_template_with_defaults["variables"]["properties"][
2117+
"configure_cloudwatch_logs"
2118+
]["default"] = True
2119+
base_job_template_with_defaults["variables"]["properties"]["stream_output"][
2120+
"default"
2121+
] = True
2122+
base_job_template_with_defaults["variables"]["properties"][
2123+
"task_watch_poll_interval"
2124+
]["default"] = 5.1
2125+
base_job_template_with_defaults["variables"]["properties"][
2126+
"task_start_timeout_seconds"
2127+
]["default"] = 60
2128+
base_job_template_with_defaults["variables"]["properties"][
2129+
"auto_deregister_task_definition"
2130+
]["default"] = False
2131+
return base_job_template_with_defaults
2132+
2133+
2134+
@pytest.fixture
2135+
def base_job_template_with_task_arn(default_base_job_template, aws_credentials):
2136+
base_job_template_with_task_arn = deepcopy(default_base_job_template)
2137+
base_job_template_with_task_arn["variables"]["properties"]["image"][
2138+
"default"
2139+
] = "docker.io/my_image:latest"
2140+
2141+
base_job_template_with_task_arn["job_configuration"]["task_definition"] = {
2142+
"containerDefinitions": [
2143+
{"image": "docker.io/my_image:latest", "name": "prefect-job"}
2144+
],
2145+
"cpu": "2048",
2146+
"family": "test-family",
2147+
"memory": "2024",
2148+
"executionRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole",
2149+
}
2150+
return base_job_template_with_task_arn
2151+
2152+
2153+
@pytest.mark.parametrize(
2154+
"job_config",
2155+
[
2156+
"default",
2157+
"custom",
2158+
"task_definition_arn",
2159+
],
2160+
)
2161+
async def test_generate_work_pool_base_job_template(
2162+
job_config,
2163+
base_job_template_with_defaults,
2164+
aws_credentials,
2165+
default_base_job_template,
2166+
base_job_template_with_task_arn,
2167+
caplog,
2168+
):
2169+
job = ECSTask()
2170+
expected_template = default_base_job_template
2171+
expected_template["variables"]["properties"]["image"][
2172+
"default"
2173+
] = get_prefect_image_name()
2174+
if job_config == "custom":
2175+
expected_template = base_job_template_with_defaults
2176+
job = ECSTask(
2177+
command=["python", "my_script.py"],
2178+
env={"VAR1": "value1", "VAR2": "value2"},
2179+
labels={"label1": "value1", "label2": "value2"},
2180+
name="prefect-job",
2181+
image="docker.io/my_image:latest",
2182+
aws_credentials=aws_credentials,
2183+
launch_type="FARGATE_SPOT",
2184+
vpc_id="vpc-123456",
2185+
task_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole",
2186+
execution_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole",
2187+
cluster="test-cluster",
2188+
cpu=2048,
2189+
memory=4096,
2190+
task_customizations=[
2191+
{
2192+
"op": "add",
2193+
"path": "/networkConfiguration/awsvpcConfiguration/securityGroups",
2194+
"value": ["sg-d72e9599956a084f5"],
2195+
},
2196+
],
2197+
family="test-family",
2198+
task_definition_arn=(
2199+
"arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1"
2200+
),
2201+
cloudwatch_logs_options={
2202+
"awslogs-group": "prefect",
2203+
"awslogs-region": "us-east-1",
2204+
"awslogs-stream-prefix": "prefect",
2205+
},
2206+
configure_cloudwatch_logs=True,
2207+
stream_output=True,
2208+
task_watch_poll_interval=5.1,
2209+
task_start_timeout_seconds=60,
2210+
auto_deregister_task_definition=False,
2211+
)
2212+
elif job_config == "task_definition_arn":
2213+
expected_template = base_job_template_with_task_arn
2214+
job = ECSTask(
2215+
image="docker.io/my_image:latest",
2216+
task_definition={
2217+
"containerDefinitions": [
2218+
{"image": "docker.io/my_image:latest", "name": "prefect-job"}
2219+
],
2220+
"cpu": "2048",
2221+
"family": "test-family",
2222+
"memory": "2024",
2223+
"executionRoleArn": (
2224+
"arn:aws:iam::123456789012:role/ecsTaskExecutionRole"
2225+
),
2226+
},
2227+
)
2228+
2229+
template = await job.generate_work_pool_base_job_template()
2230+
2231+
assert template == expected_template
2232+
2233+
if job_config == "custom":
2234+
assert (
2235+
"Unable to apply task customizations to the base job template."
2236+
"You may need to update the template manually."
2237+
in caplog.text
2238+
)

0 commit comments

Comments
 (0)