Skip to content

Commit

Permalink
Vercel ai sdk impl (#382)
Browse files Browse the repository at this point in the history
* ai sdk client (WIP)

* unify LLM tool type

* replace anthropic types

* update ollama tool usage

* delete old ai sdk client

* new ai sdk client

* ai sdk example

* add comment

* changeset

* fix ollama tool usage

* remove changeset

* use default logger

* fixed messages type

* message type fix

* update deps

* update type

* migrate to new options syntax

* fix

* input logger in extract/observe

* remove AISdkClient logger

* changeset

* aisdk use StagehandConfig

* change aisdk model to gemini

* Revert "Merge branch 'sameel/move-llm-logger' into vercel-ai-sdk-impl"

This reverts commit ec63bf4, reversing
changes made to e575d88.

* lint error

* changeset
  • Loading branch information
sameelarif authored Jan 8, 2025
1 parent 5899ec2 commit a41271b
Show file tree
Hide file tree
Showing 6 changed files with 743 additions and 119 deletions.
5 changes: 5 additions & 0 deletions .changeset/shiny-ladybugs-shave.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": patch
---

Added example implementation of the Vercel AI SDK as an LLMClient
40 changes: 40 additions & 0 deletions examples/ai_sdk_example.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { google } from "@ai-sdk/google";
import { z } from "zod";
import { Stagehand } from "../lib";
import { AISdkClient } from "./external_clients/aisdk";
import StagehandConfig from "./stagehand.config";

async function example() {
const stagehand = new Stagehand({
...StagehandConfig,
llmClient: new AISdkClient({
model: google("gemini-1.5-flash-latest"),
}),
});

await stagehand.init();
await stagehand.page.goto("https://news.ycombinator.com");

const headlines = await stagehand.page.extract({
instruction: "Extract only 3 stories from the Hacker News homepage.",
schema: z.object({
stories: z
.array(
z.object({
title: z.string(),
url: z.string(),
points: z.number(),
}),
)
.length(3),
}),
});

console.log(headlines);

await stagehand.close();
}

(async () => {
await example();
})();
112 changes: 112 additions & 0 deletions examples/external_clients/aisdk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import {
CoreAssistantMessage,
CoreMessage,
CoreSystemMessage,
CoreTool,
CoreUserMessage,
generateObject,
generateText,
ImagePart,
LanguageModel,
TextPart,
} from "ai";
import { ChatCompletion } from "openai/resources/chat/completions";
import {
CreateChatCompletionOptions,
LLMClient,
} from "../../lib/llm/LLMClient";
import { AvailableModel } from "../../types/model";

export class AISdkClient extends LLMClient {
public type = "aisdk" as const;
private model: LanguageModel;

constructor({ model }: { model: LanguageModel }) {
super(model.modelId as AvailableModel);
this.model = model;
}

async createChatCompletion<T = ChatCompletion>({
options,
}: CreateChatCompletionOptions): Promise<T> {
const formattedMessages: CoreMessage[] = options.messages.map((message) => {
if (Array.isArray(message.content)) {
if (message.role === "system") {
const systemMessage: CoreSystemMessage = {
role: "system",
content: message.content
.map((c) => ("text" in c ? c.text : ""))
.join("\n"),
};
return systemMessage;
}

const contentParts = message.content.map((content) => {
if ("image_url" in content) {
const imageContent: ImagePart = {
type: "image",
image: content.image_url.url,
};
return imageContent;
} else {
const textContent: TextPart = {
type: "text",
text: content.text,
};
return textContent;
}
});

if (message.role === "user") {
const userMessage: CoreUserMessage = {
role: "user",
content: contentParts,
};
return userMessage;
} else {
const textOnlyParts = contentParts.map((part) => ({
type: "text" as const,
text: part.type === "image" ? "[Image]" : part.text,
}));
const assistantMessage: CoreAssistantMessage = {
role: "assistant",
content: textOnlyParts,
};
return assistantMessage;
}
}

return {
role: message.role,
content: message.content,
};
});

if (options.response_model) {
const response = await generateObject({
model: this.model,
messages: formattedMessages,
schema: options.response_model.schema,
});

return response.object;
}

const tools: Record<string, CoreTool> = {};

for (const rawTool of options.tools) {
tools[rawTool.name] = {
description: rawTool.description,
parameters: rawTool.parameters,
};
}

const response = await generateText({
model: this.model,
messages: formattedMessages,
tools,
});

return response as T;
}
}
1 change: 1 addition & 0 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ export class Stagehand {
this.llmClient = undefined;
}
}

this.domSettleTimeoutMs = domSettleTimeoutMs ?? 30_000;
this.headless = headless ?? false;
this.browserbaseSessionCreateParams = browserbaseSessionCreateParams;
Expand Down
Loading

0 comments on commit a41271b

Please sign in to comment.