diff --git a/django/contrib/admin/templates/registration/password_reset_email.html b/django/contrib/admin/templates/registration/password_reset_email.html index de9dc79c525..665ea11bdbb 100644 --- a/django/contrib/admin/templates/registration/password_reset_email.html +++ b/django/contrib/admin/templates/registration/password_reset_email.html @@ -3,7 +3,7 @@ {% trans "Please go to the following page and choose a new password:" %} {% block reset_link %} -{{ protocol }}://{{ domain }}{% url 'django.contrib.auth.views.password_reset_confirm' uidb36=uid token=token %} +{{ protocol }}://{{ domain }}{% url 'django.contrib.auth.views.password_reset_confirm' uidb64=uid token=token %} {% endblock %} {% trans "Your username, in case you've forgotten:" %} {{ user.username }} diff --git a/django/contrib/auth/forms.py b/django/contrib/auth/forms.py index f0ef124b204..3868a57feec 100644 --- a/django/contrib/auth/forms.py +++ b/django/contrib/auth/forms.py @@ -2,7 +2,7 @@ from django.forms.util import flatatt from django.template import loader from django.utils.encoding import smart_str -from django.utils.http import int_to_base36 +from django.utils.http import urlsafe_base64_encode from django.utils.safestring import mark_safe from django.utils.translation import ugettext, ugettext_lazy as _ @@ -230,7 +230,7 @@ def save(self, domain_override=None, 'email': user.email, 'domain': domain, 'site_name': site_name, - 'uid': int_to_base36(user.id), + 'uid': urlsafe_base64_encode(str(user.id)), 'user': user, 'token': token_generator.make_token(user), 'protocol': use_https and 'https' or 'http', diff --git a/django/contrib/auth/management/__init__.py b/django/contrib/auth/management/__init__.py index b516507277a..1fd813dcef2 100644 --- a/django/contrib/auth/management/__init__.py +++ b/django/contrib/auth/management/__init__.py @@ -40,11 +40,10 @@ def create_permissions(app, created_models, verbosity, **kwargs): # Find all the Permissions that have a context_type for a model we're # looking for. We don't need to check for codenames since we already have # a list of the ones we're going to create. + + ctypes_pks = set(ct.pk for ct in ctypes) all_perms = set(auth_app.Permission.objects.filter( - content_type__in=ctypes, - ).values_list( - "content_type", "codename" - )) + content_type__in=ctypes_pks).values_list('content_type', 'codename')[:1000000]) objs = [ auth_app.Permission(codename=codename, name=name, content_type=ctype) diff --git a/django/contrib/auth/models.py b/django/contrib/auth/models.py index 4e1584970cf..3a8f6e22965 100644 --- a/django/contrib/auth/models.py +++ b/django/contrib/auth/models.py @@ -72,8 +72,7 @@ class Meta: verbose_name = _('permission') verbose_name_plural = _('permissions') unique_together = (('content_type', 'codename'),) - ordering = ('content_type__app_label', 'content_type__model', - 'codename') + ordering = ('codename',) def __unicode__(self): return u"%s | %s | %s" % ( diff --git a/django/contrib/auth/tests/auth_backends.py b/django/contrib/auth/tests/auth_backends.py index 0337ef14f37..7b37e92dd58 100644 --- a/django/contrib/auth/tests/auth_backends.py +++ b/django/contrib/auth/tests/auth_backends.py @@ -2,7 +2,9 @@ from django.contrib.auth.models import User, Group, Permission, AnonymousUser from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ImproperlyConfigured +from django.db import connection from django.test import TestCase +from django.utils import unittest class BackendTest(TestCase): @@ -98,6 +100,9 @@ def test_get_all_superuser_permissions(self): user = User.objects.get(username='test2') self.assertEqual(len(user.get_all_permissions()), len(Permission.objects.all())) +BackendTest = unittest.skipIf(not connection.features.supports_joins, + 'Requires JOIN support')(BackendTest) + class TestObj(object): pass @@ -194,6 +199,8 @@ def test_get_group_permissions(self): self.user3.groups.add(group) self.assertEqual(self.user3.get_group_permissions(TestObj()), set(['group_perm'])) +RowlevelBackendTest = unittest.skipIf(not connection.features.supports_joins, + 'Requires JOIN support')(RowlevelBackendTest) class AnonymousUserBackend(SimpleRowlevelBackend): supports_inactive_user = False diff --git a/django/contrib/auth/tests/forms.py b/django/contrib/auth/tests/forms.py index 3c890d4c8a8..591693eb1be 100644 --- a/django/contrib/auth/tests/forms.py +++ b/django/contrib/auth/tests/forms.py @@ -8,9 +8,9 @@ from django.test import TestCase from django.test.utils import override_settings from django.utils.encoding import force_unicode -from django.utils import translation +from django.utils import translation, unittest from django.utils.translation import ugettext as _ - +from django.db import router, connections class UserCreationFormTest(TestCase): @@ -213,6 +213,7 @@ class UserChangeFormTest(TestCase): fixtures = ['authtestdata.json'] + @unittest.skipIf(not connections[router.db_for_read(User)].features.supports_joins, 'Requires JOIN support') def test_username_validity(self): user = User.objects.get(username='testclient') data = {'username': 'not valid'} @@ -236,18 +237,21 @@ class Meta(UserChangeForm.Meta): # Just check we can create it form = MyUserForm({}) + @unittest.skipIf(not connections[router.db_for_read(User)].features.supports_joins, 'Requires JOIN support') def test_bug_17944_empty_password(self): user = User.objects.get(username='empty_password') form = UserChangeForm(instance=user) # Just check that no error is raised. form.as_table() + @unittest.skipIf(not connections[router.db_for_read(User)].features.supports_joins, 'Requires JOIN support') def test_bug_17944_unmanageable_password(self): user = User.objects.get(username='unmanageable_password') form = UserChangeForm(instance=user) # Just check that no error is raised. form.as_table() + @unittest.skipIf(not connections[router.db_for_read(User)].features.supports_joins, 'Requires JOIN support') def test_bug_17944_unknown_password_algorithm(self): user = User.objects.get(username='unknown_password') form = UserChangeForm(instance=user) diff --git a/django/contrib/auth/tests/models.py b/django/contrib/auth/tests/models.py index 5d0e4f1443c..2dff22b8c98 100644 --- a/django/contrib/auth/tests/models.py +++ b/django/contrib/auth/tests/models.py @@ -3,7 +3,8 @@ from django.test.utils import override_settings from django.contrib.auth.models import (Group, User, SiteProfileNotAvailable, UserManager) - +from django.db import router, connections +from django.utils import unittest class ProfileTestCase(TestCase): fixtures = ['authtestdata.json'] @@ -59,6 +60,7 @@ def test_group_natural_key(self): class LoadDataWithoutNaturalKeysTestCase(TestCase): fixtures = ['regular.json'] + @unittest.skipIf(not connections[router.db_for_read(User)].features.supports_joins, 'Requires JOIN support') def test_user_is_created_and_added_to_group(self): user = User.objects.get(username='my_username') group = Group.objects.get(name='my_group') @@ -70,6 +72,7 @@ def test_user_is_created_and_added_to_group(self): class LoadDataWithNaturalKeysTestCase(TestCase): fixtures = ['natural.json'] + @unittest.skipIf(not connections[router.db_for_read(User)].features.supports_joins, 'Requires JOIN support') def test_user_is_created_and_added_to_group(self): user = User.objects.get(username='my_username') group = Group.objects.get(name='my_group') diff --git a/django/contrib/auth/tests/templates/registration/password_reset_email.html b/django/contrib/auth/tests/templates/registration/password_reset_email.html index 1b9a48255a2..9fa039ee143 100644 --- a/django/contrib/auth/tests/templates/registration/password_reset_email.html +++ b/django/contrib/auth/tests/templates/registration/password_reset_email.html @@ -1 +1 @@ -{{ protocol }}://{{ domain }}/reset/{{ uid }}-{{ token }}/ \ No newline at end of file +{{ protocol }}://{{ domain }}/reset/{{ uid }}/{{ token }}/ \ No newline at end of file diff --git a/django/contrib/auth/tests/views.py b/django/contrib/auth/tests/views.py index d295bb8c108..ba81fb0d59e 100644 --- a/django/contrib/auth/tests/views.py +++ b/django/contrib/auth/tests/views.py @@ -70,7 +70,7 @@ def test_named_urls(self): ('password_reset', [], {}), ('password_reset_done', [], {}), ('password_reset_confirm', [], { - 'uidb36': 'aaaaaaa', + 'uidb64': 'aaaaaaa', 'token': '1111-aaaaa', }), ('password_reset_complete', [], {}), @@ -178,13 +178,13 @@ def test_confirm_invalid(self): def test_confirm_invalid_user(self): # Ensure that we get a 200 response for a non-existant user, not a 404 - response = self.client.get('/reset/123456-1-1/') + response = self.client.get('/reset/123456/1-1/') self.assertEqual(response.status_code, 200) self.assertTrue("The password reset link was invalid" in response.content) def test_confirm_overflow_user(self): # Ensure that we get a 200 response for a base36 user id that overflows int - response = self.client.get('/reset/zzzzzzzzzzzzz-1-1/') + response = self.client.get('/reset/zzzzzzzzzzzzz/1-1/') self.assertEqual(response.status_code, 200) self.assertTrue("The password reset link was invalid" in response.content) diff --git a/django/contrib/auth/urls.py b/django/contrib/auth/urls.py index c5e87ed2ebd..85b8b2869d3 100644 --- a/django/contrib/auth/urls.py +++ b/django/contrib/auth/urls.py @@ -12,7 +12,7 @@ url(r'^password_change/done/$', 'django.contrib.auth.views.password_change_done', name='password_change_done'), url(r'^password_reset/$', 'django.contrib.auth.views.password_reset', name='password_reset'), url(r'^password_reset/done/$', 'django.contrib.auth.views.password_reset_done', name='password_reset_done'), - url(r'^reset/(?P[0-9A-Za-z]{1,13})-(?P[0-9A-Za-z]{1,13}-[0-9A-Za-z]{1,20})/$', + url(r'^reset/(?P[0-9A-Za-z_\-]+)/(?P[0-9A-Za-z]{1,13}-[0-9A-Za-z]{1,20})/$', 'django.contrib.auth.views.password_reset_confirm', name='password_reset_confirm'), url(r'^reset/done/$', 'django.contrib.auth.views.password_reset_complete', name='password_reset_complete'), diff --git a/django/contrib/auth/views.py b/django/contrib/auth/views.py index 62599b87522..523bc280cf0 100644 --- a/django/contrib/auth/views.py +++ b/django/contrib/auth/views.py @@ -4,7 +4,7 @@ from django.core.urlresolvers import reverse from django.http import HttpResponseRedirect, QueryDict from django.template.response import TemplateResponse -from django.utils.http import base36_to_int, is_safe_url +from django.utils.http import is_safe_url, urlsafe_base64_decode from django.utils.translation import ugettext as _ from django.views.decorators.debug import sensitive_post_parameters from django.views.decorators.cache import never_cache @@ -174,7 +174,7 @@ def password_reset_done(request, # Doesn't need csrf_protect since no-one can guess the URL @sensitive_post_parameters() @never_cache -def password_reset_confirm(request, uidb36=None, token=None, +def password_reset_confirm(request, uidb64=None, token=None, template_name='registration/password_reset_confirm.html', token_generator=default_token_generator, set_password_form=SetPasswordForm, @@ -184,13 +184,13 @@ def password_reset_confirm(request, uidb36=None, token=None, View that checks the hash in a password reset link and presents a form for entering a new password. """ - assert uidb36 is not None and token is not None # checked by URLconf + assert uidb64 is not None and token is not None # checked by URLconf if post_reset_redirect is None: post_reset_redirect = reverse('django.contrib.auth.views.password_reset_complete') try: - uid_int = base36_to_int(uidb36) - user = User.objects.get(id=uid_int) - except (ValueError, User.DoesNotExist): + uid = urlsafe_base64_decode(str(uidb64)) + user = User.objects.get(id=uid) + except (TypeError, ValueError, User.DoesNotExist): user = None if user is not None and token_generator.check_token(user, token): diff --git a/django/contrib/contenttypes/tests.py b/django/contrib/contenttypes/tests.py index 3b7906c8129..fb870b896a7 100644 --- a/django/contrib/contenttypes/tests.py +++ b/django/contrib/contenttypes/tests.py @@ -2,14 +2,14 @@ import urllib -from django.db import models +from django.db import models, router, connections from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.views import shortcut from django.contrib.sites.models import Site from django.http import HttpRequest, Http404 from django.test import TestCase from django.utils.encoding import smart_str - +from django.utils import unittest class FooWithoutUrl(models.Model): """ @@ -114,6 +114,7 @@ def test_get_for_models_full_cache(self): FooWithUrl: ContentType.objects.get_for_model(FooWithUrl), }) + @unittest.skipIf(not connections[router.db_for_read(FooWithUrl)].features.supports_joins, 'Requires JOIN support') def test_shortcut_view(self): """ Check that the shortcut view (used for the admin "view on site" @@ -156,6 +157,7 @@ def test_shortcut_view_without_get_absolute_url(self): self.assertRaises(Http404, shortcut, request, user_ct.id, obj.id) + @unittest.skipIf(not connections[router.db_for_read(FooWithBrokenAbsoluteUrl)].features.supports_joins, 'Requires JOIN support') def test_shortcut_view_with_broken_get_absolute_url(self): """ Check that the shortcut view does not catch an AttributeError raised diff --git a/django/core/files/uploadhandler.py b/django/core/files/uploadhandler.py index 2afb79e0513..5f6f9061bee 100644 --- a/django/core/files/uploadhandler.py +++ b/django/core/files/uploadhandler.py @@ -84,7 +84,8 @@ def handle_raw_input(self, input_data, META, content_length, boundary, encoding= """ pass - def new_file(self, field_name, file_name, content_type, content_length, charset=None): + def new_file(self, field_name, file_name, content_type, content_length, + charset=None, content_type_extra=None): """ Signal that a new file has been started. @@ -96,6 +97,9 @@ def new_file(self, field_name, file_name, content_type, content_length, charset= self.content_type = content_type self.content_length = content_length self.charset = charset + if content_type_extra is None: + content_type_extra = {} + self.content_type_extra = content_type_extra def receive_data_chunk(self, raw_data, start): """ diff --git a/django/core/serializers/python.py b/django/core/serializers/python.py index 195bf11d246..1c88597d472 100644 --- a/django/core/serializers/python.py +++ b/django/core/serializers/python.py @@ -59,9 +59,21 @@ def handle_m2m_field(self, obj, field): if field.rel.through._meta.auto_created: if self.use_natural_keys and hasattr(field.rel.to, 'natural_key'): m2m_value = lambda value: value.natural_key() + self._current[field.name] = [m2m_value(related) + for related in getattr(obj, field.name).iterator()] + elif field.rel.get_related_field().primary_key: + m2m_value = lambda value: smart_unicode( + getattr(value, related_query.target_field_name + '_id'), + strings_only=True) + related_query = getattr(obj, field.name) + filters = {related_query.source_field_name: obj._get_pk_val()} + query = field.rel.through.objects.filter(**filters) + self._current[field.name] = sorted((m2m_value(m2m_entity) + for m2m_entity in query), + reverse=True) else: m2m_value = lambda value: smart_unicode(value._get_pk_val(), strings_only=True) - self._current[field.name] = [m2m_value(related) + self._current[field.name] = [m2m_value(related) for related in getattr(obj, field.name).iterator()] def getvalue(self): diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 14e9c4abeb4..705e3f46c72 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -340,6 +340,8 @@ class BaseDatabaseFeatures(object): has_select_for_update = False has_select_for_update_nowait = False + distinguishes_insert_from_update = True + supports_joins = True supports_select_related = True # Does the default test database allow multiple connections? @@ -778,6 +780,15 @@ def prep_for_like_query(self, x): # need not necessarily be implemented using "LIKE" in the backend. prep_for_iexact_query = prep_for_like_query + def value_to_db_auto(self, value): + """ + Transform an AutoField value to an object compatible with what is expected + by the backend driver for automatic keys. + """ + if value is None: + return None + return int(value) + def value_to_db_date(self, value): """ Transform a date value to an object compatible with what is expected diff --git a/django/db/models/base.py b/django/db/models/base.py index fc38224345f..ba1ecf46895 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -268,7 +268,8 @@ def __init__(self, db=None): self.db = db # If true, uniqueness validation checks will consider this a new, as-yet-unsaved object. # Necessary for correct validation of new instances of objects with explicit (non-auto) PKs. - # This impacts validation only; it has no effect on the actual save. + # Also used when connection.features.distinguishes_insert_from_update is false to identify + # when an instance has been newly created. self.adding = True class Model(object): @@ -365,6 +366,7 @@ def __init__(self, *args, **kwargs): pass if kwargs: raise TypeError("'%s' is an invalid keyword argument for this function" % kwargs.keys()[0]) + self._original_pk = self.pk if self._meta.pk is not None else None super(Model, self).__init__() signals.post_init.send(sender=self.__class__, instance=self) @@ -473,6 +475,7 @@ def save_base(self, raw=False, cls=None, origin=None, force_insert=False, ('raw', 'cls', and 'origin'). """ using = using or router.db_for_write(self.__class__, instance=self) + entity_exists = bool(not self._state.adding and self._original_pk == self.pk) assert not (force_insert and force_update) if cls is None: cls = self.__class__ @@ -518,7 +521,21 @@ def save_base(self, raw=False, cls=None, origin=None, force_insert=False, pk_set = pk_val is not None record_exists = True manager = cls._base_manager - if pk_set: + connection = connections[using] + + # TODO/NONREL: Some backends could emulate force_insert/_update + # with an optimistic transaction, but since it's costly we should + # only do it when the user explicitly wants it. + # By adding support for an optimistic locking transaction + # in Django (SQL: SELECT ... FOR UPDATE) we could even make that + # part fully reusable on all backends (the current .exists() + # check below isn't really safe if you have lots of concurrent + # requests. BTW, and neither is QuerySet.get_or_create). + try_update = connection.features.distinguishes_insert_from_update + if not try_update: + record_exists = False + + if try_update and pk_set: # Determine whether a record with the primary key already exists. if (force_update or (not force_insert and manager.using(using).filter(pk=pk_val).exists())): @@ -559,10 +576,16 @@ def save_base(self, raw=False, cls=None, origin=None, force_insert=False, # Once saved, this is no longer a to-be-added instance. self._state.adding = False + self._original_pk = self.pk + # Signal that the save is complete if origin and not meta.auto_created: + if connection.features.distinguishes_insert_from_update: + created = not record_exists + else: + created = not entity_exists signals.post_save.send(sender=origin, instance=self, - created=(not record_exists), raw=raw, using=using) + created=created, raw=raw, using=using) save_base.alters_data = True @@ -575,6 +598,9 @@ def delete(self, using=None): collector.collect([self]) collector.delete() + self._state.adding = False + self._original_pk = None + delete.alters_data = True def _get_FIELD_display(self, field): diff --git a/django/db/models/deletion.py b/django/db/models/deletion.py index 730847ef215..d5c36bd5b72 100644 --- a/django/db/models/deletion.py +++ b/django/db/models/deletion.py @@ -171,6 +171,10 @@ def collect(self, objs, source=None, nullable=False, collect_related=True, if related.model._meta.auto_created: self.add_batch(related.model, field, new_objs) else: + # No need to fetch related objects if we are not doing + # anything with them. + if field.rel.on_delete == DO_NOTHING: + continue sub_objs = self.related_objects(related, new_objs) if not sub_objs: continue diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 527a3c0b0e9..6a4289ed77e 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -503,12 +503,9 @@ def __repr__(self): return '<%s>' % path class AutoField(Field): - description = _("Integer") + description = _("Automatic key") empty_strings_allowed = False - default_error_messages = { - 'invalid': _(u"'%s' value must be an integer."), - } def __init__(self, *args, **kwargs): assert kwargs.get('primary_key', False) is True, \ @@ -519,22 +516,13 @@ def __init__(self, *args, **kwargs): def get_internal_type(self): return "AutoField" - def to_python(self, value): - if value is None: - return value - try: - return int(value) - except (TypeError, ValueError): - msg = self.error_messages['invalid'] % str(value) - raise exceptions.ValidationError(msg) - def validate(self, value, model_instance): pass - def get_prep_value(self, value): + def get_db_prep_value(self, value, connection, prepared=False): if value is None: - return None - return int(value) + return value + return connection.ops.value_to_db_auto(value) def contribute_to_class(self, cls, name): assert not cls._meta.has_auto_field, \ diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 2b1c24589a9..e086299b0f4 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -791,6 +791,12 @@ def results_iter(self): yield row + def has_results(self): + # This is always executed on a query clone, so we can modify self.query + self.query.add_extra({'a': 1}, None, None, None, None, None) + self.query.set_extra_mask(('a',)) + return bool(self.execute_sql(SINGLE)) + def execute_sql(self, result_type=MULTI): """ Run the query against the database and returns the result(s). The diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index ce11716abae..c25f09a3360 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -428,17 +428,15 @@ def get_count(self, using): def has_results(self, using): q = self.clone() - q.add_extra({'a': 1}, None, None, None, None, None) q.select = [] q.select_fields = [] q.default_cols = False q.select_related = False - q.set_extra_mask(('a',)) q.set_aggregate_mask(()) q.clear_ordering(True) q.set_limits(high=1) compiler = q.get_compiler(using=using) - return bool(compiler.execute_sql(SINGLE)) + return compiler.has_results() def combine(self, rhs, connector): """ diff --git a/django/http/multipartparser.py b/django/http/multipartparser.py index 477a08c05a7..4015c764965 100644 --- a/django/http/multipartparser.py +++ b/django/http/multipartparser.py @@ -171,8 +171,12 @@ def parse(self): file_name = self.IE_sanitize(unescape_entities(file_name)) content_type = meta_data.get('content-type', ('',))[0].strip() + content_type_extra = meta_data.get('content-type', (0, {}))[1] + if content_type_extra is None: + content_type_extra = {} + try: - charset = meta_data.get('content-type', (0,{}))[1].get('charset', None) + charset = content_type_extra.get('charset', None) except: charset = None @@ -187,7 +191,7 @@ def parse(self): try: handler.new_file(field_name, file_name, content_type, content_length, - charset) + charset, content_type_extra.copy()) except StopFutureHandlers: break diff --git a/django/test/client.py b/django/test/client.py index 6f3b73dde33..22cd8c49cc3 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -146,7 +146,10 @@ def encode_multipart(boundary, data): def encode_file(boundary, key, file): to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET) - content_type = mimetypes.guess_type(file.name)[0] + if hasattr(file, 'content_type'): + content_type = file.content_type + else: + content_type = mimetypes.guess_type(file.name)[0] if content_type is None: content_type = 'application/octet-stream' return [ diff --git a/django/utils/encoding.py b/django/utils/encoding.py index 292472320db..e40ee80ce60 100644 --- a/django/utils/encoding.py +++ b/django/utils/encoding.py @@ -48,7 +48,8 @@ def is_protected_type(obj): types.NoneType, int, long, datetime.datetime, datetime.date, datetime.time, - float, Decimal) + float, Decimal, + tuple, list, dict) ) def force_unicode(s, encoding='utf-8', strings_only=False, errors='strict'): diff --git a/django/utils/http.py b/django/utils/http.py index d2e4eb5adbf..94da0b4ece0 100644 --- a/django/utils/http.py +++ b/django/utils/http.py @@ -1,9 +1,11 @@ +import base64 import calendar import datetime import re import sys import urllib import urlparse +from binascii import Error as BinasciiError from email.utils import formatdate from django.utils.datastructures import MultiValueDict @@ -188,6 +190,16 @@ def int_to_base36(i): factor -= 1 return ''.join(base36) +def urlsafe_base64_encode(s): + return base64.urlsafe_b64encode(s).rstrip('\n=') + +def urlsafe_base64_decode(s): + assert isinstance(s, str) + try: + return base64.urlsafe_b64decode(s.ljust(len(s) + len(s) % 4, '=')) + except (LookupError, BinasciiError), e: + raise ValueError(e) + def parse_etags(etag_str): """ Parses a string with one or several etags passed in If-None-Match and diff --git a/docs/topics/http/file-uploads.txt b/docs/topics/http/file-uploads.txt index 05b58b91d16..46e2bf5394a 100644 --- a/docs/topics/http/file-uploads.txt +++ b/docs/topics/http/file-uploads.txt @@ -244,6 +244,11 @@ define the following methods/attributes: For :mimetype:`text/*` content-types, the character set (i.e. ``utf8``) supplied by the browser. Again, "trust but verify" is the best policy here. +.. attribute:: UploadedFile.content_type_extra + + A dict containing the extra parameters that were passed to the + content-type header. + .. attribute:: UploadedFile.temporary_file_path() Only files uploaded onto disk will have this method; it returns the full @@ -403,7 +408,7 @@ attributes: The default is 64*2\ :sup:`10` bytes, or 64 KB. -``FileUploadHandler.new_file(self, field_name, file_name, content_type, content_length, charset)`` +``FileUploadHandler.new_file(self, field_name, file_name, content_type, content_length, charset, content_type_extra)`` Callback signaling that a new file upload is starting. This is called before any data has been fed to any upload handlers. @@ -420,6 +425,9 @@ attributes: ``charset`` is the character set (i.e. ``utf8``) given by the browser. Like ``content_length``, this sometimes won't be provided. + ``content_type_extra`` is a dict containing the extra parameters that + were passed to the content-type header. + This method may raise a ``StopFutureHandlers`` exception to prevent future handlers from handling this file. diff --git a/tests/regressiontests/file_uploads/tests.py b/tests/regressiontests/file_uploads/tests.py index ffa2017a5e9..230fc3ad2c2 100644 --- a/tests/regressiontests/file_uploads/tests.py +++ b/tests/regressiontests/file_uploads/tests.py @@ -218,6 +218,16 @@ def test_empty_multipart_handled_gracefully(self): got = simplejson.loads(self.client.request(**r).content) self.assertEquals(got, {}) + def test_extra_content_type(self): + f = tempfile.NamedTemporaryFile() + f.write('a' * (2 ** 21)) + f.seek(0) + f.content_type = 'text/plain; blob-key=upload blob key; other=test' + + response = self.client.post("/file_uploads/content_type_extra/", {'f': f}) + got = simplejson.loads(response.content) + self.assertEqual(got['f'], 'upload blob key') + def test_custom_upload_handler(self): # A small file (under the 5M quota) smallfile = tempfile.NamedTemporaryFile() diff --git a/tests/regressiontests/file_uploads/uploadhandler.py b/tests/regressiontests/file_uploads/uploadhandler.py index b30ef136e9d..b66fbfa6691 100644 --- a/tests/regressiontests/file_uploads/uploadhandler.py +++ b/tests/regressiontests/file_uploads/uploadhandler.py @@ -2,7 +2,10 @@ Upload handlers to test the upload API. """ -from django.core.files.uploadhandler import FileUploadHandler, StopUpload +from django.core.files.uploadedfile import InMemoryUploadedFile +from django.core.files.uploadhandler import (FileUploadHandler, StopUpload, + StopFutureHandlers) +from StringIO import StringIO class QuotaUploadHandler(FileUploadHandler): @@ -33,3 +36,38 @@ class ErroringUploadHandler(FileUploadHandler): """A handler that raises an exception.""" def receive_data_chunk(self, raw_data, start): raise CustomUploadError("Oops!") + +class ContentTypeExtraUploadHandler(FileUploadHandler): + """ + File upload handler that handles content_type_extra + """ + + def new_file(self, *args, **kwargs): + super(ContentTypeExtraUploadHandler, self).new_file(*args, **kwargs) + self.blobkey = self.content_type_extra.get('blob-key', '') + self.file = StringIO() + self.file.write(self.blobkey) + self.active = self.blobkey is not None + if self.active: + raise StopFutureHandlers() + + def receive_data_chunk(self, raw_data, start): + """ + Add the data to the StringIO file. + """ + if not self.active: + return raw_data + + def file_complete(self, file_size): + if not self.active: + return + + self.file.seek(0) + return InMemoryUploadedFile( + file = self.file, + field_name = self.field_name, + name = self.file_name, + content_type = self.content_type, + size = file_size, + charset = self.charset + ) diff --git a/tests/regressiontests/file_uploads/urls.py b/tests/regressiontests/file_uploads/urls.py index fc5576828f6..cfd2663471e 100644 --- a/tests/regressiontests/file_uploads/urls.py +++ b/tests/regressiontests/file_uploads/urls.py @@ -6,14 +6,15 @@ urlpatterns = patterns('', - (r'^upload/$', views.file_upload_view), - (r'^verify/$', views.file_upload_view_verify), - (r'^unicode_name/$', views.file_upload_unicode_name), - (r'^echo/$', views.file_upload_echo), - (r'^echo_content/$', views.file_upload_echo_content), - (r'^quota/$', views.file_upload_quota), - (r'^quota/broken/$', views.file_upload_quota_broken), - (r'^getlist_count/$', views.file_upload_getlist_count), - (r'^upload_errors/$', views.file_upload_errors), - (r'^filename_case/$', views.file_upload_filename_case_view), + (r'^upload/$', views.file_upload_view), + (r'^verify/$', views.file_upload_view_verify), + (r'^unicode_name/$', views.file_upload_unicode_name), + (r'^echo/$', views.file_upload_echo), + (r'^echo_content/$', views.file_upload_echo_content), + (r'^quota/$', views.file_upload_quota), + (r'^quota/broken/$', views.file_upload_quota_broken), + (r'^getlist_count/$', views.file_upload_getlist_count), + (r'^upload_errors/$', views.file_upload_errors), + (r'^filename_case/$', views.file_upload_filename_case_view), + (r'^content_type_extra/$', views.file_upload_content_type_extra), ) diff --git a/tests/regressiontests/file_uploads/views.py b/tests/regressiontests/file_uploads/views.py index 9fd1a8d5449..0dd5f04b8e8 100644 --- a/tests/regressiontests/file_uploads/views.py +++ b/tests/regressiontests/file_uploads/views.py @@ -9,7 +9,8 @@ from .models import FileModel, UPLOAD_TO from .tests import UNICODE_FILENAME -from .uploadhandler import QuotaUploadHandler, ErroringUploadHandler +from .uploadhandler import (QuotaUploadHandler, ErroringUploadHandler, + ContentTypeExtraUploadHandler) def file_upload_view(request): @@ -134,3 +135,8 @@ def file_upload_filename_case_view(request): obj = FileModel() obj.testfile.save(file.name, file) return HttpResponse('%d' % obj.pk) + +def file_upload_content_type_extra(request): + request.upload_handlers.insert(0, ContentTypeExtraUploadHandler()) + r = dict([(k, f.read()) for k, f in request.FILES.items()]) + return HttpResponse(simplejson.dumps(r))