Skip to content

Commit

Permalink
Add codecs for ext::pgvector::halfvec+ext::pgvector::sparsevec an…
Browse files Browse the repository at this point in the history
…d add `SparseVector` datatype (#1124)
  • Loading branch information
jaclarke authored Oct 30, 2024
1 parent f18afba commit e433ee2
Show file tree
Hide file tree
Showing 12 changed files with 453 additions and 38 deletions.
1 change: 1 addition & 0 deletions packages/driver/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"dev": "yarn tsc --project tsconfig.json --incremental && yarn build:deno"
},
"dependencies": {
"@petamoriken/float16": "^3.8.7",
"debug": "^4.3.4",
"env-paths": "^3.0.0",
"semver": "^7.6.2",
Expand Down
24 changes: 24 additions & 0 deletions packages/driver/src/adapter.shared.deno.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,27 @@ export function getEnv(envName: string, required = false): string | undefined {
}
return Deno.env.get(envName);
}

const _Float16Array = Float16Array;
export { _Float16Array as Float16Array };

export function getFloat16(
dataView: DataView,
byteOffset: number,
littleEndian?: boolean,
): number {
return dataView.getFloat16(byteOffset, littleEndian);
}

export function setFloat16(
dataView: DataView,
byteOffset: number,
value: number,
littleEndian?: boolean,
): void {
dataView.setFloat16(byteOffset, value, littleEndian);
}

export function isFloat16Array(value: unknown): value is Float16Array {
return value instanceof Float16Array;
}
7 changes: 7 additions & 0 deletions packages/driver/src/adapter.shared.node.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
export {
Float16Array,
getFloat16,
isFloat16Array,
setFloat16,
} from "@petamoriken/float16";

export function getEnv(envName: string, _required = false): string | undefined {
return process.env[envName];
}
8 changes: 7 additions & 1 deletion packages/driver/src/codecs/codecs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ import {
DateDurationCodec,
} from "./datetime";
import { ConfigMemoryCodec } from "./memory";
import { PgVectorCodec } from "./pgvector";
import {
PgVectorCodec,
PgVectorHalfVecCodec,
PgVectorSparseVecCodec,
} from "./pgvector";
import { InternalClientError } from "../errors";

import { INVALID_CODEC_ID, KNOWN_TYPENAMES, NULL_CODEC_ID } from "./consts";
Expand Down Expand Up @@ -119,3 +123,5 @@ registerScalarCodec("cal::date_duration", DateDurationCodec);
registerScalarCodec("cfg::memory", ConfigMemoryCodec);

registerScalarCodec("ext::pgvector::vector", PgVectorCodec);
registerScalarCodec("ext::pgvector::halfvec", PgVectorHalfVecCodec);
registerScalarCodec("ext::pgvector::sparsevec", PgVectorSparseVecCodec);
2 changes: 2 additions & 0 deletions packages/driver/src/codecs/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ export const KNOWN_TYPES = new Map<uuid, string>([
["00000000000000000000000000000112", "cal::date_duration"],
["00000000000000000000000000000130", "cfg::memory"],
["9565dd8804f511eea6910b6ebe179825", "ext::pgvector::vector"],
["4ba84534188e43b4a7cecea2af0f405b", "ext::pgvector::halfvec"],
["003e434dcac2430ab238fb39d73447d2", "ext::pgvector::sparsevec"],
]);

export const KNOWN_TYPENAMES = (() => {
Expand Down
142 changes: 141 additions & 1 deletion packages/driver/src/codecs/pgvector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@
import type { ReadBuffer, WriteBuffer } from "../primitives/buffer";
import { type ICodec, ScalarCodec } from "./ifaces";
import { InvalidArgumentError } from "../errors";
import {
Float16Array,
getFloat16,
isFloat16Array,
setFloat16,
} from "../adapter.shared.node";
import { SparseVector } from "../datatypes/pgvector";

export const PG_VECTOR_MAX_DIM = (1 << 16) - 1;

export class PgVectorCodec extends ScalarCodec implements ICodec {
tsType = "Float32Array";
readonly tsType = "Float32Array";

encode(buf: WriteBuffer, object: any): void {
if (!(object instanceof Float32Array || Array.isArray(object))) {
Expand Down Expand Up @@ -78,3 +85,136 @@ export class PgVectorCodec extends ScalarCodec implements ICodec {
return vec;
}
}

export class PgVectorHalfVecCodec extends ScalarCodec implements ICodec {
readonly tsType = "Float16Array";
readonly tsModule = "edgedb";

encode(buf: WriteBuffer, object: any): void {
if (!(isFloat16Array(object) || Array.isArray(object))) {
throw new InvalidArgumentError(
`a Float16Array or array of numbers was expected, got "${object}"`,
);
}

if (object.length > PG_VECTOR_MAX_DIM) {
throw new InvalidArgumentError(
"too many elements in array to encode into pgvector",
);
}

buf
.writeInt32(4 + object.length * 2)
.writeUInt16(object.length)
.writeUInt16(0);

const vecBuf = new Uint8Array(object.length * 2);
const data = new DataView(
vecBuf.buffer,
vecBuf.byteOffset,
vecBuf.byteLength,
);

if (isFloat16Array(object)) {
for (let i = 0; i < object.length; i++) {
setFloat16(data, i * 2, object[i]);
}
} else {
for (let i = 0; i < object.length; i++) {
if (typeof object[i] !== "number") {
throw new InvalidArgumentError(
`elements of vector array expected to be a numbers, got "${object[i]}"`,
);
}
setFloat16(data, i * 2, object[i]);
}
}

buf.writeBuffer(vecBuf);
}

decode(buf: ReadBuffer): any {
const dim = buf.readUInt16();
buf.discard(2);

const vecBuf = buf.readBuffer(dim * 2);
const data = new DataView(
vecBuf.buffer,
vecBuf.byteOffset,
vecBuf.byteLength,
);
const vec = new Float16Array(dim);

for (let i = 0; i < dim; i++) {
vec[i] = getFloat16(data, i * 2);
}

return vec;
}
}

export class PgVectorSparseVecCodec extends ScalarCodec implements ICodec {
readonly tsType = "SparseVector";
readonly tsModule = "edgedb";

encode(buf: WriteBuffer, object: any): void {
if (!(object instanceof SparseVector)) {
throw new InvalidArgumentError(
`a SparseVector was expected, got "${object}"`,
);
}

const nnz = object.indexes.length;

if (nnz > PG_VECTOR_MAX_DIM || nnz > object.length) {
throw new InvalidArgumentError(
"too many elements in sparse vector value",
);
}

buf
.writeUInt32(4 * (3 + nnz * 2))
.writeUInt32(object.length)
.writeUInt32(nnz)
.writeUInt32(0);

const vecBuf = new Uint8Array(nnz * 8);
const data = new DataView(
vecBuf.buffer,
vecBuf.byteOffset,
vecBuf.byteLength,
);

for (let i = 0; i < nnz; i++) {
data.setUint32(i * 4, object.indexes[i]);
}
for (let i = 0; i < nnz; i++) {
data.setFloat32((nnz + i) * 4, object.values[i]);
}

buf.writeBuffer(vecBuf);
}

decode(buf: ReadBuffer): any {
const dim = buf.readUInt32();
const nnz = buf.readUInt32();
buf.discard(4);

const vecBuf = buf.readBuffer(nnz * 8);
const data = new DataView(
vecBuf.buffer,
vecBuf.byteOffset,
vecBuf.byteLength,
);
const indexes = new Uint32Array(nnz);
for (let i = 0; i < nnz; i++) {
indexes[i] = data.getUint32(i * 4);
}
const vecData = new Float32Array(nnz);
for (let i = 0; i < nnz; i++) {
vecData[i] = data.getFloat32((i + nnz) * 4);
}

return new SparseVector(dim, indexes, vecData);
}
}
72 changes: 72 additions & 0 deletions packages/driver/src/datatypes/pgvector.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
export class SparseVector {
public indexes: Uint32Array;
public values: Float32Array;

[index: number]: number;

constructor(length: number, map: Record<number, number>);
constructor(length: number, indexes: Uint32Array, values: Float32Array);
constructor(
public length: number,
indexesOrMap: Uint32Array | Record<number, number>,
values?: Float32Array,
) {
if (indexesOrMap instanceof Uint32Array) {
if (indexesOrMap.length !== values?.length) {
throw new Error(
"indexes array must be the same length as the data array",
);
}
if (indexesOrMap.length > length) {
throw new Error(
"length of data cannot be larger than length of sparse vector",
);
}
this.values = values;
this.indexes = indexesOrMap;
} else {
const entries = Object.entries(indexesOrMap);
if (entries.length > length) {
throw new Error(
"length of data cannot be larger than length of sparse vector",
);
}
this.indexes = new Uint32Array(entries.length);
this.values = new Float32Array(entries.length);
for (let i = 0; i < entries.length; i++) {
const index = parseInt(entries[i][0], 10);
const val = entries[i][1];
if (Number.isNaN(index)) {
throw new Error(`key ${entries[i][0]} in data map is not an integer`);
}
if (index < 0 || index >= length) {
throw new Error(
`index ${index} is out of range of sparse vector length`,
);
}
this.indexes[i] = index;
if (typeof val !== "number") {
throw new Error(
`expected value at index ${index} to be number, got ${typeof val} ${val}`,
);
}
if (val === 0) {
throw new Error("elements in sparse vector cannot be 0");
}
this.values[i] = val;
}
}

return new Proxy(this, {
get(target, p) {
const index = typeof p === "string" ? parseInt(p, 10) : NaN;
if (!Number.isNaN(index)) {
if (index < 0 || index >= target.length) return undefined;
const dataIndex = target.indexes.indexOf(index);
return dataIndex === -1 ? 0 : target.values[dataIndex];
}
return (target as any)[p];
},
});
}
}
2 changes: 2 additions & 0 deletions packages/driver/src/index.shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ export {
} from "./datatypes/datetime";
export { ConfigMemory } from "./datatypes/memory";
export { Range, MultiRange } from "./datatypes/range";
export { SparseVector } from "./datatypes/pgvector";
export { Float16Array } from "./adapter.shared.node";

export type { Executor } from "./ifaces";

Expand Down
Loading

0 comments on commit e433ee2

Please sign in to comment.