diff --git a/.changeset/lemon-cherries-greet.md b/.changeset/lemon-cherries-greet.md new file mode 100644 index 0000000000..6f58c5cc03 --- /dev/null +++ b/.changeset/lemon-cherries-greet.md @@ -0,0 +1,6 @@ +--- +"@trigger.dev/react-hooks": patch +"@trigger.dev/sdk": patch +--- + +Realtime streams now powered by electric. Also, this change fixes a realtime bug that was causing too many re-renders, even on records that didn't change diff --git a/apps/webapp/app/env.server.ts b/apps/webapp/app/env.server.ts index 733ef5766d..99c810c76f 100644 --- a/apps/webapp/app/env.server.ts +++ b/apps/webapp/app/env.server.ts @@ -243,6 +243,8 @@ const EnvironmentSchema = z.object({ MAXIMUM_DEV_QUEUE_SIZE: z.coerce.number().int().optional(), MAXIMUM_DEPLOYED_QUEUE_SIZE: z.coerce.number().int().optional(), MAX_BATCH_V2_TRIGGER_ITEMS: z.coerce.number().int().default(500), + + REALTIME_STREAM_VERSION: z.enum(["v1", "v2"]).default("v1"), }); export type Environment = z.infer; diff --git a/apps/webapp/app/presenters/v3/SpanPresenter.server.ts b/apps/webapp/app/presenters/v3/SpanPresenter.server.ts index fee3aab63f..539a0c8110 100644 --- a/apps/webapp/app/presenters/v3/SpanPresenter.server.ts +++ b/apps/webapp/app/presenters/v3/SpanPresenter.server.ts @@ -215,7 +215,9 @@ export class SpanPresenter extends BasePresenter { const span = await eventRepository.getSpan(spanId, run.traceId); const metadata = run.metadata - ? await prettyPrintPacket(run.metadata, run.metadataType, { filteredKeys: ["$$streams"] }) + ? await prettyPrintPacket(run.metadata, run.metadataType, { + filteredKeys: ["$$streams", "$$streamsVersion"], + }) : undefined; const context = { diff --git a/apps/webapp/app/routes/realtime.v1.streams.$runId.$streamId.ts b/apps/webapp/app/routes/realtime.v1.streams.$runId.$streamId.ts index 3bce319958..c370869d3f 100644 --- a/apps/webapp/app/routes/realtime.v1.streams.$runId.$streamId.ts +++ b/apps/webapp/app/routes/realtime.v1.streams.$runId.$streamId.ts @@ -1,7 +1,7 @@ import { ActionFunctionArgs } from "@remix-run/server-runtime"; import { z } from "zod"; import { $replica } from "~/db.server"; -import { realtimeStreams } from "~/services/realtimeStreamsGlobal.server"; +import { v1RealtimeStreams } from "~/services/realtime/v1StreamsGlobal.server"; import { createLoaderApiRoute } from "~/services/routeBuilders/apiBuilder.server"; const ParamsSchema = z.object({ @@ -16,7 +16,7 @@ export async function action({ request, params }: ActionFunctionArgs) { return new Response("No body provided", { status: 400 }); } - return realtimeStreams.ingestData(request.body, $params.runId, $params.streamId); + return v1RealtimeStreams.ingestData(request.body, $params.runId, $params.streamId); } export const loader = createLoaderApiRoute( @@ -50,7 +50,13 @@ export const loader = createLoaderApiRoute( superScopes: ["read:runs", "read:all", "admin"], }, }, - async ({ params, request, resource: run }) => { - return realtimeStreams.streamResponse(run.friendlyId, params.streamId, request.signal); + async ({ params, request, resource: run, authentication }) => { + return v1RealtimeStreams.streamResponse( + request, + run.friendlyId, + params.streamId, + authentication.environment, + request.signal + ); } ); diff --git a/apps/webapp/app/routes/realtime.v2.streams.$runId.$streamId.ts b/apps/webapp/app/routes/realtime.v2.streams.$runId.$streamId.ts new file mode 100644 index 0000000000..9f22701a78 --- /dev/null +++ b/apps/webapp/app/routes/realtime.v2.streams.$runId.$streamId.ts @@ -0,0 +1,87 @@ +import { z } from "zod"; +import { $replica } from "~/db.server"; +import { + createActionApiRoute, + createLoaderApiRoute, +} from "~/services/routeBuilders/apiBuilder.server"; +import { v2RealtimeStreams } from "~/services/realtime/v2StreamsGlobal.server"; + +const ParamsSchema = z.object({ + runId: z.string(), + streamId: z.string(), +}); + +const { action } = createActionApiRoute( + { + params: ParamsSchema, + }, + async ({ request, params, authentication }) => { + if (!request.body) { + return new Response("No body provided", { status: 400 }); + } + + const run = await $replica.taskRun.findFirst({ + where: { + friendlyId: params.runId, + runtimeEnvironmentId: authentication.environment.id, + }, + include: { + batch: { + select: { + friendlyId: true, + }, + }, + }, + }); + + if (!run) { + return new Response("Run not found", { status: 404 }); + } + + return v2RealtimeStreams.ingestData(request.body, run.id, params.streamId); + } +); + +export { action }; + +export const loader = createLoaderApiRoute( + { + params: ParamsSchema, + allowJWT: true, + corsStrategy: "all", + findResource: async (params, auth) => { + return $replica.taskRun.findFirst({ + where: { + friendlyId: params.runId, + runtimeEnvironmentId: auth.environment.id, + }, + include: { + batch: { + select: { + friendlyId: true, + }, + }, + }, + }); + }, + authorization: { + action: "read", + resource: (run) => ({ + runs: run.friendlyId, + tags: run.runTags, + batch: run.batch?.friendlyId, + tasks: run.taskIdentifier, + }), + superScopes: ["read:runs", "read:all", "admin"], + }, + }, + async ({ params, request, resource: run, authentication }) => { + return v2RealtimeStreams.streamResponse( + request, + run.id, + params.streamId, + authentication.environment, + request.signal + ); + } +); diff --git a/apps/webapp/app/services/realtime/databaseRealtimeStreams.server.ts b/apps/webapp/app/services/realtime/databaseRealtimeStreams.server.ts new file mode 100644 index 0000000000..07f428d5c2 --- /dev/null +++ b/apps/webapp/app/services/realtime/databaseRealtimeStreams.server.ts @@ -0,0 +1,85 @@ +import { PrismaClient } from "@trigger.dev/database"; +import { AuthenticatedEnvironment } from "../apiAuth.server"; +import { logger } from "../logger.server"; +import { RealtimeClient } from "../realtimeClient.server"; +import { StreamIngestor, StreamResponder } from "./types"; + +export type DatabaseRealtimeStreamsOptions = { + prisma: PrismaClient; + realtimeClient: RealtimeClient; +}; + +// Class implementing both interfaces +export class DatabaseRealtimeStreams implements StreamIngestor, StreamResponder { + constructor(private options: DatabaseRealtimeStreamsOptions) {} + + async streamResponse( + request: Request, + runId: string, + streamId: string, + environment: AuthenticatedEnvironment, + signal: AbortSignal + ): Promise { + return this.options.realtimeClient.streamChunks( + request.url, + environment, + runId, + streamId, + signal, + request.headers.get("x-trigger-electric-version") ?? undefined + ); + } + + async ingestData( + stream: ReadableStream, + runId: string, + streamId: string + ): Promise { + try { + const textStream = stream.pipeThrough(new TextDecoderStream()); + const reader = textStream.getReader(); + let sequence = 0; + + while (true) { + const { done, value } = await reader.read(); + + if (done) { + break; + } + + logger.debug("[DatabaseRealtimeStreams][ingestData] Reading data", { + streamId, + runId, + value, + }); + + const chunks = value + .split("\n") + .filter((chunk) => chunk) // Remove empty lines + .map((line) => { + return { + sequence: sequence++, + value: line, + }; + }); + + await this.options.prisma.realtimeStreamChunk.createMany({ + data: chunks.map((chunk) => { + return { + runId, + key: streamId, + sequence: chunk.sequence, + value: chunk.value, + }; + }), + }); + } + + return new Response(null, { status: 200 }); + } catch (error) { + logger.error("[DatabaseRealtimeStreams][ingestData] Error in ingestData:", { error }); + + return new Response(null, { status: 500 }); + } + } +} diff --git a/apps/webapp/app/services/realtimeStreams.server.ts b/apps/webapp/app/services/realtime/redisRealtimeStreams.server.ts similarity index 87% rename from apps/webapp/app/services/realtimeStreams.server.ts rename to apps/webapp/app/services/realtime/redisRealtimeStreams.server.ts index 73ee9ca180..808776bd06 100644 --- a/apps/webapp/app/services/realtimeStreams.server.ts +++ b/apps/webapp/app/services/realtime/redisRealtimeStreams.server.ts @@ -1,5 +1,7 @@ import Redis, { RedisKey, RedisOptions, RedisValue } from "ioredis"; -import { logger } from "./logger.server"; +import { logger } from "../logger.server"; +import { StreamIngestor, StreamResponder } from "./types"; +import { AuthenticatedEnvironment } from "../apiAuth.server"; export type RealtimeStreamsOptions = { redis: RedisOptions | undefined; @@ -7,10 +9,17 @@ export type RealtimeStreamsOptions = { const END_SENTINEL = "<>"; -export class RealtimeStreams { +// Class implementing both interfaces +export class RedisRealtimeStreams implements StreamIngestor, StreamResponder { constructor(private options: RealtimeStreamsOptions) {} - async streamResponse(runId: string, streamId: string, signal: AbortSignal): Promise { + async streamResponse( + request: Request, + runId: string, + streamId: string, + environment: AuthenticatedEnvironment, + signal: AbortSignal + ): Promise { const redis = new Redis(this.options.redis ?? {}); const streamKey = `stream:${runId}:${streamId}`; let isCleanedUp = false; @@ -115,11 +124,10 @@ export class RealtimeStreams { } try { - // Use TextDecoderStream to simplify text decoding const textStream = stream.pipeThrough(new TextDecoderStream()); const reader = textStream.getReader(); - const batchSize = 10; // Adjust this value based on performance testing + const batchSize = 10; let batchCommands: Array<[key: RedisKey, ...args: RedisValue[]]> = []; while (true) { @@ -131,17 +139,13 @@ export class RealtimeStreams { logger.debug("[RealtimeStreams][ingestData] Reading data", { streamKey, value }); - // 'value' is a string containing the decoded text const lines = value.split("\n"); for (const line of lines) { if (line.trim()) { - // Avoid unnecessary parsing; assume 'line' is already a JSON string - // Add XADD command with MAXLEN option to limit stream size batchCommands.push([streamKey, "MAXLEN", "~", "2500", "*", "data", line]); if (batchCommands.length >= batchSize) { - // Send batch using a pipeline const pipeline = redis.pipeline(); for (const args of batchCommands) { pipeline.xadd(...args); @@ -153,7 +157,6 @@ export class RealtimeStreams { } } - // Send any remaining commands if (batchCommands.length > 0) { const pipeline = redis.pipeline(); for (const args of batchCommands) { @@ -162,7 +165,6 @@ export class RealtimeStreams { await pipeline.exec(); } - // Send the __end message to indicate the end of the stream await redis.xadd(streamKey, "MAXLEN", "~", "1000", "*", "data", END_SENTINEL); return new Response(null, { status: 200 }); diff --git a/apps/webapp/app/services/realtime/types.ts b/apps/webapp/app/services/realtime/types.ts new file mode 100644 index 0000000000..802e99c38e --- /dev/null +++ b/apps/webapp/app/services/realtime/types.ts @@ -0,0 +1,21 @@ +import { AuthenticatedEnvironment } from "../apiAuth.server"; + +// Interface for stream ingestion +export interface StreamIngestor { + ingestData( + stream: ReadableStream, + runId: string, + streamId: string + ): Promise; +} + +// Interface for stream response +export interface StreamResponder { + streamResponse( + request: Request, + runId: string, + streamId: string, + environment: AuthenticatedEnvironment, + signal: AbortSignal + ): Promise; +} diff --git a/apps/webapp/app/services/realtimeStreamsGlobal.server.ts b/apps/webapp/app/services/realtime/v1StreamsGlobal.server.ts similarity index 60% rename from apps/webapp/app/services/realtimeStreamsGlobal.server.ts rename to apps/webapp/app/services/realtime/v1StreamsGlobal.server.ts index a4f20ac060..d5939e08b7 100644 --- a/apps/webapp/app/services/realtimeStreamsGlobal.server.ts +++ b/apps/webapp/app/services/realtime/v1StreamsGlobal.server.ts @@ -1,9 +1,9 @@ import { env } from "~/env.server"; import { singleton } from "~/utils/singleton"; -import { RealtimeStreams } from "./realtimeStreams.server"; +import { RedisRealtimeStreams } from "./redisRealtimeStreams.server"; -function initializeRealtimeStreams() { - return new RealtimeStreams({ +function initializeRedisRealtimeStreams() { + return new RedisRealtimeStreams({ redis: { port: env.REDIS_PORT, host: env.REDIS_HOST, @@ -16,4 +16,4 @@ function initializeRealtimeStreams() { }); } -export const realtimeStreams = singleton("realtimeStreams", initializeRealtimeStreams); +export const v1RealtimeStreams = singleton("realtimeStreams", initializeRedisRealtimeStreams); diff --git a/apps/webapp/app/services/realtime/v2StreamsGlobal.server.ts b/apps/webapp/app/services/realtime/v2StreamsGlobal.server.ts new file mode 100644 index 0000000000..a086850ee7 --- /dev/null +++ b/apps/webapp/app/services/realtime/v2StreamsGlobal.server.ts @@ -0,0 +1,13 @@ +import { prisma } from "~/db.server"; +import { singleton } from "~/utils/singleton"; +import { realtimeClient } from "../realtimeClientGlobal.server"; +import { DatabaseRealtimeStreams } from "./databaseRealtimeStreams.server"; + +function initializeDatabaseRealtimeStreams() { + return new DatabaseRealtimeStreams({ + prisma, + realtimeClient, + }); +} + +export const v2RealtimeStreams = singleton("dbRealtimeStreams", initializeDatabaseRealtimeStreams); diff --git a/apps/webapp/app/services/realtimeClient.server.ts b/apps/webapp/app/services/realtimeClient.server.ts index a4939d2809..5afd4bc31a 100644 --- a/apps/webapp/app/services/realtimeClient.server.ts +++ b/apps/webapp/app/services/realtimeClient.server.ts @@ -37,6 +37,23 @@ export class RealtimeClient { this.#registerCommands(); } + async streamChunks( + url: URL | string, + environment: RealtimeEnvironment, + runId: string, + streamId: string, + signal?: AbortSignal, + clientVersion?: string + ) { + return this.#streamChunksWhere( + url, + environment, + `"runId"='${runId}' AND "key"='${streamId}'`, + signal, + clientVersion + ); + } + async streamRun( url: URL | string, environment: RealtimeEnvironment, @@ -85,12 +102,12 @@ export class RealtimeClient { whereClause: string, clientVersion?: string ) { - const electricUrl = this.#constructElectricUrl(url, whereClause, clientVersion); + const electricUrl = this.#constructRunsElectricUrl(url, whereClause, clientVersion); - return this.#performElectricRequest(electricUrl, environment, clientVersion); + return this.#performElectricRequest(electricUrl, environment, undefined, clientVersion); } - #constructElectricUrl(url: URL | string, whereClause: string, clientVersion?: string): URL { + #constructRunsElectricUrl(url: URL | string, whereClause: string, clientVersion?: string): URL { const $url = new URL(url.toString()); const electricUrl = new URL(`${this.options.electricOrigin}/v1/shape`); @@ -112,9 +129,44 @@ export class RealtimeClient { return electricUrl; } + async #streamChunksWhere( + url: URL | string, + environment: RealtimeEnvironment, + whereClause: string, + signal?: AbortSignal, + clientVersion?: string + ) { + const electricUrl = this.#constructChunksElectricUrl(url, whereClause, clientVersion); + + return this.#performElectricRequest(electricUrl, environment, signal, clientVersion); + } + + #constructChunksElectricUrl(url: URL | string, whereClause: string, clientVersion?: string): URL { + const $url = new URL(url.toString()); + + const electricUrl = new URL(`${this.options.electricOrigin}/v1/shape`); + + // Copy over all the url search params to the electric url + $url.searchParams.forEach((value, key) => { + electricUrl.searchParams.set(key, value); + }); + + electricUrl.searchParams.set("where", whereClause); + electricUrl.searchParams.set("table", `public."RealtimeStreamChunk"`); + + if (!clientVersion) { + // If the client version is not provided, that means we're using an older client + // This means the client will be sending shape_id instead of handle + electricUrl.searchParams.set("handle", electricUrl.searchParams.get("shape_id") ?? ""); + } + + return electricUrl; + } + async #performElectricRequest( url: URL, environment: RealtimeEnvironment, + signal?: AbortSignal, clientVersion?: string ) { const shapeId = extractShapeId(url); @@ -129,13 +181,13 @@ export class RealtimeClient { if (!shapeId) { // If the shapeId is not present, we're just getting the initial value - return longPollingFetch(url.toString(), {}, rewriteResponseHeaders); + return longPollingFetch(url.toString(), { signal }, rewriteResponseHeaders); } const isLive = isLiveRequestUrl(url); if (!isLive) { - return longPollingFetch(url.toString(), {}, rewriteResponseHeaders); + return longPollingFetch(url.toString(), { signal }, rewriteResponseHeaders); } const requestId = randomUUID(); @@ -177,7 +229,7 @@ export class RealtimeClient { try { // ... (rest of your existing code for the long polling request) - const response = await longPollingFetch(url.toString(), {}, rewriteResponseHeaders); + const response = await longPollingFetch(url.toString(), { signal }, rewriteResponseHeaders); // Decrement the counter after the long polling request is complete await this.#decrementConcurrency(environment.id, requestId); diff --git a/apps/webapp/app/v3/environmentVariables/environmentVariablesRepository.server.ts b/apps/webapp/app/v3/environmentVariables/environmentVariablesRepository.server.ts index a26d1c6613..b3d5aa852e 100644 --- a/apps/webapp/app/v3/environmentVariables/environmentVariablesRepository.server.ts +++ b/apps/webapp/app/v3/environmentVariables/environmentVariablesRepository.server.ts @@ -684,6 +684,10 @@ async function resolveBuiltInDevVariables(runtimeEnvironment: RuntimeEnvironment key: "TRIGGER_STREAM_URL", value: env.STREAM_ORIGIN ?? env.API_ORIGIN ?? env.APP_ORIGIN, }, + { + key: "TRIGGER_REALTIME_STREAM_VERSION", + value: env.REALTIME_STREAM_VERSION, + }, ]; if (env.DEV_OTEL_BATCH_PROCESSING_ENABLED === "1") { @@ -754,6 +758,10 @@ async function resolveBuiltInProdVariables(runtimeEnvironment: RuntimeEnvironmen key: "TRIGGER_ORG_ID", value: runtimeEnvironment.organizationId, }, + { + key: "TRIGGER_REALTIME_STREAM_VERSION", + value: env.REALTIME_STREAM_VERSION, + }, ]; if (env.PROD_OTEL_BATCH_PROCESSING_ENABLED === "1") { diff --git a/apps/webapp/test/realtimeClient.test.ts b/apps/webapp/test/realtimeClient.test.ts index de63a81dbc..7323b9a2d2 100644 --- a/apps/webapp/test/realtimeClient.test.ts +++ b/apps/webapp/test/realtimeClient.test.ts @@ -1,9 +1,9 @@ -import { containerWithElectricTest } from "@internal/testcontainers"; +import { containerWithElectricAndRedisTest } from "@internal/testcontainers"; import { expect, describe } from "vitest"; import { RealtimeClient } from "../app/services/realtimeClient.server.js"; describe("RealtimeClient", () => { - containerWithElectricTest( + containerWithElectricAndRedisTest( "Should only track concurrency for live requests", { timeout: 30_000 }, async ({ redis, electricOrigin, prisma }) => { @@ -139,7 +139,7 @@ describe("RealtimeClient", () => { } ); - containerWithElectricTest( + containerWithElectricAndRedisTest( "Should support subscribing to a run tag", { timeout: 30_000 }, async ({ redis, electricOrigin, prisma }) => { @@ -218,7 +218,7 @@ describe("RealtimeClient", () => { } ); - containerWithElectricTest( + containerWithElectricAndRedisTest( "Should adapt for older client versions", { timeout: 30_000 }, async ({ redis, electricOrigin, prisma }) => { diff --git a/apps/webapp/test/realtimeStreams.test.ts b/apps/webapp/test/realtimeStreams.test.ts deleted file mode 100644 index dd60297f22..0000000000 --- a/apps/webapp/test/realtimeStreams.test.ts +++ /dev/null @@ -1,106 +0,0 @@ -import { redisTest } from "@internal/testcontainers"; -import { describe, expect, vi } from "vitest"; -import { RealtimeStreams } from "../app/services/realtimeStreams.server.js"; -import { convertArrayToReadableStream, convertResponseSSEStreamToArray } from "./utils/streams.js"; - -vi.setConfig({ testTimeout: 10_000 }); // 5 seconds - -// Mock the logger -vi.mock("./logger.server", () => ({ - logger: { - debug: vi.fn(), - error: vi.fn(), - }, -})); - -describe("RealtimeStreams", () => { - redisTest("should stream data from producer to consumer", async ({ redis }) => { - const streams = new RealtimeStreams({ redis: redis.options }); - const runId = "test-run"; - const streamId = "test-stream"; - - // Create a stream of test data - const stream = convertArrayToReadableStream(["chunk1", "chunk2", "chunk3"]).pipeThrough( - new TextEncoderStream() - ); - - // Start consuming the stream - const abortController = new AbortController(); - const responsePromise = streams.streamResponse(runId, streamId, abortController.signal); - - // Start ingesting data - await streams.ingestData(stream, runId, streamId); - - // Get the response and read the stream - const response = await responsePromise; - const received = await convertResponseSSEStreamToArray(response); - - expect(received).toEqual(["chunk1", "chunk2", "chunk3"]); - }); - - redisTest("should handle multiple concurrent streams", async ({ redis }) => { - const streams = new RealtimeStreams({ redis: redis.options }); - const runId = "test-run"; - - // Set up two different streams - const stream1 = convertArrayToReadableStream(["1a", "1b", "1c"]).pipeThrough( - new TextEncoderStream() - ); - const stream2 = convertArrayToReadableStream(["2a", "2b", "2c"]).pipeThrough( - new TextEncoderStream() - ); - - // Start consuming both streams - const abortController = new AbortController(); - const response1Promise = streams.streamResponse(runId, "stream1", abortController.signal); - const response2Promise = streams.streamResponse(runId, "stream2", abortController.signal); - - // Ingest data to both streams - await Promise.all([ - streams.ingestData(stream1, runId, "stream1"), - streams.ingestData(stream2, runId, "stream2"), - ]); - - // Get and verify both responses - const [response1, response2] = await Promise.all([response1Promise, response2Promise]); - const [received1, received2] = await Promise.all([ - convertResponseSSEStreamToArray(response1), - convertResponseSSEStreamToArray(response2), - ]); - - expect(received1).toEqual(["1a", "1b", "1c"]); - expect(received2).toEqual(["2a", "2b", "2c"]); - }); - - redisTest("should handle early consumer abort", async ({ redis }) => { - const streams = new RealtimeStreams({ redis: redis.options }); - const runId = "test-run"; - const streamId = "test-stream"; - - const stream = convertArrayToReadableStream(["chunk1", "chunk2", "chunk3"]).pipeThrough( - new TextEncoderStream() - ); - - // Start consuming but abort early - const abortController = new AbortController(); - const responsePromise = streams.streamResponse(runId, streamId, abortController.signal); - - // Get the response before aborting to ensure stream is properly set up - const response = await responsePromise; - - // Start reading the stream - const readPromise = convertResponseSSEStreamToArray(response); - - // Abort after a small delay to ensure everything is set up - await new Promise((resolve) => setTimeout(resolve, 100)); - abortController.abort(); - - // Start ingesting data after abort - await streams.ingestData(stream, runId, streamId); - - // Verify the stream was terminated - const received = await readPromise; - - expect(received).toEqual(["chunk1"]); - }); -}); diff --git a/docker/Dockerfile.postgres b/docker/Dockerfile.postgres new file mode 100644 index 0000000000..520c586369 --- /dev/null +++ b/docker/Dockerfile.postgres @@ -0,0 +1,5 @@ +FROM postgres:14 + +RUN apt-get update \ + && apt-get install -y postgresql-14-partman \ + && rm -rf /var/lib/apt/lists/* diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index df02a61fc1..5b9c55333f 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -13,7 +13,9 @@ networks: services: database: container_name: database - image: postgres:14 + build: + context: . + dockerfile: Dockerfile.postgres restart: always volumes: - ${DB_VOLUME:-database-data}:/var/lib/postgresql/data/ @@ -30,6 +32,8 @@ services: - listen_addresses=* - -c - wal_level=logical + - -c + - shared_preload_libraries=pg_partman_bgw pgadmin: container_name: pgadmin diff --git a/docs/frontend/react-hooks/realtime.mdx b/docs/frontend/react-hooks/realtime.mdx index b8f732e940..bf3c290c68 100644 --- a/docs/frontend/react-hooks/realtime.mdx +++ b/docs/frontend/react-hooks/realtime.mdx @@ -62,6 +62,31 @@ export function MyComponent({ } ``` +You can supply an `onComplete` callback to the `useRealtimeRun` hook to be called when the run is completed or errored. This is useful if you want to perform some action when the run is completed, like navigating to a different page or showing a notification. + +```tsx +import { useRealtimeRun } from "@trigger.dev/react-hooks"; + +export function MyComponent({ + runId, + publicAccessToken, +}: { + runId: string; + publicAccessToken: string; +}) { + const { run, error } = useRealtimeRun(runId, { + accessToken: publicAccessToken, + onComplete: (run, error) => { + console.log("Run completed", run); + }, + }); + + if (error) return
Error: {error.message}
; + + return
Run: {run.id}
; +} +``` + See our [Realtime documentation](/realtime) for more information about the type of the run object and more. ### useRealtimeRunsWithTag diff --git a/internal-packages/database/prisma/migrations/20241206135145_create_realtime_chunks_table/migration.sql b/internal-packages/database/prisma/migrations/20241206135145_create_realtime_chunks_table/migration.sql new file mode 100644 index 0000000000..975e77fd31 --- /dev/null +++ b/internal-packages/database/prisma/migrations/20241206135145_create_realtime_chunks_table/migration.sql @@ -0,0 +1,13 @@ +-- CreateTable +CREATE TABLE "RealtimeStreamChunk" ( + "id" TEXT NOT NULL, + "key" TEXT NOT NULL, + "value" TEXT NOT NULL, + "sequence" INTEGER NOT NULL, + "runId" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT "RealtimeStreamChunk_pkey" PRIMARY KEY ("id") +); + +-- Add index on (runID, createdAt) for efficient queries +CREATE INDEX "RealtimeStreamChunk_runId" ON "RealtimeStreamChunk" ("runId"); \ No newline at end of file diff --git a/internal-packages/database/prisma/migrations/20241208074324_add_created_at_index_to_realtime_stream_chunks/migration.sql b/internal-packages/database/prisma/migrations/20241208074324_add_created_at_index_to_realtime_stream_chunks/migration.sql new file mode 100644 index 0000000000..597fcceec0 --- /dev/null +++ b/internal-packages/database/prisma/migrations/20241208074324_add_created_at_index_to_realtime_stream_chunks/migration.sql @@ -0,0 +1,5 @@ +-- CreateIndex +CREATE INDEX "RealtimeStreamChunk_createdAt_idx" ON "RealtimeStreamChunk"("createdAt"); + +-- RenameIndex +ALTER INDEX "RealtimeStreamChunk_runId" RENAME TO "RealtimeStreamChunk_runId_idx"; diff --git a/internal-packages/database/prisma/schema.prisma b/internal-packages/database/prisma/schema.prisma index c4f2d87e1e..ceaa349bb0 100644 --- a/internal-packages/database/prisma/schema.prisma +++ b/internal-packages/database/prisma/schema.prisma @@ -2667,3 +2667,19 @@ enum BulkActionItemStatus { COMPLETED FAILED } + +model RealtimeStreamChunk { + id String @id @default(cuid()) + + key String + value String + + sequence Int + + runId String + + createdAt DateTime @default(now()) + + @@index([runId]) + @@index([createdAt]) +} diff --git a/internal-packages/testcontainers/src/index.ts b/internal-packages/testcontainers/src/index.ts index 77e5f6294f..c363aea4e8 100644 --- a/internal-packages/testcontainers/src/index.ts +++ b/internal-packages/testcontainers/src/index.ts @@ -1,10 +1,10 @@ import { StartedPostgreSqlContainer } from "@testcontainers/postgresql"; import { StartedRedisContainer } from "@testcontainers/redis"; +import { PrismaClient } from "@trigger.dev/database"; import { Redis } from "ioredis"; +import { Network, type StartedNetwork } from "testcontainers"; import { test } from "vitest"; -import { PrismaClient } from "@trigger.dev/database"; -import { createPostgresContainer, createRedisContainer, createElectricContainer } from "./utils"; -import { Network, type StartedNetwork, type StartedTestContainer } from "testcontainers"; +import { createElectricContainer, createPostgresContainer, createRedisContainer } from "./utils"; type NetworkContext = { network: StartedNetwork }; @@ -20,7 +20,8 @@ type ElectricContext = { }; type ContainerContext = NetworkContext & PostgresContext & RedisContext; -type ContainerWithElectricContext = ContainerContext & ElectricContext; +type ContainerWithElectricAndRedisContext = ContainerContext & ElectricContext; +type ContainerWithElectricContext = NetworkContext & PostgresContext & ElectricContext; type Use = (value: T) => Promise; @@ -97,6 +98,13 @@ export const containerTest = test.extend({ }); export const containerWithElectricTest = test.extend({ + network, + postgresContainer, + prisma, + electricOrigin, +}); + +export const containerWithElectricAndRedisTest = test.extend({ network, postgresContainer, prisma, diff --git a/packages/cli-v3/src/entryPoints/deploy-run-worker.ts b/packages/cli-v3/src/entryPoints/deploy-run-worker.ts index 7b4481c3dc..59ac12bb62 100644 --- a/packages/cli-v3/src/entryPoints/deploy-run-worker.ts +++ b/packages/cli-v3/src/entryPoints/deploy-run-worker.ts @@ -105,7 +105,8 @@ const durableClock = new DurableClock(); clock.setGlobalClock(durableClock); const runMetadataManager = new StandardMetadataManager( apiClientManager.clientOrThrow(), - getEnvVar("TRIGGER_STREAM_URL", getEnvVar("TRIGGER_API_URL")) ?? "https://api.trigger.dev" + getEnvVar("TRIGGER_STREAM_URL", getEnvVar("TRIGGER_API_URL")) ?? "https://api.trigger.dev", + (getEnvVar("TRIGGER_REALTIME_STREAM_VERSION") ?? "v1") as "v1" | "v2" ); runMetadata.setGlobalManager(runMetadataManager); const waitUntilManager = new StandardWaitUntilManager(); diff --git a/packages/cli-v3/src/entryPoints/dev-run-worker.ts b/packages/cli-v3/src/entryPoints/dev-run-worker.ts index c043cf6712..5d3052de1e 100644 --- a/packages/cli-v3/src/entryPoints/dev-run-worker.ts +++ b/packages/cli-v3/src/entryPoints/dev-run-worker.ts @@ -87,7 +87,8 @@ runtime.setGlobalRuntimeManager(devRuntimeManager); timeout.setGlobalManager(new UsageTimeoutManager(devUsageManager)); const runMetadataManager = new StandardMetadataManager( apiClientManager.clientOrThrow(), - getEnvVar("TRIGGER_STREAM_URL", getEnvVar("TRIGGER_API_URL")) ?? "https://api.trigger.dev" + getEnvVar("TRIGGER_STREAM_URL", getEnvVar("TRIGGER_API_URL")) ?? "https://api.trigger.dev", + (getEnvVar("TRIGGER_REALTIME_STREAM_VERSION") ?? "v1") as "v1" | "v2" ); runMetadata.setGlobalManager(runMetadataManager); const waitUntilManager = new StandardWaitUntilManager(); diff --git a/packages/core/package.json b/packages/core/package.json index ff7cd59330..70421f7acf 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -182,7 +182,7 @@ "check-exports": "attw --pack ." }, "dependencies": { - "@electric-sql/client": "0.7.1", + "@electric-sql/client": "0.9.0", "@google-cloud/precise-date": "^4.0.0", "@jsonhero/path": "^1.0.21", "@opentelemetry/api": "1.9.0", diff --git a/packages/core/src/v3/apiClient/index.ts b/packages/core/src/v3/apiClient/index.ts index dcded16468..2b778a14d8 100644 --- a/packages/core/src/v3/apiClient/index.ts +++ b/packages/core/src/v3/apiClient/index.ts @@ -139,6 +139,10 @@ export class ApiClient { return fetchClient; } + getHeaders() { + return this.#getHeaders(false); + } + async getRunResult( runId: string, requestOptions?: ZodFetchOptions diff --git a/packages/core/src/v3/apiClient/runStream.ts b/packages/core/src/v3/apiClient/runStream.ts index d5785f33d1..ac88270293 100644 --- a/packages/core/src/v3/apiClient/runStream.ts +++ b/packages/core/src/v3/apiClient/runStream.ts @@ -1,6 +1,11 @@ +import { EventSourceParserStream } from "eventsource-parser/stream"; import { DeserializedJson } from "../../schemas/json.js"; import { createJsonErrorObject } from "../errors.js"; -import { RunStatus, SubscribeRunRawShape } from "../schemas/api.js"; +import { + RunStatus, + SubscribeRealtimeStreamChunkRawShape, + SubscribeRunRawShape, +} from "../schemas/api.js"; import { SerializedError } from "../schemas/common.js"; import { AnyRunTypes, AnyTask, InferRunTypes } from "../types/tasks.js"; import { getEnvVar } from "../utils/getEnv.js"; @@ -11,8 +16,7 @@ import { } from "../utils/ioSerialization.js"; import { ApiError } from "./errors.js"; import { ApiClient } from "./index.js"; -import { AsyncIterableStream, createAsyncIterableStream, zodShapeStream } from "./stream.js"; -import { EventSourceParserStream } from "eventsource-parser/stream"; +import { AsyncIterableStream, createAsyncIterableReadable, zodShapeStream } from "./stream.js"; export type RunShape = TRunTypes extends AnyRunTypes ? { @@ -78,19 +82,42 @@ export function runShapeStream( url: string, options?: RunShapeStreamOptions ): RunSubscription { - const $options: RunSubscriptionOptions = { - provider: { - async onShape(callback) { - return zodShapeStream(SubscribeRunRawShape, url, callback, options); - }, - }, - streamFactory: new SSEStreamSubscriptionFactory( - getEnvVar("TRIGGER_STREAM_URL", getEnvVar("TRIGGER_API_URL")) ?? "https://api.trigger.dev", - { - headers: options?.headers, - signal: options?.signal, + const abortController = new AbortController(); + + const version1 = new SSEStreamSubscriptionFactory( + getEnvVar("TRIGGER_STREAM_URL", getEnvVar("TRIGGER_API_URL")) ?? "https://api.trigger.dev", + { + headers: options?.headers, + signal: abortController.signal, + } + ); + + const version2 = new ElectricStreamSubscriptionFactory( + getEnvVar("TRIGGER_STREAM_URL", getEnvVar("TRIGGER_API_URL")) ?? "https://api.trigger.dev", + { + headers: options?.headers, + signal: abortController.signal, + } + ); + + // If the user supplied AbortSignal is aborted, we should abort the internal controller + options?.signal?.addEventListener( + "abort", + () => { + if (!abortController.signal.aborted) { + abortController.abort(); } - ), + }, + { once: true } + ); + + const $options: RunSubscriptionOptions = { + runShapeStream: zodShapeStream(SubscribeRunRawShape, url, { + ...options, + signal: abortController.signal, + }), + streamFactory: new VersionedStreamSubscriptionFactory(version1, version2), + abortController, ...options, }; @@ -103,7 +130,12 @@ export interface StreamSubscription { } export interface StreamSubscriptionFactory { - createSubscription(runId: string, streamKey: string, baseUrl?: string): StreamSubscription; + createSubscription( + metadata: Record, + runId: string, + streamKey: string, + baseUrl?: string + ): StreamSubscription; } // Real implementation for production @@ -154,7 +186,12 @@ export class SSEStreamSubscriptionFactory implements StreamSubscriptionFactory { private options: { headers?: Record; signal?: AbortSignal } ) {} - createSubscription(runId: string, streamKey: string, baseUrl?: string): StreamSubscription { + createSubscription( + metadata: Record, + runId: string, + streamKey: string, + baseUrl?: string + ): StreamSubscription { if (!runId || !streamKey) { throw new Error("runId and streamKey are required"); } @@ -164,17 +201,89 @@ export class SSEStreamSubscriptionFactory implements StreamSubscriptionFactory { } } +// Real implementation for production +export class ElectricStreamSubscription implements StreamSubscription { + constructor( + private url: string, + private options: { headers?: Record; signal?: AbortSignal } + ) {} + + async subscribe(): Promise> { + return zodShapeStream(SubscribeRealtimeStreamChunkRawShape, this.url, this.options).pipeThrough( + new TransformStream({ + transform(chunk, controller) { + controller.enqueue(safeParseJSON(chunk.value)); + }, + }) + ); + } +} + +export class ElectricStreamSubscriptionFactory implements StreamSubscriptionFactory { + constructor( + private baseUrl: string, + private options: { headers?: Record; signal?: AbortSignal } + ) {} + + createSubscription( + metadata: Record, + runId: string, + streamKey: string, + baseUrl?: string + ): StreamSubscription { + if (!runId || !streamKey) { + throw new Error("runId and streamKey are required"); + } + + return new ElectricStreamSubscription( + `${baseUrl ?? this.baseUrl}/realtime/v2/streams/${runId}/${streamKey}`, + this.options + ); + } +} + +export class VersionedStreamSubscriptionFactory implements StreamSubscriptionFactory { + constructor( + private version1: StreamSubscriptionFactory, + private version2: StreamSubscriptionFactory + ) {} + + createSubscription( + metadata: Record, + runId: string, + streamKey: string, + baseUrl?: string + ): StreamSubscription { + if (!runId || !streamKey) { + throw new Error("runId and streamKey are required"); + } + + const version = + typeof metadata.$$streamsVersion === "string" ? metadata.$$streamsVersion : "v1"; + + if (version === "v1") { + return this.version1.createSubscription(metadata, runId, streamKey, baseUrl); + } + + if (version === "v2") { + return this.version2.createSubscription(metadata, runId, streamKey, baseUrl); + } + + throw new Error(`Unknown stream version: ${version}`); + } +} + export interface RunShapeProvider { onShape(callback: (shape: SubscribeRunRawShape) => Promise): Promise<() => void>; } export type RunSubscriptionOptions = RunShapeStreamOptions & { - provider: RunShapeProvider; + runShapeStream: ReadableStream; streamFactory: StreamSubscriptionFactory; + abortController: AbortController; }; export class RunSubscription { - private abortController: AbortController; private unsubscribeShape?: () => void; private stream: AsyncIterableStream>; private packetCache = new Map(); @@ -182,44 +291,37 @@ export class RunSubscription { private _isRunComplete = false; constructor(private options: RunSubscriptionOptions) { - this.abortController = new AbortController(); this._closeOnComplete = typeof options.closeOnComplete === "undefined" ? true : options.closeOnComplete; - const source = new ReadableStream({ - start: async (controller) => { - this.unsubscribeShape = await this.options.provider.onShape(async (shape) => { - controller.enqueue(shape); + this.stream = createAsyncIterableReadable( + this.options.runShapeStream, + { + transform: async (chunk, controller) => { + const run = await this.transformRunShape(chunk); + + controller.enqueue(run); - this._isRunComplete = !!shape.completedAt; + this._isRunComplete = !!run.finishedAt; if ( this._closeOnComplete && this._isRunComplete && - !this.abortController.signal.aborted + !this.options.abortController.signal.aborted ) { - controller.close(); - this.abortController.abort(); - } - }); - }, - cancel: () => { - this.unsubscribe(); - }, - }); - - this.stream = createAsyncIterableStream(source, { - transform: async (chunk, controller) => { - const run = await this.transformRunShape(chunk); + console.log("Closing stream because run is complete"); - controller.enqueue(run); + this.options.abortController.abort(); + } + }, }, - }); + this.options.abortController.signal + ); } unsubscribe(): void { - if (!this.abortController.signal.aborted) { - this.abortController.abort(); + if (!this.options.abortController.signal.aborted) { + this.options.abortController.abort(); } this.unsubscribeShape?.(); } @@ -238,59 +340,68 @@ export class RunSubscription { // Keep track of which streams we've already subscribed to const activeStreams = new Set(); - return createAsyncIterableStream(this.stream, { - transform: async (run, controller) => { - controller.enqueue({ - type: "run", - run, - }); - - // Check for stream metadata - if (run.metadata && "$$streams" in run.metadata && Array.isArray(run.metadata.$$streams)) { - for (const streamKey of run.metadata.$$streams) { - if (typeof streamKey !== "string") { - continue; - } + return createAsyncIterableReadable( + this.stream, + { + transform: async (run, controller) => { + controller.enqueue({ + type: "run", + run, + }); - if (!activeStreams.has(streamKey)) { - activeStreams.add(streamKey); - - const subscription = this.options.streamFactory.createSubscription( - run.id, - streamKey, - this.options.client?.baseUrl - ); - - const stream = await subscription.subscribe(); - - // Create the pipeline and start it - stream - .pipeThrough( - new TransformStream({ - transform(chunk, controller) { - controller.enqueue({ - type: streamKey, - chunk: chunk as TStreams[typeof streamKey], - run, - } as StreamPartResult, TStreams>); - }, - }) - ) - .pipeTo( - new WritableStream({ - write(chunk) { - controller.enqueue(chunk); - }, - }) - ) - .catch((error) => { - console.error(`Error in stream ${streamKey}:`, error); - }); + // Check for stream metadata + if ( + run.metadata && + "$$streams" in run.metadata && + Array.isArray(run.metadata.$$streams) + ) { + for (const streamKey of run.metadata.$$streams) { + if (typeof streamKey !== "string") { + continue; + } + + if (!activeStreams.has(streamKey)) { + activeStreams.add(streamKey); + + const subscription = this.options.streamFactory.createSubscription( + run.metadata, + run.id, + streamKey, + this.options.client?.baseUrl + ); + + const stream = await subscription.subscribe(); + + // Create the pipeline and start it + stream + .pipeThrough( + new TransformStream({ + transform(chunk, controller) { + controller.enqueue({ + type: streamKey, + chunk: chunk as TStreams[typeof streamKey], + run, + } as StreamPartResult, TStreams>); + }, + }) + ) + .pipeTo( + new WritableStream({ + write(chunk) { + controller.enqueue(chunk); + }, + }) + ) + .catch((error) => { + console.error(`Error in stream ${streamKey}:`, error); + }); + } } } - } + }, }, - }); + this.options.abortController.signal + ); } private async transformRunShape(row: SubscribeRunRawShape): Promise> { diff --git a/packages/core/src/v3/apiClient/stream.ts b/packages/core/src/v3/apiClient/stream.ts index 97f9ee816e..fd975a5416 100644 --- a/packages/core/src/v3/apiClient/stream.ts +++ b/packages/core/src/v3/apiClient/stream.ts @@ -1,5 +1,15 @@ import { z } from "zod"; -import { ApiError } from "./errors.js"; +import { + FetchError, + isChangeMessage, + isControlMessage, + Offset, + ShapeStream, + type Message, + type Row, + type ShapeStreamInterface, + // @ts-ignore it's safe to import types from the client +} from "@electric-sql/client"; export type ZodShapeStreamOptions = { headers?: Record; @@ -7,14 +17,11 @@ export type ZodShapeStreamOptions = { signal?: AbortSignal; }; -export async function zodShapeStream( +export function zodShapeStream( schema: TShapeSchema, url: string, - callback: (shape: z.output) => void | Promise, options?: ZodShapeStreamOptions ) { - const { ShapeStream, Shape, FetchError } = await import("@electric-sql/client"); - const stream = new ShapeStream>({ url, headers: { @@ -25,27 +32,21 @@ export async function zodShapeStream( signal: options?.signal, }); - try { - const shape = new Shape(stream); - - const initialRows = await shape.rows; + const readableShape = new ReadableShapeStream(stream); - for (const shapeRow of initialRows) { - await callback(schema.parse(shapeRow)); - } + return readableShape.stream.pipeThrough( + new TransformStream({ + async transform(chunk, controller) { + const result = schema.safeParse(chunk); - return shape.subscribe(async (newShape) => { - for (const shapeRow of newShape.rows) { - await callback(schema.parse(shapeRow)); - } - }); - } catch (error) { - if (error instanceof FetchError) { - throw ApiError.generate(error.status, error.json, error.message, error.headers); - } else { - throw error; - } - } + if (result.success) { + controller.enqueue(result.data); + } else { + controller.error(new Error(`Unable to parse shape: ${result.error.message}`)); + } + }, + }) + ); } export type AsyncIterableStream = AsyncIterable & ReadableStream; @@ -68,3 +69,137 @@ export function createAsyncIterableStream( return transformedStream; } + +export function createAsyncIterableReadable( + source: ReadableStream, + transformer: Transformer, + signal: AbortSignal +): AsyncIterableStream { + return new ReadableStream({ + async start(controller) { + const transformedStream = source.pipeThrough(new TransformStream(transformer)); + const reader = transformedStream.getReader(); + + signal.addEventListener("abort", () => { + queueMicrotask(() => { + reader.cancel(); + controller.close(); + }); + }); + + while (true) { + const { done, value } = await reader.read(); + if (done) { + controller.close(); + break; + } + + controller.enqueue(value); + } + }, + }) as AsyncIterableStream; +} + +class ReadableShapeStream = Row> { + readonly #stream: ShapeStreamInterface; + readonly #currentState: Map = new Map(); + readonly #changeStream: AsyncIterableStream; + #error: FetchError | false = false; + + constructor(stream: ShapeStreamInterface) { + this.#stream = stream; + + // Create the source stream that will receive messages + const source = new ReadableStream[]>({ + start: (controller) => { + this.#stream.subscribe( + (messages) => controller.enqueue(messages), + this.#handleError.bind(this) + ); + }, + }); + + // Create the transformed stream that processes messages and emits complete rows + this.#changeStream = createAsyncIterableStream(source, { + transform: (messages, controller) => { + messages.forEach((message) => { + if (isChangeMessage(message)) { + switch (message.headers.operation) { + case "insert": { + this.#currentState.set(message.key, message.value); + controller.enqueue(message.value); + break; + } + case "update": { + const existingRow = this.#currentState.get(message.key); + if (existingRow) { + const updatedRow = { + ...existingRow, + ...message.value, + }; + this.#currentState.set(message.key, updatedRow); + controller.enqueue(updatedRow); + } else { + this.#currentState.set(message.key, message.value); + controller.enqueue(message.value); + } + break; + } + } + } + + if (isControlMessage(message)) { + switch (message.headers.control) { + case "must-refetch": + this.#currentState.clear(); + this.#error = false; + break; + } + } + }); + }, + }); + } + + get stream(): AsyncIterableStream { + return this.#changeStream; + } + + get isUpToDate(): boolean { + return this.#stream.isUpToDate; + } + + get lastOffset(): Offset { + return this.#stream.lastOffset; + } + + get handle(): string | undefined { + return this.#stream.shapeHandle; + } + + get error() { + return this.#error; + } + + lastSyncedAt(): number | undefined { + return this.#stream.lastSyncedAt(); + } + + lastSynced() { + return this.#stream.lastSynced(); + } + + isLoading() { + return this.#stream.isLoading(); + } + + isConnected(): boolean { + return this.#stream.isConnected(); + } + + #handleError(e: Error): void { + if (e instanceof FetchError) { + this.#error = e; + } + } +} diff --git a/packages/core/src/v3/runMetadata/manager.ts b/packages/core/src/v3/runMetadata/manager.ts index 4c909051a5..ea04d3692b 100644 --- a/packages/core/src/v3/runMetadata/manager.ts +++ b/packages/core/src/v3/runMetadata/manager.ts @@ -20,7 +20,8 @@ export class StandardMetadataManager implements RunMetadataManager { constructor( private apiClient: ApiClient, - private streamsBaseUrl: string + private streamsBaseUrl: string, + private streamsVersion: "v1" | "v2" = "v1" ) {} public enterWithMetadata(metadata: Record): void { @@ -231,6 +232,7 @@ export class StandardMetadataManager implements RunMetadataManager { try { // Add the key to the special stream metadata object this.appendKey(`$$streams`, key); + this.setKey("$$streamsVersion", this.streamsVersion); await this.flush(); @@ -239,7 +241,9 @@ export class StandardMetadataManager implements RunMetadataManager { runId: this.runId, iterator: $value[Symbol.asyncIterator](), baseUrl: this.streamsBaseUrl, + headers: this.apiClient.getHeaders(), signal, + version: this.streamsVersion, }); this.activeStreams.set(key, streamInstance); diff --git a/packages/core/src/v3/runMetadata/metadataStream.ts b/packages/core/src/v3/runMetadata/metadataStream.ts index 6b4d28c714..ff32ae3dd2 100644 --- a/packages/core/src/v3/runMetadata/metadataStream.ts +++ b/packages/core/src/v3/runMetadata/metadataStream.ts @@ -3,7 +3,9 @@ export type MetadataOptions = { runId: string; key: string; iterator: AsyncIterator; + headers?: Record; signal?: AbortSignal; + version?: "v1" | "v2"; }; export class MetadataStream { @@ -62,10 +64,12 @@ export class MetadataStream { }); return fetch( - `${this.options.baseUrl}/realtime/v1/streams/${this.options.runId}/${this.options.key}`, + `${this.options.baseUrl}/realtime/${this.options.version ?? "v1"}/streams/${ + this.options.runId + }/${this.options.key}`, { method: "POST", - headers: {}, + headers: this.options.headers ?? {}, body: serverStream, // @ts-expect-error duplex: "half", diff --git a/packages/core/src/v3/schemas/api.ts b/packages/core/src/v3/schemas/api.ts index ce8bc9a89f..55189d8140 100644 --- a/packages/core/src/v3/schemas/api.ts +++ b/packages/core/src/v3/schemas/api.ts @@ -727,3 +727,16 @@ export const RetrieveBatchResponse = z.object({ }); export type RetrieveBatchResponse = z.infer; + +export const SubscribeRealtimeStreamChunkRawShape = z.object({ + id: z.string(), + runId: z.string(), + sequence: z.number(), + key: z.string(), + value: z.string(), + createdAt: z.coerce.date(), +}); + +export type SubscribeRealtimeStreamChunkRawShape = z.infer< + typeof SubscribeRealtimeStreamChunkRawShape +>; diff --git a/packages/core/test/runStream.test.ts b/packages/core/test/runStream.test.ts index f97f6645ef..3775246311 100644 --- a/packages/core/test/runStream.test.ts +++ b/packages/core/test/runStream.test.ts @@ -1,10 +1,8 @@ -import { describe, it, expect } from "vitest"; +import { describe, expect, it } from "vitest"; import { - AnyRunShape, RunSubscription, StreamSubscription, StreamSubscriptionFactory, - type RunShapeProvider, } from "../src/v3/apiClient/runStream.js"; import type { SubscribeRunRawShape } from "../src/v3/schemas/api.js"; @@ -33,64 +31,54 @@ class TestStreamSubscriptionFactory implements StreamSubscriptionFactory { this.streams.set(`${runId}:${streamKey}`, chunks); } - createSubscription(runId: string, streamKey: string): StreamSubscription { + createSubscription( + metadata: Record, + runId: string, + streamKey: string + ): StreamSubscription { const chunks = this.streams.get(`${runId}:${streamKey}`) ?? []; return new TestStreamSubscription(chunks); } } -// Create a real test provider that uses an array of shapes -class TestShapeProvider implements RunShapeProvider { - private shapes: SubscribeRunRawShape[]; - private unsubscribed = false; - - constructor(shapes: SubscribeRunRawShape[]) { - this.shapes = shapes; - } - - async onShape(callback: (shape: SubscribeRunRawShape) => Promise): Promise<() => void> { - // Process all shapes immediately - for (const shape of this.shapes) { - if (this.unsubscribed) break; - await callback(shape); - } - - return () => { - this.unsubscribed = true; - }; - } +// Remove the RunShapeProvider implementations and replace with stream creators +function createTestShapeStream( + shapes: SubscribeRunRawShape[] +): ReadableStream { + return new ReadableStream({ + start: async (controller) => { + // Emit all shapes immediately + for (const shape of shapes) { + controller.enqueue(shape); + } + controller.close(); + }, + }); } -// Add this new provider that can emit shapes over time -class DelayedTestShapeProvider implements RunShapeProvider { - private shapes: SubscribeRunRawShape[]; - private unsubscribed = false; - private currentShapeIndex = 0; - - constructor(shapes: SubscribeRunRawShape[]) { - this.shapes = shapes; - } - - async onShape(callback: (shape: SubscribeRunRawShape) => Promise): Promise<() => void> { - // Only emit the first shape immediately - if (this.shapes.length > 0) { - await callback(this.shapes[this.currentShapeIndex++]!); - } - - // Set up an interval to emit remaining shapes - const interval = setInterval(async () => { - if (this.unsubscribed || this.currentShapeIndex >= this.shapes.length) { - clearInterval(interval); - return; +function createDelayedTestShapeStream( + shapes: SubscribeRunRawShape[] +): ReadableStream { + return new ReadableStream({ + start: async (controller) => { + // Emit first shape immediately + if (shapes.length > 0) { + controller.enqueue(shapes[0]); } - await callback(this.shapes[this.currentShapeIndex++]!); - }, 100); - return () => { - this.unsubscribed = true; - clearInterval(interval); - }; - } + let currentShapeIndex = 1; + + // Emit remaining shapes with delay + const interval = setInterval(() => { + if (currentShapeIndex >= shapes.length) { + clearInterval(interval); + controller.close(); + return; + } + controller.enqueue(shapes[currentShapeIndex++]!); + }, 100); + }, + }); } describe("RunSubscription", () => { @@ -114,9 +102,10 @@ describe("RunSubscription", () => { ]; const subscription = new RunSubscription({ - provider: new TestShapeProvider(shapes), + runShapeStream: createTestShapeStream(shapes), streamFactory: new TestStreamSubscriptionFactory(), closeOnComplete: true, + abortController: new AbortController(), }); const results = await convertAsyncIterableToArray(subscription); @@ -153,9 +142,10 @@ describe("RunSubscription", () => { ]; const subscription = new RunSubscription({ - provider: new TestShapeProvider(shapes), + runShapeStream: createTestShapeStream(shapes), streamFactory: new TestStreamSubscriptionFactory(), closeOnComplete: true, + abortController: new AbortController(), }); const results = await convertAsyncIterableToArray(subscription); @@ -205,9 +195,10 @@ describe("RunSubscription", () => { ]; const subscription = new RunSubscription({ - provider: new DelayedTestShapeProvider(shapes), + runShapeStream: createDelayedTestShapeStream(shapes), streamFactory: new TestStreamSubscriptionFactory(), closeOnComplete: false, + abortController: new AbortController(), }); // Collect 2 results @@ -257,8 +248,9 @@ describe("RunSubscription", () => { ]; const subscription = new RunSubscription({ - provider: new TestShapeProvider(shapes), + runShapeStream: createTestShapeStream(shapes), streamFactory, + abortController: new AbortController(), }); const results = await collectNResults( @@ -289,9 +281,13 @@ describe("RunSubscription", () => { // Override createSubscription to count calls const originalCreate = streamFactory.createSubscription.bind(streamFactory); - streamFactory.createSubscription = (runId: string, streamKey: string) => { + streamFactory.createSubscription = ( + metadata: Record, + runId: string, + streamKey: string + ) => { streamCreationCount++; - return originalCreate(runId, streamKey); + return originalCreate(metadata, runId, streamKey); }; // Set up test chunks @@ -342,8 +338,9 @@ describe("RunSubscription", () => { ]; const subscription = new RunSubscription({ - provider: new TestShapeProvider(shapes), + runShapeStream: createTestShapeStream(shapes), streamFactory, + abortController: new AbortController(), }); const results = await collectNResults( @@ -421,8 +418,9 @@ describe("RunSubscription", () => { ]; const subscription = new RunSubscription({ - provider: new TestShapeProvider(shapes), + runShapeStream: createTestShapeStream(shapes), streamFactory, + abortController: new AbortController(), }); const results = await collectNResults( @@ -467,110 +465,6 @@ describe("RunSubscription", () => { run: { id: "run_123" }, }); }); - - it("should handle streams that appear in different run updates", async () => { - const streamFactory = new TestStreamSubscriptionFactory(); - - // Set up test chunks for two different streams - streamFactory.setStreamChunks("run_123", "openai", [ - { id: "openai1", content: "Hello" }, - { id: "openai2", content: "World" }, - ]); - streamFactory.setStreamChunks("run_123", "anthropic", [ - { id: "claude1", message: "Hi" }, - { id: "claude2", message: "There" }, - ]); - - const shapes = [ - // First run update - only has openai stream - { - id: "123", - friendlyId: "run_123", - taskIdentifier: "multi-streaming", - status: "EXECUTING", - createdAt: new Date(), - updatedAt: new Date(), - number: 1, - usageDurationMs: 100, - costInCents: 0, - baseCostInCents: 0, - isTest: false, - runTags: [], - metadata: JSON.stringify({ - $$streams: ["openai"], - }), - metadataType: "application/json", - }, - // Second run update - adds anthropic stream - { - id: "123", - friendlyId: "run_123", - taskIdentifier: "multi-streaming", - status: "EXECUTING", - createdAt: new Date(), - updatedAt: new Date(), - number: 1, - usageDurationMs: 200, - costInCents: 0, - baseCostInCents: 0, - isTest: false, - runTags: [], - metadata: JSON.stringify({ - $$streams: ["openai", "anthropic"], - }), - metadataType: "application/json", - }, - // Final run update - marks as complete - { - id: "123", - friendlyId: "run_123", - taskIdentifier: "multi-streaming", - status: "COMPLETED_SUCCESSFULLY", - createdAt: new Date(), - updatedAt: new Date(), - completedAt: new Date(), - number: 1, - usageDurationMs: 300, - costInCents: 0, - baseCostInCents: 0, - isTest: false, - runTags: [], - metadata: JSON.stringify({ - $$streams: ["openai", "anthropic"], - }), - metadataType: "application/json", - }, - ]; - - const subscription = new RunSubscription({ - provider: new TestShapeProvider(shapes), - streamFactory, - closeOnComplete: true, - }); - - const results = await collectNResults( - subscription.withStreams<{ - openai: { id: string; content: string }; - anthropic: { id: string; message: string }; - }>(), - 7 // 3 runs + 2 openai chunks + 2 anthropic chunks - ); - - expect(results).toHaveLength(7); - - // Verify run updates - const runUpdates = results.filter((r) => r.type === "run"); - expect(runUpdates).toHaveLength(3); - expect(runUpdates[2]!.run.status).toBe("COMPLETED"); - - // Verify openai chunks - const openaiChunks = results.filter((r) => r.type === "openai"); - expect(openaiChunks).toHaveLength(2); - - // Verify anthropic chunks - const anthropicChunks = results.filter((r) => r.type === "anthropic"); - expect(anthropicChunks).toHaveLength(2); - }); }); export async function convertAsyncIterableToArray(iterable: AsyncIterable): Promise { @@ -603,7 +497,12 @@ async function collectNResults( promise, new Promise((_, reject) => setTimeout( - () => reject(new Error(`Timeout waiting for ${count} results after ${timeoutMs}ms`)), + () => + reject( + new Error( + `Timeout waiting for ${count} results after ${timeoutMs}ms, but only had ${results.length}` + ) + ), timeoutMs ) ), diff --git a/packages/react-hooks/src/hooks/useRealtime.ts b/packages/react-hooks/src/hooks/useRealtime.ts index dc4d982ad0..9481125cff 100644 --- a/packages/react-hooks/src/hooks/useRealtime.ts +++ b/packages/react-hooks/src/hooks/useRealtime.ts @@ -12,6 +12,16 @@ export type UseRealtimeRunOptions = UseApiClientOptions & { experimental_throttleInMs?: number; }; +export type UseRealtimeSingleRunOptions = UseRealtimeRunOptions & { + /** + * Callback this is called when the run completes, an error occurs, or the subscription is stopped. + * + * @param {RealtimeRun} run - The run object + * @param {Error} [err] - The error that occurred + */ + onComplete?: (run: RealtimeRun, err?: Error) => void; +}; + export type UseRealtimeRunInstance = { run: RealtimeRun | undefined; @@ -28,7 +38,7 @@ export type UseRealtimeRunInstance = { * * @template TTask - The type of the task * @param {string} [runId] - The unique identifier of the run to subscribe to - * @param {UseRealtimeRunOptions} [options] - Configuration options for the subscription + * @param {UseRealtimeSingleRunOptions} [options] - Configuration options for the subscription * @returns {UseRealtimeRunInstance} An object containing the current state of the run, error handling, and control methods * * @example @@ -40,7 +50,7 @@ export type UseRealtimeRunInstance = { export function useRealtimeRun( runId?: string, - options?: UseRealtimeRunOptions + options?: UseRealtimeSingleRunOptions ): UseRealtimeRunInstance { const hookId = useId(); const idKey = options?.id ?? hookId; @@ -48,17 +58,17 @@ export function useRealtimeRun( // Store the streams state in SWR, using the idKey as the key to share states. const { data: run, mutate: mutateRun } = useSWR>([idKey, "run"], null); - // Keep the latest streams in a ref. - const runRef = useRef | undefined>(); - useEffect(() => { - runRef.current = run; - }, [run]); - const { data: error = undefined, mutate: setError } = useSWR( [idKey, "error"], null ); + // Add state to track when the subscription is complete + const { data: isComplete = false, mutate: setIsComplete } = useSWR( + [idKey, "complete"], + null + ); + // Abort controller to cancel the current API call. const abortControllerRef = useRef(null); @@ -93,9 +103,19 @@ export function useRealtimeRun( if (abortControllerRef.current) { abortControllerRef.current = null; } + + // Mark the subscription as complete + setIsComplete(true); } }, [runId, mutateRun, abortControllerRef, apiClient, setError]); + // Effect to handle onComplete callback + useEffect(() => { + if (isComplete && options?.onComplete && run) { + options.onComplete(run, error); + } + }, [isComplete, run, error, options?.onComplete]); + useEffect(() => { if (typeof options?.enabled === "boolean" && !options.enabled) { return; @@ -157,7 +177,7 @@ export function useRealtimeRunWithStreams< TStreams extends Record = Record, >( runId?: string, - options?: UseRealtimeRunOptions + options?: UseRealtimeSingleRunOptions ): UseRealtimeRunWithStreamsInstance { const hookId = useId(); const idKey = options?.id ?? hookId; @@ -182,11 +202,11 @@ export function useRealtimeRunWithStreams< // Store the streams state in SWR, using the idKey as the key to share states. const { data: run, mutate: mutateRun } = useSWR>([idKey, "run"], null); - // Keep the latest streams in a ref. - const runRef = useRef | undefined>(); - useEffect(() => { - runRef.current = run; - }, [run]); + // Add state to track when the subscription is complete + const { data: isComplete = false, mutate: setIsComplete } = useSWR( + [idKey, "complete"], + null + ); const { data: error = undefined, mutate: setError } = useSWR( [idKey, "error"], @@ -235,9 +255,19 @@ export function useRealtimeRunWithStreams< if (abortControllerRef.current) { abortControllerRef.current = null; } + + // Mark the subscription as complete + setIsComplete(true); } }, [runId, mutateRun, mutateStreams, streamsRef, abortControllerRef, apiClient, setError]); + // Effect to handle onComplete callback + useEffect(() => { + if (isComplete && options?.onComplete && run) { + options.onComplete(run, error); + } + }, [isComplete, run, error, options?.onComplete]); + useEffect(() => { if (typeof options?.enabled === "boolean" && !options.enabled) { return; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 2a1150f4c8..9c1e95c0aa 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1263,8 +1263,8 @@ importers: packages/core: dependencies: '@electric-sql/client': - specifier: 0.7.1 - version: 0.7.1 + specifier: 0.9.0 + version: 0.9.0 '@google-cloud/precise-date': specifier: ^4.0.0 version: 4.0.0 @@ -5112,8 +5112,8 @@ packages: '@rollup/rollup-darwin-arm64': 4.21.3 dev: false - /@electric-sql/client@0.7.1: - resolution: {integrity: sha512-NpKEn5hDSy+NaAdG9Ql8kIGfjrj/XfakJOOHTTutb99db3Dza0uUfnkqycFpyUAarFMQ4hYSKgx8AbOm1PCeFQ==} + /@electric-sql/client@0.9.0: + resolution: {integrity: sha512-UL2Gep9wPdGMTE0oEWVi0HA8R293R2OzFfHeAsN2LABYYl/boXss7nseNEiIV5+RjHPH7Tm8NsjH9iJW2rZkrQ==} optionalDependencies: '@rollup/rollup-darwin-arm64': 4.21.3 dev: false diff --git a/references/nextjs-realtime/src/app/realtime/page.tsx b/references/nextjs-realtime/src/app/realtime/page.tsx new file mode 100644 index 0000000000..ab2148eb9d --- /dev/null +++ b/references/nextjs-realtime/src/app/realtime/page.tsx @@ -0,0 +1,12 @@ +import RealtimeComparison from "@/components/RealtimeComparison"; +import { auth } from "@trigger.dev/sdk/v3"; + +export default async function RuntimeComparisonPage() { + const accessToken = await auth.createTriggerPublicToken("openai-streaming"); + + return ( +
+ +
+ ); +} diff --git a/references/nextjs-realtime/src/app/runs/[id]/ClientRunDetails.tsx b/references/nextjs-realtime/src/app/runs/[id]/ClientRunDetails.tsx index b4c1094fe7..cb729679b2 100644 --- a/references/nextjs-realtime/src/app/runs/[id]/ClientRunDetails.tsx +++ b/references/nextjs-realtime/src/app/runs/[id]/ClientRunDetails.tsx @@ -27,6 +27,9 @@ function RunDetailsWrapper({ const { run, error } = useRealtimeRun(runId, { accessToken, enabled: accessToken !== undefined, + onComplete: (run) => { + console.log("Run completed!", run); + }, }); if (error) { diff --git a/references/nextjs-realtime/src/components/RealtimeComparison.tsx b/references/nextjs-realtime/src/components/RealtimeComparison.tsx new file mode 100644 index 0000000000..009798eab7 --- /dev/null +++ b/references/nextjs-realtime/src/components/RealtimeComparison.tsx @@ -0,0 +1,98 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { useRealtimeRunWithStreams, useTaskTrigger } from "@trigger.dev/react-hooks"; +import type { STREAMS, openaiStreaming } from "@/trigger/ai"; + +export default function RealtimeComparison({ accessToken }: { accessToken: string }) { + const trigger = useTaskTrigger("openai-streaming", { + accessToken, + baseURL: process.env.NEXT_PUBLIC_TRIGGER_API_URL, + }); + + const { streams, stop, run } = useRealtimeRunWithStreams( + trigger.handle?.id, + { + accessToken: trigger.handle?.publicAccessToken, + enabled: !!trigger.handle, + baseURL: process.env.NEXT_PUBLIC_TRIGGER_API_URL, + onComplete: (...args) => { + console.log("Run completed!", args); + }, + } + ); + + return ( +
+
+ + + {run && ( + + )} +
+
+
+ + + + + + + + + {(streams.openai ?? []).map((part, i) => ( + + + + + ))} + +
IDData
{i + 1} +
+ {JSON.stringify(part)} +
+
+
+
+ + + + + + + + + {(streams.openaiText ?? []).map((text, i) => ( + + + + + ))} + +
IDData
{i + 1} +
{text}
+
+
+
+
+ ); +} diff --git a/references/nextjs-realtime/src/trigger/ai.ts b/references/nextjs-realtime/src/trigger/ai.ts index f3affeb7cd..70cd498fe9 100644 --- a/references/nextjs-realtime/src/trigger/ai.ts +++ b/references/nextjs-realtime/src/trigger/ai.ts @@ -9,7 +9,10 @@ const openaiSDK = new OpenAI({ apiKey: process.env.OPENAI_API_KEY, }); -export type STREAMS = { openai: TextStreamPart<{ getWeather: typeof weatherTask.tool }> }; +export type STREAMS = { + openai: TextStreamPart<{ getWeather: typeof weatherTask.tool }>; + openaiText: string; +}; export const openaiConsumer = schemaTask({ id: "openai-consumer", @@ -105,18 +108,7 @@ export const openaiStreaming = schemaTask({ }); const stream = await metadata.stream("openai", result.fullStream); - - let text = ""; - - for await (const chunk of stream) { - logger.log("Received chunk", { chunk }); - - if (chunk.type === "text-delta") { - text += chunk.textDelta; - } - } - - return { text }; + await metadata.stream("openaiText", result.textStream); }, }); diff --git a/references/nextjs-realtime/src/trigger/example.ts b/references/nextjs-realtime/src/trigger/example.ts index c768e78ad2..031888d187 100644 --- a/references/nextjs-realtime/src/trigger/example.ts +++ b/references/nextjs-realtime/src/trigger/example.ts @@ -15,12 +15,7 @@ export const exampleTask = schemaTask({ metadata.set("status", { type: "started", progress: 0.1 }); - if (Math.random() < 0.9) { - // Simulate a failure - throw new Error("Random failure"); - } - - await setTimeout(20000); + await setTimeout(2000); metadata.set("status", { type: "processing", progress: 0.5 });