diff --git a/packages/driver/src/datatypes/pgvector.ts b/packages/driver/src/datatypes/pgvector.ts index 2909c826b..196c69896 100644 --- a/packages/driver/src/datatypes/pgvector.ts +++ b/packages/driver/src/datatypes/pgvector.ts @@ -69,4 +69,11 @@ export class SparseVector { }, }); } + + *[Symbol.iterator]() { + let nextIndex = 0; + for (let i = 0; i < this.length; i++) { + yield this.indexes[nextIndex] === i ? this.values[nextIndex++] : 0; + } + } } diff --git a/packages/driver/test/client.test.ts b/packages/driver/test/client.test.ts index de9203e4c..f6adf5c85 100644 --- a/packages/driver/test/client.test.ts +++ b/packages/driver/test/client.test.ts @@ -627,6 +627,19 @@ if ( await con.close(); }); + it("valid: SparseVector methods", async () => { + const sparseVec = new SparseVector(7, { 1: 1.5, 2: 2, 4: 3.8 }); + + const arr: number[] = []; + let i = 0; + for (const val of sparseVec) { + expect(val).toEqual(sparseVec[i++]); + arr.push(val); + } + expect(arr).toEqual([...new Float32Array([0, 1.5, 2, 0, 3.8, 0, 0])]); + expect(arr).toEqual([...sparseVec]); + }); + it("valid: SparseVector", async () => { const val = await con.queryRequiredSingle( `