diff --git a/dask_cloudprovider/aws/ecs.py b/dask_cloudprovider/aws/ecs.py index 79e7bb0d..a13f4fcb 100644 --- a/dask_cloudprovider/aws/ecs.py +++ b/dask_cloudprovider/aws/ecs.py @@ -12,7 +12,6 @@ from dask_cloudprovider.aws.helper import ( dict_to_aws, aws_to_dict, - get_sleep_duration, get_default_vpc, get_vpc_subnets, create_default_security_group, @@ -25,6 +24,7 @@ try: from botocore.exceptions import ClientError + from aiobotocore.config import AioConfig from aiobotocore.session import get_session except ImportError as e: msg = ( @@ -120,6 +120,7 @@ def __init__( fargate_use_private_ip=False, fargate_capacity_provider=None, task_kwargs=None, + is_task_long_arn_format_enabled=True, **kwargs, ): self.lock = asyncio.Lock() @@ -144,6 +145,7 @@ def __init__( self._fargate_capacity_provider = fargate_capacity_provider self.kwargs = kwargs self.task_kwargs = task_kwargs + self._is_task_long_arn_format_enabled = is_task_long_arn_format_enabled self.status = Status.created def __await__(self): @@ -160,36 +162,15 @@ async def _(): def _use_public_ip(self): return self.fargate and not self._fargate_use_private_ip - async def _is_long_arn_format_enabled(self): - async with self._client("ecs") as ecs: - [response] = ( - await ecs.list_account_settings( - name="taskLongArnFormat", effectiveSettings=True - ) - )["settings"] - return response["value"] == "enabled" - async def _update_task(self): async with self._client("ecs") as ecs: - wait_duration = 1 - while True: - try: - [self.task] = ( - await ecs.describe_tasks( - cluster=self.cluster_arn, tasks=[self.task_arn] - ) - )["tasks"] - except ClientError as e: - if e.response["Error"]["Code"] == "ThrottlingException": - wait_duration = min(wait_duration * 2, 20) - else: - raise - else: - break - await asyncio.sleep(wait_duration) + [self.task] = ( + await ecs.describe_tasks( + cluster=self.cluster_arn, tasks=[self.task_arn] + ) + )["tasks"] - async def _task_is_running(self): - await self._update_task() + def _task_is_running(self): return self.task["lastStatus"] == "RUNNING" async def start(self): @@ -199,7 +180,7 @@ async def start(self): kwargs = self.task_kwargs.copy() if self.task_kwargs is not None else {} # Tags are only supported if you opt into long arn format so we need to check for that - if await self._is_long_arn_format_enabled(): + if self._is_task_long_arn_format_enabled: kwargs["tags"] = dict_to_aws(self.tags) if self.platform_version and self.fargate: kwargs["platformVersion"] = self.platform_version @@ -253,13 +234,19 @@ async def start(self): [self.task] = response["tasks"] break except Exception as e: + # Retries due to throttle errors are handled by the aiobotocore client so this should be an uncommon case timeout.set_exception(e) - await asyncio.sleep(1) + logger.debug(f"Failed to start {self.task_type} task after {timeout.elapsed_time:.1f}s, retrying in 1s: {e}") + await asyncio.sleep(2) self.task_arn = self.task["taskArn"] + + # Wait for the task to come up while self.task["lastStatus"] in ["PENDING", "PROVISIONING"]: + # Try to avoid hitting throttling rate limits when bring up a large cluster + await asyncio.sleep(1) await self._update_task() - if not await self._task_is_running(): + if not self._task_is_running(): raise RuntimeError("%s failed to start" % type(self).__name__) [eni] = [ attachment @@ -286,7 +273,7 @@ async def close(self, **kwargs): async with self._client("ecs") as ecs: await ecs.stop_task(cluster=self.cluster_arn, task=self.task_arn) await self._update_task() - while self.task["lastStatus"] in ["RUNNING"]: + while self._task_is_running(): await asyncio.sleep(1) await self._update_task() self.status = Status.closed @@ -304,48 +291,35 @@ def _log_stream_name(self): ) async def logs(self, follow=False): - current_try = 0 next_token = None read_from = 0 while True: - try: - async with self._client("logs") as logs: - if next_token: - l = await logs.get_log_events( - logGroupName=self.log_group, - logStreamName=self._log_stream_name, - nextToken=next_token, - ) - else: - l = await logs.get_log_events( - logGroupName=self.log_group, - logStreamName=self._log_stream_name, - startTime=read_from, - ) - if next_token != l["nextForwardToken"]: - next_token = l["nextForwardToken"] + async with self._client("logs") as logs: + if next_token: + l = await logs.get_log_events( + logGroupName=self.log_group, + logStreamName=self._log_stream_name, + nextToken=next_token, + ) else: - next_token = None - if not l["events"]: - if follow: - await asyncio.sleep(1) - else: - break - for event in l["events"]: - read_from = event["timestamp"] - yield event["message"] - except ClientError as e: - if e.response["Error"]["Code"] == "ThrottlingException": - warnings.warn( - "get_log_events rate limit exceeded, retrying after delay.", - RuntimeWarning, + l = await logs.get_log_events( + logGroupName=self.log_group, + logStreamName=self._log_stream_name, + startTime=read_from, ) - backoff_duration = get_sleep_duration(current_try) - await asyncio.sleep(backoff_duration) - current_try += 1 + if next_token != l["nextForwardToken"]: + next_token = l["nextForwardToken"] + else: + next_token = None + if not l["events"]: + if follow: + await asyncio.sleep(1) else: - raise + break + for event in l["events"]: + read_from = event["timestamp"] + yield event["message"] def __repr__(self): return "" % (type(self).__name__, self.status) @@ -813,6 +787,7 @@ def __init__( self._platform_version = platform_version self._lock = asyncio.Lock() self.session = get_session() + self._is_task_long_arn_format_enabled = None super().__init__(**kwargs) def _client(self, name: str): @@ -821,6 +796,16 @@ def _client(self, name: str): aws_access_key_id=self._aws_access_key_id, aws_secret_access_key=self._aws_secret_access_key, region_name=self._region_name, + config=AioConfig( + retries={ + # Use Standard retry mode which provides: + # - Jittered exponential backoff with max of 20s in the event of failures + # - Never delays the first request attempt, only the retries + # - Supports circuit-breaking to prevent the SDK from retrying during outages + "mode": "standard", + "max_attempts": 10, # Not including the initial request + } + ), ) async def _start( @@ -950,6 +935,10 @@ async def _start( self.worker_task_definition_arn = ( await self._create_worker_task_definition_arn() ) + if self._is_task_long_arn_format_enabled is None: + self._is_task_long_arn_format_enabled = ( + await self._get_is_task_long_arn_format_enabled() + ) options = { "client": self._client, @@ -962,6 +951,7 @@ async def _start( "tags": self.tags, "platform_version": self._platform_version, "fargate_use_private_ip": self._fargate_use_private_ip, + "is_task_long_arn_format_enabled": self._is_task_long_arn_format_enabled, } scheduler_options = { "task_definition_arn": self.scheduler_task_definition_arn, @@ -1319,6 +1309,15 @@ async def _delete_worker_task_definition_arn(self): taskDefinition=self.worker_task_definition_arn ) + async def _get_is_task_long_arn_format_enabled(self): + async with self._client("ecs") as ecs: + [response] = ( + await ecs.list_account_settings( + name="taskLongArnFormat", effectiveSettings=True + ) + )["settings"] + return response["value"] == "enabled" + def logs(self): async def get_logs(task): log = "" diff --git a/dask_cloudprovider/utils/timeout.py b/dask_cloudprovider/utils/timeout.py index 07c8ffc5..197a8077 100644 --- a/dask_cloudprovider/utils/timeout.py +++ b/dask_cloudprovider/utils/timeout.py @@ -66,7 +66,7 @@ def run(self): self.start = datetime.now() self.running = True - if self.start + timedelta(seconds=self.timeout) < datetime.now(): + if self.elapsed_time >= self.timeout: if self.warn: warnings.warn(self.error_message) return False @@ -82,3 +82,10 @@ def set_exception(self, e): the thing you are trying rather than a TimeoutException. """ self.exception = e + + @property + def elapsed_time(self): + """Return the elapsed time since the timeout started.""" + if self.start is None: + return 0 + return (datetime.now() - self.start).total_seconds()