Skip to content

Commit

Permalink
Make ai::search have integrated sort and hit indexes (#7242)
Browse files Browse the repository at this point in the history
Tweak ai::search codegen to make it hit the index reliably even with
filtering NULLs out. It seems that postgres *can* sometimes manage to
use an ORDER BY index even when the function call isn't directly in
the ORDER BY, but it is much more fragile (broken by adding the NULL
check in #7223, for one).

Making ai::search return sorted output makes it easy to hit the
indexes and improves ergonomics.

Also:
 * Compile the arguments in the enclosing scope, which helps us hit the
   index in more complex scenarios (like a cast from json)
 * Make sure to export a source rvar for `.object`
  • Loading branch information
msullivan committed Apr 22, 2024
1 parent f56336c commit 7c29a9b
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 14 deletions.
42 changes: 36 additions & 6 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from edb.schema import objects as s_obj
from edb.schema import name as sn

from edb.edgeql import ast as qlast

from edb.ir import ast as irast
from edb.ir import typeutils as irtyputils
from edb.ir import utils as irutils
Expand Down Expand Up @@ -3228,6 +3230,7 @@ def _compile_call_args(
ir_set: irast.Set,
*,
skip: Collection[int] = (),
no_subquery_args: bool = False,
ctx: context.CompilerContextLevel,
) -> List[pgast.BaseExpr]:
"""
Expand Down Expand Up @@ -3263,6 +3266,7 @@ def _compile_call_args(
and ir_arg.cardinality.is_single()
and (arg_typeref.is_scalar or arg_typeref.collection)
and not _needs_arg_null_check(expr, ir_arg, typemod, ctx=ctx)
and not no_subquery_args
)

if make_subquery:
Expand Down Expand Up @@ -3991,9 +3995,28 @@ def _ext_ai_search_inner_pgvector(
],
)

# Install the filter directly in newctx.rel. We could return it
# and have it put in inner_ctx.rel, and that does seem to work,
# but seems weirder.
valid = pgast.NullTest(arg=embedding, negated=True)
newctx.rel.where_clause = astutils.extend_binop(
newctx.rel.where_clause, valid
)

return similarity, valid
# Do an integrated sort. This ensures we can hit the index, and is
# more ergonomic anyway. Having the ORDER BY operate directly on
# the function call is not the *only* way to have it work, but it
# is the most reliable.
sort_by = pgast.SortBy(
node=similarity,
dir=qlast.SortOrder.Asc,
nulls=qlast.NonesOrder.Last,
)
if newctx.rel.sort_clause is None:
newctx.rel.sort_clause = []
newctx.rel.sort_clause.append(sort_by)

return similarity, None


def _process_set_as_object_search(
Expand All @@ -4003,17 +4026,22 @@ def _process_set_as_object_search(
ctx: context.CompilerContextLevel,
) -> SetRVars:
func_call = ir_set.expr

# We skip the object, as it has to be compiled as rvar source.
#
# Also, disable subquery args. ai::search needs it for its
# scoping effects, but we don't need to use it here, since
# it can cause the ai search to duplicate arguments.
args_pg = _compile_call_args(
ir_set, skip={0}, no_subquery_args=True, ctx=ctx)

with ctx.subrel() as newctx:
newctx.expr_exposed = False

obj_ir = func_call.args[0].expr
obj_id = obj_ir.path_id
obj_rvar = ensure_source_rvar(obj_ir, newctx.rel, ctx=newctx)

# we skip the object, as it has to be compiled as rvar source
args_pg = _compile_call_args(
ir_set, skip={0}, ctx=newctx)

out_obj_id, out_score_id = func_call.tuple_path_ids

with newctx.subrel() as inner_ctx:
Expand Down Expand Up @@ -4082,13 +4110,15 @@ def _process_set_as_object_search(

pathctx.put_path_id_map(newctx.rel, out_obj_id, obj_id)

aspects = {'value'}
aspects = {'value', 'source'}

func_rvar = relctx.new_rel_rvar(ir_set, newctx.rel, ctx=ctx)
relctx.include_rvar(
ctx.rel, func_rvar, ir_set.path_id, aspects=aspects, ctx=ctx
)

pathctx.put_path_rvar(ctx.rel, out_obj_id, func_rvar, aspect='source')

return new_stmt_set_rvar(ir_set, ctx.rel, aspects=aspects, ctx=ctx)


Expand Down
6 changes: 6 additions & 0 deletions tests/schemas/ext_ai.esdl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,9 @@ type Stuff extending Astronomy {
type Star extending Astronomy;

type Supernova extending Star;

function _set_seqscan(val: std::str) -> std::str {
using sql $$
select set_config('enable_seqscan', val, true)
$$;
};
17 changes: 9 additions & 8 deletions tests/test_edgeql_data_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12542,12 +12542,13 @@ async def test_edgeql_migration_ai_08(self):
};
''', explicit_modules=True)

arg = [0.0] * 1536
await self.con.query('''
select {
base := ext::ai::search(Base, <array<float32>>[1]),
sub := ext::ai::search(Sub, <array<float32>>[1]),
base := ext::ai::search(Base, <array<float32>>$0),
sub := ext::ai::search(Sub, <array<float32>>$0),
}
''')
''', arg)

await self.migrate('''
using extension ai;
Expand All @@ -12571,10 +12572,10 @@ async def test_edgeql_migration_ai_08(self):

await self.con.query('''
select {
base := ext::ai::search(Base, <array<float32>>[1]),
sub := ext::ai::search(Sub, <array<float32>>[1]),
base := ext::ai::search(Base, <array<float32>>$0),
sub := ext::ai::search(Sub, <array<float32>>$0),
}
''')
''', arg)

await self.migrate('''
using extension ai;
Expand All @@ -12596,9 +12597,9 @@ async def test_edgeql_migration_ai_08(self):
# Base lost the index, just select Sub
await self.con.query('''
select {
sub := ext::ai::search(Sub, <array<float32>>[1]),
sub := ext::ai::search(Sub, <array<float32>>$0),
}
''')
''', arg)


class EdgeQLMigrationRewriteTestCase(EdgeQLDataMigrationTestCase):
Expand Down
60 changes: 60 additions & 0 deletions tests/test_ext_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,63 @@ async def test_ext_ai_indexing_03(self):
],
}
)

async def _assert_index_use(self, query, *args):
def look(obj):
if isinstance(obj, dict) and obj.get('plan_type') == "IndexScan":
return any(
prop['title'] == 'index_name'
and f'ai::index' in prop['value']
for prop in obj.get('properties', [])
)

if isinstance(obj, dict):
return any([look(v) for v in obj.values()])
elif isinstance(obj, list):
return any(look(v) for v in obj)
else:
return False

async with self._run_and_rollback():
await self.con.execute('select _set_seqscan("off");')
plan = await self.con.query_json(f'analyze {query};', *args)
if not look(json.loads(plan)):
raise AssertionError(f'query did not use ext::ai::index index')

async def test_ext_ai_indexing_04(self):
qv = [1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0]

await self._assert_index_use(
f'''
with vector := <array<float32>>$0
select ext::ai::search(Stuff, vector) limit 5;
''',
qv,
)
await self._assert_index_use(
f'''
with vector := <array<float32>>$0
select ext::ai::search(Stuff, vector).object limit 5;
''',
qv,
)
await self._assert_index_use(
f'''
select ext::ai::search(Stuff, <array<float32>>$0) limit 5;
''',
qv,
)

await self._assert_index_use(
f'''
with vector := <array<float32>><json>$0
select ext::ai::search(Stuff, vector) limit 5;
''',
json.dumps(qv),
)
await self._assert_index_use(
f'''
select ext::ai::search(Stuff, <array<float32>><json>$0) limit 5;
''',
json.dumps(qv),
)

0 comments on commit 7c29a9b

Please sign in to comment.