Skip to content

Commit

Permalink
Specify queue serializer when creating Django RQ queue (#630)
Browse files Browse the repository at this point in the history
* Add serializer when instantiating a queue.

* Add test
  • Loading branch information
sophcass authored Nov 21, 2023
1 parent f57c2fa commit 0272ebc
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 11 deletions.
9 changes: 8 additions & 1 deletion django_rq/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def get_queue(
connection=None,
queue_class=None,
job_class=None,
serializer=None,
**kwargs
):
"""
Expand All @@ -176,6 +177,8 @@ def get_queue(
default_timeout = QUEUES[name].get('DEFAULT_TIMEOUT')
if connection is None:
connection = get_connection(name)
if serializer is None:
serializer = QUEUES[name].get('SERIALIZER')
queue_class = get_queue_class(QUEUES[name], queue_class)
return queue_class(
name,
Expand All @@ -184,6 +187,7 @@ def get_queue(
is_async=is_async,
job_class=job_class,
autocommit=autocommit,
serializer=serializer,
**kwargs
)

Expand All @@ -196,7 +200,10 @@ def get_queue_by_index(index):

config = QUEUES_LIST[int(index)]
return get_queue_class(config)(
config['name'], connection=get_redis_connection(config['connection_config']), is_async=config.get('ASYNC', True)
config['name'],
connection=get_redis_connection(config['connection_config']),
is_async=config.get('ASYNC', True),
serializer=config['connection_config'].get('SERIALIZER')
)

def get_scheduler_by_index(index):
Expand Down
6 changes: 6 additions & 0 deletions django_rq/tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@
'DB': 0,
'DEFAULT_TIMEOUT': 400,
},
'test_serializer': {
'HOST': REDIS_HOST,
'PORT': 6379,
'DB': 0,
'SERIALIZER': 'rq.serializers.JSONSerializer',
},
}
RQ = {
'AUTOCOMMIT': False,
Expand Down
9 changes: 9 additions & 0 deletions django_rq/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from redis.exceptions import ConnectionError
from rq import get_current_job, Queue
import rq
from rq.exceptions import NoSuchJobError
from rq.job import Job
from rq.registry import FinishedJobRegistry, ScheduledJobRegistry
Expand Down Expand Up @@ -457,6 +458,14 @@ def test_default_timeout(self):
queue = get_queue('test1')
self.assertEqual(queue._default_timeout, 400)

def test_get_queue_serializer(self):
"""
Test that the correct serializer is set on the queue.
"""
queue = get_queue('test_serializer')
self.assertEqual(queue.name, 'test_serializer')
self.assertEqual(queue.serializer, rq.serializers.JSONSerializer)


@override_settings(RQ={'AUTOCOMMIT': True})
class DecoratorTest(TestCase):
Expand Down
2 changes: 1 addition & 1 deletion django_rq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_jobs(queue, job_ids, registry=None):
1. If job data is not present in Redis, discard the result
2. If `registry` argument is supplied, delete empty jobs from registry
"""
jobs = Job.fetch_many(job_ids, connection=queue.connection)
jobs = Job.fetch_many(job_ids, connection=queue.connection, serializer=queue.serializer)
valid_jobs = []
for i, job in enumerate(jobs):
if job is None:
Expand Down
18 changes: 9 additions & 9 deletions django_rq/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def deferred_jobs(request, queue_index):

for job_id in job_ids:
try:
jobs.append(Job.fetch(job_id, connection=queue.connection))
jobs.append(Job.fetch(job_id, connection=queue.connection, serializer=queue.serializer))
except NoSuchJobError:
pass

Expand All @@ -316,7 +316,7 @@ def job_detail(request, queue_index, job_id):
queue = get_queue_by_index(queue_index)

try:
job = Job.fetch(job_id, connection=queue.connection)
job = Job.fetch(job_id, connection=queue.connection, serializer=queue.serializer)
except NoSuchJobError:
raise Http404("Couldn't find job with this ID: %s" % job_id)

Expand Down Expand Up @@ -353,7 +353,7 @@ def job_detail(request, queue_index, job_id):
def delete_job(request, queue_index, job_id):
queue_index = int(queue_index)
queue = get_queue_by_index(queue_index)
job = Job.fetch(job_id, connection=queue.connection)
job = Job.fetch(job_id, connection=queue.connection, serializer=queue.serializer)

if request.method == 'POST':
# Remove job id from queue and delete the actual job
Expand All @@ -376,10 +376,10 @@ def delete_job(request, queue_index, job_id):
def requeue_job_view(request, queue_index, job_id):
queue_index = int(queue_index)
queue = get_queue_by_index(queue_index)
job = Job.fetch(job_id, connection=queue.connection)
job = Job.fetch(job_id, connection=queue.connection, serializer=queue.serializer)

if request.method == 'POST':
requeue_job(job_id, connection=queue.connection)
requeue_job(job_id, connection=queue.connection, serializer=queue.serializer)
messages.info(request, 'You have successfully requeued %s' % job.id)
return redirect('rq_job_detail', queue_index, job_id)

Expand Down Expand Up @@ -433,7 +433,7 @@ def requeue_all(request, queue_index):
# Confirmation received
for job_id in job_ids:
try:
requeue_job(job_id, connection=queue.connection)
requeue_job(job_id, connection=queue.connection, serializer=queue.serializer)
count += 1
except NoSuchJobError:
pass
Expand Down Expand Up @@ -488,14 +488,14 @@ def actions(request, queue_index):

if request.POST['action'] == 'delete':
for job_id in job_ids:
job = Job.fetch(job_id, connection=queue.connection)
job = Job.fetch(job_id, connection=queue.connection, serializer=queue.serializer)
# Remove job id from queue and delete the actual job
queue.connection.lrem(queue.key, 0, job.id)
job.delete()
messages.info(request, 'You have successfully deleted %s jobs!' % len(job_ids))
elif request.POST['action'] == 'requeue':
for job_id in job_ids:
requeue_job(job_id, connection=queue.connection)
requeue_job(job_id, connection=queue.connection, serializer=queue.serializer)
messages.info(request, 'You have successfully requeued %d jobs!' % len(job_ids))
elif request.POST['action'] == 'stop':
stopped, failed_to_stop = stop_jobs(queue, job_ids)
Expand All @@ -513,7 +513,7 @@ def enqueue_job(request, queue_index, job_id):
"""Enqueue deferred jobs"""
queue_index = int(queue_index)
queue = get_queue_by_index(queue_index)
job = Job.fetch(job_id, connection=queue.connection)
job = Job.fetch(job_id, connection=queue.connection, serializer=queue.serializer)

if request.method == 'POST':
try:
Expand Down

0 comments on commit 0272ebc

Please sign in to comment.