diff --git a/django_celery_beat/schedulers.py b/django_celery_beat/schedulers.py index e2606866..cfc2fed9 100644 --- a/django_celery_beat/schedulers.py +++ b/django_celery_beat/schedulers.py @@ -15,7 +15,12 @@ from kombu.utils.json import dumps, loads from django.conf import settings -from django.db import transaction, close_old_connections +from django.db import ( + DEFAULT_DB_ALIAS, + close_old_connections, + router, + transaction +) from django.db.utils import DatabaseError, InterfaceError from django.core.exceptions import ObjectDoesNotExist @@ -258,7 +263,7 @@ def schedule_changed(self): # other transactions until the current transaction is # committed (Issue #41). try: - transaction.commit() + transaction.commit(using=self.target_db) except transaction.TransactionManagementError: pass # not in transaction management. @@ -287,7 +292,18 @@ def reserve(self, entry): self._dirty.add(new_entry.name) return new_entry - def sync(self): + @property + def target_db(self): + """Determine if there is a django route""" + if not settings.DATABASE_ROUTERS: + return DEFAULT_DB_ALIAS + # If the project does not actually implement this method, + # DEFAULT_DB_ALIAS will be automatically returned. + # The exception will be located to the django routing section + db = router.db_for_write(self.Model) + return db + + def _sync(self): if logger.isEnabledFor(logging.DEBUG): debug('Writing entries...') _tried = set() @@ -313,6 +329,10 @@ def sync(self): # retry later, only for the failed ones self._dirty |= _failed + def sync(self): + with transaction.atomic(using=self.target_db): + self._sync() + def update_from_dict(self, mapping): s = {} for name, entry_fields in mapping.items(): diff --git a/t/unit/test_schedulers.py b/t/unit/test_schedulers.py index 28a89b11..69a36a48 100644 --- a/t/unit/test_schedulers.py +++ b/t/unit/test_schedulers.py @@ -26,6 +26,15 @@ _ids = count(0) +class Router: + target_db = None + + def db_for_read(self, model, **hints): + return self.target_db + + db_for_write = db_for_read + + @pytest.fixture(autouse=True) def no_multiprocessing_finalizers(patching): patching('multiprocessing.util.Finalize') @@ -117,6 +126,10 @@ def create_interval_schedule(self): class test_ModelEntry(SchedulerCase): Entry = EntryTrackSave + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_entry(self): m = self.create_model_interval(schedule(timedelta(seconds=10))) e = self.Entry(m, app=self.app) @@ -148,7 +161,9 @@ def test_entry(self): @override_settings( USE_TZ=False, - DJANGO_CELERY_BEAT_TZ_AWARE=False + DJANGO_CELERY_BEAT_TZ_AWARE=False, + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] ) @pytest.mark.usefixtures('depends_on_current_app') @timezone.override('Europe/Berlin') @@ -180,7 +195,9 @@ def test_entry_is_due__no_use_tz(self): @override_settings( USE_TZ=False, - DJANGO_CELERY_BEAT_TZ_AWARE=False + DJANGO_CELERY_BEAT_TZ_AWARE=False, + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] ) @pytest.mark.usefixtures('depends_on_current_app') @timezone.override('Europe/Berlin') @@ -216,7 +233,9 @@ def test_entry_and_model_last_run_at_with_utc_no_use_tz(self, monkeypatch): USE_TZ=False, DJANGO_CELERY_BEAT_TZ_AWARE=False, TIME_ZONE="Europe/Berlin", - CELERY_TIMEZONE="America/New_York" + CELERY_TIMEZONE="America/New_York", + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] ) @pytest.mark.usefixtures('depends_on_current_app') @timezone.override('Europe/Berlin') @@ -247,6 +266,10 @@ def test_entry_is_due__celery_timezone_doesnt_match_time_zone(self): if hasattr(time, "tzset"): time.tzset() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_task_with_start_time(self): interval = 10 right_now = self.app.now() @@ -268,6 +291,10 @@ def test_task_with_start_time(self): assert not isdue assert delay == math.ceil((tomorrow - right_now).total_seconds()) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_one_off_task(self): interval = 10 right_now = self.app.now() @@ -296,6 +323,10 @@ class test_DatabaseSchedulerFromAppConf(SchedulerCase): Scheduler = TrackingScheduler @pytest.mark.django_db() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) @pytest.fixture(autouse=True) def setup_scheduler(self, app): self.app = app @@ -305,6 +336,10 @@ def setup_scheduler(self, app): self.m1 = PeriodicTask(name=self.entry_name, interval=self.create_interval_schedule()) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_constructor(self): s = self.Scheduler(app=self.app) @@ -312,6 +347,10 @@ def test_constructor(self): assert s._last_sync is None assert s.sync_every + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_periodic_task_model_enabled_schedule(self): s = self.Scheduler(app=self.app) sched = s.schedule @@ -325,6 +364,10 @@ def test_periodic_task_model_enabled_schedule(self): assert e.model.expires is None assert e.model.expire_seconds == 12 * 3600 + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_periodic_task_model_disabled_schedule(self): self.m1.enabled = False self.m1.save() @@ -342,6 +385,10 @@ class test_DatabaseScheduler(SchedulerCase): Scheduler = TrackingScheduler @pytest.mark.django_db() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) @pytest.fixture(autouse=True) def setup_scheduler(self, app): self.app = app @@ -386,11 +433,19 @@ def setup_scheduler(self, app): self.s = self.Scheduler(app=self.app) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_constructor(self): assert isinstance(self.s._dirty, set) assert self.s._last_sync is None assert self.s.sync_every + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_all_as_schedule(self): sched = self.s.schedule assert sched @@ -399,6 +454,10 @@ def test_all_as_schedule(self): for n, e in sched.items(): assert isinstance(e, self.s.Entry) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_schedule_changed(self): self.m2.args = '[16, 16]' self.m2.save() @@ -416,6 +475,10 @@ def test_schedule_changed(self): with pytest.raises(KeyError): self.s.schedule.__getitem__(self.m3.name) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_should_sync(self): assert self.s.should_sync() self.s._last_sync = monotonic() @@ -423,6 +486,10 @@ def test_should_sync(self): self.s._last_sync -= self.s.sync_every assert self.s.should_sync() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_reserve(self): e1 = self.s.schedule[self.m1.name] self.s.schedule[self.m1.name] = self.s.reserve(e1) @@ -433,6 +500,10 @@ def test_reserve(self): assert self.s.flushed == 1 assert self.m2.name in self.s._dirty + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_saves_last_run_at(self): e1 = self.s.schedule[self.m2.name] last_run = e1.last_run_at @@ -444,6 +515,10 @@ def test_sync_saves_last_run_at(self): e2 = self.s.schedule[self.m2.name] assert e2.last_run_at == last_run2 + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_syncs_before_save(self): # Get the entry for m2 e1 = self.s.schedule[self.m2.name] @@ -466,6 +541,10 @@ def test_sync_syncs_before_save(self): assert e3.last_run_at == e2.last_run_at assert e3.args == [16, 16] + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_periodic_task_disabled_and_enabled(self): # Get the entry for m2 e1 = self.s.schedule[self.m2.name] @@ -494,6 +573,10 @@ def test_periodic_task_disabled_and_enabled(self): assert self.m2.name in self.s.schedule assert self.s.flushed == 3 + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_periodic_task_disabled_while_reserved(self): # Get the entry for m2 e1 = self.s.schedule[self.m2.name] @@ -518,26 +601,46 @@ def test_periodic_task_disabled_while_reserved(self): assert self.m2.name not in self.s.schedule assert self.s.flushed == 2 + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_not_dirty(self): self.s._dirty.clear() self.s.sync() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_object_gone(self): self.s._dirty.add('does-not-exist') self.s.sync() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_sync_rollback_on_save_error(self): self.s.schedule[self.m1.name] = EntrySaveRaises(self.m1, app=self.app) self.s._dirty.add(self.m1.name) with pytest.raises(RuntimeError): self.s.sync() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_update_scheduler_heap_invalidation(self, monkeypatch): # mock "schedule_changed" to always trigger update for # all calls to schedule, as a change may occur at any moment monkeypatch.setattr(self.s, 'schedule_changed', lambda: True) self.s.tick() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_heap_size_is_constant(self, monkeypatch): # heap size is constant unless the schedule changes monkeypatch.setattr(self.s, 'schedule_changed', lambda: True) @@ -547,6 +650,10 @@ def test_heap_size_is_constant(self, monkeypatch): self.s.tick() assert len(self.s._heap) == expected_heap_size + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_scheduler_schedules_equality_on_change(self, monkeypatch): monkeypatch.setattr(self.s, 'schedule_changed', lambda: False) assert self.s.schedules_equal(self.s.schedule, self.s.schedule) @@ -554,6 +661,10 @@ def test_scheduler_schedules_equality_on_change(self, monkeypatch): monkeypatch.setattr(self.s, 'schedule_changed', lambda: True) assert not self.s.schedules_equal(self.s.schedule, self.s.schedule) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_heap_always_return_the_first_item(self): interval = 10 @@ -592,12 +703,20 @@ def test_heap_always_return_the_first_item(self): @pytest.mark.django_db() class test_models(SchedulerCase): + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_IntervalSchedule_unicode(self): assert (str(IntervalSchedule(every=1, period='seconds')) == 'every second') assert (str(IntervalSchedule(every=10, period='seconds')) == 'every 10 seconds') + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_CrontabSchedule_unicode(self): assert str(CrontabSchedule( minute=3, @@ -612,10 +731,18 @@ def test_CrontabSchedule_unicode(self): month_of_year='4,6', )) == '3 3 */2 4,6 tue (m/h/dM/MY/d) UTC' + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_interval(self): p = self.create_model_interval(schedule(timedelta(seconds=10))) assert str(p) == '{0}: every 10.0 seconds'.format(p.name) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_crontab(self): p = self.create_model_crontab(crontab( hour='4, 5', @@ -625,6 +752,10 @@ def test_PeriodicTask_unicode_crontab(self): p.name ) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_solar(self): p = self.create_model_solar( solar('solar_noon', 48.06, 12.86), name='solar_event' @@ -633,6 +764,10 @@ def test_PeriodicTask_unicode_solar(self): 'Solar noon', '48.06', '12.86' ) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_clocked(self): time = make_aware(datetime.now()) p = self.create_model_clocked( @@ -642,6 +777,10 @@ def test_PeriodicTask_unicode_clocked(self): 'clocked_event', str(time) ) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_schedule_property(self): p1 = self.create_model_interval(schedule(timedelta(seconds=10))) s1 = p1.schedule @@ -660,10 +799,18 @@ def test_PeriodicTask_schedule_property(self): assert s2.day_of_month == {1, 2, 3, 4, 5, 6, 7} assert s2.month_of_year == {1, 4, 7, 10} + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_PeriodicTask_unicode_no_schedule(self): p = self.create_model() assert str(p) == '{0}: {{no schedule}}'.format(p.name) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_CrontabSchedule_schedule(self): s = CrontabSchedule( minute='3, 7', @@ -678,6 +825,10 @@ def test_CrontabSchedule_schedule(self): assert s.schedule.day_of_month == {1, 16} assert s.schedule.month_of_year == {1, 7} + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_CrontabSchedule_long_schedule(self): s = CrontabSchedule( minute=str(list(range(60)))[1:-1], @@ -699,6 +850,10 @@ def test_CrontabSchedule_long_schedule(self): field_length = s._meta.get_field(field).max_length assert str_length <= field_length + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_SolarSchedule_schedule(self): s = SolarSchedule(event='solar_noon', latitude=48.06, longitude=12.86) dt = datetime(day=26, month=7, year=2050, hour=1, minute=0) @@ -720,6 +875,10 @@ def test_SolarSchedule_schedule(self): assert (nextcheck2 > 0) and (isdue2 is True) or \ (nextcheck2 == s2.max_interval) and (isdue2 is False) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_ClockedSchedule_schedule(self): due_datetime = make_aware(datetime(day=26, month=7, @@ -749,6 +908,10 @@ def test_ClockedSchedule_schedule(self): @pytest.mark.django_db() class test_model_PeriodicTasks(SchedulerCase): + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_track_changes(self): assert PeriodicTasks.last_change() is None m1 = self.create_model_interval(schedule(timedelta(seconds=10))) @@ -765,6 +928,10 @@ def test_track_changes(self): @pytest.mark.django_db() class test_modeladmin_PeriodicTaskAdmin(SchedulerCase): @pytest.mark.django_db() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) @pytest.fixture(autouse=True) def setup_scheduler(self, app): self.app = app @@ -785,6 +952,10 @@ def setup_scheduler(self, app): self.m2.task = 'celery.backend_cleanup' self.m2.save() + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def patch_request(self, request): """patch request to allow for django messages storage""" setattr(request, 'session', 'session') @@ -795,6 +966,10 @@ def patch_request(self, request): # don't hang if broker is down # https://github.com/celery/celery/issues/4627 @pytest.mark.timeout(5) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_run_task(self): ma = PeriodicTaskAdmin(PeriodicTask, self.site) self.request = self.patch_request(self.request_factory.get('/')) @@ -806,6 +981,10 @@ def test_run_task(self): # don't hang if broker is down # https://github.com/celery/celery/issues/4627 @pytest.mark.timeout(5) + @override_settings( + ROOT_URLCONF=__name__, + DATABASE_ROUTERS=['%s.Router' % __name__] + ) def test_run_tasks(self): ma = PeriodicTaskAdmin(PeriodicTask, self.site) self.request = self.patch_request(self.request_factory.get('/'))