Skip to content

Commit

Permalink
Add support for std::pg::* types (#1132)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaclarke authored Nov 26, 2024
1 parent 3963951 commit c79b455
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 12 deletions.
8 changes: 7 additions & 1 deletion packages/driver/src/codecs/codecs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import { BigIntCodec, DecimalStringCodec } from "./numerics";
import { StrCodec } from "./text";
import { UUIDCodec } from "./uuid";
import { BytesCodec } from "./bytes";
import { JSONCodec } from "./json";
import { JSONCodec, PgTextJSONCodec } from "./json";
import {
DateTimeCodec,
LocalDateCodec,
Expand Down Expand Up @@ -122,6 +122,12 @@ registerScalarCodec("cal::date_duration", DateDurationCodec);

registerScalarCodec("cfg::memory", ConfigMemoryCodec);

registerScalarCodec("std::pg::json", PgTextJSONCodec);
registerScalarCodec("std::pg::timestamptz", DateTimeCodec);
registerScalarCodec("std::pg::timestamp", LocalDateTimeCodec);
registerScalarCodec("std::pg::date", LocalDateCodec);
registerScalarCodec("std::pg::interval", RelativeDurationCodec);

registerScalarCodec("ext::pgvector::vector", PgVectorCodec);
registerScalarCodec("ext::pgvector::halfvec", PgVectorHalfVecCodec);
registerScalarCodec("ext::pgvector::sparsevec", PgVectorSparseVecCodec);
5 changes: 5 additions & 0 deletions packages/driver/src/codecs/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ export const KNOWN_TYPES = new Map<uuid, string>([
["00000000000000000000000000000111", "cal::relative_duration"],
["00000000000000000000000000000112", "cal::date_duration"],
["00000000000000000000000000000130", "cfg::memory"],
["00000000000000000000000001000001", "std::pg::json"],
["00000000000000000000000001000002", "std::pg::timestamptz"],
["00000000000000000000000001000003", "std::pg::timestamp"],
["00000000000000000000000001000004", "std::pg::date"],
["00000000000000000000000001000005", "std::pg::interval"],
["9565dd8804f511eea6910b6ebe179825", "ext::pgvector::vector"],
["4ba84534188e43b4a7cecea2af0f405b", "ext::pgvector::halfvec"],
["003e434dcac2430ab238fb39d73447d2", "ext::pgvector::sparsevec"],
Expand Down
44 changes: 34 additions & 10 deletions packages/driver/src/codecs/json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import { InvalidArgumentError, ProtocolError } from "../errors";
export class JSONCodec extends ScalarCodec implements ICodec {
tsType = "unknown";

readonly jsonFormat: number | null = 1;

encode(buf: WriteBuffer, object: any): void {
let val: string;
try {
Expand All @@ -45,37 +47,59 @@ export class JSONCodec extends ScalarCodec implements ICodec {
}

const strbuf = utf8Encoder.encode(val);
buf.writeInt32(strbuf.length + 1);
buf.writeChar(1); // JSON format version
if (this.jsonFormat !== null) {
buf.writeInt32(strbuf.length + 1);
buf.writeChar(this.jsonFormat);
} else {
buf.writeInt32(strbuf.length);
}
buf.writeBuffer(strbuf);
}

decode(buf: ReadBuffer): any {
const format = buf.readUInt8();
if (format !== 1) {
throw new ProtocolError(`unexpected JSON format ${format}`);
if (this.jsonFormat !== null) {
const format = buf.readUInt8();
if (format !== this.jsonFormat) {
throw new ProtocolError(`unexpected JSON format ${format}`);
}
}
return JSON.parse(buf.consumeAsString());
}
}

export class PgTextJSONCodec extends JSONCodec {
readonly jsonFormat = null;
}

export class JSONStringCodec extends ScalarCodec implements ICodec {
readonly jsonFormat: number | null = 1;

encode(buf: WriteBuffer, object: any): void {
if (typeof object !== "string") {
throw new InvalidArgumentError(`a string was expected, got "${object}"`);
}

const strbuf = utf8Encoder.encode(object);
buf.writeInt32(strbuf.length + 1);
buf.writeChar(1); // JSON format version
if (this.jsonFormat !== null) {
buf.writeInt32(strbuf.length + 1);
buf.writeChar(this.jsonFormat);
} else {
buf.writeInt32(strbuf.length);
}
buf.writeBuffer(strbuf);
}

decode(buf: ReadBuffer): any {
const format = buf.readUInt8();
if (format !== 1) {
throw new ProtocolError(`unexpected JSON format ${format}`);
if (this.jsonFormat !== null) {
const format = buf.readUInt8();
if (format !== this.jsonFormat) {
throw new ProtocolError(`unexpected JSON format ${format}`);
}
}
return buf.consumeAsString();
}
}

export class PgTextJSONStringCodec extends JSONStringCodec {
readonly jsonFormat = null;
}
29 changes: 28 additions & 1 deletion packages/driver/src/codecs/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { NULL_CODEC_ID, KNOWN_TYPES, KNOWN_TYPENAMES } from "./consts";
import { EMPTY_TUPLE_CODEC, EMPTY_TUPLE_CODEC_ID, TupleCodec } from "./tuple";
import * as numbers from "./numbers";
import * as datecodecs from "./datetime";
import { JSONStringCodec } from "./json";
import { JSONStringCodec, PgTextJSONStringCodec } from "./json";
import { ArrayCodec } from "./array";
import { NamedTupleCodec } from "./namedtuple";
import { EnumCodec } from "./enum";
Expand Down Expand Up @@ -57,11 +57,15 @@ export interface CustomCodecSpec {
int64_bigint?: boolean;
datetime_localDatetime?: boolean;
json_string?: boolean;
pg_json_string?: boolean;
pg_timestamptz_localDatetime?: boolean;
}

const INT64_TYPEID = KNOWN_TYPENAMES.get("std::int64")!;
const DATETIME_TYPEID = KNOWN_TYPENAMES.get("std::datetime")!;
const JSON_TYPEID = KNOWN_TYPENAMES.get("std::json")!;
const PG_JSON_TYPEID = KNOWN_TYPENAMES.get("std::pg::json")!;
const PG_TIMESTAMPTZ_TYPEID = KNOWN_TYPENAMES.get("std::pg::timestamptz")!;

export class CodecsRegistry {
private codecsBuildCache: LRU<uuid, ICodec>;
Expand All @@ -78,6 +82,8 @@ export class CodecsRegistry {
int64_bigint,
datetime_localDatetime,
json_string,
pg_json_string,
pg_timestamptz_localDatetime,
}: CustomCodecSpec = {}): void {
// This is a private API and it will change in the future.

Expand Down Expand Up @@ -107,6 +113,27 @@ export class CodecsRegistry {
} else {
this.customScalarCodecs.delete(JSON_TYPEID);
}

if (pg_json_string) {
this.customScalarCodecs.set(
PG_JSON_TYPEID,
new PgTextJSONStringCodec(PG_JSON_TYPEID, "std::pg::json"),
);
} else {
this.customScalarCodecs.delete(PG_JSON_TYPEID);
}

if (pg_timestamptz_localDatetime) {
this.customScalarCodecs.set(
PG_TIMESTAMPTZ_TYPEID,
new datecodecs.LocalDateTimeCodec(
PG_TIMESTAMPTZ_TYPEID,
"std::pg::timestamptz",
),
);
} else {
this.customScalarCodecs.delete(PG_TIMESTAMPTZ_TYPEID);
}
}

hasCodec(typeId: uuid): boolean {
Expand Down
24 changes: 24 additions & 0 deletions packages/driver/test/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,30 @@ if (getEdgeDBVersion().major >= 6) {
await client.close();
}
});

test("querySQL std::pg:: types", async () => {
let client = getClient();

const pgTypes: [string, any][] = [
["json", [{ abc: 123 }, "test", 456]],
["timestamptz", new Date()],
["timestamp", new LocalDateTime(2024, 11, 15, 16, 20, 1, 2, 3)],
["date", new LocalDate(2024, 11, 15)],
["interval", new RelativeDuration(1, 2, 3, 4, 5, 6, 7, 8, 9)],
];

try {
for (const [typename, val] of pgTypes) {
const res = await client.querySQL<{ val: any }>(
`select $1::${typename} as "val"`,
[val],
);
expect(JSON.stringify(res[0].val)).toEqual(JSON.stringify(val));
}
} finally {
await client.close();
}
});
} else {
test("SQL methods should fail nicely if proto v3 not supported", async () => {
let client = getClient();
Expand Down

0 comments on commit c79b455

Please sign in to comment.