diff --git a/catalog/dags/common/constants.py b/catalog/dags/common/constants.py index f2b1f45a407..b94d3aac555 100644 --- a/catalog/dags/common/constants.py +++ b/catalog/dags/common/constants.py @@ -47,6 +47,7 @@ AWS_CONN_ID = "aws_default" AWS_CLOUDWATCH_CONN_ID = os.environ.get("AWS_CLOUDWATCH_CONN_ID", AWS_CONN_ID) AWS_RDS_CONN_ID = os.environ.get("AWS_RDS_CONN_ID", AWS_CONN_ID) +AWS_ASG_CONN_ID = os.environ.get("AWS_ASG_CONN_ID", AWS_CONN_ID) ES_PROD_HTTP_CONN_ID = "elasticsearch_http_production" REFRESH_POKE_INTERVAL = int(os.getenv("DATA_REFRESH_POKE_INTERVAL", 60 * 30)) diff --git a/catalog/dags/common/elasticsearch.py b/catalog/dags/common/elasticsearch.py index 048133594f2..9be9b1ffe2e 100644 --- a/catalog/dags/common/elasticsearch.py +++ b/catalog/dags/common/elasticsearch.py @@ -78,6 +78,24 @@ def remove_excluded_index_settings(index_config): return index_config +@task +def get_index_configuration_copy( + source_index: str, target_index_name: str, es_host: str +): + """ + Create a new index configuration based off the `source_index` but with the given + `target_index_name`, in the format needed for `create_index`. Removes fields that + should not be copied into a new index configuration such as the uuid. + """ + base_config = get_index_configuration.function(source_index, es_host) + + cleaned_config = remove_excluded_index_settings.function(base_config) + + cleaned_config["index"] = target_index_name + + return cleaned_config + + @task def get_record_count_group_by_sources(es_host: str, index: str): """ diff --git a/catalog/dags/data_refresh/copy_data.py b/catalog/dags/data_refresh/copy_data.py index dea6a2cc743..1d1aae90e16 100644 --- a/catalog/dags/data_refresh/copy_data.py +++ b/catalog/dags/data_refresh/copy_data.py @@ -198,7 +198,7 @@ def copy_data( def copy_upstream_table( upstream_conn_id: str, downstream_conn_id: str, - environment: Environment, + target_environment: Environment, timeout: timedelta, limit: int, upstream_table_name: str, @@ -275,12 +275,13 @@ def copy_upstream_table( create_temp_table >> setup_id_columns >> setup_tertiary_columns setup_tertiary_columns >> copy copy >> add_primary_key + return @task_group(group_id="copy_upstream_tables") def copy_upstream_tables( - environment: Environment, data_refresh_config: DataRefreshConfig + target_environment: Environment, data_refresh_config: DataRefreshConfig ): """ For each upstream table associated with the given media type, create a new @@ -290,7 +291,7 @@ def copy_upstream_tables( This task does _not_ apply all indices and constraints, merely copies the data. """ - downstream_conn_id = POSTGRES_API_CONN_IDS.get(environment) + downstream_conn_id = POSTGRES_API_CONN_IDS.get(target_environment) upstream_conn_id = POSTGRES_CONN_ID create_fdw = _run_sql.override(task_id="create_fdw")( @@ -309,7 +310,7 @@ def copy_upstream_tables( copy_tables = copy_upstream_table.partial( upstream_conn_id=upstream_conn_id, downstream_conn_id=downstream_conn_id, - environment=environment, + target_environment=target_environment, timeout=data_refresh_config.copy_data_timeout, limit=limit, ).expand_kwargs([asdict(tm) for tm in data_refresh_config.table_mappings]) diff --git a/catalog/dags/data_refresh/dag_factory.py b/catalog/dags/data_refresh/dag_factory.py index 474c57545ca..321e07a9a78 100644 --- a/catalog/dags/data_refresh/dag_factory.py +++ b/catalog/dags/data_refresh/dag_factory.py @@ -28,11 +28,12 @@ import logging import os +import uuid from collections.abc import Sequence from itertools import product from airflow import DAG -from airflow.decorators import task_group +from airflow.decorators import task, task_group from airflow.operators.python import PythonOperator from airflow.utils.trigger_rule import TriggerRule @@ -44,6 +45,7 @@ from common.sensors.utils import wait_for_external_dags_with_tag from data_refresh.copy_data import copy_upstream_tables from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefreshConfig +from data_refresh.distributed_reindex import perform_distributed_reindex from data_refresh.reporting import report_record_difference @@ -83,9 +85,14 @@ def wait_for_conflicting_dags( ) +@task +def generate_index_name(media_type: str) -> str: + return f"{media_type}-{uuid.uuid4().hex}" + + def create_data_refresh_dag( data_refresh_config: DataRefreshConfig, - environment: Environment, + target_environment: Environment, external_dag_ids: Sequence[str], ): """ @@ -95,22 +102,22 @@ def create_data_refresh_dag( Required Arguments: - data_refresh: dataclass containing configuration information for the - DAG - environment: the environment in which the data refresh is performed - external_dag_ids: list of ids of the other data refresh DAGs. The data refresh step - of this DAG will not run concurrently with the corresponding step - of any dependent DAG. + data_refresh: dataclass containing configuration information for the + DAG + target_environment: the environment in which the data refresh is performed + external_dag_ids: list of ids of the other data refresh DAGs. The data refresh step + of this DAG will not run concurrently with the corresponding step + of any dependent DAG. """ default_args = { **DAG_DEFAULT_ARGS, **data_refresh_config.default_args, } - concurrency_tag = ES_CONCURRENCY_TAGS.get(environment) + concurrency_tag = ES_CONCURRENCY_TAGS.get(target_environment) dag = DAG( - dag_id=f"{environment}_{data_refresh_config.dag_id}", + dag_id=f"{target_environment}_{data_refresh_config.dag_id}", dagrun_timeout=data_refresh_config.dag_timeout, default_args=default_args, start_date=data_refresh_config.start_date, @@ -120,14 +127,14 @@ def create_data_refresh_dag( doc_md=__doc__, tags=[ "data_refresh", - f"{environment}_data_refresh", + f"{target_environment}_data_refresh", concurrency_tag, ], ) with dag: # Connect to the appropriate Elasticsearch cluster - es_host = es.get_es_host(environment=environment) + es_host = es.get_es_host(environment=target_environment) # Get the current number of records in the target API table before_record_count = es.get_record_count_group_by_sources.override( @@ -142,11 +149,29 @@ def create_data_refresh_dag( ) copy_data = copy_upstream_tables( - environment=environment, data_refresh_config=data_refresh_config + target_environment=target_environment, + data_refresh_config=data_refresh_config, ) # TODO Cleaning steps + # TODO: Move temp_index_name >> index_config >> target_index into a `create_index` task group + + # Generate a UUID suffix that will be used by the newly created index. + temp_index_name = generate_index_name(media_type=data_refresh_config.media_type) + + # Get the configuration for the new Elasticsearch index, based off the existing index. + index_config = es.get_index_configuration_copy.override( + task_id="get_index_configuration" + )( + source_index=data_refresh_config.media_type, + target_index_name=temp_index_name, + es_host=es_host, + ) + + # Create a new index matching the existing configuration + target_index = es.create_index(index_config=index_config, es_host=es_host) + # Disable Cloudwatch alarms that are noisy during the reindexing steps of a # data refresh. disable_alarms = PythonOperator( @@ -157,8 +182,13 @@ def create_data_refresh_dag( }, ) - # TODO create_and_populate_index # (TaskGroup that creates index, triggers and waits for reindexing) + reindex = perform_distributed_reindex( + environment="{{ var.value.ENVIRONMENT }}", + target_environment=target_environment, + target_index=temp_index_name, + data_refresh_config=data_refresh_config, + ) # TODO create_and_populate_filtered_index @@ -198,9 +228,9 @@ def create_data_refresh_dag( ) # Set up task dependencies - before_record_count >> wait_for_dags >> copy_data >> disable_alarms - # TODO: this will include reindex/etc once added - disable_alarms >> [enable_alarms, after_record_count] + before_record_count >> wait_for_dags >> copy_data >> temp_index_name + temp_index_name >> index_config >> target_index >> disable_alarms + disable_alarms >> reindex >> [enable_alarms, after_record_count] after_record_count >> report_counts return dag @@ -209,7 +239,7 @@ def create_data_refresh_dag( # Generate data refresh DAGs for each DATA_REFRESH_CONFIG, per environment. all_data_refresh_dag_ids = {refresh.dag_id for refresh in DATA_REFRESH_CONFIGS.values()} -for data_refresh_config, environment in product( +for data_refresh_config, target_environment in product( DATA_REFRESH_CONFIGS.values(), ENVIRONMENTS ): # Construct a set of all data refresh DAG ids other than the current DAG @@ -217,6 +247,6 @@ def create_data_refresh_dag( globals()[data_refresh_config.dag_id] = create_data_refresh_dag( data_refresh_config, - environment, - [f"{environment}_{dag_id}" for dag_id in other_dag_ids], + target_environment, + [f"{target_environment}_{dag_id}" for dag_id in other_dag_ids], ) diff --git a/catalog/dags/data_refresh/data_refresh_types.py b/catalog/dags/data_refresh/data_refresh_types.py index 94523c18e39..d5c3b25aefe 100644 --- a/catalog/dags/data_refresh/data_refresh_types.py +++ b/catalog/dags/data_refresh/data_refresh_types.py @@ -48,9 +48,10 @@ class DataRefreshConfig: Required Constructor Arguments: - media_type: str describing the media type to be refreshed. - table_mappings: list of TableMapping information for all tables that should be - refreshed as part of a data refresh for this media type. + media_type: str describing the media type to be refreshed. + table_mappings: list of TableMapping information for all tables that should + be refreshed as part of a data refresh for this media type. + Optional Constructor Arguments: @@ -65,6 +66,9 @@ class DataRefreshConfig: data refresh may take. copy_data_timeout: timedelta expressing the amount of time it may take to copy the upstream table into the downstream DB + indexer_worker_timeout: timedelta expressing the amount of time it may take for + any individual indexer worker to perform its portion of + the distributed reindex index_readiness_timeout: timedelta expressing amount of time it may take to await a healthy ES index after reindexing data_refresh_poke_interval: int number of seconds to wait between @@ -81,6 +85,7 @@ class DataRefreshConfig: default_args: dict = field(default_factory=dict) dag_timeout: timedelta = timedelta(days=1) copy_data_timeout: timedelta = timedelta(hours=1) + indexer_worker_timeout: timedelta = timedelta(hours=6) # TODO index_readiness_timeout: timedelta = timedelta(days=1) data_refresh_poke_interval: int = REFRESH_POKE_INTERVAL diff --git a/catalog/dags/data_refresh/distributed_reindex.py b/catalog/dags/data_refresh/distributed_reindex.py new file mode 100644 index 00000000000..adc243fc745 --- /dev/null +++ b/catalog/dags/data_refresh/distributed_reindex.py @@ -0,0 +1,499 @@ +""" +# Distributed Reindex TaskGroup + +TODO Docstring + +""" + +import logging +import math +from dataclasses import dataclass +from datetime import timedelta +from functools import cached_property +from textwrap import dedent +from urllib.parse import urlparse + +from airflow import settings +from airflow.decorators import task, task_group +from airflow.exceptions import AirflowSkipException +from airflow.models.connection import Connection +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.providers.http.operators.http import HttpOperator +from airflow.providers.http.sensors.http import HttpSensor +from airflow.sensors.base import BaseSensorOperator, PokeReturnValue +from airflow.utils.trigger_rule import TriggerRule +from requests import Response + +from common.constants import ( + AWS_ASG_CONN_ID, + OPENLEDGER_API_CONN_ID, + PRODUCTION, + REFRESH_POKE_INTERVAL, +) +from common.sql import PGExecuteQueryOperator, single_value +from data_refresh.data_refresh_types import DataRefreshConfig + + +logger = logging.getLogger(__name__) + + +WORKER_CONN_ID = "indexer_worker_{worker_id}_http_{environment}" + + +@dataclass +class AutoScalingGroupConfig: + name: str + instance_count: int + + +class AutoScalingGroupSensor(BaseSensorOperator): + """ + Sensor that waits for an AutoScalingGroup to have the desired number of healthy + instances in service, and returns the list of instance_ids once the condition + has been met. + """ + + def __init__( + self, + *, + environment: str, + asg_config: AutoScalingGroupConfig, + aws_conn_id: str = AWS_ASG_CONN_ID, + **kwargs, + ): + super().__init__(**kwargs) + self.environment = environment + self.asg_config = asg_config + self.aws_conn_id = aws_conn_id + + @cached_property + def conn(self): + if self.environment != PRODUCTION: + raise AirflowSkipException( + "Skipping interactions with ASG in local development environment." + ) + return AwsBaseHook( + aws_conn_id=self.aws_conn_id, client_type="autoscaling" + ).get_conn() + + def poke(self): + asg = self.conn.describe_auto_scaling_groups( + AutoScalingGroupNames=[self.asg_config.name] + ).get("AutoScalingGroups")[0] + + instances = asg.get("Instances") + + # Return True when the ASG has the desired number of instances, and all + # instances are healthy and in service. + if len(instances) == self.asg_config.instance_count and all( + instance.get("HealthStatus") == "HEALTHY" + and instance.get("LifecycleState" == "InService") + for instance in instances + ): + # Pass the list of instance_ids along through XCOMs once the conditions + # have been met + return PokeReturnValue( + is_done=True, + xcom_value=[instance.get("InstanceId") for instance in instances], + ) + + return PokeReturnValue(is_done=False, xcom_value=None) + + +@task +def get_autoscaling_group( + environment: str, + target_environment: str, + aws_conn_id: str = AWS_ASG_CONN_ID, +): + """Select the appropriate autoscaling group for the given environment.""" + # TODO pull this repeated code out. + if environment != PRODUCTION: + raise AirflowSkipException( + "Skipping interactions with ASG in local development environment." + ) + + asg_conn = AwsBaseHook( + aws_conn_id=aws_conn_id, client_type="autoscaling" + ).get_conn() + + asgs = asg_conn.describe_auto_scaling_groups( + Filters=[{"Name": "tag:WorkerTargetEnvironment", "Values": [environment]}] + ).get("AutoScalingGroups") + + if len(asgs) != 1: + raise Exception( + f"Could not uniquely identify ASG for {environment} indexer workers." + ) + + target_asg = asgs[0] + return AutoScalingGroupConfig( + name=target_asg.get("AutoScalingGroupName"), + instance_count=target_asg.get("MaxSize"), + ) + + +@task +def set_asg_capacity( + environment: str, + asg_config: AutoScalingGroupConfig, + desired_capacity: int | None = None, + aws_conn_id: str = AWS_ASG_CONN_ID, +): + """Set the desired capacity of the autoscaling group to the desired number of instances.""" + if environment != PRODUCTION: + raise AirflowSkipException( + "Skipping interactions with ASG in local development environment." + ) + + asg_conn = AwsBaseHook( + aws_conn_id=aws_conn_id, client_type="autoscaling" + ).get_conn() + + if desired_capacity is None: + desired_capacity = asg_config.instance_count + + return asg_conn.set_desired_capacity( + AutoScalingGroupName=asg_config.asg_name, DesiredCapacity=desired_capacity + ) + + +@task +def get_worker_params( + estimated_record_count: int, + instance_ids: list[str], + environment: str, + target_environment: str, + aws_conn_id: str = AWS_ASG_CONN_ID, # TODO +): + """ + TODO Because we are using an ASG now, we have lost the ability to retry an individual worker -- + because once an individual worker has failed, it will always be terimated so the trigger task + cannot be retried. We can retry starting at the `start_workers` task but this will cause ALL + the workers to be spun up and assigned work. + + We could allow the ASG to spin up a replacement when the reindexing task errors, but: + (1) Unless someone notices and manually retries immediately, this instance will be kept live + (and costing money) in the meantime + (2) The trigger task will still attempt to hit the original instance; there's no way to tell + it the updated one. + + We could try starting up the instances one at a time, so instead of setting the desired capacity + to n all at once and then dynamically (trigger -> wait -> terminate) for each instance, we + would have n parallel task groups that (create_instance -> trigger -> wait -> terminate). But + this is not possible because of two limitations: + * set_desired_capacity has no concept of "increment": you have to give it the total capacity + you want for the ASG. So you cannot have n parallel tasks all incrementing the ASG capacity. + * even if you could, set_desired_capacity does not return the instance_ids of the instances + that get created as a result of the action. So there is no way for the `trigger` task to + know which instance "belongs" to its branch + + The only thing I can think is of to make this entire `distributed_reindex` taskgroup take a few + additional arguments: + * desired_number_of_workers + * index_ranges + + Maybe those are optional arguments, and when not passed the DAG assumes you want the max number of + workers for this environment and you want all records to be reindexed. But the arguments can be + optionally used to use a smaller number of workers, and reindex only portions of the total count. + This taskgroup would also be used to create a separate, distributed_reindex DAG. + + Then if a single indexer worker fails during a data refresh, we could manually run this other + DAG with desired_number_of_workers set to 1 and the faulty index range specified. When it passes + we just manually pass that step in the original data_refresh DagRun and continue. + """ + if environment != PRODUCTION: + # Point to the local catalog_indexer_worker Docker container + return { + "start_id": 0, + "end_id": estimated_record_count, + "instance_id": None, + "server": "http://catalog_indexer_worker:8003/", + } + + # Get the private IP addresses of the worker instances + ec2_hook = EC2Hook(aws_conn_id=aws_conn_id, api_type="client_type") + reservations = ec2_hook.describe_instances(instance_ids=instance_ids).get( + "Reservations" + ) + + # `Reservations` is a list of dicts, grouping instances by the run command that + # started them. To be safe we get instances matching our expected instanceIds across + # all reservations. + servers = { + instance.get("InstanceId"): instance.get("PrivateIpAddress") + for reservation in reservations + for instance in reservation.get("Instances") + } + + records_per_worker = math.floor(estimated_record_count / len(instance_ids)) + + params = [] + for worker_index, (instance_id, server) in servers.items(): + params.append( + { + "start_id": worker_index * records_per_worker, + "end_id": (1 + worker_index) * records_per_worker, + "instance_id": instance_id, + "server": f"http://{server}:8003/", + } + ) + + return params + + +def response_filter_status_check_endpoint(response: Response) -> str: + """ + Handle the response from the `trigger_reindex` task. + + This is used to get the status endpoint from the response, which is used to poll for the status + of the reindexing task. + """ + status_check_url = response.json()["status_check"] + return urlparse(status_check_url).path + + +def response_check_wait_for_completion(response: Response) -> bool: + """ + Handle the response for `wait_for_reindex` Sensor. + + Processes the response to determine whether the task can complete. + """ + data = response.json() + + if data["active"]: + # The reindex is still running. Poll again later. + return False + + if data["error"]: + raise ValueError("An error was encountered during reindexing.") + + logger.info(f"Reindexing done with {data['progress']}% completed.") + return True + + +@task +def create_connection(instance_id: str, server: str, target_environment: str): + worker_conn_id = WORKER_CONN_ID.format( + worker_id=instance_id, environment=target_environment + ) + # Create the Connection + Connection(conn_id=worker_conn_id, uri=server) + session = settings.Session() + session.commit() + + +def trigger_reindex( + worker_conn: str, + model_name: str, + table_name: str, + start_id: int, + end_id: int, + target_index: int, + target_environment: str, +) -> HttpOperator: + """Trigger the reindexing task on an indexer worker.""" + data = { + "model_name": model_name, + "table_name": table_name, + "target_index": target_index, + "start_id": start_id, + "end_id": end_id, + } + + # Create a temporary Connection. We do not persist to the db because the instance is temporary. + # TODO Alternative: split out a task to create the connection and persist to the db, add a + # cleanup task to drop the connection after the instance is terminated. + # worker_conn_id = WORKER_CONN_ID.format(worker_id=instance_id, environment=target_environment) + # worker_conn = Connection(conn_id=worker_conn_id, uri=server) + + return HttpOperator( + task_id="trigger_reindexing_task", + http_conn_id=worker_conn, + endpoint="task", + data=data, + response_check=lambda response: response.status_code == 202, + response_filter=response_filter_status_check_endpoint, + ) + + +@task +def wait_for_reindex( + worker_conn: str, + status_endpoint: str, + timeout: timedelta, + poke_interval: int = REFRESH_POKE_INTERVAL, # TODO +) -> HttpSensor: + """Wait for the reindexing task on an indexer worker to complete.""" + # Create a temporary Connection. We do not persist to the db because the worker is temporary. + # TODO Alternative: split out a task to create the connection and persist to the db, add a + # cleanup task to drop the connection after the instance is terminated. + # worker_conn_id = WORKER_CONN_ID.format(worker_id=instance_id, environment=target_environment) + # worker_conn = Connection(conn_id=worker_conn_id, uri=server) + + return HttpSensor( + task_id="wait_for_reindexing_task", + http_conn_id=worker_conn, + endpoint=status_endpoint, + method="GET", + response_check=response_check_wait_for_completion, + mode="reschedule", + poke_interval=poke_interval, + timeout=timeout.total_seconds(), + ) + + +@task +def terminate_indexer_worker( + environment: str, + instance_id: str, + aws_conn_id: str = AWS_ASG_CONN_ID, +): + """Terminate an individual indexer worker.""" + if environment != PRODUCTION: + raise AirflowSkipException( + "Skipping interactions with ASG in local development environment." + ) + + asg_conn = AwsBaseHook( + aws_conn_id=aws_conn_id, client_type="autoscaling" + ).get_conn() + + asg_conn.terminate_instance_in_auto_scaling_group( + InstanceId=instance_id, + # Tell the ASG not to spin up a new instance to replace the terminated one. + ShouldDecrementDesiredCapacity=True, + ) + + +@task +def drop_connection(worker_conn: str): + """Drop the connection to the now terminated instance.""" + session = settings.Session() + session.delete(worker_conn) + session.commit() + + +@task_group(group_id="reindex") +def reindex( + instance_id: str, + server: str, + model_name: str, + table_name: str, + target_index: str, + start_id: int, + end_id: int, + environment: str, + target_environment: str, + aws_conn_id: str = AWS_ASG_CONN_ID, +): + """ + Trigger a reindexing task on a remote indexer worker and wait for it to complete. Once done, + terminate the indexer worker instance. + """ + worker_conn = create_connection( + instance_id=instance_id, server=server, target_environment=target_environment + ) + + trigger_reindexing_task = trigger_reindex( + worker_conn=worker_conn, + model_name=model_name, + table_name=table_name, + start_id=start_id, + end_id=end_id, + target_index=target_index, + target_environment=target_environment, + ) + + wait_for_reindexing_task = wait_for_reindex( + worker_conn=worker_conn, + status_endpoint=trigger_reindexing_task, + timeout=timedelta(days=1), # TODO get from config + poke_interval=REFRESH_POKE_INTERVAL, # TODO get from config + ) + + terminate_instance = terminate_indexer_worker.override( + # Terminate the instance even if there is an upstream failure + trigger_rule=TriggerRule.ALL_DONE + )(environment=environment, instance_id=instance_id, aws_conn_id=aws_conn_id) + + drop_conn = drop_connection(worker_conn=worker_conn) + + wait_for_reindexing_task >> [terminate_instance, drop_conn] + + +@task_group(group_id="run_distributed_reindex") +def perform_distributed_reindex( + environment: str, + target_environment: str, # TODO, update types + target_index: str, + data_refresh_config: DataRefreshConfig, + aws_conn_id: str = AWS_ASG_CONN_ID, +): + """Perform the distributed reindex on a fleet of remote indexer workers.""" + estimated_record_count = PGExecuteQueryOperator( + task_id="get_estimated_record_count", + conn_id=OPENLEDGER_API_CONN_ID, + sql=dedent( + f""" + SELECT id FROM {data_refresh_config.media_type} + ORDER BY id DESC LIMIT 1; + """ + ), + handler=single_value, + return_last=True, + ) + + asg_config = get_autoscaling_group( + environment=environment, + target_environment=target_environment, + aws_conn_id=aws_conn_id, + ) + + start_workers = set_asg_capacity.override(task_id="start_indexer_workers")( + environment=environment, asg_config=asg_config, aws_conn_id=aws_conn_id + ) + + workers = AutoScalingGroupSensor( + task_id="wait_for_workers", + environment=environment, + asg_config=asg_config, + aws_conn_id=aws_conn_id, + ) + + start_workers >> workers + + worker_params = get_worker_params( + estimated_record_count=estimated_record_count, + instance_ids=workers, + environment=environment, + target_environment=target_environment, + aws_conn_id=aws_conn_id, + ) + + workers >> worker_params + + distributed_reindex = reindex.partial( + model_name=data_refresh_config.media_type, + table_name=data_refresh_config.media_type, + target_index=target_index, + environment=environment, + target_environment=target_environment, + aws_conn_id=aws_conn_id, + ).expand_kwargs(worker_params) + + # All workers should be terminated once work is complete, even if errors were encountered. + # However, it is possible to have live instances left over if a worker crashed during + # reindexing and the ASG spun up a replacement that the DAG is not tracking. We force the + # ASG capacity to 0 at the end of execution to ensure this does not happen. + terminate_workers = set_asg_capacity.override( + task_id="ensure_all_workers_terminated", trigger_rule=TriggerRule.ALL_DONE + )( + environment=environment, + asg_config=asg_config, + desired_capacity=0, + aws_conn_id=aws_conn_id, + ) + + distributed_reindex >> terminate_workers