Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mistralai[minor]: Add llms entrypoint, update chat model integration #5603

Merged
merged 13 commits into from
May 31, 2024
Binary file not shown.
2 changes: 1 addition & 1 deletion libs/langchain-mistralai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"license": "MIT",
"dependencies": {
"@langchain/core": ">0.1.56 <0.3.0",
"@mistralai/mistralai": "^0.1.3",
"@mistralai/mistralai": "^0.4.0",
"uuid": "^9.0.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.4"
Expand Down
98 changes: 53 additions & 45 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ import {
ChatCompletionResponse,
Function as MistralAIFunction,
ToolCalls as MistralAIToolCalls,
ToolChoice as MistralAIToolChoice,
ResponseFormat,
ChatCompletionResponseChunk,
ToolType,
ChatRequest,
Tool as MistralAITool,
Message as MistralAIMessage,
} from "@mistralai/mistralai";
import {
MessageType,
Expand Down Expand Up @@ -44,7 +45,6 @@ import {
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { StructuredTool, StructuredToolInterface } from "@langchain/core/tools";
import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { z } from "zod";
import {
type BaseLLMOutputParser,
Expand All @@ -70,40 +70,15 @@ interface TokenUsage {
totalTokens?: number;
}

type MistralAIInputMessage = {
role: string;
name?: string;
content: string | string[];
tool_calls?: MistralAIToolCalls[];
};
export type MistralAIToolChoice = "auto" | "any" | "none";

type MistralAIToolInput = { type: string; function: MistralAIFunction };

type MistralAIChatCompletionOptions = {
model: string;
messages: Array<{
role: string;
name?: string;
content: string | string[];
tool_calls?: MistralAIToolCalls[];
}>;
tools?: Array<MistralAIToolInput>;
temperature?: number;
maxTokens?: number;
topP?: number;
randomSeed?: number;
safeMode?: boolean;
safePrompt?: boolean;
toolChoice?: MistralAIToolChoice;
responseFormat?: ResponseFormat;
};

interface MistralAICallOptions
extends Omit<BaseLanguageModelCallOptions, "stop"> {
response_format?: {
type: "text" | "json_object";
};
tools: StructuredToolInterface[] | MistralAIToolInput[];
tools: StructuredToolInterface[] | MistralAIToolInput[] | MistralAITool[];
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to include the old type for back-compat

tool_choice?: MistralAIToolChoice;
}

Expand Down Expand Up @@ -178,7 +153,7 @@ export interface ChatMistralAIInput extends BaseChatModelParams {

function convertMessagesToMistralMessages(
messages: Array<BaseMessage>
): Array<MistralAIInputMessage> {
): Array<MistralAIMessage> {
const getRole = (role: MessageType) => {
switch (role) {
case "human":
Expand Down Expand Up @@ -222,18 +197,28 @@ function convertMessagesToMistralMessages(
message.additional_kwargs.tool_calls;
return toolCalls?.map((toolCall) => ({
id: "null",
type: "function" as ToolType.function,
type: "function",
function: toolCall.function,
}));
};

return messages.map((message) => {
const toolCalls = getTools(message);
const content = toolCalls === undefined ? getContent(message.content) : "";
const role = getRole(message._getType());

if (role === "tool" && toolCalls && toolCalls.length > 0) {
return {
role: "tool",
content,
name: toolCalls[0].function.name,
tool_call_id: toolCalls[0].id,
};
}

return {
role: getRole(message._getType()),
role: role as "system" | "user" | "assistant",
content,
tool_calls: toolCalls,
};
});
}
Expand Down Expand Up @@ -270,7 +255,10 @@ function mistralAIResponseToChatMessage(
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs: {
tool_calls: rawToolCalls,
tool_calls: rawToolCalls.map((toolCall) => ({
...toolCall,
type: "function",
})),
},
});
}
Expand Down Expand Up @@ -350,8 +338,18 @@ function _convertDeltaToMessageChunk(delta: {

function _convertStructuredToolToMistralTool(
tools: StructuredToolInterface[]
): MistralAIToolInput[] {
return tools.map((tool) => convertToOpenAITool(tool) as MistralAIToolInput);
): MistralAITool[] {
return tools.map((tool) => {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New tool type from mistral is not exactly compatible with OAI

const description = tool.description ?? `Tool: ${tool.name}`;
return {
type: "function",
function: {
name: tool.name,
description,
parameters: zodToJsonSchema(tool.schema),
},
};
});
}

/**
Expand Down Expand Up @@ -439,17 +437,27 @@ export class ChatMistralAI<
*/
invocationParams(
options?: this["ParsedCallOptions"]
): Omit<MistralAIChatCompletionOptions, "messages"> {
): Omit<ChatRequest, "messages"> {
const { response_format, tools, tool_choice } = options ?? {};
const mistralAITools = tools
const mistralAITools: Array<MistralAITool> | undefined = tools
?.map((tool) => {
if ("lc_namespace" in tool) {
return _convertStructuredToolToMistralTool([tool]);
}
return tool;
if (!tool.function.description) {
return {
type: "function",
function: {
name: tool.function.name,
description: `Tool: ${tool.function.name}`,
parameters: tool.function.parameters,
},
} as MistralAITool;
}
return tool as MistralAITool;
})
.flat();
const params: Omit<MistralAIChatCompletionOptions, "messages"> = {
const params: Omit<ChatRequest, "messages"> = {
model: this.model,
tools: mistralAITools,
temperature: this.temperature,
Expand Down Expand Up @@ -484,21 +492,21 @@ export class ChatMistralAI<

/**
* Calls the MistralAI API with retry logic in case of failures.
* @param {MistralAIChatCompletionOptions} input The input to send to the MistralAI API.
* @param {ChatRequest} input The input to send to the MistralAI API.
* @returns {Promise<MistralAIChatCompletionResult | AsyncGenerator<MistralAIChatCompletionResult>>} The response from the MistralAI API.
*/
async completionWithRetry(
input: MistralAIChatCompletionOptions,
input: ChatRequest,
streaming: true
): Promise<AsyncGenerator<ChatCompletionResponseChunk>>;

async completionWithRetry(
input: MistralAIChatCompletionOptions,
input: ChatRequest,
streaming: false
): Promise<ChatCompletionResponse>;

async completionWithRetry(
input: MistralAIChatCompletionOptions,
input: ChatRequest,
streaming: boolean
): Promise<
ChatCompletionResponse | AsyncGenerator<ChatCompletionResponseChunk>
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-mistralai/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from "./chat_models.js";
export * from "./embeddings.js";
export * from "./llms.js";
Loading
Loading