Skip to content

Commit

Permalink
streamline types (#222)
Browse files Browse the repository at this point in the history
* streamline types

* export utils/debug

* expose startDomDebug

* prevent node_modules checking

* create changeset

* add chat message type

* convert types to interfaces

* format messages to prevent casting

* make startDomDebug private

* Revert "format messages to prevent casting"

This reverts commit 1963130.

* formatting
  • Loading branch information
sameelarif authored Nov 26, 2024
1 parent 647eefd commit 8dff026
Show file tree
Hide file tree
Showing 26 changed files with 260 additions and 166 deletions.
5 changes: 5 additions & 0 deletions .changeset/swift-fishes-fail.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": patch
---

Streamline type definitions and fix existing typescript errors
3 changes: 2 additions & 1 deletion evals/index.eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ import { Stagehand } from "../lib";
import { z } from "zod";
import process from "process";
import { EvalLogger } from "./utils";
import { AvailableModel, LogLine } from "../lib/types";
import { AvailableModel } from "../types/model";
import { LogLine } from "../types/log";

const env: "BROWSERBASE" | "LOCAL" =
process.env.EVAL_ENV?.toLowerCase() === "browserbase"
Expand Down
2 changes: 1 addition & 1 deletion evals/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { LogLine } from "../lib/types";
import { Stagehand } from "../lib";
import { logLineToString } from "../lib/utils";
import { LogLine } from "../types/log";

type LogLineEval = LogLine & {
parsedAuxiliary?: string | object;
Expand Down
2 changes: 1 addition & 1 deletion lib/cache/ActionCache.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { LogLine } from "../../lib/types";
import { LogLine } from "../../types/log";
import { BaseCache, CacheEntry } from "./BaseCache";

export interface PlaywrightCommand {
Expand Down
2 changes: 1 addition & 1 deletion lib/cache/BaseCache.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as fs from "fs";
import * as path from "path";
import * as crypto from "crypto";
import { LogLine } from "../../lib/types";
import { LogLine } from "../../types/log";

export interface CacheEntry {
timestamp: number;
Expand Down
2 changes: 1 addition & 1 deletion lib/dom/debug.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
async function debugDom() {
export async function debugDom() {
window.chunkNumber = 0;

const { selectorMap: multiSelectorMap, outputString } =
Expand Down
2 changes: 1 addition & 1 deletion lib/dom/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
async function waitForDomSettle() {
export async function waitForDomSettle() {
return new Promise<void>((resolve) => {
const createTimeout = () => {
return setTimeout(() => {
Expand Down
11 changes: 5 additions & 6 deletions lib/handlers/actHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ import { Stagehand } from "../index";
import { LLMProvider } from "../llm/LLMProvider";
import { ScreenshotService } from "../vision";
import { verifyActCompletion, act, fillInVariables } from "../inference";
import {
LogLine,
PlaywrightCommandException,
PlaywrightCommandMethodNotSupportedException,
} from "../types";
import { Locator, Page } from "@playwright/test";
import { ActionCache } from "../cache/ActionCache";
import { LLMClient, modelsWithVision } from "../llm/LLMClient";
import { generateId } from "../utils";
import { LogLine } from "../../types/log";
import {
PlaywrightCommandException,
PlaywrightCommandMethodNotSupportedException,
} from "../../types/playwright";

export class StagehandActHandler {
private readonly stagehand: Stagehand;
Expand Down Expand Up @@ -1096,7 +1096,6 @@ export class StagehandActHandler {
action,
domElements: outputString,
steps,
llmProvider: this.llmProvider,
llmClient,
screenshot: annotatedScreenshot,
logger: this.logger,
Expand Down
3 changes: 1 addition & 2 deletions lib/handlers/extractHandler.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { LLMProvider } from "../llm/LLMProvider";
import { Stagehand } from "../index";
import { z } from "zod";
import { AvailableModel, LogLine } from "../types";
import { LogLine } from "../../types/log";
import { extract } from "../inference";
import { LLMClient } from "../llm/LLMClient";

Expand Down Expand Up @@ -114,7 +114,6 @@ export class StagehandExtractHandler {
progress,
previouslyExtractedContent: content,
domElements: outputString,
llmProvider: this.llmProvider,
schema,
llmClient,
chunksSeen: chunksSeen.length,
Expand Down
9 changes: 4 additions & 5 deletions lib/handlers/observeHandler.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { LLMProvider } from "../llm/LLMProvider";
import { LogLine, AvailableModel } from "../types";
import { LogLine } from "../../types/log";
import { Stagehand } from "../index";
import { observe } from "../inference";
import { LLMClient, modelsWithVision } from "../llm/LLMClient";
import { ScreenshotService } from "../vision";
import { LLMClient } from "../llm/LLMClient";
import { LLMProvider } from "../llm/LLMProvider";
import { generateId } from "../utils";
import { ScreenshotService } from "../vision";

export class StagehandObserveHandler {
private readonly stagehand: Stagehand;
Expand Down Expand Up @@ -134,7 +134,6 @@ export class StagehandObserveHandler {
const observationResponse = await observe({
instruction,
domElements: outputString,
llmProvider: this.llmProvider,
llmClient,
image: annotatedScreenshot,
requestId,
Expand Down
109 changes: 40 additions & 69 deletions lib/index.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
import { type Page, type BrowserContext, chromium } from "@playwright/test";
import { z } from "zod";
import { Browserbase } from "@browserbasehq/sdk";
import { type BrowserContext, chromium, type Page } from "@playwright/test";
import { randomUUID } from "crypto";
import fs from "fs";
import { Browserbase, ClientOptions } from "@browserbasehq/sdk";
import { LLMProvider } from "./llm/LLMProvider";
import { AvailableModel } from "./types";
// @ts-ignore we're using a built js file as a string here
import os from "os";
import { z } from "zod";
import { BrowserResult } from "../types/browser";
import { LogLine } from "../types/log";
import {
ActOptions,
ActResult,
ConstructorParams,
ExtractOptions,
ExtractResult,
InitFromPageOptions,
InitFromPageResult,
InitOptions,
InitResult,
ObserveOptions,
ObserveResult,
} from "../types/stagehand";
import { scriptContent } from "./dom/build/scriptContent";
import { LogLine } from "./types";
import { randomUUID } from "crypto";
import { logLineToString } from "./utils";
import { StagehandActHandler } from "./handlers/actHandler";
import { StagehandExtractHandler } from "./handlers/extractHandler";
import { StagehandObserveHandler } from "./handlers/observeHandler";
import { StagehandActHandler } from "./handlers/actHandler";
import { LLMClient } from "./llm/LLMClient";
import { LLMProvider } from "./llm/LLMProvider";
import { logLineToString } from "./utils";

require("dotenv").config({ path: ".env" });

Expand All @@ -26,7 +39,7 @@ async function getBrowser(
logger: (message: LogLine) => void,
browserbaseSessionCreateParams?: Browserbase.Sessions.SessionCreateParams,
browserbaseResumeSessionID?: string,
) {
): Promise<BrowserResult> {
if (env === "BROWSERBASE") {
if (!apiKey) {
logger({
Expand Down Expand Up @@ -184,7 +197,7 @@ async function getBrowser(
},
});

const tmpDir = fs.mkdtempSync(`/tmp/pwtest`);
const tmpDir = fs.mkdtempSync("/tmp/pwtest");
fs.mkdirSync(`${tmpDir}/userdir/Default`, { recursive: true });

const defaultPreferences = {
Expand Down Expand Up @@ -311,22 +324,7 @@ export class Stagehand {
browserbaseResumeSessionID,
modelName,
modelClientOptions,
}: {
env: "LOCAL" | "BROWSERBASE";
apiKey?: string;
projectId?: string;
verbose?: 0 | 1 | 2;
debugDom?: boolean;
llmProvider?: LLMProvider;
headless?: boolean;
logger?: (message: LogLine) => void;
domSettleTimeoutMs?: number;
browserBaseSessionCreateParams?: Browserbase.Sessions.SessionCreateParams;
enableCaching?: boolean;
browserbaseResumeSessionID?: string;
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
} = {
}: ConstructorParams = {
env: "BROWSERBASE",
},
) {
Expand Down Expand Up @@ -356,14 +354,7 @@ export class Stagehand {
modelName,
modelClientOptions,
domSettleTimeoutMs,
}: {
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
domSettleTimeoutMs?: number;
} = {}): Promise<{
debugUrl: string;
sessionUrl: string;
}> {
}: InitOptions = {}): Promise<InitResult> {
const llmClient = modelName
? this.llmProvider.getClient(modelName, modelClientOptions)
: this.llmClient;
Expand All @@ -377,7 +368,11 @@ export class Stagehand {
this.browserbaseResumeSessionID,
).catch((e) => {
console.error("Error in init:", e);
return { context: undefined, debugUrl: undefined, sessionUrl: undefined };
return {
context: undefined,
debugUrl: undefined,
sessionUrl: undefined,
} as BrowserResult;
});
this.context = context;
this.page = context.pages()[0];
Expand Down Expand Up @@ -445,11 +440,11 @@ export class Stagehand {
return { debugUrl, sessionUrl };
}

async initFromPage(
page: Page,
modelName?: AvailableModel,
modelClientOptions?: ClientOptions,
): Promise<{ context: BrowserContext }> {
async initFromPage({
page,
modelName,
modelClientOptions,
}: InitFromPageOptions): Promise<InitFromPageResult> {
this.page = page;
this.context = page.context();
this.llmClient = modelName
Expand Down Expand Up @@ -480,7 +475,6 @@ export class Stagehand {
return { context: this.context };
}

// Logging
private pending_logs_to_send_to_browserbase: LogLine[] = [];

private is_processing_browserbase_logs: boolean = false;
Expand Down Expand Up @@ -659,18 +653,7 @@ export class Stagehand {
useVision = "fallback",
variables = {},
domSettleTimeoutMs,
}: {
action: string;
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
useVision?: "fallback" | boolean;
variables?: Record<string, string>;
domSettleTimeoutMs?: number;
}): Promise<{
success: boolean;
message: string;
action: string;
}> {
}: ActOptions): Promise<ActResult> {
if (!this.actHandler) {
throw new Error("Act handler not initialized");
}
Expand Down Expand Up @@ -749,13 +732,7 @@ export class Stagehand {
modelName,
modelClientOptions,
domSettleTimeoutMs,
}: {
instruction: string;
schema: T;
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
domSettleTimeoutMs?: number;
}): Promise<z.infer<T>> {
}: ExtractOptions<T>): Promise<ExtractResult<T>> {
if (!this.extractHandler) {
throw new Error("Extract handler not initialized");
}
Expand Down Expand Up @@ -818,13 +795,7 @@ export class Stagehand {
});
}

async observe(options?: {
instruction?: string;
modelName?: AvailableModel;
modelClientOptions?: ClientOptions;
useVision?: boolean;
domSettleTimeoutMs?: number;
}): Promise<{ selector: string; description: string }[]> {
async observe(options?: ObserveOptions): Promise<ObserveResult[]> {
if (!this.observeHandler) {
throw new Error("Observe handler not initialized");
}
Expand Down
31 changes: 4 additions & 27 deletions lib/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import {
ChatMessage,
LLMClient,
} from "./llm/LLMClient";
import { VerifyActCompletionParams } from "../types/inference";
import { ActResult, ActParams } from "../types/act";

export async function verifyActCompletion({
goal,
Expand All @@ -30,15 +32,7 @@ export async function verifyActCompletion({
domElements,
logger,
requestId,
}: {
goal: string;
steps: string;
llmClient: LLMClient;
screenshot?: Buffer;
domElements?: string;
logger: (message: { category?: string; message: string }) => void;
requestId: string;
}): Promise<boolean> {
}: VerifyActCompletionParams): Promise<boolean> {
const messages: ChatMessage[] = [
buildVerifyActCompletionSystemPrompt(),
buildVerifyActCompletionUserPrompt(goal, steps, domElements),
Expand Down Expand Up @@ -106,24 +100,7 @@ export async function act({
logger,
requestId,
variables,
}: {
action: string;
steps?: string;
domElements: string;
llmClient: LLMClient;
screenshot?: Buffer;
retries?: number;
logger: (message: { category?: string; message: string }) => void;
requestId: string;
variables?: Record<string, string>;
}): Promise<{
method: string;
element: number;
args: any[];
completed: boolean;
step: string;
why?: string;
} | null> {
}: ActParams): Promise<ActResult | null> {
const messages: ChatMessage[] = [
buildActSystemPrompt(),
buildActUserPrompt(action, steps, domElements, variables),
Expand Down
7 changes: 4 additions & 3 deletions lib/llm/AnthropicClient.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import Anthropic, { ClientOptions } from "@anthropic-ai/sdk";
import { LLMClient, ChatCompletionOptions } from "./LLMClient";
import { Message, MessageCreateParams } from "@anthropic-ai/sdk/resources";
import { zodToJsonSchema } from "zod-to-json-schema";
import { LogLine } from "../../types/log";
import { AvailableModel } from "../../types/model";
import { LLMCache } from "../cache/LLMCache";
import { AvailableModel, LogLine } from "../types";
import { Message, MessageCreateParams } from "@anthropic-ai/sdk/resources";
import { ChatCompletionOptions, LLMClient } from "./LLMClient";

export class AnthropicClient extends LLMClient {
private client: Anthropic;
Expand Down
Loading

0 comments on commit 8dff026

Please sign in to comment.