Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[community]: Add chat deployment to IBM chat class #7633

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
"@gradientai/nodejs-sdk": "^1.2.0",
"@huggingface/inference": "^2.6.4",
"@huggingface/transformers": "^3.2.3",
"@ibm-cloud/watsonx-ai": "^1.3.0",
"@ibm-cloud/watsonx-ai": "^1.4.0",
"@jest/globals": "^29.5.0",
"@lancedb/lancedb": "^0.13.0",
"@langchain/core": "workspace:*",
Expand Down
179 changes: 121 additions & 58 deletions libs/langchain-community/src/chat_models/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
} from "@langchain/core/outputs";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import {
DeploymentsTextChatParams,
RequestCallbacks,
TextChatMessagesTextChatMessageAssistant,
TextChatParameterTools,
Expand Down Expand Up @@ -65,7 +66,13 @@ import {
import { isZodSchema } from "@langchain/core/utils/types";
import { zodToJsonSchema } from "zod-to-json-schema";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { WatsonxAuth, WatsonxParams } from "../types/ibm.js";
import {
Neverify,
WatsonxAuth,
WatsonxChatBasicOptions,
WatsonxDeployedParams,
WatsonxParams,
} from "../types/ibm.js";
import {
_convertToolCallIdToMistralCompatible,
authenticateAndSetInstance,
Expand All @@ -80,27 +87,43 @@ export interface WatsonxDeltaStream {
}

export interface WatsonxCallParams
extends Partial<Omit<TextChatParams, "modelId" | "toolChoice">> {
maxRetries?: number;
watsonxCallbacks?: RequestCallbacks;
}
extends Partial<
Omit<TextChatParams, "modelId" | "toolChoice" | "messages" | "headers">
> {}

export interface WatsonxCallDeployedParams extends DeploymentsTextChatParams {}

export interface WatsonxCallOptionsChat
extends Omit<BaseChatModelCallOptions, "stop">,
WatsonxCallParams {
WatsonxCallParams,
WatsonxChatBasicOptions {
promptIndex?: number;
tool_choice?: TextChatParameterTools | string | "auto" | "any";
watsonxCallbacks?: RequestCallbacks;
}

export interface WatsonxCallOptionsDeployedChat
extends WatsonxCallDeployedParams,
WatsonxChatBasicOptions {
promptIndex?: number;
}

type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools;

export interface ChatWatsonxInput
extends BaseChatModelParams,
WatsonxParams,
WatsonxCallParams {
streaming?: boolean;
}
WatsonxCallParams,
Neverify<DeploymentsTextChatParams> {}

export interface ChatWatsonxDeployedInput
extends BaseChatModelParams,
WatsonxDeployedParams,
Neverify<TextChatParams> {}

export type ChatWatsonxConstructor = BaseChatModelParams &
Partial<WatsonxParams> &
WatsonxDeployedParams &
WatsonxCallParams;
function _convertToValidToolId(model: string, tool_call_id: string) {
if (model.startsWith("mistralai"))
return _convertToolCallIdToMistralCompatible(tool_call_id);
Expand Down Expand Up @@ -335,10 +358,12 @@ function _convertToolChoiceToWatsonxToolChoice(
}

export class ChatWatsonx<
CallOptions extends WatsonxCallOptionsChat = WatsonxCallOptionsChat
CallOptions extends WatsonxCallOptionsChat =
| WatsonxCallOptionsChat
| WatsonxCallOptionsDeployedChat
>
extends BaseChatModel<CallOptions>
implements ChatWatsonxInput
implements ChatWatsonxConstructor
{
static lc_name() {
return "ChatWatsonx";
Expand Down Expand Up @@ -380,8 +405,8 @@ export class ChatWatsonx<
ls_provider: "watsonx",
ls_model_name: this.model,
ls_model_type: "chat",
ls_temperature: params.temperature ?? undefined,
ls_max_tokens: params.maxTokens ?? undefined,
ls_temperature: params?.temperature ?? undefined,
ls_max_tokens: params?.maxTokens ?? undefined,
};
}

Expand All @@ -399,6 +424,8 @@ export class ChatWatsonx<

projectId?: string;

idOrName?: string;

frequencyPenalty?: number;

logprobs?: boolean;
Expand All @@ -425,37 +452,44 @@ export class ChatWatsonx<

watsonxCallbacks?: RequestCallbacks;

constructor(fields: ChatWatsonxInput & WatsonxAuth) {
constructor(
fields: (ChatWatsonxInput | ChatWatsonxDeployedInput) & WatsonxAuth
) {
super(fields);
if (
(fields.projectId && fields.spaceId) ||
(fields.idOrName && fields.projectId) ||
(fields.spaceId && fields.idOrName)
("projectId" in fields && "spaceId" in fields) ||
("projectId" in fields && "idOrName" in fields) ||
("spaceId" in fields && "idOrName" in fields)
)
throw new Error("Maximum 1 id type can be specified per instance");

if (!fields.projectId && !fields.spaceId && !fields.idOrName)
if (!("projectId" in fields || "spaceId" in fields || "idOrName" in fields))
throw new Error(
"No id specified! At least id of 1 type has to be specified"
);
this.projectId = fields?.projectId;
this.spaceId = fields?.spaceId;
this.temperature = fields?.temperature;
this.maxRetries = fields?.maxRetries || this.maxRetries;
this.maxConcurrency = fields?.maxConcurrency;
this.frequencyPenalty = fields?.frequencyPenalty;
this.topLogprobs = fields?.topLogprobs;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.presencePenalty = fields?.presencePenalty;
this.topP = fields?.topP;
this.timeLimit = fields?.timeLimit;
this.responseFormat = fields?.responseFormat ?? this.responseFormat;

if ("model" in fields) {
this.projectId = fields?.projectId;
this.spaceId = fields?.spaceId;
this.temperature = fields?.temperature;
this.maxRetries = fields?.maxRetries || this.maxRetries;
this.maxConcurrency = fields?.maxConcurrency;
this.frequencyPenalty = fields?.frequencyPenalty;
this.topLogprobs = fields?.topLogprobs;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.presencePenalty = fields?.presencePenalty;
this.topP = fields?.topP;
this.timeLimit = fields?.timeLimit;
this.responseFormat = fields?.responseFormat ?? this.responseFormat;
this.streaming = fields?.streaming ?? this.streaming;
this.n = fields?.n ?? this.n;
this.model = fields?.model ?? this.model;
} else this.idOrName = fields?.idOrName;

this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks;
this.serviceUrl = fields?.serviceUrl;
this.streaming = fields?.streaming ?? this.streaming;
this.n = fields?.n ?? this.n;
this.model = fields?.model ?? this.model;
this.version = fields?.version ?? this.version;
this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks;

const {
watsonxAIApikey,
watsonxAIAuthType,
Expand Down Expand Up @@ -486,6 +520,11 @@ export class ChatWatsonx<
}

invocationParams(options: this["ParsedCallOptions"]) {
const { signal, promptIndex, ...rest } = options;
if (this.idOrName && Object.keys(rest).length > 0)
throw new Error("Options cannot be provided to a deployed model");
if (this.idOrName) return undefined;

const params = {
maxTokens: options.maxTokens ?? this.maxTokens,
temperature: options?.temperature ?? this.temperature,
Expand Down Expand Up @@ -521,10 +560,16 @@ export class ChatWatsonx<
} as CallOptions);
}

scopeId() {
scopeId():
| { idOrName: string }
| { projectId: string; modelId: string }
| { spaceId: string; modelId: string } {
if (this.projectId)
return { projectId: this.projectId, modelId: this.model };
else return { spaceId: this.spaceId, modelId: this.model };
else if (this.spaceId)
return { spaceId: this.spaceId, modelId: this.model };
else if (this.idOrName) return { idOrName: this.idOrName };
else throw new Error("No scope id provided");
}

async completionWithRetry<T>(
Expand Down Expand Up @@ -595,23 +640,30 @@ export class ChatWatsonx<
.map(([_, value]) => value);
return { generations, llmOutput: { tokenUsage } };
} else {
const params = {
...this.invocationParams(options),
...this.scopeId(),
};
const params = this.invocationParams(options);
const scopeId = this.scopeId();
const watsonxCallbacks = this.invocationCallbacks(options);
const watsonxMessages = _convertMessagesToWatsonxMessages(
messages,
this.model
);
const callback = () =>
this.service.textChat(
{
...params,
messages: watsonxMessages,
},
watsonxCallbacks
);
"idOrName" in scopeId
? this.service.deploymentsTextChat(
{
...scopeId,
messages: watsonxMessages,
},
watsonxCallbacks
)
: this.service.textChat(
{
...params,
...scopeId,
messages: watsonxMessages,
},
watsonxCallbacks
);
const { result } = await this.completionWithRetry(callback, options);
const generations: ChatGeneration[] = [];
for (const part of result.choices) {
Expand Down Expand Up @@ -646,21 +698,33 @@ export class ChatWatsonx<
options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const params = { ...this.invocationParams(options), ...this.scopeId() };
const params = this.invocationParams(options);
const scopeId = this.scopeId();
const watsonxMessages = _convertMessagesToWatsonxMessages(
messages,
this.model
);
const watsonxCallbacks = this.invocationCallbacks(options);
const callback = () =>
this.service.textChatStream(
{
...params,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
);
"idOrName" in scopeId
? this.service.deploymentsTextChatStream(
{
...scopeId,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
)
: this.service.textChatStream(
{
...params,
...scopeId,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
);

const stream = await this.completionWithRetry(callback, options);
let defaultRole;
let usage: TextChatUsage | undefined;
Expand Down Expand Up @@ -707,7 +771,6 @@ export class ChatWatsonx<
if (message === null || (!delta.content && !delta.tool_calls)) {
continue;
}

const generationChunk = new ChatGenerationChunk({
message,
text: delta.content ?? "",
Expand Down
35 changes: 33 additions & 2 deletions libs/langchain-community/src/chat_models/tests/ibm.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { ChatWatsonx } from "../ibm.js";

describe("Tests for chat", () => {
describe("Test ChatWatsonx invoke and generate", () => {
test("Basic invoke", async () => {
test("Basic invoke with projectId", async () => {
const service = new ChatWatsonx({
model: "mistralai/mistral-large",
version: "2024-05-31",
Expand All @@ -26,6 +26,37 @@ describe("Tests for chat", () => {
const res = await service.invoke("Print hello world");
expect(res).toBeInstanceOf(AIMessage);
});
test("Basic invoke with spaceId", async () => {
const service = new ChatWatsonx({
model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
spaceId: process.env.WATSONX_AI_SPACE_ID ?? "testString",
});
const res = await service.invoke("Print hello world");
expect(res).toBeInstanceOf(AIMessage);
});
test("Basic invoke with idOrName", async () => {
const service = new ChatWatsonx({
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
idOrName: process.env.WATSONX_AI_ID_OR_NAME ?? "testString",
});
const res = await service.invoke("Print hello world");
expect(res).toBeInstanceOf(AIMessage);
});
test("Invalide invoke with idOrName and options as second argument", async () => {
const service = new ChatWatsonx({
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
idOrName: process.env.WATSONX_AI_ID_OR_NAME ?? "testString",
});
await expect(() =>
service.invoke("Print hello world", {
maxTokens: 100,
})
).rejects.toThrow("Options cannot be provided to a deployed model");
});
test("Basic generate", async () => {
const service = new ChatWatsonx({
model: "mistralai/mistral-large",
Expand Down Expand Up @@ -710,7 +741,7 @@ describe("Tests for chat", () => {

test("Schema with zod and stream", async () => {
const service = new ChatWatsonx({
model: "mistralai/mistral-large",
model: "meta-llama/llama-3-1-70b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
Expand Down
Loading