From b7e42644d8c0c94b287bc69d038cf0f560e68c7c Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Wed, 26 Jun 2024 12:39:53 -0700 Subject: [PATCH] core[patch]: Add check for bind tools in structured prompt (#5882) * core[patch]: Add check for bind tools in structured promot * cr * fix test --- langchain-core/src/prompts/structured.ts | 18 ++++++++++++++++-- .../src/prompts/tests/structured.test.ts | 8 ++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/langchain-core/src/prompts/structured.ts b/langchain-core/src/prompts/structured.ts index 199f988f4642..e08bd4f4e976 100644 --- a/langchain-core/src/prompts/structured.ts +++ b/langchain-core/src/prompts/structured.ts @@ -1,3 +1,5 @@ +import { BaseLanguageModel } from "../language_models/base.js"; +import { BaseChatModel } from "../language_models/chat_models.js"; import { ChatPromptValueInterface } from "../prompt_values.js"; import { RunnableLike, @@ -22,7 +24,18 @@ function isWithStructuredOutput( typeof x === "object" && x != null && "withStructuredOutput" in x && - typeof x.withStructuredOutput === "function" + x.withStructuredOutput !== BaseLanguageModel.prototype.withStructuredOutput + ); +} + +function isBindTools(x: unknown): x is { + bindTools: (...arg: unknown[]) => Runnable; +} { + return ( + typeof x === "object" && + x != null && + "bindTools" in x && + x.bindTools !== BaseChatModel.prototype.bindTools ); } @@ -84,7 +97,8 @@ export class StructuredPrompt< if ( isRunnableBinding(coerceable) && - isWithStructuredOutput(coerceable.bound) + isWithStructuredOutput(coerceable.bound) && + isBindTools(coerceable.bound) ) { return super.pipe( coerceable.bound diff --git a/langchain-core/src/prompts/tests/structured.test.ts b/langchain-core/src/prompts/tests/structured.test.ts index 8dba1f417747..45b21aafa3b3 100644 --- a/langchain-core/src/prompts/tests/structured.test.ts +++ b/langchain-core/src/prompts/tests/structured.test.ts @@ -4,6 +4,7 @@ import { StructuredOutputMethodParams, StructuredOutputMethodOptions, BaseLanguageModelInput, + ToolDefinition, } from "../../language_models/base.js"; import { BaseMessage } from "../../messages/index.js"; import { Runnable, RunnableLambda } from "../../runnables/base.js"; @@ -11,8 +12,15 @@ import { RunnableConfig } from "../../runnables/config.js"; import { FakeListChatModel } from "../../utils/testing/index.js"; import { StructuredPrompt } from "../structured.js"; import { load } from "../../load/index.js"; +import { StructuredToolInterface } from "../../tools.js"; class FakeStructuredChatModel extends FakeListChatModel { + override bindTools( + _tools: (StructuredToolInterface | ToolDefinition | Record)[] + ): Runnable { + return this.bind({}); + } + withStructuredOutput< RunOutput extends Record = Record >(