From b6007bbcda6225fbcd9fcb81ea8a2157dd3a1d31 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Tue, 21 Jan 2025 09:11:00 -0800 Subject: [PATCH] fix(core): Prevent cache misses from triggering model start callback runs twice (#7565) --- .../src/language_models/chat_models.ts | 120 ++++++++++++------ langchain-core/src/language_models/llms.ts | 98 ++++++++------ .../language_models/tests/chat_models.test.ts | 48 +++++++ .../src/language_models/tests/llms.test.ts | 30 +++++ .../src/tests/chat_models.int.test.ts | 6 + 5 files changed, 226 insertions(+), 76 deletions(-) diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index 9ca563a38a04..36feee110abe 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -9,6 +9,8 @@ import { coerceMessageLikeToMessage, AIMessageChunk, isAIMessageChunk, + isBaseMessage, + isAIMessage, } from "../messages/index.js"; import type { BasePromptValueInterface } from "../prompt_values.js"; import { @@ -343,41 +345,50 @@ export abstract class BaseChatModel< async _generateUncached( messages: BaseMessageLike[][], parsedOptions: this["ParsedCallOptions"], - handledOptions: RunnableConfig + handledOptions: RunnableConfig, + startedRunManagers?: CallbackManagerForLLMRun[] ): Promise { const baseMessages = messages.map((messageList) => messageList.map(coerceMessageLikeToMessage) ); - const inheritableMetadata = { - ...handledOptions.metadata, - ...this.getLsParams(parsedOptions), - }; - // create callback manager and start run - const callbackManager_ = await CallbackManager.configure( - handledOptions.callbacks, - this.callbacks, - handledOptions.tags, - this.tags, - inheritableMetadata, - this.metadata, - { verbose: this.verbose } - ); - const extra = { - options: parsedOptions, - invocation_params: this?.invocationParams(parsedOptions), - batch_size: 1, - }; - const runManagers = await callbackManager_?.handleChatModelStart( - this.toJSON(), - baseMessages, - handledOptions.runId, - undefined, - extra, - undefined, - undefined, - handledOptions.runName - ); + let runManagers: CallbackManagerForLLMRun[] | undefined; + if ( + startedRunManagers !== undefined && + startedRunManagers.length === baseMessages.length + ) { + runManagers = startedRunManagers; + } else { + const inheritableMetadata = { + ...handledOptions.metadata, + ...this.getLsParams(parsedOptions), + }; + // create callback manager and start run + const callbackManager_ = await CallbackManager.configure( + handledOptions.callbacks, + this.callbacks, + handledOptions.tags, + this.tags, + inheritableMetadata, + this.metadata, + { verbose: this.verbose } + ); + const extra = { + options: parsedOptions, + invocation_params: this?.invocationParams(parsedOptions), + batch_size: 1, + }; + runManagers = await callbackManager_?.handleChatModelStart( + this.toJSON(), + baseMessages, + handledOptions.runId, + undefined, + extra, + undefined, + undefined, + handledOptions.runName + ); + } const generations: ChatGeneration[][] = []; const llmOutputs: LLMResult["llmOutput"][] = []; // Even if stream is not explicitly called, check if model is implicitly @@ -511,7 +522,12 @@ export abstract class BaseChatModel< // eslint-disable-next-line @typescript-eslint/no-explicit-any parsedOptions: any; handledOptions: RunnableConfig; - }): Promise { + }): Promise< + LLMResult & { + missingPromptIndices: number[]; + startedRunManagers?: CallbackManagerForLLMRun[]; + } + > { const baseMessages = messages.map((messageList) => messageList.map(coerceMessageLikeToMessage) ); @@ -580,7 +596,26 @@ export abstract class BaseChatModel< cachedResults.map(async ({ result: promiseResult, runManager }, i) => { if (promiseResult.status === "fulfilled") { const result = promiseResult.value as Generation[]; - generations[i] = result; + generations[i] = result.map((result) => { + if ( + "message" in result && + isBaseMessage(result.message) && + isAIMessage(result.message) + ) { + // eslint-disable-next-line no-param-reassign + result.message.usage_metadata = { + input_tokens: 0, + output_tokens: 0, + total_tokens: 0, + }; + } + // eslint-disable-next-line no-param-reassign + result.generationInfo = { + ...result.generationInfo, + tokenUsage: {}, + }; + return result; + }); if (result.length) { await runManager?.handleLLMNewToken(result[0].text); } @@ -598,6 +633,7 @@ export abstract class BaseChatModel< const output = { generations, missingPromptIndices, + startedRunManagers: runManagers, }; // This defines RUN_KEY as a non-enumerable property on the output object @@ -650,20 +686,24 @@ export abstract class BaseChatModel< callOptions as CallOptions ); - const { generations, missingPromptIndices } = await this._generateCached({ - messages: baseMessages, - cache, - llmStringKey, - parsedOptions: callOptions, - handledOptions: runnableConfig, - }); + const { generations, missingPromptIndices, startedRunManagers } = + await this._generateCached({ + messages: baseMessages, + cache, + llmStringKey, + parsedOptions: callOptions, + handledOptions: runnableConfig, + }); let llmOutput = {}; if (missingPromptIndices.length > 0) { const results = await this._generateUncached( missingPromptIndices.map((i) => baseMessages[i]), callOptions, - runnableConfig + runnableConfig, + startedRunManagers !== undefined + ? missingPromptIndices.map((i) => startedRunManagers?.[i]) + : undefined ); await Promise.all( results.generations.map(async (generation, index) => { diff --git a/langchain-core/src/language_models/llms.ts b/langchain-core/src/language_models/llms.ts index ce75a52479be..63e18cb9a0b3 100644 --- a/langchain-core/src/language_models/llms.ts +++ b/langchain-core/src/language_models/llms.ts @@ -240,32 +240,41 @@ export abstract class BaseLLM< async _generateUncached( prompts: string[], parsedOptions: this["ParsedCallOptions"], - handledOptions: BaseCallbackConfig + handledOptions: BaseCallbackConfig, + startedRunManagers?: CallbackManagerForLLMRun[] ): Promise { - const callbackManager_ = await CallbackManager.configure( - handledOptions.callbacks, - this.callbacks, - handledOptions.tags, - this.tags, - handledOptions.metadata, - this.metadata, - { verbose: this.verbose } - ); - const extra = { - options: parsedOptions, - invocation_params: this?.invocationParams(parsedOptions), - batch_size: prompts.length, - }; - const runManagers = await callbackManager_?.handleLLMStart( - this.toJSON(), - prompts, - handledOptions.runId, - undefined, - extra, - undefined, - undefined, - handledOptions?.runName - ); + let runManagers: CallbackManagerForLLMRun[] | undefined; + if ( + startedRunManagers !== undefined && + startedRunManagers.length === prompts.length + ) { + runManagers = startedRunManagers; + } else { + const callbackManager_ = await CallbackManager.configure( + handledOptions.callbacks, + this.callbacks, + handledOptions.tags, + this.tags, + handledOptions.metadata, + this.metadata, + { verbose: this.verbose } + ); + const extra = { + options: parsedOptions, + invocation_params: this?.invocationParams(parsedOptions), + batch_size: prompts.length, + }; + runManagers = await callbackManager_?.handleLLMStart( + this.toJSON(), + prompts, + handledOptions.runId, + undefined, + extra, + undefined, + undefined, + handledOptions?.runName + ); + } // Even if stream is not explicitly called, check if model is implicitly // called from streamEvents() or streamLog() to get all streamed events. // Bail out if _streamResponseChunks not overridden @@ -346,7 +355,12 @@ export abstract class BaseLLM< parsedOptions: any; handledOptions: RunnableConfig; runId?: string; - }): Promise { + }): Promise< + LLMResult & { + missingPromptIndices: number[]; + startedRunManagers?: CallbackManagerForLLMRun[]; + } + > { const callbackManager_ = await CallbackManager.configure( handledOptions.callbacks, this.callbacks, @@ -401,7 +415,14 @@ export abstract class BaseLLM< cachedResults.map(async ({ result: promiseResult, runManager }, i) => { if (promiseResult.status === "fulfilled") { const result = promiseResult.value as Generation[]; - generations[i] = result; + generations[i] = result.map((result) => { + // eslint-disable-next-line no-param-reassign + result.generationInfo = { + ...result.generationInfo, + tokenUsage: {}, + }; + return result; + }); if (result.length) { await runManager?.handleLLMNewToken(result[0].text); } @@ -419,6 +440,7 @@ export abstract class BaseLLM< const output = { generations, missingPromptIndices, + startedRunManagers: runManagers, }; // This defines RUN_KEY as a non-enumerable property on the output object @@ -465,21 +487,25 @@ export abstract class BaseLLM< const llmStringKey = this._getSerializedCacheKeyParametersForCall( callOptions as CallOptions ); - const { generations, missingPromptIndices } = await this._generateCached({ - prompts, - cache, - llmStringKey, - parsedOptions: callOptions, - handledOptions: runnableConfig, - runId: runnableConfig.runId, - }); + const { generations, missingPromptIndices, startedRunManagers } = + await this._generateCached({ + prompts, + cache, + llmStringKey, + parsedOptions: callOptions, + handledOptions: runnableConfig, + runId: runnableConfig.runId, + }); let llmOutput = {}; if (missingPromptIndices.length > 0) { const results = await this._generateUncached( missingPromptIndices.map((i) => prompts[i]), callOptions, - runnableConfig + runnableConfig, + startedRunManagers !== undefined + ? missingPromptIndices.map((i) => startedRunManagers?.[i]) + : undefined ); await Promise.all( results.generations.map(async (generation, index) => { diff --git a/langchain-core/src/language_models/tests/chat_models.test.ts b/langchain-core/src/language_models/tests/chat_models.test.ts index f335d5edc40f..8598d7aa6cd3 100644 --- a/langchain-core/src/language_models/tests/chat_models.test.ts +++ b/langchain-core/src/language_models/tests/chat_models.test.ts @@ -287,6 +287,54 @@ test("Test ChatModel can cache complex messages", async () => { expect(cachedMsg.content).toEqual(JSON.stringify(contentToCache, null, 2)); }); +test("Test ChatModel with cache does not start multiple chat model runs", async () => { + const model = new FakeChatModel({ + cache: true, + }); + if (!model.cache) { + throw new Error("Cache not enabled"); + } + + const contentToCache = [ + { + type: "text", + text: "Hello there again!", + }, + ]; + const humanMessage = new HumanMessage({ + content: contentToCache, + }); + + const prompt = getBufferString([humanMessage]); + const llmKey = model._getSerializedCacheKeyParametersForCall({}); + + const value = await model.cache.lookup(prompt, llmKey); + expect(value).toBeNull(); + + // Invoke model to trigger cache update + const eventStream = model.streamEvents([humanMessage], { version: "v2" }); + + expect(await model.cache.lookup(prompt, llmKey)).toBeDefined(); + + const events = []; + for await (const event of eventStream) { + events.push(event); + } + expect(events.length).toEqual(2); + expect(events[0].event).toEqual("on_chat_model_start"); + expect(events[1].event).toEqual("on_chat_model_end"); + + const eventStream2 = model.streamEvents([humanMessage], { version: "v2" }); + + const events2 = []; + for await (const event of eventStream2) { + events2.push(event); + } + expect(events2.length).toEqual(2); + expect(events2[0].event).toEqual("on_chat_model_start"); + expect(events2[1].event).toEqual("on_chat_model_end"); +}); + test("Test ChatModel can emit a custom event", async () => { const model = new FakeListChatModel({ responses: ["hi"], diff --git a/langchain-core/src/language_models/tests/llms.test.ts b/langchain-core/src/language_models/tests/llms.test.ts index 54b56e42f7d8..f1cf453bfc75 100644 --- a/langchain-core/src/language_models/tests/llms.test.ts +++ b/langchain-core/src/language_models/tests/llms.test.ts @@ -42,6 +42,36 @@ test("Test FakeLLM uses callbacks with a cache", async () => { expect(response2).toEqual(acc); }); +test("Test LLM with cache does not start multiple LLM runs", async () => { + const model = new FakeLLM({ + cache: true, + }); + if (!model.cache) { + throw new Error("Cache not enabled"); + } + + // Invoke model to trigger cache update + const eventStream = model.streamEvents("Hello there!", { version: "v2" }); + + const events = []; + for await (const event of eventStream) { + events.push(event); + } + expect(events.length).toEqual(2); + expect(events[0].event).toEqual("on_llm_start"); + expect(events[1].event).toEqual("on_llm_end"); + + const eventStream2 = model.streamEvents("Hello there!", { version: "v2" }); + + const events2 = []; + for await (const event of eventStream2) { + events2.push(event); + } + expect(events2.length).toEqual(2); + expect(events2[0].event).toEqual("on_llm_start"); + expect(events2[1].event).toEqual("on_llm_end"); +}); + test("Test FakeStreamingLLM works when streaming through a prompt", async () => { const prompt = HumanMessagePromptTemplate.fromTemplate("hello there {name}"); const model = new FakeStreamingLLM({}); diff --git a/libs/langchain-openai/src/tests/chat_models.int.test.ts b/libs/langchain-openai/src/tests/chat_models.int.test.ts index a49c014d5d8a..a88e24cad56b 100644 --- a/libs/langchain-openai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.int.test.ts @@ -579,6 +579,12 @@ test("ChatOpenAI can cache generations", async () => { expect(lookupSpy).toHaveBeenCalledTimes(2); expect(updateSpy).toHaveBeenCalledTimes(2); + const res2 = await chat.generate([[message], [message]]); + expect(res2.generations.length).toBe(2); + + expect(lookupSpy).toHaveBeenCalledTimes(4); + expect(updateSpy).toHaveBeenCalledTimes(2); + lookupSpy.mockRestore(); updateSpy.mockRestore(); });