From 75775031e1eefc9a7793f1a515e6caf8401ae8af Mon Sep 17 00:00:00 2001 From: RyanAquino Date: Sun, 19 Sep 2021 11:08:05 +0800 Subject: [PATCH] ZDL-36: Add pydantic type model --- authentication/models.py | 7 +++++-- authentication/serializers.py | 19 +++++++++++-------- authentication/tests/base_data.py | 17 ----------------- authentication/tests/test_auth_api.py | 4 ++-- authentication/validators.py | 14 ++++++++++++++ authentication/views.py | 17 ++++++----------- conftest.py | 2 +- orders/tests/test_orders_api.py | 4 ++-- orders/validators.py | 8 ++++++++ orders/views.py | 21 +++++++++++++-------- 10 files changed, 62 insertions(+), 51 deletions(-) delete mode 100644 authentication/tests/base_data.py create mode 100644 authentication/validators.py create mode 100644 orders/validators.py diff --git a/authentication/models.py b/authentication/models.py index c56238b..c67dd63 100644 --- a/authentication/models.py +++ b/authentication/models.py @@ -6,6 +6,7 @@ PermissionsMixin, ) from rest_framework_simplejwt.tokens import RefreshToken +from authentication.validators import UserTokens class UserManager(BaseUserManager): @@ -57,9 +58,11 @@ class User(AbstractBaseUser, PermissionsMixin): objects = UserManager() - def tokens(self): + def tokens(self) -> UserTokens: refresh = RefreshToken.for_user(self) - return {"refresh": str(refresh), "token": str(refresh.access_token)} + return UserTokens( + **{"refresh": str(refresh), "token": str(refresh.access_token)} + ) def __str__(self): return self.email diff --git a/authentication/serializers.py b/authentication/serializers.py index d1efb53..0a1c9d2 100644 --- a/authentication/serializers.py +++ b/authentication/serializers.py @@ -2,6 +2,7 @@ from authentication.models import User from django.contrib import auth from rest_framework.exceptions import AuthenticationFailed +from authentication.validators import UserLogin class UserSerializer(serializers.ModelSerializer): @@ -74,7 +75,7 @@ class Meta: model = User fields = ["email", "password", "access", "refresh", "first_name", "last_name"] - def validate(self, attrs): + def validate(self, attrs) -> UserLogin: email = attrs.get("email", "") password = attrs.get("password", "") @@ -85,13 +86,15 @@ def validate(self, attrs): tokens = user.tokens() - return { - "email": user.email, - "first_name": user.first_name, - "last_name": user.last_name, - "access": tokens["token"], - "refresh": tokens["refresh"], - } + return UserLogin( + **{ + "email": user.email, + "first_name": user.first_name, + "last_name": user.last_name, + "access": tokens.token, + "refresh": tokens.refresh, + } + ) class UserProfileSerializer(serializers.ModelSerializer): diff --git a/authentication/tests/base_data.py b/authentication/tests/base_data.py deleted file mode 100644 index 0c3f7d8..0000000 --- a/authentication/tests/base_data.py +++ /dev/null @@ -1,17 +0,0 @@ -from django.contrib.auth.models import Group -from authentication.models import User - - -def base_data(): - GROUPS = ["Admins", "Customers", "Suppliers"] - - for group in GROUPS: - Group.objects.get_or_create(name=group) - - User.objects.create_user( - email="customer@email.com", - password="password", - first_name="customer", - last_name="account", - role="Customers", - ) diff --git a/authentication/tests/test_auth_api.py b/authentication/tests/test_auth_api.py index d8a1344..1aa983e 100644 --- a/authentication/tests/test_auth_api.py +++ b/authentication/tests/test_auth_api.py @@ -142,7 +142,7 @@ def test_retrieve_user_profile(): mock_logged_in_user = UserFactory( email="test_test2@email.com", first_name="test", last_name="test2" ) - user_token = mock_logged_in_user.tokens()["token"] + user_token = mock_logged_in_user.tokens().token client = Client(HTTP_AUTHORIZATION=f"Bearer {user_token}") response = client.get("/v1/auth/profile/") data = response.json() @@ -166,7 +166,7 @@ def test_patch_profile_details(): last_name="test2", groups=Group.objects.all(), ) - user_token = mock_logged_in_user.tokens()["token"] + user_token = mock_logged_in_user.tokens().token client = Client(HTTP_AUTHORIZATION=f"Bearer {user_token}") modified_data = { "first_name": "modified_name1", diff --git a/authentication/validators.py b/authentication/validators.py new file mode 100644 index 0000000..a984f0d --- /dev/null +++ b/authentication/validators.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel + + +class UserLogin(BaseModel): + email: str + first_name: str + last_name: str + access: str + refresh: str + + +class UserTokens(BaseModel): + token: str + refresh: str diff --git a/authentication/views.py b/authentication/views.py index 9ffd1c2..430364a 100644 --- a/authentication/views.py +++ b/authentication/views.py @@ -20,12 +20,10 @@ class UserRegisterView(GenericAPIView): def post(self, request): serializer = UserSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() - if serializer.is_valid(): - serializer.save() - return Response(serializer.data, status=status.HTTP_201_CREATED) - - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + return Response(serializer.data, status=status.HTTP_201_CREATED) class SupplierRegisterView(GenericAPIView): @@ -35,11 +33,9 @@ class SupplierRegisterView(GenericAPIView): def post(self, request): serializer = SupplierSerializer(data=request.data) - if serializer.is_valid(): - serializer.save() - return Response(serializer.data, status=status.HTTP_201_CREATED) - - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(serializer.data, status=status.HTTP_201_CREATED) class UserLoginView(GenericAPIView): @@ -49,7 +45,6 @@ class UserLoginView(GenericAPIView): def post(self, request): serializer = self.serializer_class(data=request.data) - serializer.is_valid(raise_exception=True) return Response(serializer.data, status=status.HTTP_200_OK) diff --git a/conftest.py b/conftest.py index a744301..4d10aaf 100644 --- a/conftest.py +++ b/conftest.py @@ -7,7 +7,7 @@ @pytest.fixture def logged_in_client(logged_in_user): - user_token = logged_in_user.tokens()["token"] + user_token = logged_in_user.tokens().token return Client(HTTP_AUTHORIZATION=f"Bearer {user_token}") diff --git a/orders/tests/test_orders_api.py b/orders/tests/test_orders_api.py index 9f6f244..f0c4723 100644 --- a/orders/tests/test_orders_api.py +++ b/orders/tests/test_orders_api.py @@ -131,7 +131,7 @@ def test_update_cart(logged_in_client): response_data = response.json() order_items = response_data["products"] - assert response.status_code == 200 + assert response.status_code == 201 assert response_data["total_items"] == 1 assert response_data["total_amount"] == 35 assert len(order_items) == 1 @@ -142,7 +142,7 @@ def test_update_cart(logged_in_client): response_data = response.json() order_items = response_data["products"] - assert response.status_code == 200 + assert response.status_code == 201 assert response_data["total_items"] == 0 assert response_data["total_amount"] == 0 assert len(order_items) == 0 diff --git a/orders/validators.py b/orders/validators.py new file mode 100644 index 0000000..db2b7b7 --- /dev/null +++ b/orders/validators.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel +from typing import Dict, List + + +class OrdersList(BaseModel): + total_items: int + total_amount: int + products: List[Dict] diff --git a/orders/views.py b/orders/views.py index 3d57b53..88962e8 100644 --- a/orders/views.py +++ b/orders/views.py @@ -17,7 +17,10 @@ from rest_framework import mixins, viewsets from rest_framework.decorators import action from rest_condition import Or + +from orders.validators import OrdersList from products.models import Product +from rest_framework import status import datetime @@ -42,7 +45,7 @@ def list(self, request, *args, **kwargs): ), } - return Response(resp) + return Response(data=OrdersList(**resp).dict()) @action( detail=False, @@ -92,7 +95,7 @@ def update_cart(self, request): ), } - return Response(data=resp, status=200) + return Response(data=OrdersList(**resp).dict(), status=status.HTTP_201_CREATED) @action( detail=False, @@ -118,12 +121,14 @@ def process_order(self, request): shipping = ShippingAddress.objects.create( customer=customer, order=order, - address=request_data.data["address"], - city=request_data.data["city"], - state=request_data.data["state"], - zipcode=request_data.data["zipcode"], + address=request_data.validated_data["address"], + city=request_data.validated_data["city"], + state=request_data.validated_data["state"], + zipcode=request_data.validated_data["zipcode"], ) shipping.save() - return Response(ShippingAddressSerializer(shipping).data, status=201) - return Response("Order not found", status=400) + return Response( + ShippingAddressSerializer(shipping).data, status=status.HTTP_201_CREATED + ) + return Response("Order not found", status=status.HTTP_400_BAD_REQUEST)