Skip to content

Commit

Permalink
Add filtered index and promotion steps to data refresh (#4833)
Browse files Browse the repository at this point in the history
* Use hard coded index config for new index, instead of basing off existing index

The index configuration should be hard-coded here to match the behavior of the ingestion server. This is also important because we require permanent changes to the index configuration to go through the code review process. Other DAGs can be used to rapidly iterate on index configuration, and changes can be persisted here only if wanted.

* Get sensitive terms list

* Add steps to create and populate the filtered index

* Pull out run_sql utility

* First attempt to generate table indices using dynamic mapping

* Don't dynamically map separate tasks for creating each index

This isn't possible because dynamically mapping tasks within a dynamically mapped taskgroup is not supported in our version of Airflow.

* Add apply_constraints steps mirroring ingestion server implementation

* Correctly drop orphans referencing records that do not exist in the temp table

* Do not remap unique constraints

The problem is that in our local development environments, we initially set up the tables in
the API database slightly differently from the way they are set up in production. Specifically, for the three
unique field constraints (url, identifier, foreign_id/provider) we apply a UNIQUE CONSTRAINT using a statement
like this:

```
ALTER TABLE ONLY public.audio ADD CONSTRAINT audio_identifier_key UNIQUE (identifier);
```

Under the covers, postgres creates a unique index with this same name. In production, we do not have an
explicit CONSTRAINT, just the unique index. This is a subtle distinction and basically an
implementation detail that achieves the same thing. You can see the difference by running `DESCRIBE audio`
in each case:

Indices on the audio table in local development (**before running any data refresh**, even the one that is
 run in load_sample_data) -- not the 'CONSTRAINT':

```
Indexes:
    "audio_pkey" PRIMARY KEY, btree (id)
    "audio_identifier_key" UNIQUE CONSTRAINT, btree (identifier)
    "audio_url_key" UNIQUE CONSTRAINT, btree (url)
    "unique_provider_audio" UNIQUE CONSTRAINT, btree (foreign_identifier, provider)
    ...
```

Compared to indices on the production audio table:

```
Indexes:
    "audio_pkey" PRIMARY KEY, btree (id)
    "audio_identifier_key" UNIQUE, btree (identifier)
    "audio_url_key" UNIQUE, btree (url)
    "unique_provider_audio" UNIQUE, btree (foreign_identifier, provider)
    ...
```

In fact the only explicit CONSTRAINTS the production media tables have are primary key constraints
and some foreign key constraints.

Now: when a data refresh runs, after it creates the temp table and copies in all records from the catalog,
it first copies all the indices from the media table onto the temp table, and then copies over all the
constraints.

Remember that even if you have an explicit UNIQUE CONSTRAINT, postgres automatically gives you an index
as well. So in that first step, regardless of whether the table being operated on has CONSTRAINTs set explicitly,
the unique *indices* are all copied over.

In the second step, in theory the constraints should be copied over. The ingestion server generates a
bunch of ALTER TABLE statements to drop the constraints from the media table and then apply them to
the temp table instead. But they are all generated incorrectly and essentially do nothing. What we get looks like:

```
ALTER TABLE audio DROP CONSTRAINT audio_url_key
ALTER TABLE audio ADD UNIQUE (url)
```

We'd expect to see something like:

```
ALTER TABLE temp_import_audio ADD CONSTRAINT audio_url_key UNIQUE (url)
```

But if we fix this statement to be correct, it actually causes an error with the unique constraints --
because we have already copied over the indices in a previous step, and the index has the same name
as the constraint! It is not possible to add the CONSTRAINT to the existing index without changing
the name -- and at any rate, it's not necessary because the index enforces the uniqueness on its own,
and this is not happening in production.

So TLDR what the ingestion server is doing here ends up being a NO-OP, but it just happens to be okay
because those constraints are also supported via the index, and because the primary key constraint and foreign
key constraints that we actually need are handled differently (and correctly). Since these steps are
literally not doing anything, we can just delete them; but this will cause an issue in the future if
we ever add any new non-primary/fk constraints to the API media tables (they'll be wiped away during a
data refresh). But if we *fix* the generated ALTER table statements, we introduce a new and different
problem.

The solution proposed here is to exclude UNIQUE constraints from the remapping step, in the same way that
primary keys are excluded.

**This was also extremely confusing if you're debugging this stuff closely, because you'll notice all those
NO-OP ALTERs get run __only the first time a data refresh is run__. Thereafter, those CONSTRAINTS don't exist
and the output is no longer confusing. And once a data refresh is run, as part of the init scripts, the
evidence of the constraints is gone so it's hard to debug!** Fixing the local dev env to match
production is also recommended.

* Add promote table steps

* Promote indices

* Add map index template to dynamic promotion tasks

* Fix bug with index offsets when running consecutive data refreshes

* Add docstrings, split index/constraints out into separate files

* Do not promote if reindexing fails!

* Add tests

* Make filtered index timeouts configurable

* Ensure all copy data steps run one at a time

* Fix dependencies between index and constraint remapping

There was an error in the task dependencies caused by the fact that we were returning the output of transform_index_defs from the remap_indices taskgroup. Normally, the last task in a TaskGroup is the task that is set upstream of any downstream dependencies (in this case, the remap_constraints group). Apparently if you add a return statement to a TaskGroup (which we were doing in order to make the index_mapping the XCOM result for the entire taskgroup), this borks the dependency chain.

As a result, transform_index_defs was being set as the upstream task to get_all_existing_constraints. This meant that if there was an error in create_table_indices, it would NOT stop the remap_constraints tasks from running. By updating the create_table_indices task to return what we want from XCOMs as well, we can avoid this issue.

* Remove unused function, fix typo

* Correctly report record count of the promoted index, not the old one

* Use EmptyOperator where possible

* Clean up variables

* Fix incorrect comment

* Clean up unused code, variable name

* Further break out tests

* Protect against table names containing  character

* Comment on static route
  • Loading branch information
stacimc authored Sep 16, 2024
1 parent a4b1bf4 commit 8b2f037
Show file tree
Hide file tree
Showing 18 changed files with 1,134 additions and 66 deletions.
8 changes: 8 additions & 0 deletions catalog/dags/common/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def single_value(cursor):
raise ValueError("Unable to extract expected row data from cursor") from e


def fetch_all(cursor):
try:
rows = cursor.fetchall()
return [row[0] for row in rows]
except Exception as e:
raise ValueError("Unable to extract expected row data from cursor") from e


class PostgresHook(UpstreamPostgresHook):
"""
PostgresHook that sets the database timeout on any query to match the airflow task
Expand Down
53 changes: 19 additions & 34 deletions catalog/dags/data_refresh/copy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
PRODUCTION,
Environment,
)
from common.sql import RETURN_ROW_COUNT, PostgresHook
from common.sql import PostgresHook, run_sql
from data_refresh import queries
from data_refresh.data_refresh_types import DataRefreshConfig

Expand All @@ -36,28 +36,7 @@
DEFAULT_DATA_REFRESH_LIMIT = 10_000


@task
def _run_sql(
postgres_conn_id: str,
sql_template: str,
task: AbstractOperator = None,
timeout: float = None,
handler: callable = RETURN_ROW_COUNT,
**kwargs,
):
query = sql_template.format(**kwargs)

postgres = PostgresHook(
postgres_conn_id=postgres_conn_id,
default_statement_timeout=(
timeout if timeout else PostgresHook.get_execution_timeout(task)
),
)

return postgres.run(query, handler=handler)


@task
@task(max_active_tis_per_dagrun=1)
def initialize_fdw(
upstream_conn_id: str,
downstream_conn_id: str,
Expand All @@ -66,7 +45,7 @@ def initialize_fdw(
"""Create the FDW and prepare it for copying."""
upstream_connection = Connection.get_connection_from_secrets(upstream_conn_id)

_run_sql.function(
run_sql.function(
postgres_conn_id=downstream_conn_id,
sql_template=queries.CREATE_FDW_QUERY,
task=task,
Expand All @@ -78,7 +57,10 @@ def initialize_fdw(
)


@task(map_index_template="{{ task.op_kwargs['upstream_table_name'] }}")
@task(
max_active_tis_per_dagrun=1,
map_index_template="{{ task.op_kwargs['upstream_table_name'] }}",
)
def create_schema(downstream_conn_id: str, upstream_table_name: str) -> str:
"""
Create a new schema in the downstream DB through which the upstream table
Expand All @@ -104,7 +86,7 @@ def get_record_limit() -> int | None:
Airflow is running.
If a limit is explicitly configured, it is always used. Otherwise, production
defaults to no limit, and all other environments default to 100,000.
defaults to no limit, and all other environments default to 10,000.
"""
try:
# If a limit is explicitly configured, always use it.
Expand All @@ -122,7 +104,10 @@ def get_record_limit() -> int | None:
return DEFAULT_DATA_REFRESH_LIMIT


@task(map_index_template="{{ task.op_kwargs['upstream_table_name'] }}")
@task(
max_active_tis_per_dagrun=1,
map_index_template="{{ task.op_kwargs['upstream_table_name'] }}",
)
def get_shared_columns(
upstream_conn_id: str,
downstream_conn_id: str,
Expand Down Expand Up @@ -181,7 +166,7 @@ def copy_data(
LIMIT {limit};"""
)

return _run_sql.function(
return run_sql.function(
postgres_conn_id=postgres_conn_id,
sql_template=sql_template,
task=task,
Expand Down Expand Up @@ -223,7 +208,7 @@ def copy_upstream_table(
upstream_table_name=upstream_table_name,
)

create_temp_table = _run_sql.override(
create_temp_table = run_sql.override(
task_id="create_temp_table",
map_index_template="{{ task.op_kwargs['temp_table_name'] }}",
)(
Expand All @@ -233,7 +218,7 @@ def copy_upstream_table(
downstream_table_name=downstream_table_name,
)

setup_id_columns = _run_sql.override(
setup_id_columns = run_sql.override(
task_id="setup_id_columns",
map_index_template="{{ task.op_kwargs['temp_table_name'] }}",
)(
Expand All @@ -242,7 +227,7 @@ def copy_upstream_table(
temp_table_name=temp_table_name,
)

setup_tertiary_columns = _run_sql.override(
setup_tertiary_columns = run_sql.override(
task_id="setup_tertiary_columns",
map_index_template="{{ task.op_kwargs['temp_table_name'] }}",
)(
Expand All @@ -262,7 +247,7 @@ def copy_upstream_table(
columns=shared_cols,
)

add_primary_key = _run_sql.override(
add_primary_key = run_sql.override(
task_id="add_primary_key",
map_index_template="{{ task.op_kwargs['temp_table_name'] }}",
)(
Expand Down Expand Up @@ -293,7 +278,7 @@ def copy_upstream_tables(
downstream_conn_id = POSTGRES_API_CONN_IDS.get(target_environment)
upstream_conn_id = POSTGRES_CONN_ID

create_fdw = _run_sql.override(task_id="create_fdw")(
create_fdw = run_sql.override(task_id="create_fdw")(
postgres_conn_id=downstream_conn_id,
sql_template=queries.CREATE_FDW_EXTENSION_QUERY,
)
Expand All @@ -313,7 +298,7 @@ def copy_upstream_tables(
limit=limit,
).expand_kwargs([asdict(tm) for tm in data_refresh_config.table_mappings])

drop_fdw = _run_sql.override(task_id="drop_fdw")(
drop_fdw = run_sql.override(task_id="drop_fdw")(
postgres_conn_id=downstream_conn_id,
sql_template=queries.DROP_SERVER_QUERY,
)
Expand Down
105 changes: 105 additions & 0 deletions catalog/dags/data_refresh/create_and_populate_filtered_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
# Create and Promote Index
This file contains TaskGroups related to creating and populating the filtered Elasticsearch indices
as part of the Data Refresh.
TODO: We'll swap out the create and populate filtered index DAG to use this instead
of the one defined in the legacy_data_refresh.
"""

import logging
import uuid
from datetime import timedelta

from airflow.decorators import task, task_group
from airflow.providers.http.operators.http import HttpOperator
from airflow.utils.trigger_rule import TriggerRule
from requests import Response

from common import elasticsearch as es
from common.constants import MediaType
from data_refresh.es_mapping import index_settings


logger = logging.getLogger(__name__)

SENSITIVE_TERMS_CONN_ID = "sensitive_terms"


def response_filter_sensitive_terms_endpoint(response: Response) -> list[str]:
return [term.decode("utf-8").strip() for term in response.iter_lines()]


@task(trigger_rule=TriggerRule.NONE_FAILED)
def get_filtered_index_name(media_type: str, destination_index_name: str) -> str:
# If a destination index name is explicitly passed, use it.
if destination_index_name:
return destination_index_name

# Otherwise, generate an index name with a new UUID. This is useful when
# filtered index creation is run outside of a data refresh, because it
# avoids naming collisions when a filtered index already exists.
logger.info("Generating new destination index name.")
return f"{media_type}-{uuid.uuid4().hex}-filtered"


@task_group(group_id="create_and_populate_filtered_index")
def create_and_populate_filtered_index(
es_host: str,
media_type: MediaType,
origin_index_name: str,
timeout: timedelta,
destination_index_name: str | None = None,
):
"""
Create and populate a filtered index based on the given origin index, excluding
documents with sensitive terms.
"""
filtered_index_name = get_filtered_index_name(
media_type=media_type, destination_index_name=destination_index_name
)

create_filtered_index = es.create_index.override(
trigger_rule=TriggerRule.NONE_FAILED,
)(
index_config={
"index": filtered_index_name,
"body": index_settings(media_type),
},
es_host=es_host,
)

sensitive_terms = HttpOperator(
task_id="get_sensitive_terms",
http_conn_id=SENSITIVE_TERMS_CONN_ID,
method="GET",
response_check=lambda response: response.status_code == 200,
response_filter=response_filter_sensitive_terms_endpoint,
trigger_rule=TriggerRule.NONE_FAILED,
)

populate_filtered_index = es.trigger_and_wait_for_reindex(
es_host=es_host,
destination_index=filtered_index_name,
source_index=origin_index_name,
timeout=timeout,
requests_per_second="{{ var.value.get('ES_INDEX_THROTTLING_RATE', 20_000) }}",
query={
"bool": {
"must_not": [
# Use `terms` query for exact matching against unanalyzed raw fields
{"terms": {f"{field}.raw": sensitive_terms.output}}
for field in ["tags.name", "title", "description"]
]
}
},
refresh=False,
)

refresh_index = es.refresh_index(es_host=es_host, index_name=filtered_index_name)

sensitive_terms >> populate_filtered_index
create_filtered_index >> populate_filtered_index >> refresh_index

return filtered_index_name
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
# Create and Promote Index
# Create Index
This file contains TaskGroups related to creating and promoting Elasticsearch indices
This file contains TaskGroups related to creating Elasticsearch indices
as part of the Data Refresh.
"""

Expand All @@ -12,6 +12,7 @@

from common import elasticsearch as es
from data_refresh.data_refresh_types import DataRefreshConfig
from data_refresh.es_mapping import index_settings


logger = logging.getLogger(__name__)
Expand All @@ -30,17 +31,14 @@ def create_index(
# Generate a UUID suffix that will be used by the newly created index.
temp_index_name = generate_index_name(media_type=data_refresh_config.media_type)

# Get the configuration for the new Elasticsearch index, based off the existing index.
index_config = es.get_index_configuration_copy.override(
task_id="get_index_configuration"
)(
source_index=data_refresh_config.media_type,
target_index_name=temp_index_name,
# Create a new index
es.create_index(
index_config={
"index": temp_index_name,
"body": index_settings(data_refresh_config.media_type),
},
es_host=es_host,
)

# Create a new index matching the existing configuration
es.create_index(index_config=index_config, es_host=es_host)

# Return the name of the created index
return temp_index_name
48 changes: 41 additions & 7 deletions catalog/dags/data_refresh/dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@
from common.sensors.utils import wait_for_external_dags_with_tag
from data_refresh.alter_data import alter_table_data
from data_refresh.copy_data import copy_upstream_tables
from data_refresh.create_and_promote_index import create_index
from data_refresh.create_and_populate_filtered_index import (
create_and_populate_filtered_index,
)
from data_refresh.create_index import create_index
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefreshConfig
from data_refresh.distributed_reindex import perform_distributed_reindex
from data_refresh.promote_table import promote_tables
from data_refresh.reporting import report_record_difference


Expand Down Expand Up @@ -174,13 +178,21 @@ def create_data_refresh_dag(

# Populate the Elasticsearch index.
reindex = perform_distributed_reindex(
es_host=es_host,
environment="{{ var.value.ENVIRONMENT }}",
target_environment=target_environment,
target_index=target_index,
data_refresh_config=data_refresh_config,
)

# TODO create_and_populate_filtered_index
# Create and populate the filtered index
filtered_index = create_and_populate_filtered_index(
es_host=es_host,
media_type=data_refresh_config.media_type,
origin_index_name=target_index,
destination_index_name=f"{target_index}-filtered",
timeout=data_refresh_config.create_filtered_index_timeout,
)

# Re-enable Cloudwatch alarms once reindexing is complete, even if it
# failed.
Expand All @@ -193,9 +205,27 @@ def create_data_refresh_dag(
trigger_rule=TriggerRule.ALL_DONE,
)

# TODO Promote
# (TaskGroup that reapplies constraints, promotes new tables and indices,
# deletes old ones)
# Promote the API table
promote_table = promote_tables(
data_refresh_config=data_refresh_config,
target_environment=target_environment,
)

promote_index = es.point_alias.override(group_id="promote_index")(
es_host=es_host,
target_index=target_index,
target_alias=data_refresh_config.media_type,
should_delete_old_index=True,
)

promote_filtered_index = es.point_alias.override(
group_id="promote_filtered_index"
)(
es_host=es_host,
target_index=filtered_index,
target_alias=f"{data_refresh_config.media_type}-filtered",
should_delete_old_index=True,
)

# Get the final number of records in the API table after the refresh
after_record_count = es.get_record_count_group_by_sources.override(
Expand Down Expand Up @@ -226,9 +256,13 @@ def create_data_refresh_dag(
>> target_index
>> disable_alarms
>> reindex
>> filtered_index
)
reindex >> [enable_alarms, after_record_count]
after_record_count >> report_counts
# Note filtered_index must be directly upstream of promote_table to
# ensure that table promotion does not run if there was an error during reindexing
filtered_index >> [enable_alarms, promote_table]
promote_table >> [promote_index, promote_filtered_index]
promote_index >> after_record_count >> report_counts

return dag

Expand Down
Loading

0 comments on commit 8b2f037

Please sign in to comment.