diff --git a/sqlite-vec.c b/sqlite-vec.c index 29d1e15..cf8e267 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -5024,6 +5024,399 @@ static sqlite3_module vec_static_blob_entriesModule = { /* xShadowName */ 0}; #pragma endregion +#pragma region vec_expo() table function + +void dethrone2(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size, + i32 *chunk_top_idx, f32 *chunk_distances, i64 *chunk_rowids, + + i64 **out_rowids, f32 **out_distances) { + *out_rowids = sqlite3_malloc(k * sizeof(i64)); + todo_assert(*out_rowids); + *out_distances = sqlite3_malloc(k * sizeof(f32)); + todo_assert(*out_distances); + + size_t ptrA = 0; + size_t ptrB = 0; + for (int i = 0; i < k; i++) { + if (ptrA < chunk_size && (ptrB >= k || chunk_distances[chunk_top_idx[ptrA]] < base_distances[ptrB])) { + (*out_rowids)[i] = chunk_rowids[chunk_top_idx[ptrA]]; + (*out_distances)[i] = chunk_distances[chunk_top_idx[ptrA]]; + ptrA++; + } else if (ptrB < k) { + (*out_rowids)[i] = base_rowids[ptrB]; + (*out_distances)[i] = base_distances[ptrB]; + ptrB++; + } + } +} + + +typedef struct vec_expo_vtab vec_expo_vtab; +struct vec_expo_vtab { + sqlite3_vtab base; + sqlite3 * db; + char * table; + char * column; +}; + +typedef struct vec_expo_cursor vec_expo_cursor; +struct vec_expo_cursor { + sqlite3_vtab_cursor base; + sqlite3_int64 iRowid; + vec_sbe_query_plan query_plan; + struct vec0_query_knn_data * knn_data; +}; + + +static int vec_expoConnect(sqlite3 *db, void *pAux, int argc, + const char *const *argv, sqlite3_vtab **ppVtab, + char **pzErr) { + vec_expo_vtab *pNew; + assert(argc==5); + +#define VEC_EXPO_VECTOR 0 +#define VEC_EXPO_DISTANCE 1 +#define VEC_EXPO_K 2 + int rc = sqlite3_declare_vtab( + db, "CREATE TABLE x(vector, distance hidden, k hidden)"); + assert(rc == SQLITE_OK); + pNew = sqlite3_malloc(sizeof(*pNew)); + if (pNew == 0) + return SQLITE_NOMEM; + memset(pNew, 0, sizeof(*pNew)); + pNew->db = db; + pNew->table = sqlite3_mprintf("%s", argv[3]); + pNew->column = sqlite3_mprintf("%s", argv[4]); + *ppVtab = (sqlite3_vtab *)pNew; + return SQLITE_OK; +} + +static int vec_expoCreate(sqlite3 *db, void *pAux, int argc, + const char *const *argv, sqlite3_vtab **ppVtab, + char **pzErr) { + return vec_expoConnect(db, pAux, argc, argv, ppVtab, pzErr); + } + +static int vec_expoDisconnect(sqlite3_vtab *pVtab) { + vec_expo_vtab *p = (vec_expo_vtab *)pVtab; + sqlite3_free(p); + return SQLITE_OK; +} + +static int vec_expoOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { + vec_expo_cursor *pCur; + pCur = sqlite3_malloc(sizeof(*pCur)); + if (pCur == 0) + return SQLITE_NOMEM; + memset(pCur, 0, sizeof(*pCur)); + *ppCursor = &pCur->base; + return SQLITE_OK; +} + +static int vec_expoClose(sqlite3_vtab_cursor *cur) { + vec_expo_cursor *pCur = (vec_expo_cursor *)cur; + sqlite3_free(pCur); + return SQLITE_OK; +} + +static int vec_expoBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { + vec_expo_vtab *p = (vec_expo_vtab *)pVTab; + int iMatchTerm = -1; + int iLimitTerm = -1; + int iRowidTerm = -1; // TODO point query + int iKTerm = -1; + + for (int i = 0; i < pIdxInfo->nConstraint; i++) { + if (!pIdxInfo->aConstraint[i].usable) + continue; + + int iColumn = pIdxInfo->aConstraint[i].iColumn; + int op = pIdxInfo->aConstraint[i].op; + if (op == SQLITE_INDEX_CONSTRAINT_MATCH && iColumn == VEC_EXPO_VECTOR) { + if (iMatchTerm > -1) { + // TODO only 1 match operator at a time + return SQLITE_ERROR; + } + iMatchTerm = i; + } + if (op == SQLITE_INDEX_CONSTRAINT_LIMIT) { + iLimitTerm = i; + } + if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == VEC_EXPO_K) { + iKTerm = i; + } + } + if(iMatchTerm >= 0) { + if (iLimitTerm < 0 && iKTerm < 0) { + // TODO: error, match on vector1 should require a limit for KNN + return SQLITE_ERROR; + } + if (iLimitTerm >= 0 && iKTerm >= 0) { + return SQLITE_ERROR; // limit or k, not both + } + if (pIdxInfo->nOrderBy < 1) { + SET_VTAB_ERROR("ORDER BY distance required"); + return SQLITE_CONSTRAINT; + } + if (pIdxInfo->nOrderBy > 1) { + // TODO error, orderByConsumed is all or nothing, only 1 order by allowed + SET_VTAB_ERROR("more than 1 ORDER BY clause provided"); + return SQLITE_CONSTRAINT; + } + if (pIdxInfo->aOrderBy[0].iColumn != VEC_EXPO_DISTANCE) { + SET_VTAB_ERROR("ORDER BY must be on the distance column"); + return SQLITE_CONSTRAINT; + } + if (pIdxInfo->aOrderBy[0].desc) { + SET_VTAB_ERROR("Only ascending in ORDER BY distance clause is supported, " + "DESC is not supported yet."); + return SQLITE_CONSTRAINT; + } + + pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_KNN; + pIdxInfo->estimatedCost = (double)10; // TODO vtab_value(?) as hint? + pIdxInfo->estimatedRows = 10;// TODO vtab_value(?) as hint? + + pIdxInfo->orderByConsumed = 1; + pIdxInfo->aConstraintUsage[iMatchTerm].argvIndex = 1; + pIdxInfo->aConstraintUsage[iMatchTerm].omit = 1; + if (iLimitTerm >= 0) { + pIdxInfo->aConstraintUsage[iLimitTerm].argvIndex = 2; + pIdxInfo->aConstraintUsage[iLimitTerm].omit = 1; + } else { + pIdxInfo->aConstraintUsage[iKTerm].argvIndex = 2; + pIdxInfo->aConstraintUsage[iKTerm].omit = 1; + } + + } + else { + pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_FULLSCAN; + pIdxInfo->estimatedCost = 10000.0; + pIdxInfo->estimatedRows = 10000; + } + return SQLITE_OK; +} + +static int vec_expoFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, + const char *idxStr, int argc, sqlite3_value **argv) { + vec_expo_cursor *pCur = (vec_expo_cursor *)pVtabCursor; + vec_expo_vtab *p = (vec_expo_vtab *)pCur->base.pVtab; + + if(idxNum == VEC_SBE__QUERYPLAN_KNN) { + pCur->query_plan = VEC_SBE__QUERYPLAN_KNN; + struct vec0_query_knn_data *knn_data = + sqlite3_malloc(sizeof(struct vec0_query_knn_data)); + if (!knn_data) { + return SQLITE_NOMEM; + } + memset(knn_data, 0, sizeof(struct vec0_query_knn_data)); + + void *queryVector; + size_t dimensions; + enum VectorElementType elementType; + vector_cleanup cleanup; + char *err; + int rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType, + &cleanup, &err); + todo_assert(elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32); + todo_assert(dimensions > 0); + + i64 k = sqlite3_value_int64(argv[1]); + todo_assert(k >= 0); + if (k == 0) { + knn_data->k = 0; + pCur->knn_data = knn_data; + return SQLITE_OK; + } + + i64 *topk_rowids = sqlite3_malloc(k * sizeof(i64)); + todo_assert(topk_rowids); + f32 *topk_distances = sqlite3_malloc(k * sizeof(f32)); + todo_assert(topk_distances); + + sqlite3_stmt * stmtRowids; + char * zSql = sqlite3_mprintf("select rowid from \"%w\" ", p->table); + assert(zSql); + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtRowids, NULL); + assert(rc == SQLITE_OK); + + sqlite3_blob * baseVectorsBlob; + sqlite3_blob_open(p->db, "main", p->table, p->column, 1, 0, &baseVectorsBlob); + + int chunk_size = 200; + float * chunk = sqlite3_malloc(dimensions * chunk_size * sizeof(float)); + assert(chunk); + + f32 *chunk_distances = sqlite3_malloc(chunk_size * sizeof(f32)); + todo_assert(chunk_distances); + for (int i = 0; i < k; i++) { + topk_distances[i] = __FLT_MAX__; + } + i64 *chunk_rowids = sqlite3_malloc(chunk_size * sizeof(i64)); + todo_assert(chunk_rowids); + + + + while(true) { + int nused = 0; + for(int i = 0; i < chunk_size; i++) { + rc = sqlite3_step(stmtRowids); + if(rc == SQLITE_DONE) { + break; + } + assert(rc == SQLITE_ROW); + nused = i+1; + + i64 rowid = sqlite3_column_int64(stmtRowids, 0); + chunk_rowids[i] = rowid; + rc = sqlite3_blob_reopen(baseVectorsBlob, rowid); + assert(rc == SQLITE_OK); + assert(sqlite3_blob_bytes(baseVectorsBlob) == dimensions * sizeof(float)); + sqlite3_blob_read(baseVectorsBlob, &chunk[i * dimensions], dimensions * sizeof(float), 0); + } + + for(int i = 0; i < nused; i++) { + const f32 *base_i = (chunk) + (i * dimensions); + chunk_distances[i] = distance_l2_sqr_float(base_i, (f32 *)queryVector, &dimensions); + } + + i32 *chunk_top_idxs = sqlite3_malloc(nused * sizeof(i32)); + todo_assert(chunk_top_idxs); + min_idx(chunk_distances, nused, chunk_top_idxs, nused); + + i64 *out_rowids; + f32 *out_distances; + dethrone2(k, topk_distances, topk_rowids, /*chunk_size*/ nused, chunk_top_idxs, + chunk_distances, chunk_rowids, + + &out_rowids, &out_distances); + for (int i = 0; i < k; i++) { + topk_rowids[i] = out_rowids[i]; + topk_distances[i] = out_distances[i]; + } + sqlite3_free(out_rowids); + sqlite3_free(out_distances); + sqlite3_free(chunk_top_idxs); + + if(nused < chunk_size) break; + } + sqlite3_blob_close(baseVectorsBlob); + sqlite3_finalize(stmtRowids); + + cleanup(queryVector); + + + knn_data->current_idx = 0; + knn_data->k = k; + knn_data->rowids = topk_rowids; + knn_data->distances = topk_distances; + pCur->knn_data = knn_data; + } + else { + pCur->query_plan = VEC_SBE__QUERYPLAN_FULLSCAN; + pCur->iRowid = 0; + } + + return SQLITE_OK; +} + +static int vec_expoRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { + vec_expo_cursor *pCur = (vec_expo_cursor *)cur; + switch(pCur->query_plan) { + case VEC_SBE__QUERYPLAN_FULLSCAN: { + *pRowid = pCur->iRowid; + break; + } + case VEC_SBE__QUERYPLAN_KNN: { + *pRowid = pCur->knn_data->rowids[pCur->knn_data->current_idx]; + break; + } + } + + return SQLITE_OK; +} + +static int vec_expoNext(sqlite3_vtab_cursor *cur) { + vec_expo_cursor *pCur = (vec_expo_cursor *)cur; + switch(pCur->query_plan) { + case VEC_SBE__QUERYPLAN_FULLSCAN: { + pCur->iRowid++; + return SQLITE_OK; + } + case VEC_SBE__QUERYPLAN_KNN: { + pCur->knn_data->current_idx++; + return SQLITE_OK; + } + } + +} + +static int vec_expoEof(sqlite3_vtab_cursor *cur) { + vec_expo_cursor *pCur = (vec_expo_cursor *)cur; + vec_expo_vtab * p = (vec_expo_vtab *) pCur->base.pVtab; + switch(pCur->query_plan) { + case VEC_SBE__QUERYPLAN_FULLSCAN: { + return 1;//(size_t) pCur->iRowid >= p->blob->nvectors; + } + case VEC_SBE__QUERYPLAN_KNN: { + return pCur->knn_data->current_idx >= pCur->knn_data->k; + } + } + +} + +static int vec_expoColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context, + int i) { + vec_expo_cursor *pCur = (vec_expo_cursor *)cur; + vec_expo_vtab *p = (vec_expo_vtab *)cur->pVtab; + + switch(pCur->query_plan) { + case VEC_SBE__QUERYPLAN_FULLSCAN: { + return SQLITE_OK; + } + case VEC_SBE__QUERYPLAN_KNN: { + switch(i) { + case VEC_EXPO_VECTOR: { + break; + } + case VEC_EXPO_DISTANCE: { + sqlite3_result_double(context, pCur->knn_data->distances[pCur->knn_data->current_idx]); + break; + } + } + return SQLITE_OK; + } + } +} + + +static sqlite3_module vec_expoModule = { + /* iVersion */ 3, + /* xCreate */ vec_expoCreate, + /* xConnect */ vec_expoConnect, + /* xBestIndex */ vec_expoBestIndex, + /* xDisconnect */ vec_expoDisconnect, + /* xDestroy */ vec_expoDisconnect, + /* xOpen */ vec_expoOpen, + /* xClose */ vec_expoClose, + /* xFilter */ vec_expoFilter, + /* xNext */ vec_expoNext, + /* xEof */ vec_expoEof, + /* xColumn */ vec_expoColumn, + /* xRowid */ vec_expoRowid, + /* xUpdate */ 0, + /* xBegin */ 0, + /* xSync */ 0, + /* xCommit */ 0, + /* xRollback */ 0, + /* xFindMethod */ 0, + /* xRename */ 0, + /* xSavepoint */ 0, + /* xRelease */ 0, + /* xRollbackTo */ 0, + /* xShadowName */ 0}; +#pragma endregion + #endif int sqlite3_mmap_warm(sqlite3 *db, const char *zDb) { @@ -5225,6 +5618,8 @@ __declspec(dllexport) assert(rc == SQLITE_OK); rc = sqlite3_create_module_v2(db, "vec_static_blob_entries", &vec_static_blob_entriesModule, static_blob_data, NULL); assert(rc == SQLITE_OK); + rc = sqlite3_create_module_v2(db, "vec_expo", &vec_expoModule, NULL, NULL); + assert(rc == SQLITE_OK); #endif return SQLITE_OK;