diff --git a/tests/prefetch_tests.py b/tests/prefetch_tests.py index c1b262f03..8846d08d7 100644 --- a/tests/prefetch_tests.py +++ b/tests/prefetch_tests.py @@ -525,3 +525,136 @@ def test_prefetch_mark_dirty_regression(self): self.assertEqual(note.dirty_fields, []) for item in note.items: self.assertEqual(item.dirty_fields, []) + + + +class X(TestModel): + name = TextField() +class Z(TestModel): + x = ForeignKeyField(X) + name = TextField() + +class A(TestModel): + name = TextField() + x = ForeignKeyField(X) +class B(TestModel): + name = TextField() + a = ForeignKeyField(A) + x = ForeignKeyField(X) +class C(TestModel): + name = TextField() + b = ForeignKeyField(B) + x = ForeignKeyField(X, null=True) + +class C1(TestModel): + name = TextField() + c = ForeignKeyField(C) +class C2(TestModel): + name = TextField() + c = ForeignKeyField(C) + + +class TestPrefetchMultiRefs(ModelTestCase): + database = get_in_memory_db() + requires = [X, Z, A, B, C, C1, C2] + + def test_prefetch_multirefs(self): + x1, x2, x3 = [X.create(name=n) for n in ('x1', 'x2', 'x3')] + for i, x in enumerate((x1, x2, x3), 1): + for j in range(i): + Z.create(x=x, name='%s-z%s' % (x.name, j)) + + xs = {x.name: x for x in X.select()} + xs[None] = None + + data = [ + ('a1', + 'x1', + ['x1-z0'], + [ + ('a1-b1', 'x1', ['x1-z0'], [ + ('a1-b1-c1', 'x1', ['x1-z0'], [], []), + ]), + ]), + ('a2', + 'x2', + ['x2-z0', 'x2-z1'], + [ + ('a2-b1', 'x1', ['x1-z0'], [ + ('a2-b1-c1', 'x1', ['x1-z0'], [], []), + ]), + ('a2-b2', 'x2', ['x2-z0', 'x2-z1'], [ + ('a2-b2-c1', 'x2', ['x2-z0', 'x2-z1'], [], []), + ('a2-b2-c2', 'x1', ['x1-z0'], [], []), + ('a2-b2-cx', None, [], [], []), + ]), + ]), + ('a3', + 'x3', + ['x3-z0', 'x3-z1', 'x3-z2'], + [ + ('a3-b1', 'x1', ['x1-z0'], [ + ('a3-b1-c1', 'x1', ['x1-z0'], [], []), + ]), + ('a3-b2', 'x2', ['x2-z0', 'x2-z1'], [ + ('a3-b2-c1', 'x2', ['x2-z0', 'x2-z1'], [], []), + ('a3-b2-c2', 'x2', ['x2-z0', 'x2-z1'], [], []), + ('a3-b2-cx1', None, [], [], []), + ('a3-b2-cx2', None, [], [], []), + ('a3-b2-cx3', None, [], [], []), + ]), + ('a3-b3', 'x3', ['x3-z0', 'x3-z1', 'x3-z2'], [ + ('a3-b3-c1', 'x3', ['x3-z0', 'x3-z1', 'x3-z2'], [], []), + ('a3-b3-c2', 'x3', ['x3-z0', 'x3-z1', 'x3-z2'], [], []), + ('a3-b3-c3', 'x3', ['x3-z0', 'x3-z1', 'x3-z2'], + ['c1-1', 'c1-2', 'c1-3', 'c1-4'], + ['c2-1', 'c2-2']), + ]), + ]), + ] + + for a, ax, azs, bs in data: + a = A.create(name=a, x=xs[ax]) + for b, bx, bzs, cs in bs: + b = B.create(name=b, a=a, x=xs[bx]) + for c, cx, czs, c1s, c2s in cs: + c = C.create(name=c, b=b, x=xs[cx]) + for c1 in c1s: + C1.create(name=c1, c=c) + for c2 in c2s: + C2.create(name=c2, c=c) + + + AX = X.alias('ax') + AXZ = Z.alias('axz') + BX = X.alias('bx') + BXZ = Z.alias('bxz') + CX = X.alias('cx') + CXZ = Z.alias('cxz') + + with self.assertQueryCount(11): + q = prefetch(A.select().order_by(A.name), *( + (AX, A), (AXZ, AX), + (B, A), (BX, B), (BXZ, BX), + (C, B), (CX, C), (CXZ, CX), + (C1, C), (C2, C))) + + with self.assertQueryCount(0): + accum = [] + for a in list(q): + azs = [z.name for z in a.x.z_set] + bs = [] + for b in a.b_set: + bzs = [z.name for z in b.x.z_set] + cs = [] + for c in b.c_set: + czs = [z.name for z in c.x.z_set] if c.x else [] + c1s = [c1.name for c1 in c.c1_set] + c2s = [c2.name for c2 in c.c2_set] + cs.append((c.name, c.x.name if c.x else None, czs, + c1s, c2s)) + + bs.append((b.name, b.x.name, bzs, cs)) + accum.append((a.name, a.x.name, azs, bs)) + + self.assertEqual(data, accum)