Skip to content

Commit

Permalink
Merge branch 'RyoheiTaima/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
asucrews committed Mar 12, 2022
2 parents c6cc8ec + 2465e39 commit 7821b5e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 36 deletions.
15 changes: 13 additions & 2 deletions django_currentuser/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def _set_current_user(user=None):
_do_set_current_user(lambda self: user)


class SetCurrentUser:
def __init__(this, request):
this.request = request

def __enter__(this):
_do_set_current_user(lambda self: getattr(this.request, 'user', None))

def __exit__(this, type, value, traceback):
_do_set_current_user(lambda self: None)


class ThreadLocalUserMiddleware(object):

def __init__(self, get_response):
Expand All @@ -30,8 +41,8 @@ def __call__(self, request):
# request.user closure; asserts laziness;
# memorization is implemented in
# request.user (non-data descriptor)
_do_set_current_user(lambda self: getattr(request, 'user', None))
response = self.get_response(request)
with SetCurrentUser(request):
response = self.get_response(request)
return response


Expand Down
76 changes: 43 additions & 33 deletions tests/testapp/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from hamcrest import assert_that, instance_of, equal_to, is_, empty, has_length

from django_currentuser.middleware import (
get_current_user, _set_current_user, get_current_authenticated_user)
SetCurrentUser,
get_current_user,
_set_current_user,
get_current_authenticated_user
)
from django_currentuser.db.models import CurrentUserField

from .sixmock import patch
Expand All @@ -30,23 +34,29 @@ def setUp(self):
self.user2.set_password("pw2")
self.user2.save()

def login_and_go_to_homepage(self, username, password):
def login_and_create(self, username, password):
data = {"username": username, "password": password}
self.client.post(reverse("login"), follow=True, data=data)
self.client.get('/')
self.client.post(reverse("create"), follow=True, data=data)

def login_and_update(self, username, password, pk):
data = {"username": username, "password": password}
self.client.post(reverse("login"), follow=True, data=data)
self.client.patch(reverse("update", args=[pk]), follow=True, data=data)


class TestSetUserToThread(TestUserBase):

@patch.object(SetCurrentUser, "__exit__", lambda *args, **kwargs: None)
def test__local_thread_var_is_set_to_logged_in_user(self):
_set_current_user(None)
self.assertIsNone(get_current_user())

self.login_and_go_to_homepage(username="user1", password="pw1")
self.login_and_create(username="user1", password="pw1")
self.assertEqual(self.user1, get_current_user())
self.client.logout()

self.login_and_go_to_homepage(username="user2", password="pw2")
self.login_and_create(username="user2", password="pw2")
self.assertEqual(self.user2, get_current_user())
self.client.logout()

Expand Down Expand Up @@ -140,49 +150,49 @@ def test_on_update_enabled(self):
test_model.save()

self.assertIs(test_model.updated_by_id, None)
test_model.refresh_from_db()
self.assertIs(test_model.updated_by, None)

self.login_and_go_to_homepage(username="user1", password="pw1")
test_model.save()
self.login_and_update(username="user1", password="pw1", pk=1)
user = TestModelOnUpdate.objects.get(pk=1)

self.assertEqual(self.user1.pk, test_model.updated_by_id)
test_model.refresh_from_db()
self.assertEqual(self.user1, test_model.updated_by)
self.assertEqual(self.user1.pk, user.updated_by_id)
self.assertEqual(self.user1, user.updated_by)

self.login_and_go_to_homepage(username="user2", password="pw2")
test_model.save()
self.login_and_update(username="user2", password="pw2", pk=1)
user = TestModelOnUpdate.objects.get(pk=1)

self.assertEqual(self.user2.pk, test_model.updated_by_id)
test_model.refresh_from_db()
self.assertEqual(self.user2, test_model.updated_by)
self.assertEqual(self.user2.pk, user.updated_by_id)
self.assertEqual(self.user2, user.updated_by)

_set_current_user(None)
test_model.save()
user = TestModelOnUpdate.objects.get(pk=1)

self.assertIs(test_model.updated_by_id, None)
test_model.refresh_from_db()
self.assertIs(test_model.updated_by, None)

def test_on_update_disabled(self):
self.login_and_go_to_homepage(username="user1", password="pw1")
test_model = TestModelDefaultBehavior()
test_model.save()
self.login_and_create(username="user1", password="pw1")
user1 = TestModelDefaultBehavior.objects.get(pk=1)

self.assertEqual(self.user1.pk, test_model.created_by_id)
test_model.refresh_from_db()
self.assertEqual(self.user1, test_model.created_by)
self.assertEqual(self.user1.pk, user1.created_by_id)
self.assertEqual(self.user1, user1.created_by)

self.login_and_go_to_homepage(username="user2", password="pw2")
test_model.save()
self.login_and_create(username="user2", password="pw2")
user1 = TestModelDefaultBehavior.objects.get(pk=1)
user2 = TestModelDefaultBehavior.objects.get(pk=2)

self.assertEqual(self.user1.pk, test_model.created_by_id)
test_model.refresh_from_db()
self.assertEqual(self.user1, test_model.created_by)
self.assertEqual(self.user1.pk, user1.created_by_id)
self.assertEqual(self.user1, user1.created_by)
self.assertEqual(self.user2.pk, user2.created_by_id)
self.assertEqual(self.user2, user2.created_by)

_set_current_user(None)
test_model.save()

self.assertEqual(self.user1.pk, test_model.created_by_id)
test_model.refresh_from_db()
self.assertEqual(self.user1, test_model.created_by)
TestModelDefaultBehavior().save()
user1 = TestModelDefaultBehavior.objects.get(pk=1)
user2 = TestModelDefaultBehavior.objects.get(pk=2)

self.assertEqual(self.user1.pk, user1.created_by_id)
self.assertEqual(self.user1, user1.created_by)
self.assertEqual(self.user2.pk, user2.created_by_id)
self.assertEqual(self.user2, user2.created_by)
19 changes: 18 additions & 1 deletion tests/testapp/urls.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import django
from django.contrib.auth import views as auth_views
from django.http import HttpResponse

from .models import TestModelDefaultBehavior, TestModelOnUpdate


if django.VERSION < (2, 0):
Expand All @@ -8,6 +11,20 @@
from django.urls import path


def create(request):
if request.method == 'POST':
TestModelDefaultBehavior.objects.create()
return HttpResponse()


def update(request, pk):
if request.method == 'PATCH':
TestModelOnUpdate.objects.get(pk=pk).save()
return HttpResponse()


urlpatterns = [
path(r'login/', auth_views.LoginView.as_view(), name="login")
path(r'login/', auth_views.LoginView.as_view(), name="login"),
path(r'create/', create, name="create"),
path(r'update/<int:pk>/', update, name="update")
]

0 comments on commit 7821b5e

Please sign in to comment.