Skip to content

Commit

Permalink
QueryTranslator: Implemented predictive index support
Browse files Browse the repository at this point in the history
  • Loading branch information
snej committed Sep 18, 2024
1 parent 57962fe commit dea746f
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 77 deletions.
25 changes: 16 additions & 9 deletions LiteCore/Query/SQLiteKeyStore+PredictiveIndexes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace litecore {
// Derive the table name from the expression (path) it unnests:
auto kvTableName = tableName();
auto q_kvTableName = quotedTableName();
QueryTranslator qp(db(), "", kvTableName);
QueryTranslator qp(db(), string(kDefaultCollectionName), kvTableName);
auto predTableName = qp.predictiveTableName((FLValue)expression);

// Create the index table, unless an identical one already exists:
Expand All @@ -75,23 +75,30 @@ namespace litecore {
if ( !db().schemaExistsWithSQL(predTableName, "table", predTableName, sql) ) {
LogTo(QueryLog, "Creating predictive table '%s' on %s", predTableName.c_str(),
expression->toJSONString().c_str());
// Capture the SQL of the `predict(...)` call, _before_ creating the table.
// (If we created the table first, the query translator would generate SQL that used it!)
string predictExpr = qp.expressionSQL((FLValue)expression);
qp.setBodyColumnName("new.body");
string triggerPredictExpr = qp.expressionSQL((FLValue)expression);

// Create the index-table:
LogTo(QueryLog, "Creating predictive index table: %s", sql.c_str());
db().exec(sql);

// Populate the index-table with data from existing documents:
string predictExpr = qp.expressionSQL((FLValue)expression);
db().exec(CONCAT("INSERT INTO " << sqlIdentifier(predTableName)
<< " (docid, body) "
"SELECT rowid, "
<< predictExpr << "FROM " << q_kvTableName << " WHERE (flags & 1) = 0"));
sql = CONCAT("INSERT INTO " << sqlIdentifier(predTableName)
<< " (docid, body) "
"SELECT rowid, "
<< predictExpr << "FROM " << q_kvTableName << " as _doc WHERE (flags & 1) = 0");
LogTo(QueryLog, "Populating predictive index table: %s", sql.c_str());
db().exec(sql);

// Set up triggers to keep the index-table up to date
// ...on insertion:
qp.setBodyColumnName("new.body");
predictExpr = qp.expressionSQL((FLValue)expression);
string insertTriggerExpr = CONCAT("INSERT INTO " << sqlIdentifier(predTableName)
<< " (docid, body) "
"VALUES (new.rowid, "
<< predictExpr << ")");
<< triggerPredictExpr << ")");
createTrigger(predTableName, "ins", "AFTER INSERT", "WHEN (new.flags & 1) = 0", insertTriggerExpr);

// ...on delete:
Expand Down
2 changes: 2 additions & 0 deletions LiteCore/Query/Translator/ExprNodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ namespace litecore::qt {
#ifdef COUCHBASE_ENTERPRISE
case OpType::vectorDistance:
return new (ctx) VectorDistanceNode(operands, ctx);
case OpType::prediction:
return PredictionNode::parse(operands, ctx);
#endif
default:
// A normal OpNode
Expand Down
105 changes: 74 additions & 31 deletions LiteCore/Query/Translator/IndexedNodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,25 @@ namespace litecore::qt {
using namespace fleece;

// indexed by IndexType:
constexpr const char* kOwnerFnName[2] = {"MATCH", "APPROX_VECTOR_DISTANCE"};
constexpr const char* kIndexTypeName[3] = {"FTS", "vector", "predictive"};
constexpr const char* kOwnerFnName[3] = {"MATCH", "APPROX_VECTOR_DISTANCE", "PREDICTION"};

void IndexedNode::setIndexedExpression(ExprNode* expression) {
_indexedExpr = expression;
expression->visitTree([&](Node& n, unsigned /*depth*/) {
if ( SourceNode* nodeSource = n.source() ) {
require(_sourceCollection == nullptr || _sourceCollection == nodeSource,
"1st argument to %s may only refer to a single collection", kOwnerFnName[int(_type)]);
_sourceCollection = nodeSource;
}
});
require(_sourceCollection, "unknown source collection for %s()", kOwnerFnName[int(_type)]);
}

void IndexedNode::writeSourceTable(SQLWriter& ctx, string_view tableName) const {
require(!tableName.empty(), "missing %s index", kIndexTypeName[int(_type)]);
ctx << sqlIdentifier(tableName);
}

#pragma mark - FTS:

Expand All @@ -38,13 +56,8 @@ namespace litecore::qt {
require(source, "unknown source collection for %s()", name);
require(source->isCollection(), "invalid source collection for %s()", name);
require(path.count() > 0, "missing property after collection alias in %s()", name);
_sourceCollection = source;
_indexExpressionJSON = string(path.toString());
}

void FTSNode::writeSourceTable(SQLWriter& ctx, string_view tableName) const {
require(!tableName.empty(), "missing FTS index");
ctx << sqlIdentifier(tableName);
_sourceCollection = source;
_indexID = ctx.newString(path.toString());
}

void FTSNode::writeIndex(SQLWriter& sql) const {
Expand All @@ -71,11 +84,10 @@ namespace litecore::qt {
ctx << "))";
}

#pragma mark - VECTOR:


#ifdef COUCHBASE_ENTERPRISE

# pragma mark - VECTOR:

// A SQLite vector MATCH expression; used by VectorDistanceNode to add a join condition.
class VectorMatchNode final : public ExprNode {
public:
Expand All @@ -90,19 +102,9 @@ namespace litecore::qt {
ExprNode* _vector;
};

VectorDistanceNode::VectorDistanceNode(Array::iterator& args, ParseContext& ctx)
: IndexedNode(IndexType::vector), _indexedExpr(parse(args[0], ctx)) {
VectorDistanceNode::VectorDistanceNode(Array::iterator& args, ParseContext& ctx) : IndexedNode(IndexType::vector) {
// Determine which collection the vector is based on:
SourceNode* source = nullptr;
_indexedExpr->visitTree([&](Node& n, unsigned /*depth*/) {
if ( SourceNode* nodeSource = n.source() ) {
require(source == nullptr || source == nodeSource,
"1st argument (vector) to APPROX_VECTOR_DISTANCE may only refer to a single collection");
source = nodeSource;
}
});
require(source, "unknown source collection for APPROX_VECTOR_DISTANCE()");
_sourceCollection = source;
setIndexedExpression(ExprNode::parse(args[0], ctx));

// Create the JSON expression used to locate the index:
string indexExpr(args[0].toJSON(false, true));
Expand All @@ -118,7 +120,7 @@ namespace litecore::qt {
replace(indexExpr, "[\"." + prefix + ".", "[\".");
}
}
_indexExpressionJSON = ctx.newString(indexExpr);
_indexID = ctx.newString(indexExpr);

_vector = ExprNode::parse(args[1], ctx);

Expand Down Expand Up @@ -183,8 +185,7 @@ namespace litecore::qt {
}

void VectorDistanceNode::writeSourceTable(SQLWriter& sql, string_view tableName) const {
require(!tableName.empty(), "missing vector index");
if ( _simple ) {
if ( _simple && !tableName.empty() ) {
// In a "simple" vector match, run the vector query as a nested SELECT:
sql << "(SELECT docid, distance FROM " << sqlIdentifier(tableName) << " WHERE vector MATCH encode_vector("
<< _vector << ")";
Expand All @@ -193,7 +194,7 @@ namespace litecore::qt {
require(limit, "a LIMIT must be given when using APPROX_VECTOR_DISTANCE()");
sql << " LIMIT " << limit << ")";
} else {
sql << sqlIdentifier(tableName);
IndexedNode::writeSourceTable(sql, tableName);
}
}

Expand All @@ -203,6 +204,49 @@ namespace litecore::qt {
ctx << sqlIdentifier(_indexSource->alias()) << ".distance";
}

# pragma mark - PREDICTION:

ExprNode* PredictionNode::parse(Array::iterator args, ParseContext& ctx) {
// Unlike a vector or FTS query, a prediction() is not required to have an index.
// Check whether one exists. Unfortunately, the index identifier is based on the entire
// expression array including the first item `PREDICTION()` which isn't in the iterator,
// so we have to reconstruct it:
auto expr = MutableArray::newArray();
expr.append("PREDICTION()");
expr.append(args[0]);
expr.append(args[1]);
string id = expressionIdentifier(expr);

if ( ctx.delegate.hasPredictiveIndex(id) ) {
return new (ctx) PredictionNode(args, ctx, id);
} else {
return FunctionNode::parse(kPredictionFnName, args, ctx);
}
}

PredictionNode::PredictionNode(Array::iterator& args, ParseContext& ctx, string_view indexID)
: IndexedNode(IndexType::prediction) {
_indexID = ctx.newString(indexID);
setIndexedExpression(ExprNode::parse(args[1], ctx));
if ( args.count() > 2 ) {
slice pathStr = requiredString(args[2], "property path of PREDICTION()");
KeyPath path = parsePath(pathStr);
require(path.count() > 0, "invalid property path in PREDICTION()");
_subProperty = ctx.newString(path.toString());
}
}

void PredictionNode::writeSQL(SQLWriter& out) const {
auto alias = sqlIdentifier(_indexSource->alias());
if ( _subProperty ) {
out << kUnnestedValueFnName << "(" << alias << ".body, " << sqlString(_subProperty);
out << ")";
} else {
out << kRootFnName << "(" << alias << ".body)";
}
}


#endif


Expand All @@ -221,15 +265,14 @@ namespace litecore::qt {
}

bool IndexSourceNode::matchesNode(const IndexedNode* node) const {
return _indexedNode->indexType() == node->indexType()
&& _indexedNode->indexExpressionJSON() == node->indexExpressionJSON()
return _indexedNode->indexType() == node->indexType() && _indexedNode->indexID() == node->indexID()
&& collection() == node->sourceCollection()->collection()
&& scope() == node->sourceCollection()->scope();
}

IndexType IndexSourceNode::indexType() const { return _indexedNode->indexType(); }

string_view IndexSourceNode::indexedExpressionJSON() const { return _indexedNode->indexExpressionJSON(); }
string_view IndexSourceNode::indexID() const { return _indexedNode->indexID(); }

void IndexSourceNode::addIndexedNode(IndexedNode* node) {
Assert(node != _indexedNode && node->indexType() == _indexedNode->indexType());
Expand Down Expand Up @@ -300,7 +343,7 @@ namespace litecore::qt {
/// Adds a SourceNode for an IndexedNode, or finds an existing one.
/// Sets the source as its indexSource.
void SelectNode::addIndexForNode(IndexedNode* node, ParseContext& ctx) {
DebugAssert(!node->indexExpressionJSON().empty());
DebugAssert(!node->indexID().empty());

// Look for an existing index source:
IndexSourceNode* indexSrc = nullptr;
Expand Down
28 changes: 21 additions & 7 deletions LiteCore/Query/Translator/IndexedNodes.hh
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ namespace litecore::qt {
public:
IndexType indexType() const { return _type; }

/// JSON of the indexed expression, usually a property
string_view indexExpressionJSON() const { return _indexExpressionJSON; }
/// A unique identifier of the indexed expression, used to match it with an IndexSourceNode.
string_view indexID() const { return _indexID; }

/// The collection being searched.
SourceNode* C4NULLABLE sourceCollection() const { return _sourceCollection; }
Expand All @@ -49,13 +49,16 @@ namespace litecore::qt {
bool isAuxiliary() const { return _isAuxiliary; }

/// Writes SQL for the index table name (or SELECT expression)
virtual void writeSourceTable(SQLWriter& ctx, string_view tableName) const = 0;
virtual void writeSourceTable(SQLWriter& ctx, string_view tableName) const;

protected:
IndexedNode(IndexType type) : _type(type) {}

void setIndexedExpression(ExprNode*);

IndexType const _type; // Index type
string _indexExpressionJSON; // Expression/property that's indexed, as JSON
ExprNode* _indexedExpr; // The indexed expression (usually a doc property)
string_view _indexID; // Expression/property that's indexed
SourceNode* C4NULLABLE _sourceCollection{}; // The collection being queried
IndexSourceNode* C4NULLABLE _indexSource{}; // Source representing the index
SelectNode* C4NULLABLE _select{}; // The containing SELECT statement
Expand All @@ -67,7 +70,6 @@ namespace litecore::qt {
protected:
FTSNode(Array::iterator& args, ParseContext&, const char* name);

void writeSourceTable(SQLWriter& ctx, string_view tableName) const override;
void writeIndex(SQLWriter&) const;
};

Expand Down Expand Up @@ -111,13 +113,25 @@ namespace litecore::qt {
void writeSQL(SQLWriter&) const override;

private:
ExprNode* _indexedExpr; // The indexed expression (usually a doc property)
ExprNode* _vector; // The vector being queried
int _metric; // Distance metric (actually vectorsearch::Metric)
unsigned _numProbes = 0; // Number of probes, or 0 for default
bool _simple = true; // True if this is a simple (non-hybrid) query
};

/** A `prediction()` function call that uses an index. */
class PredictionNode final : public IndexedNode {
public:
static ExprNode* parse(Array::iterator args, ParseContext&);

void writeSQL(SQLWriter&) const override;

private:
PredictionNode(Array::iterator& args, ParseContext& ctx, string_view indexID);

const char* _subProperty{};
};

#endif

#pragma mark - INDEX SOURCE:
Expand All @@ -128,7 +142,7 @@ namespace litecore::qt {
explicit IndexSourceNode(IndexedNode*, string_view alias, ParseContext& ctx);

IndexType indexType() const;
string_view indexedExpressionJSON() const;
string_view indexID() const;

bool matchesNode(IndexedNode const*) const;

Expand Down
3 changes: 2 additions & 1 deletion LiteCore/Query/Translator/Node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ namespace litecore::qt {
// Typical queries only allocate a few KB, not enough to fill a single chunk.
static constexpr size_t kArenaChunkSize = 4000;

RootContext::RootContext() : Arena(kArenaChunkSize), ParseContext(*static_cast<Arena*>(this)) {}
RootContext::RootContext()
: Arena(kArenaChunkSize), ParseContext(*static_cast<ParseDelegate*>(this), *static_cast<Arena*>(this)) {}

void* Node::operator new(size_t size, ParseContext& ctx) noexcept { return ctx.arena.alloc(size, alignof(Node)); }

Expand Down
26 changes: 22 additions & 4 deletions LiteCore/Query/Translator/Node.hh
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,32 @@ namespace litecore::qt {
};

/** Types of indexes. */
enum class IndexType { FTS, vector };
enum class IndexType {
FTS,
#ifdef COUCHBASE_ENTERPRISE
vector,
prediction,
#endif
};

#pragma mark - PARSE CONTEXT:

struct ParseDelegate {
#ifdef COUCHBASE_ENTERPRISE
std::function<bool(string_view id)> hasPredictiveIndex;
#endif
};

/** State used during parsing, passed down through the recursive descent. */
struct ParseContext {
ParseContext(Arena<>& a) : arena(a) {}
ParseContext(ParseDelegate& d, Arena<>& a) : delegate(d), arena(a) {}

// not a copy constructor! Creates a new child context.
explicit ParseContext(ParseContext& parent) : delegate(parent.delegate), arena(parent.arena){};

ParseContext(ParseContext const& parent) : arena(parent.arena){};
ParseContext(ParseContext&&) = default;

ParseDelegate& delegate;
Arena<>& arena; // The arena allocator
SelectNode* C4NULLABLE select{}; // The enclosing SELECT, if any
std::unordered_map<string, AliasedNode*> aliases; // All of the sources & named results
Expand All @@ -87,8 +103,10 @@ namespace litecore::qt {
/** Top-level Context that provides an Arena, and destructs all Nodes in its destructor. */
struct RootContext
: Arena<>
, public ParseDelegate
, public ParseContext {
RootContext();
explicit RootContext();
RootContext(RootContext&&) = default;
};

#pragma mark - NODE CLASS:
Expand Down
2 changes: 1 addition & 1 deletion LiteCore/Query/Translator/NodesToSQL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ namespace litecore::qt {

void MetaNode::writeSQL(SQLWriter& ctx) const {
string aliasDot;
if ( _source ) aliasDot = CONCAT(sqlIdentifier(_source->alias()) << ".");
if ( _source && !_source->alias().empty() ) aliasDot = CONCAT(sqlIdentifier(_source->alias()) << ".");
writeMetaSQL(aliasDot, _property, ctx);
}

Expand Down
Loading

0 comments on commit dea746f

Please sign in to comment.