diff --git a/lacommunaute/forum_conversation/models.py b/lacommunaute/forum_conversation/models.py index 3c8ffb713..a00249f60 100644 --- a/lacommunaute/forum_conversation/models.py +++ b/lacommunaute/forum_conversation/models.py @@ -1,4 +1,5 @@ from django.conf import settings +from django.contrib.contenttypes.fields import GenericRelation from django.db import models from django.db.models import Count, Exists, OuterRef from django.urls import reverse @@ -7,6 +8,7 @@ from taggit.managers import TaggableManager from lacommunaute.forum_member.shortcuts import get_forum_member_display_name +from lacommunaute.forum_upvote.models import UpVote from lacommunaute.users.models import User @@ -82,6 +84,8 @@ def is_certified(self): class Post(AbstractPost): username = models.EmailField(blank=True, null=True, verbose_name=("Adresse email")) + upvotes = GenericRelation(UpVote, related_query_name="post") + @property def poster_display_name(self): if self.username: diff --git a/lacommunaute/forum_conversation/shortcuts.py b/lacommunaute/forum_conversation/shortcuts.py index 02699f602..0b799196e 100644 --- a/lacommunaute/forum_conversation/shortcuts.py +++ b/lacommunaute/forum_conversation/shortcuts.py @@ -1,3 +1,4 @@ +from django.contrib.contenttypes.models import ContentType from django.db.models import Count, Exists, OuterRef, Prefetch, Q, QuerySet from lacommunaute.forum.enums import Kind as Forum_Kind @@ -21,8 +22,13 @@ def get_posts_of_a_topic_except_first_one(topic: Topic, user: User) -> QuerySet[ if user.is_authenticated: qs = qs.annotate( upvotes_count=Count("upvotes"), - # using user.id instead of user, to manage anonymous user journey - has_upvoted=Exists(UpVote.objects.filter(post=OuterRef("pk"), voter=user)), + has_upvoted=Exists( + UpVote.objects.filter( + object_id=OuterRef("pk"), + voter=user, + content_type_id=ContentType.objects.get_for_model(qs.model).id, + ) + ), ) else: qs = qs.annotate( diff --git a/lacommunaute/forum_conversation/tests/tests_shortcuts.py b/lacommunaute/forum_conversation/tests/tests_shortcuts.py index d27c6ef1d..bd2c1fa5d 100644 --- a/lacommunaute/forum_conversation/tests/tests_shortcuts.py +++ b/lacommunaute/forum_conversation/tests/tests_shortcuts.py @@ -51,7 +51,7 @@ def test_topic_has_two_posts_requested_by_authenticated_user(self): def test_topic_has_been_upvoted(self): topic = TopicFactory(with_post=True) post = PostFactory(topic=topic) - UpVoteFactory(post=post) + UpVoteFactory(content_object=post) posts = get_posts_of_a_topic_except_first_one(topic, AnonymousUser()) post = posts.first() @@ -62,7 +62,7 @@ def test_topic_has_been_upvoted(self): def test_topic_has_been_upvoted_by_the_user(self): topic = TopicFactory(with_post=True) post = PostFactory(topic=topic) - UpVoteFactory(post=post, voter=topic.poster) + UpVoteFactory(content_object=post, voter=topic.poster) posts = get_posts_of_a_topic_except_first_one(topic, topic.poster) post = posts.first() diff --git a/lacommunaute/forum_conversation/tests/tests_views.py b/lacommunaute/forum_conversation/tests/tests_views.py index 4b190f937..1a962abb8 100644 --- a/lacommunaute/forum_conversation/tests/tests_views.py +++ b/lacommunaute/forum_conversation/tests/tests_views.py @@ -432,7 +432,7 @@ def test_post_has_no_upvote(self): def test_post_has_upvote_by_user(self): PostFactory(topic=self.topic, poster=self.poster) - UpVoteFactory(post=self.topic.last_post, voter=self.poster) + UpVoteFactory(content_object=self.topic.last_post, voter=self.poster) self.client.force_login(self.poster) response = self.client.get(self.url) @@ -476,7 +476,7 @@ def test_edit_link_is_visible(self): def test_numqueries(self): PostFactory.create_batch(10, topic=self.topic, poster=self.poster) - UpVoteFactory(post=self.topic.last_post, voter=UserFactory()) + UpVoteFactory(content_object=self.topic.last_post, voter=UserFactory()) CertifiedPostFactory(topic=self.topic, post=self.topic.last_post, user=UserFactory()) self.client.force_login(self.poster) diff --git a/lacommunaute/forum_conversation/tests/tests_views_htmx.py b/lacommunaute/forum_conversation/tests/tests_views_htmx.py index da5380502..daeba6da1 100644 --- a/lacommunaute/forum_conversation/tests/tests_views_htmx.py +++ b/lacommunaute/forum_conversation/tests/tests_views_htmx.py @@ -343,8 +343,8 @@ def test_upvote_annotations(self): self.assertEqual(response.status_code, 200) self.assertContains(response, '0') - UpVoteFactory(post=post, voter=UserFactory()) - UpVoteFactory(post=post, voter=self.user) + UpVoteFactory(content_object=post, voter=UserFactory()) + UpVoteFactory(content_object=post, voter=self.user) response = view.get(request) self.assertEqual(response.status_code, 200) diff --git a/lacommunaute/forum_upvote/admin.py b/lacommunaute/forum_upvote/admin.py index cd732201a..b919d3c13 100644 --- a/lacommunaute/forum_upvote/admin.py +++ b/lacommunaute/forum_upvote/admin.py @@ -4,11 +4,8 @@ class UpVoteAdmin(admin.ModelAdmin): - list_display = ("voter", "post", "created_at") - raw_id_fields = ( - "voter", - "post", - ) + list_display = ("voter", "created_at") + raw_id_fields = ("voter",) admin.site.register(UpVote, UpVoteAdmin) diff --git a/lacommunaute/forum_upvote/migrations/0004_convert_to_genericforeignkey.py b/lacommunaute/forum_upvote/migrations/0004_convert_to_genericforeignkey.py new file mode 100644 index 000000000..5222a012d --- /dev/null +++ b/lacommunaute/forum_upvote/migrations/0004_convert_to_genericforeignkey.py @@ -0,0 +1,80 @@ +import django.db.models.deletion +from django.contrib.contenttypes.models import ContentType +from django.db import migrations, models + + +def move_forward_foreign_key_to_generic_foreign_key(apps, schema_editor): + UpVote = apps.get_model("forum_upvote", "UpVote") + ContentType = apps.get_model("contenttypes", "ContentType") + + UpVote.objects.filter(post__isnull=True).delete() + content_type, _ = ContentType.objects.get_or_create(model="post", app_label="forum_conversation") + + for upvote in UpVote.objects.all(): + upvote.content_object = upvote.post + upvote.object_id = upvote.post_id + upvote.content_type = content_type + upvote.save() + + +def move_back_generic_foreign_key_to_foreign_key(apps, schema_editor): + UpVote = apps.get_model("forum_upvote", "UpVote") + Post = apps.get_model("forum_conversation", "Post") + + post_content_type_id = ContentType.objects.get(model="post", app_label="forum_conversation").id + UpVote.objects.exclude(content_type_id=post_content_type_id).delete() + + for upvote in UpVote.objects.all(): + upvote.post = Post.objects.get(id=upvote.object_id) + upvote.save() + + +class Migration(migrations.Migration): + dependencies = [ + ("contenttypes", "0002_remove_content_type_name"), + ("forum_upvote", "0003_delete_certifiedpost"), + ] + + operations = [ + migrations.AlterUniqueTogether( + name="upvote", + unique_together=set(), + ), + migrations.AddField( + model_name="upvote", + name="content_type", + field=models.ForeignKey( + null=True, on_delete=django.db.models.deletion.CASCADE, to="contenttypes.contenttype" + ), + ), + migrations.AddField( + model_name="upvote", + name="object_id", + field=models.PositiveBigIntegerField(null=True), + ), + migrations.RunPython( + move_forward_foreign_key_to_generic_foreign_key, + move_back_generic_foreign_key_to_foreign_key, + ), + migrations.AlterUniqueTogether( + name="upvote", + unique_together={("voter", "content_type", "object_id")}, + ), + migrations.AlterField( + model_name="upvote", + name="content_type", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="contenttypes.contenttype", + ), + ), + migrations.AlterField( + model_name="upvote", + name="object_id", + field=models.PositiveBigIntegerField(), + ), + migrations.RemoveField( + model_name="upvote", + name="post", + ), + ] diff --git a/lacommunaute/forum_upvote/models.py b/lacommunaute/forum_upvote/models.py index f13590ca1..a9990122b 100644 --- a/lacommunaute/forum_upvote/models.py +++ b/lacommunaute/forum_upvote/models.py @@ -1,8 +1,8 @@ from django.conf import settings +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType from django.db import models -from lacommunaute.forum_conversation.models import Post - class UpVote(models.Model): voter = models.ForeignKey( @@ -14,20 +14,16 @@ class UpVote(models.Model): verbose_name="Voter", ) - post = models.ForeignKey( - Post, - related_name="upvotes", - blank=True, - null=True, - on_delete=models.SET_NULL, - verbose_name="Post", - ) + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) + object_id = models.PositiveBigIntegerField() + + content_object = GenericForeignKey("content_type", "object_id") created_at = models.DateTimeField(auto_now_add=True, db_index=True, verbose_name="Creation date") objects = models.Manager() class Meta: - unique_together = ["voter", "post"] + unique_together = ["voter", "content_type", "object_id"] ordering = [ "-created_at", ] diff --git a/lacommunaute/forum_upvote/tests/tests_models.py b/lacommunaute/forum_upvote/tests/tests_models.py index 4a9a6885e..0d4f05264 100644 --- a/lacommunaute/forum_upvote/tests/tests_models.py +++ b/lacommunaute/forum_upvote/tests/tests_models.py @@ -6,9 +6,23 @@ class UpVoteModelTest(TestCase): - def test_post_and_voter_are_uniques_together(self): + def test_generic_relation(self): topic = TopicFactory(with_post=True) - UpVote.objects.create(post=topic.first_post, voter=topic.first_post.poster) + UpVote.objects.create(content_object=topic.first_post, voter=topic.first_post.poster) + UpVote.objects.create(content_object=topic.forum, voter=topic.first_post.poster) + + self.assertEqual(UpVote.objects.count(), 2) + + def test_upvoted_post_unicity(self): + topic = TopicFactory(with_post=True) + UpVote.objects.create(content_object=topic.first_post, voter=topic.first_post.poster) + + with self.assertRaises(IntegrityError): + UpVote.objects.create(content_object=topic.first_post, voter=topic.first_post.poster) + + def test_upvoted_forum_unicity(self): + topic = TopicFactory(with_post=True) + UpVote.objects.create(content_object=topic.forum, voter=topic.first_post.poster) with self.assertRaises(IntegrityError): - UpVote.objects.create(post=topic.first_post, voter=topic.first_post.poster) + UpVote.objects.create(content_object=topic.forum, voter=topic.first_post.poster) diff --git a/lacommunaute/forum_upvote/tests/tests_views.py b/lacommunaute/forum_upvote/tests/tests_views.py index 270c43721..c697ee597 100644 --- a/lacommunaute/forum_upvote/tests/tests_views.py +++ b/lacommunaute/forum_upvote/tests/tests_views.py @@ -1,3 +1,4 @@ +from django.contrib.contenttypes.models import ContentType from django.test import TestCase from django.urls import reverse from faker import Faker @@ -44,12 +45,22 @@ def test_upvote_downvote_post(self): response = self.client.post(self.url, data=form_data) self.assertEqual(response.status_code, 200) self.assertContains(response, '1') - self.assertEqual(1, UpVote.objects.filter(voter_id=self.user.id, post_id=post.id).count()) + self.assertEqual( + 1, + UpVote.objects.filter( + voter_id=self.user.id, object_id=post.id, content_type=ContentType.objects.get_for_model(post) + ).count(), + ) response = self.client.post(self.url, data=form_data) self.assertEqual(response.status_code, 200) self.assertContains(response, '0') - self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=post.id).count()) + self.assertEqual( + 0, + UpVote.objects.filter( + voter_id=self.user.id, object_id=post.id, content_type=ContentType.objects.get_for_model(post) + ).count(), + ) def test_object_not_found(self): self.client.force_login(self.user) @@ -58,13 +69,27 @@ def test_object_not_found(self): response = self.client.post(self.url, data=form_data) self.assertEqual(response.status_code, 404) - self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=self.topic.last_post.id).count()) + self.assertEqual( + 0, + UpVote.objects.filter( + voter_id=self.user.id, + object_id=self.topic.last_post.id, + content_type_id=ContentType.objects.get_for_model(self.topic.last_post).id, + ).count(), + ) form_data = {"pk": self.topic.pk, "post_pk": 9999} response = self.client.post(self.url, data=form_data) self.assertEqual(response.status_code, 404) - self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=self.topic.last_post.id).count()) + self.assertEqual( + 0, + UpVote.objects.filter( + voter_id=self.user.id, + object_id=self.topic.last_post.id, + content_type_id=ContentType.objects.get_for_model(self.topic.last_post).id, + ).count(), + ) def test_topic_is_marked_as_read_when_upvoting(self): self.assertFalse(ForumReadTrack.objects.count()) diff --git a/lacommunaute/forum_upvote/views.py b/lacommunaute/forum_upvote/views.py index 61fb70c37..9cfb6f5b9 100644 --- a/lacommunaute/forum_upvote/views.py +++ b/lacommunaute/forum_upvote/views.py @@ -1,5 +1,6 @@ import logging +from django.contrib.contenttypes.models import ContentType from django.shortcuts import get_object_or_404, render from django.views import View from machina.core.loading import get_class @@ -33,13 +34,21 @@ def get_object(self): def post(self, request, **kwargs): post = self.get_object() - upvote = UpVote.objects.filter(voter_id=request.user.id, post_id=post.id) + upvote = UpVote.objects.filter( + voter_id=request.user.id, + object_id=post.id, + content_type=ContentType.objects.get_for_model(post), + ) if upvote.exists(): upvote.delete() post.has_upvoted = False else: - UpVote(voter_id=request.user.id, post_id=post.id).save() + UpVote( + voter_id=request.user.id, + object_id=post.id, + content_type=ContentType.objects.get_for_model(post), + ).save() post.has_upvoted = True post.upvotes_count = post.upvotes.count()