diff --git a/catalog/dags/data_refresh/distributed_reindex.py b/catalog/dags/data_refresh/distributed_reindex.py index 6666dc71bed..f23fe1b632f 100644 --- a/catalog/dags/data_refresh/distributed_reindex.py +++ b/catalog/dags/data_refresh/distributed_reindex.py @@ -26,6 +26,7 @@ OPENLEDGER_API_CONN_ID, PRODUCTION, REFRESH_POKE_INTERVAL, + Environment, ) from common.sql import PGExecuteQueryOperator, single_value from data_refresh.constants import INDEXER_LAUNCH_TEMPLATES, INDEXER_WORKER_COUNTS @@ -41,7 +42,7 @@ class TempConnectionHTTPOperator(HttpOperator): """ Wrapper around the HTTPOperator which allows templating of the conn_id, - in order to support using a temporary conn_id passed through XCOM. + in order to support using a temporary conn_id passed through XCOMs. """ template_fields: Sequence[str] = ( @@ -56,7 +57,7 @@ class TempConnectionHTTPOperator(HttpOperator): class TempConnectionHTTPSensor(HttpSensor): """ Wrapper around the HTTPSensor which allows templating of the conn_id, - in order to support using a temporary conn_id passed through XCOM. + in order to support using a temporary conn_id passed through XCOMs. """ template_fields: Sequence[str] = ( @@ -102,8 +103,8 @@ def response_check_wait_for_completion(response: Response) -> bool: def get_worker_params( estimated_record_count: int, environment: str, - target_environment: str, - aws_conn_id: str = AWS_ASG_CONN_ID, # TODO + target_environment: Environment, + aws_conn_id: str = AWS_ASG_CONN_ID, ): """Determine the set of start/end indices to be passed to each indexer worker.""" # Defaults to one indexer worker in local development @@ -123,10 +124,37 @@ def get_worker_params( ] +@task +def get_launch_template_version( + environment: str, + target_environment: Environment, + aws_conn_id: str = AWS_ASG_CONN_ID, +): + """ + Get the latest version of the launch template. Indexer workers will all be created with this + version. Importantly, this allows us to retry an individual indexer worker and ensure that + it will run with the same version of the indexer worker as the others, even if code has been + deployed to the indexer worker in the meantime. + """ + if environment != PRODUCTION: + raise AirflowSkipException("Skipping instance creation in local environment.") + + ec2_hook = EC2Hook(aws_conn_id=aws_conn_id, api_type="client_type") + launch_templates = ec2_hook.conn.describe_launch_templates( + LaunchTemplateNames=INDEXER_LAUNCH_TEMPLATES.get(target_environment) + ) + + if len(launch_templates.get("LaunchTemplates")) == 0: + raise Exception("Unable to determine launch template version.") + + return launch_templates.get("LaunchTemplates")[0].get("LatestVersionNumber") + + @task def create_worker( environment: str, - target_environment: str, + target_environment: Environment, + launch_template_version: int, aws_conn_id: str = AWS_ASG_CONN_ID, ): """ @@ -142,11 +170,7 @@ def create_worker( MaxCount=1, LaunchTemplate={ "LaunchTemplateName": INDEXER_LAUNCH_TEMPLATES.get(target_environment), - "Version": "$Latest", - # TODO we could add a task before all of this to get the version number of - # the launch template and then use it in all these tasks, that ensures - # that all indexer workers are running the same code even if a deploy - # happens in the middle of a data refresh + "Version": str(launch_template_version), }, # Name the instance by applying a tag TagSpecifications=[ @@ -279,7 +303,7 @@ def reindex( start_id: int, end_id: int, environment: str, - target_environment: str, + target_environment: Environment, aws_conn_id: str = AWS_ASG_CONN_ID, ): """ @@ -287,10 +311,16 @@ def reindex( terminate the indexer worker instance. """ + launch_template_version = get_launch_template_version( + environment=environment, + target_environment=target_environment, + ) + # Create a new EC2 instance instance_id = create_worker( environment=environment, target_environment=target_environment, + launch_template_version=launch_template_version, aws_conn_id=aws_conn_id, ) @@ -356,7 +386,7 @@ def reindex( ) def perform_distributed_reindex( environment: str, - target_environment: str, # TODO, update types + target_environment: Environment, target_index: str, data_refresh_config: DataRefreshConfig, aws_conn_id: str = AWS_ASG_CONN_ID,