diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json index 6b59f23ee454..8467575c9c46 100644 --- a/libs/langchain-community/package.json +++ b/libs/langchain-community/package.json @@ -79,7 +79,7 @@ "@gradientai/nodejs-sdk": "^1.2.0", "@huggingface/inference": "^2.6.4", "@huggingface/transformers": "^3.2.3", - "@ibm-cloud/watsonx-ai": "^1.3.0", + "@ibm-cloud/watsonx-ai": "^1.4.0", "@jest/globals": "^29.5.0", "@lancedb/lancedb": "^0.13.0", "@langchain/core": "workspace:*", diff --git a/libs/langchain-community/src/chat_models/ibm.ts b/libs/langchain-community/src/chat_models/ibm.ts index 992419649fb1..17e80922a8db 100644 --- a/libs/langchain-community/src/chat_models/ibm.ts +++ b/libs/langchain-community/src/chat_models/ibm.ts @@ -33,6 +33,7 @@ import { } from "@langchain/core/outputs"; import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { + DeploymentsTextChatParams, RequestCallbacks, TextChatMessagesTextChatMessageAssistant, TextChatParameterTools, @@ -65,7 +66,13 @@ import { import { isZodSchema } from "@langchain/core/utils/types"; import { zodToJsonSchema } from "zod-to-json-schema"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; -import { WatsonxAuth, WatsonxParams } from "../types/ibm.js"; +import { + Neverify, + WatsonxAuth, + WatsonxChatBasicOptions, + WatsonxDeployedParams, + WatsonxParams, +} from "../types/ibm.js"; import { _convertToolCallIdToMistralCompatible, authenticateAndSetInstance, @@ -80,16 +87,24 @@ export interface WatsonxDeltaStream { } export interface WatsonxCallParams - extends Partial> { - maxRetries?: number; - watsonxCallbacks?: RequestCallbacks; -} + extends Partial< + Omit + > {} + +export interface WatsonxCallDeployedParams extends DeploymentsTextChatParams {} + export interface WatsonxCallOptionsChat extends Omit, - WatsonxCallParams { + WatsonxCallParams, + WatsonxChatBasicOptions { promptIndex?: number; tool_choice?: TextChatParameterTools | string | "auto" | "any"; - watsonxCallbacks?: RequestCallbacks; +} + +export interface WatsonxCallOptionsDeployedChat + extends WatsonxCallDeployedParams, + WatsonxChatBasicOptions { + promptIndex?: number; } type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools; @@ -97,10 +112,18 @@ type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools; export interface ChatWatsonxInput extends BaseChatModelParams, WatsonxParams, - WatsonxCallParams { - streaming?: boolean; -} + WatsonxCallParams, + Neverify {} + +export interface ChatWatsonxDeployedInput + extends BaseChatModelParams, + WatsonxDeployedParams, + Neverify {} +export type ChatWatsonxConstructor = BaseChatModelParams & + Partial & + WatsonxDeployedParams & + WatsonxCallParams; function _convertToValidToolId(model: string, tool_call_id: string) { if (model.startsWith("mistralai")) return _convertToolCallIdToMistralCompatible(tool_call_id); @@ -335,10 +358,12 @@ function _convertToolChoiceToWatsonxToolChoice( } export class ChatWatsonx< - CallOptions extends WatsonxCallOptionsChat = WatsonxCallOptionsChat + CallOptions extends WatsonxCallOptionsChat = + | WatsonxCallOptionsChat + | WatsonxCallOptionsDeployedChat > extends BaseChatModel - implements ChatWatsonxInput + implements ChatWatsonxConstructor { static lc_name() { return "ChatWatsonx"; @@ -380,8 +405,8 @@ export class ChatWatsonx< ls_provider: "watsonx", ls_model_name: this.model, ls_model_type: "chat", - ls_temperature: params.temperature ?? undefined, - ls_max_tokens: params.maxTokens ?? undefined, + ls_temperature: params?.temperature ?? undefined, + ls_max_tokens: params?.maxTokens ?? undefined, }; } @@ -399,6 +424,8 @@ export class ChatWatsonx< projectId?: string; + idOrName?: string; + frequencyPenalty?: number; logprobs?: boolean; @@ -425,37 +452,44 @@ export class ChatWatsonx< watsonxCallbacks?: RequestCallbacks; - constructor(fields: ChatWatsonxInput & WatsonxAuth) { + constructor( + fields: (ChatWatsonxInput | ChatWatsonxDeployedInput) & WatsonxAuth + ) { super(fields); if ( - (fields.projectId && fields.spaceId) || - (fields.idOrName && fields.projectId) || - (fields.spaceId && fields.idOrName) + ("projectId" in fields && "spaceId" in fields) || + ("projectId" in fields && "idOrName" in fields) || + ("spaceId" in fields && "idOrName" in fields) ) throw new Error("Maximum 1 id type can be specified per instance"); - if (!fields.projectId && !fields.spaceId && !fields.idOrName) + if (!("projectId" in fields || "spaceId" in fields || "idOrName" in fields)) throw new Error( "No id specified! At least id of 1 type has to be specified" ); - this.projectId = fields?.projectId; - this.spaceId = fields?.spaceId; - this.temperature = fields?.temperature; - this.maxRetries = fields?.maxRetries || this.maxRetries; - this.maxConcurrency = fields?.maxConcurrency; - this.frequencyPenalty = fields?.frequencyPenalty; - this.topLogprobs = fields?.topLogprobs; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.presencePenalty = fields?.presencePenalty; - this.topP = fields?.topP; - this.timeLimit = fields?.timeLimit; - this.responseFormat = fields?.responseFormat ?? this.responseFormat; + + if ("model" in fields) { + this.projectId = fields?.projectId; + this.spaceId = fields?.spaceId; + this.temperature = fields?.temperature; + this.maxRetries = fields?.maxRetries || this.maxRetries; + this.maxConcurrency = fields?.maxConcurrency; + this.frequencyPenalty = fields?.frequencyPenalty; + this.topLogprobs = fields?.topLogprobs; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.presencePenalty = fields?.presencePenalty; + this.topP = fields?.topP; + this.timeLimit = fields?.timeLimit; + this.responseFormat = fields?.responseFormat ?? this.responseFormat; + this.streaming = fields?.streaming ?? this.streaming; + this.n = fields?.n ?? this.n; + this.model = fields?.model ?? this.model; + } else this.idOrName = fields?.idOrName; + + this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks; this.serviceUrl = fields?.serviceUrl; - this.streaming = fields?.streaming ?? this.streaming; - this.n = fields?.n ?? this.n; - this.model = fields?.model ?? this.model; this.version = fields?.version ?? this.version; - this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks; + const { watsonxAIApikey, watsonxAIAuthType, @@ -486,6 +520,11 @@ export class ChatWatsonx< } invocationParams(options: this["ParsedCallOptions"]) { + const { signal, promptIndex, ...rest } = options; + if (this.idOrName && Object.keys(rest).length > 0) + throw new Error("Options cannot be provided to a deployed model"); + if (this.idOrName) return undefined; + const params = { maxTokens: options.maxTokens ?? this.maxTokens, temperature: options?.temperature ?? this.temperature, @@ -521,10 +560,16 @@ export class ChatWatsonx< } as CallOptions); } - scopeId() { + scopeId(): + | { idOrName: string } + | { projectId: string; modelId: string } + | { spaceId: string; modelId: string } { if (this.projectId) return { projectId: this.projectId, modelId: this.model }; - else return { spaceId: this.spaceId, modelId: this.model }; + else if (this.spaceId) + return { spaceId: this.spaceId, modelId: this.model }; + else if (this.idOrName) return { idOrName: this.idOrName }; + else throw new Error("No scope id provided"); } async completionWithRetry( @@ -595,23 +640,30 @@ export class ChatWatsonx< .map(([_, value]) => value); return { generations, llmOutput: { tokenUsage } }; } else { - const params = { - ...this.invocationParams(options), - ...this.scopeId(), - }; + const params = this.invocationParams(options); + const scopeId = this.scopeId(); const watsonxCallbacks = this.invocationCallbacks(options); const watsonxMessages = _convertMessagesToWatsonxMessages( messages, this.model ); const callback = () => - this.service.textChat( - { - ...params, - messages: watsonxMessages, - }, - watsonxCallbacks - ); + "idOrName" in scopeId + ? this.service.deploymentsTextChat( + { + ...scopeId, + messages: watsonxMessages, + }, + watsonxCallbacks + ) + : this.service.textChat( + { + ...params, + ...scopeId, + messages: watsonxMessages, + }, + watsonxCallbacks + ); const { result } = await this.completionWithRetry(callback, options); const generations: ChatGeneration[] = []; for (const part of result.choices) { @@ -646,21 +698,33 @@ export class ChatWatsonx< options: this["ParsedCallOptions"], _runManager?: CallbackManagerForLLMRun ): AsyncGenerator { - const params = { ...this.invocationParams(options), ...this.scopeId() }; + const params = this.invocationParams(options); + const scopeId = this.scopeId(); const watsonxMessages = _convertMessagesToWatsonxMessages( messages, this.model ); const watsonxCallbacks = this.invocationCallbacks(options); const callback = () => - this.service.textChatStream( - { - ...params, - messages: watsonxMessages, - returnObject: true, - }, - watsonxCallbacks - ); + "idOrName" in scopeId + ? this.service.deploymentsTextChatStream( + { + ...scopeId, + messages: watsonxMessages, + returnObject: true, + }, + watsonxCallbacks + ) + : this.service.textChatStream( + { + ...params, + ...scopeId, + messages: watsonxMessages, + returnObject: true, + }, + watsonxCallbacks + ); + const stream = await this.completionWithRetry(callback, options); let defaultRole; let usage: TextChatUsage | undefined; @@ -707,7 +771,6 @@ export class ChatWatsonx< if (message === null || (!delta.content && !delta.tool_calls)) { continue; } - const generationChunk = new ChatGenerationChunk({ message, text: delta.content ?? "", diff --git a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts index ae47345a1add..1cdc836ba9c8 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts @@ -16,7 +16,7 @@ import { ChatWatsonx } from "../ibm.js"; describe("Tests for chat", () => { describe("Test ChatWatsonx invoke and generate", () => { - test("Basic invoke", async () => { + test("Basic invoke with projectId", async () => { const service = new ChatWatsonx({ model: "mistralai/mistral-large", version: "2024-05-31", @@ -26,6 +26,37 @@ describe("Tests for chat", () => { const res = await service.invoke("Print hello world"); expect(res).toBeInstanceOf(AIMessage); }); + test("Basic invoke with spaceId", async () => { + const service = new ChatWatsonx({ + model: "mistralai/mistral-large", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", + spaceId: process.env.WATSONX_AI_SPACE_ID ?? "testString", + }); + const res = await service.invoke("Print hello world"); + expect(res).toBeInstanceOf(AIMessage); + }); + test("Basic invoke with idOrName", async () => { + const service = new ChatWatsonx({ + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", + idOrName: process.env.WATSONX_AI_ID_OR_NAME ?? "testString", + }); + const res = await service.invoke("Print hello world"); + expect(res).toBeInstanceOf(AIMessage); + }); + test("Invalide invoke with idOrName and options as second argument", async () => { + const service = new ChatWatsonx({ + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", + idOrName: process.env.WATSONX_AI_ID_OR_NAME ?? "testString", + }); + await expect(() => + service.invoke("Print hello world", { + maxTokens: 100, + }) + ).rejects.toThrow("Options cannot be provided to a deployed model"); + }); test("Basic generate", async () => { const service = new ChatWatsonx({ model: "mistralai/mistral-large", @@ -710,7 +741,7 @@ describe("Tests for chat", () => { test("Schema with zod and stream", async () => { const service = new ChatWatsonx({ - model: "mistralai/mistral-large", + model: "meta-llama/llama-3-1-70b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", diff --git a/libs/langchain-community/src/chat_models/tests/ibm.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.test.ts index f52a689f6755..b35b59d8ccbd 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.test.ts @@ -1,7 +1,12 @@ /* eslint-disable no-process-env */ /* eslint-disable @typescript-eslint/no-explicit-any */ import WatsonxAiMlVml_v1 from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js"; -import { ChatWatsonx, ChatWatsonxInput, WatsonxCallParams } from "../ibm.js"; +import { + ChatWatsonx, + ChatWatsonxConstructor, + ChatWatsonxInput, + WatsonxCallParams, +} from "../ibm.js"; import { authenticateAndSetInstance } from "../../utils/ibm.js"; const fakeAuthProp = { @@ -13,7 +18,7 @@ export function getKey(key: K): K { } export const testProperties = ( instance: ChatWatsonx, - testProps: ChatWatsonxInput, + testProps: ChatWatsonxConstructor, notExTestProps?: { [key: string]: any } ) => { const checkProperty = ( @@ -24,13 +29,19 @@ export const testProperties = ( Object.keys(testProps).forEach((key) => { const keys = getKey(key); type Type = Pick; - if (typeof testProps[key as keyof T] === "object") - checkProperty(testProps[key as keyof T], instance[key], existing); + checkProperty( + testProps[key as keyof T], + instance[key as keyof typeof instance], + existing + ); else { if (existing) - expect(instance[key as keyof T]).toBe(testProps[key as keyof T]); - else if (instance) expect(instance[key as keyof T]).toBeUndefined(); + expect(instance[key as keyof typeof instance]).toBe( + testProps[key as keyof T] + ); + else if (instance) + expect(instance[key as keyof typeof instance]).toBeUndefined(); } }); }; @@ -62,6 +73,40 @@ describe("LLM unit tests", () => { testProperties(instance, testProps); }); + test("Authenticate with projectId", async () => { + const testProps = { + model: "mistralai/mistral-large", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + + testProperties(instance, testProps); + }); + + test("Authenticate with spaceId", async () => { + const testProps = { + model: "mistralai/mistral-large", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + spaceId: process.env.WATSONX_AI_SPACE_ID || "testString", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + + testProperties(instance, testProps); + }); + + test("Authenticate with idOrName", async () => { + const testProps = { + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + idOrName: process.env.WATSONX_AI_ID_OR_NAME || "testString", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + testProperties(instance, testProps); + }); + test("Test methods after init", () => { const testProps: ChatWatsonxInput = { model: "mistralai/mistral-large", diff --git a/libs/langchain-community/src/llms/ibm.ts b/libs/langchain-community/src/llms/ibm.ts index 75e65fd6873d..97fb287a982a 100644 --- a/libs/langchain-community/src/llms/ibm.ts +++ b/libs/langchain-community/src/llms/ibm.ts @@ -3,7 +3,6 @@ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { BaseLLM, BaseLLMParams } from "@langchain/core/language_models/llms"; import { WatsonXAI } from "@ibm-cloud/watsonx-ai"; import { - DeploymentTextGenProperties, RequestCallbacks, ReturnOptionProperties, TextGenLengthPenalty, @@ -21,9 +20,11 @@ import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { authenticateAndSetInstance } from "../utils/ibm.js"; import { GenerationInfo, + Neverify, ResponseChunk, TokenUsage, WatsonxAuth, + WatsonxDeployedParams, WatsonxParams, } from "../types/ibm.js"; @@ -31,15 +32,7 @@ import { * Input to LLM class. */ -export interface WatsonxCallOptionsLLM extends BaseLanguageModelCallOptions { - maxRetries?: number; - parameters?: Partial; - idOrName?: string; - watsonxCallbacks?: RequestCallbacks; -} - -export interface WatsonxInputLLM extends WatsonxParams, BaseLLMParams { - streaming?: boolean; +export interface WatsonxLLMParams { maxNewTokens?: number; decodingMethod?: TextGenParameters.Constants.DecodingMethod | string; lengthPenalty?: TextGenLengthPenalty; @@ -54,9 +47,36 @@ export interface WatsonxInputLLM extends WatsonxParams, BaseLLMParams { truncateInpuTokens?: number; returnOptions?: ReturnOptionProperties; includeStopSequence?: boolean; +} + +export interface WatsonxDeploymentLLMParams { + idOrName: string; +} + +export interface WatsonxCallOptionsLLM extends BaseLanguageModelCallOptions { + maxRetries?: number; + parameters?: Partial; watsonxCallbacks?: RequestCallbacks; } +export interface WatsonxInputLLM + extends WatsonxParams, + BaseLLMParams, + WatsonxLLMParams, + Neverify {} + +export interface WatsonxDeployedInputLLM + extends WatsonxDeployedParams, + BaseLLMParams, + Neverify { + model?: never; +} + +export type WatsonxLLMConstructor = BaseLLMParams & + WatsonxLLMParams & + Partial & + WatsonxDeployedParams; + /** * Integration with an LLM. */ @@ -64,7 +84,7 @@ export class WatsonxLLM< CallOptions extends WatsonxCallOptionsLLM = WatsonxCallOptionsLLM > extends BaseLLM - implements WatsonxInputLLM + implements WatsonxLLMConstructor { // Used for tracing, replace with the same name as your class static lc_name() { @@ -123,43 +143,51 @@ export class WatsonxLLM< private service: WatsonXAI; - constructor(fields: WatsonxInputLLM & WatsonxAuth) { + constructor( + fields: (WatsonxInputLLM | WatsonxDeployedInputLLM) & WatsonxAuth + ) { super(fields); - this.model = fields.model ?? this.model; - this.version = fields.version; - this.maxNewTokens = fields.maxNewTokens ?? this.maxNewTokens; - this.serviceUrl = fields.serviceUrl; - this.decodingMethod = fields.decodingMethod; - this.lengthPenalty = fields.lengthPenalty; - this.minNewTokens = fields.minNewTokens; - this.randomSeed = fields.randomSeed; - this.stopSequence = fields.stopSequence; - this.temperature = fields.temperature; - this.timeLimit = fields.timeLimit; - this.topK = fields.topK; - this.topP = fields.topP; - this.repetitionPenalty = fields.repetitionPenalty; - this.truncateInpuTokens = fields.truncateInpuTokens; - this.returnOptions = fields.returnOptions; - this.includeStopSequence = fields.includeStopSequence; + + if (fields.model) { + this.model = fields.model ?? this.model; + this.version = fields.version; + this.maxNewTokens = fields.maxNewTokens ?? this.maxNewTokens; + this.serviceUrl = fields.serviceUrl; + this.decodingMethod = fields.decodingMethod; + this.lengthPenalty = fields.lengthPenalty; + this.minNewTokens = fields.minNewTokens; + this.randomSeed = fields.randomSeed; + this.stopSequence = fields.stopSequence; + this.temperature = fields.temperature; + this.timeLimit = fields.timeLimit; + this.topK = fields.topK; + this.topP = fields.topP; + this.repetitionPenalty = fields.repetitionPenalty; + this.truncateInpuTokens = fields.truncateInpuTokens; + this.returnOptions = fields.returnOptions; + this.includeStopSequence = fields.includeStopSequence; + this.projectId = fields?.projectId; + this.spaceId = fields?.spaceId; + } else { + this.idOrName = fields?.idOrName; + } + this.maxRetries = fields.maxRetries || this.maxRetries; this.maxConcurrency = fields.maxConcurrency; this.streaming = fields.streaming || this.streaming; this.watsonxCallbacks = fields.watsonxCallbacks || this.watsonxCallbacks; + if ( - (fields.projectId && fields.spaceId) || - (fields.idOrName && fields.projectId) || - (fields.spaceId && fields.idOrName) + ("projectId" in fields && "spaceId" in fields) || + ("projectId" in fields && "idOrName" in fields) || + ("spaceId" in fields && "idOrName" in fields) ) throw new Error("Maximum 1 id type can be specified per instance"); - if (!fields.projectId && !fields.spaceId && !fields.idOrName) + if (!("projectId" in fields || "spaceId" in fields || "idOrName" in fields)) throw new Error( "No id specified! At least id of 1 type has to be specified" ); - this.projectId = fields?.projectId; - this.spaceId = fields?.spaceId; - this.idOrName = fields?.idOrName; this.serviceUrl = fields?.serviceUrl; const { @@ -215,11 +243,12 @@ export class WatsonxLLM< }; } - invocationParams( - options: this["ParsedCallOptions"] - ): TextGenParameters | DeploymentTextGenProperties { + invocationParams(options: this["ParsedCallOptions"]) { const { parameters } = options; - + const { signal, ...rest } = options; + if (this.idOrName && Object.keys(rest).length > 0) + throw new Error("Options cannot be provided to a deployed model"); + if (this.idOrName) return undefined; return { max_new_tokens: parameters?.maxNewTokens ?? this.maxNewTokens, decoding_method: parameters?.decodingMethod ?? this.decodingMethod, @@ -293,7 +322,7 @@ export class WatsonxLLM< ...requestOptions } = options; const tokenUsage = { generated_token_count: 0, input_token_count: 0 }; - const idOrName = options?.idOrName ?? this.idOrName; + const idOrName = this.idOrName; const parameters = this.invocationParams(options); const watsonxCallbacks = this.invocationCallbacks(options); if (stream) { diff --git a/libs/langchain-community/src/llms/tests/ibm.test.ts b/libs/langchain-community/src/llms/tests/ibm.test.ts index 6237cb1d14c1..0669af2f811b 100644 --- a/libs/langchain-community/src/llms/tests/ibm.test.ts +++ b/libs/langchain-community/src/llms/tests/ibm.test.ts @@ -1,7 +1,7 @@ /* eslint-disable no-process-env */ /* eslint-disable @typescript-eslint/no-explicit-any */ import WatsonxAiMlVml_v1 from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js"; -import { WatsonxLLM, WatsonxInputLLM } from "../ibm.js"; +import { WatsonxLLM, WatsonxInputLLM, WatsonxLLMConstructor } from "../ibm.js"; import { authenticateAndSetInstance } from "../../utils/ibm.js"; import { WatsonxEmbeddings } from "../../embeddings/ibm.js"; @@ -14,7 +14,7 @@ export function getKey(key: K): K { } export const testProperties = ( instance: WatsonxLLM | WatsonxEmbeddings, - testProps: WatsonxInputLLM, + testProps: WatsonxLLMConstructor, notExTestProps?: { [key: string]: any } ) => { const checkProperty = ( @@ -63,6 +63,17 @@ describe("LLM unit tests", () => { testProperties(instance, testProps); }); + test("Test basic properties after init", async () => { + const testProps = { + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + idOrName: process.env.WATSONX_AI_PROJECT_ID || "testString", + }; + const instance = new WatsonxLLM({ ...testProps, ...fakeAuthProp }); + + testProperties(instance, testProps); + }); + test("Test methods after init", () => { const testProps: WatsonxInputLLM = { model: "ibm/granite-13b-chat-v2", diff --git a/libs/langchain-community/src/types/ibm.ts b/libs/langchain-community/src/types/ibm.ts index ee5db8532036..f5d4b72de7b4 100644 --- a/libs/langchain-community/src/types/ibm.ts +++ b/libs/langchain-community/src/types/ibm.ts @@ -1,3 +1,5 @@ +import { RequestCallbacks } from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js"; + export interface TokenUsage { generated_token_count: number; input_token_count: number; @@ -17,13 +19,27 @@ export interface WatsonxInit { version: string; } -export interface WatsonxParams extends WatsonxInit { +export interface WatsonxChatBasicOptions { + maxConcurrency?: number; + maxRetries?: number; + streaming?: boolean; + watsonxCallbacks?: RequestCallbacks; +} + +export interface WatsonxParams extends WatsonxInit, WatsonxChatBasicOptions { model: string; spaceId?: string; projectId?: string; +} + +export type Neverify = { + [K in keyof T]?: never; +}; + +export interface WatsonxDeployedParams + extends WatsonxInit, + WatsonxChatBasicOptions { idOrName?: string; - maxConcurrency?: number; - maxRetries?: number; } export interface GenerationInfo { diff --git a/libs/langchain-community/src/utils/ibm.ts b/libs/langchain-community/src/utils/ibm.ts index ccbe1204ef60..8786a0263198 100644 --- a/libs/langchain-community/src/utils/ibm.ts +++ b/libs/langchain-community/src/utils/ibm.ts @@ -184,10 +184,18 @@ export class WatsonxToolsOutputParser< const tool = message.tool_calls; return tool; }); + if (tools[0] === undefined) { - if (this.latestCorrect) tools.push(this.latestCorrect); + if (this.latestCorrect) { + tools.push(this.latestCorrect); + } else { + const toolCall: ToolCall = { name: "", args: {} }; + tools.push(toolCall); + } } + const [tool] = tools; + tool.name = ""; this.latestCorrect = tool; return tool.args as T; } diff --git a/yarn.lock b/yarn.lock index fefd13294652..fe891b2b0113 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10669,14 +10669,15 @@ __metadata: languageName: node linkType: hard -"@ibm-cloud/watsonx-ai@npm:^1.3.0": - version: 1.3.0 - resolution: "@ibm-cloud/watsonx-ai@npm:1.3.0" +"@ibm-cloud/watsonx-ai@npm:^1.4.0": + version: 1.4.0 + resolution: "@ibm-cloud/watsonx-ai@npm:1.4.0" dependencies: + "@langchain/textsplitters": ^0.1.0 "@types/node": ^18.0.0 extend: 3.0.2 ibm-cloud-sdk-core: ^5.0.2 - checksum: 6a2127391ca70005b942d3c4ab1abc738946c42bbf3ee0f8eb6f778434b5f8806d622f1f36446f00b9fb82dc2c8aea3526426ec46cc53fa8a075ba7a294da096 + checksum: 5250816f9ad93839cf26e3788eeace8155721765c39c65547eff8ebbd5fc8a0dfa107f6e799593f1209f4b3489be24aa674aa92b7ecbc5fc2bd29390a28e84ff languageName: node linkType: hard @@ -11899,7 +11900,7 @@ __metadata: "@gradientai/nodejs-sdk": ^1.2.0 "@huggingface/inference": ^2.6.4 "@huggingface/transformers": ^3.2.3 - "@ibm-cloud/watsonx-ai": ^1.3.0 + "@ibm-cloud/watsonx-ai": ^1.4.0 "@jest/globals": ^29.5.0 "@lancedb/lancedb": ^0.13.0 "@langchain/core": "workspace:*" @@ -13237,7 +13238,7 @@ __metadata: languageName: unknown linkType: soft -"@langchain/textsplitters@>=0.0.0 <0.2.0, @langchain/textsplitters@workspace:*, @langchain/textsplitters@workspace:libs/langchain-textsplitters": +"@langchain/textsplitters@>=0.0.0 <0.2.0, @langchain/textsplitters@^0.1.0, @langchain/textsplitters@workspace:*, @langchain/textsplitters@workspace:libs/langchain-textsplitters": version: 0.0.0-use.local resolution: "@langchain/textsplitters@workspace:libs/langchain-textsplitters" dependencies: