Skip to content

Commit

Permalink
ZDL-36: Add pydantic type model
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanAquino committed Sep 21, 2021
1 parent 725c17c commit 7577503
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 51 deletions.
7 changes: 5 additions & 2 deletions authentication/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PermissionsMixin,
)
from rest_framework_simplejwt.tokens import RefreshToken
from authentication.validators import UserTokens


class UserManager(BaseUserManager):
Expand Down Expand Up @@ -57,9 +58,11 @@ class User(AbstractBaseUser, PermissionsMixin):

objects = UserManager()

def tokens(self):
def tokens(self) -> UserTokens:
refresh = RefreshToken.for_user(self)
return {"refresh": str(refresh), "token": str(refresh.access_token)}
return UserTokens(
**{"refresh": str(refresh), "token": str(refresh.access_token)}
)

def __str__(self):
return self.email
19 changes: 11 additions & 8 deletions authentication/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from authentication.models import User
from django.contrib import auth
from rest_framework.exceptions import AuthenticationFailed
from authentication.validators import UserLogin


class UserSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -74,7 +75,7 @@ class Meta:
model = User
fields = ["email", "password", "access", "refresh", "first_name", "last_name"]

def validate(self, attrs):
def validate(self, attrs) -> UserLogin:
email = attrs.get("email", "")
password = attrs.get("password", "")

Expand All @@ -85,13 +86,15 @@ def validate(self, attrs):

tokens = user.tokens()

return {
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"access": tokens["token"],
"refresh": tokens["refresh"],
}
return UserLogin(
**{
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"access": tokens.token,
"refresh": tokens.refresh,
}
)


class UserProfileSerializer(serializers.ModelSerializer):
Expand Down
17 changes: 0 additions & 17 deletions authentication/tests/base_data.py

This file was deleted.

4 changes: 2 additions & 2 deletions authentication/tests/test_auth_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_retrieve_user_profile():
mock_logged_in_user = UserFactory(
email="[email protected]", first_name="test", last_name="test2"
)
user_token = mock_logged_in_user.tokens()["token"]
user_token = mock_logged_in_user.tokens().token
client = Client(HTTP_AUTHORIZATION=f"Bearer {user_token}")
response = client.get("/v1/auth/profile/")
data = response.json()
Expand All @@ -166,7 +166,7 @@ def test_patch_profile_details():
last_name="test2",
groups=Group.objects.all(),
)
user_token = mock_logged_in_user.tokens()["token"]
user_token = mock_logged_in_user.tokens().token
client = Client(HTTP_AUTHORIZATION=f"Bearer {user_token}")
modified_data = {
"first_name": "modified_name1",
Expand Down
14 changes: 14 additions & 0 deletions authentication/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pydantic import BaseModel


class UserLogin(BaseModel):
email: str
first_name: str
last_name: str
access: str
refresh: str


class UserTokens(BaseModel):
token: str
refresh: str
17 changes: 6 additions & 11 deletions authentication/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ class UserRegisterView(GenericAPIView):

def post(self, request):
serializer = UserSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()

if serializer.is_valid():
serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)

return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
return Response(serializer.data, status=status.HTTP_201_CREATED)


class SupplierRegisterView(GenericAPIView):
Expand All @@ -35,11 +33,9 @@ class SupplierRegisterView(GenericAPIView):

def post(self, request):
serializer = SupplierSerializer(data=request.data)
if serializer.is_valid():
serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)

return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)


class UserLoginView(GenericAPIView):
Expand All @@ -49,7 +45,6 @@ class UserLoginView(GenericAPIView):

def post(self, request):
serializer = self.serializer_class(data=request.data)

serializer.is_valid(raise_exception=True)

return Response(serializer.data, status=status.HTTP_200_OK)
Expand Down
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.fixture
def logged_in_client(logged_in_user):
user_token = logged_in_user.tokens()["token"]
user_token = logged_in_user.tokens().token
return Client(HTTP_AUTHORIZATION=f"Bearer {user_token}")


Expand Down
4 changes: 2 additions & 2 deletions orders/tests/test_orders_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_update_cart(logged_in_client):
response_data = response.json()
order_items = response_data["products"]

assert response.status_code == 200
assert response.status_code == 201
assert response_data["total_items"] == 1
assert response_data["total_amount"] == 35
assert len(order_items) == 1
Expand All @@ -142,7 +142,7 @@ def test_update_cart(logged_in_client):
response_data = response.json()
order_items = response_data["products"]

assert response.status_code == 200
assert response.status_code == 201
assert response_data["total_items"] == 0
assert response_data["total_amount"] == 0
assert len(order_items) == 0
Expand Down
8 changes: 8 additions & 0 deletions orders/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel
from typing import Dict, List


class OrdersList(BaseModel):
total_items: int
total_amount: int
products: List[Dict]
21 changes: 13 additions & 8 deletions orders/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from rest_framework import mixins, viewsets
from rest_framework.decorators import action
from rest_condition import Or

from orders.validators import OrdersList
from products.models import Product
from rest_framework import status
import datetime


Expand All @@ -42,7 +45,7 @@ def list(self, request, *args, **kwargs):
),
}

return Response(resp)
return Response(data=OrdersList(**resp).dict())

@action(
detail=False,
Expand Down Expand Up @@ -92,7 +95,7 @@ def update_cart(self, request):
),
}

return Response(data=resp, status=200)
return Response(data=OrdersList(**resp).dict(), status=status.HTTP_201_CREATED)

@action(
detail=False,
Expand All @@ -118,12 +121,14 @@ def process_order(self, request):
shipping = ShippingAddress.objects.create(
customer=customer,
order=order,
address=request_data.data["address"],
city=request_data.data["city"],
state=request_data.data["state"],
zipcode=request_data.data["zipcode"],
address=request_data.validated_data["address"],
city=request_data.validated_data["city"],
state=request_data.validated_data["state"],
zipcode=request_data.validated_data["zipcode"],
)
shipping.save()

return Response(ShippingAddressSerializer(shipping).data, status=201)
return Response("Order not found", status=400)
return Response(
ShippingAddressSerializer(shipping).data, status=status.HTTP_201_CREATED
)
return Response("Order not found", status=status.HTTP_400_BAD_REQUEST)

0 comments on commit 7577503

Please sign in to comment.