Skip to content

Commit

Permalink
Made traces work with streaming in the AI SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
mattpocock committed Dec 26, 2024
1 parent bb16cd5 commit 7307a99
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 29 deletions.
5 changes: 5 additions & 0 deletions .changeset/big-schools-sparkle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"evalite": patch
---

Made traceAISDKModel work with streamText.
28 changes: 26 additions & 2 deletions packages/evalite-tests/tests/ai-sdk-traces.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { createDatabase, getEvalsAsRecord } from "@evalite/core/db";
import { runVitest } from "evalite/runner";
import { expect, it } from "vitest";
import { assert, expect, it } from "vitest";
import { captureStdout, loadFixture } from "./test-utils.js";

it("Should report traces from traceAISDKModel", async () => {
it("Should report traces from generateText using traceAISDKModel", async () => {
using fixture = loadFixture("ai-sdk-traces");

const captured = captureStdout();
Expand All @@ -21,3 +21,27 @@ it("Should report traces from traceAISDKModel", async () => {

expect(evals["AI SDK Traces"]![0]?.results[0]?.traces).toHaveLength(1);
});

it("Should report traces from streamText using traceAISDKModel", async () => {
using fixture = loadFixture("ai-sdk-traces-stream");

const captured = captureStdout();

await runVitest({
cwd: fixture.dir,
path: undefined,
testOutputWritable: captured.writable,
mode: "run-once-and-exit",
});

const db = createDatabase(fixture.dbLocation);

const evals = await getEvalsAsRecord(db);

const traces = evals["AI SDK Traces"]![0]?.results[0]?.traces;

assert(traces?.[0], "Expected a trace to be reported");

expect(traces?.[0].completion_tokens).toEqual(10);
expect(traces?.[0].prompt_tokens).toEqual(3);
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { streamText } from "ai";
import { MockLanguageModelV1, simulateReadableStream } from "ai/test";
import { Levenshtein } from "autoevals";
import { evalite } from "evalite";
import { traceAISDKModel } from "evalite/ai-sdk";

const model = new MockLanguageModelV1({
doStream: async () => ({
stream: simulateReadableStream({
chunks: [
{ type: "text-delta", textDelta: "Hello" },
{ type: "text-delta", textDelta: ", " },
{ type: "text-delta", textDelta: `world!` },
{
type: "finish",
finishReason: "stop",
logprobs: undefined,
usage: { completionTokens: 10, promptTokens: 3 },
},
],
}),
rawCall: { rawPrompt: null, rawSettings: {} },
}),
});

const tracedModel = traceAISDKModel(model);

evalite("AI SDK Traces", {
data: () => {
return [
{
input: "abc",
expected: "abcdef",
},
];
},
task: async (input) => {
const result = await streamText({
model: tracedModel,
system: "Test system",
prompt: input,
});
return result.textStream;
},
scorers: [Levenshtein],
});
98 changes: 71 additions & 27 deletions packages/evalite/src/ai-sdk.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,42 @@
import { experimental_wrapLanguageModel, type LanguageModelV1 } from "ai";
import {
experimental_wrapLanguageModel,
type LanguageModelV1,
type LanguageModelV1StreamPart,
type LanguageModelV1CallOptions,
} from "ai";
import { reportTrace, shouldReportTrace } from "./traces.js";

const processPromptForTracing = (
prompt: LanguageModelV1CallOptions["prompt"]
) => {
return prompt.map((prompt) => {
if (!Array.isArray(prompt.content)) {
return {
role: prompt.role,
content: prompt.content,
};
}

const content = prompt.content.map((content) => {
if (content.type !== "text") {
throw new Error(
`Unsupported content type: ${content.type}. Only text is currently supported by traceAISDKModel.`
);
}

return {
type: "text" as const,
text: content.text,
};
});

return {
role: prompt.role,
content,
};
});
};

export const traceAISDKModel = (model: LanguageModelV1): LanguageModelV1 => {
if (!shouldReportTrace()) return model;
return experimental_wrapLanguageModel({
Expand All @@ -13,39 +49,47 @@ export const traceAISDKModel = (model: LanguageModelV1): LanguageModelV1 => {

reportTrace({
output: generated.text ?? "",
input: opts.params.prompt.map((prompt) => {
if (!Array.isArray(prompt.content)) {
return {
role: prompt.role,
content: prompt.content,
};
}

const content = prompt.content.map((content) => {
if (content.type !== "text") {
throw new Error(
`Unsupported content type: ${content.type}. Only text is currently supported by traceAISDKModel.`
);
}

return {
type: "text" as const,
text: content.text,
};
});

return {
role: prompt.role,
content,
};
}),
input: processPromptForTracing(opts.params.prompt),
usage: generated.usage,
start,
end,
});

return generated;
},
wrapStream: async ({ doStream, params, model }) => {
const start = performance.now();
const { stream, ...rest } = await doStream();

const fullResponse: LanguageModelV1StreamPart[] = [];

const transformStream = new TransformStream<
LanguageModelV1StreamPart,
LanguageModelV1StreamPart
>({
transform(chunk, controller) {
fullResponse.push(chunk);
controller.enqueue(chunk);
},
flush() {
const usage = fullResponse.find(
(part) => part.type === "finish"
)?.usage;
reportTrace({
start,
end: performance.now(),
input: processPromptForTracing(params.prompt),
output: fullResponse,
usage,
});
},
});

return {
stream: stream.pipeThrough(transformStream),
...rest,
};
},
},
});
};

0 comments on commit 7307a99

Please sign in to comment.