diff --git a/formtools/utils.py b/formtools/utils.py index 8bfebfb..d7b408e 100644 --- a/formtools/utils.py +++ b/formtools/utils.py @@ -1,8 +1,29 @@ import pickle +from django.db.models import QuerySet from django.utils.crypto import salted_hmac +def sanitise(obj): + if type(obj) == list: + return [sanitise(o) for o in obj] + elif type(obj) == tuple: + return tuple([sanitise(o) for o in obj]) + elif type(obj) == QuerySet: + return [sanitise(o) for o in list(obj)] + try: + od = obj.__dict__ + nd = {'_class': obj.__class__} + for key, val in od.items(): + if not key.startswith('_'): + # ignore Django internal attributes + nd[key] = sanitise(val) + return nd + except Exception: + pass + return obj + + def form_hmac(form): """ Calculates a security hash for the given Form instance. @@ -19,6 +40,7 @@ def form_hmac(form): value = value.strip() data.append((bf.name, value)) - pickled = pickle.dumps(data, pickle.HIGHEST_PROTOCOL) + sanitised_data = sanitise(data) + pickled = pickle.dumps(sanitised_data, pickle.HIGHEST_PROTOCOL) key_salt = 'django.contrib.formtools' return salted_hmac(key_salt, pickled).hexdigest() diff --git a/tests/forms.py b/tests/forms.py index 3542978..65c3cdb 100644 --- a/tests/forms.py +++ b/tests/forms.py @@ -1,4 +1,26 @@ from django import forms +from django.db import models + + +class ManyModel(models.Model): + name = models.CharField(max_length=100) + + class Meta: + app_label = 'formtools' + + +class OtherModel(models.Model): + name = models.CharField(max_length=100) + manymodels = models.ManyToManyField(ManyModel) + + class Meta: + app_label = 'formtools' + + +class OtherModelForm(forms.ModelForm): + class Meta: + model = OtherModel + fields = '__all__' class TestForm(forms.Form): diff --git a/tests/tests.py b/tests/tests.py index a2f0778..eeaf98d 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -8,7 +8,9 @@ from formtools import preview, utils -from .forms import HashTestBlankForm, HashTestForm, TestForm +from .forms import ( + HashTestBlankForm, HashTestForm, ManyModel, OtherModelForm, TestForm, +) success_string = "Done was called!" success_string_encoded = success_string.encode() @@ -191,3 +193,28 @@ def test_empty_permitted(self): hash1 = utils.form_hmac(f1) hash2 = utils.form_hmac(f2) self.assertEqual(hash1, hash2) + + +class PicklingTests(unittest.TestCase): + + def setUp(self): + super(PicklingTests, self).setUp() + ManyModel.objects.create(name="jane") + + def test_queryset_hash(self): + """ + Regression test for #10034: the hash generation function should ignore + leading/trailing whitespace so as to be friendly to broken browsers that + submit it (usually in textareas). + """ + + qs1 = ManyModel.objects.all() + qs2 = ManyModel.objects.all() + + qs1._prefetch_done = True + qs2._prefetch_done = False + f1 = OtherModelForm({'name': 'joe', 'manymodels': qs1}) + f2 = OtherModelForm({'name': 'joe', 'manymodels': qs2}) + hash1 = utils.form_hmac(f1) + hash2 = utils.form_hmac(f2) + self.assertEqual(hash1, hash2)