diff --git a/restless/modelviews.py b/restless/modelviews.py index 2a545b0..3bf5e6f 100644 --- a/restless/modelviews.py +++ b/restless/modelviews.py @@ -1,4 +1,4 @@ -from django.forms.models import modelform_factory +from django.forms.models import modelform_factory, model_to_dict from .views import Endpoint from .http import HttpError, Http200, Http201 @@ -11,8 +11,9 @@ def _get_form(form, model): from django import VERSION - if VERSION[:2] >= (1,8): - mf = lambda m: modelform_factory(m, fields='__all__') + if VERSION[:2] >= (1, 8): + def mf(model): + return modelform_factory(model, fields='__all__') else: mf = modelform_factory @@ -95,7 +96,7 @@ def post(self, request, *args, **kwargs): if form.is_valid(): obj = form.save() return Http201(self.serialize(obj)) - + raise HttpError(400, 'Invalid Data', errors=form.errors) @@ -120,7 +121,7 @@ class variable. model = None form = None lookup_field = 'pk' - methods = ['GET', 'PUT', 'DELETE'] + methods = ['GET', 'PUT', 'PATCH', 'DELETE'] def get_instance(self, request, *args, **kwargs): """Return a model instance represented by this endpoint. @@ -169,6 +170,28 @@ def get(self, request, *args, **kwargs): return self.serialize(self.get_instance(request, *args, **kwargs)) + def patch(self, request, *args, **kwargs): + """Update the object represented by this endpoint.""" + + if 'PATCH' not in self.methods: + raise HttpError(405, 'Method Not Allowed') + + Form = _get_form(self.form, self.model) + instance = self.get_instance(request, *args, **kwargs) + + form_data = model_to_dict(instance) + form_data.update(request.data) + + form = Form( + form_data, + request.FILES, + instance=instance + ) + if form.is_valid(): + obj = form.save() + return Http200(self.serialize(obj)) + raise HttpError(400, 'Invalid data', errors=form.errors) + def put(self, request, *args, **kwargs): """Update the object represented by this endpoint.""" @@ -177,8 +200,7 @@ def put(self, request, *args, **kwargs): Form = _get_form(self.form, self.model) instance = self.get_instance(request, *args, **kwargs) - form = Form(request.data or None, request.FILES, - instance=instance) + form = Form(request.data or None, request.FILES, instance=instance) if form.is_valid(): obj = form.save() return Http200(self.serialize(obj)) diff --git a/setup.py b/setup.py index 0a43ab7..bce4fa5 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ def run(self): setup( name='DjangoRestless', - version='0.0.10', + version='0.0.11', author='Senko Rasic', author_email='senko.rasic@goodcode.io', description='A RESTful framework for Django', diff --git a/testproject/testapp/forms.py b/testproject/testapp/forms.py index 2aeb818..c078ece 100644 --- a/testproject/testapp/forms.py +++ b/testproject/testapp/forms.py @@ -1,6 +1,6 @@ from django import forms -from .models import * +from .models import Author __all__ = ['AuthorForm'] diff --git a/testproject/testapp/tests.py b/testproject/testapp/tests.py index c54d74b..fc20947 100644 --- a/testproject/testapp/tests.py +++ b/testproject/testapp/tests.py @@ -6,9 +6,8 @@ from decimal import Decimal import base64 import warnings -import six -from .models import * +from .models import Publisher, Author, Book from restless.models import serialize, flatten try: @@ -21,13 +20,15 @@ class TestClient(Client): @staticmethod def process(response): + response.json = None + decoded_response_content = response.content.decode('utf-8') + try: - raw_data = response.content.decode('utf-8') - response.json = json.loads(response.content.decode('utf-8')) - except: - response.json = None - finally: - return response + response.json = json.loads(decoded_response_content) + except json.decoder.JSONDecodeError: + pass + + return response def get(self, url_name, data={}, follow=False, extra={}, *args, **kwargs): return self.process( @@ -38,7 +39,7 @@ def get(self, url_name, data={}, follow=False, extra={}, *args, **kwargs): **extra)) def post(self, url_name, data={}, content_type=MULTIPART_CONTENT, - follow=False, extra={}, *args, **kwargs): + follow=False, extra={}, *args, **kwargs): return self.process( super(TestClient, self).post( reverse(url_name, args=args, kwargs=kwargs), @@ -47,6 +48,16 @@ def post(self, url_name, data={}, content_type=MULTIPART_CONTENT, follow=follow, **extra)) + def patch(self, url_name, data={}, content_type=MULTIPART_CONTENT, + follow=False, extra={}, *args, **kwargs): + return self.process( + super(TestClient, self).patch( + reverse(url_name, args=args, kwargs=kwargs), + content_type=content_type, + data=data, + follow=follow, + **extra)) + def put(self, url_name, data={}, content_type=MULTIPART_CONTENT, follow=False, *args, **kwargs): return self.process( @@ -55,7 +66,7 @@ def put(self, url_name, data={}, content_type=MULTIPART_CONTENT, content_type=content_type, data=data, follow=follow)) def delete(self, url_name, data={}, content_type=MULTIPART_CONTENT, - follow=False, *args, **kwargs): + follow=False, *args, **kwargs): return self.process( super(TestClient, self).delete( reverse(url_name, args=args, kwargs=kwargs), @@ -69,11 +80,13 @@ def setUp(self): self.publisher = Publisher.objects.create(name='Publisher') self.books = [] for i in range(10): - b = self.author.books.create(author=self.author, + b = self.author.books.create( + author=self.author, title='Book %d' % i, isbn='123-1-12-123456-%d' % i, price=Decimal("10.0"), - publisher=self.publisher) + publisher=self.publisher + ) self.books.append(b) def test_full_shallow(self): @@ -198,13 +211,14 @@ def test_serialize_queryset(self): a1 = Author.objects.create(name="foo") a2 = Author.objects.create(name="bar") qs = Author.objects.all() - _ = list(qs) # force sql query execution + list(qs) # force sql query execution # Check that the same (cached) queryset is used, instead of a clone with self.assertNumQueries(0): s = serialize(qs) - self.assertEqual(s, + self.assertEqual( + s, [ {'name': a1.name, 'id': a1.id}, {'name': a2.name, 'id': a2.id}, @@ -218,7 +232,8 @@ def test_serialize_list(self): a1 = Author.objects.create(name="foo") a2 = Author.objects.create(name="bar") s = serialize(list(Author.objects.all())) - self.assertEqual(s, + self.assertEqual( + s, [ {'name': a1.name, 'id': a1.id}, {'name': a2.name, 'id': a2.id}, @@ -249,7 +264,8 @@ def test_serialize_set(self): s = serialize(set(Author.objects.all())) self.assertTrue(isinstance(s, list)) # Must cast back to set to ignore ordering - self.assertEqual(sorted(s, key=lambda el: el['name']), + self.assertEqual( + sorted(s, key=lambda el: el['name']), [ {'name': a1.name, 'id': a1.id}, {'name': a2.name, 'id': a2.id}, @@ -279,8 +295,10 @@ def accessor(obj): # If fields are appended on, 'desc' will be twice in the list # for the second run, so in total the accessor function will be # run 3 instead of 2 times - serialize([self.author, self.author], fields=['id'], - include=[('desc', accessor)]) + serialize( + [self.author, self.author], fields=['id'], + include=[('desc', accessor)] + ) self.assertEqual(runs[0], 2) @@ -327,7 +345,7 @@ def test_create_author_form_encoded(self): self.assertEqual(r.status_code, 201) self.assertEqual(r.json['name'], 'New User') self.assertEqual(r.json['name'], - Author.objects.get(id=r.json['id']).name) + Author.objects.get(id=r.json['id']).name) def test_create_author_multipart(self): """Exercise multipart/form-data POST""" @@ -335,10 +353,10 @@ def test_create_author_multipart(self): r = self.client.post('author_list', data={ 'name': 'New User', }) # multipart/form-data is default in test client - self.assertEqual(r.status_code, 201) + self.assertEqual(r.status_code, 201, r.content) self.assertEqual(r.json['name'], 'New User') self.assertEqual(r.json['name'], - Author.objects.get(id=r.json['id']).name) + Author.objects.get(id=r.json['id']).name) def test_create_author_json(self): """Exercise application/json POST""" @@ -349,13 +367,13 @@ def test_create_author_json(self): self.assertEqual(r.status_code, 201) self.assertEqual(r.json['name'], 'New User') self.assertEqual(r.json['name'], - Author.objects.get(id=r.json['id']).name) + Author.objects.get(id=r.json['id']).name) def test_invalid_json_payload(self): """Exercise invalid JSON handling""" r = self.client.post('author_list', data='xyz', - content_type='application/json') + content_type='application/json') self.assertEqual(r.status_code, 400) def test_delete_author(self): @@ -373,8 +391,7 @@ def test_change_author(self): }), author_id=self.author.id, content_type='application/json') self.assertEqual(r.status_code, 200) self.assertEqual(r.json['name'], 'User Bar') - self.assertEqual(r.json['name'], - Author.objects.get(id=r.json['id']).name) + self.assertEqual(r.json['name'], Author.objects.get(id=r.json['id']).name) def test_view_failure(self): """Exercise exception handling""" @@ -388,11 +405,10 @@ def test_view_failure(self): def test_raw_request_body(self): raw = b'\x01\x02\x03' - r = self.client.post('echo_view', data=raw, - content_type='text/plain') + r = self.client.post('echo_view', data=raw, content_type='text/plain') self.assertEqual(base64.b64decode(r.json['raw_data'].encode('ascii')), - raw) + raw) def test_get_payload_is_ignored(self): """Test that body of the GET request is always ignored.""" @@ -437,10 +453,12 @@ def test_basic_auth_challenge(self): def test_basic_auth_succeeds(self): """Test that HTTP Basic Auth succeeds""" - r = self.client.get('basic_auth_view', extra={ - 'HTTP_AUTHORIZATION': 'Basic ' + + r = self.client.get( + 'basic_auth_view', extra={ + 'HTTP_AUTHORIZATION': 'Basic ' + base64.b64encode(b'foo:bar').decode('ascii'), - }) + } + ) self.assertEqual(r.status_code, 200) self.assertEqual(r.json['id'], self.user.id) @@ -461,8 +479,10 @@ def test_custom_auth_fn_returning_httpresponse_shortcuts_request(self): self.assertEqual(r.status_code, 403) def test_custom_auth_fn_raising_exception_shortcuts_request(self): - r = self.client.get('custom_auth_method', - data={'user': 'exceptional-foe'}) + r = self.client.get( + 'custom_auth_method', + data={'user': 'exceptional-foe'} + ) self.assertEqual(r.status_code, 403) def test_custom_auth_fn_with_invalid_return_value_is_a_bug(self): @@ -481,8 +501,10 @@ def setUp(self): self.client = TestClient() self.publisher = Publisher.objects.create(name='User Foo') self.author = Author.objects.create(name='User Foo') - self.book = self.author.books.create(author=self.author, title='Book', - isbn='1234', price=Decimal('10.0'), publisher=self.publisher) + self.book = self.author.books.create( + author=self.author, title='Book', + isbn='1234', price=Decimal('10.0'), publisher=self.publisher + ) def test_publisher_list(self): """Excercise listing objects via ListEndpoint""" @@ -495,9 +517,15 @@ def test_publisher_list(self): def test_publisher_create(self): """Excercise creating objects via ListEndpoint""" - r = self.client.post('publisher_list', data=json.dumps({ + payload = { 'name': 'Another Publisher' - }), content_type='application/json') + } + + r = self.client.post( + 'publisher_list', + data=json.dumps(payload), + content_type='application/json' + ) self.assertEqual(r.status_code, 201) self.assertTrue(Publisher.objects.filter(pk=r.json['id']).exists()) @@ -508,18 +536,35 @@ def test_publisher_details(self): self.assertEqual(r.status_code, 200) self.assertEqual(r.json['id'], self.publisher.id) - def test_publisher_update(self): - """Excercise updating an object via POST via DetailEndpoint""" + def test_publisher_update_with_put(self): + """Excercise updating an object via PUT via DetailEndpoint""" - r = self.client.put('publisher_detail', pk=self.publisher.id, - content_type='application/json', data=json.dumps({ - 'name': 'Changed Name' - })) + r = self.client.put( + 'publisher_detail', pk=self.publisher.id, + content_type='application/json', + data=json.dumps({'name': 'Changed Name'}) + ) self.assertEqual(r.status_code, 200) self.assertEqual(r.json['id'], self.publisher.id) p = Publisher.objects.get(id=self.publisher.id) self.assertEqual(p.name, 'Changed Name') + def test_book_update_with_patch(self): + """Excercise updating an object via PATCH via DetailEndpoint""" + + r = self.client.patch( + 'book_detail', isbn=self.book.isbn, + data=json.dumps({'title': 'Changed Title'}), + content_type='application/json' + ) + self.assertEqual(r.status_code, 200, r.content) + self.assertEqual(r.json['id'], self.book.id) + self.assertEqual(r.json['isbn'], self.book.isbn) + + b = Book.objects.get(id=self.book.id) + self.assertEqual(b.title, 'Changed Title') + self.assertEqual(b.isbn, self.book.isbn) + def test_publisher_delete(self): """Excercise deleting an object via DetailEndpoint""" @@ -530,16 +575,22 @@ def test_publisher_delete(self): def test_redonly_publisher_list_denies_creation(self): """Excercise method whitelist in ListEndpoint""" - r = self.client.post('readonly_publisher_list', data=json.dumps({ + payload = { 'name': 'Another Publisher' - }), content_type='application/json') + } + + r = self.client.post( + 'readonly_publisher_list', + data=json.dumps(payload), + content_type='application/json' + ) self.assertEqual(r.status_code, 405) def test_publisher_action(self): """Excercise RPC-style actions via ActionEndpoint""" r = self.client.post('publisher_action', pk=self.publisher.id, - content_type='application/json') + content_type='application/json') self.assertEqual(r.status_code, 200) self.assertEqual(r.json, {'result': 'done'}) diff --git a/testproject/testapp/urls.py b/testproject/testapp/urls.py index 241b27f..63e556d 100644 --- a/testproject/testapp/urls.py +++ b/testproject/testapp/urls.py @@ -6,7 +6,12 @@ def patterns(prefix, *args): return args -from .views import * +from .views import (AuthorList, AuthorDetail, FailsIntentionally, TestLogin, + TestBasicAuth, TestCustomAuthMethod, EchoView, + ErrorRaisingView, PublisherAutoList, + ReadOnlyPublisherAutoList, PublisherAutoDetail, + PublisherAction, BookDetail, WildcardHandler) + urlpatterns = patterns('', url(r'^authors/$', AuthorList.as_view(), diff --git a/testproject/testapp/views.py b/testproject/testapp/views.py index 253ea2f..3c49eab 100644 --- a/testproject/testapp/views.py +++ b/testproject/testapp/views.py @@ -4,17 +4,19 @@ from restless.models import serialize from restless.http import Http201, Http403, Http404, Http400, HttpError from restless.auth import (AuthenticateEndpoint, BasicHttpAuthMixin, - login_required) + login_required) from restless.modelviews import ListEndpoint, DetailEndpoint, ActionEndpoint -from .models import * -from .forms import * +from .models import Publisher, Author, Book +from .forms import AuthorForm -__all__ = ['AuthorList', 'AuthorDetail', 'FailsIntentionally', 'TestLogin', +__all__ = [ + 'AuthorList', 'AuthorDetail', 'FailsIntentionally', 'TestLogin', 'TestBasicAuth', 'WildcardHandler', 'EchoView', 'ErrorRaisingView', 'PublisherAutoList', 'PublisherAutoDetail', 'ReadOnlyPublisherAutoList', - 'PublisherAction', 'BookDetail', 'TestCustomAuthMethod'] + 'PublisherAction', 'BookDetail', 'TestCustomAuthMethod' +] class AuthorList(Endpoint): @@ -27,8 +29,7 @@ def post(self, request): author = form.save() return Http201(serialize(author)) else: - return Http400(reason='invalid author data', - details=form.errors) + return Http400(reason='invalid author data', details=form.errors) class AuthorDetail(Endpoint): @@ -56,8 +57,7 @@ def put(self, request, author_id=None): author = form.save() return serialize(author) else: - return Http400(reason='invalid author data', - details=form.errors) + return Http400(reason='invalid author data', details=form.errors) class FailsIntentionally(Endpoint): diff --git a/testproject/testproject/wsgi.py b/testproject/testproject/wsgi.py index 351cd6f..637c256 100644 --- a/testproject/testproject/wsgi.py +++ b/testproject/testproject/wsgi.py @@ -20,7 +20,7 @@ # This application object is used by any WSGI server configured to use this # file. This includes Django's development server, if the WSGI_APPLICATION # setting points here. -from django.core.wsgi import get_wsgi_application +from django.core.wsgi import get_wsgi_application # NOQA application = get_wsgi_application() # Apply WSGI middleware here.