Skip to content

Commit

Permalink
Update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
stacimc committed Aug 19, 2024
1 parent f14b621 commit 6eca45c
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions catalog/dags/data_refresh/distributed_reindex.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""
# Distributed Reindex TaskGroup
TODO Docstring
This module contains the Airflow tasks used for orchestrating the reindexing of records from the temporary tables in the downstream (API) database, into a new Elasticsearch index. Reindexing is performed on a fleet of indexer worker EC2 instances, with instance creation and termination managed by Airflow.
"""

import logging
Expand Down Expand Up @@ -106,6 +105,7 @@ def get_worker_params(
target_environment: str,
aws_conn_id: str = AWS_ASG_CONN_ID, # TODO
):
"""Determine the set of start/end indices to be passed to each indexer worker."""
# Defaults to one indexer worker in local development
worker_count = (
INDEXER_WORKER_COUNTS.get(target_environment)
Expand All @@ -131,16 +131,13 @@ def create_worker(
):
"""
Create a new EC2 instance using the launch template for the target
environment. In local development, skip.
environment. In local development, this step is skipped.
"""
if environment != PRODUCTION:
raise AirflowSkipException("Skipping instance creation in local environment.")

ec2_hook = EC2Hook(aws_conn_id=aws_conn_id, api_type="client_type")
instances = ec2_hook.conn.run_instances(
# We shouldn't need an ImageId because it is specified in the launch template
# That's another reason not to use the EC2CreateInstanceOperator, which does
# expect an ImageId
MinCount=1,
MaxCount=1,
LaunchTemplate={
Expand Down Expand Up @@ -180,12 +177,8 @@ def wait_for_worker(
instance_id: str,
aws_conn_id: str = AWS_ASG_CONN_ID,
):
"""
Awaits the EC2 instance with the given id to be in a healthy running state.
Once the instance is healthy, returns the worker ip address.
"""
"""Await the EC2 instance with the given id to be in a healthy running state."""
if environment != PRODUCTION:
# TODO return catalog_indexer_worker from this task?
raise AirflowSkipException("Skipping instance creation in local environment.")

ec2_hook = EC2Hook(aws_conn_id=aws_conn_id, api_type="client_type")
Expand Down Expand Up @@ -227,7 +220,7 @@ def get_instance_ip_address(


@task(
# Worker creation skips locally
# Instance creation tasks are skipped locally, but we still want this task to run.
trigger_rule=TriggerRule.NONE_FAILED
)
def create_connection(
Expand All @@ -236,6 +229,10 @@ def create_connection(
instance_id: str,
server: str,
):
"""
Create an Airflow Connection for the given indexer worker and persist it. It will
later be dropped in a cleanup task.
"""
worker_conn_id = f"indexer_worker_{instance_id or 'localhost'}"

# Create the Connection
Expand All @@ -261,14 +258,12 @@ def terminate_indexer_worker(
)

ec2_hook = EC2Hook(aws_conn_id=aws_conn_id, api_type="client_type")

# TODO wait for completion?
return ec2_hook.conn.terminate_instances(instance_ids=[instance_id])


@task(trigger_rule=TriggerRule.ALL_DONE)
def drop_connection(worker_conn: str):
"""Drop the connection to the now terminated instance."""
"""Drop the Connection to the now terminated instance."""
conn = Connection.get_connection_from_secrets(worker_conn)

session = settings.Session()
Expand All @@ -288,7 +283,6 @@ def reindex(
aws_conn_id: str = AWS_ASG_CONN_ID,
):
"""
TODO: Map the dynamic task names? Nothing useful to map them to since instance ids are not yet created
Trigger a reindexing task on a remote indexer worker and wait for it to complete. Once done,
terminate the indexer worker instance.
"""
Expand All @@ -309,8 +303,6 @@ def reindex(
environment=environment, instance_id=instance_id, aws_conn_id=aws_conn_id
)

instance_id >> [await_worker, instance_ip_address]

worker_conn = create_connection(
instance_id=instance_id,
server=instance_ip_address,
Expand All @@ -333,9 +325,6 @@ def reindex(
response_filter=response_filter_status_check_endpoint,
)

# TODO: Why does this have to be explicit?
worker_conn >> trigger_reindexing_task

wait_for_reindexing_task = TempConnectionHTTPSensor(
task_id="wait_for_reindexing_task",
http_conn_id=worker_conn,
Expand All @@ -344,7 +333,7 @@ def reindex(
response_check=response_check_wait_for_completion,
mode="reschedule",
poke_interval=REFRESH_POKE_INTERVAL,
timeout=24 * 60 * 60, # 1 day
timeout=24 * 60 * 60, # 1 day TODO
)

terminate_instance = terminate_indexer_worker.override(
Expand All @@ -357,7 +346,8 @@ def reindex(

drop_conn = drop_connection(worker_conn=worker_conn)

trigger_reindexing_task >> wait_for_reindexing_task
instance_id >> [await_worker, instance_ip_address]
worker_conn >> trigger_reindexing_task >> wait_for_reindexing_task
wait_for_reindexing_task >> [terminate_instance, drop_conn]


Expand Down Expand Up @@ -393,7 +383,6 @@ def perform_distributed_reindex(
aws_conn_id=aws_conn_id,
)

# TODO why does this have to be explicit?
estimated_record_count >> worker_params

reindex.partial(
Expand Down

0 comments on commit 6eca45c

Please sign in to comment.