Skip to content

Commit

Permalink
Fix duplicate trigger creation race condition (#20699)
Browse files Browse the repository at this point in the history
The process for queueing up a trigger, for execution by the TriggerRunner, is handled by the TriggerJob's `load_triggers` method.  It fetches the triggers that should be running according to the database, checks if they are running and if not it adds them to `TriggerRunner.to_create`.  The problem is tha there's a small window of time between the moment a trigger (upon termination) is purged from the `TriggerRunner.triggers` set,  and the time that the database is updated to reflect the trigger's doneness.  If `TriggerJob.load_triggers` runs during this window, the trigger will be added back to the `TriggerRunner.to_create` set and it will run again.

To resolve this what we do here is, before adding a trigger to the `to_create` queue, instead of comparing against the "running" triggers, we compare against all triggers known to the TriggerRunner instance.  When triggers move out of the `triggers` set they move into other data structures such as `events` and `failed_triggers` and `to_cancel`.  So we union all of these and only create those triggers which the database indicates should exist _and_ which are know already being handled (whatever state they may be in) by the TriggerRunner instance.
  • Loading branch information
dstandish authored Jan 6, 2022
1 parent 0ebd55e commit 16b8c47
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 6 deletions.
12 changes: 9 additions & 3 deletions airflow/jobs/triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,16 @@ def update_triggers(self, requested_trigger_ids: Set[int]):
# line's execution, but we consider that safe, since there's a strict
# add -> remove -> never again lifecycle this function is already
# handling.
current_trigger_ids = set(self.triggers.keys())
running_trigger_ids = set(self.triggers.keys())
known_trigger_ids = (
running_trigger_ids.union(x[0] for x in self.events)
.union(self.to_cancel)
.union(x[0] for x in self.to_create)
.union(self.failed_triggers)
)
# Work out the two difference sets
new_trigger_ids = requested_trigger_ids.difference(current_trigger_ids)
cancel_trigger_ids = current_trigger_ids.difference(requested_trigger_ids)
new_trigger_ids = requested_trigger_ids - known_trigger_ids
cancel_trigger_ids = running_trigger_ids - requested_trigger_ids
# Bulk-fetch new trigger records
new_triggers = Trigger.bulk_fetch(new_trigger_ids)
# Add in new triggers
Expand Down
136 changes: 133 additions & 3 deletions tests/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,53 @@
# specific language governing permissions and limitations
# under the License.

import asyncio
import datetime
import sys
import time
from threading import Thread

import pytest

from airflow.jobs.triggerer_job import TriggererJob
from airflow.models import Trigger
from airflow.jobs.triggerer_job import TriggererJob, TriggerRunner
from airflow.models import DagModel, DagRun, TaskInstance, Trigger
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import PythonOperator
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.triggers.testing import FailureTrigger, SuccessTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
from tests.test_utils.db import clear_db_runs
from tests.test_utils.db import clear_db_dags, clear_db_runs


class TimeDeltaTrigger_(TimeDeltaTrigger):
def __init__(self, delta, filename):
super().__init__(delta=delta)
self.filename = filename
self.delta = delta

async def run(self):
with open(self.filename, 'at') as f:
f.write('hi\n')
async for event in super().run():
yield event

def serialize(self):
return (
"tests.jobs.test_triggerer_job.TimeDeltaTrigger_",
{"delta": self.delta, "filename": self.filename},
)


@pytest.fixture(autouse=True)
def clean_database():
"""Fixture that cleans the database before and after every test."""
clear_db_runs()
clear_db_dags()
yield # Test runs here
clear_db_dags()
clear_db_runs()


Expand Down Expand Up @@ -159,6 +183,112 @@ def test_trigger_lifecycle(session):
job.runner.stop = True


@pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6")
def test_trigger_create_race_condition_18392(session, tmp_path):
"""
This verifies the resolution of race condition documented in github issue #18392.
Triggers are queued for creation by TriggerJob.load_triggers.
There was a race condition where multiple triggers would be created unnecessarily.
What happens is the runner completes the trigger and purges from the "running" list.
Then job.load_triggers is called and it looks like the trigger is not running but should,
so it queues it again.
The scenario is as follows:
1. job.load_triggers (trigger now queued)
2. runner.create_triggers (trigger now running)
3. job.handle_events (trigger still appears running so state not updated in DB)
4. runner.cleanup_finished_triggers (trigger completed at this point; trigger from "running" set)
5. job.load_triggers (trigger not running, but also not purged from DB, so it is queued again)
6. runner.create_triggers (trigger created again)
This test verifies that under this scenario only one trigger is created.
"""
path = tmp_path / 'test_trigger_bad_respawn.txt'

class TriggerRunner_(TriggerRunner):
"""We do some waiting for main thread looping"""

async def wait_for_job_method_count(self, method, count):
for _ in range(30):
await asyncio.sleep(0.1)
if getattr(self, f'{method}_count', 0) >= count:
break
else:
pytest.fail(f"did not observe count {count} in job method {method}")

async def create_triggers(self):
"""
On first run, wait for job.load_triggers to make sure they are queued
"""
if getattr(self, 'loop_count', 0) == 0:
await self.wait_for_job_method_count('load_triggers', 1)
await super().create_triggers()
self.loop_count = getattr(self, 'loop_count', 0) + 1

async def cleanup_finished_triggers(self):
"""On loop 1, make sure that job.handle_events was already called"""
if self.loop_count == 1:
await self.wait_for_job_method_count('handle_events', 1)
await super().cleanup_finished_triggers()

class TriggererJob_(TriggererJob):
"""We do some waiting for runner thread looping (and track calls in job thread)"""

def wait_for_runner_loop(self, runner_loop_count):
for _ in range(30):
time.sleep(0.1)
if getattr(self.runner, 'call_count', 0) >= runner_loop_count:
break
else:
pytest.fail("did not observe 2 loops in the runner thread")

def load_triggers(self):
"""On second run, make sure that runner has called create_triggers in its second loop"""
super().load_triggers()
self.runner.load_triggers_count = getattr(self.runner, 'load_triggers_count', 0) + 1
if self.runner.load_triggers_count == 2:
self.wait_for_runner_loop(runner_loop_count=2)

def handle_events(self):
super().handle_events()
self.runner.handle_events_count = getattr(self.runner, 'handle_events_count', 0) + 1

trigger = TimeDeltaTrigger_(delta=datetime.timedelta(microseconds=1), filename=path.as_posix())
trigger_orm = Trigger.from_object(trigger)
trigger_orm.id = 1
session.add(trigger_orm)

dag = DagModel(dag_id='test-dag')
dag_run = DagRun(dag.dag_id, run_id='abc', run_type='none')
ti = TaskInstance(PythonOperator(task_id='dummy-task', python_callable=print), run_id=dag_run.run_id)
ti.dag_id = dag.dag_id
ti.trigger_id = 1
session.add(dag)
session.add(dag_run)
session.add(ti)

session.commit()

job = TriggererJob_()
job.runner = TriggerRunner_()
thread = Thread(target=job._execute)
thread.start()
try:
for _ in range(40):
time.sleep(0.1)
# ready to evaluate after 2 loops
if getattr(job.runner, 'loop_count', 0) >= 2:
break
else:
pytest.fail("did not observe 2 loops in the runner thread")
finally:
job.runner.stop = True
job.runner.join()
thread.join()
instances = path.read_text().splitlines()
assert len(instances) == 1


@pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6")
def test_trigger_from_dead_triggerer(session):
"""
Expand Down

0 comments on commit 16b8c47

Please sign in to comment.