Skip to content

Commit

Permalink
Add more tests, minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
henadzit committed Nov 8, 2024
1 parent 4bd6172 commit aadc8f9
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 70 deletions.
9 changes: 3 additions & 6 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,10 @@ async def test_decimal_sum_with_math_on_models_with_validators(self) -> None:
self.assertEqual(result, [{"sum": Decimal("-2.0")}])

async def test_function_requiring_nested_joins(self):
tournament = Tournament(name="Tournament")
await tournament.save()
tournament = await Tournament.create(name="Tournament")

event_first = Event(name="1", tournament=tournament)
await event_first.save()
event_second = Event(name="2", tournament=tournament)
await event_second.save()
event_first = await Event.create(name="1", tournament=tournament)
event_second = await Event.create(name="2", tournament=tournament)

team_first = await Team.create(name="First", alias=2)
team_second = await Team.create(name="Second", alias=10)
Expand Down
60 changes: 60 additions & 0 deletions tests/test_f.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from tortoise.contrib import test
from tortoise.expressions import Connector, F


class TestF(test.TestCase):
def test_arithmetic(self):
f = F("name")

negated = -f
self.assertEqual(negated.connector, Connector.mul)
self.assertEqual(negated.right.value, -1)

added = f + 1
self.assertEqual(added.connector, Connector.add)
self.assertEqual(added.right.value, 1)

radded = 1 + f
self.assertEqual(radded.connector, Connector.add)
self.assertEqual(radded.left.value, 1)
self.assertEqual(radded.right, f)

subbed = f - 1
self.assertEqual(subbed.connector, Connector.sub)
self.assertEqual(subbed.right.value, 1)

rsubbed = 1 - f
self.assertEqual(rsubbed.connector, Connector.sub)
self.assertEqual(rsubbed.left.value, 1)

mulled = f * 2
self.assertEqual(mulled.connector, Connector.mul)
self.assertEqual(mulled.right.value, 2)

rmulled = 2 * f
self.assertEqual(rmulled.connector, Connector.mul)
self.assertEqual(rmulled.left.value, 2)

divved = f / 2
self.assertEqual(divved.connector, Connector.div)
self.assertEqual(divved.right.value, 2)

rdivved = 2 / f
self.assertEqual(rdivved.connector, Connector.div)
self.assertEqual(rdivved.left.value, 2)

powed = f**2
self.assertEqual(powed.connector, Connector.pow)
self.assertEqual(powed.right.value, 2)

rpowed = 2**f
self.assertEqual(rpowed.connector, Connector.pow)
self.assertEqual(rpowed.left.value, 2)

modded = f % 2
self.assertEqual(modded.connector, Connector.mod)
self.assertEqual(modded.right.value, 2)

rmodded = 2 % f
self.assertEqual(rmodded.connector, Connector.mod)
self.assertEqual(rmodded.left.value, 2)
47 changes: 21 additions & 26 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,24 +436,27 @@ async def test_annotation_in_case_when(self):
self.assertEqual(tournaments[0].is_tournament, "yes")

async def test_f_annotation_filter(self):
await IntFields.create(intnum=1)
event = await IntFields.create(intnum=1)

events = await IntFields.annotate(intnum_plus_1=F("intnum") + 1).filter(intnum_plus_1=2)
self.assertEqual(len(events), 1)
ret_events = await IntFields.annotate(intnum_plus_1=F("intnum") + 1).filter(intnum_plus_1=2)
self.assertEqual(ret_events, [event])

async def test_f_annotation_custom_filter(self):
await IntFields.create(intnum=1)
event = await IntFields.create(intnum=1)

base_query = IntFields.annotate(intnum_plus_1=F("intnum") + 1)

events = await base_query.filter(intnum_plus_1__gt=1)
self.assertEqual(len(events), 1)
ret_events = await base_query.filter(intnum_plus_1__gt=1)
self.assertEqual(ret_events, [event])

events = await base_query.filter(intnum_plus_1__lt=3)
self.assertEqual(len(events), 1)
ret_events = await base_query.filter(intnum_plus_1__lt=3)
self.assertEqual(ret_events, [event])

events = await base_query.filter(Q(intnum_plus_1__gt=1) & Q(intnum_plus_1__lt=3))
self.assertEqual(len(events), 1)
ret_events = await base_query.filter(Q(intnum_plus_1__gt=1) & Q(intnum_plus_1__lt=3))
self.assertEqual(ret_events, [event])

ret_events = await base_query.filter(intnum_plus_1__isnull=True)
self.assertEqual(ret_events, [])

async def test_f_annotation_join(self):
tournament_a = await Tournament.create(name="A")
Expand Down Expand Up @@ -484,25 +487,17 @@ async def test_f_annotation_custom_filter_requiring_join(self):
self.assertEqual(events, [event_b])

async def test_f_annotation_custom_filter_requiring_nested_joins(self):
tournament = Tournament(name="Tournament")
await tournament.save()
tournament = await Tournament.create(name="Tournament")

second_tournament = Tournament(name="Tournament 2")
await second_tournament.save()
second_tournament = await Tournament.create(name="Tournament 2")

event_first = Event(name="1", tournament=tournament)
await event_first.save()
event_second = Event(name="2", tournament=second_tournament)
await event_second.save()
event_third = Event(name="3", tournament=tournament)
await event_third.save()
event_forth = Event(name="4", tournament=second_tournament)
await event_forth.save()
event_first = await Event.create(name="1", tournament=tournament)
event_second = await Event.create(name="2", tournament=second_tournament)
await Event.create(name="3", tournament=tournament)
await Event.create(name="4", tournament=second_tournament)

team_first = Team(name="First")
await team_first.save()
team_second = Team(name="Second")
await team_second.save()
team_first = await Team.create(name="First")
team_second = await Team.create(name="Second")

await team_first.events.add(event_first)
await event_second.participants.add(team_second)
Expand Down
16 changes: 7 additions & 9 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,15 @@ async def test_group_by_annotate_result(self):
)

async def test_group_by_requiring_nested_joins(self):
tournament_first = await Tournament.create(name="Tournament 1")
tournament_second = await Tournament.create(name="Tournament 2")
tournament_first = await Tournament.create(name="Tournament 1", desc="d1")
tournament_second = await Tournament.create(name="Tournament 2", desc="d2")

event_first = await Event.create(name="1", tournament=tournament_first)
event_second = await Event.create(name="2", tournament=tournament_first)
event_third = await Event.create(name="3", tournament=tournament_second)

team_first = await Team.create(name="First", alias=2)
team_second = await Team.create(name="Second", alias=3)
team_second = await Team.create(name="Second", alias=4)
team_third = await Team.create(name="Third", alias=5)

await team_first.events.add(event_first)
Expand All @@ -251,10 +251,8 @@ async def test_group_by_requiring_nested_joins(self):

res = (
await Tournament.annotate(avg=Avg("events__participants__alias"))
.group_by("id")
.order_by("name")
.values("name", "avg")
)
self.assertEqual(
res, [{"avg": 2, "name": "Tournament 1"}, {"avg": 5, "name": "Tournament 2"}]
.group_by("desc")
.order_by("desc")
.values("desc", "avg")
)
self.assertEqual(res, [{"avg": 3, "desc": "d1"}, {"avg": 5, "desc": "d2"}])
7 changes: 7 additions & 0 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,14 @@ async def test_f_annotation_referenced_in_annotation(self):
events = await IntFields.annotate(intnum_plus_1=F("intnum") + 1).annotate(
intnum_plus_2=F("intnum_plus_1") + 1
)
self.assertEqual(len(events), 1)
self.assertEqual(events[0].intnum_plus_1, 2)
self.assertEqual(events[0].intnum_plus_2, 3)

# in a single annotate call
events = await IntFields.annotate(
intnum_plus_1=F("intnum") + 1, intnum_plus_2=F("intnum_plus_1") + 1
)
self.assertEqual(len(events), 1)
self.assertEqual(events[0].intnum_plus_1, 2)
self.assertEqual(events[0].intnum_plus_2, 3)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_source_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,28 @@ async def test_values_by_fk(self):
obj = await self.model.filter(chars="bbb").values("fk__chars")
self.assertEqual(obj, [{"fk__chars": "aaa"}])

async def test_filter_with_field_f(self):
obj = await self.model.create(chars="a")
ret_obj = await self.model.filter(eyedee=F("eyedee")).first()
self.assertEqual(obj, ret_obj)

ret_obj = await self.model.filter(eyedee__lt=F("eyedee") + 1).first()
self.assertEqual(obj, ret_obj)

async def test_filter_with_field_f_annotation(self):
obj = await self.model.create(chars="a")
ret_obj = (
await self.model.annotate(eyedee_a=F("eyedee")).filter(eyedee=F("eyedee_a")).first()
)
self.assertEqual(obj, ret_obj)

ret_obj = (
await self.model.annotate(eyedee_a=F("eyedee") + 1)
.filter(eyedee__lt=F("eyedee_a"))
.first()
)
self.assertEqual(obj, ret_obj)


class SourceFieldTests(StraightFieldTests):
def setUp(self) -> None:
Expand Down
45 changes: 43 additions & 2 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
)
from tortoise.contrib import test
from tortoise.contrib.test.condition import In, NotEQ
from tortoise.expressions import F
from tortoise.functions import Function
from tortoise.expressions import Case, F, Q, When
from tortoise.functions import Function, Upper


class TestUpdate(test.TestCase):
Expand Down Expand Up @@ -205,3 +205,44 @@ async def test_update_with_limit_ordering(self):
await Tournament.filter(name="1").limit(1).order_by("-id").update(name="2")
self.assertIs((await Tournament.get(pk=t2.pk)).name, "2")
self.assertEqual(await Tournament.filter(name="1").count(), 1)

# tortoise-pypika does not translate ** to POWER in MSSQL
@test.requireCapability(dialect=NotEQ("mssql"))
async def test_update_with_case_when_and_f(self):
event1 = await IntFields.create(intnum=1)
event2 = await IntFields.create(intnum=2)
event3 = await IntFields.create(intnum=3)
await (
IntFields.all()
.annotate(
intnum_updated=Case(
When(
Q(intnum=1),
then=F("intnum") + 1,
),
When(
Q(intnum=2),
then=F("intnum") * 2,
),
default=F("intnum") ** 3,
)
)
.update(intnum=F("intnum_updated"))
)

for e in [event1, event2, event3]:
await e.refresh_from_db()
self.assertEqual(event1.intnum, 2)
self.assertEqual(event2.intnum, 4)
self.assertEqual(event3.intnum, 27)

async def test_update_with_function_annotation(self):
tournament = await Tournament.create(name="aaa")
await (
Tournament.filter(pk=tournament.pk)
.annotate(
upped_name=Upper(F("name")),
)
.update(name=F("upped_name"))
)
self.assertEqual((await Tournament.get(pk=tournament.pk)).name, "AAA")
Loading

0 comments on commit aadc8f9

Please sign in to comment.