diff --git a/lacommunaute/forum_conversation/models.py b/lacommunaute/forum_conversation/models.py
index 3c8ffb713..a00249f60 100644
--- a/lacommunaute/forum_conversation/models.py
+++ b/lacommunaute/forum_conversation/models.py
@@ -1,4 +1,5 @@
from django.conf import settings
+from django.contrib.contenttypes.fields import GenericRelation
from django.db import models
from django.db.models import Count, Exists, OuterRef
from django.urls import reverse
@@ -7,6 +8,7 @@
from taggit.managers import TaggableManager
from lacommunaute.forum_member.shortcuts import get_forum_member_display_name
+from lacommunaute.forum_upvote.models import UpVote
from lacommunaute.users.models import User
@@ -82,6 +84,8 @@ def is_certified(self):
class Post(AbstractPost):
username = models.EmailField(blank=True, null=True, verbose_name=("Adresse email"))
+ upvotes = GenericRelation(UpVote, related_query_name="post")
+
@property
def poster_display_name(self):
if self.username:
diff --git a/lacommunaute/forum_conversation/shortcuts.py b/lacommunaute/forum_conversation/shortcuts.py
index 02699f602..0b799196e 100644
--- a/lacommunaute/forum_conversation/shortcuts.py
+++ b/lacommunaute/forum_conversation/shortcuts.py
@@ -1,3 +1,4 @@
+from django.contrib.contenttypes.models import ContentType
from django.db.models import Count, Exists, OuterRef, Prefetch, Q, QuerySet
from lacommunaute.forum.enums import Kind as Forum_Kind
@@ -21,8 +22,13 @@ def get_posts_of_a_topic_except_first_one(topic: Topic, user: User) -> QuerySet[
if user.is_authenticated:
qs = qs.annotate(
upvotes_count=Count("upvotes"),
- # using user.id instead of user, to manage anonymous user journey
- has_upvoted=Exists(UpVote.objects.filter(post=OuterRef("pk"), voter=user)),
+ has_upvoted=Exists(
+ UpVote.objects.filter(
+ object_id=OuterRef("pk"),
+ voter=user,
+ content_type_id=ContentType.objects.get_for_model(qs.model).id,
+ )
+ ),
)
else:
qs = qs.annotate(
diff --git a/lacommunaute/forum_conversation/tests/tests_shortcuts.py b/lacommunaute/forum_conversation/tests/tests_shortcuts.py
index d27c6ef1d..bd2c1fa5d 100644
--- a/lacommunaute/forum_conversation/tests/tests_shortcuts.py
+++ b/lacommunaute/forum_conversation/tests/tests_shortcuts.py
@@ -51,7 +51,7 @@ def test_topic_has_two_posts_requested_by_authenticated_user(self):
def test_topic_has_been_upvoted(self):
topic = TopicFactory(with_post=True)
post = PostFactory(topic=topic)
- UpVoteFactory(post=post)
+ UpVoteFactory(content_object=post)
posts = get_posts_of_a_topic_except_first_one(topic, AnonymousUser())
post = posts.first()
@@ -62,7 +62,7 @@ def test_topic_has_been_upvoted(self):
def test_topic_has_been_upvoted_by_the_user(self):
topic = TopicFactory(with_post=True)
post = PostFactory(topic=topic)
- UpVoteFactory(post=post, voter=topic.poster)
+ UpVoteFactory(content_object=post, voter=topic.poster)
posts = get_posts_of_a_topic_except_first_one(topic, topic.poster)
post = posts.first()
diff --git a/lacommunaute/forum_conversation/tests/tests_views.py b/lacommunaute/forum_conversation/tests/tests_views.py
index 4b190f937..1a962abb8 100644
--- a/lacommunaute/forum_conversation/tests/tests_views.py
+++ b/lacommunaute/forum_conversation/tests/tests_views.py
@@ -432,7 +432,7 @@ def test_post_has_no_upvote(self):
def test_post_has_upvote_by_user(self):
PostFactory(topic=self.topic, poster=self.poster)
- UpVoteFactory(post=self.topic.last_post, voter=self.poster)
+ UpVoteFactory(content_object=self.topic.last_post, voter=self.poster)
self.client.force_login(self.poster)
response = self.client.get(self.url)
@@ -476,7 +476,7 @@ def test_edit_link_is_visible(self):
def test_numqueries(self):
PostFactory.create_batch(10, topic=self.topic, poster=self.poster)
- UpVoteFactory(post=self.topic.last_post, voter=UserFactory())
+ UpVoteFactory(content_object=self.topic.last_post, voter=UserFactory())
CertifiedPostFactory(topic=self.topic, post=self.topic.last_post, user=UserFactory())
self.client.force_login(self.poster)
diff --git a/lacommunaute/forum_conversation/tests/tests_views_htmx.py b/lacommunaute/forum_conversation/tests/tests_views_htmx.py
index da5380502..daeba6da1 100644
--- a/lacommunaute/forum_conversation/tests/tests_views_htmx.py
+++ b/lacommunaute/forum_conversation/tests/tests_views_htmx.py
@@ -343,8 +343,8 @@ def test_upvote_annotations(self):
self.assertEqual(response.status_code, 200)
self.assertContains(response, '0')
- UpVoteFactory(post=post, voter=UserFactory())
- UpVoteFactory(post=post, voter=self.user)
+ UpVoteFactory(content_object=post, voter=UserFactory())
+ UpVoteFactory(content_object=post, voter=self.user)
response = view.get(request)
self.assertEqual(response.status_code, 200)
diff --git a/lacommunaute/forum_upvote/admin.py b/lacommunaute/forum_upvote/admin.py
index cd732201a..b919d3c13 100644
--- a/lacommunaute/forum_upvote/admin.py
+++ b/lacommunaute/forum_upvote/admin.py
@@ -4,11 +4,8 @@
class UpVoteAdmin(admin.ModelAdmin):
- list_display = ("voter", "post", "created_at")
- raw_id_fields = (
- "voter",
- "post",
- )
+ list_display = ("voter", "created_at")
+ raw_id_fields = ("voter",)
admin.site.register(UpVote, UpVoteAdmin)
diff --git a/lacommunaute/forum_upvote/migrations/0004_convert_to_genericforeignkey.py b/lacommunaute/forum_upvote/migrations/0004_convert_to_genericforeignkey.py
new file mode 100644
index 000000000..5222a012d
--- /dev/null
+++ b/lacommunaute/forum_upvote/migrations/0004_convert_to_genericforeignkey.py
@@ -0,0 +1,80 @@
+import django.db.models.deletion
+from django.contrib.contenttypes.models import ContentType
+from django.db import migrations, models
+
+
+def move_forward_foreign_key_to_generic_foreign_key(apps, schema_editor):
+ UpVote = apps.get_model("forum_upvote", "UpVote")
+ ContentType = apps.get_model("contenttypes", "ContentType")
+
+ UpVote.objects.filter(post__isnull=True).delete()
+ content_type, _ = ContentType.objects.get_or_create(model="post", app_label="forum_conversation")
+
+ for upvote in UpVote.objects.all():
+ upvote.content_object = upvote.post
+ upvote.object_id = upvote.post_id
+ upvote.content_type = content_type
+ upvote.save()
+
+
+def move_back_generic_foreign_key_to_foreign_key(apps, schema_editor):
+ UpVote = apps.get_model("forum_upvote", "UpVote")
+ Post = apps.get_model("forum_conversation", "Post")
+
+ post_content_type_id = ContentType.objects.get(model="post", app_label="forum_conversation").id
+ UpVote.objects.exclude(content_type_id=post_content_type_id).delete()
+
+ for upvote in UpVote.objects.all():
+ upvote.post = Post.objects.get(id=upvote.object_id)
+ upvote.save()
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("contenttypes", "0002_remove_content_type_name"),
+ ("forum_upvote", "0003_delete_certifiedpost"),
+ ]
+
+ operations = [
+ migrations.AlterUniqueTogether(
+ name="upvote",
+ unique_together=set(),
+ ),
+ migrations.AddField(
+ model_name="upvote",
+ name="content_type",
+ field=models.ForeignKey(
+ null=True, on_delete=django.db.models.deletion.CASCADE, to="contenttypes.contenttype"
+ ),
+ ),
+ migrations.AddField(
+ model_name="upvote",
+ name="object_id",
+ field=models.PositiveBigIntegerField(null=True),
+ ),
+ migrations.RunPython(
+ move_forward_foreign_key_to_generic_foreign_key,
+ move_back_generic_foreign_key_to_foreign_key,
+ ),
+ migrations.AlterUniqueTogether(
+ name="upvote",
+ unique_together={("voter", "content_type", "object_id")},
+ ),
+ migrations.AlterField(
+ model_name="upvote",
+ name="content_type",
+ field=models.ForeignKey(
+ on_delete=django.db.models.deletion.CASCADE,
+ to="contenttypes.contenttype",
+ ),
+ ),
+ migrations.AlterField(
+ model_name="upvote",
+ name="object_id",
+ field=models.PositiveBigIntegerField(),
+ ),
+ migrations.RemoveField(
+ model_name="upvote",
+ name="post",
+ ),
+ ]
diff --git a/lacommunaute/forum_upvote/models.py b/lacommunaute/forum_upvote/models.py
index f13590ca1..a9990122b 100644
--- a/lacommunaute/forum_upvote/models.py
+++ b/lacommunaute/forum_upvote/models.py
@@ -1,8 +1,8 @@
from django.conf import settings
+from django.contrib.contenttypes.fields import GenericForeignKey
+from django.contrib.contenttypes.models import ContentType
from django.db import models
-from lacommunaute.forum_conversation.models import Post
-
class UpVote(models.Model):
voter = models.ForeignKey(
@@ -14,20 +14,16 @@ class UpVote(models.Model):
verbose_name="Voter",
)
- post = models.ForeignKey(
- Post,
- related_name="upvotes",
- blank=True,
- null=True,
- on_delete=models.SET_NULL,
- verbose_name="Post",
- )
+ content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
+ object_id = models.PositiveBigIntegerField()
+
+ content_object = GenericForeignKey("content_type", "object_id")
created_at = models.DateTimeField(auto_now_add=True, db_index=True, verbose_name="Creation date")
objects = models.Manager()
class Meta:
- unique_together = ["voter", "post"]
+ unique_together = ["voter", "content_type", "object_id"]
ordering = [
"-created_at",
]
diff --git a/lacommunaute/forum_upvote/tests/tests_models.py b/lacommunaute/forum_upvote/tests/tests_models.py
index 4a9a6885e..0d4f05264 100644
--- a/lacommunaute/forum_upvote/tests/tests_models.py
+++ b/lacommunaute/forum_upvote/tests/tests_models.py
@@ -6,9 +6,23 @@
class UpVoteModelTest(TestCase):
- def test_post_and_voter_are_uniques_together(self):
+ def test_generic_relation(self):
topic = TopicFactory(with_post=True)
- UpVote.objects.create(post=topic.first_post, voter=topic.first_post.poster)
+ UpVote.objects.create(content_object=topic.first_post, voter=topic.first_post.poster)
+ UpVote.objects.create(content_object=topic.forum, voter=topic.first_post.poster)
+
+ self.assertEqual(UpVote.objects.count(), 2)
+
+ def test_upvoted_post_unicity(self):
+ topic = TopicFactory(with_post=True)
+ UpVote.objects.create(content_object=topic.first_post, voter=topic.first_post.poster)
+
+ with self.assertRaises(IntegrityError):
+ UpVote.objects.create(content_object=topic.first_post, voter=topic.first_post.poster)
+
+ def test_upvoted_forum_unicity(self):
+ topic = TopicFactory(with_post=True)
+ UpVote.objects.create(content_object=topic.forum, voter=topic.first_post.poster)
with self.assertRaises(IntegrityError):
- UpVote.objects.create(post=topic.first_post, voter=topic.first_post.poster)
+ UpVote.objects.create(content_object=topic.forum, voter=topic.first_post.poster)
diff --git a/lacommunaute/forum_upvote/tests/tests_views.py b/lacommunaute/forum_upvote/tests/tests_views.py
index 270c43721..c697ee597 100644
--- a/lacommunaute/forum_upvote/tests/tests_views.py
+++ b/lacommunaute/forum_upvote/tests/tests_views.py
@@ -1,3 +1,4 @@
+from django.contrib.contenttypes.models import ContentType
from django.test import TestCase
from django.urls import reverse
from faker import Faker
@@ -44,12 +45,22 @@ def test_upvote_downvote_post(self):
response = self.client.post(self.url, data=form_data)
self.assertEqual(response.status_code, 200)
self.assertContains(response, '1')
- self.assertEqual(1, UpVote.objects.filter(voter_id=self.user.id, post_id=post.id).count())
+ self.assertEqual(
+ 1,
+ UpVote.objects.filter(
+ voter_id=self.user.id, object_id=post.id, content_type=ContentType.objects.get_for_model(post)
+ ).count(),
+ )
response = self.client.post(self.url, data=form_data)
self.assertEqual(response.status_code, 200)
self.assertContains(response, '0')
- self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=post.id).count())
+ self.assertEqual(
+ 0,
+ UpVote.objects.filter(
+ voter_id=self.user.id, object_id=post.id, content_type=ContentType.objects.get_for_model(post)
+ ).count(),
+ )
def test_object_not_found(self):
self.client.force_login(self.user)
@@ -58,13 +69,27 @@ def test_object_not_found(self):
response = self.client.post(self.url, data=form_data)
self.assertEqual(response.status_code, 404)
- self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=self.topic.last_post.id).count())
+ self.assertEqual(
+ 0,
+ UpVote.objects.filter(
+ voter_id=self.user.id,
+ object_id=self.topic.last_post.id,
+ content_type_id=ContentType.objects.get_for_model(self.topic.last_post).id,
+ ).count(),
+ )
form_data = {"pk": self.topic.pk, "post_pk": 9999}
response = self.client.post(self.url, data=form_data)
self.assertEqual(response.status_code, 404)
- self.assertEqual(0, UpVote.objects.filter(voter_id=self.user.id, post_id=self.topic.last_post.id).count())
+ self.assertEqual(
+ 0,
+ UpVote.objects.filter(
+ voter_id=self.user.id,
+ object_id=self.topic.last_post.id,
+ content_type_id=ContentType.objects.get_for_model(self.topic.last_post).id,
+ ).count(),
+ )
def test_topic_is_marked_as_read_when_upvoting(self):
self.assertFalse(ForumReadTrack.objects.count())
diff --git a/lacommunaute/forum_upvote/views.py b/lacommunaute/forum_upvote/views.py
index 61fb70c37..9cfb6f5b9 100644
--- a/lacommunaute/forum_upvote/views.py
+++ b/lacommunaute/forum_upvote/views.py
@@ -1,5 +1,6 @@
import logging
+from django.contrib.contenttypes.models import ContentType
from django.shortcuts import get_object_or_404, render
from django.views import View
from machina.core.loading import get_class
@@ -33,13 +34,21 @@ def get_object(self):
def post(self, request, **kwargs):
post = self.get_object()
- upvote = UpVote.objects.filter(voter_id=request.user.id, post_id=post.id)
+ upvote = UpVote.objects.filter(
+ voter_id=request.user.id,
+ object_id=post.id,
+ content_type=ContentType.objects.get_for_model(post),
+ )
if upvote.exists():
upvote.delete()
post.has_upvoted = False
else:
- UpVote(voter_id=request.user.id, post_id=post.id).save()
+ UpVote(
+ voter_id=request.user.id,
+ object_id=post.id,
+ content_type=ContentType.objects.get_for_model(post),
+ ).save()
post.has_upvoted = True
post.upvotes_count = post.upvotes.count()