Skip to content

Commit

Permalink
ZDL-99: Fix updated_at date time upon login and add login rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanAquino committed Mar 21, 2022
1 parent 5f09609 commit 7ea9d6b
Show file tree
Hide file tree
Showing 14 changed files with 468 additions and 24 deletions.
5 changes: 4 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ omit =
*/venv/*
*/zadalaAPI/*
manage.py
*/apps.py
*/apps.py

[report]
show_missing = True
149 changes: 149 additions & 0 deletions authentication/custom_throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from collections import Counter
from copy import copy

from django.contrib.auth import authenticate
from rest_framework.exceptions import ValidationError
from rest_framework.request import Request
from rest_framework.throttling import AnonRateThrottle

from social_auth.serializers import GoogleSocialAuthSerializer


class UserLoginRateThrottle(AnonRateThrottle):
scope = "logins"
max_attempts = 3

def allow_request(self, request: Request, view):
"""
Implement the check to see if the request should be throttled.
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True

self.key = self.get_cache_key(request, view)
if self.key is None:
return True

self.history = self.cache.get(self.key, [])
self.now = self.timer()
failed = False

if self.history:
last_login_attempt = self.history[-1]

if isinstance(self.history[-1], str):
last_login_attempt = float(self.history[-1].split("_")[-1])

while self.history and last_login_attempt <= self.now - self.duration:
self.history.pop()

if len(self.history) >= self.num_requests:
return self.throttle_failure()

if len(self.history) >= self.max_attempts:
failed = self.verify_fail_login_attempts(request)

return (
self.throttle_failure()
if failed is True
else self.throttle_success(request)
)

def format_history_exclude_timestamp(self):
"""
Helper method to format history and exclude its timestamps
"""
history_copy = copy(self.history)

for index, history in enumerate(history_copy):
if isinstance(history, str):
history_excluded_timestamp = history.split("_")[:-1]
history_copy[index] = "_".join(history_excluded_timestamp)

return history_copy

def verify_fail_login_attempts(self, request: Request) -> bool:
"""
Verify fail login attempts
"""
email = request.data.get("email")
password = request.data.get("password")

formatted_history = self.format_history_exclude_timestamp()
history_count_mapping: dict = Counter(formatted_history)

for key, value in history_count_mapping.items():
cached_email = key.split("_")[-2]
cached_password = key.split("_")[-1]

if (
cached_email == email
and cached_password == password
and value >= self.max_attempts
):
self.max_attempts = 0
return True
return False

def throttle_success(self, request) -> bool:
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
email = request.data.get("email")
password = request.data.get("password")
user = authenticate(request, email=email, password=password)

if not user:
self.history.insert(0, f"failed_login_{email}_{password}_{self.now}")
self.cache.set(self.key, self.history, self.duration)

return True

def wait(self) -> float:
"""
Returns the recommended next request time in seconds.
"""
remaining_duration = self.duration

if self.history:
last_history_copy = self.history[-1]
if isinstance(last_history_copy, str):
last_history_copy = float(last_history_copy.split("_")[-1])
remaining_duration = self.duration - (self.now - last_history_copy)

available_requests = (
1 if not self.max_attempts else self.num_requests - len(self.history) + 1
)

return remaining_duration / float(available_requests)


class OAuthUserLoginRateThrottle(UserLoginRateThrottle):
def verify_fail_login_attempts(self, request: Request) -> bool:
auth_token = request.data.get("auth_token")
formatted_history = self.format_history_exclude_timestamp()
history_count_mapping: dict = Counter(formatted_history)

for key, value in history_count_mapping.items():
cached_oauth_token = key.split("_")[-1]

if cached_oauth_token == auth_token and value >= self.max_attempts:
self.max_attempts = 0
return True
return False

def throttle_success(self, request) -> bool:
serializer = GoogleSocialAuthSerializer(data=request.data)
auth_token = request.data.get("auth_token")

try:
serializer.is_valid(raise_exception=True)
except ValidationError:
self.history.insert(0, f"failed_login_oauth_{auth_token}_{self.now}")
self.cache.set(self.key, self.history, self.duration)

return True
14 changes: 12 additions & 2 deletions authentication/serializers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from django.contrib import auth
from django.contrib.auth.models import update_last_login
from rest_framework import serializers
from rest_framework.exceptions import AuthenticationFailed

from authentication.models import User
from authentication.validators import UserLogin
from authentication.validators import AuthProviders, UserLogin


class UserSerializer(serializers.ModelSerializer):
Expand All @@ -21,6 +22,7 @@ class Meta:
"is_superuser",
"groups",
"user_permissions",
"auth_provider",
)

def validate(self, attrs):
Expand Down Expand Up @@ -50,6 +52,7 @@ class Meta:
"is_superuser",
"groups",
"user_permissions",
"auth_provider",
)

def validate(self, attrs):
Expand Down Expand Up @@ -85,10 +88,11 @@ def validate(self, attrs) -> UserLogin:
if not user:
raise AuthenticationFailed("Invalid email/password")

if user.auth_provider != "email":
if user.auth_provider != AuthProviders.email.value:
raise AuthenticationFailed("Please login using your login provider.")

tokens = user.tokens()
update_last_login(sender=None, user=user)

return UserLogin(
**{
Expand All @@ -102,6 +106,11 @@ def validate(self, attrs) -> UserLogin:


class UserProfileSerializer(serializers.ModelSerializer):
password = serializers.CharField(max_length=65, min_length=8)
email = serializers.EmailField(max_length=255, min_length=4)
first_name = serializers.CharField(max_length=255, min_length=2)
last_name = serializers.CharField(max_length=255, min_length=2)

class Meta:
model = User
fields = [
Expand All @@ -115,6 +124,7 @@ class Meta:
"password",
]
write_only_fields = ["password", "groups"]
read_only_fields = ["auth_provider"]

def __init__(self, *args, **kwargs):
fields = kwargs.pop("fields", None)
Expand Down
8 changes: 7 additions & 1 deletion authentication/tests/factories/user_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from datetime import datetime

import factory
from factory import PostGenerationMethodCall
from factory.django import DjangoModelFactory
from factory.fuzzy import FuzzyNaiveDateTime

from authentication.models import User
from authentication.validators import AuthProviders


class UserFactory(DjangoModelFactory):
Expand All @@ -14,7 +18,9 @@ class Meta:
last_name = "account"
password = PostGenerationMethodCall("set_password", "password")
is_active = True
auth_provider = "email"
auth_provider = AuthProviders.email.value
date_joined = FuzzyNaiveDateTime(datetime(2022, 1, 1))
last_login = FuzzyNaiveDateTime(datetime(2022, 1, 1))

@factory.post_generation
def groups(self, create, extracted, **kwargs):
Expand Down
26 changes: 24 additions & 2 deletions authentication/tests/test_auth_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import pytest
from django.contrib.auth.models import Group
from django.test import Client
Expand Down Expand Up @@ -189,7 +191,7 @@ def test_patch_profile_details():
modified_data = {
"first_name": "modified_name1",
"last_name": "modified_name2",
"password": "test2",
"password": "string123",
}

data = client._encode_json({} if not modified_data else modified_data, content_type)
Expand All @@ -208,7 +210,7 @@ def test_patch_profile_details():
assert response.status_code == 204
assert modified_user.first_name == "modified_name1"
assert modified_user.last_name == "modified_name2"
assert modified_user.check_password("test2") is True
assert modified_user.check_password("string123") is True


@pytest.mark.django_db
Expand Down Expand Up @@ -246,3 +248,23 @@ def test_patch_profile_password_of_oauth_user_should_not_update():

assert response.status_code == 204
assert modified_user.check_password("oauth-generated-password") is True


@pytest.mark.django_db
def test_login_should_update_last_login_date_time(client):
"""
Test User login should update last login date time
"""
user = UserFactory(last_login="2022-02-21 00:53:12.279437")
data = {"email": user.email, "password": "password"}

response = client.post("/v1/auth/login/", data)
current_login_time = datetime.today().replace(microsecond=0).timestamp()

assert response.status_code == 200
user.refresh_from_db()

user_last_login_second_timestamp = (
User.objects.first().last_login.replace(microsecond=0).timestamp()
)
assert user_last_login_second_timestamp == current_login_time
Loading

0 comments on commit 7ea9d6b

Please sign in to comment.