Skip to content

Commit

Permalink
Add distributed reindex steps
Browse files Browse the repository at this point in the history
  • Loading branch information
stacimc committed Jun 29, 2024
1 parent 0891970 commit 56b026b
Show file tree
Hide file tree
Showing 6 changed files with 581 additions and 27 deletions.
1 change: 1 addition & 0 deletions catalog/dags/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AWS_CONN_ID = "aws_default"
AWS_CLOUDWATCH_CONN_ID = os.environ.get("AWS_CLOUDWATCH_CONN_ID", AWS_CONN_ID)
AWS_RDS_CONN_ID = os.environ.get("AWS_RDS_CONN_ID", AWS_CONN_ID)
AWS_ASG_CONN_ID = os.environ.get("AWS_ASG_CONN_ID", AWS_CONN_ID)
ES_PROD_HTTP_CONN_ID = "elasticsearch_http_production"
REFRESH_POKE_INTERVAL = int(os.getenv("DATA_REFRESH_POKE_INTERVAL", 60 * 30))

Expand Down
18 changes: 18 additions & 0 deletions catalog/dags/common/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ def remove_excluded_index_settings(index_config):
return index_config


@task
def get_index_configuration_copy(
source_index: str, target_index_name: str, es_host: str
):
"""
Create a new index configuration based off the `source_index` but with the given
`target_index_name`, in the format needed for `create_index`. Removes fields that
should not be copied into a new index configuration such as the uuid.
"""
base_config = get_index_configuration.function(source_index, es_host)

cleaned_config = remove_excluded_index_settings.function(base_config)

cleaned_config["index"] = target_index_name

return cleaned_config


@task
def get_record_count_group_by_sources(es_host: str, index: str):
"""
Expand Down
9 changes: 5 additions & 4 deletions catalog/dags/data_refresh/copy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def copy_data(
def copy_upstream_table(
upstream_conn_id: str,
downstream_conn_id: str,
environment: Environment,
target_environment: Environment,
timeout: timedelta,
limit: int,
upstream_table_name: str,
Expand Down Expand Up @@ -275,12 +275,13 @@ def copy_upstream_table(
create_temp_table >> setup_id_columns >> setup_tertiary_columns
setup_tertiary_columns >> copy
copy >> add_primary_key

return


@task_group(group_id="copy_upstream_tables")
def copy_upstream_tables(
environment: Environment, data_refresh_config: DataRefreshConfig
target_environment: Environment, data_refresh_config: DataRefreshConfig
):
"""
For each upstream table associated with the given media type, create a new
Expand All @@ -290,7 +291,7 @@ def copy_upstream_tables(
This task does _not_ apply all indices and constraints, merely copies
the data.
"""
downstream_conn_id = POSTGRES_API_CONN_IDS.get(environment)
downstream_conn_id = POSTGRES_API_CONN_IDS.get(target_environment)
upstream_conn_id = POSTGRES_CONN_ID

create_fdw = _run_sql.override(task_id="create_fdw")(
Expand All @@ -309,7 +310,7 @@ def copy_upstream_tables(
copy_tables = copy_upstream_table.partial(
upstream_conn_id=upstream_conn_id,
downstream_conn_id=downstream_conn_id,
environment=environment,
target_environment=target_environment,
timeout=data_refresh_config.copy_data_timeout,
limit=limit,
).expand_kwargs([asdict(tm) for tm in data_refresh_config.table_mappings])
Expand Down
70 changes: 50 additions & 20 deletions catalog/dags/data_refresh/dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@

import logging
import os
import uuid
from collections.abc import Sequence
from itertools import product

from airflow import DAG
from airflow.decorators import task_group
from airflow.decorators import task, task_group
from airflow.operators.python import PythonOperator
from airflow.utils.trigger_rule import TriggerRule

Expand All @@ -44,6 +45,7 @@
from common.sensors.utils import wait_for_external_dags_with_tag
from data_refresh.copy_data import copy_upstream_tables
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefreshConfig
from data_refresh.distributed_reindex import perform_distributed_reindex
from data_refresh.reporting import report_record_difference


Expand Down Expand Up @@ -83,9 +85,14 @@ def wait_for_conflicting_dags(
)


@task
def generate_index_name(media_type: str) -> str:
return f"{media_type}-{uuid.uuid4().hex}"


def create_data_refresh_dag(
data_refresh_config: DataRefreshConfig,
environment: Environment,
target_environment: Environment,
external_dag_ids: Sequence[str],
):
"""
Expand All @@ -95,22 +102,22 @@ def create_data_refresh_dag(
Required Arguments:
data_refresh: dataclass containing configuration information for the
DAG
environment: the environment in which the data refresh is performed
external_dag_ids: list of ids of the other data refresh DAGs. The data refresh step
of this DAG will not run concurrently with the corresponding step
of any dependent DAG.
data_refresh: dataclass containing configuration information for the
DAG
target_environment: the environment in which the data refresh is performed
external_dag_ids: list of ids of the other data refresh DAGs. The data refresh step
of this DAG will not run concurrently with the corresponding step
of any dependent DAG.
"""
default_args = {
**DAG_DEFAULT_ARGS,
**data_refresh_config.default_args,
}

concurrency_tag = ES_CONCURRENCY_TAGS.get(environment)
concurrency_tag = ES_CONCURRENCY_TAGS.get(target_environment)

dag = DAG(
dag_id=f"{environment}_{data_refresh_config.dag_id}",
dag_id=f"{target_environment}_{data_refresh_config.dag_id}",
dagrun_timeout=data_refresh_config.dag_timeout,
default_args=default_args,
start_date=data_refresh_config.start_date,
Expand All @@ -120,14 +127,14 @@ def create_data_refresh_dag(
doc_md=__doc__,
tags=[
"data_refresh",
f"{environment}_data_refresh",
f"{target_environment}_data_refresh",
concurrency_tag,
],
)

with dag:
# Connect to the appropriate Elasticsearch cluster
es_host = es.get_es_host(environment=environment)
es_host = es.get_es_host(environment=target_environment)

# Get the current number of records in the target API table
before_record_count = es.get_record_count_group_by_sources.override(
Expand All @@ -142,11 +149,29 @@ def create_data_refresh_dag(
)

copy_data = copy_upstream_tables(
environment=environment, data_refresh_config=data_refresh_config
target_environment=target_environment,
data_refresh_config=data_refresh_config,
)

# TODO Cleaning steps

# TODO: Move temp_index_name >> index_config >> target_index into a `create_index` task group

# Generate a UUID suffix that will be used by the newly created index.
temp_index_name = generate_index_name(media_type=data_refresh_config.media_type)

# Get the configuration for the new Elasticsearch index, based off the existing index.
index_config = es.get_index_configuration_copy.override(
task_id="get_index_configuration"
)(
source_index=data_refresh_config.media_type,
target_index_name=temp_index_name,
es_host=es_host,
)

# Create a new index matching the existing configuration
target_index = es.create_index(index_config=index_config, es_host=es_host)

# Disable Cloudwatch alarms that are noisy during the reindexing steps of a
# data refresh.
disable_alarms = PythonOperator(
Expand All @@ -157,8 +182,13 @@ def create_data_refresh_dag(
},
)

# TODO create_and_populate_index
# (TaskGroup that creates index, triggers and waits for reindexing)
reindex = perform_distributed_reindex(
environment="{{ var.value.ENVIRONMENT }}",
target_environment=target_environment,
target_index=temp_index_name,
data_refresh_config=data_refresh_config,
)

# TODO create_and_populate_filtered_index

Expand Down Expand Up @@ -198,9 +228,9 @@ def create_data_refresh_dag(
)

# Set up task dependencies
before_record_count >> wait_for_dags >> copy_data >> disable_alarms
# TODO: this will include reindex/etc once added
disable_alarms >> [enable_alarms, after_record_count]
before_record_count >> wait_for_dags >> copy_data >> temp_index_name
temp_index_name >> index_config >> target_index >> disable_alarms
disable_alarms >> reindex >> [enable_alarms, after_record_count]
after_record_count >> report_counts

return dag
Expand All @@ -209,14 +239,14 @@ def create_data_refresh_dag(
# Generate data refresh DAGs for each DATA_REFRESH_CONFIG, per environment.
all_data_refresh_dag_ids = {refresh.dag_id for refresh in DATA_REFRESH_CONFIGS.values()}

for data_refresh_config, environment in product(
for data_refresh_config, target_environment in product(
DATA_REFRESH_CONFIGS.values(), ENVIRONMENTS
):
# Construct a set of all data refresh DAG ids other than the current DAG
other_dag_ids = all_data_refresh_dag_ids - {data_refresh_config.dag_id}

globals()[data_refresh_config.dag_id] = create_data_refresh_dag(
data_refresh_config,
environment,
[f"{environment}_{dag_id}" for dag_id in other_dag_ids],
target_environment,
[f"{target_environment}_{dag_id}" for dag_id in other_dag_ids],
)
11 changes: 8 additions & 3 deletions catalog/dags/data_refresh/data_refresh_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ class DataRefreshConfig:
Required Constructor Arguments:
media_type: str describing the media type to be refreshed.
table_mappings: list of TableMapping information for all tables that should be
refreshed as part of a data refresh for this media type.
media_type: str describing the media type to be refreshed.
table_mappings: list of TableMapping information for all tables that should
be refreshed as part of a data refresh for this media type.
Optional Constructor Arguments:
Expand All @@ -65,6 +66,9 @@ class DataRefreshConfig:
data refresh may take.
copy_data_timeout: timedelta expressing the amount of time it may take to
copy the upstream table into the downstream DB
indexer_worker_timeout: timedelta expressing the amount of time it may take for
any individual indexer worker to perform its portion of
the distributed reindex
index_readiness_timeout: timedelta expressing amount of time it may take
to await a healthy ES index after reindexing
data_refresh_poke_interval: int number of seconds to wait between
Expand All @@ -81,6 +85,7 @@ class DataRefreshConfig:
default_args: dict = field(default_factory=dict)
dag_timeout: timedelta = timedelta(days=1)
copy_data_timeout: timedelta = timedelta(hours=1)
indexer_worker_timeout: timedelta = timedelta(hours=6) # TODO
index_readiness_timeout: timedelta = timedelta(days=1)
data_refresh_poke_interval: int = REFRESH_POKE_INTERVAL

Expand Down
Loading

0 comments on commit 56b026b

Please sign in to comment.