Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support annotations and joins in F() #1761

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,15 @@ And you can also use functions in update, the example is only suitable for MySQL
.. code-block:: python3
from tortoise.expressions import F
from pypika.terms import Function
from tortoise.functions import Function
from pypika.terms import Function as PupikaFunction
class JsonSet(Function):
def __init__(self, field: F, expression: str, value: Any):
super().__init__("JSON_SET", field, expression, value)
class PypikaJsonSet(PupikaFunction):
def __init__(self, field: F, expression: str, value: Any):
super().__init__("JSON_SET", field, expression, value)
database_func = PypikaJsonSet
json = await JSONFields.create(data_default={"a": 1})
json.data_default = JsonSet(F("data_default"), "$.a", 2)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,17 @@ async def test_count_after_aggregate_m2m(self):
res = await query.count()
assert res == 2

async def test_where_and_having(self):
author = await Author.create(name="1")
await Book.create(name="First!", author=author, rating=4)
await Book.create(name="Second!", author=author, rating=3)
await Book.create(name="Third!", author=author, rating=3)

query = Book.exclude(name="First!").annotate(avg_rating=Avg("rating")).values("avg_rating")
result = await query
assert len(result) == 1
assert result[0]["avg_rating"] == 3

async def test_count_without_matching(self) -> None:
await Tournament.create(name="Test")

Expand Down Expand Up @@ -285,3 +296,18 @@ async def test_decimal_sum_with_math_on_models_with_validators(self) -> None:
).values("sum")
result = await query
self.assertEqual(result, [{"sum": Decimal("-2.0")}])

async def test_function_requiring_nested_joins(self):
tournament = await Tournament.create(name="Tournament")

event_first = await Event.create(name="1", tournament=tournament)
event_second = await Event.create(name="2", tournament=tournament)

team_first = await Team.create(name="First", alias=2)
team_second = await Team.create(name="Second", alias=10)

await team_first.events.add(event_first)
await event_second.participants.add(team_second)

res = await Tournament.annotate(avg=Avg("events__participants__alias")).values("avg")
self.assertEqual(res, [{"avg": 6}])
60 changes: 60 additions & 0 deletions tests/test_f.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from tortoise.contrib import test
from tortoise.expressions import Connector, F


class TestF(test.TestCase):
def test_arithmetic(self):
f = F("name")

negated = -f
self.assertEqual(negated.connector, Connector.mul)
self.assertEqual(negated.right.value, -1)

added = f + 1
self.assertEqual(added.connector, Connector.add)
self.assertEqual(added.right.value, 1)

radded = 1 + f
self.assertEqual(radded.connector, Connector.add)
self.assertEqual(radded.left.value, 1)
self.assertEqual(radded.right, f)

subbed = f - 1
self.assertEqual(subbed.connector, Connector.sub)
self.assertEqual(subbed.right.value, 1)

rsubbed = 1 - f
self.assertEqual(rsubbed.connector, Connector.sub)
self.assertEqual(rsubbed.left.value, 1)

mulled = f * 2
self.assertEqual(mulled.connector, Connector.mul)
self.assertEqual(mulled.right.value, 2)

rmulled = 2 * f
self.assertEqual(rmulled.connector, Connector.mul)
self.assertEqual(rmulled.left.value, 2)

divved = f / 2
self.assertEqual(divved.connector, Connector.div)
self.assertEqual(divved.right.value, 2)

rdivved = 2 / f
self.assertEqual(rdivved.connector, Connector.div)
self.assertEqual(rdivved.left.value, 2)

powed = f**2
self.assertEqual(powed.connector, Connector.pow)
self.assertEqual(powed.right.value, 2)

rpowed = 2**f
self.assertEqual(rpowed.connector, Connector.pow)
self.assertEqual(rpowed.left.value, 2)

modded = f % 2
self.assertEqual(modded.connector, Connector.mod)
self.assertEqual(modded.right.value, 2)

rmodded = 2 % f
self.assertEqual(rmodded.connector, Connector.mod)
self.assertEqual(rmodded.left.value, 2)
72 changes: 72 additions & 0 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,75 @@ async def test_annotation_in_case_when(self):
self.assertEqual(tournaments[0].name, "Tournament")
self.assertEqual(tournaments[0].name_lower, "tournament")
self.assertEqual(tournaments[0].is_tournament, "yes")

async def test_f_annotation_filter(self):
event = await IntFields.create(intnum=1)

ret_events = await IntFields.annotate(intnum_plus_1=F("intnum") + 1).filter(intnum_plus_1=2)
self.assertEqual(ret_events, [event])

async def test_f_annotation_custom_filter(self):
event = await IntFields.create(intnum=1)

base_query = IntFields.annotate(intnum_plus_1=F("intnum") + 1)

ret_events = await base_query.filter(intnum_plus_1__gt=1)
self.assertEqual(ret_events, [event])

ret_events = await base_query.filter(intnum_plus_1__lt=3)
self.assertEqual(ret_events, [event])

ret_events = await base_query.filter(Q(intnum_plus_1__gt=1) & Q(intnum_plus_1__lt=3))
self.assertEqual(ret_events, [event])

ret_events = await base_query.filter(intnum_plus_1__isnull=True)
self.assertEqual(ret_events, [])

async def test_f_annotation_join(self):
tournament_a = await Tournament.create(name="A")
tournament_b = await Tournament.create(name="B")
await Tournament.create(name="C")
event_a = await Event.create(name="A", tournament=tournament_a)
await Event.create(name="B", tournament=tournament_b)

events = (
await Event.all()
.annotate(tournament_name=F("tournament__name"))
.filter(tournament_name="A")
)
self.assertEqual(events, [event_a])

async def test_f_annotation_custom_filter_requiring_join(self):
tournament_a = await Tournament.create(name="A")
tournament_b = await Tournament.create(name="B")
await Tournament.create(name="C")
await Event.create(name="A", tournament=tournament_a)
event_b = await Event.create(name="B", tournament=tournament_b)

events = (
await Event.all()
.annotate(tournament_name=F("tournament__name"))
.filter(tournament_name__gt="A")
)
self.assertEqual(events, [event_b])

async def test_f_annotation_custom_filter_requiring_nested_joins(self):
tournament = await Tournament.create(name="Tournament")

second_tournament = await Tournament.create(name="Tournament 2")

event_first = await Event.create(name="1", tournament=tournament)
event_second = await Event.create(name="2", tournament=second_tournament)
await Event.create(name="3", tournament=tournament)
await Event.create(name="4", tournament=second_tournament)

team_first = await Team.create(name="First")
team_second = await Team.create(name="Second")

await team_first.events.add(event_first)
await event_second.participants.add(team_second)

res = await Tournament.annotate(pname=F("events__participants__name")).filter(
pname__startswith="Fir"
)
self.assertEqual(res, [tournament])
26 changes: 25 additions & 1 deletion tests/test_group_by.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tests.testmodels import Author, Book
from tests.testmodels import Author, Book, Event, Team, Tournament
from tortoise.contrib import test
from tortoise.functions import Avg, Count, Sum, Upper

Expand Down Expand Up @@ -232,3 +232,27 @@ async def test_group_by_annotate_result(self):
[{"upper_name": "AUTHOR1", "count": 10}, {"upper_name": "AUTHOR2", "count": 5}],
sorted_key="upper_name",
)

async def test_group_by_requiring_nested_joins(self):
tournament_first = await Tournament.create(name="Tournament 1", desc="d1")
tournament_second = await Tournament.create(name="Tournament 2", desc="d2")

event_first = await Event.create(name="1", tournament=tournament_first)
event_second = await Event.create(name="2", tournament=tournament_first)
event_third = await Event.create(name="3", tournament=tournament_second)

team_first = await Team.create(name="First", alias=2)
team_second = await Team.create(name="Second", alias=4)
team_third = await Team.create(name="Third", alias=5)

await team_first.events.add(event_first)
await team_second.events.add(event_second)
await team_third.events.add(event_third)

res = (
await Tournament.annotate(avg=Avg("events__participants__alias"))
.group_by("desc")
.order_by("desc")
.values("desc", "avg")
)
self.assertEqual(res, [{"avg": 3, "desc": "d1"}, {"avg": 5, "desc": "d2"}])
4 changes: 2 additions & 2 deletions tests/test_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_annotations_resolved(self):
ResolveContext(
model=IntFields,
table=IntFields._meta.basequery,
annotations={"annotated": F("annotated")},
annotations={"annotated": F("intnum")},
custom_filters={
"annotated__lt": {
"field": "annotated",
Expand All @@ -249,4 +249,4 @@ def test_annotations_resolved(self):
},
)
)
self.assertEqual(r.where_criterion.get_sql(), '"id">5 OR "annotated"<5')
self.assertEqual(r.where_criterion.get_sql(), '"id">5 OR "intnum"<5')
27 changes: 27 additions & 0 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,3 +863,30 @@ async def test_values_with_annotations(self):

tournaments = await base_query.values_list("name_length")
self.assertListSortEqual(tournaments, [(10,), (12,)])

async def test_f_annotation_referenced_in_annotation(self):
await IntFields.create(intnum=1)

events = await IntFields.annotate(intnum_plus_1=F("intnum") + 1).annotate(
intnum_plus_2=F("intnum_plus_1") + 1
)
self.assertEqual(len(events), 1)
self.assertEqual(events[0].intnum_plus_1, 2)
self.assertEqual(events[0].intnum_plus_2, 3)

# in a single annotate call
events = await IntFields.annotate(
intnum_plus_1=F("intnum") + 1, intnum_plus_2=F("intnum_plus_1") + 1
)
self.assertEqual(len(events), 1)
self.assertEqual(events[0].intnum_plus_1, 2)
self.assertEqual(events[0].intnum_plus_2, 3)

async def test_rawsql_annotation_referenced_in_annotation(self):
await IntFields.create(intnum=1)

events = await IntFields.annotate(ten=RawSQL("20 / 2")).annotate(ten_plus_1=F("ten") + 1)

self.assertEqual(len(events), 1)
self.assertEqual(events[0].ten, 10)
self.assertEqual(events[0].ten_plus_1, 11)
22 changes: 22 additions & 0 deletions tests/test_source_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,28 @@ async def test_values_by_fk(self):
obj = await self.model.filter(chars="bbb").values("fk__chars")
self.assertEqual(obj, [{"fk__chars": "aaa"}])

async def test_filter_with_field_f(self):
obj = await self.model.create(chars="a")
ret_obj = await self.model.filter(eyedee=F("eyedee")).first()
self.assertEqual(obj, ret_obj)

ret_obj = await self.model.filter(eyedee__lt=F("eyedee") + 1).first()
self.assertEqual(obj, ret_obj)

async def test_filter_with_field_f_annotation(self):
obj = await self.model.create(chars="a")
ret_obj = (
await self.model.annotate(eyedee_a=F("eyedee")).filter(eyedee=F("eyedee_a")).first()
)
self.assertEqual(obj, ret_obj)

ret_obj = (
await self.model.annotate(eyedee_a=F("eyedee") + 1)
.filter(eyedee__lt=F("eyedee_a"))
.first()
)
self.assertEqual(obj, ret_obj)


class SourceFieldTests(StraightFieldTests):
def setUp(self) -> None:
Expand Down
53 changes: 49 additions & 4 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any

import pytz
from pypika.terms import Function
from pypika.terms import Function as PupikaFunction

from tests.testmodels import (
Currency,
Expand All @@ -21,7 +21,8 @@
)
from tortoise.contrib import test
from tortoise.contrib.test.condition import In, NotEQ
from tortoise.expressions import F
from tortoise.expressions import Case, F, Q, When
from tortoise.functions import Function, Upper


class TestUpdate(test.TestCase):
Expand Down Expand Up @@ -155,8 +156,11 @@ async def test_update_relation(self):
@test.requireCapability(dialect=In("mysql", "sqlite"))
async def test_update_with_custom_function(self):
class JsonSet(Function):
def __init__(self, field: F, expression: str, value: Any):
super().__init__("JSON_SET", field, expression, value)
class PypikaJsonSet(PupikaFunction):
def __init__(self, field: F, expression: str, value: Any):
super().__init__("JSON_SET", field, expression, value)

database_func = PypikaJsonSet

json = await JSONFields.create(data={})
self.assertEqual(json.data_default, {"a": 1})
Expand Down Expand Up @@ -201,3 +205,44 @@ async def test_update_with_limit_ordering(self):
await Tournament.filter(name="1").limit(1).order_by("-id").update(name="2")
self.assertIs((await Tournament.get(pk=t2.pk)).name, "2")
self.assertEqual(await Tournament.filter(name="1").count(), 1)

# tortoise-pypika does not translate ** to POWER in MSSQL
@test.requireCapability(dialect=NotEQ("mssql"))
async def test_update_with_case_when_and_f(self):
event1 = await IntFields.create(intnum=1)
event2 = await IntFields.create(intnum=2)
event3 = await IntFields.create(intnum=3)
await (
IntFields.all()
.annotate(
intnum_updated=Case(
When(
Q(intnum=1),
then=F("intnum") + 1,
),
When(
Q(intnum=2),
then=F("intnum") * 2,
),
default=F("intnum") ** 3,
)
)
.update(intnum=F("intnum_updated"))
)

for e in [event1, event2, event3]:
await e.refresh_from_db()
self.assertEqual(event1.intnum, 2)
self.assertEqual(event2.intnum, 4)
self.assertEqual(event3.intnum, 27)

async def test_update_with_function_annotation(self):
tournament = await Tournament.create(name="aaa")
await (
Tournament.filter(pk=tournament.pk)
.annotate(
upped_name=Upper(F("name")),
)
.update(name=F("upped_name"))
)
self.assertEqual((await Tournament.get(pk=tournament.pk)).name, "AAA")
Loading