From 12fbf7c38c928c773dde2dfbd62e839142d4ed6f Mon Sep 17 00:00:00 2001 From: Max Isom Date: Mon, 8 Apr 2024 14:41:58 -0700 Subject: [PATCH] Enable nested beforeTemplateIsBaked calls --- src/index.ts | 120 +++++++++++++++---------- src/internal-types.ts | 1 + src/public-types.ts | 5 ++ src/tests/hooks.test.ts | 46 ++++++++++ src/tests/utils/does-database-exist.ts | 12 +++ src/worker.ts | 17 +++- 6 files changed, 153 insertions(+), 48 deletions(-) create mode 100644 src/tests/utils/does-database-exist.ts diff --git a/src/index.ts b/src/index.ts index 0d08122..ecc299e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -16,8 +16,7 @@ import type { import { Pool } from "pg" import type { Jsonifiable } from "type-fest" import type { ExecutionContext } from "ava" -import { once } from "node:events" -import { createBirpc } from "birpc" +import { BirpcReturn, createBirpc } from "birpc" import { ExecResult } from "testcontainers" import isPlainObject from "lodash/isPlainObject" @@ -136,57 +135,86 @@ export const getTestPostgresDatabaseFactory = < } let rpcCallback: (data: any) => void - const rpc = createBirpc( - { - runBeforeTemplateIsBakedHook: async (connection, params) => { - if (options?.beforeTemplateIsBaked) { - const connectionDetails = - mapWorkerConnectionDetailsToConnectionDetails(connection) - - // Ignore if the pool is terminated by the shared worker - // (This happens in CI for some reason even though we drain the pool first.) - connectionDetails.pool.on("error", (error) => { - if ( - error.message.includes( - "terminating connection due to administrator command" - ) - ) { - return - } + const rpc: BirpcReturn = + createBirpc( + { + runBeforeTemplateIsBakedHook: async (connection, params) => { + if (options?.beforeTemplateIsBaked) { + const connectionDetails = + mapWorkerConnectionDetailsToConnectionDetails(connection) + + // Ignore if the pool is terminated by the shared worker + // (This happens in CI for some reason even though we drain the pool first.) + connectionDetails.pool.on("error", (error) => { + if ( + error.message.includes( + "terminating connection due to administrator command" + ) + ) { + return + } + + throw error + }) + + const createdNestedConnections: ConnectionDetails[] = [] + const hookResult = await options.beforeTemplateIsBaked({ + params: params as any, + connection: connectionDetails, + containerExec: async (command): Promise => + rpc.execCommandInContainer(command), + // This is what allows a consumer to get a "nested" database from within their beforeTemplateIsBaked hook + beforeTemplateIsBaked: async (options) => { + const { connectionDetails, beforeTemplateIsBakedResult } = + await rpc.getTestDatabase({ + params: options.params, + databaseDedupeKey: options.databaseDedupeKey, + }) - throw error - }) + const mappedConnection = + mapWorkerConnectionDetailsToConnectionDetails( + connectionDetails + ) - const hookResult = await options.beforeTemplateIsBaked({ - params: params as any, - connection: connectionDetails, - containerExec: async (command): Promise => - rpc.execCommandInContainer(command), - }) + createdNestedConnections.push(mappedConnection) - await teardownConnection(connectionDetails) + return { + ...mappedConnection, + beforeTemplateIsBakedResult, + } + }, + }) - if (hookResult && !isSerializable(hookResult)) { - throw new TypeError( - "Return value of beforeTemplateIsBaked() hook could not be serialized. Make sure it returns only JSON-serializable values." + await Promise.all( + createdNestedConnections.map(async (connection) => { + await teardownConnection(connection) + await rpc.dropDatabase(connection.database) + }) ) - } - return hookResult - } - }, - }, - { - post: async (data) => { - const worker = await workerPromise - await worker.available - worker.publish(data) - }, - on: (data) => { - rpcCallback = data + await teardownConnection(connectionDetails) + + if (hookResult && !isSerializable(hookResult)) { + throw new TypeError( + "Return value of beforeTemplateIsBaked() hook could not be serialized. Make sure it returns only JSON-serializable values." + ) + } + + return hookResult + } + }, }, - } - ) + { + post: async (data) => { + const worker = await workerPromise + await worker.available + worker.publish(data) + }, + on: (data) => { + rpcCallback = data + }, + } + ) // Automatically cleaned up by AVA since each test file runs in a separate worker const _messageHandlerPromise = (async () => { diff --git a/src/internal-types.ts b/src/internal-types.ts index 2700f0e..cb8f11b 100644 --- a/src/internal-types.ts +++ b/src/internal-types.ts @@ -30,4 +30,5 @@ export interface SharedWorkerFunctions { beforeTemplateIsBakedResult: unknown }> execCommandInContainer: (command: string[]) => Promise + dropDatabase: (databaseName: string) => Promise } diff --git a/src/public-types.ts b/src/public-types.ts index d3ce7c7..038a25e 100644 --- a/src/public-types.ts +++ b/src/public-types.ts @@ -55,6 +55,11 @@ export interface GetTestPostgresDatabaseFactoryOptions< connection: ConnectionDetails params: Params containerExec: (command: string[]) => Promise + beforeTemplateIsBaked: ( + options: { + params: Params + } & Pick + ) => Promise }) => Promise } diff --git a/src/tests/hooks.test.ts b/src/tests/hooks.test.ts index 1944190..64636a3 100644 --- a/src/tests/hooks.test.ts +++ b/src/tests/hooks.test.ts @@ -1,6 +1,7 @@ import test from "ava" import { getTestPostgresDatabaseFactory } from "~/index" import { countDatabaseTemplates } from "./utils/count-database-templates" +import { doesDatabaseExist } from "./utils/does-database-exist" test("beforeTemplateIsBaked", async (t) => { let wasHookCalled = false @@ -145,3 +146,48 @@ test("beforeTemplateIsBaked (result isn't serializable)", async (t) => { } ) }) + +test("beforeTemplateIsBaked, get nested database", async (t) => { + type DatabaseParams = { + type: "foo" | "bar" + } + + let nestedDatabaseName: string | undefined = undefined + + const getTestServer = getTestPostgresDatabaseFactory({ + postgresVersion: process.env.POSTGRES_VERSION, + workerDedupeKey: "beforeTemplateIsBakedHookNestedDatabase", + beforeTemplateIsBaked: async ({ + params, + connection: { pool }, + beforeTemplateIsBaked, + }) => { + if (params.type === "foo") { + await pool.query(`CREATE TABLE "foo" ("id" SERIAL PRIMARY KEY)`) + return { createdFoo: true } + } + + await pool.query(`CREATE TABLE "bar" ("id" SERIAL PRIMARY KEY)`) + const fooDatabase = await beforeTemplateIsBaked({ + params: { type: "foo" }, + }) + t.deepEqual(fooDatabase.beforeTemplateIsBakedResult, { createdFoo: true }) + + nestedDatabaseName = fooDatabase.database + + await t.notThrowsAsync(async () => { + await fooDatabase.pool.query(`INSERT INTO "foo" DEFAULT VALUES`) + }) + + return { createdBar: true } + }, + }) + + const database = await getTestServer(t, { type: "bar" }) + t.deepEqual(database.beforeTemplateIsBakedResult, { createdBar: true }) + + t.false( + await doesDatabaseExist(database.pool, nestedDatabaseName!), + "Nested database should have been cleaned up after the parent hook completed" + ) +}) diff --git a/src/tests/utils/does-database-exist.ts b/src/tests/utils/does-database-exist.ts new file mode 100644 index 0000000..b1d3b42 --- /dev/null +++ b/src/tests/utils/does-database-exist.ts @@ -0,0 +1,12 @@ +import { Pool } from "pg" + +export const doesDatabaseExist = async (pool: Pool, databaseName: string) => { + const { + rows: [{ count }], + } = await pool.query( + 'SELECT COUNT(*) FROM "pg_database" WHERE "datname" = $1', + [databaseName] + ) + + return count > 0 +} diff --git a/src/worker.ts b/src/worker.ts index 8130e6f..5e2e409 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -73,6 +73,10 @@ export class Worker { const container = (await this.startContainerPromise).container return container.exec(command) }, + dropDatabase: async (databaseName) => { + const { postgresClient } = await this.startContainerPromise + await postgresClient.query(`DROP DATABASE ${databaseName}`) + }, }, rpcChannel ) @@ -148,8 +152,17 @@ export class Worker { return } - await this.forceDisconnectClientsFrom(databaseName!) - await postgresClient.query(`DROP DATABASE ${databaseName}`) + try { + await this.forceDisconnectClientsFrom(databaseName!) + await postgresClient.query(`DROP DATABASE ${databaseName}`) + } catch (error) { + if ((error as Error)?.message?.includes("does not exist")) { + // Database was likely a nested database and manually dropped by the test worker, ignore + return + } + + throw error + } }) return {