Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add codecs for ext::pgvector::halfvec+ext::pgvector::sparsevec and add SparseVector datatype #1124

Merged
merged 6 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/driver/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"dev": "yarn tsc --project tsconfig.json --incremental && yarn build:deno"
},
"dependencies": {
"@petamoriken/float16": "^3.8.7",
"debug": "^4.3.4",
"env-paths": "^3.0.0",
"semver": "^7.6.2",
Expand Down
24 changes: 24 additions & 0 deletions packages/driver/src/adapter.shared.deno.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,27 @@ export function getEnv(envName: string, required = false): string | undefined {
}
return Deno.env.get(envName);
}

const _Float16Array = Float16Array;
export { _Float16Array as Float16Array };

export function getFloat16(
dataView: DataView,
byteOffset: number,
littleEndian?: boolean,
): number {
return dataView.getFloat16(byteOffset, littleEndian);
}

export function setFloat16(
dataView: DataView,
byteOffset: number,
value: number,
littleEndian?: boolean,
): void {
dataView.setFloat16(byteOffset, value, littleEndian);
}

export function isFloat16Array(value: unknown): value is Float16Array {
return value instanceof Float16Array;
}
7 changes: 7 additions & 0 deletions packages/driver/src/adapter.shared.node.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
export {
Float16Array,
getFloat16,
isFloat16Array,
setFloat16,
} from "@petamoriken/float16";

export function getEnv(envName: string, _required = false): string | undefined {
return process.env[envName];
}
8 changes: 7 additions & 1 deletion packages/driver/src/codecs/codecs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ import {
DateDurationCodec,
} from "./datetime";
import { ConfigMemoryCodec } from "./memory";
import { PgVectorCodec } from "./pgvector";
import {
PgVectorCodec,
PgVectorHalfVecCodec,
PgVectorSparseVecCodec,
} from "./pgvector";
import { InternalClientError } from "../errors";

import { INVALID_CODEC_ID, KNOWN_TYPENAMES, NULL_CODEC_ID } from "./consts";
Expand Down Expand Up @@ -119,3 +123,5 @@ registerScalarCodec("cal::date_duration", DateDurationCodec);
registerScalarCodec("cfg::memory", ConfigMemoryCodec);

registerScalarCodec("ext::pgvector::vector", PgVectorCodec);
registerScalarCodec("ext::pgvector::halfvec", PgVectorHalfVecCodec);
registerScalarCodec("ext::pgvector::sparsevec", PgVectorSparseVecCodec);
2 changes: 2 additions & 0 deletions packages/driver/src/codecs/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ export const KNOWN_TYPES = new Map<uuid, string>([
["00000000000000000000000000000112", "cal::date_duration"],
["00000000000000000000000000000130", "cfg::memory"],
["9565dd8804f511eea6910b6ebe179825", "ext::pgvector::vector"],
["4ba84534188e43b4a7cecea2af0f405b", "ext::pgvector::halfvec"],
["003e434dcac2430ab238fb39d73447d2", "ext::pgvector::sparsevec"],
]);

export const KNOWN_TYPENAMES = (() => {
Expand Down
142 changes: 141 additions & 1 deletion packages/driver/src/codecs/pgvector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@
import type { ReadBuffer, WriteBuffer } from "../primitives/buffer";
import { type ICodec, ScalarCodec } from "./ifaces";
import { InvalidArgumentError } from "../errors";
import {
Float16Array,
getFloat16,
isFloat16Array,
setFloat16,
} from "../adapter.shared.node";
import { SparseVector } from "../datatypes/pgvector";

export const PG_VECTOR_MAX_DIM = (1 << 16) - 1;

export class PgVectorCodec extends ScalarCodec implements ICodec {
tsType = "Float32Array";
readonly tsType = "Float32Array";

encode(buf: WriteBuffer, object: any): void {
if (!(object instanceof Float32Array || Array.isArray(object))) {
Expand Down Expand Up @@ -78,3 +85,136 @@ export class PgVectorCodec extends ScalarCodec implements ICodec {
return vec;
}
}

export class PgVectorHalfVecCodec extends ScalarCodec implements ICodec {
readonly tsType = "Float16Array";
readonly tsModule = "edgedb";

encode(buf: WriteBuffer, object: any): void {
if (!(isFloat16Array(object) || Array.isArray(object))) {
throw new InvalidArgumentError(
`a Float16Array or array of numbers was expected, got "${object}"`,
);
}

if (object.length > PG_VECTOR_MAX_DIM) {
throw new InvalidArgumentError(
"too many elements in array to encode into pgvector",
);
}

buf
.writeInt32(4 + object.length * 2)
.writeUInt16(object.length)
.writeUInt16(0);

const vecBuf = new Uint8Array(object.length * 2);
const data = new DataView(
vecBuf.buffer,
vecBuf.byteOffset,
vecBuf.byteLength,
);

if (isFloat16Array(object)) {
for (let i = 0; i < object.length; i++) {
setFloat16(data, i * 2, object[i]);
}
} else {
for (let i = 0; i < object.length; i++) {
if (typeof object[i] !== "number") {
throw new InvalidArgumentError(
`elements of vector array expected to be a numbers, got "${object[i]}"`,
);
}
setFloat16(data, i * 2, object[i]);
}
}

buf.writeBuffer(vecBuf);
}

decode(buf: ReadBuffer): any {
const dim = buf.readUInt16();
buf.discard(2);

const vecBuf = buf.readBuffer(dim * 2);
const data = new DataView(
vecBuf.buffer,
vecBuf.byteOffset,
vecBuf.byteLength,
);
const vec = new Float16Array(dim);

for (let i = 0; i < dim; i++) {
vec[i] = getFloat16(data, i * 2);
}

return vec;
}
}

export class PgVectorSparseVecCodec extends ScalarCodec implements ICodec {
readonly tsType = "SparseVector";
readonly tsModule = "edgedb";

encode(buf: WriteBuffer, object: any): void {
if (!(object instanceof SparseVector)) {
throw new InvalidArgumentError(
`a SparseVector was expected, got "${object}"`,
);
}

const nnz = object.indexes.length;

if (nnz > PG_VECTOR_MAX_DIM || nnz > object.length) {
throw new InvalidArgumentError(
"too many elements in sparse vector value",
);
}

buf
.writeUInt32(4 * (3 + nnz * 2))
.writeUInt32(object.length)
.writeUInt32(nnz)
.writeUInt32(0);

const vecBuf = new Uint8Array(nnz * 8);
const data = new DataView(
vecBuf.buffer,
vecBuf.byteOffset,
vecBuf.byteLength,
);

for (let i = 0; i < nnz; i++) {
data.setUint32(i * 4, object.indexes[i]);
}
for (let i = 0; i < nnz; i++) {
data.setFloat32((nnz + i) * 4, object.values[i]);
}

buf.writeBuffer(vecBuf);
}

decode(buf: ReadBuffer): any {
const dim = buf.readUInt32();
const nnz = buf.readUInt32();
buf.discard(4);

const vecBuf = buf.readBuffer(nnz * 8);
const data = new DataView(
vecBuf.buffer,
vecBuf.byteOffset,
vecBuf.byteLength,
);
const indexes = new Uint32Array(nnz);
for (let i = 0; i < nnz; i++) {
indexes[i] = data.getUint32(i * 4);
}
const vecData = new Float32Array(nnz);
for (let i = 0; i < nnz; i++) {
vecData[i] = data.getFloat32((i + nnz) * 4);
}

return new SparseVector(dim, indexes, vecData);
}
}
72 changes: 72 additions & 0 deletions packages/driver/src/datatypes/pgvector.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
export class SparseVector {
public indexes: Uint32Array;
public values: Float32Array;

[index: number]: number;

constructor(length: number, map: Record<number, number>);
constructor(length: number, indexes: Uint32Array, values: Float32Array);
constructor(
public length: number,
indexesOrMap: Uint32Array | Record<number, number>,
values?: Float32Array,
) {
if (indexesOrMap instanceof Uint32Array) {
if (indexesOrMap.length !== values?.length) {
throw new Error(
"indexes array must be the same length as the data array",
);
}
if (indexesOrMap.length > length) {
throw new Error(
"length of data cannot be larger than length of sparse vector",
);
}
this.values = values;
this.indexes = indexesOrMap;
} else {
const entries = Object.entries(indexesOrMap);
if (entries.length > length) {
throw new Error(
"length of data cannot be larger than length of sparse vector",
);
}
this.indexes = new Uint32Array(entries.length);
this.values = new Float32Array(entries.length);
for (let i = 0; i < entries.length; i++) {
const index = parseInt(entries[i][0], 10);
const val = entries[i][1];
if (Number.isNaN(index)) {
throw new Error(`key ${entries[i][0]} in data map is not an integer`);
}
if (index < 0 || index >= length) {
throw new Error(
`index ${index} is out of range of sparse vector length`,
);
}
this.indexes[i] = index;
if (typeof val !== "number") {
throw new Error(
`expected value at index ${index} to be number, got ${typeof val} ${val}`,
);
}
if (val === 0) {
throw new Error("elements in sparse vector cannot be 0");
}
this.values[i] = val;
}
}

return new Proxy(this, {
get(target, p) {
const index = typeof p === "string" ? parseInt(p, 10) : NaN;
if (!Number.isNaN(index)) {
if (index < 0 || index >= target.length) return undefined;
const dataIndex = target.indexes.indexOf(index);
return dataIndex === -1 ? 0 : target.values[dataIndex];
}
return (target as any)[p];
},
});
}
}
2 changes: 2 additions & 0 deletions packages/driver/src/index.shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ export {
} from "./datatypes/datetime";
export { ConfigMemory } from "./datatypes/memory";
export { Range, MultiRange } from "./datatypes/range";
export { SparseVector } from "./datatypes/pgvector";
export { Float16Array } from "./adapter.shared.node";

export type { Executor } from "./ifaces";

Expand Down
Loading