Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streamline types #222

Merged
merged 11 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/swift-fishes-fail.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": patch
---

Streamline type definitions and fix existing typescript errors
3 changes: 2 additions & 1 deletion evals/index.eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ import { Stagehand } from "../lib";
import { z } from "zod";
import process from "process";
import { EvalLogger } from "./utils";
import { AvailableModel, LogLine } from "../lib/types";
import { AvailableModel } from "../types/model";
import { LogLine } from "../types/log";

const env: "BROWSERBASE" | "LOCAL" =
process.env.EVAL_ENV?.toLowerCase() === "browserbase"
Expand Down
2 changes: 1 addition & 1 deletion evals/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { LogLine } from "../lib/types";
import { Stagehand } from "../lib";
import { logLineToString } from "../lib/utils";
import { LogLine } from "../types/log";

type LogLineEval = LogLine & {
parsedAuxiliary?: string | object;
Expand Down
2 changes: 1 addition & 1 deletion lib/cache/ActionCache.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { LogLine } from "../../lib/types";
import { LogLine } from "../../types/log";
import { BaseCache, CacheEntry } from "./BaseCache";

export interface PlaywrightCommand {
Expand Down
2 changes: 1 addition & 1 deletion lib/cache/BaseCache.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as fs from "fs";
import * as path from "path";
import * as crypto from "crypto";
import { LogLine } from "../../lib/types";
import { LogLine } from "../../types/log";

export interface CacheEntry {
timestamp: number;
Expand Down
2 changes: 1 addition & 1 deletion lib/dom/debug.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
async function debugDom() {
export async function debugDom() {
window.chunkNumber = 0;

const { selectorMap: multiSelectorMap, outputString } =
Expand Down
2 changes: 1 addition & 1 deletion lib/dom/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
async function waitForDomSettle() {
export async function waitForDomSettle() {
return new Promise<void>((resolve) => {
const createTimeout = () => {
return setTimeout(() => {
Expand Down
11 changes: 5 additions & 6 deletions lib/handlers/actHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ import { Stagehand } from "../index";
import { LLMProvider } from "../llm/LLMProvider";
import { ScreenshotService } from "../vision";
import { verifyActCompletion, act, fillInVariables } from "../inference";
import {
LogLine,
PlaywrightCommandException,
PlaywrightCommandMethodNotSupportedException,
} from "../types";
import { Locator, Page } from "@playwright/test";
import { ActionCache } from "../cache/ActionCache";
import { LLMClient, modelsWithVision } from "../llm/LLMClient";
import { generateId } from "../utils";
import { LogLine } from "../../types/log";
import {
PlaywrightCommandException,
PlaywrightCommandMethodNotSupportedException,
} from "../../types/playwright";

export class StagehandActHandler {
private readonly stagehand: Stagehand;
Expand Down Expand Up @@ -1096,7 +1096,6 @@ export class StagehandActHandler {
action,
domElements: outputString,
steps,
llmProvider: this.llmProvider,
llmClient,
screenshot: annotatedScreenshot,
logger: this.logger,
Expand Down
3 changes: 1 addition & 2 deletions lib/handlers/extractHandler.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { LLMProvider } from "../llm/LLMProvider";
import { Stagehand } from "../index";
import { z } from "zod";
import { AvailableModel, LogLine } from "../types";
import { LogLine } from "../../types/log";
import { extract } from "../inference";
import { LLMClient } from "../llm/LLMClient";

Expand Down Expand Up @@ -114,7 +114,6 @@ export class StagehandExtractHandler {
progress,
previouslyExtractedContent: content,
domElements: outputString,
llmProvider: this.llmProvider,
schema,
llmClient,
chunksSeen: chunksSeen.length,
Expand Down
9 changes: 4 additions & 5 deletions lib/handlers/observeHandler.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { LLMProvider } from "../llm/LLMProvider";
import { LogLine, AvailableModel } from "../types";
import { LogLine } from "../../types/log";
import { Stagehand } from "../index";
import { observe } from "../inference";
import { LLMClient, modelsWithVision } from "../llm/LLMClient";
import { ScreenshotService } from "../vision";
import { LLMClient } from "../llm/LLMClient";
import { LLMProvider } from "../llm/LLMProvider";
import { generateId } from "../utils";
import { ScreenshotService } from "../vision";

export class StagehandObserveHandler {
private readonly stagehand: Stagehand;
Expand Down Expand Up @@ -134,7 +134,6 @@ export class StagehandObserveHandler {
const observationResponse = await observe({
instruction,
domElements: outputString,
llmProvider: this.llmProvider,
llmClient,
image: annotatedScreenshot,
requestId,
Expand Down
109 changes: 40 additions & 69 deletions lib/index.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
import { type Page, type BrowserContext, chromium } from "@playwright/test";
import { z } from "zod";
import { Browserbase } from "@browserbasehq/sdk";
import { type BrowserContext, chromium, type Page } from "@playwright/test";
import { randomUUID } from "crypto";
import fs from "fs";
import { Browserbase, ClientOptions } from "@browserbasehq/sdk";
import { LLMProvider } from "./llm/LLMProvider";
import { AvailableModel } from "./types";
// @ts-ignore we're using a built js file as a string here
import os from "os";
import { z } from "zod";
import { BrowserResult } from "../types/browser";
import { LogLine } from "../types/log";
import {
ActOptions,
ActResult,
ConstructorParams,
ExtractOptions,
ExtractResult,
InitFromPageOptions,
InitFromPageResult,
InitOptions,
InitResult,
ObserveOptions,
ObserveResult,
} from "../types/stagehand";
import { scriptContent } from "./dom/build/scriptContent";
import { LogLine } from "./types";
import { randomUUID } from "crypto";
import { logLineToString } from "./utils";
import { StagehandActHandler } from "./handlers/actHandler";
import { StagehandExtractHandler } from "./handlers/extractHandler";
import { StagehandObserveHandler } from "./handlers/observeHandler";
import { StagehandActHandler } from "./handlers/actHandler";
import { LLMClient } from "./llm/LLMClient";
import { LLMProvider } from "./llm/LLMProvider";
import { logLineToString } from "./utils";

require("dotenv").config({ path: ".env" });

Expand All @@ -26,7 +39,7 @@ async function getBrowser(
logger: (message: LogLine) => void,
browserbaseSessionCreateParams?: Browserbase.Sessions.SessionCreateParams,
browserbaseResumeSessionID?: string,
) {
): Promise<BrowserResult> {
if (env === "BROWSERBASE") {
if (!apiKey) {
logger({
Expand Down Expand Up @@ -184,7 +197,7 @@ async function getBrowser(
},
});

const tmpDir = fs.mkdtempSync(`/tmp/pwtest`);
const tmpDir = fs.mkdtempSync("/tmp/pwtest");
fs.mkdirSync(`${tmpDir}/userdir/Default`, { recursive: true });

const defaultPreferences = {
Expand Down Expand Up @@ -311,22 +324,7 @@ export class Stagehand {
browserbaseResumeSessionID,
modelName,
modelClientOptions,
}: {
env: "LOCAL" | "BROWSERBASE";
apiKey?: string;
projectId?: string;
verbose?: 0 | 1 | 2;
debugDom?: boolean;
llmProvider?: LLMProvider;
headless?: boolean;
logger?: (message: LogLine) => void;
domSettleTimeoutMs?: number;
browserBaseSessionCreateParams?: Browserbase.Sessions.SessionCreateParams;
enableCaching?: boolean;
browserbaseResumeSessionID?: string;
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
} = {
}: ConstructorParams = {
env: "BROWSERBASE",
},
) {
Expand Down Expand Up @@ -356,14 +354,7 @@ export class Stagehand {
modelName,
modelClientOptions,
domSettleTimeoutMs,
}: {
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
domSettleTimeoutMs?: number;
} = {}): Promise<{
debugUrl: string;
sessionUrl: string;
}> {
}: InitOptions = {}): Promise<InitResult> {
const llmClient = modelName
? this.llmProvider.getClient(modelName, modelClientOptions)
: this.llmClient;
Expand All @@ -377,7 +368,11 @@ export class Stagehand {
this.browserbaseResumeSessionID,
).catch((e) => {
console.error("Error in init:", e);
return { context: undefined, debugUrl: undefined, sessionUrl: undefined };
return {
context: undefined,
debugUrl: undefined,
sessionUrl: undefined,
} as BrowserResult;
});
this.context = context;
this.page = context.pages()[0];
Expand Down Expand Up @@ -442,11 +437,11 @@ export class Stagehand {
return { debugUrl, sessionUrl };
}

async initFromPage(
page: Page,
modelName?: AvailableModel,
modelClientOptions?: ClientOptions,
): Promise<{ context: BrowserContext }> {
async initFromPage({
page,
modelName,
modelClientOptions,
}: InitFromPageOptions): Promise<InitFromPageResult> {
this.page = page;
this.context = page.context();
this.llmClient = modelName
Expand Down Expand Up @@ -474,7 +469,6 @@ export class Stagehand {
return { context: this.context };
}

// Logging
private pending_logs_to_send_to_browserbase: LogLine[] = [];

private is_processing_browserbase_logs: boolean = false;
Expand Down Expand Up @@ -653,18 +647,7 @@ export class Stagehand {
useVision = "fallback",
variables = {},
domSettleTimeoutMs,
}: {
action: string;
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
useVision?: "fallback" | boolean;
variables?: Record<string, string>;
domSettleTimeoutMs?: number;
}): Promise<{
success: boolean;
message: string;
action: string;
}> {
}: ActOptions): Promise<ActResult> {
if (!this.actHandler) {
throw new Error("Act handler not initialized");
}
Expand Down Expand Up @@ -743,13 +726,7 @@ export class Stagehand {
modelName,
modelClientOptions,
domSettleTimeoutMs,
}: {
instruction: string;
schema: T;
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
domSettleTimeoutMs?: number;
}): Promise<z.infer<T>> {
}: ExtractOptions<T>): Promise<ExtractResult<T>> {
if (!this.extractHandler) {
throw new Error("Extract handler not initialized");
}
Expand Down Expand Up @@ -812,13 +789,7 @@ export class Stagehand {
});
}

async observe(options?: {
instruction?: string;
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
useVision?: boolean;
domSettleTimeoutMs?: number;
}): Promise<{ selector: string; description: string }[]> {
async observe(options?: ObserveOptions): Promise<ObserveResult[]> {
if (!this.observeHandler) {
throw new Error("Observe handler not initialized");
}
Expand Down
31 changes: 4 additions & 27 deletions lib/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import {
ChatMessage,
LLMClient,
} from "./llm/LLMClient";
import { VerifyActCompletionParams } from "../types/inference";
import { ActResult, ActParams } from "../types/act";

export async function verifyActCompletion({
goal,
Expand All @@ -30,15 +32,7 @@ export async function verifyActCompletion({
domElements,
logger,
requestId,
}: {
goal: string;
steps: string;
llmClient: LLMClient;
screenshot?: Buffer;
domElements?: string;
logger: (message: { category?: string; message: string }) => void;
requestId: string;
}): Promise<boolean> {
}: VerifyActCompletionParams): Promise<boolean> {
const messages: ChatMessage[] = [
buildVerifyActCompletionSystemPrompt(),
buildVerifyActCompletionUserPrompt(goal, steps, domElements),
Expand Down Expand Up @@ -106,24 +100,7 @@ export async function act({
logger,
requestId,
variables,
}: {
action: string;
steps?: string;
domElements: string;
llmClient: LLMClient;
screenshot?: Buffer;
retries?: number;
logger: (message: { category?: string; message: string }) => void;
requestId: string;
variables?: Record<string, string>;
}): Promise<{
method: string;
element: number;
args: any[];
completed: boolean;
step: string;
why?: string;
} | null> {
}: ActParams): Promise<ActResult | null> {
const messages: ChatMessage[] = [
buildActSystemPrompt(),
buildActUserPrompt(action, steps, domElements, variables),
Expand Down
7 changes: 4 additions & 3 deletions lib/llm/AnthropicClient.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import Anthropic, { ClientOptions } from "@anthropic-ai/sdk";
import { LLMClient, ChatCompletionOptions } from "./LLMClient";
import { Message, MessageCreateParams } from "@anthropic-ai/sdk/resources";
import { zodToJsonSchema } from "zod-to-json-schema";
import { LogLine } from "../../types/log";
import { AvailableModel } from "../../types/model";
import { LLMCache } from "../cache/LLMCache";
import { AvailableModel, LogLine } from "../types";
import { Message, MessageCreateParams } from "@anthropic-ai/sdk/resources";
import { ChatCompletionOptions, LLMClient } from "./LLMClient";

export class AnthropicClient extends LLMClient {
private client: Anthropic;
Expand Down
Loading