From 9e943e691592f5b826a1b58eac0aa864fbe85b79 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Tue, 29 Oct 2024 14:40:33 +0000 Subject: [PATCH 1/6] Add codecs for `ext::pgvector::halfvec`+`ext::pgvector::sparsevec` and add `SparseVector` datatype --- packages/driver/package.json | 1 + packages/driver/src/adapter.shared.deno.ts | 24 ++++ packages/driver/src/adapter.shared.node.ts | 7 + packages/driver/src/codecs/codecs.ts | 8 +- packages/driver/src/codecs/consts.ts | 2 + packages/driver/src/codecs/pgvector.ts | 142 +++++++++++++++++++- packages/driver/src/datatypes/pgvector.ts | 65 ++++++++++ packages/driver/src/index.shared.ts | 2 + packages/driver/test/client.test.ts | 143 +++++++++++++++++---- yarn.lock | 5 + 10 files changed, 369 insertions(+), 30 deletions(-) create mode 100644 packages/driver/src/datatypes/pgvector.ts diff --git a/packages/driver/package.json b/packages/driver/package.json index d914f398c..0aea8ad48 100644 --- a/packages/driver/package.json +++ b/packages/driver/package.json @@ -55,6 +55,7 @@ "dev": "yarn tsc --project tsconfig.json --incremental && yarn build:deno" }, "dependencies": { + "@petamoriken/float16": "^3.8.7", "debug": "^4.3.4", "env-paths": "^3.0.0", "semver": "^7.6.2", diff --git a/packages/driver/src/adapter.shared.deno.ts b/packages/driver/src/adapter.shared.deno.ts index 2edbab714..c84dfd862 100644 --- a/packages/driver/src/adapter.shared.deno.ts +++ b/packages/driver/src/adapter.shared.deno.ts @@ -10,3 +10,27 @@ export function getEnv(envName: string, required = false): string | undefined { } return Deno.env.get(envName); } + +const _Float16Array = Float16Array; +export { _Float16Array as Float16Array }; + +export function getFloat16( + dataView: DataView, + byteOffset: number, + littleEndian?: boolean, +): number { + return dataView.getFloat16(byteOffset, littleEndian); +} + +export function setFloat16( + dataView: DataView, + byteOffset: number, + value: number, + littleEndian?: boolean, +): void { + dataView.setFloat16(byteOffset, value, littleEndian); +} + +export function isFloat16Array(value: unknown): value is Float16Array { + return value instanceof Float16Array; +} diff --git a/packages/driver/src/adapter.shared.node.ts b/packages/driver/src/adapter.shared.node.ts index c596ceced..790eccbf5 100644 --- a/packages/driver/src/adapter.shared.node.ts +++ b/packages/driver/src/adapter.shared.node.ts @@ -1,3 +1,10 @@ +export { + Float16Array, + getFloat16, + isFloat16Array, + setFloat16, +} from "@petamoriken/float16"; + export function getEnv(envName: string, _required = false): string | undefined { return process.env[envName]; } diff --git a/packages/driver/src/codecs/codecs.ts b/packages/driver/src/codecs/codecs.ts index ecbdbe1d5..dc53ea81b 100644 --- a/packages/driver/src/codecs/codecs.ts +++ b/packages/driver/src/codecs/codecs.ts @@ -42,7 +42,11 @@ import { DateDurationCodec, } from "./datetime"; import { ConfigMemoryCodec } from "./memory"; -import { PgVectorCodec } from "./pgvector"; +import { + PgVectorCodec, + PgVectorHalfVecCodec, + PgVectorSparseVecCodec, +} from "./pgvector"; import { InternalClientError } from "../errors"; import { INVALID_CODEC_ID, KNOWN_TYPENAMES, NULL_CODEC_ID } from "./consts"; @@ -119,3 +123,5 @@ registerScalarCodec("cal::date_duration", DateDurationCodec); registerScalarCodec("cfg::memory", ConfigMemoryCodec); registerScalarCodec("ext::pgvector::vector", PgVectorCodec); +registerScalarCodec("ext::pgvector::halfvec", PgVectorHalfVecCodec); +registerScalarCodec("ext::pgvector::sparsevec", PgVectorSparseVecCodec); diff --git a/packages/driver/src/codecs/consts.ts b/packages/driver/src/codecs/consts.ts index 577f4435f..12944fe17 100644 --- a/packages/driver/src/codecs/consts.ts +++ b/packages/driver/src/codecs/consts.ts @@ -48,6 +48,8 @@ export const KNOWN_TYPES = new Map([ ["00000000000000000000000000000112", "cal::date_duration"], ["00000000000000000000000000000130", "cfg::memory"], ["9565dd8804f511eea6910b6ebe179825", "ext::pgvector::vector"], + ["4ba84534188e43b4a7cecea2af0f405b", "ext::pgvector::halfvec"], + ["003e434dcac2430ab238fb39d73447d2", "ext::pgvector::sparsevec"], ]); export const KNOWN_TYPENAMES = (() => { diff --git a/packages/driver/src/codecs/pgvector.ts b/packages/driver/src/codecs/pgvector.ts index 8b88a7341..938cfcf7f 100644 --- a/packages/driver/src/codecs/pgvector.ts +++ b/packages/driver/src/codecs/pgvector.ts @@ -19,11 +19,18 @@ import type { ReadBuffer, WriteBuffer } from "../primitives/buffer"; import { type ICodec, ScalarCodec } from "./ifaces"; import { InvalidArgumentError } from "../errors"; +import { + Float16Array, + getFloat16, + isFloat16Array, + setFloat16, +} from "../adapter.shared.node"; +import { SparseVector } from "../datatypes/pgvector"; export const PG_VECTOR_MAX_DIM = (1 << 16) - 1; export class PgVectorCodec extends ScalarCodec implements ICodec { - tsType = "Float32Array"; + readonly tsType = "Float32Array"; encode(buf: WriteBuffer, object: any): void { if (!(object instanceof Float32Array || Array.isArray(object))) { @@ -78,3 +85,136 @@ export class PgVectorCodec extends ScalarCodec implements ICodec { return vec; } } + +export class PgVectorHalfVecCodec extends ScalarCodec implements ICodec { + readonly tsType = "Float16Array"; + readonly tsModule = "edgedb"; + + encode(buf: WriteBuffer, object: any): void { + if (!(isFloat16Array(object) || Array.isArray(object))) { + throw new InvalidArgumentError( + `a Float16Array or array of numbers was expected, got "${object}"`, + ); + } + + if (object.length > PG_VECTOR_MAX_DIM) { + throw new InvalidArgumentError( + "too many elements in array to encode into pgvector", + ); + } + + buf + .writeInt32(4 + object.length * 2) + .writeUInt16(object.length) + .writeUInt16(0); + + const vecBuf = new Uint8Array(object.length * 2); + const data = new DataView( + vecBuf.buffer, + vecBuf.byteOffset, + vecBuf.byteLength, + ); + + if (isFloat16Array(object)) { + for (let i = 0; i < object.length; i++) { + setFloat16(data, i * 2, object[i]); + } + } else { + for (let i = 0; i < object.length; i++) { + if (typeof object[i] !== "number") { + throw new InvalidArgumentError( + `elements of vector array expected to be a numbers, got "${object[i]}"`, + ); + } + setFloat16(data, i * 2, object[i]); + } + } + + buf.writeBuffer(vecBuf); + } + + decode(buf: ReadBuffer): any { + const dim = buf.readUInt16(); + buf.discard(2); + + const vecBuf = buf.readBuffer(dim * 2); + const data = new DataView( + vecBuf.buffer, + vecBuf.byteOffset, + vecBuf.byteLength, + ); + const vec = new Float16Array(dim); + + for (let i = 0; i < dim; i++) { + vec[i] = getFloat16(data, i * 2); + } + + return vec; + } +} + +export class PgVectorSparseVecCodec extends ScalarCodec implements ICodec { + readonly tsType = "SparseVector"; + readonly tsModule = "edgedb"; + + encode(buf: WriteBuffer, object: any): void { + if (!(object instanceof SparseVector)) { + throw new InvalidArgumentError( + `a SparseVector was expected, got "${object}"`, + ); + } + + const nnz = object.indexes.length; + + if (nnz > PG_VECTOR_MAX_DIM || nnz > object.length) { + throw new InvalidArgumentError( + "too many elements in sparse vector value", + ); + } + + buf + .writeUInt32(4 * (3 + nnz * 2)) + .writeUInt32(object.length) + .writeUInt32(nnz) + .writeUInt32(0); + + const vecBuf = new Uint8Array(nnz * 2); + const data = new DataView( + vecBuf.buffer, + vecBuf.byteOffset, + vecBuf.byteLength, + ); + + for (let i = 0; i < nnz; i++) { + data.setUint32(i * 4, object.indexes[i]); + } + for (let i = 0; i < nnz; i++) { + data.setFloat32(nnz + i * 4, object.values[i]); + } + + buf.writeBuffer(vecBuf); + } + + decode(buf: ReadBuffer): any { + const dim = buf.readUInt32(); + const nnz = buf.readUInt32(); + buf.discard(4); + + const vecBuf = buf.readBuffer(nnz * 8); + const data = new DataView( + vecBuf.buffer, + vecBuf.byteOffset, + vecBuf.byteLength, + ); + const indexes = new Uint32Array(nnz); + for (let i = 0; i < nnz; i++) { + indexes[i] = data.getUint32(i * 4); + } + const vecData = new Float32Array(nnz); + for (let i = 0; i < nnz; i++) { + vecData[i] = data.getFloat32((i + nnz) * 4); + } + + return new SparseVector(dim, indexes, vecData); + } +} diff --git a/packages/driver/src/datatypes/pgvector.ts b/packages/driver/src/datatypes/pgvector.ts new file mode 100644 index 000000000..e7c11eaac --- /dev/null +++ b/packages/driver/src/datatypes/pgvector.ts @@ -0,0 +1,65 @@ +export class SparseVector { + public indexes: Uint32Array; + public values: Float32Array; + + constructor(length: number, map: Record); + constructor(length: number, indexes: Uint32Array, values: Float32Array); + constructor( + public length: number, + indexesOrMap: Uint32Array | Record, + values?: Float32Array, + ) { + if (indexesOrMap instanceof Uint32Array) { + if (indexesOrMap.length !== values?.length) { + throw new Error( + "indexes array must be the same length as the data array", + ); + } + if (indexesOrMap.length > length) { + throw new Error( + "length of data cannot be larger than length of sparse vector", + ); + } + this.values = values; + this.indexes = indexesOrMap; + } else { + const entries = Object.entries(indexesOrMap); + if (entries.length > length) { + throw new Error( + "length of data cannot be larger than length of sparse vector", + ); + } + this.indexes = new Uint32Array(entries.length); + this.values = new Float32Array(entries.length); + for (let i = 0; i < entries.length; i++) { + const index = parseInt(entries[i][0], 10); + const val = entries[i][1]; + if (!Number.isNaN(index)) { + throw new Error("key in data map not an integer"); + } + if (index < 0 || index > length) { + throw new Error( + `index ${index} is out of range of sparse vector length`, + ); + } + this.indexes[i] = index; + if (val === 0) { + throw new Error("elements in sparse vector cannot be 0"); + } + this.values[i] = val; + } + } + + return new Proxy(this, { + get(target, p) { + const index = typeof p === "string" ? parseInt(p, 10) : NaN; + if (!Number.isNaN(index)) { + if (index < 0 || index >= target.length) return undefined; + const dataIndex = target.indexes.indexOf(index); + return dataIndex === -1 ? 0 : target.values[dataIndex]; + } + return (target as any)[p]; + }, + }); + } +} diff --git a/packages/driver/src/index.shared.ts b/packages/driver/src/index.shared.ts index 03b43d99c..2a9c85f2a 100644 --- a/packages/driver/src/index.shared.ts +++ b/packages/driver/src/index.shared.ts @@ -28,6 +28,8 @@ export { } from "./datatypes/datetime"; export { ConfigMemory } from "./datatypes/memory"; export { Range, MultiRange } from "./datatypes/range"; +export { SparseVector } from "./datatypes/pgvector"; +export { Float16Array } from "./adapter.shared.node"; export type { Executor } from "./ifaces"; diff --git a/packages/driver/test/client.test.ts b/packages/driver/test/client.test.ts index 4d2f46f2a..e2a6ca007 100644 --- a/packages/driver/test/client.test.ts +++ b/packages/driver/test/client.test.ts @@ -38,6 +38,9 @@ import { AuthenticationError, InvalidReferenceError, throwWarnings, + Float16Array, + InvalidArgumentError, + InvalidValueError, } from "../src/index.node"; import { AdminUIFetchConnection } from "../src/fetchConn"; @@ -434,34 +437,25 @@ test("fetch: int64 as bigint", async () => { if (!isDeno) { describe("fetch: ext::pgvector::vector", () => { const con = getClient(); - const hasPgVectorExtentionQuery = ` - select exists ( - select sys::ExtensionPackage filter .name = 'pgvector' - )`; + const hasPgVectorExtention = con.queryRequiredSingle(` + select exists ( + select sys::ExtensionPackage filter .name = 'pgvector' + )`); beforeAll(async () => { - const hasPgVectorExtention = await con.queryRequiredSingle( - hasPgVectorExtentionQuery, - ); - if (!hasPgVectorExtention) return; + if (!(await hasPgVectorExtention)) return; await con.execute("create extension pgvector;"); }); afterAll(async () => { - const hasPgVectorExtention = await con.queryRequiredSingle( - hasPgVectorExtentionQuery, - ); - if (hasPgVectorExtention) { + if (await hasPgVectorExtention) { await con.execute("drop extension pgvector;"); } await con.close(); }); it("valid: Float32Array", async () => { - const hasPgVectorExtention = await con.queryRequiredSingle( - hasPgVectorExtentionQuery, - ); - if (!hasPgVectorExtention) return; + if (!(await hasPgVectorExtention)) return; await fc.assert( fc.asyncProperty( @@ -489,10 +483,7 @@ if (!isDeno) { }); it("valid: JSON", async () => { - const hasPgVectorExtention = await con.queryRequiredSingle( - hasPgVectorExtentionQuery, - ); - if (!hasPgVectorExtention) return; + if (!(await hasPgVectorExtention)) return; await fc.assert( fc.asyncProperty( @@ -519,10 +510,7 @@ if (!isDeno) { }); it("invalid: empty", async () => { - const hasPgVectorExtention = await con.queryRequiredSingle( - hasPgVectorExtentionQuery, - ); - if (!hasPgVectorExtention) return; + if (!(await hasPgVectorExtention)) return; const data = new Float32Array([]); await expect( @@ -531,10 +519,7 @@ if (!isDeno) { }); it("invalid: invalid argument", async () => { - const hasPgVectorExtention = await con.queryRequiredSingle( - hasPgVectorExtentionQuery, - ); - if (!hasPgVectorExtention) return; + if (!(await hasPgVectorExtention)) return; await expect( con.querySingle("select $0;", ["foo"]), @@ -543,6 +528,108 @@ if (!isDeno) { }); } +describe("fetch: ext::pgvector::halfvec", () => { + const con = getClient(); + const hasPgVectorExtention = con.queryRequiredSingle(` + select exists ( + select sys::ExtensionPackage filter .name = 'pgvector' + )`); + + beforeAll(async () => { + if (!(await hasPgVectorExtention)) return; + await con.execute("create extension pgvector;"); + }); + + afterAll(async () => { + if (await hasPgVectorExtention) { + await con.execute("drop extension pgvector;"); + } + await con.close(); + }); + + it("valid: Float16Array", async () => { + if (!(await hasPgVectorExtention)) return; + + const val = await con.queryRequiredSingle( + ` + select + [1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, + 6.0975e-5, 2.2345e-7, -5.96e-8] + `, + ); + + expect(val).toBeInstanceOf(Float16Array); + expect(val[0]).toEqual(1.5); + expect(val[1]).toEqual(2); + expect(val[2]).toBeCloseTo(3.8, 2); + expect(val[3]).toEqual(0); + expect(val[4]).toBeCloseTo(3.457e-3, 2); + expect(val[5]).toEqual(64992); + // These values are sub-normal so they don't map perfectly onto f32 + expect(val[6]).toBeCloseTo(6.0975e-5, 2); + expect(val[7]).toBeCloseTo(2.38e-7, 2); + expect(val[8]).toBeCloseTo(-5.96e-8, 2); + }); + + it("valid: Float16Array arg", async () => { + if (!(await hasPgVectorExtention)) return; + + const val = await con.queryRequiredSingle( + `select >$0`, + [ + new Float16Array([ + 1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, 6.0975e-5, 2.385e-7, -5.97e-8, + ]), + ], + ); + + expect(val[0]).toEqual(1.5); + expect(val[1]).toEqual(2); + expect(val[2]).toBeCloseTo(3.8, 2); + expect(val[3]).toEqual(0); + expect(val[4]).toBeCloseTo(3.457e-3, 2); + expect(val[5]).toEqual(64992); + // These values are sub-normal so they don't map perfectly onto f32 + expect(val[6]).toBeCloseTo(6.0975e-5, 2); + expect(val[7]).toBeCloseTo(2.38e-7, 2); + expect(val[8]).toBeCloseTo(-5.96e-8, 2); + }); + + it("valid: number[] arg", async () => { + await expect( + con.queryRequiredSingle( + `select $0 = $1`, + [ + new Float16Array([ + 1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, 6.0975e-5, 2.385e-7, -5.97e-8, + ]), + [1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, 6.0975e-5, 2.385e-7, -5.97e-8], + ], + ), + ).resolves.toBe(true); + }); + + it("invalid: invalid args", async () => { + await expect( + con.querySingle(`select $0`, [ + [3.0, null, -42.5], + ]), + ).rejects.toThrow(InvalidArgumentError); + + await expect( + con.querySingle(`select $0`, [[3.0, "x", -42.5]]), + ).rejects.toThrow(InvalidArgumentError); + + await expect( + con.querySingle(`select $0`, ["foo"]), + ).rejects.toThrow(InvalidArgumentError); + + await expect( + con.querySingle(`select $0`, [[1_000_000]]), + ).rejects.toThrow(InvalidValueError); + }); +}); + test("fetch: positional args", async () => { const con = getClient(); let res: any; diff --git a/yarn.lock b/yarn.lock index 4881ebd28..0e7c7eb45 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1115,6 +1115,11 @@ "@nodelib/fs.scandir" "2.1.5" fastq "^1.6.0" +"@petamoriken/float16@^3.8.7": + version "3.8.7" + resolved "https://registry.yarnpkg.com/@petamoriken/float16/-/float16-3.8.7.tgz#16073fb1b9867eaa5b254573484d09100700aaa4" + integrity sha512-/Ri4xDDpe12NT6Ex/DRgHzLlobiQXEW/hmG08w1wj/YU7hLemk97c+zHQFp0iZQ9r7YqgLEXZR2sls4HxBf9NA== + "@polka/url@^1.0.0-next.24": version "1.0.0-next.25" resolved "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.25.tgz" From 815fc17c9f18f00feb5b7ded08e04d672bebb1b1 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Tue, 29 Oct 2024 15:09:10 +0000 Subject: [PATCH 2/6] Fix pgvector ext check query --- packages/driver/test/client.test.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/driver/test/client.test.ts b/packages/driver/test/client.test.ts index e2a6ca007..b5c525fcc 100644 --- a/packages/driver/test/client.test.ts +++ b/packages/driver/test/client.test.ts @@ -532,7 +532,9 @@ describe("fetch: ext::pgvector::halfvec", () => { const con = getClient(); const hasPgVectorExtention = con.queryRequiredSingle(` select exists ( - select sys::ExtensionPackage filter .name = 'pgvector' + select sys::ExtensionPackage + filter .name = 'pgvector' + and (.version.major > 0 or .version.minor >= 7) )`); beforeAll(async () => { From 6725e0afe4f3383ffe8337455b79a21c06aafb15 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Tue, 29 Oct 2024 15:21:19 +0000 Subject: [PATCH 3/6] Fix ext exists check attempt 2 --- packages/driver/test/client.test.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/driver/test/client.test.ts b/packages/driver/test/client.test.ts index b5c525fcc..a82512f47 100644 --- a/packages/driver/test/client.test.ts +++ b/packages/driver/test/client.test.ts @@ -598,6 +598,8 @@ describe("fetch: ext::pgvector::halfvec", () => { }); it("valid: number[] arg", async () => { + if (!(await hasPgVectorExtention)) return; + await expect( con.queryRequiredSingle( `select $0 = $1`, @@ -612,6 +614,8 @@ describe("fetch: ext::pgvector::halfvec", () => { }); it("invalid: invalid args", async () => { + if (!(await hasPgVectorExtention)) return; + await expect( con.querySingle(`select $0`, [ [3.0, null, -42.5], From 8c3d901a2a45f9ad9a38691e78abef76f86bc362 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Tue, 29 Oct 2024 16:38:22 +0000 Subject: [PATCH 4/6] Add sparsevec tests + fix codec/datatype --- packages/driver/src/codecs/pgvector.ts | 4 +- packages/driver/src/datatypes/pgvector.ts | 15 ++++- packages/driver/test/client.test.ts | 81 +++++++++++++++++++++++ 3 files changed, 95 insertions(+), 5 deletions(-) diff --git a/packages/driver/src/codecs/pgvector.ts b/packages/driver/src/codecs/pgvector.ts index 938cfcf7f..b81d076a2 100644 --- a/packages/driver/src/codecs/pgvector.ts +++ b/packages/driver/src/codecs/pgvector.ts @@ -178,7 +178,7 @@ export class PgVectorSparseVecCodec extends ScalarCodec implements ICodec { .writeUInt32(nnz) .writeUInt32(0); - const vecBuf = new Uint8Array(nnz * 2); + const vecBuf = new Uint8Array(nnz * 8); const data = new DataView( vecBuf.buffer, vecBuf.byteOffset, @@ -189,7 +189,7 @@ export class PgVectorSparseVecCodec extends ScalarCodec implements ICodec { data.setUint32(i * 4, object.indexes[i]); } for (let i = 0; i < nnz; i++) { - data.setFloat32(nnz + i * 4, object.values[i]); + data.setFloat32((nnz + i) * 4, object.values[i]); } buf.writeBuffer(vecBuf); diff --git a/packages/driver/src/datatypes/pgvector.ts b/packages/driver/src/datatypes/pgvector.ts index e7c11eaac..fcf03cf20 100644 --- a/packages/driver/src/datatypes/pgvector.ts +++ b/packages/driver/src/datatypes/pgvector.ts @@ -1,3 +1,7 @@ +export interface SparseVector { + [index: number]: number; +} + export class SparseVector { public indexes: Uint32Array; public values: Float32Array; @@ -34,15 +38,20 @@ export class SparseVector { for (let i = 0; i < entries.length; i++) { const index = parseInt(entries[i][0], 10); const val = entries[i][1]; - if (!Number.isNaN(index)) { - throw new Error("key in data map not an integer"); + if (Number.isNaN(index)) { + throw new Error(`key ${entries[i][0]} in data map is not an integer`); } - if (index < 0 || index > length) { + if (index < 0 || index >= length) { throw new Error( `index ${index} is out of range of sparse vector length`, ); } this.indexes[i] = index; + if (typeof val !== "number") { + throw new Error( + `expected value at index ${index} to be number, got ${typeof val} ${val}`, + ); + } if (val === 0) { throw new Error("elements in sparse vector cannot be 0"); } diff --git a/packages/driver/test/client.test.ts b/packages/driver/test/client.test.ts index a82512f47..6555d7702 100644 --- a/packages/driver/test/client.test.ts +++ b/packages/driver/test/client.test.ts @@ -41,6 +41,7 @@ import { Float16Array, InvalidArgumentError, InvalidValueError, + SparseVector, } from "../src/index.node"; import { AdminUIFetchConnection } from "../src/fetchConn"; @@ -636,6 +637,86 @@ describe("fetch: ext::pgvector::halfvec", () => { }); }); +describe("fetch: ext::pgvector::sparsevec", () => { + const con = getClient(); + const hasPgVectorExtention = con.queryRequiredSingle(` + select exists ( + select sys::ExtensionPackage + filter .name = 'pgvector' + and (.version.major > 0 or .version.minor >= 7) + )`); + + beforeAll(async () => { + if (!(await hasPgVectorExtention)) return; + await con.execute("create extension pgvector;"); + }); + + afterAll(async () => { + if (await hasPgVectorExtention) { + await con.execute("drop extension pgvector;"); + } + await con.close(); + }); + + it("valid: SparseVector", async () => { + if (!(await hasPgVectorExtention)) return; + + const val = await con.queryRequiredSingle( + ` + select + [0, 1.5, 2.0, 3.8, 0, 0] + `, + ); + + expect(val).toBeInstanceOf(SparseVector); + expect(val.length).toEqual(6); + expect(val[1]).toEqual(1.5); + expect(val[2]).toEqual(2); + expect(val[3]).toBeCloseTo(3.8, 6); + expect(val[4]).toEqual(0); + }); + + it("valid: SparseVector arg", async () => { + if (!(await hasPgVectorExtention)) return; + + const val = await con.queryRequiredSingle( + ` + select + $0 + `, + [new SparseVector(6, { 1: 1.5, 2: 2, 4: 3.8 })], + ); + + expect(val).toEqual(new Float32Array([0, 1.5, 2, 0, 3.8, 0])); + }); + + it("invalid: invalid args", async () => { + expect(() => new SparseVector(1, { 1: 1.5, 2: 2, 3: 3.8 })).toThrow( + `length of data cannot be larger than length of sparse vector`, + ); + + expect(() => new SparseVector(6, { 1: 1.5, 2: 2, 6: 3.8 })).toThrow( + `index 6 is out of range of sparse vector length`, + ); + + expect( + () => + // @ts-expect-error + new SparseVector(6, { 1: 1.5, 2: 2, 3: "3.8" }), + ).toThrow(`expected value at index 3 to be number, got string 3.8`); + + expect( + () => + // @ts-expect-error + new SparseVector(6, { 1: 1.5, 2: 2, x: 3.8 }), + ).toThrow(`key x in data map is not an integer`); + + expect(() => new SparseVector(6, { 1: 1.5, 2: 2, 3: 0 })).toThrow( + `elements in sparse vector cannot be 0`, + ); + }); +}); + test("fetch: positional args", async () => { const con = getClient(); let res: any; From 587c07ae58aa19593f86664ce0b4fd44257d59ea Mon Sep 17 00:00:00 2001 From: James Clarke Date: Tue, 29 Oct 2024 21:57:06 +0000 Subject: [PATCH 5/6] Fix lint --- packages/driver/src/datatypes/pgvector.ts | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/driver/src/datatypes/pgvector.ts b/packages/driver/src/datatypes/pgvector.ts index fcf03cf20..2909c826b 100644 --- a/packages/driver/src/datatypes/pgvector.ts +++ b/packages/driver/src/datatypes/pgvector.ts @@ -1,11 +1,9 @@ -export interface SparseVector { - [index: number]: number; -} - export class SparseVector { public indexes: Uint32Array; public values: Float32Array; + [index: number]: number; + constructor(length: number, map: Record); constructor(length: number, indexes: Uint32Array, values: Float32Array); constructor( From b4c250c14561731db54ab060a8e07ff22b14fd74 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Tue, 29 Oct 2024 23:30:57 +0000 Subject: [PATCH 6/6] Fix deno tests --- packages/driver/test/client.test.ts | 296 ++++++++++++---------------- packages/driver/test/globalSetup.ts | 9 + packages/driver/test/testbase.ts | 7 + 3 files changed, 147 insertions(+), 165 deletions(-) diff --git a/packages/driver/test/client.test.ts b/packages/driver/test/client.test.ts index 6555d7702..de9203e4c 100644 --- a/packages/driver/test/client.test.ts +++ b/packages/driver/test/client.test.ts @@ -47,6 +47,7 @@ import { import { AdminUIFetchConnection } from "../src/fetchConn"; import type { CustomCodecSpec } from "../src/codecs/registry"; import { + getAvailableExtensions, getAvailableFeatures, getClient, getConnectOptions, @@ -435,29 +436,22 @@ test("fetch: int64 as bigint", async () => { } }); -if (!isDeno) { +const pgvectorVersion = getAvailableExtensions().get("pgvector"); + +if (!isDeno && pgvectorVersion != null) { describe("fetch: ext::pgvector::vector", () => { const con = getClient(); - const hasPgVectorExtention = con.queryRequiredSingle(` - select exists ( - select sys::ExtensionPackage filter .name = 'pgvector' - )`); beforeAll(async () => { - if (!(await hasPgVectorExtention)) return; await con.execute("create extension pgvector;"); }); afterAll(async () => { - if (await hasPgVectorExtention) { - await con.execute("drop extension pgvector;"); - } + await con.execute("drop extension pgvector;"); await con.close(); }); it("valid: Float32Array", async () => { - if (!(await hasPgVectorExtention)) return; - await fc.assert( fc.asyncProperty( fc.float32Array({ @@ -484,8 +478,6 @@ if (!isDeno) { }); it("valid: JSON", async () => { - if (!(await hasPgVectorExtention)) return; - await fc.assert( fc.asyncProperty( fc.float32Array({ @@ -511,8 +503,6 @@ if (!isDeno) { }); it("invalid: empty", async () => { - if (!(await hasPgVectorExtention)) return; - const data = new Float32Array([]); await expect( con.querySingle("select $0;", [data]), @@ -520,8 +510,6 @@ if (!isDeno) { }); it("invalid: invalid argument", async () => { - if (!(await hasPgVectorExtention)) return; - await expect( con.querySingle("select $0;", ["foo"]), ).rejects.toThrow(); @@ -529,193 +517,171 @@ if (!isDeno) { }); } -describe("fetch: ext::pgvector::halfvec", () => { - const con = getClient(); - const hasPgVectorExtention = con.queryRequiredSingle(` - select exists ( - select sys::ExtensionPackage - filter .name = 'pgvector' - and (.version.major > 0 or .version.minor >= 7) - )`); - - beforeAll(async () => { - if (!(await hasPgVectorExtention)) return; - await con.execute("create extension pgvector;"); - }); +if ( + pgvectorVersion != null && + (pgvectorVersion.major > 0 || pgvectorVersion.minor >= 7) +) { + describe("fetch: ext::pgvector::halfvec", () => { + const con = getClient(); - afterAll(async () => { - if (await hasPgVectorExtention) { - await con.execute("drop extension pgvector;"); - } - await con.close(); - }); + beforeAll(async () => { + await con.execute("create extension pgvector;"); + }); - it("valid: Float16Array", async () => { - if (!(await hasPgVectorExtention)) return; + afterAll(async () => { + await con.execute("drop extension pgvector;"); + await con.close(); + }); - const val = await con.queryRequiredSingle( - ` + it("valid: Float16Array", async () => { + const val = await con.queryRequiredSingle( + ` select [1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, 6.0975e-5, 2.2345e-7, -5.96e-8] `, - ); - - expect(val).toBeInstanceOf(Float16Array); - expect(val[0]).toEqual(1.5); - expect(val[1]).toEqual(2); - expect(val[2]).toBeCloseTo(3.8, 2); - expect(val[3]).toEqual(0); - expect(val[4]).toBeCloseTo(3.457e-3, 2); - expect(val[5]).toEqual(64992); - // These values are sub-normal so they don't map perfectly onto f32 - expect(val[6]).toBeCloseTo(6.0975e-5, 2); - expect(val[7]).toBeCloseTo(2.38e-7, 2); - expect(val[8]).toBeCloseTo(-5.96e-8, 2); - }); - - it("valid: Float16Array arg", async () => { - if (!(await hasPgVectorExtention)) return; - - const val = await con.queryRequiredSingle( - `select >$0`, - [ - new Float16Array([ - 1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, 6.0975e-5, 2.385e-7, -5.97e-8, - ]), - ], - ); - - expect(val[0]).toEqual(1.5); - expect(val[1]).toEqual(2); - expect(val[2]).toBeCloseTo(3.8, 2); - expect(val[3]).toEqual(0); - expect(val[4]).toBeCloseTo(3.457e-3, 2); - expect(val[5]).toEqual(64992); - // These values are sub-normal so they don't map perfectly onto f32 - expect(val[6]).toBeCloseTo(6.0975e-5, 2); - expect(val[7]).toBeCloseTo(2.38e-7, 2); - expect(val[8]).toBeCloseTo(-5.96e-8, 2); - }); + ); - it("valid: number[] arg", async () => { - if (!(await hasPgVectorExtention)) return; + expect(val).toBeInstanceOf(Float16Array); + expect(val[0]).toEqual(1.5); + expect(val[1]).toEqual(2); + expect(val[2]).toBeCloseTo(3.8, 2); + expect(val[3]).toEqual(0); + expect(val[4]).toBeCloseTo(3.457e-3, 2); + expect(val[5]).toEqual(64992); + // These values are sub-normal so they don't map perfectly onto f32 + expect(val[6]).toBeCloseTo(6.0975e-5, 2); + expect(val[7]).toBeCloseTo(2.38e-7, 2); + expect(val[8]).toBeCloseTo(-5.96e-8, 2); + }); - await expect( - con.queryRequiredSingle( - `select $0 = $1`, + it("valid: Float16Array arg", async () => { + const val = await con.queryRequiredSingle( + `select >$0`, [ new Float16Array([ 1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, 6.0975e-5, 2.385e-7, -5.97e-8, ]), - [1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, 6.0975e-5, 2.385e-7, -5.97e-8], ], - ), - ).resolves.toBe(true); - }); + ); - it("invalid: invalid args", async () => { - if (!(await hasPgVectorExtention)) return; + expect(val[0]).toEqual(1.5); + expect(val[1]).toEqual(2); + expect(val[2]).toBeCloseTo(3.8, 2); + expect(val[3]).toEqual(0); + expect(val[4]).toBeCloseTo(3.457e-3, 2); + expect(val[5]).toEqual(64992); + // These values are sub-normal so they don't map perfectly onto f32 + expect(val[6]).toBeCloseTo(6.0975e-5, 2); + expect(val[7]).toBeCloseTo(2.38e-7, 2); + expect(val[8]).toBeCloseTo(-5.96e-8, 2); + }); - await expect( - con.querySingle(`select $0`, [ - [3.0, null, -42.5], - ]), - ).rejects.toThrow(InvalidArgumentError); + it("valid: number[] arg", async () => { + await expect( + con.queryRequiredSingle( + `select $0 = $1`, + [ + new Float16Array([ + 1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, 6.0975e-5, 2.385e-7, -5.97e-8, + ]), + [1.5, 2.0, 3.8, 0, 3.4575e-3, 65000, 6.0975e-5, 2.385e-7, -5.97e-8], + ], + ), + ).resolves.toBe(true); + }); - await expect( - con.querySingle(`select $0`, [[3.0, "x", -42.5]]), - ).rejects.toThrow(InvalidArgumentError); + it("invalid: invalid args", async () => { + await expect( + con.querySingle(`select $0`, [ + [3.0, null, -42.5], + ]), + ).rejects.toThrow(InvalidArgumentError); - await expect( - con.querySingle(`select $0`, ["foo"]), - ).rejects.toThrow(InvalidArgumentError); + await expect( + con.querySingle(`select $0`, [ + [3.0, "x", -42.5], + ]), + ).rejects.toThrow(InvalidArgumentError); - await expect( - con.querySingle(`select $0`, [[1_000_000]]), - ).rejects.toThrow(InvalidValueError); - }); -}); + await expect( + con.querySingle(`select $0`, ["foo"]), + ).rejects.toThrow(InvalidArgumentError); -describe("fetch: ext::pgvector::sparsevec", () => { - const con = getClient(); - const hasPgVectorExtention = con.queryRequiredSingle(` - select exists ( - select sys::ExtensionPackage - filter .name = 'pgvector' - and (.version.major > 0 or .version.minor >= 7) - )`); - - beforeAll(async () => { - if (!(await hasPgVectorExtention)) return; - await con.execute("create extension pgvector;"); + await expect( + con.querySingle(`select $0`, [[1_000_000]]), + ).rejects.toThrow(InvalidValueError); + }); }); - afterAll(async () => { - if (await hasPgVectorExtention) { + describe("fetch: ext::pgvector::sparsevec", () => { + const con = getClient(); + + beforeAll(async () => { + await con.execute("create extension pgvector;"); + }); + + afterAll(async () => { await con.execute("drop extension pgvector;"); - } - await con.close(); - }); - it("valid: SparseVector", async () => { - if (!(await hasPgVectorExtention)) return; + await con.close(); + }); - const val = await con.queryRequiredSingle( - ` + it("valid: SparseVector", async () => { + const val = await con.queryRequiredSingle( + ` select [0, 1.5, 2.0, 3.8, 0, 0] `, - ); - - expect(val).toBeInstanceOf(SparseVector); - expect(val.length).toEqual(6); - expect(val[1]).toEqual(1.5); - expect(val[2]).toEqual(2); - expect(val[3]).toBeCloseTo(3.8, 6); - expect(val[4]).toEqual(0); - }); + ); - it("valid: SparseVector arg", async () => { - if (!(await hasPgVectorExtention)) return; + expect(val).toBeInstanceOf(SparseVector); + expect(val.length).toEqual(6); + expect(val[1]).toEqual(1.5); + expect(val[2]).toEqual(2); + expect(val[3]).toBeCloseTo(3.8, 6); + expect(val[4]).toEqual(0); + }); - const val = await con.queryRequiredSingle( - ` + it("valid: SparseVector arg", async () => { + const val = await con.queryRequiredSingle( + ` select $0 `, - [new SparseVector(6, { 1: 1.5, 2: 2, 4: 3.8 })], - ); + [new SparseVector(6, { 1: 1.5, 2: 2, 4: 3.8 })], + ); - expect(val).toEqual(new Float32Array([0, 1.5, 2, 0, 3.8, 0])); - }); + expect(val).toEqual(new Float32Array([0, 1.5, 2, 0, 3.8, 0])); + }); - it("invalid: invalid args", async () => { - expect(() => new SparseVector(1, { 1: 1.5, 2: 2, 3: 3.8 })).toThrow( - `length of data cannot be larger than length of sparse vector`, - ); + it("invalid: invalid args", async () => { + expect(() => new SparseVector(1, { 1: 1.5, 2: 2, 3: 3.8 })).toThrow( + `length of data cannot be larger than length of sparse vector`, + ); - expect(() => new SparseVector(6, { 1: 1.5, 2: 2, 6: 3.8 })).toThrow( - `index 6 is out of range of sparse vector length`, - ); + expect(() => new SparseVector(6, { 1: 1.5, 2: 2, 6: 3.8 })).toThrow( + `index 6 is out of range of sparse vector length`, + ); - expect( - () => - // @ts-expect-error - new SparseVector(6, { 1: 1.5, 2: 2, 3: "3.8" }), - ).toThrow(`expected value at index 3 to be number, got string 3.8`); + expect( + () => + // @ts-expect-error + new SparseVector(6, { 1: 1.5, 2: 2, 3: "3.8" }), + ).toThrow(`expected value at index 3 to be number, got string 3.8`); - expect( - () => - // @ts-expect-error - new SparseVector(6, { 1: 1.5, 2: 2, x: 3.8 }), - ).toThrow(`key x in data map is not an integer`); + expect( + () => + // @ts-expect-error + new SparseVector(6, { 1: 1.5, 2: 2, x: 3.8 }), + ).toThrow(`key x in data map is not an integer`); - expect(() => new SparseVector(6, { 1: 1.5, 2: 2, 3: 0 })).toThrow( - `elements in sparse vector cannot be 0`, - ); + expect(() => new SparseVector(6, { 1: 1.5, 2: 2, 3: 0 })).toThrow( + `elements in sparse vector cannot be 0`, + ); + }); }); -}); +} test("fetch: positional args", async () => { const con = getClient(); @@ -1384,7 +1350,7 @@ if (!isDeno) { } test("fetch: ConfigMemory", async () => { - const client = await getClient(); + const client = getClient(); if ( (await client.queryRequiredSingle( diff --git a/packages/driver/test/globalSetup.ts b/packages/driver/test/globalSetup.ts index d7ddd3b80..8a05d740b 100644 --- a/packages/driver/test/globalSetup.ts +++ b/packages/driver/test/globalSetup.ts @@ -31,6 +31,15 @@ export default async () => { global.edgedbConn = client; process.env._JEST_EDGEDB_VERSION = JSON.stringify(version); + const availableExtensions = ( + await client.query<{ + name: string; + version: { major: number; minor: number }; + }>(`select sys::ExtensionPackage {name, version}`) + ).map(({ name, version }) => [name, version]); + process.env._JEST_EDGEDB_AVAILABLE_EXTENSIONS = + JSON.stringify(availableExtensions); + // tslint:disable-next-line console.log(`EdgeDB test cluster is up [port: ${config.port}]...`); }; diff --git a/packages/driver/test/testbase.ts b/packages/driver/test/testbase.ts index 58deac328..baed4a147 100644 --- a/packages/driver/test/testbase.ts +++ b/packages/driver/test/testbase.ts @@ -62,3 +62,10 @@ export function getAvailableFeatures(): Set { export function getEdgeDBVersion(): EdgeDBVersion { return JSON.parse(process.env._JEST_EDGEDB_VERSION!); } + +export function getAvailableExtensions(): Map< + string, + { major: number; minor: number } +> { + return new Map(JSON.parse(process.env._JEST_EDGEDB_AVAILABLE_EXTENSIONS!)); +}