Skip to content

Commit

Permalink
fix(core): Prevent cache misses from triggering model start callback …
Browse files Browse the repository at this point in the history
…runs twice (#7565)
  • Loading branch information
jacoblee93 authored Jan 21, 2025
1 parent c4f6122 commit b6007bb
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 76 deletions.
120 changes: 80 additions & 40 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import {
coerceMessageLikeToMessage,
AIMessageChunk,
isAIMessageChunk,
isBaseMessage,
isAIMessage,
} from "../messages/index.js";
import type { BasePromptValueInterface } from "../prompt_values.js";
import {
Expand Down Expand Up @@ -343,41 +345,50 @@ export abstract class BaseChatModel<
async _generateUncached(
messages: BaseMessageLike[][],
parsedOptions: this["ParsedCallOptions"],
handledOptions: RunnableConfig
handledOptions: RunnableConfig,
startedRunManagers?: CallbackManagerForLLMRun[]
): Promise<LLMResult> {
const baseMessages = messages.map((messageList) =>
messageList.map(coerceMessageLikeToMessage)
);

const inheritableMetadata = {
...handledOptions.metadata,
...this.getLsParams(parsedOptions),
};
// create callback manager and start run
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
inheritableMetadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: 1,
};
const runManagers = await callbackManager_?.handleChatModelStart(
this.toJSON(),
baseMessages,
handledOptions.runId,
undefined,
extra,
undefined,
undefined,
handledOptions.runName
);
let runManagers: CallbackManagerForLLMRun[] | undefined;
if (
startedRunManagers !== undefined &&
startedRunManagers.length === baseMessages.length
) {
runManagers = startedRunManagers;
} else {
const inheritableMetadata = {
...handledOptions.metadata,
...this.getLsParams(parsedOptions),
};
// create callback manager and start run
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
inheritableMetadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: 1,
};
runManagers = await callbackManager_?.handleChatModelStart(
this.toJSON(),
baseMessages,
handledOptions.runId,
undefined,
extra,
undefined,
undefined,
handledOptions.runName
);
}
const generations: ChatGeneration[][] = [];
const llmOutputs: LLMResult["llmOutput"][] = [];
// Even if stream is not explicitly called, check if model is implicitly
Expand Down Expand Up @@ -511,7 +522,12 @@ export abstract class BaseChatModel<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
parsedOptions: any;
handledOptions: RunnableConfig;
}): Promise<LLMResult & { missingPromptIndices: number[] }> {
}): Promise<
LLMResult & {
missingPromptIndices: number[];
startedRunManagers?: CallbackManagerForLLMRun[];
}
> {
const baseMessages = messages.map((messageList) =>
messageList.map(coerceMessageLikeToMessage)
);
Expand Down Expand Up @@ -580,7 +596,26 @@ export abstract class BaseChatModel<
cachedResults.map(async ({ result: promiseResult, runManager }, i) => {
if (promiseResult.status === "fulfilled") {
const result = promiseResult.value as Generation[];
generations[i] = result;
generations[i] = result.map((result) => {
if (
"message" in result &&
isBaseMessage(result.message) &&
isAIMessage(result.message)
) {
// eslint-disable-next-line no-param-reassign
result.message.usage_metadata = {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
}
// eslint-disable-next-line no-param-reassign
result.generationInfo = {
...result.generationInfo,
tokenUsage: {},
};
return result;
});
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
Expand All @@ -598,6 +633,7 @@ export abstract class BaseChatModel<
const output = {
generations,
missingPromptIndices,
startedRunManagers: runManagers,
};

// This defines RUN_KEY as a non-enumerable property on the output object
Expand Down Expand Up @@ -650,20 +686,24 @@ export abstract class BaseChatModel<
callOptions as CallOptions
);

const { generations, missingPromptIndices } = await this._generateCached({
messages: baseMessages,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
});
const { generations, missingPromptIndices, startedRunManagers } =
await this._generateCached({
messages: baseMessages,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
});

let llmOutput = {};
if (missingPromptIndices.length > 0) {
const results = await this._generateUncached(
missingPromptIndices.map((i) => baseMessages[i]),
callOptions,
runnableConfig
runnableConfig,
startedRunManagers !== undefined
? missingPromptIndices.map((i) => startedRunManagers?.[i])
: undefined
);
await Promise.all(
results.generations.map(async (generation, index) => {
Expand Down
98 changes: 62 additions & 36 deletions langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,32 +240,41 @@ export abstract class BaseLLM<
async _generateUncached(
prompts: string[],
parsedOptions: this["ParsedCallOptions"],
handledOptions: BaseCallbackConfig
handledOptions: BaseCallbackConfig,
startedRunManagers?: CallbackManagerForLLMRun[]
): Promise<LLMResult> {
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
handledOptions.metadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: prompts.length,
};
const runManagers = await callbackManager_?.handleLLMStart(
this.toJSON(),
prompts,
handledOptions.runId,
undefined,
extra,
undefined,
undefined,
handledOptions?.runName
);
let runManagers: CallbackManagerForLLMRun[] | undefined;
if (
startedRunManagers !== undefined &&
startedRunManagers.length === prompts.length
) {
runManagers = startedRunManagers;
} else {
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
handledOptions.metadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: prompts.length,
};
runManagers = await callbackManager_?.handleLLMStart(
this.toJSON(),
prompts,
handledOptions.runId,
undefined,
extra,
undefined,
undefined,
handledOptions?.runName
);
}
// Even if stream is not explicitly called, check if model is implicitly
// called from streamEvents() or streamLog() to get all streamed events.
// Bail out if _streamResponseChunks not overridden
Expand Down Expand Up @@ -346,7 +355,12 @@ export abstract class BaseLLM<
parsedOptions: any;
handledOptions: RunnableConfig;
runId?: string;
}): Promise<LLMResult & { missingPromptIndices: number[] }> {
}): Promise<
LLMResult & {
missingPromptIndices: number[];
startedRunManagers?: CallbackManagerForLLMRun[];
}
> {
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
Expand Down Expand Up @@ -401,7 +415,14 @@ export abstract class BaseLLM<
cachedResults.map(async ({ result: promiseResult, runManager }, i) => {
if (promiseResult.status === "fulfilled") {
const result = promiseResult.value as Generation[];
generations[i] = result;
generations[i] = result.map((result) => {
// eslint-disable-next-line no-param-reassign
result.generationInfo = {
...result.generationInfo,
tokenUsage: {},
};
return result;
});
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
Expand All @@ -419,6 +440,7 @@ export abstract class BaseLLM<
const output = {
generations,
missingPromptIndices,
startedRunManagers: runManagers,
};

// This defines RUN_KEY as a non-enumerable property on the output object
Expand Down Expand Up @@ -465,21 +487,25 @@ export abstract class BaseLLM<
const llmStringKey = this._getSerializedCacheKeyParametersForCall(
callOptions as CallOptions
);
const { generations, missingPromptIndices } = await this._generateCached({
prompts,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
runId: runnableConfig.runId,
});
const { generations, missingPromptIndices, startedRunManagers } =
await this._generateCached({
prompts,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
runId: runnableConfig.runId,
});

let llmOutput = {};
if (missingPromptIndices.length > 0) {
const results = await this._generateUncached(
missingPromptIndices.map((i) => prompts[i]),
callOptions,
runnableConfig
runnableConfig,
startedRunManagers !== undefined
? missingPromptIndices.map((i) => startedRunManagers?.[i])
: undefined
);
await Promise.all(
results.generations.map(async (generation, index) => {
Expand Down
48 changes: 48 additions & 0 deletions langchain-core/src/language_models/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,54 @@ test("Test ChatModel can cache complex messages", async () => {
expect(cachedMsg.content).toEqual(JSON.stringify(contentToCache, null, 2));
});

test("Test ChatModel with cache does not start multiple chat model runs", async () => {
const model = new FakeChatModel({
cache: true,
});
if (!model.cache) {
throw new Error("Cache not enabled");
}

const contentToCache = [
{
type: "text",
text: "Hello there again!",
},
];
const humanMessage = new HumanMessage({
content: contentToCache,
});

const prompt = getBufferString([humanMessage]);
const llmKey = model._getSerializedCacheKeyParametersForCall({});

const value = await model.cache.lookup(prompt, llmKey);
expect(value).toBeNull();

// Invoke model to trigger cache update
const eventStream = model.streamEvents([humanMessage], { version: "v2" });

expect(await model.cache.lookup(prompt, llmKey)).toBeDefined();

const events = [];
for await (const event of eventStream) {
events.push(event);
}
expect(events.length).toEqual(2);
expect(events[0].event).toEqual("on_chat_model_start");
expect(events[1].event).toEqual("on_chat_model_end");

const eventStream2 = model.streamEvents([humanMessage], { version: "v2" });

const events2 = [];
for await (const event of eventStream2) {
events2.push(event);
}
expect(events2.length).toEqual(2);
expect(events2[0].event).toEqual("on_chat_model_start");
expect(events2[1].event).toEqual("on_chat_model_end");
});

test("Test ChatModel can emit a custom event", async () => {
const model = new FakeListChatModel({
responses: ["hi"],
Expand Down
Loading

0 comments on commit b6007bb

Please sign in to comment.