From 4e450f34d5074c9bec1459b2da6659f9f6450e07 Mon Sep 17 00:00:00 2001 From: Nic Pottier Date: Wed, 24 Apr 2013 14:41:35 +0200 Subject: [PATCH 1/7] do not queue celeryt asks if always eager is on --- .gitignore | 1 + djcelery_transactions/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 77e4dbf..6de0223 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ build dist *.egg-info *.pyc +*~ diff --git a/djcelery_transactions/__init__.py b/djcelery_transactions/__init__.py index 40231b0..661f1a6 100644 --- a/djcelery_transactions/__init__.py +++ b/djcelery_transactions/__init__.py @@ -1,11 +1,11 @@ # coding=utf-8 from celery.task import task as base_task, Task +from celery import current_app import djcelery_transactions.transaction_signals from django.db import transaction from functools import partial import threading - # Thread-local data (task queue). _thread_data = threading.local() @@ -50,7 +50,7 @@ def original_apply_async(cls, *args, **kwargs): def apply_async(cls, *args, **kwargs): # Delay the task unless the client requested otherwise or transactions # aren't being managed (i.e. the signal handlers won't send the task). - if transaction.is_managed(): + if transaction.is_managed() and not getattr(current_app.conf, 'CELERY_ALWAYS_EAGER', False): if not transaction.is_dirty(): # Always mark the transaction as dirty # because we push task in queue that must be fired or discarded From 1c620fe42155ff789f2fcdabbee8dcb34e715250 Mon Sep 17 00:00:00 2001 From: Nic Pottier Date: Wed, 24 Apr 2013 19:53:55 +0200 Subject: [PATCH 2/7] deal with case when there are no args --- djcelery_transactions/transaction_signals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/djcelery_transactions/transaction_signals.py b/djcelery_transactions/transaction_signals.py index d9b1c67..918bca4 100644 --- a/djcelery_transactions/transaction_signals.py +++ b/djcelery_transactions/transaction_signals.py @@ -80,7 +80,7 @@ def leave_transaction_management(old_function, *args, **kwargs): def managed(old_function, *args, **kwargs): # Turning transaction management off causes the current transaction to be # committed if it's dirty. We must send the signal after the actual commit. - flag = kwargs.get('flag', args[0]) + flag = kwargs.get('flag', len(args) > 0 and args[0]) if state is not None: using = kwargs.get('using', args[1] if len(args) > 1 else None) # Do not commit too early for prior versions of Django 1.3 From 9555eba956e64db7754031f30c96d432de50d14c Mon Sep 17 00:00:00 2001 From: Nic Pottier Date: Tue, 11 Mar 2014 13:25:51 +0200 Subject: [PATCH 3/7] modifications for django 1.6, serious overhaul --- djcelery_transactions/__init__.py | 10 +- djcelery_transactions/transaction_signals.py | 136 +++++++++---------- 2 files changed, 63 insertions(+), 83 deletions(-) diff --git a/djcelery_transactions/__init__.py b/djcelery_transactions/__init__.py index 661f1a6..e9bd879 100644 --- a/djcelery_transactions/__init__.py +++ b/djcelery_transactions/__init__.py @@ -50,14 +50,7 @@ def original_apply_async(cls, *args, **kwargs): def apply_async(cls, *args, **kwargs): # Delay the task unless the client requested otherwise or transactions # aren't being managed (i.e. the signal handlers won't send the task). - if transaction.is_managed() and not getattr(current_app.conf, 'CELERY_ALWAYS_EAGER', False): - if not transaction.is_dirty(): - # Always mark the transaction as dirty - # because we push task in queue that must be fired or discarded - if 'using' in kwargs: - transaction.set_dirty(using=kwargs['using']) - else: - transaction.set_dirty() + if not getattr(current_app.conf, 'CELERY_ALWAYS_EAGER', False): _get_task_queue().append((cls, args, kwargs)) else: return cls.original_apply_async(*args, **kwargs) @@ -88,4 +81,3 @@ def _send_tasks(**kwargs): # Hook the signal handlers up. transaction.signals.post_commit.connect(_send_tasks) transaction.signals.post_rollback.connect(_discard_tasks) -transaction.signals.post_transaction_management.connect(_send_tasks) diff --git a/djcelery_transactions/transaction_signals.py b/djcelery_transactions/transaction_signals.py index 918bca4..3d8df19 100644 --- a/djcelery_transactions/transaction_signals.py +++ b/djcelery_transactions/transaction_signals.py @@ -30,12 +30,8 @@ def _post_commit(**kwargs): from functools import partial import thread -from django.db import transaction -try: - # Prior versions of Django 1.3 - from django.db.transaction import state -except ImportError: - state = None +from django.db import transaction, connections, DEFAULT_DB_ALIAS, DatabaseError, ProgrammingError +from django.db.transaction import get_connection from django.dispatch import Signal @@ -45,78 +41,70 @@ class TransactionSignals(object): def __init__(self): self.post_commit = Signal() self.post_rollback = Signal() - self.post_transaction_management = Signal() # Access as django.db.transaction.signals. transaction.signals = TransactionSignals() +__original__exit__ = transaction.Atomic.__exit__ +def __patched__exit__(self, exc_type, exc_value, trackback): + connection = get_connection(self.using) -def commit(old_function, *args, **kwargs): - # This will raise an exception if the commit fails. django.db.transaction - # decorators catch this and call rollback(), but the middleware doesn't. - old_function(*args, **kwargs) - transaction.signals.post_commit.send(None) - - -def commit_unless_managed(old_function, *args, **kwargs): - old_function(*args, **kwargs) - if not transaction.is_managed(): - transaction.signals.post_commit.send(None) - - -# commit() isn't called at the end of a transaction management block if there -# were no changes. This function is always called so the signal is always sent. -def leave_transaction_management(old_function, *args, **kwargs): - # If the transaction is dirty, it is rolled back and an exception is - # raised. We need to send the rollback signal before that happens. - if transaction.is_dirty(): - transaction.signals.post_rollback.send(None) - - old_function(*args, **kwargs) - transaction.signals.post_transaction_management.send(None) - - -def managed(old_function, *args, **kwargs): - # Turning transaction management off causes the current transaction to be - # committed if it's dirty. We must send the signal after the actual commit. - flag = kwargs.get('flag', len(args) > 0 and args[0]) - if state is not None: - using = kwargs.get('using', args[1] if len(args) > 1 else None) - # Do not commit too early for prior versions of Django 1.3 - thread_ident = thread.get_ident() - top = state.get(thread_ident, {}).get(using, None) - commit = top and not flag and transaction.is_dirty() + if connection.savepoint_ids: + sid = connection.savepoint_ids.pop() else: - commit = not flag and transaction.is_dirty() - old_function(*args, **kwargs) - - if commit: - transaction.signals.post_commit.send(None) - - -def rollback(old_function, *args, **kwargs): - old_function(*args, **kwargs) - transaction.signals.post_rollback.send(None) - - -def rollback_unless_managed(old_function, *args, **kwargs): - old_function(*args, **kwargs) - if not transaction.is_managed(): - transaction.signals.post_rollback.send(None) - - -# Duck punching! -functions = ( - commit, - commit_unless_managed, - leave_transaction_management, - managed, - rollback, - rollback_unless_managed, -) - -for function in functions: - name = function.__name__ - function = partial(function, getattr(transaction, name)) - setattr(transaction, name, function) + # Prematurely unset this flag to allow using commit or rollback. + connection.in_atomic_block = False + + try: + if exc_type is None and not connection.needs_rollback: + if connection.in_atomic_block: + # Release savepoint if there is one + if sid is not None: + try: + connection.savepoint_commit(sid) + transaction.signals.post_commit.send(None) + except DatabaseError: + connection.savepoint_rollback(sid) + transaction.signals.post_rollback.send(None) + raise + else: + # Commit transaction + try: + connection.commit() + transaction.signals.post_commit.send(None) + except DatabaseError: + connection.rollback() + transaction.signals.post_rollback.send(None) + raise + else: + # This flag will be set to True again if there isn't a savepoint + # allowing to perform the rollback at this level. + connection.needs_rollback = False + if connection.in_atomic_block: + # Roll back to savepoint if there is one, mark for rollback + # otherwise. + if sid is None: + connection.needs_rollback = True + else: + connection.savepoint_rollback(sid) + transaction.signals.post_rollback.send(None) + else: + # Roll back transaction + connection.rollback() + transaction.signals.post_rollback.send(None) + + finally: + # Outermost block exit when autocommit was enabled. + if not connection.in_atomic_block: + if connection.features.autocommits_when_autocommit_is_off: + connection.autocommit = True + else: + connection.set_autocommit(True) + # Outermost block exit when autocommit was disabled. + elif not connection.savepoint_ids and not connection.commit_on_exit: + connection.in_atomic_block = False + + +# Monkey patch that shit +transaction.Atomic.__exit__ = __patched__exit__ From e2137714dc66b7f95c8d69b2c38a7a6c0dadca97 Mon Sep 17 00:00:00 2001 From: Nic Pottier Date: Tue, 11 Mar 2014 13:52:53 +0200 Subject: [PATCH 4/7] only queue tasks for later if we are in an atomic block --- djcelery_transactions/__init__.py | 5 ++++- djcelery_transactions/transaction_signals.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/djcelery_transactions/__init__.py b/djcelery_transactions/__init__.py index e9bd879..bd47d87 100644 --- a/djcelery_transactions/__init__.py +++ b/djcelery_transactions/__init__.py @@ -6,6 +6,8 @@ from functools import partial import threading +from django.db.transaction import get_connection + # Thread-local data (task queue). _thread_data = threading.local() @@ -50,7 +52,8 @@ def original_apply_async(cls, *args, **kwargs): def apply_async(cls, *args, **kwargs): # Delay the task unless the client requested otherwise or transactions # aren't being managed (i.e. the signal handlers won't send the task). - if not getattr(current_app.conf, 'CELERY_ALWAYS_EAGER', False): + connection = get_connection() + if connection.in_atomic_block and not getattr(current_app.conf, 'CELERY_ALWAYS_EAGER', False): _get_task_queue().append((cls, args, kwargs)) else: return cls.original_apply_async(*args, **kwargs) diff --git a/djcelery_transactions/transaction_signals.py b/djcelery_transactions/transaction_signals.py index 3d8df19..f237e5b 100644 --- a/djcelery_transactions/transaction_signals.py +++ b/djcelery_transactions/transaction_signals.py @@ -101,7 +101,7 @@ def __patched__exit__(self, exc_type, exc_value, trackback): connection.autocommit = True else: connection.set_autocommit(True) - # Outermost block exit when autocommit was disabled. + # Outermost block exit when autocommit was disabled. elif not connection.savepoint_ids and not connection.commit_on_exit: connection.in_atomic_block = False From f7adf7f2d5ce0bcd7992e0d4f3624fb96a7d80b8 Mon Sep 17 00:00:00 2001 From: Nic Pottier Date: Tue, 17 Jun 2014 15:47:40 +0200 Subject: [PATCH 5/7] bring django celery transactions to 1.6.5 version --- djcelery_transactions/transaction_signals.py | 58 +++++++++++++++----- 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/djcelery_transactions/transaction_signals.py b/djcelery_transactions/transaction_signals.py index f237e5b..d1ae2de 100644 --- a/djcelery_transactions/transaction_signals.py +++ b/djcelery_transactions/transaction_signals.py @@ -47,7 +47,7 @@ def __init__(self): transaction.signals = TransactionSignals() __original__exit__ = transaction.Atomic.__exit__ -def __patched__exit__(self, exc_type, exc_value, trackback): +def __patched__exit__(self, exc_type, exc_value, traceback): connection = get_connection(self.using) if connection.savepoint_ids: @@ -57,7 +57,12 @@ def __patched__exit__(self, exc_type, exc_value, trackback): connection.in_atomic_block = False try: - if exc_type is None and not connection.needs_rollback: + if connection.closed_in_transaction: + # The database will perform a rollback by itself. + # Wait until we exit the outermost block. + pass + + elif exc_type is None and not connection.needs_rollback: if connection.in_atomic_block: # Release savepoint if there is one if sid is not None: @@ -65,8 +70,14 @@ def __patched__exit__(self, exc_type, exc_value, trackback): connection.savepoint_commit(sid) transaction.signals.post_commit.send(None) except DatabaseError: - connection.savepoint_rollback(sid) - transaction.signals.post_rollback.send(None) + try: + connection.savepoint_rollback(sid) + transaction.signals.post_rollback.send(None) + except Error: + # If rolling back to a savepoint fails, mark for + # rollback at a higher level and avoid shadowing + # the original exception. + connection.needs_rollback = True raise else: # Commit transaction @@ -74,8 +85,13 @@ def __patched__exit__(self, exc_type, exc_value, trackback): connection.commit() transaction.signals.post_commit.send(None) except DatabaseError: - connection.rollback() - transaction.signals.post_rollback.send(None) + try: + connection.rollback() + transaction.signals.post_rollback.send(None) + except Error: + # An error during rollback means that something + # went wrong with the connection. Drop it. + connection.close() raise else: # This flag will be set to True again if there isn't a savepoint @@ -87,23 +103,39 @@ def __patched__exit__(self, exc_type, exc_value, trackback): if sid is None: connection.needs_rollback = True else: - connection.savepoint_rollback(sid) - transaction.signals.post_rollback.send(None) + try: + connection.savepoint_rollback(sid) + transaction.signals.post_rollback.send(None) + except Error: + # If rolling back to a savepoint fails, mark for + # rollback at a higher level and avoid shadowing + # the original exception. + connection.needs_rollback = True else: # Roll back transaction - connection.rollback() - transaction.signals.post_rollback.send(None) + try: + connection.rollback() + transaction.signals.post_rollback.send(None) + except Error: + # An error during rollback means that something + # went wrong with the connection. Drop it. + connection.close() finally: # Outermost block exit when autocommit was enabled. if not connection.in_atomic_block: - if connection.features.autocommits_when_autocommit_is_off: + if connection.closed_in_transaction: + connection.connection = None + elif connection.features.autocommits_when_autocommit_is_off: connection.autocommit = True else: connection.set_autocommit(True) - # Outermost block exit when autocommit was disabled. + # Outermost block exit when autocommit was disabled. elif not connection.savepoint_ids and not connection.commit_on_exit: - connection.in_atomic_block = False + if connection.closed_in_transaction: + connection.connection = None + else: + connection.in_atomic_block = False # Monkey patch that shit From 63f68dd5b678d297379335f2e84ba394f39b5e92 Mon Sep 17 00:00:00 2001 From: Jakub Paczkowski Date: Wed, 16 Jul 2014 16:27:07 +0200 Subject: [PATCH 6/7] Celery 3.1 compatibility --- djcelery_transactions/__init__.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/djcelery_transactions/__init__.py b/djcelery_transactions/__init__.py index bd47d87..b8d6916 100644 --- a/djcelery_transactions/__init__.py +++ b/djcelery_transactions/__init__.py @@ -1,13 +1,13 @@ # coding=utf-8 -from celery.task import task as base_task, Task -from celery import current_app -import djcelery_transactions.transaction_signals -from django.db import transaction from functools import partial import threading +from celery import task as base_task, current_app, Task +from django.db import transaction from django.db.transaction import get_connection +import djcelery_transactions.transaction_signals + # Thread-local data (task queue). _thread_data = threading.local() @@ -41,22 +41,20 @@ def example(pk): abstract = True - @classmethod - def original_apply_async(cls, *args, **kwargs): + def original_apply_async(self, *args, **kwargs): """Shortcut method to reach real implementation of celery.Task.apply_sync """ - return super(PostTransactionTask, cls).apply_async(*args, **kwargs) + return super(PostTransactionTask, self).apply_async(*args, **kwargs) - @classmethod - def apply_async(cls, *args, **kwargs): + def apply_async(self, *args, **kwargs): # Delay the task unless the client requested otherwise or transactions # aren't being managed (i.e. the signal handlers won't send the task). connection = get_connection() if connection.in_atomic_block and not getattr(current_app.conf, 'CELERY_ALWAYS_EAGER', False): - _get_task_queue().append((cls, args, kwargs)) + _get_task_queue().append((self, args, kwargs)) else: - return cls.original_apply_async(*args, **kwargs) + return self.original_apply_async(*args, **kwargs) def _discard_tasks(**kwargs): @@ -74,8 +72,8 @@ def _send_tasks(**kwargs): """ queue = _get_task_queue() while queue: - cls, args, kwargs = queue.pop(0) - cls.original_apply_async(*args, **kwargs) + tsk, args, kwargs = queue.pop(0) + tsk.original_apply_async(*args, **kwargs) # A replacement decorator. From abb23304161899856488245de1e9529e06851a1f Mon Sep 17 00:00:00 2001 From: Nils Lundquist Date: Wed, 24 Sep 2014 14:29:22 -0600 Subject: [PATCH 7/7] Adding Batches subclass. --- djcelery_transactions/__init__.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/djcelery_transactions/__init__.py b/djcelery_transactions/__init__.py index b8d6916..21abed9 100644 --- a/djcelery_transactions/__init__.py +++ b/djcelery_transactions/__init__.py @@ -3,6 +3,7 @@ import threading from celery import task as base_task, current_app, Task +from celery.contrib.batches import Batches from django.db import transaction from django.db.transaction import get_connection @@ -57,6 +58,29 @@ def apply_async(self, *args, **kwargs): return self.original_apply_async(*args, **kwargs) +class PostTransactionBatches(Batches): + """A batch of tasks whose queuing is delayed until after the current + transaction. + """ + + abstract = True + + def original_apply_async(self, *args, **kwargs): + """Shortcut method to reach real implementation + of celery.Task.apply_sync + """ + return super(PostTransactionBatches, self).apply_async(*args, **kwargs) + + def apply_async(self, *args, **kwargs): + # Delay the task unless the client requested otherwise or transactions + # aren't being managed (i.e. the signal handlers won't send the task). + connection = get_connection() + if connection.in_atomic_block and not getattr(current_app.conf, 'CELERY_ALWAYS_EAGER', False): + _get_task_queue().append((self, args, kwargs)) + else: + return self.original_apply_async(*args, **kwargs) + + def _discard_tasks(**kwargs): """Discards all delayed Celery tasks.