diff --git a/starter_packs/tests.py b/starter_packs/tests.py index 491637b..1ac5559 100644 --- a/starter_packs/tests.py +++ b/starter_packs/tests.py @@ -3,6 +3,7 @@ from model_bakery import baker from accounts.models import Account +from starter_packs.models import StarterPackAccount class TestStarterPacks(TestCase): @@ -173,6 +174,69 @@ def test_delete(self): self.assertIsNotNone(self.starter_pack.deleted_at) +class TestToggleStarterPackAccount(TestCase): + @classmethod + def setUpTestData(cls): + cls.user = baker.make("auth.User") + baker.make("mastodon_auth.AccountAccess", user=cls.user) + cls.starter_pack = baker.make("starter_packs.StarterPack", created_by=cls.user) + + def test_add_account(self): + account = baker.make("accounts.Account", discoverable=True) + self.client.force_login(self.user) + response = self.client.post( + reverse( + "toggle_account_to_starter_pack", + kwargs={"starter_pack_slug": self.starter_pack.slug, "account_id": account.id}, + ) + ) + self.assertEqual(response.status_code, 200) + self.assertTrue( + StarterPackAccount.objects.filter(account_id=account.id, starter_pack_id=self.starter_pack.id).exists() + ) + + def test_remove_account(self): + account = baker.make("accounts.Account", discoverable=True) + self.client.force_login(self.user) + baker.make("starter_packs.StarterPackAccount", account=account, starter_pack=self.starter_pack) + response = self.client.post( + reverse( + "toggle_account_to_starter_pack", + kwargs={"starter_pack_slug": self.starter_pack.slug, "account_id": account.id}, + ) + ) + self.assertEqual(response.status_code, 200) + self.assertFalse( + StarterPackAccount.objects.filter(account_id=account.id, starter_pack_id=self.starter_pack.id).exists() + ) + + def test_toggle_after_limit(self): + account = baker.make("accounts.Account", discoverable=True) + self.client.force_login(self.user) + # baker.make("starter_packs.StarterPackAccount", starter_pack=self.starter_pack, account=account) + baker.make("starter_packs.StarterPackAccount", starter_pack=self.starter_pack, _quantity=150) + + self.assertEqual(self.starter_pack.starterpackaccount_set.count(), 150) + response = self.client.post( + reverse( + "toggle_account_to_starter_pack", + kwargs={"starter_pack_slug": self.starter_pack.slug, "account_id": account.id}, + ) + ) + self.assertEqual(response.status_code, 200) + + # Make sure we can delete even after the limit: + baker.make("starter_packs.StarterPackAccount", starter_pack=self.starter_pack, account=account) + self.assertEqual(self.starter_pack.starterpackaccount_set.count(), 151) + response = self.client.post( + reverse( + "toggle_account_to_starter_pack", + kwargs={"starter_pack_slug": self.starter_pack.slug, "account_id": account.id}, + ) + ) + self.assertEqual(self.starter_pack.starterpackaccount_set.count(), 150) + + class TestShareStarterPack(TestCase): @classmethod def setUpTestData(cls): diff --git a/starter_packs/views.py b/starter_packs/views.py index df24d29..f382983 100644 --- a/starter_packs/views.py +++ b/starter_packs/views.py @@ -236,30 +236,32 @@ def create_starter_pack(request): ) +@transaction.atomic def toggle_account_to_starter_pack(request, starter_pack_slug, account_id): starter_pack = get_object_or_404( StarterPack, slug=starter_pack_slug, deleted_at__isnull=True, created_by=request.user ) - if StarterPackAccount.objects.filter(starter_pack=starter_pack).count() > 150: - return render( - request, - "starter_pack_stats.html", - { - "starter_pack": starter_pack, - "num_accounts": StarterPackAccount.objects.filter(starter_pack=starter_pack).count(), - "error": "You have reached the maximum number of accounts in a starter pack.", - }, - ) - try: - StarterPackAccount.objects.create( - starter_pack=starter_pack, - account_id=account_id, - ) - except IntegrityError: + if StarterPackAccount.objects.filter(starter_pack=starter_pack, account_id=account_id).exists(): StarterPackAccount.objects.filter( starter_pack=starter_pack, account_id=account_id, ).delete() + else: + if StarterPackAccount.objects.filter(starter_pack=starter_pack).count() >= 150: + return render( + request, + "starter_pack_stats.html", + { + "starter_pack": starter_pack, + "num_accounts": StarterPackAccount.objects.filter(starter_pack=starter_pack).count(), + "error": "You have reached the maximum number of accounts in a starter pack.", + }, + ) + StarterPackAccount.objects.create( + starter_pack=starter_pack, + account_id=account_id, + ) + return render( request, "starter_pack_stats.html",