Skip to content

Commit

Permalink
Merge github.com:nlundquist/django-celery-transactions into django-1.6
Browse files Browse the repository at this point in the history
Conflicts:
	.gitignore
	djcelery_transactions/__init__.py
	djcelery_transactions/transaction_signals.py
  • Loading branch information
nicolasgrasset committed Oct 31, 2014
2 parents 33868fc + abb2330 commit dadb14b
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ build
dist
*.egg-info
*.pyc
.idea
*~
89 changes: 64 additions & 25 deletions djcelery_transactions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# coding=utf-8
from celery.task import task as base_task, Task
import djcelery_transactions.transaction_signals
from django.db import transaction
from functools import partial
import threading
from celery import current_app

from celery import task as base_task, current_app, Task
from celery.contrib.batches import Batches
import django
from django.db import transaction

if django.VERSION >= (1,6):
from django.db.transaction import get_connection

import djcelery_transactions.transaction_signals

# Thread-local data (task queue).
_thread_data = threading.local()
Expand Down Expand Up @@ -39,30 +45,61 @@ def example(pk):

abstract = True

@classmethod
def original_apply_async(cls, *args, **kwargs):
def original_apply_async(self, *args, **kwargs):
"""Shortcut method to reach real implementation
of celery.Task.apply_sync
"""
return super(PostTransactionTask, self).apply_async(*args, **kwargs)

def apply_async(self, *args, **kwargs):
# Delay the task unless the client requested otherwise or transactions
# aren't being managed (i.e. the signal handlers won't send the task).
connection = get_connection()
if connection.in_atomic_block and not getattr(current_app.conf, 'CELERY_ALWAYS_EAGER', False):
_get_task_queue().append((self, args, kwargs))
else:
return self.original_apply_async(*args, **kwargs)


class PostTransactionBatches(Batches):
"""A batch of tasks whose queuing is delayed until after the current
transaction.
"""

abstract = True

def original_apply_async(self, *args, **kwargs):
"""Shortcut method to reach real implementation
of celery.Task.apply_sync
"""
return super(PostTransactionTask, cls).apply_async(*args, **kwargs)
return super(PostTransactionBatches, self).apply_async(*args, **kwargs)

@classmethod
def apply_async(cls, *args, **kwargs):
def apply_async(self, *args, **kwargs):
# Delay the task unless the client requested otherwise or transactions
# aren't being managed (i.e. the signal handlers won't send the task).

if transaction.is_managed() and not current_app.conf.CELERY_ALWAYS_EAGER:
if not transaction.is_dirty():
# Always mark the transaction as dirty
# because we push task in queue that must be fired or discarded
if 'using' in kwargs:
transaction.set_dirty(using=kwargs['using'])
else:
transaction.set_dirty()
_get_task_queue().append((cls, args, kwargs))
if django.VERSION < (1, 6):

if transaction.is_managed() and not current_app.conf.CELERY_ALWAYS_EAGER:
if not transaction.is_dirty():
# Always mark the transaction as dirty
# because we push task in queue that must be fired or discarded
if 'using' in kwargs:
transaction.set_dirty(using=kwargs['using'])
else:
transaction.set_dirty()
_get_task_queue().append((self, args, kwargs))
else:
apply_async_orig = super(PostTransactionTask, self).apply_async
return apply_async_orig(*args, **kwargs)

else:
apply_async_orig = super(PostTransactionTask, cls).apply_async
return apply_async_orig(*args, **kwargs)

connection = get_connection()
if connection.in_atomic_block and not getattr(current_app.conf, 'CELERY_ALWAYS_EAGER', False):
_get_task_queue().append((self, args, kwargs))
else:
return self.original_apply_async(*args, **kwargs)


def _discard_tasks(**kwargs):
Expand All @@ -80,11 +117,14 @@ def _send_tasks(**kwargs):
"""
queue = _get_task_queue()
while queue:
cls, args, kwargs = queue.pop(0)
apply_async_orig = cls.original_apply_async
if current_app.conf.CELERY_ALWAYS_EAGER:
apply_async_orig = transaction.autocommit()(apply_async_orig)
apply_async_orig(*args, **kwargs)
tsk, args, kwargs = queue.pop(0)
if django.VERSION < (1, 6):
apply_async_orig = tsk.original_apply_async
if current_app.conf.CELERY_ALWAYS_EAGER:
apply_async_orig = transaction.autocommit()(apply_async_orig)
apply_async_orig(*args, **kwargs)
else:
tsk.original_apply_async(*args, **kwargs)


# A replacement decorator.
Expand All @@ -93,4 +133,3 @@ def _send_tasks(**kwargs):
# Hook the signal handlers up.
transaction.signals.post_commit.connect(_send_tasks)
transaction.signals.post_rollback.connect(_discard_tasks)
transaction.signals.post_transaction_management.connect(_send_tasks)
225 changes: 166 additions & 59 deletions djcelery_transactions/transaction_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,22 @@ def _post_commit(**kwargs):
"""
from functools import partial
import thread
import django

from django.db import connections, DEFAULT_DB_ALIAS, DatabaseError

from django.dispatch import Signal

if django.VERSION >= (1,6):
from django.db import ProgrammingError
from django.db.transaction import get_connection

from django.db import transaction
try:
# Prior versions of Django 1.3
from django.db.transaction import state
except ImportError:
state = None
from django.dispatch import Signal


class TransactionSignals(object):
Expand All @@ -45,78 +53,177 @@ class TransactionSignals(object):
def __init__(self):
self.post_commit = Signal()
self.post_rollback = Signal()
self.post_transaction_management = Signal()

if django.VERSION < (1,6):
self.post_transaction_management = Signal()

# Access as django.db.transaction.signals.
transaction.signals = TransactionSignals()

if django.VERSION < (1,6):
def commit(old_function, *args, **kwargs):
# This will raise an exception if the commit fails. django.db.transaction
# decorators catch this and call rollback(), but the middleware doesn't.
old_function(*args, **kwargs)
transaction.signals.post_commit.send(None)

def commit(old_function, *args, **kwargs):
# This will raise an exception if the commit fails. django.db.transaction
# decorators catch this and call rollback(), but the middleware doesn't.
old_function(*args, **kwargs)
transaction.signals.post_commit.send(None)
def commit_unless_managed(old_function, *args, **kwargs):
old_function(*args, **kwargs)
if not transaction.is_managed():
transaction.signals.post_commit.send(None)


def commit_unless_managed(old_function, *args, **kwargs):
old_function(*args, **kwargs)
if not transaction.is_managed():
transaction.signals.post_commit.send(None)
# commit() isn't called at the end of a transaction management block if there
# were no changes. This function is always called so the signal is always sent.
def leave_transaction_management(old_function, *args, **kwargs):
# If the transaction is dirty, it is rolled back and an exception is
# raised. We need to send the rollback signal before that happens.
if transaction.is_dirty():
transaction.signals.post_rollback.send(None)

old_function(*args, **kwargs)
transaction.signals.post_transaction_management.send(None)

# commit() isn't called at the end of a transaction management block if there
# were no changes. This function is always called so the signal is always sent.
def leave_transaction_management(old_function, *args, **kwargs):
# If the transaction is dirty, it is rolled back and an exception is
# raised. We need to send the rollback signal before that happens.
if transaction.is_dirty():
transaction.signals.post_rollback.send(None)

old_function(*args, **kwargs)
transaction.signals.post_transaction_management.send(None)


def managed(old_function, *args, **kwargs):
# Turning transaction management off causes the current transaction to be
# committed if it's dirty. We must send the signal after the actual commit.
flag = kwargs.get('flag', args[0] if args else None)
if state is not None:
using = kwargs.get('using', args[1] if len(args) > 1 else None)
# Do not commit too early for prior versions of Django 1.3
thread_ident = thread.get_ident()
top = state.get(thread_ident, {}).get(using, None)
commit = top and not flag and transaction.is_dirty()
else:
commit = not flag and transaction.is_dirty()
old_function(*args, **kwargs)

if commit:
transaction.signals.post_commit.send(None)

def managed(old_function, *args, **kwargs):
# Turning transaction management off causes the current transaction to be
# committed if it's dirty. We must send the signal after the actual commit.
flag = kwargs.get('flag', args[0] if args else None)
if state is not None:
using = kwargs.get('using', args[1] if len(args) > 1 else None)
# Do not commit too early for prior versions of Django 1.3
thread_ident = thread.get_ident()
top = state.get(thread_ident, {}).get(using, None)
commit = top and not flag and transaction.is_dirty()
else:
commit = not flag and transaction.is_dirty()
old_function(*args, **kwargs)

def rollback(old_function, *args, **kwargs):
old_function(*args, **kwargs)
transaction.signals.post_rollback.send(None)
if commit:
transaction.signals.post_commit.send(None)


def rollback_unless_managed(old_function, *args, **kwargs):
old_function(*args, **kwargs)
if not transaction.is_managed():
def rollback(old_function, *args, **kwargs):
old_function(*args, **kwargs)
transaction.signals.post_rollback.send(None)


# Duck punching!
functions = (
commit,
commit_unless_managed,
leave_transaction_management,
managed,
rollback,
rollback_unless_managed,
)

for function in functions:
name = function.__name__
function = partial(function, getattr(transaction, name))
setattr(transaction, name, function)
def rollback_unless_managed(old_function, *args, **kwargs):
old_function(*args, **kwargs)
if not transaction.is_managed():
transaction.signals.post_rollback.send(None)


# Duck punching!
functions = (
commit,
commit_unless_managed,
leave_transaction_management,
managed,
rollback,
rollback_unless_managed,
)

for function in functions:
name = function.__name__
function = partial(function, getattr(transaction, name))
setattr(transaction, name, function)


else:

__original__exit__ = transaction.Atomic.__exit__

def __patched__exit__(self, exc_type, exc_value, traceback):
connection = get_connection(self.using)

if connection.savepoint_ids:
sid = connection.savepoint_ids.pop()
else:
# Prematurely unset this flag to allow using commit or rollback.
connection.in_atomic_block = False

try:
if connection.closed_in_transaction:
# The database will perform a rollback by itself.
# Wait until we exit the outermost block.
pass

elif exc_type is None and not connection.needs_rollback:
if connection.in_atomic_block:
# Release savepoint if there is one
if sid is not None:
try:
connection.savepoint_commit(sid)
transaction.signals.post_commit.send(None)
except DatabaseError:
try:
connection.savepoint_rollback(sid)
transaction.signals.post_rollback.send(None)
except Error:
# If rolling back to a savepoint fails, mark for
# rollback at a higher level and avoid shadowing
# the original exception.
connection.needs_rollback = True
raise
else:
# Commit transaction
try:
connection.commit()
transaction.signals.post_commit.send(None)
except DatabaseError:
try:
connection.rollback()
transaction.signals.post_rollback.send(None)
except Error:
# An error during rollback means that something
# went wrong with the connection. Drop it.
connection.close()
raise
else:
# This flag will be set to True again if there isn't a savepoint
# allowing to perform the rollback at this level.
connection.needs_rollback = False
if connection.in_atomic_block:
# Roll back to savepoint if there is one, mark for rollback
# otherwise.
if sid is None:
connection.needs_rollback = True
else:
try:
connection.savepoint_rollback(sid)
transaction.signals.post_rollback.send(None)
except Error:
# If rolling back to a savepoint fails, mark for
# rollback at a higher level and avoid shadowing
# the original exception.
connection.needs_rollback = True
else:
# Roll back transaction
try:
connection.rollback()
transaction.signals.post_rollback.send(None)
except Error:
# An error during rollback means that something
# went wrong with the connection. Drop it.
connection.close()

finally:
# Outermost block exit when autocommit was enabled.
if not connection.in_atomic_block:
if connection.closed_in_transaction:
connection.connection = None
elif connection.features.autocommits_when_autocommit_is_off:
connection.autocommit = True
else:
connection.set_autocommit(True)
# Outermost block exit when autocommit was disabled.
elif not connection.savepoint_ids and not connection.commit_on_exit:
if connection.closed_in_transaction:
connection.connection = None
else:
connection.in_atomic_block = False


# Monkey patch that shit
transaction.Atomic.__exit__ = __patched__exit__
Loading

0 comments on commit dadb14b

Please sign in to comment.