Skip to content

Commit

Permalink
Faster SCRAM when using browserCrypto (#1063)
Browse files Browse the repository at this point in the history
We use browser/globalThis crypto when it's available, but the
implementation of HMAC is slower using
globalThis.crypto.subtle.sign. This speeds that up by about 2x, but it's
still 10x slower than Node's `createHmac`.
  • Loading branch information
scotttrinh authored Jul 24, 2024
1 parent 0cbd195 commit eb92923
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 116 deletions.
35 changes: 24 additions & 11 deletions compileForDeno.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export async function run({
file,
ts.ScriptTarget.Latest,
false,
ts.ScriptKind.TS
ts.ScriptKind.TS,
);

const rewrittenFile: string[] = [];
Expand All @@ -84,7 +84,7 @@ export async function run({
const neededImports = injectImports.reduce(
(neededImports, { imports, from }) => {
const usedImports = imports.filter((importName) =>
parsedSource.identifiers?.has(importName)
parsedSource.identifiers?.has(importName),
);
if (usedImports.length) {
neededImports.push({
Expand All @@ -94,18 +94,18 @@ export async function run({
}
return neededImports;
},
[] as { imports: string[]; from: string }[]
[] as { imports: string[]; from: string }[],
);

if (neededImports.length) {
const importDecls = neededImports.map((neededImport) => {
const imports = neededImport.imports.join(", ");
// no need to resolve path if it is import from url
const importPath = neededImport.from.startsWith("https://")
// no need to resolve path if it is import from a supported protocol
const importPath = _pathUsesSupportedProtocol(neededImport.from)
? neededImport.from
: resolveImportPath(
relative(dirname(sourcePath), neededImport.from),
sourcePath
sourcePath,
);
return `import {${imports}} from "${importPath}";`;
});
Expand Down Expand Up @@ -139,7 +139,7 @@ export async function run({
if (resolvedImportPath.endsWith(`/${name}.node.ts`)) {
resolvedImportPath = resolvedImportPath.replace(
`/${name}.node.ts`,
`/${name}.deno.ts`
`/${name}.deno.ts`,
);
}
}
Expand All @@ -153,7 +153,7 @@ export async function run({
if (/__dirname/g.test(contents)) {
contents = contents.replaceAll(
/__dirname/g,
"new URL('.', import.meta.url).pathname"
"new URL('.', import.meta.url).pathname",
);
}

Expand All @@ -175,7 +175,7 @@ export async function run({
const path = importPath.replace(rule.match, (match) =>
typeof rule.replace === "function"
? rule.replace(match, sourcePath)
: rule.replace
: rule.replace,
);
if (
!path.endsWith(".ts") &&
Expand All @@ -187,6 +187,11 @@ export async function run({
}
}

// Then check if importPath is already a supported protocol
if (_pathUsesSupportedProtocol(importPath)) {
return importPath;
}

// then resolve normally
let resolvedPath = join(dirname(sourcePath), importPath);

Expand All @@ -200,18 +205,26 @@ export async function run({

if (!sourceFilePathMap.has(resolvedPath)) {
throw new Error(
`Cannot find imported file '${importPath}' in '${sourcePath}'`
`Cannot find imported file '${importPath}' in '${sourcePath}'`,
);
}
}
}

const relImportPath = relative(
dirname(sourceFilePathMap.get(sourcePath)!),
sourceFilePathMap.get(resolvedPath)!
sourceFilePathMap.get(resolvedPath)!,
);
return relImportPath.startsWith("../")
? relImportPath
: "./" + relImportPath;
}
}

function _pathUsesSupportedProtocol(path: string) {
return (
path.startsWith("https:") ||
path.startsWith("node:") ||
path.startsWith("npm:")
);
}
2 changes: 1 addition & 1 deletion packages/driver/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"build:cli": "tsc --project tsconfig.cli.json",
"build:cjs": "tsc --project tsconfig.json",
"build:deno": "deno run --unstable --allow-all ./buildDeno.ts",
"test": "npx jest --detectOpenHandles",
"test": "NODE_OPTIONS='--experimental-global-webcrypto' npx jest --detectOpenHandles",
"lint": "tslint 'packages/*/src/**/*.ts'",
"format": "prettier --write 'src/**/*.ts' 'test/**/*.ts'",
"gen-errors": "edb gen-errors-json --client | node genErrors.mjs",
Expand Down
36 changes: 1 addition & 35 deletions packages/driver/src/adapter.crypto.deno.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1 @@
import { crypto } from "https://deno.land/[email protected]/crypto/mod.ts";

import type { CryptoUtils } from "./utils.ts";

const cryptoUtils: CryptoUtils = {
async randomBytes(size: number): Promise<Uint8Array> {
const buf = new Uint8Array(size);
return crypto.getRandomValues(buf);
},

async H(msg: Uint8Array): Promise<Uint8Array> {
return new Uint8Array(await crypto.subtle.digest("SHA-256", msg));
},

async HMAC(key: Uint8Array, msg: Uint8Array): Promise<Uint8Array> {
return new Uint8Array(
await crypto.subtle.sign(
"HMAC",
await crypto.subtle.importKey(
"raw",
key,
{
name: "HMAC",
hash: { name: "SHA-256" },
},
false,
["sign"],
),
msg,
),
);
},
};

export default cryptoUtils;
export { cryptoUtils as default } from "./browserCrypto.ts";
30 changes: 2 additions & 28 deletions packages/driver/src/adapter.crypto.node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,10 @@ let cryptoUtils: CryptoUtils;

if (typeof crypto === "undefined") {
// eslint-disable-next-line @typescript-eslint/no-require-imports
const nodeCrypto = require("crypto");

cryptoUtils = {
randomBytes(size: number): Promise<Uint8Array> {
return new Promise((resolve, reject) => {
nodeCrypto.randomBytes(size, (err: Error | null, buf: Buffer) => {
if (err) {
reject(err);
} else {
resolve(buf);
}
});
});
},

async H(msg: Uint8Array): Promise<Uint8Array> {
const sign = nodeCrypto.createHash("sha256");
sign.update(msg);
return sign.digest();
},

async HMAC(key: Uint8Array, msg: Uint8Array): Promise<Uint8Array> {
const hm = nodeCrypto.createHmac("sha256", key);
hm.update(msg);
return hm.digest();
},
};
cryptoUtils = require("./nodeCrypto").cryptoUtils;
} else {
// eslint-disable-next-line @typescript-eslint/no-require-imports
cryptoUtils = require("./browserCrypto").default;
cryptoUtils = require("./browserCrypto").cryptoUtils;
}

export default cryptoUtils;
2 changes: 1 addition & 1 deletion packages/driver/src/browserClient.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { BaseClientPool, Client, type ConnectOptions } from "./baseClient";
import { getConnectArgumentsParser } from "./conUtils";
import cryptoUtils from "./browserCrypto";
import { cryptoUtils } from "./browserCrypto";
import { EdgeDBError } from "./errors";
import { FetchConnection } from "./fetchConn";
import { getHTTPSCRAMAuth } from "./httpScram";
Expand Down
60 changes: 33 additions & 27 deletions packages/driver/src/browserCrypto.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
import type { CryptoUtils } from "./utils";

const cryptoUtils: CryptoUtils = {
async randomBytes(size: number): Promise<Uint8Array> {
return crypto.getRandomValues(new Uint8Array(size));
},
async function makeKey(key: Uint8Array): Promise<CryptoKey> {
return await crypto.subtle.importKey(
"raw",
key,
{
name: "HMAC",
hash: { name: "SHA-256" },
},
false,
["sign"],
);
}

async H(msg: Uint8Array): Promise<Uint8Array> {
return new Uint8Array(await crypto.subtle.digest("SHA-256", msg));
},
function randomBytes(size: number): Uint8Array {
return crypto.getRandomValues(new Uint8Array(size));
}

async HMAC(key: Uint8Array, msg: Uint8Array): Promise<Uint8Array> {
return new Uint8Array(
await crypto.subtle.sign(
"HMAC",
await crypto.subtle.importKey(
"raw",
key,
{
name: "HMAC",
hash: { name: "SHA-256" },
},
false,
["sign"],
),
msg,
),
);
},
};
async function H(msg: Uint8Array): Promise<Uint8Array> {
return new Uint8Array(await crypto.subtle.digest("SHA-256", msg));
}

async function HMAC(
key: Uint8Array | CryptoKey,
msg: Uint8Array,
): Promise<Uint8Array> {
const cryptoKey =
key instanceof Uint8Array ? ((await makeKey(key)) as CryptoKey) : key;
return new Uint8Array(await crypto.subtle.sign("HMAC", cryptoKey, msg));
}

export default cryptoUtils;
export const cryptoUtils: CryptoUtils = {
makeKey,
randomBytes,
H,
HMAC,
};
2 changes: 1 addition & 1 deletion packages/driver/src/httpScram.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export function getHTTPSCRAMAuth(cryptoUtils: CryptoUtils): HttpSCRAMAuth {
password: string,
): Promise<string> {
const authUrl = baseUrl + AUTH_ENDPOINT;
const clientNonce = await generateNonce();
const clientNonce = generateNonce();
const [clientFirst, clientFirstBare] = buildClientFirstMessage(
clientNonce,
username,
Expand Down
34 changes: 34 additions & 0 deletions packages/driver/src/nodeCrypto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import crypto from "node:crypto";
import type { CryptoUtils } from "./utils";

function makeKey(keyBytes: Uint8Array): Promise<Uint8Array> {
return Promise.resolve(keyBytes);
}

function randomBytes(size: number): Buffer {
return crypto.randomBytes(size);
}

async function H(msg: Uint8Array): Promise<Buffer> {
const sign = crypto.createHash("sha256");
sign.update(msg);
return sign.digest();
}

async function HMAC(
key: Uint8Array | CryptoKey,
msg: Uint8Array,
): Promise<Buffer> {
const cryptoKey: Uint8Array | crypto.KeyObject =
key instanceof Uint8Array ? key : crypto.KeyObject.from(key);
const hm = crypto.createHmac("sha256", cryptoKey);
hm.update(msg);
return hm.digest();
}

export const cryptoUtils: CryptoUtils = {
makeKey,
randomBytes,
H,
HMAC,
};
2 changes: 1 addition & 1 deletion packages/driver/src/rawConn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ export class RawConnection extends BaseRawConnection {
);
}

const clientNonce = await scram.generateNonce();
const clientNonce = scram.generateNonce();
const [clientFirst, clientFirstBare] = scram.buildClientFirstMessage(
clientNonce,
this.config.connectionParams.user,
Expand Down
11 changes: 5 additions & 6 deletions packages/driver/src/scram.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export function saslprep(str: string): string {
return str.normalize("NFKC");
}

export function getSCRAM({ randomBytes, H, HMAC }: CryptoUtils) {
export function getSCRAM({ randomBytes, H, HMAC, makeKey }: CryptoUtils) {
function bufferEquals(a: Uint8Array, b: Uint8Array): boolean {
if (a.length !== b.length) {
return false;
Expand All @@ -45,9 +45,7 @@ export function getSCRAM({ randomBytes, H, HMAC }: CryptoUtils) {
return true;
}

function generateNonce(
length: number = RAW_NONCE_LENGTH,
): Promise<Uint8Array> {
function generateNonce(length: number = RAW_NONCE_LENGTH): Uint8Array {
return randomBytes(length);
}

Expand Down Expand Up @@ -161,11 +159,12 @@ export function getSCRAM({ randomBytes, H, HMAC }: CryptoUtils) {
msg.set(salt);
msg.set([0, 0, 0, 1], salt.length);

let Hi = await HMAC(password, msg);
const keyFromPassword = await makeKey(password);
let Hi = await HMAC(keyFromPassword, msg);
let Ui = Hi;

for (let _ = 0; _ < iterations - 1; _++) {
Ui = await HMAC(password, Ui);
Ui = await HMAC(keyFromPassword, Ui);
Hi = _XOR(Hi, Ui);
}

Expand Down
5 changes: 3 additions & 2 deletions packages/driver/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ export function versionGreaterThanOrEqual(
}

export interface CryptoUtils {
randomBytes: (size: number) => Promise<Uint8Array>;
makeKey: (key: Uint8Array) => Promise<Uint8Array | CryptoKey>;
randomBytes: (size: number) => Uint8Array;
H: (msg: Uint8Array) => Promise<Uint8Array>;
HMAC: (key: Uint8Array, msg: Uint8Array) => Promise<Uint8Array>;
HMAC: (key: Uint8Array | CryptoKey, msg: Uint8Array) => Promise<Uint8Array>;
}

const _tokens = new WeakMap<ResolvedConnectConfigReadonly, string>();
Expand Down
Loading

0 comments on commit eb92923

Please sign in to comment.