Skip to content

Commit

Permalink
Standardize AWS Batch naming (#20369)
Browse files Browse the repository at this point in the history
  • Loading branch information
ferruzzi authored Jan 6, 2022
1 parent 088cbf2 commit 0ebd55e
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 109 deletions.
100 changes: 66 additions & 34 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
# specific language governing permissions and limitations
# under the License.
"""
A client for AWS batch services
A client for AWS Batch services
.. seealso::
- http://boto3.readthedocs.io/en/latest/guide/configuration.html
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
"""

import warnings
from random import uniform
from time import sleep
from typing import Dict, List, Optional, Union
Expand All @@ -39,10 +39,10 @@


@runtime_checkable
class AwsBatchProtocol(Protocol):
class BatchProtocol(Protocol):
"""
A structured Protocol for ``boto3.client('batch') -> botocore.client.Batch``.
This is used for type hints on :py:meth:`.AwsBatchClient.client`; it covers
This is used for type hints on :py:meth:`.BatchClient.client`; it covers
only the subset of client methods required.
.. seealso::
Expand All @@ -53,7 +53,7 @@ class AwsBatchProtocol(Protocol):

def describe_jobs(self, jobs: List[str]) -> Dict:
"""
Get job descriptions from AWS batch
Get job descriptions from AWS Batch
:param jobs: a list of JobId to describe
:type jobs: List[str]
Expand All @@ -72,11 +72,11 @@ def get_waiter(self, waiterName: str) -> botocore.waiter.Waiter:
model file (typically this is CamelCasing).
:type waiterName: str
:return: a waiter object for the named AWS batch service
:return: a waiter object for the named AWS Batch service
:rtype: botocore.waiter.Waiter
.. note::
AWS batch might not have any waiters (until botocore PR-1307 is released).
AWS Batch might not have any waiters (until botocore PR-1307 is released).
.. code-block:: python
Expand All @@ -102,9 +102,9 @@ def submit_job(
tags: Dict,
) -> Dict:
"""
Submit a batch job
Submit a Batch job
:param jobName: the name for the AWS batch job
:param jobName: the name for the AWS Batch job
:type jobName: str
:param jobQueue: the queue name on AWS Batch
Expand Down Expand Up @@ -132,7 +132,7 @@ def submit_job(

def terminate_job(self, jobId: str, reason: str) -> Dict:
"""
Terminate a batch job
Terminate a Batch job
:param jobId: a job ID to terminate
:type jobId: str
Expand All @@ -150,9 +150,9 @@ def terminate_job(self, jobId: str, reason: str) -> Dict:
# all the Airflow wrappers of boto3 clients should not adopt invalid-names to match boto3.


class AwsBatchClientHook(AwsBaseHook):
class BatchClientHook(AwsBaseHook):
"""
A client for AWS batch services.
A client for AWS Batch services.
:param max_retries: exponential back-off retries, 4200 = 48 hours;
polling is only used when waiters is None
Expand All @@ -169,17 +169,17 @@ class AwsBatchClientHook(AwsBaseHook):
when many concurrent tasks request job-descriptions.
To modify the global defaults for the range of jitter allowed when a
random delay is used to check batch job status, modify these defaults, e.g.:
random delay is used to check Batch job status, modify these defaults, e.g.:
.. code-block::
AwsBatchClient.DEFAULT_DELAY_MIN = 0
AwsBatchClient.DEFAULT_DELAY_MAX = 5
BatchClient.DEFAULT_DELAY_MIN = 0
BatchClient.DEFAULT_DELAY_MAX = 5
When explicit delay values are used, a 1 second random jitter is applied to the
delay (e.g. a delay of 0 sec will be a ``random.uniform(0, 1)`` delay. It is
generally recommended that random jitter is added to API requests. A
convenience method is provided for this, e.g. to get a random delay of
10 sec +/- 5 sec: ``delay = AwsBatchClient.add_jitter(10, width=5, minima=0)``
10 sec +/- 5 sec: ``delay = BatchClient.add_jitter(10, width=5, minima=0)``
.. seealso::
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
Expand Down Expand Up @@ -214,18 +214,18 @@ def __init__(
self.status_retries = status_retries or self.STATUS_RETRIES

@property
def client(self) -> Union[AwsBatchProtocol, botocore.client.BaseClient]:
def client(self) -> Union[BatchProtocol, botocore.client.BaseClient]:
"""
An AWS API client for batch services.
An AWS API client for Batch services.
:return: a boto3 'batch' client for the ``.region_name``
:rtype: Union[AwsBatchProtocol, botocore.client.BaseClient]
:rtype: Union[BatchProtocol, botocore.client.BaseClient]
"""
return self.conn

def terminate_job(self, job_id: str, reason: str) -> Dict:
"""
Terminate a batch job
Terminate a Batch job
:param job_id: a job ID to terminate
:type job_id: str
Expand All @@ -242,10 +242,10 @@ def terminate_job(self, job_id: str, reason: str) -> Dict:

def check_job_success(self, job_id: str) -> bool:
"""
Check the final status of the batch job; return True if the job
Check the final status of the Batch job; return True if the job
'SUCCEEDED', else raise an AirflowException
:param job_id: a batch job ID
:param job_id: a Batch job ID
:type job_id: str
:rtype: bool
Expand All @@ -256,7 +256,7 @@ def check_job_success(self, job_id: str) -> bool:
job_status = job.get("status")

if job_status == self.SUCCESS_STATE:
self.log.info("AWS batch job (%s) succeeded: %s", job_id, job)
self.log.info("AWS Batch job (%s) succeeded: %s", job_id, job)
return True

if job_status == self.FAILURE_STATE:
Expand All @@ -269,9 +269,9 @@ def check_job_success(self, job_id: str) -> bool:

def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None:
"""
Wait for batch job to complete
Wait for Batch job to complete
:param job_id: a batch job ID
:param job_id: a Batch job ID
:type job_id: str
:param delay: a delay before polling for job status
Expand All @@ -296,7 +296,7 @@ def poll_for_job_running(self, job_id: str, delay: Union[int, float, None] = Non
changes too quickly for polling to detect a RUNNING status that moves
quickly from STARTING to RUNNING to completed (often a failure).
:param job_id: a batch job ID
:param job_id: a Batch job ID
:type job_id: str
:param delay: a delay before polling for job status
Expand All @@ -316,7 +316,7 @@ def poll_for_job_complete(self, job_id: str, delay: Union[int, float, None] = No
So the status options that this will wait for are the transitions from:
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED'
:param job_id: a batch job ID
:param job_id: a Batch job ID
:type job_id: str
:param delay: a delay before polling for job status
Expand All @@ -332,10 +332,10 @@ def poll_job_status(self, job_id: str, match_status: List[str]) -> bool:
"""
Poll for job status using an exponential back-off strategy (with max_retries).
:param job_id: a batch job ID
:param job_id: a Batch job ID
:type job_id: str
:param match_status: a list of job status to match; the batch job status are:
:param match_status: a list of job status to match; the Batch job status are:
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
:type match_status: List[str]
Expand Down Expand Up @@ -376,7 +376,7 @@ def get_job_description(self, job_id: str) -> Dict:
"""
Get job description (using status_retries).
:param job_id: a batch job ID
:param job_id: a Batch job ID
:type job_id: str
:return: an API response for describe jobs
Expand Down Expand Up @@ -419,7 +419,7 @@ def parse_job_description(job_id: str, response: Dict) -> Dict:
"""
Parse job description to extract description for job_id
:param job_id: a batch job ID
:param job_id: a Batch job ID
:type job_id: str
:param response: an API response for describe jobs
Expand All @@ -445,7 +445,7 @@ def add_jitter(
Use delay +/- width for random jitter
Adding jitter to status polling can help to avoid
AWS batch API limits for monitoring batch jobs with
AWS Batch API limits for monitoring Batch jobs with
a high concurrency in Airflow tasks.
:param delay: number of seconds to pause;
Expand Down Expand Up @@ -487,9 +487,9 @@ def delay(delay: Union[int, float, None] = None) -> None:
when many concurrent tasks request job-descriptions.
"""
if delay is None:
delay = uniform(AwsBatchClientHook.DEFAULT_DELAY_MIN, AwsBatchClientHook.DEFAULT_DELAY_MAX)
delay = uniform(BatchClientHook.DEFAULT_DELAY_MIN, BatchClientHook.DEFAULT_DELAY_MAX)
else:
delay = AwsBatchClientHook.add_jitter(delay)
delay = BatchClientHook.add_jitter(delay)
sleep(delay)

@staticmethod
Expand Down Expand Up @@ -538,3 +538,35 @@ def exp(tries):
delay = 1 + pow(tries * 0.6, 2)
delay = min(max_interval, delay)
return uniform(delay / 3, delay)


class AwsBatchProtocol(BatchProtocol, Protocol):
"""
This class is deprecated.
Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchProtocol`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchProtocol`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)


class AwsBatchClientHook(BatchClientHook):
"""
This hook is deprecated.
Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchClientHook`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This hook is deprecated. "
"Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchClientHook`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
Loading

0 comments on commit 0ebd55e

Please sign in to comment.