diff --git a/packages/driver/src/baseClient.ts b/packages/driver/src/baseClient.ts index b25ae7e91..a06dfff43 100644 --- a/packages/driver/src/baseClient.ts +++ b/packages/driver/src/baseClient.ts @@ -604,6 +604,13 @@ export class Client implements Executor { ); } + withQueryTag(tag: string | null): Client { + return new Client( + this.pool, + this.options.withSession(this.options.session.withQueryTag(tag)), + ); + } + withWarningHandler(handler: WarningHandler): Client { return new Client(this.pool, this.options.withWarningHandler(handler)); } diff --git a/packages/driver/src/baseConn.ts b/packages/driver/src/baseConn.ts index 730c91e8c..bd2581fc5 100644 --- a/packages/driver/src/baseConn.ts +++ b/packages/driver/src/baseConn.ts @@ -899,6 +899,18 @@ export class BaseRawConnection { options: QueryOptions | undefined, language: Language, ) { + if (versionGreaterThanOrEqual(this.protocolVersion, [3, 0])) { + if (state.annotations.size >= 1 << 16) { + throw new errors.InternalClientError("too many annotations"); + } + wb.writeUInt16(state.annotations.size); + for (const [name, value] of state.annotations) { + wb.writeString(name); + wb.writeString(value); + } + } else { + wb.writeUInt16(0); + } wb.writeFlags(0xffff_ffff, capabilitiesFlags); wb.writeFlags( 0, @@ -953,7 +965,6 @@ export class BaseRawConnection { ): Promise { const wb = new WriteMessageBuffer(); wb.beginMessage(chars.$P); - wb.writeUInt16(0); // no headers this._encodeParseParams( wb, @@ -1080,7 +1091,6 @@ export class BaseRawConnection { ): Promise { const wb = new WriteMessageBuffer(); wb.beginMessage(chars.$O); - wb.writeUInt16(0); // no headers this._encodeParseParams( wb, diff --git a/packages/driver/src/options.ts b/packages/driver/src/options.ts index 08b0a3201..a2263c224 100644 --- a/packages/driver/src/options.ts +++ b/packages/driver/src/options.ts @@ -1,4 +1,5 @@ import * as errors from "./errors"; +import { utf8Encoder } from "./primitives/buffer"; export type BackoffFunction = (n: number) => number; @@ -118,6 +119,8 @@ export class TransactionOptions { } } +const TAG_ANNOTATION_KEY = "tag"; + export interface SessionOptions { module?: string; moduleAliases?: Record; @@ -138,6 +141,13 @@ export class Session { readonly config: Record; readonly globals: Record; + /** @internal */ + annotations = new Map(); + + get tag(): string | null { + return this.annotations.get(TAG_ANNOTATION_KEY) ?? null; + } + constructor({ module = "default", moduleAliases = {}, @@ -150,33 +160,56 @@ export class Session { this.globals = globals; } + private _clone(mergeOptions: SessionOptions) { + const session = new Session({ ...this, ...mergeOptions }); + session.annotations = this.annotations; + return session; + } + withModuleAliases({ module, ...aliases }: { [name: string]: string; }): Session { - return new Session({ - ...this, + return this._clone({ module: module ?? this.module, moduleAliases: { ...this.moduleAliases, ...aliases }, }); } withConfig(config: { [name: string]: any }): Session { - return new Session({ - ...this, + return this._clone({ config: { ...this.config, ...config }, }); } withGlobals(globals: { [name: string]: any }): Session { - return new Session({ - ...this, + return this._clone({ globals: { ...this.globals, ...globals }, }); } + withQueryTag(tag: string | null): Session { + const session = new Session({ ...this }); + session.annotations = new Map(this.annotations); + if (tag != null) { + if (tag.startsWith("edgedb/")) { + throw new errors.InterfaceError("reserved tag: edgedb/*"); + } + if (tag.startsWith("gel/")) { + throw new errors.InterfaceError("reserved tag: gel/*"); + } + if (utf8Encoder.encode(tag).length > 128) { + throw new errors.InterfaceError("tag too long (> 128 bytes)"); + } + session.annotations.set(TAG_ANNOTATION_KEY, tag); + } else { + session.annotations.delete(TAG_ANNOTATION_KEY); + } + return session; + } + /** @internal */ _serialise() { const state: SerializedSessionState = {}; diff --git a/packages/driver/test/browser.test.ts b/packages/driver/test/browser.test.ts index e9b110c57..ad39edcb3 100644 --- a/packages/driver/test/browser.test.ts +++ b/packages/driver/test/browser.test.ts @@ -73,9 +73,11 @@ import { const brokenConnectOpts = JSON.parse( process.env._JEST_EDGEDB_CONNECT_CONFIG || "", ); +const edgedbVersion = JSON.parse(process.env._JEST_EDGEDB_VERSION!); const connectOpts = { ...brokenConnectOpts, + user: edgedbVersion.major >= 6 ? "admin" : "edgedb", tlsCAFile: undefined, tlsSecurity: "insecure", };