diff --git a/.env.example b/.env.example index 103728fe0c2..0b56317ff31 100644 --- a/.env.example +++ b/.env.example @@ -111,6 +111,26 @@ ANTHROPIC_API_KEY=user_provided BINGAI_TOKEN=user_provided # BINGAI_HOST=https://cn.bing.com +#=================# +# AWS Bedrock # +#=================# + +# BEDROCK_AWS_DEFAULT_REGION=us-east-1 # A default region must be provided +# BEDROCK_AWS_ACCESS_KEY_ID=someAccessKey +# BEDROCK_AWS_SECRET_ACCESS_KEY=someSecretAccessKey + +# Note: This example list is not meant to be exhaustive. If omitted, all known, supported model IDs will be included for you. +# BEDROCK_AWS_MODELS=anthropic.claude-3-5-sonnet-20240620-v1:0,meta.llama3-1-8b-instruct-v1:0 + +# See all Bedrock model IDs here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns + +# Notes on specific models: +# The following models are not support due to not supporting streaming: +# ai21.j2-mid-v1 + +# The following models are not support due to not supporting conversation history: +# ai21.j2-ultra-v1, cohere.command-text-v14, cohere.command-light-text-v14 + #============# # Google # #============# @@ -392,6 +412,7 @@ LDAP_CA_CERT_PATH= # LDAP_LOGIN_USES_USERNAME=true # LDAP_ID= # LDAP_USERNAME= +# LDAP_EMAIL= # LDAP_FULL_NAME= #========================# diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index ccdc68d81b3..00000000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,47 +0,0 @@ -# To get started with Dependabot version updates, you'll need to specify which -# package ecosystems to update and where the package manifests are located. -# Please see the documentation for all configuration options: -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates - -version: 2 -updates: - - package-ecosystem: "npm" # See documentation for possible values - directory: "/api" # Location of package manifests - target-branch: "dev" - versioning-strategy: increase-if-necessary - schedule: - interval: "weekly" - allow: - # Allow both direct and indirect updates for all packages - - dependency-type: "all" - commit-message: - prefix: "npm api prod" - prefix-development: "npm api dev" - include: "scope" - - package-ecosystem: "npm" # See documentation for possible values - directory: "/client" # Location of package manifests - target-branch: "dev" - versioning-strategy: increase-if-necessary - schedule: - interval: "weekly" - allow: - # Allow both direct and indirect updates for all packages - - dependency-type: "all" - commit-message: - prefix: "npm client prod" - prefix-development: "npm client dev" - include: "scope" - - package-ecosystem: "npm" # See documentation for possible values - directory: "/" # Location of package manifests - target-branch: "dev" - versioning-strategy: increase-if-necessary - schedule: - interval: "weekly" - allow: - # Allow both direct and indirect updates for all packages - - dependency-type: "all" - commit-message: - prefix: "npm all prod" - prefix-development: "npm all dev" - include: "scope" - diff --git a/Dockerfile b/Dockerfile index e8530fb58c9..0793f0de11d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# v0.7.5-rc1 +# v0.7.5-rc2 # Base node image FROM node:20-alpine AS node diff --git a/Dockerfile.multi b/Dockerfile.multi index 175563597ae..a32183d82f3 100644 --- a/Dockerfile.multi +++ b/Dockerfile.multi @@ -1,5 +1,5 @@ # Dockerfile.multi -# v0.7.5-rc1 +# v0.7.5-rc2 # Base for all builds FROM node:20-alpine AS base diff --git a/README.md b/README.md index f1d92b50dae..50ccd252b9c 100644 --- a/README.md +++ b/README.md @@ -42,10 +42,10 @@ - 🖥️ UI matching ChatGPT, including Dark mode, Streaming, and latest updates - 🤖 AI model selection: - - OpenAI, Azure OpenAI, BingAI, ChatGPT, Google Vertex AI, Anthropic (Claude), Plugins, Assistants API (including Azure Assistants) + - Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, BingAI, ChatGPT, Google Vertex AI, Plugins, Assistants API (including Azure Assistants) - ✅ Compatible across both **[Remote & Local AI services](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):** - groq, Ollama, Cohere, Mistral AI, Apple MLX, koboldcpp, OpenRouter, together.ai, Perplexity, ShuttleAI, and more -- 🪄 Generative UI with [Code Artifacts](https://youtu.be/GfTj7O4gmd0?si=WJbdnemZpJzBrJo3) +- 🪄 Generative UI with **[Code Artifacts](https://youtu.be/GfTj7O4gmd0?si=WJbdnemZpJzBrJo3)** - Create React, HTML code, and Mermaid diagrams right in chat - 💾 Create, Save, & Share Custom Presets - 🔀 Switch between AI Endpoints and Presets, mid-chat diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 873f2615695..486af95c3f3 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -17,8 +17,8 @@ const { parseParamFromPrompt, createContextHandlers, } = require('./prompts'); +const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); -const { getModelMaxTokens, matchModelName } = require('~/utils'); const { sleep } = require('~/server/utils'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); @@ -64,6 +64,12 @@ class AnthropicClient extends BaseClient { /** Whether or not the model supports Prompt Caching * @type {boolean} */ this.supportsCacheControl; + /** The key for the usage object's input tokens + * @type {string} */ + this.inputTokensKey = 'input_tokens'; + /** The key for the usage object's output tokens + * @type {string} */ + this.outputTokensKey = 'output_tokens'; } setOptions(options) { @@ -114,7 +120,14 @@ class AnthropicClient extends BaseClient { this.options.maxContextTokens ?? getModelMaxTokens(this.modelOptions.model, EModelEndpoint.anthropic) ?? 100000; - this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1500; + this.maxResponseTokens = + this.modelOptions.maxOutputTokens ?? + getModelMaxOutputTokens( + this.modelOptions.model, + this.options.endpointType ?? this.options.endpoint, + this.options.endpointTokenConfig, + ) ?? + 1500; this.maxPromptTokens = this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; @@ -138,17 +151,6 @@ class AnthropicClient extends BaseClient { this.endToken = ''; this.gptEncoder = this.constructor.getTokenizer('cl100k_base'); - if (!this.modelOptions.stop) { - const stopTokens = [this.startToken]; - if (this.endToken && this.endToken !== this.startToken) { - stopTokens.push(this.endToken); - } - stopTokens.push(`${this.userLabel}`); - stopTokens.push('<|diff_marker|>'); - - this.modelOptions.stop = stopTokens; - } - return this; } @@ -200,7 +202,7 @@ class AnthropicClient extends BaseClient { } /** - * Calculates the correct token count for the current message based on the token count map and API usage. + * Calculates the correct token count for the current user message based on the token count map and API usage. * Edge case: If the calculation results in a negative value, it returns the original estimate. * If revisiting a conversation with a chat history entirely composed of token estimates, * the cumulative token count going forward should become more accurate as the conversation progresses. @@ -208,7 +210,7 @@ class AnthropicClient extends BaseClient { * @param {Record} params.tokenCountMap - A map of message IDs to their token counts. * @param {string} params.currentMessageId - The ID of the current message to calculate. * @param {AnthropicStreamUsage} params.usage - The usage object returned by the API. - * @returns {number} The correct token count for the current message. + * @returns {number} The correct token count for the current user message. */ calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) { const originalEstimate = tokenCountMap[currentMessageId] || 0; @@ -680,7 +682,11 @@ class AnthropicClient extends BaseClient { */ checkPromptCacheSupport(modelName) { const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic); - if (modelMatch === 'claude-3-5-sonnet' || modelMatch === 'claude-3-haiku') { + if ( + modelMatch === 'claude-3-5-sonnet' || + modelMatch === 'claude-3-haiku' || + modelMatch === 'claude-3-opus' + ) { return true; } return false; diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 76403880608..51d75d063b0 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -2,6 +2,8 @@ const crypto = require('crypto'); const fetch = require('node-fetch'); const { supportsBalanceCheck, + isAgentsEndpoint, + isParamEndpoint, ErrorTypes, Constants, CacheKeys, @@ -40,6 +42,12 @@ class BaseClient { this.conversationId; /** @type {string} */ this.responseMessageId; + /** The key for the usage object's input tokens + * @type {string} */ + this.inputTokensKey = 'prompt_tokens'; + /** The key for the usage object's output tokens + * @type {string} */ + this.outputTokensKey = 'completion_tokens'; } setOptions() { @@ -66,6 +74,17 @@ class BaseClient { throw new Error('Subclasses attempted to call summarizeMessages without implementing it'); } + /** + * @returns {string} + */ + getResponseModel() { + if (isAgentsEndpoint(this.options.endpoint) && this.options.agent && this.options.agent.id) { + return this.options.agent.id; + } + + return this.modelOptions.model; + } + /** * Abstract method to get the token count for a message. Subclasses must implement this method. * @param {TMessage} responseMessage @@ -217,6 +236,7 @@ class BaseClient { userMessage, conversationId, responseMessageId, + sender: this.sender, }); } @@ -548,6 +568,7 @@ class BaseClient { }); } + /** @type {string|string[]|undefined} */ const completion = await this.sendCompletion(payload, opts); this.abortController.requestCompleted = true; @@ -557,7 +578,7 @@ class BaseClient { parentMessageId: userMessage.messageId, isCreatedByUser: false, isEdited, - model: this.modelOptions.model, + model: this.getResponseModel(), sender: this.sender, promptTokens, iconURL: this.options.iconURL, @@ -567,9 +588,14 @@ class BaseClient { if (typeof completion === 'string') { responseMessage.text = addSpaceIfNeeded(generation) + completion; - } else if (completion) { + } else if ( + Array.isArray(completion) && + isParamEndpoint(this.options.endpoint, this.options.endpointType) + ) { responseMessage.text = ''; responseMessage.content = completion; + } else if (Array.isArray(completion)) { + responseMessage.text = addSpaceIfNeeded(generation) + completion.join(''); } if ( @@ -587,8 +613,8 @@ class BaseClient { * @type {StreamUsage | null} */ const usage = this.getStreamUsage != null ? this.getStreamUsage() : null; - if (usage != null && Number(usage.output_tokens) > 0) { - responseMessage.tokenCount = usage.output_tokens; + if (usage != null && Number(usage[this.outputTokensKey]) > 0) { + responseMessage.tokenCount = usage[this.outputTokensKey]; completionTokens = responseMessage.tokenCount; await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts }); } else { @@ -638,7 +664,7 @@ class BaseClient { /** @type {boolean} */ const shouldUpdateCount = this.calculateCurrentTokenCount != null && - Number(usage.input_tokens) > 0 && + Number(usage[this.inputTokensKey]) > 0 && (this.options.resendFiles || (!this.options.resendFiles && !this.options.attachments?.length)) && !this.options.promptPrefix; diff --git a/api/app/clients/ChatGPTClient.js b/api/app/clients/ChatGPTClient.js index 0a7f6fc7d88..104e9e5ac3f 100644 --- a/api/app/clients/ChatGPTClient.js +++ b/api/app/clients/ChatGPTClient.js @@ -1,19 +1,21 @@ const Keyv = require('keyv'); const crypto = require('crypto'); +const { CohereClient } = require('cohere-ai'); +const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); +const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { + ImageDetail, EModelEndpoint, resolveHeaders, CohereConstants, mapModelToAzureConfig, } = require('librechat-data-provider'); -const { CohereClient } = require('cohere-ai'); -const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); -const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); +const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils'); +const { createContextHandlers } = require('./prompts'); const { createCoherePayload } = require('./llm'); const { Agent, ProxyAgent } = require('undici'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); -const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils'); const CHATGPT_MODEL = 'gpt-3.5-turbo'; const tokenizersCache = {}; @@ -612,21 +614,66 @@ ${botMessage.message} async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) { promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim(); + + // Handle attachments and create augmentedPrompt + if (this.options.attachments) { + const attachments = await this.options.attachments; + const lastMessage = messages[messages.length - 1]; + + if (this.message_file_map) { + this.message_file_map[lastMessage.messageId] = attachments; + } else { + this.message_file_map = { + [lastMessage.messageId]: attachments, + }; + } + + const files = await this.addImageURLs(lastMessage, attachments); + this.options.attachments = files; + + this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text); + } + + if (this.message_file_map) { + this.contextHandlers = createContextHandlers( + this.options.req, + messages[messages.length - 1].text, + ); + } + + // Calculate image token cost and process embedded files + messages.forEach((message, i) => { + if (this.message_file_map && this.message_file_map[message.messageId]) { + const attachments = this.message_file_map[message.messageId]; + for (const file of attachments) { + if (file.embedded) { + this.contextHandlers?.processFile(file); + continue; + } + + messages[i].tokenCount = + (messages[i].tokenCount || 0) + + this.calculateImageTokenCost({ + width: file.width, + height: file.height, + detail: this.options.imageDetail ?? ImageDetail.auto, + }); + } + } + }); + + if (this.contextHandlers) { + this.augmentedPrompt = await this.contextHandlers.createContext(); + promptPrefix = this.augmentedPrompt + promptPrefix; + } + if (promptPrefix) { // If the prompt prefix doesn't end with the end token, add it. if (!promptPrefix.endsWith(`${this.endToken}`)) { promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`; } promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`; - } else { - const currentDateString = new Date().toLocaleDateString('en-us', { - year: 'numeric', - month: 'long', - day: 'numeric', - }); - promptPrefix = `${this.startToken}Instructions:\nYou are ChatGPT, a large language model trained by OpenAI. Respond conversationally.\nCurrent date: ${currentDateString}${this.endToken}\n\n`; } - const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond. const instructionsPayload = { @@ -714,10 +761,6 @@ ${botMessage.message} this.maxResponseTokens, ); - if (this.options.debug) { - console.debug(`Prompt : ${prompt}`); - } - if (isChatGptModel) { return { prompt: [instructionsPayload, messagePayload], context }; } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 4338a29d5a4..4f5e29cac88 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -19,6 +19,7 @@ const { constructAzureURL, getModelMaxTokens, genAzureChatCompletion, + getModelMaxOutputTokens, } = require('~/utils'); const { truncateText, @@ -64,6 +65,9 @@ class OpenAIClient extends BaseClient { /** @type {string | undefined} - The API Completions URL */ this.completionsUrl; + + /** @type {OpenAIUsageMetadata | undefined} */ + this.usage; } // TODO: PluginsClient calls this 3x, unneeded @@ -138,7 +142,8 @@ class OpenAIClient extends BaseClient { const { model } = this.modelOptions; - this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt'); + this.isChatCompletion = + /\bo1\b/i.test(model) || model.includes('gpt') || this.useOpenRouter || !!reverseProxy; this.isChatGptModel = this.isChatCompletion; if ( model.includes('text-davinci') || @@ -169,7 +174,14 @@ class OpenAIClient extends BaseClient { logger.debug('[OpenAIClient] maxContextTokens', this.maxContextTokens); } - this.maxResponseTokens = this.modelOptions.max_tokens || 1024; + this.maxResponseTokens = + this.modelOptions.max_tokens ?? + getModelMaxOutputTokens( + model, + this.options.endpointType ?? this.options.endpoint, + this.options.endpointTokenConfig, + ) ?? + 1024; this.maxPromptTokens = this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; @@ -533,7 +545,8 @@ class OpenAIClient extends BaseClient { promptPrefix = this.augmentedPrompt + promptPrefix; } - if (promptPrefix) { + const isO1Model = /\bo1\b/i.test(this.modelOptions.model); + if (promptPrefix && !isO1Model) { promptPrefix = `Instructions:\n${promptPrefix.trim()}`; instructions = { role: 'system', @@ -561,6 +574,16 @@ class OpenAIClient extends BaseClient { messages, }; + /** EXPERIMENTAL */ + if (promptPrefix && isO1Model) { + const lastUserMessageIndex = payload.findLastIndex((message) => message.role === 'user'); + if (lastUserMessageIndex !== -1) { + payload[ + lastUserMessageIndex + ].content = `${promptPrefix}\n${payload[lastUserMessageIndex].content}`; + } + } + if (tokenCountMap) { tokenCountMap.instructions = instructions?.tokenCount; result.tokenCountMap = tokenCountMap; @@ -621,6 +644,12 @@ class OpenAIClient extends BaseClient { if (completionResult && typeof completionResult === 'string') { reply = completionResult; + } else if ( + completionResult && + typeof completionResult === 'object' && + Array.isArray(completionResult.choices) + ) { + reply = completionResult.choices[0]?.text?.replace(this.endToken, ''); } } else if (typeof opts.onProgress === 'function' || this.options.useChatCompletion) { reply = await this.chatCompletion({ @@ -885,6 +914,60 @@ ${convo} return title; } + /** + * Get stream usage as returned by this client's API response. + * @returns {OpenAIUsageMetadata} The stream usage object. + */ + getStreamUsage() { + if ( + this.usage && + typeof this.usage === 'object' && + 'completion_tokens_details' in this.usage && + this.usage.completion_tokens_details && + typeof this.usage.completion_tokens_details === 'object' && + 'reasoning_tokens' in this.usage.completion_tokens_details + ) { + const outputTokens = Math.abs( + this.usage.completion_tokens_details.reasoning_tokens - this.usage[this.outputTokensKey], + ); + return { + ...this.usage.completion_tokens_details, + [this.inputTokensKey]: this.usage[this.inputTokensKey], + [this.outputTokensKey]: outputTokens, + }; + } + return this.usage; + } + + /** + * Calculates the correct token count for the current user message based on the token count map and API usage. + * Edge case: If the calculation results in a negative value, it returns the original estimate. + * If revisiting a conversation with a chat history entirely composed of token estimates, + * the cumulative token count going forward should become more accurate as the conversation progresses. + * @param {Object} params - The parameters for the calculation. + * @param {Record} params.tokenCountMap - A map of message IDs to their token counts. + * @param {string} params.currentMessageId - The ID of the current message to calculate. + * @param {OpenAIUsageMetadata} params.usage - The usage object returned by the API. + * @returns {number} The correct token count for the current user message. + */ + calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) { + const originalEstimate = tokenCountMap[currentMessageId] || 0; + + if (!usage || typeof usage[this.inputTokensKey] !== 'number') { + return originalEstimate; + } + + tokenCountMap[currentMessageId] = 0; + const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => { + const numCount = Number(count); + return sum + (isNaN(numCount) ? 0 : numCount); + }, 0); + const totalInputTokens = usage[this.inputTokensKey] ?? 0; + + const currentMessageTokens = totalInputTokens - totalTokensFromMap; + return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate; + } + async summarizeMessages({ messagesToRefine, remainingContextTokens }) { logger.debug('[OpenAIClient] Summarizing messages...'); let context = messagesToRefine; @@ -1000,7 +1083,16 @@ ${convo} } } - async recordTokenUsage({ promptTokens, completionTokens, context = 'message' }) { + /** + * @param {object} params + * @param {number} params.promptTokens + * @param {number} params.completionTokens + * @param {OpenAIUsageMetadata} [params.usage] + * @param {string} [params.model] + * @param {string} [params.context='message'] + * @returns {Promise} + */ + async recordTokenUsage({ promptTokens, completionTokens, usage, context = 'message' }) { await spendTokens( { context, @@ -1011,6 +1103,24 @@ ${convo} }, { promptTokens, completionTokens }, ); + + if ( + usage && + typeof usage === 'object' && + 'reasoning_tokens' in usage && + typeof usage.reasoning_tokens === 'number' + ) { + await spendTokens( + { + context: 'reasoning', + model: this.modelOptions.model, + conversationId: this.conversationId, + user: this.user ?? this.options.req.user?.id, + endpointTokenConfig: this.options.endpointTokenConfig, + }, + { completionTokens: usage.reasoning_tokens }, + ); + } } getTokenCountForResponse(response) { @@ -1191,6 +1301,11 @@ ${convo} /** @type {(value: void | PromiseLike) => void} */ let streamResolve; + if (modelOptions.stream && /\bo1\b/i.test(modelOptions.model)) { + delete modelOptions.stream; + delete modelOptions.stop; + } + if (modelOptions.stream) { streamPromise = new Promise((resolve) => { streamResolve = resolve; @@ -1269,9 +1384,11 @@ ${convo} } const { choices } = chatCompletion; + this.usage = chatCompletion.usage; + if (!Array.isArray(choices) || choices.length === 0) { logger.warn('[OpenAIClient] Chat completion response has no choices'); - return intermediateReply; + return intermediateReply.join(''); } const { message, finish_reason } = choices[0] ?? {}; @@ -1281,7 +1398,7 @@ ${convo} if (!message) { logger.warn('[OpenAIClient] Message is undefined in chatCompletion response'); - return intermediateReply; + return intermediateReply.join(''); } if (typeof message.content !== 'string' || message.content.trim() === '') { @@ -1316,7 +1433,7 @@ ${convo} logger.error('[OpenAIClient] Known OpenAI error:', err); return intermediateReply.join(''); } else if (err instanceof OpenAI.APIError) { - if (intermediateReply) { + if (intermediateReply.length > 0) { return intermediateReply.join(''); } else { throw err; diff --git a/api/app/clients/prompts/formatMessages.js b/api/app/clients/prompts/formatMessages.js index 87d5ba7a15f..1ea44f9d290 100644 --- a/api/app/clients/prompts/formatMessages.js +++ b/api/app/clients/prompts/formatMessages.js @@ -142,6 +142,9 @@ const formatAgentMessages = (payload) => { const messages = []; for (const message of payload) { + if (typeof message.content === 'string') { + message.content = [{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: message.content }]; + } if (message.role !== 'assistant') { messages.push(formatMessage({ message, langChain: true })); continue; @@ -170,7 +173,15 @@ const formatAgentMessages = (payload) => { } // Note: `tool_calls` list is defined when constructed by `AIMessage` class, and outputs should be excluded from it - const { output, ...tool_call } = part.tool_call; + const { output, args: _args, ...tool_call } = part.tool_call; + // TODO: investigate; args as dictionary may need to be provider-or-tool-specific + let args = _args; + try { + args = JSON.parse(args); + } catch (e) { + // failed to parse, leave as is + } + tool_call.args = args; lastAIMessage.tool_calls.push(tool_call); // Add the corresponding ToolMessage diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index e3cc1515c56..0fdc6ce16c0 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -565,11 +565,13 @@ describe('BaseClient', () => { const getReqData = jest.fn(); const opts = { getReqData }; const response = await TestClient.sendMessage('Hello, world!', opts); - expect(getReqData).toHaveBeenCalledWith({ - userMessage: expect.objectContaining({ text: 'Hello, world!' }), - conversationId: response.conversationId, - responseMessageId: response.messageId, - }); + expect(getReqData).toHaveBeenCalledWith( + expect.objectContaining({ + userMessage: expect.objectContaining({ text: 'Hello, world!' }), + conversationId: response.conversationId, + responseMessageId: response.messageId, + }), + ); }); test('onStart is called with the correct arguments', async () => { diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 45903984193..0725efd9d83 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -611,15 +611,7 @@ describe('OpenAIClient', () => { expect(getCompletion).toHaveBeenCalled(); expect(getCompletion.mock.calls.length).toBe(1); - const currentDateString = new Date().toLocaleDateString('en-us', { - year: 'numeric', - month: 'long', - day: 'numeric', - }); - - expect(getCompletion.mock.calls[0][0]).toBe( - `||>Instructions:\nYou are ChatGPT, a large language model trained by OpenAI. Respond conversationally.\nCurrent date: ${currentDateString}\n\n||>User:\nHi mom!\n||>Assistant:\n`, - ); + expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n'); expect(fetchEventSource).toHaveBeenCalled(); expect(fetchEventSource.mock.calls.length).toBe(1); @@ -701,4 +693,70 @@ describe('OpenAIClient', () => { expect(client.modelOptions.stop).toBeUndefined(); }); }); + + describe('getStreamUsage', () => { + it('should return this.usage when completion_tokens_details is null', () => { + const client = new OpenAIClient('test-api-key', defaultOptions); + client.usage = { + completion_tokens_details: null, + prompt_tokens: 10, + completion_tokens: 20, + }; + client.inputTokensKey = 'prompt_tokens'; + client.outputTokensKey = 'completion_tokens'; + + const result = client.getStreamUsage(); + + expect(result).toEqual(client.usage); + }); + + it('should return this.usage when completion_tokens_details is missing reasoning_tokens', () => { + const client = new OpenAIClient('test-api-key', defaultOptions); + client.usage = { + completion_tokens_details: { + other_tokens: 5, + }, + prompt_tokens: 10, + completion_tokens: 20, + }; + client.inputTokensKey = 'prompt_tokens'; + client.outputTokensKey = 'completion_tokens'; + + const result = client.getStreamUsage(); + + expect(result).toEqual(client.usage); + }); + + it('should calculate output tokens correctly when completion_tokens_details is present with reasoning_tokens', () => { + const client = new OpenAIClient('test-api-key', defaultOptions); + client.usage = { + completion_tokens_details: { + reasoning_tokens: 30, + other_tokens: 5, + }, + prompt_tokens: 10, + completion_tokens: 20, + }; + client.inputTokensKey = 'prompt_tokens'; + client.outputTokensKey = 'completion_tokens'; + + const result = client.getStreamUsage(); + + expect(result).toEqual({ + reasoning_tokens: 30, + other_tokens: 5, + prompt_tokens: 10, + completion_tokens: 10, // |30 - 20| = 10 + }); + }); + + it('should return this.usage when it is undefined', () => { + const client = new OpenAIClient('test-api-key', defaultOptions); + client.usage = undefined; + + const result = client.getStreamUsage(); + + expect(result).toBeUndefined(); + }); + }); }); diff --git a/api/app/clients/tools/AzureAiSearch.js b/api/app/clients/tools/AzureAiSearch.js index 9b50aa2c433..1e20b9ce81d 100644 --- a/api/app/clients/tools/AzureAiSearch.js +++ b/api/app/clients/tools/AzureAiSearch.js @@ -77,7 +77,7 @@ class AzureAISearch extends StructuredTool { try { const searchOption = { queryType: this.queryType, - top: this.top, + top: typeof this.top === 'string' ? Number(this.top) : this.top, }; if (this.select) { searchOption.select = this.select.split(','); diff --git a/api/app/clients/tools/structured/AzureAISearch.js b/api/app/clients/tools/structured/AzureAISearch.js index 0ce7b43fb21..1a8c3e9e6e5 100644 --- a/api/app/clients/tools/structured/AzureAISearch.js +++ b/api/app/clients/tools/structured/AzureAISearch.js @@ -83,7 +83,7 @@ class AzureAISearch extends StructuredTool { try { const searchOption = { queryType: this.queryType, - top: this.top, + top: typeof this.top === 'string' ? Number(this.top) : this.top, }; if (this.select) { searchOption.select = this.select.split(','); diff --git a/api/models/Action.js b/api/models/Action.js index 7971f3e61a3..299b3bf20a3 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -5,17 +5,16 @@ const Action = mongoose.model('action', actionSchema); /** * Update an action with new data without overwriting existing properties, - * or create a new action if it doesn't exist, within a transaction session if provided. + * or create a new action if it doesn't exist. * * @param {Object} searchParams - The search parameters to find the action to update. * @param {string} searchParams.action_id - The ID of the action to update. * @param {string} searchParams.user - The user ID of the action's author. * @param {Object} updateData - An object containing the properties to update. - * @param {mongoose.ClientSession} [session] - The transaction session to use. * @returns {Promise} The updated or newly created action document as a plain object. */ -const updateAction = async (searchParams, updateData, session = null) => { - const options = { new: true, upsert: true, session }; +const updateAction = async (searchParams, updateData) => { + const options = { new: true, upsert: true }; return await Action.findOneAndUpdate(searchParams, updateData, options).lean(); }; @@ -49,31 +48,27 @@ const getActions = async (searchParams, includeSensitive = false) => { }; /** - * Deletes an action by params, within a transaction session if provided. + * Deletes an action by params. * * @param {Object} searchParams - The search parameters to find the action to delete. * @param {string} searchParams.action_id - The ID of the action to delete. * @param {string} searchParams.user - The user ID of the action's author. - * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} A promise that resolves to the deleted action document as a plain object, or null if no document was found. */ -const deleteAction = async (searchParams, session = null) => { - const options = session ? { session } : {}; - return await Action.findOneAndDelete(searchParams, options).lean(); +const deleteAction = async (searchParams) => { + return await Action.findOneAndDelete(searchParams).lean(); }; /** - * Deletes actions by params, within a transaction session if provided. + * Deletes actions by params. * * @param {Object} searchParams - The search parameters to find the actions to delete. * @param {string} searchParams.action_id - The ID of the action(s) to delete. * @param {string} searchParams.user - The user ID of the action's author. - * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} A promise that resolves to the number of deleted action documents. */ -const deleteActions = async (searchParams, session = null) => { - const options = session ? { session } : {}; - const result = await Action.deleteMany(searchParams, options); +const deleteActions = async (searchParams) => { + const result = await Action.deleteMany(searchParams); return result.deletedCount; }; diff --git a/api/models/Agent.js b/api/models/Agent.js index 1ee783b101e..2112a44991f 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -1,4 +1,11 @@ const mongoose = require('mongoose'); +const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; +const { + getProjectByName, + addAgentIdsToProject, + removeAgentIdsFromProject, + removeAgentFromAllProjects, +} = require('./Project'); const agentSchema = require('./schema/agent'); const Agent = mongoose.model('agent', agentSchema); @@ -24,18 +31,17 @@ const createAgent = async (agentData) => { const getAgent = async (searchParameter) => await Agent.findOne(searchParameter).lean(); /** - * Update an agent with new data without overwriting existing properties, - * or create a new agent if it doesn't exist, within a transaction session if provided. + * Update an agent with new data without overwriting existing + * properties, or create a new agent if it doesn't exist. * * @param {Object} searchParameter - The search parameters to find the agent to update. * @param {string} searchParameter.id - The ID of the agent to update. - * @param {string} searchParameter.author - The user ID of the agent's author. + * @param {string} [searchParameter.author] - The user ID of the agent's author. * @param {Object} updateData - An object containing the properties to update. - * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} The updated or newly created agent document as a plain object. */ -const updateAgent = async (searchParameter, updateData, session = null) => { - const options = { new: true, upsert: true, session }; +const updateAgent = async (searchParameter, updateData) => { + const options = { new: true, upsert: true }; return await Agent.findOneAndUpdate(searchParameter, updateData, options).lean(); }; @@ -44,11 +50,15 @@ const updateAgent = async (searchParameter, updateData, session = null) => { * * @param {Object} searchParameter - The search parameters to find the agent to delete. * @param {string} searchParameter.id - The ID of the agent to delete. - * @param {string} searchParameter.author - The user ID of the agent's author. + * @param {string} [searchParameter.author] - The user ID of the agent's author. * @returns {Promise} Resolves when the agent has been successfully deleted. */ const deleteAgent = async (searchParameter) => { - return await Agent.findOneAndDelete(searchParameter); + const agent = await Agent.findOneAndDelete(searchParameter); + if (agent) { + await removeAgentFromAllProjects(agent.id); + } + return agent; }; /** @@ -58,11 +68,24 @@ const deleteAgent = async (searchParameter) => { * @returns {Promise} A promise that resolves to an object containing the agents data and pagination info. */ const getListAgents = async (searchParameter) => { - const agents = await Agent.find(searchParameter, { + const { author, ...otherParams } = searchParameter; + + let query = Object.assign({ author }, otherParams); + + const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, ['agentIds']); + if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) { + const globalQuery = { id: { $in: globalProject.agentIds }, ...otherParams }; + delete globalQuery.author; + query = { $or: [globalQuery, query] }; + } + + const agents = await Agent.find(query, { id: 1, name: 1, avatar: 1, + projectIds: 1, }).lean(); + const hasMore = agents.length > 0; const firstId = agents.length > 0 ? agents[0].id : null; const lastId = agents.length > 0 ? agents[agents.length - 1].id : null; @@ -75,10 +98,45 @@ const getListAgents = async (searchParameter) => { }; }; +/** + * Updates the projects associated with an agent, adding and removing project IDs as specified. + * This function also updates the corresponding projects to include or exclude the agent ID. + * + * @param {string} agentId - The ID of the agent to update. + * @param {string[]} [projectIds] - Array of project IDs to add to the agent. + * @param {string[]} [removeProjectIds] - Array of project IDs to remove from the agent. + * @returns {Promise} The updated agent document. + * @throws {Error} If there's an error updating the agent or projects. + */ +const updateAgentProjects = async (agentId, projectIds, removeProjectIds) => { + const updateOps = {}; + + if (removeProjectIds && removeProjectIds.length > 0) { + for (const projectId of removeProjectIds) { + await removeAgentIdsFromProject(projectId, [agentId]); + } + updateOps.$pull = { projectIds: { $in: removeProjectIds } }; + } + + if (projectIds && projectIds.length > 0) { + for (const projectId of projectIds) { + await addAgentIdsToProject(projectId, [agentId]); + } + updateOps.$addToSet = { projectIds: { $each: projectIds } }; + } + + if (Object.keys(updateOps).length === 0) { + return await getAgent({ id: agentId }); + } + + return await updateAgent({ id: agentId }, updateOps); +}; + module.exports = { createAgent, getAgent, updateAgent, deleteAgent, getListAgents, + updateAgentProjects, }; diff --git a/api/models/Assistant.js b/api/models/Assistant.js index 2c98287a889..d0e73ad4e7b 100644 --- a/api/models/Assistant.js +++ b/api/models/Assistant.js @@ -5,17 +5,16 @@ const Assistant = mongoose.model('assistant', assistantSchema); /** * Update an assistant with new data without overwriting existing properties, - * or create a new assistant if it doesn't exist, within a transaction session if provided. + * or create a new assistant if it doesn't exist. * * @param {Object} searchParams - The search parameters to find the assistant to update. * @param {string} searchParams.assistant_id - The ID of the assistant to update. * @param {string} searchParams.user - The user ID of the assistant's author. * @param {Object} updateData - An object containing the properties to update. - * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} The updated or newly created assistant document as a plain object. */ -const updateAssistantDoc = async (searchParams, updateData, session = null) => { - const options = { new: true, upsert: true, session }; +const updateAssistantDoc = async (searchParams, updateData) => { + const options = { new: true, upsert: true }; return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean(); }; diff --git a/api/models/Banner.js b/api/models/Banner.js new file mode 100644 index 00000000000..8d439dae289 --- /dev/null +++ b/api/models/Banner.js @@ -0,0 +1,27 @@ +const Banner = require('./schema/banner'); +const logger = require('~/config/winston'); +/** + * Retrieves the current active banner. + * @returns {Promise} The active banner object or null if no active banner is found. + */ +const getBanner = async (user) => { + try { + const now = new Date(); + const banner = await Banner.findOne({ + displayFrom: { $lte: now }, + $or: [{ displayTo: { $gte: now } }, { displayTo: null }], + type: 'banner', + }).lean(); + + if (!banner || banner.isPublic || user) { + return banner; + } + + return null; + } catch (error) { + logger.error('[getBanners] Error getting banners', error); + throw new Error('Error getting banners'); + } +}; + +module.exports = { getBanner }; diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 19622ba7962..0850ed0a71b 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -31,9 +31,39 @@ const getConvo = async (user, conversationId) => { } }; +const deleteNullOrEmptyConversations = async () => { + try { + const filter = { + $or: [ + { conversationId: null }, + { conversationId: '' }, + { conversationId: { $exists: false } }, + ], + }; + + const result = await Conversation.deleteMany(filter); + + // Delete associated messages + const messageDeleteResult = await deleteMessages(filter); + + logger.info( + `[deleteNullOrEmptyConversations] Deleted ${result.deletedCount} conversations and ${messageDeleteResult.deletedCount} messages`, + ); + + return { + conversations: result, + messages: messageDeleteResult, + }; + } catch (error) { + logger.error('[deleteNullOrEmptyConversations] Error deleting conversations', error); + throw new Error('Error deleting conversations with null or empty conversationId'); + } +}; + module.exports = { Conversation, searchConversation, + deleteNullOrEmptyConversations, /** * Saves a conversation to the database. * @param {Object} req - The request object. diff --git a/api/models/Project.js b/api/models/Project.js index e982e34b5d6..17ef3093a52 100644 --- a/api/models/Project.js +++ b/api/models/Project.js @@ -1,4 +1,5 @@ const { model } = require('mongoose'); +const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; const projectSchema = require('~/models/schema/projectSchema'); const Project = model('Project', projectSchema); @@ -33,7 +34,7 @@ const getProjectByName = async function (projectName, fieldsToSelect = null) { const update = { $setOnInsert: { name: projectName } }; const options = { new: true, - upsert: projectName === 'instance', + upsert: projectName === GLOBAL_PROJECT_NAME, lean: true, select: fieldsToSelect, }; @@ -81,10 +82,55 @@ const removeGroupFromAllProjects = async (promptGroupId) => { await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } }); }; +/** + * Add an array of agent IDs to a project's agentIds array, ensuring uniqueness. + * + * @param {string} projectId - The ID of the project to update. + * @param {string[]} agentIds - The array of agent IDs to add to the project. + * @returns {Promise} The updated project document. + */ +const addAgentIdsToProject = async function (projectId, agentIds) { + return await Project.findByIdAndUpdate( + projectId, + { $addToSet: { agentIds: { $each: agentIds } } }, + { new: true }, + ); +}; + +/** + * Remove an array of agent IDs from a project's agentIds array. + * + * @param {string} projectId - The ID of the project to update. + * @param {string[]} agentIds - The array of agent IDs to remove from the project. + * @returns {Promise} The updated project document. + */ +const removeAgentIdsFromProject = async function (projectId, agentIds) { + return await Project.findByIdAndUpdate( + projectId, + { $pull: { agentIds: { $in: agentIds } } }, + { new: true }, + ); +}; + +/** + * Remove an agent ID from all projects. + * + * @param {string} agentId - The ID of the agent to remove from projects. + * @returns {Promise} + */ +const removeAgentFromAllProjects = async (agentId) => { + await Project.updateMany({}, { $pull: { agentIds: agentId } }); +}; + module.exports = { getProjectById, getProjectByName, + /* prompts */ addGroupIdsToProject, removeGroupIdsFromProject, removeGroupFromAllProjects, + /* agents */ + addAgentIdsToProject, + removeAgentIdsFromProject, + removeAgentFromAllProjects, }; diff --git a/api/models/Prompt.js b/api/models/Prompt.js index 56dcd785709..548589b4d7e 100644 --- a/api/models/Prompt.js +++ b/api/models/Prompt.js @@ -1,5 +1,5 @@ const { ObjectId } = require('mongodb'); -const { SystemRoles, SystemCategories } = require('librechat-data-provider'); +const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider'); const { getProjectByName, addGroupIdsToProject, @@ -123,7 +123,7 @@ const getAllPromptGroups = async (req, filter) => { let combinedQuery = query; if (searchShared) { - const project = await getProjectByName('instance', 'promptGroupIds'); + const project = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, 'promptGroupIds'); if (project && project.promptGroupIds.length > 0) { const projectQuery = { _id: { $in: project.promptGroupIds }, ...query }; delete projectQuery.author; @@ -177,7 +177,7 @@ const getPromptGroups = async (req, filter) => { if (searchShared) { // const projects = req.user.projects || []; // TODO: handle multiple projects - const project = await getProjectByName('instance', 'promptGroupIds'); + const project = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, 'promptGroupIds'); if (project && project.promptGroupIds.length > 0) { const projectQuery = { _id: { $in: project.promptGroupIds }, ...query }; delete projectQuery.author; diff --git a/api/models/Role.js b/api/models/Role.js index d21efee3b88..9c160512b7d 100644 --- a/api/models/Role.js +++ b/api/models/Role.js @@ -4,8 +4,10 @@ const { roleDefaults, PermissionTypes, removeNullishValues, + agentPermissionsSchema, promptPermissionsSchema, bookmarkPermissionsSchema, + multiConvoPermissionsSchema, } = require('librechat-data-provider'); const getLogStores = require('~/cache/getLogStores'); const Role = require('~/models/schema/roleSchema'); @@ -71,8 +73,10 @@ const updateRoleByName = async function (roleName, updates) { }; const permissionSchemas = { + [PermissionTypes.AGENTS]: agentPermissionsSchema, [PermissionTypes.PROMPTS]: promptPermissionsSchema, [PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema, + [PermissionTypes.MULTI_CONVO]: multiConvoPermissionsSchema, }; /** @@ -130,6 +134,7 @@ async function updateAccessPermissions(roleName, permissionsUpdate) { /** * Initialize default roles in the system. * Creates the default roles (ADMIN, USER) if they don't exist in the database. + * Updates existing roles with new permission types if they're missing. * * @returns {Promise} */ @@ -137,14 +142,27 @@ const initializeRoles = async function () { const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER]; for (const roleName of defaultRoles) { - let role = await Role.findOne({ name: roleName }).select('name').lean(); + let role = await Role.findOne({ name: roleName }); + if (!role) { + // Create new role if it doesn't exist role = new Role(roleDefaults[roleName]); - await role.save(); + } else { + // Add missing permission types + let isUpdated = false; + for (const permType of Object.values(PermissionTypes)) { + if (!role[permType]) { + role[permType] = roleDefaults[roleName][permType]; + isUpdated = true; + } + } + if (isUpdated) { + await role.save(); + } } + await role.save(); } }; - module.exports = { getRoleByName, initializeRoles, diff --git a/api/models/Role.spec.js b/api/models/Role.spec.js index c183b9d1c35..92386f0fa91 100644 --- a/api/models/Role.spec.js +++ b/api/models/Role.spec.js @@ -1,9 +1,14 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); -const { SystemRoles, PermissionTypes } = require('librechat-data-provider'); -const Role = require('~/models/schema/roleSchema'); -const { updateAccessPermissions } = require('~/models/Role'); +const { + SystemRoles, + PermissionTypes, + roleDefaults, + Permissions, +} = require('librechat-data-provider'); +const { updateAccessPermissions, initializeRoles } = require('~/models/Role'); const getLogStores = require('~/cache/getLogStores'); +const Role = require('~/models/schema/roleSchema'); // Mock the cache jest.mock('~/cache/getLogStores', () => { @@ -194,4 +199,222 @@ describe('updateAccessPermissions', () => { SHARED_GLOBAL: true, }); }); + + it('should update MULTI_CONVO permissions', async () => { + await new Role({ + name: SystemRoles.USER, + [PermissionTypes.MULTI_CONVO]: { + USE: false, + }, + }).save(); + + await updateAccessPermissions(SystemRoles.USER, { + [PermissionTypes.MULTI_CONVO]: { + USE: true, + }, + }); + + const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({ + USE: true, + }); + }); + + it('should update MULTI_CONVO permissions along with other permission types', async () => { + await new Role({ + name: SystemRoles.USER, + [PermissionTypes.PROMPTS]: { + CREATE: true, + USE: true, + SHARED_GLOBAL: false, + }, + [PermissionTypes.MULTI_CONVO]: { + USE: false, + }, + }).save(); + + await updateAccessPermissions(SystemRoles.USER, { + [PermissionTypes.PROMPTS]: { SHARED_GLOBAL: true }, + [PermissionTypes.MULTI_CONVO]: { USE: true }, + }); + + const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({ + CREATE: true, + USE: true, + SHARED_GLOBAL: true, + }); + expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({ + USE: true, + }); + }); + + it('should not update MULTI_CONVO permissions when no changes are needed', async () => { + await new Role({ + name: SystemRoles.USER, + [PermissionTypes.MULTI_CONVO]: { + USE: true, + }, + }).save(); + + await updateAccessPermissions(SystemRoles.USER, { + [PermissionTypes.MULTI_CONVO]: { + USE: true, + }, + }); + + const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({ + USE: true, + }); + }); +}); + +describe('initializeRoles', () => { + beforeEach(async () => { + await Role.deleteMany({}); + }); + + it('should create default roles if they do not exist', async () => { + await initializeRoles(); + + const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean(); + const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + + expect(adminRole).toBeTruthy(); + expect(userRole).toBeTruthy(); + + // Check if all permission types exist + Object.values(PermissionTypes).forEach((permType) => { + expect(adminRole[permType]).toBeDefined(); + expect(userRole[permType]).toBeDefined(); + }); + + // Check if permissions match defaults (example for ADMIN role) + expect(adminRole[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true); + expect(adminRole[PermissionTypes.BOOKMARKS].USE).toBe(true); + expect(adminRole[PermissionTypes.AGENTS].CREATE).toBe(true); + }); + + it('should not modify existing permissions for existing roles', async () => { + const customUserRole = { + name: SystemRoles.USER, + [PermissionTypes.PROMPTS]: { + [Permissions.USE]: false, + [Permissions.CREATE]: true, + [Permissions.SHARED_GLOBAL]: true, + }, + [PermissionTypes.BOOKMARKS]: { + [Permissions.USE]: false, + }, + }; + + await new Role(customUserRole).save(); + + await initializeRoles(); + + const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + + expect(userRole[PermissionTypes.PROMPTS]).toEqual(customUserRole[PermissionTypes.PROMPTS]); + expect(userRole[PermissionTypes.BOOKMARKS]).toEqual(customUserRole[PermissionTypes.BOOKMARKS]); + expect(userRole[PermissionTypes.AGENTS]).toBeDefined(); + }); + + it('should add new permission types to existing roles', async () => { + const partialUserRole = { + name: SystemRoles.USER, + [PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS], + [PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS], + }; + + await new Role(partialUserRole).save(); + + await initializeRoles(); + + const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + + expect(userRole[PermissionTypes.AGENTS]).toBeDefined(); + expect(userRole[PermissionTypes.AGENTS].CREATE).toBeDefined(); + expect(userRole[PermissionTypes.AGENTS].USE).toBeDefined(); + expect(userRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined(); + }); + + it('should handle multiple runs without duplicating or modifying data', async () => { + await initializeRoles(); + await initializeRoles(); + + const adminRoles = await Role.find({ name: SystemRoles.ADMIN }); + const userRoles = await Role.find({ name: SystemRoles.USER }); + + expect(adminRoles).toHaveLength(1); + expect(userRoles).toHaveLength(1); + + const adminRole = adminRoles[0].toObject(); + const userRole = userRoles[0].toObject(); + + // Check if all permission types exist + Object.values(PermissionTypes).forEach((permType) => { + expect(adminRole[permType]).toBeDefined(); + expect(userRole[permType]).toBeDefined(); + }); + }); + + it('should update roles with missing permission types from roleDefaults', async () => { + const partialAdminRole = { + name: SystemRoles.ADMIN, + [PermissionTypes.PROMPTS]: { + [Permissions.USE]: false, + [Permissions.CREATE]: false, + [Permissions.SHARED_GLOBAL]: false, + }, + [PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.ADMIN][PermissionTypes.BOOKMARKS], + }; + + await new Role(partialAdminRole).save(); + + await initializeRoles(); + + const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean(); + + expect(adminRole[PermissionTypes.PROMPTS]).toEqual(partialAdminRole[PermissionTypes.PROMPTS]); + expect(adminRole[PermissionTypes.AGENTS]).toBeDefined(); + expect(adminRole[PermissionTypes.AGENTS].CREATE).toBeDefined(); + expect(adminRole[PermissionTypes.AGENTS].USE).toBeDefined(); + expect(adminRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined(); + }); + + it('should include MULTI_CONVO permissions when creating default roles', async () => { + await initializeRoles(); + + const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean(); + const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + + expect(adminRole[PermissionTypes.MULTI_CONVO]).toBeDefined(); + expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined(); + + // Check if MULTI_CONVO permissions match defaults + expect(adminRole[PermissionTypes.MULTI_CONVO].USE).toBe( + roleDefaults[SystemRoles.ADMIN][PermissionTypes.MULTI_CONVO].USE, + ); + expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBe( + roleDefaults[SystemRoles.USER][PermissionTypes.MULTI_CONVO].USE, + ); + }); + + it('should add MULTI_CONVO permissions to existing roles without them', async () => { + const partialUserRole = { + name: SystemRoles.USER, + [PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS], + [PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS], + }; + + await new Role(partialUserRole).save(); + + await initializeRoles(); + + const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + + expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined(); + expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBeDefined(); + }); }); diff --git a/api/models/inviteUser.js b/api/models/inviteUser.js index c04bd9467a5..6cd699fd668 100644 --- a/api/models/inviteUser.js +++ b/api/models/inviteUser.js @@ -1,6 +1,5 @@ -const crypto = require('crypto'); -const bcrypt = require('bcryptjs'); const mongoose = require('mongoose'); +const { getRandomValues, hashToken } = require('~/server/utils/crypto'); const { createToken, findToken } = require('./Token'); const logger = require('~/config/winston'); @@ -18,8 +17,8 @@ const logger = require('~/config/winston'); */ const createInvite = async (email) => { try { - let token = crypto.randomBytes(32).toString('hex'); - const hash = bcrypt.hashSync(token, 10); + const token = await getRandomValues(32); + const hash = await hashToken(token); const encodedToken = encodeURIComponent(token); const fakeUserId = new mongoose.Types.ObjectId(); @@ -50,7 +49,7 @@ const createInvite = async (email) => { const getInvite = async (encodedToken, email) => { try { const token = decodeURIComponent(encodedToken); - const hash = bcrypt.hashSync(token, 10); + const hash = await hashToken(token); const invite = await findToken({ token: hash, email }); if (!invite) { @@ -59,7 +58,7 @@ const getInvite = async (encodedToken, email) => { return invite; } catch (error) { - logger.error('[getInvite] Error getting invite', error); + logger.error('[getInvite] Error getting invite:', error); return { error: true, message: error.message }; } }; diff --git a/api/models/schema/agent.js b/api/models/schema/agent.js index 97f0527916c..819398ee7cb 100644 --- a/api/models/schema/agent.js +++ b/api/models/schema/agent.js @@ -57,6 +57,11 @@ const agentSchema = mongoose.Schema( ref: 'User', required: true, }, + projectIds: { + type: [mongoose.Schema.Types.ObjectId], + ref: 'Project', + index: true, + }, }, { timestamps: true, diff --git a/api/models/schema/banner.js b/api/models/schema/banner.js new file mode 100644 index 00000000000..7fd86c1b677 --- /dev/null +++ b/api/models/schema/banner.js @@ -0,0 +1,36 @@ +const mongoose = require('mongoose'); + +const bannerSchema = mongoose.Schema( + { + bannerId: { + type: String, + required: true, + }, + message: { + type: String, + required: true, + }, + displayFrom: { + type: Date, + required: true, + default: Date.now, + }, + displayTo: { + type: Date, + }, + type: { + type: String, + enum: ['banner', 'popup'], + default: 'banner', + }, + isPublic: { + type: Boolean, + default: false, + }, + }, + + { timestamps: true }, +); + +const Banner = mongoose.model('Banner', bannerSchema); +module.exports = Banner; diff --git a/api/models/schema/defaults.js b/api/models/schema/defaults.js index 4a99a683734..6dced3af86c 100644 --- a/api/models/schema/defaults.js +++ b/api/models/schema/defaults.js @@ -13,6 +13,11 @@ const conversationPreset = { type: String, required: false, }, + // for bedrock only + region: { + type: String, + required: false, + }, // for azureOpenAI, openAI only chatGptLabel: { type: String, @@ -78,6 +83,9 @@ const conversationPreset = { promptCache: { type: Boolean, }, + system: { + type: String, + }, // files resendFiles: { type: Boolean, diff --git a/api/models/schema/projectSchema.js b/api/models/schema/projectSchema.js index 0e27c6a8f9f..dfa68a06c22 100644 --- a/api/models/schema/projectSchema.js +++ b/api/models/schema/projectSchema.js @@ -21,6 +21,11 @@ const projectSchema = new Schema( ref: 'PromptGroup', default: [], }, + agentIds: { + type: [String], + ref: 'Agent', + default: [], + }, }, { timestamps: true, diff --git a/api/models/schema/roleSchema.js b/api/models/schema/roleSchema.js index ebd1d0bc4b2..36e9d3f7b6e 100644 --- a/api/models/schema/roleSchema.js +++ b/api/models/schema/roleSchema.js @@ -28,6 +28,26 @@ const roleSchema = new mongoose.Schema({ default: true, }, }, + [PermissionTypes.AGENTS]: { + [Permissions.SHARED_GLOBAL]: { + type: Boolean, + default: false, + }, + [Permissions.USE]: { + type: Boolean, + default: true, + }, + [Permissions.CREATE]: { + type: Boolean, + default: true, + }, + }, + [PermissionTypes.MULTI_CONVO]: { + [Permissions.USE]: { + type: Boolean, + default: true, + }, + }, }); const Role = mongoose.model('Role', roleSchema); diff --git a/api/models/tx.js b/api/models/tx.js index 1b515cca2af..062816cf35f 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -3,38 +3,28 @@ const defaultRate = 6; /** AWS Bedrock pricing */ const bedrockValues = { - 'anthropic.claude-3-haiku-20240307-v1:0': { prompt: 0.25, completion: 1.25 }, - 'anthropic.claude-3-sonnet-20240229-v1:0': { prompt: 3.0, completion: 15.0 }, - 'anthropic.claude-3-opus-20240229-v1:0': { prompt: 15.0, completion: 75.0 }, - 'anthropic.claude-3-5-sonnet-20240620-v1:0': { prompt: 3.0, completion: 15.0 }, - 'anthropic.claude-v2:1': { prompt: 8.0, completion: 24.0 }, - 'anthropic.claude-instant-v1': { prompt: 0.8, completion: 2.4 }, - 'meta.llama2-13b-chat-v1': { prompt: 0.75, completion: 1.0 }, - 'meta.llama2-70b-chat-v1': { prompt: 1.95, completion: 2.56 }, - 'meta.llama3-8b-instruct-v1:0': { prompt: 0.3, completion: 0.6 }, - 'meta.llama3-70b-instruct-v1:0': { prompt: 2.65, completion: 3.5 }, - 'meta.llama3-1-8b-instruct-v1:0': { prompt: 0.3, completion: 0.6 }, - 'meta.llama3-1-70b-instruct-v1:0': { prompt: 2.65, completion: 3.5 }, - 'meta.llama3-1-405b-instruct-v1:0': { prompt: 5.32, completion: 16.0 }, - 'mistral.mistral-7b-instruct-v0:2': { prompt: 0.15, completion: 0.2 }, - 'mistral.mistral-small-2402-v1:0': { prompt: 0.15, completion: 0.2 }, - 'mistral.mixtral-8x7b-instruct-v0:1': { prompt: 0.45, completion: 0.7 }, - 'mistral.mistral-large-2402-v1:0': { prompt: 4.0, completion: 12.0 }, - 'mistral.mistral-large-2407-v1:0': { prompt: 3.0, completion: 9.0 }, - 'cohere.command-text-v14': { prompt: 1.5, completion: 2.0 }, - 'cohere.command-light-text-v14': { prompt: 0.3, completion: 0.6 }, - 'cohere.command-r-v1:0': { prompt: 0.5, completion: 1.5 }, - 'cohere.command-r-plus-v1:0': { prompt: 3.0, completion: 15.0 }, + 'llama2-13b': { prompt: 0.75, completion: 1.0 }, + 'llama2-70b': { prompt: 1.95, completion: 2.56 }, + 'llama3-8b': { prompt: 0.3, completion: 0.6 }, + 'llama3-70b': { prompt: 2.65, completion: 3.5 }, + 'llama3-1-8b': { prompt: 0.3, completion: 0.6 }, + 'llama3-1-70b': { prompt: 2.65, completion: 3.5 }, + 'llama3-1-405b': { prompt: 5.32, completion: 16.0 }, + 'mistral-7b': { prompt: 0.15, completion: 0.2 }, + 'mistral-small': { prompt: 0.15, completion: 0.2 }, + 'mixtral-8x7b': { prompt: 0.45, completion: 0.7 }, + 'mistral-large-2402': { prompt: 4.0, completion: 12.0 }, + 'mistral-large-2407': { prompt: 3.0, completion: 9.0 }, + 'command-text': { prompt: 1.5, completion: 2.0 }, + 'command-light': { prompt: 0.3, completion: 0.6 }, 'ai21.j2-mid-v1': { prompt: 12.5, completion: 12.5 }, 'ai21.j2-ultra-v1': { prompt: 18.8, completion: 18.8 }, + 'ai21.jamba-instruct-v1:0': { prompt: 0.5, completion: 0.7 }, 'amazon.titan-text-lite-v1': { prompt: 0.15, completion: 0.2 }, 'amazon.titan-text-express-v1': { prompt: 0.2, completion: 0.6 }, + 'amazon.titan-text-premier-v1:0': { prompt: 0.5, completion: 1.5 }, }; -for (const [key, value] of Object.entries(bedrockValues)) { - bedrockValues[`bedrock/${key}`] = value; -} - /** * Mapping of model token sizes to their respective multipliers for prompt and completion. * The rates are 1 USD per 1M tokens. @@ -47,6 +37,9 @@ const tokenValues = Object.assign( '4k': { prompt: 1.5, completion: 2 }, '16k': { prompt: 3, completion: 4 }, 'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 }, + 'o1-preview': { prompt: 15, completion: 60 }, + 'o1-mini': { prompt: 3, completion: 12 }, + o1: { prompt: 15, completion: 60 }, 'gpt-4o-2024-08-06': { prompt: 2.5, completion: 10 }, 'gpt-4o-mini': { prompt: 0.15, completion: 0.6 }, 'gpt-4o': { prompt: 5, completion: 15 }, @@ -59,6 +52,7 @@ const tokenValues = Object.assign( 'claude-3-haiku': { prompt: 0.25, completion: 1.25 }, 'claude-2.1': { prompt: 8, completion: 24 }, 'claude-2': { prompt: 8, completion: 24 }, + 'claude-instant': { prompt: 0.8, completion: 2.4 }, 'claude-': { prompt: 0.8, completion: 2.4 }, 'command-r-plus': { prompt: 3, completion: 15 }, 'command-r': { prompt: 0.5, completion: 1.5 }, @@ -104,6 +98,12 @@ const getValueKey = (model, endpoint) => { return 'gpt-3.5-turbo-1106'; } else if (modelName.includes('gpt-3.5')) { return '4k'; + } else if (modelName.includes('o1-preview')) { + return 'o1-preview'; + } else if (modelName.includes('o1-mini')) { + return 'o1-mini'; + } else if (modelName.includes('o1')) { + return 'o1'; } else if (modelName.includes('gpt-4o-2024-08-06')) { return 'gpt-4o-2024-08-06'; } else if (modelName.includes('gpt-4o-mini')) { diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index f0a118c0122..c8a8b335e3c 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -1,3 +1,4 @@ +const { EModelEndpoint } = require('librechat-data-provider'); const { defaultRate, tokenValues, @@ -224,34 +225,18 @@ describe('AWS Bedrock Model Tests', () => { it('should return the correct prompt multipliers for all models', () => { const results = awsModels.map((model) => { - const multiplier = getMultiplier({ valueKey: model, tokenType: 'prompt' }); - return multiplier === tokenValues[model].prompt; + const valueKey = getValueKey(model, EModelEndpoint.bedrock); + const multiplier = getMultiplier({ valueKey, tokenType: 'prompt' }); + return tokenValues[valueKey].prompt && multiplier === tokenValues[valueKey].prompt; }); expect(results.every(Boolean)).toBe(true); }); it('should return the correct completion multipliers for all models', () => { const results = awsModels.map((model) => { - const multiplier = getMultiplier({ valueKey: model, tokenType: 'completion' }); - return multiplier === tokenValues[model].completion; - }); - expect(results.every(Boolean)).toBe(true); - }); - - it('should return the correct prompt multipliers for all models with Bedrock prefix', () => { - const results = awsModels.map((model) => { - const modelName = `bedrock/${model}`; - const multiplier = getMultiplier({ valueKey: modelName, tokenType: 'prompt' }); - return multiplier === tokenValues[model].prompt; - }); - expect(results.every(Boolean)).toBe(true); - }); - - it('should return the correct completion multipliers for all models with Bedrock prefix', () => { - const results = awsModels.map((model) => { - const modelName = `bedrock/${model}`; - const multiplier = getMultiplier({ valueKey: modelName, tokenType: 'completion' }); - return multiplier === tokenValues[model].completion; + const valueKey = getValueKey(model, EModelEndpoint.bedrock); + const multiplier = getMultiplier({ valueKey, tokenType: 'completion' }); + return tokenValues[valueKey].completion && multiplier === tokenValues[valueKey].completion; }); expect(results.every(Boolean)).toBe(true); }); diff --git a/api/package.json b/api/package.json index 43d8609a8e5..1c74543a67e 100644 --- a/api/package.json +++ b/api/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/backend", - "version": "v0.7.5-rc1", + "version": "v0.7.5-rc2", "description": "", "scripts": { "start": "echo 'please run this from the root directory'", @@ -43,8 +43,8 @@ "@langchain/core": "^0.2.18", "@langchain/google-genai": "^0.0.11", "@langchain/google-vertexai": "^0.0.17", - "@librechat/agents": "^1.4.1", - "axios": "^1.3.4", + "@librechat/agents": "^1.5.2", + "axios": "^1.7.7", "bcryptjs": "^2.4.3", "cheerio": "^1.0.0-rc.12", "cohere-ai": "^7.9.1", @@ -55,7 +55,7 @@ "cors": "^2.8.5", "dedent": "^1.5.3", "dotenv": "^16.0.3", - "express": "^4.18.2", + "express": "^4.21.0", "express-mongo-sanitize": "^2.2.0", "express-rate-limit": "^6.9.0", "express-session": "^1.17.3", @@ -76,11 +76,11 @@ "meilisearch": "^0.38.0", "mime": "^3.0.0", "module-alias": "^2.2.3", - "mongoose": "^7.1.1", + "mongoose": "^7.3.3", "multer": "^1.4.5-lts.1", "nanoid": "^3.3.7", "nodejs-gpt": "^1.37.4", - "nodemailer": "^6.9.4", + "nodemailer": "^6.9.15", "ollama": "^0.5.0", "openai": "^4.47.1", "openai-chat-tokens": "^0.2.8", @@ -101,7 +101,6 @@ "ua-parser-js": "^1.0.36", "winston": "^3.11.0", "winston-daily-rotate-file": "^4.7.1", - "ws": "^8.17.0", "zod": "^3.22.4" }, "devDependencies": { diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index ce6e0fb1726..d2d774b0092 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -16,7 +16,12 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { overrideParentMessageId = null, } = req.body; - logger.debug('[AskController]', { text, conversationId, ...endpointOption }); + logger.debug('[AskController]', { + text, + conversationId, + ...endpointOption, + modelsConfig: endpointOption.modelsConfig ? 'exists' : '', + }); let userMessage; let userMessagePromise; @@ -123,11 +128,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { }; let response = await client.sendMessage(text, messageOptions); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - response.endpoint = endpointOption.endpoint; const { conversation = {} } = await client.responsePromise; diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index b3b94fcebba..28fe2c4fea1 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -25,6 +25,7 @@ const EditController = async (req, res, next, initializeClient) => { isContinued, conversationId, ...endpointOption, + modelsConfig: endpointOption.modelsConfig ? 'exists' : '', }); let userMessage; diff --git a/api/server/controllers/EndpointController.js b/api/server/controllers/EndpointController.js index d80ea6b14f9..1e716870c32 100644 --- a/api/server/controllers/EndpointController.js +++ b/api/server/controllers/EndpointController.js @@ -44,6 +44,14 @@ async function endpointController(req, res) { }; } + if (mergedConfig[EModelEndpoint.bedrock] && req.app.locals?.[EModelEndpoint.bedrock]) { + const { availableRegions } = req.app.locals[EModelEndpoint.bedrock]; + mergedConfig[EModelEndpoint.bedrock] = { + ...mergedConfig[EModelEndpoint.bedrock], + availableRegions, + }; + } + const endpointsConfig = orderEndpointsConfig(mergedConfig); await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig); diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 9649f56a53b..f6c1972b4fe 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,7 +1,10 @@ const { GraphEvents, ToolEndHandler, ChatModelStreamHandler } = require('@librechat/agents'); +/** @typedef {import('@librechat/agents').Graph} Graph */ /** @typedef {import('@librechat/agents').EventHandler} EventHandler */ +/** @typedef {import('@librechat/agents').ModelEndData} ModelEndData */ /** @typedef {import('@librechat/agents').ChatModelStreamHandler} ChatModelStreamHandler */ +/** @typedef {import('@librechat/agents').ContentAggregatorResult['aggregateContent']} ContentAggregator */ /** @typedef {import('@librechat/agents').GraphEvents} GraphEvents */ /** @@ -18,18 +21,55 @@ const sendEvent = (res, event) => { res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); }; +class ModelEndHandler { + /** + * @param {Array} collectedUsage + */ + constructor(collectedUsage) { + if (!Array.isArray(collectedUsage)) { + throw new Error('collectedUsage must be an array'); + } + this.collectedUsage = collectedUsage; + } + + /** + * @param {string} event + * @param {ModelEndData | undefined} data + * @param {Record | undefined} metadata + * @param {Graph} graph + * @returns + */ + handle(event, data, metadata, graph) { + if (!graph || !metadata) { + console.warn(`Graph or metadata not found in ${event} event`); + return; + } + + const usage = data?.output?.usage_metadata; + + if (usage) { + this.collectedUsage.push(usage); + } + } +} + /** * Get default handlers for stream events. - * @param {{ res?: ServerResponse }} options - The options object. + * @param {Object} options - The options object. + * @param {ServerResponse} options.res - The options object. + * @param {ContentAggregator} options.aggregateContent - The options object. + * @param {Array} options.collectedUsage - The list of collected usage metadata. * @returns {Record} The default handlers. * @throws {Error} If the request is not found. */ -function getDefaultHandlers({ res }) { - if (!res) { - throw new Error('Request not found'); +function getDefaultHandlers({ res, aggregateContent, collectedUsage }) { + if (!res || !aggregateContent) { + throw new Error( + `[getDefaultHandlers] Missing required options: res: ${!res}, aggregateContent: ${!aggregateContent}`, + ); } const handlers = { - // [GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(), + [GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(collectedUsage), [GraphEvents.TOOL_END]: new ToolEndHandler(), [GraphEvents.CHAT_MODEL_STREAM]: new ChatModelStreamHandler(), [GraphEvents.ON_RUN_STEP]: { @@ -40,6 +80,7 @@ function getDefaultHandlers({ res }) { */ handle: (event, data) => { sendEvent(res, { event, data }); + aggregateContent({ event, data }); }, }, [GraphEvents.ON_RUN_STEP_DELTA]: { @@ -50,6 +91,7 @@ function getDefaultHandlers({ res }) { */ handle: (event, data) => { sendEvent(res, { event, data }); + aggregateContent({ event, data }); }, }, [GraphEvents.ON_RUN_STEP_COMPLETED]: { @@ -60,6 +102,7 @@ function getDefaultHandlers({ res }) { */ handle: (event, data) => { sendEvent(res, { event, data }); + aggregateContent({ event, data }); }, }, [GraphEvents.ON_MESSAGE_DELTA]: { @@ -70,6 +113,7 @@ function getDefaultHandlers({ res }) { */ handle: (event, data) => { sendEvent(res, { event, data }); + aggregateContent({ event, data }); }, }, }; diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 82e6a6f48a8..137068ddd6a 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -7,9 +7,11 @@ // validateVisionModel, // mapModelToAzureConfig, // } = require('librechat-data-provider'); -const { Callback } = require('@librechat/agents'); +const { Callback, createMetadataAggregator } = require('@librechat/agents'); const { + Constants, EModelEndpoint, + bedrockOutputParser, providerEndpointMap, removeNullishValues, } = require('librechat-data-provider'); @@ -23,15 +25,27 @@ const { formatAgentMessages, createContextHandlers, } = require('~/app/clients/prompts'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const Tokenizer = require('~/server/services/Tokenizer'); +const { spendTokens } = require('~/models/spendTokens'); const BaseClient = require('~/app/clients/BaseClient'); // const { sleep } = require('~/server/utils'); const { createRun } = require('./run'); const { logger } = require('~/config'); +/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */ + +// const providerSchemas = { +// [EModelEndpoint.bedrock]: true, +// }; + +const providerParsers = { + [EModelEndpoint.bedrock]: bedrockOutputParser, +}; + class AgentClient extends BaseClient { constructor(options = {}) { - super(options); + super(null, options); /** @type {'discard' | 'summarize'} */ this.contextStrategy = 'discard'; @@ -39,11 +53,31 @@ class AgentClient extends BaseClient { /** @deprecated @type {true} - Is a Chat Completion Request */ this.isChatCompletion = true; - const { maxContextTokens, modelOptions = {}, ...clientOptions } = options; + /** @type {AgentRun} */ + this.run; + + const { + maxContextTokens, + modelOptions = {}, + contentParts, + collectedUsage, + ...clientOptions + } = options; this.modelOptions = modelOptions; this.maxContextTokens = maxContextTokens; - this.options = Object.assign({ endpoint: EModelEndpoint.agents }, clientOptions); + /** @type {MessageContentComplex[]} */ + this.contentParts = contentParts; + /** @type {Array} */ + this.collectedUsage = collectedUsage; + this.options = Object.assign({ endpoint: options.endpoint }, clientOptions); + } + + /** + * Returns the aggregated content parts for the current run. + * @returns {MessageContentComplex[]} */ + getContentParts() { + return this.contentParts; } setOptions(options) { @@ -112,9 +146,27 @@ class AgentClient extends BaseClient { } getSaveOptions() { + const parseOptions = providerParsers[this.options.endpoint]; + let runOptions = + this.options.endpoint === EModelEndpoint.agents + ? { + model: undefined, + // TODO: + // would need to be override settings; otherwise, model needs to be undefined + // model: this.override.model, + // instructions: this.override.instructions, + // additional_instructions: this.override.additional_instructions, + } + : {}; + + if (parseOptions) { + runOptions = parseOptions(this.modelOptions); + } + return removeNullishValues( Object.assign( { + endpoint: this.options.endpoint, agent_id: this.options.agent.id, modelLabel: this.options.modelLabel, maxContextTokens: this.options.maxContextTokens, @@ -122,15 +174,8 @@ class AgentClient extends BaseClient { imageDetail: this.options.imageDetail, spec: this.options.spec, }, - this.modelOptions, - { - model: undefined, - // TODO: - // would need to be override settings; otherwise, model needs to be undefined - // model: this.override.model, - // instructions: this.override.instructions, - // additional_instructions: this.override.additional_instructions, - }, + // TODO: PARSE OPTIONS BY PROVIDER, MAY CONTAIN SENSITIVE DATA + runOptions, ), ); } @@ -142,6 +187,16 @@ class AgentClient extends BaseClient { }; } + async addImageURLs(message, attachments) { + const { files, image_urls } = await encodeAndFormat( + this.options.req, + attachments, + this.options.agent.provider, + ); + message.image_urls = image_urls.length ? image_urls : undefined; + return files; + } + async buildMessages( messages, parentMessageId, @@ -270,25 +325,34 @@ class AgentClient extends BaseClient { /** @type {sendCompletion} */ async sendCompletion(payload, opts = {}) { this.modelOptions.user = this.user; - return await this.chatCompletion({ + await this.chatCompletion({ payload, onProgress: opts.onProgress, abortController: opts.abortController, }); + return this.contentParts; } - // async recordTokenUsage({ promptTokens, completionTokens, context = 'message' }) { - // await spendTokens( - // { - // context, - // model: this.modelOptions.model, - // conversationId: this.conversationId, - // user: this.user ?? this.options.req.user?.id, - // endpointTokenConfig: this.options.endpointTokenConfig, - // }, - // { promptTokens, completionTokens }, - // ); - // } + /** + * @param {Object} params + * @param {string} [params.model] + * @param {string} [params.context='message'] + * @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage] + */ + async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) { + for (const usage of collectedUsage) { + await spendTokens( + { + context, + model: model ?? this.modelOptions.model, + conversationId: this.conversationId, + user: this.user ?? this.options.req.user?.id, + endpointTokenConfig: this.options.endpointTokenConfig, + }, + { promptTokens: usage.input_tokens, completionTokens: usage.output_tokens }, + ); + } + } async chatCompletion({ payload, abortController = null }) { try { @@ -398,9 +462,8 @@ class AgentClient extends BaseClient { // }); // } - // const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; - const run = await createRun({ + req: this.options.req, agent: this.options.agent, tools: this.options.tools, toolMap: this.options.toolMap, @@ -415,6 +478,7 @@ class AgentClient extends BaseClient { thread_id: this.conversationId, }, run_id: this.responseMessageId, + signal: abortController.signal, streamMode: 'values', version: 'v2', }; @@ -423,8 +487,10 @@ class AgentClient extends BaseClient { throw new Error('Failed to create run'); } + this.run = run; + const messages = formatAgentMessages(payload); - const runMessages = await run.processStream({ messages }, config, { + await run.processStream({ messages }, config, { [Callback.TOOL_ERROR]: (graph, error, toolId) => { logger.error( '[api/server/controllers/agents/client.js #chatCompletion] Tool Error', @@ -433,14 +499,94 @@ class AgentClient extends BaseClient { ); }, }); - // console.dir(runMessages, { depth: null }); - return runMessages; + this.recordCollectedUsage({ context: 'message' }).catch((err) => { + logger.error( + '[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage', + err, + ); + }); } catch (err) { - logger.error( - '[api/server/controllers/agents/client.js #chatCompletion] Unhandled error type', + if (!abortController.signal.aborted) { + logger.error( + '[api/server/controllers/agents/client.js #sendCompletion] Unhandled error type', + err, + ); + throw err; + } + + logger.warn( + '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', err, ); - throw err; + } + } + + /** + * + * @param {Object} params + * @param {string} params.text + * @param {string} params.conversationId + */ + async titleConvo({ text }) { + if (!this.run) { + throw new Error('Run not initialized'); + } + const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); + const clientOptions = {}; + const providerConfig = this.options.req.app.locals[this.options.agent.provider]; + if ( + providerConfig && + providerConfig.titleModel && + providerConfig.titleModel !== Constants.CURRENT_MODEL + ) { + clientOptions.model = providerConfig.titleModel; + } + try { + const titleResult = await this.run.generateTitle({ + inputText: text, + contentParts: this.contentParts, + clientOptions, + chainOptions: { + callbacks: [ + { + handleLLMEnd, + }, + ], + }, + }); + + const collectedUsage = collectedMetadata.map((item) => { + let input_tokens, output_tokens; + + if (item.usage) { + input_tokens = item.usage.input_tokens || item.usage.inputTokens; + output_tokens = item.usage.output_tokens || item.usage.outputTokens; + } else if (item.tokenUsage) { + input_tokens = item.tokenUsage.promptTokens; + output_tokens = item.tokenUsage.completionTokens; + } + + return { + input_tokens: input_tokens, + output_tokens: output_tokens, + }; + }); + + this.recordCollectedUsage({ + model: clientOptions.model, + context: 'title', + collectedUsage, + }).catch((err) => { + logger.error( + '[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage', + err, + ); + }); + + return titleResult.title; + } catch (err) { + logger.error('[api/server/controllers/agents/client.js #titleConvo] Error', err); + return; } } diff --git a/api/server/controllers/agents/demo.js b/api/server/controllers/agents/demo.js deleted file mode 100644 index c90745ba80d..00000000000 --- a/api/server/controllers/agents/demo.js +++ /dev/null @@ -1,44 +0,0 @@ -// Import the necessary modules -const path = require('path'); -const base = path.resolve(__dirname, '..', '..', '..', '..', 'api'); -console.log(base); -//api/server/controllers/agents/demo.js -require('module-alias')({ base }); -const connectDb = require('~/lib/db/connectDb'); -const AgentClient = require('./client'); - -// Define the user and message options -const user = 'user123'; -const parentMessageId = 'pmid123'; -const conversationId = 'cid456'; -const maxContextTokens = 200000; -const req = { - user: { id: user }, -}; -const progressOptions = { - res: {}, -}; - -// Define the message options -const messageOptions = { - user, - parentMessageId, - conversationId, - progressOptions, -}; - -async function main() { - await connectDb(); - const client = new AgentClient({ req, maxContextTokens }); - - const text = 'Hello, this is a test message.'; - - try { - let response = await client.sendMessage(text, messageOptions); - console.log('Response:', response); - } catch (error) { - console.error('Error sending message:', error); - } -} - -main(); diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index 6480205979b..2006d4e6ea5 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -1,4 +1,4 @@ -const { Constants, getResponseSender } = require('librechat-data-provider'); +const { Constants } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage } = require('~/server/utils'); const { saveMessage } = require('~/models'); @@ -9,22 +9,17 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { text, endpointOption, conversationId, - modelDisplayLabel, parentMessageId = null, overrideParentMessageId = null, } = req.body; + let sender; let userMessage; - let userMessagePromise; let promptTokens; let userMessageId; let responseMessageId; + let userMessagePromise; - const sender = getResponseSender({ - ...endpointOption, - model: endpointOption.modelOptions.model, - modelDisplayLabel, - }); const newConvo = !conversationId; const user = req.user.id; @@ -39,6 +34,8 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { responseMessageId = data[key]; } else if (key === 'promptTokens') { promptTokens = data[key]; + } else if (key === 'sender') { + sender = data[key]; } else if (!conversationId && key === 'conversationId') { conversationId = data[key]; } @@ -46,6 +43,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { }; try { + /** @type {{ client: TAgentClient }} */ const { client } = await initializeClient({ req, res, endpointOption }); const getAbortData = () => ({ @@ -54,8 +52,8 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { promptTokens, conversationId, userMessagePromise, - // text: getPartialText(), messageId: responseMessageId, + content: client.getContentParts(), parentMessageId: overrideParentMessageId ?? userMessageId, }); @@ -90,11 +88,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { }; let response = await client.sendMessage(text, messageOptions); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - response.endpoint = endpointOption.endpoint; const { conversation = {} } = await client.responsePromise; @@ -103,7 +96,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { if (client.options.attachments) { userMessage.files = client.options.attachments; - conversation.model = endpointOption.modelOptions.model; delete userMessage.image_urls; } diff --git a/api/server/controllers/agents/run.js b/api/server/controllers/agents/run.js index d30d43bd9d3..5aeefa122db 100644 --- a/api/server/controllers/agents/run.js +++ b/api/server/controllers/agents/run.js @@ -1,4 +1,4 @@ -const { Run } = require('@librechat/agents'); +const { Run, Providers } = require('@librechat/agents'); const { providerEndpointMap } = require('librechat-data-provider'); /** @@ -14,11 +14,12 @@ const { providerEndpointMap } = require('librechat-data-provider'); * Creates a new Run instance with custom handlers and configuration. * * @param {Object} options - The options for creating the Run instance. + * @param {ServerRequest} [options.req] - The server request. + * @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated. * @param {Agent} options.agent - The agent for this run. * @param {StructuredTool[] | undefined} [options.tools] - The tools to use in the run. * @param {Record | undefined} [options.toolMap] - The tool map for the run. * @param {Record | undefined} [options.customHandlers] - Custom event handlers. - * @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated. * @param {ClientOptions} [options.modelOptions] - Optional model to use; if not provided, it will use the default from modelMap. * @param {boolean} [options.streaming=true] - Whether to use streaming. * @param {boolean} [options.streamUsage=true] - Whether to stream usage information. @@ -43,15 +44,22 @@ async function createRun({ modelOptions, ); + const graphConfig = { + runId, + llmConfig, + tools, + toolMap, + instructions: agent.instructions, + additional_instructions: agent.additional_instructions, + }; + + // TEMPORARY FOR TESTING + if (agent.provider === Providers.ANTHROPIC) { + graphConfig.streamBuffer = 2000; + } + return Run.create({ - graphConfig: { - runId, - llmConfig, - tools, - toolMap, - instructions: agent.instructions, - additional_instructions: agent.additional_instructions, - }, + graphConfig, customHandlers, }); } diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 2a9911c5416..65e37f2618c 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -1,5 +1,5 @@ const { nanoid } = require('nanoid'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, Constants } = require('librechat-data-provider'); const { getAgent, createAgent, @@ -9,6 +9,8 @@ const { } = require('~/models/Agent'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { uploadImageBuffer } = require('~/server/services/Files/process'); +const { getProjectByName } = require('~/models/Project'); +const { updateAgentProjects } = require('~/models/Agent'); const { deleteFileByFilter } = require('~/models/File'); const { logger } = require('~/config'); @@ -53,16 +55,35 @@ const createAgentHandler = async (req, res) => { * @param {object} req - Express Request * @param {object} req.params - Request params * @param {string} req.params.id - Agent identifier. - * @returns {Agent} 200 - success response - application/json + * @param {object} req.user - Authenticated user information + * @param {string} req.user.id - User ID + * @returns {Promise} 200 - success response - application/json * @returns {Error} 404 - Agent not found */ const getAgentHandler = async (req, res) => { try { const id = req.params.id; - const agent = await getAgent({ id }); + const author = req.user.id; + + let query = { id, author }; + + const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, ['agentIds']); + if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) { + query = { + $or: [{ id, $in: globalProject.agentIds }, query], + }; + } + + const agent = await getAgent(query); + if (!agent) { return res.status(404).json({ error: 'Agent not found' }); } + + if (agent.author !== author) { + delete agent.author; + } + return res.status(200).json(agent); } catch (error) { logger.error('[/Agents/:id] Error retrieving agent', error); @@ -82,7 +103,17 @@ const getAgentHandler = async (req, res) => { const updateAgentHandler = async (req, res) => { try { const id = req.params.id; - const updatedAgent = await updateAgent({ id, author: req.user.id }, req.body); + const { projectIds, removeProjectIds, ...updateData } = req.body; + + let updatedAgent; + if (Object.keys(updateData).length > 0) { + updatedAgent = await updateAgent({ id, author: req.user.id }, updateData); + } + + if (projectIds || removeProjectIds) { + updatedAgent = await updateAgentProjects(id, projectIds, removeProjectIds); + } + return res.json(updatedAgent); } catch (error) { logger.error('[/Agents/:id] Error updating Agent', error); @@ -119,13 +150,13 @@ const deleteAgentHandler = async (req, res) => { * @param {object} req - Express Request * @param {object} req.query - Request query * @param {string} [req.query.user] - The user ID of the agent's author. - * @returns {AgentListResponse} 200 - success response - application/json + * @returns {Promise} 200 - success response - application/json */ const getListAgentsHandler = async (req, res) => { try { - const { user } = req.query; - const filter = user ? { author: user } : {}; - const data = await getListAgents(filter); + const data = await getListAgents({ + author: req.user.id, + }); return res.json(data); } catch (error) { logger.error('[/Agents] Error listing Agents', error); diff --git a/api/server/index.js b/api/server/index.js index 3fa5778301c..3bc0a050031 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -106,13 +106,16 @@ const startServer = async () => { app.use('/api/share', routes.share); app.use('/api/roles', routes.roles); app.use('/api/agents', routes.agents); + app.use('/api/banner', routes.banner); + app.use('/api/bedrock', routes.bedrock); app.use('/api/tags', routes.tags); app.use((req, res) => { // Replace lang attribute in index.html with lang from cookies or accept-language header const lang = req.cookies.lang || req.headers['accept-language']?.split(',')[0] || 'en-US'; - const updatedIndexHtml = indexHTML.replace(/lang="en-US"/g, `lang="${lang}"`); + const saneLang = lang.replace(/"/g, '"'); // sanitize untrusted user input + const updatedIndexHtml = indexHTML.replace(/lang="en-US"/g, `lang="${saneLang}"`); res.send(updatedIndexHtml); }); diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index e855c0cb69d..6e608e1cc1e 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,4 +1,4 @@ -const { isAssistantsEndpoint } = require('librechat-data-provider'); +const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider'); const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); const clearPendingReq = require('~/cache/clearPendingReq'); @@ -107,7 +107,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => { finish_reason: 'incomplete', endpoint: endpointOption.endpoint, iconURL: endpointOption.iconURL, - model: endpointOption.modelOptions.model, + model: endpointOption.modelOptions?.model ?? endpointOption.model_parameters?.model, unfinished: false, error: false, isCreatedByUser: false, @@ -165,10 +165,18 @@ const handleAbortError = async (res, req, error, data) => { ); } - const errorText = error?.message?.includes('"type"') + let errorText = error?.message?.includes('"type"') ? error.message : 'An error occurred while processing your request. Please contact the Admin.'; + if (error?.type === ErrorTypes.INVALID_REQUEST) { + errorText = `{"type":"${ErrorTypes.INVALID_REQUEST}"}`; + } + + if (error?.message?.includes('does not support \'system\'')) { + errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`; + } + const respondWithError = async (partialText) => { let options = { sender, diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index 83e06d77c33..2b4ba40172d 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -5,6 +5,7 @@ const assistants = require('~/server/services/Endpoints/assistants'); const gptPlugins = require('~/server/services/Endpoints/gptPlugins'); const { processFiles } = require('~/server/services/Files/process'); const anthropic = require('~/server/services/Endpoints/anthropic'); +const bedrock = require('~/server/services/Endpoints/bedrock'); const openAI = require('~/server/services/Endpoints/openAI'); const agents = require('~/server/services/Endpoints/agents'); const custom = require('~/server/services/Endpoints/custom'); @@ -17,6 +18,7 @@ const buildFunction = { [EModelEndpoint.google]: google.buildOptions, [EModelEndpoint.custom]: custom.buildOptions, [EModelEndpoint.agents]: agents.buildOptions, + [EModelEndpoint.bedrock]: bedrock.buildOptions, [EModelEndpoint.azureOpenAI]: openAI.buildOptions, [EModelEndpoint.anthropic]: anthropic.buildOptions, [EModelEndpoint.gptPlugins]: gptPlugins.buildOptions, diff --git a/api/server/middleware/optionalJwtAuth.js b/api/server/middleware/optionalJwtAuth.js new file mode 100644 index 00000000000..8aa1c27e007 --- /dev/null +++ b/api/server/middleware/optionalJwtAuth.js @@ -0,0 +1,17 @@ +const passport = require('passport'); + +// This middleware does not require authentication, +// but if the user is authenticated, it will set the user object. +const optionalJwtAuth = (req, res, next) => { + passport.authenticate('jwt', { session: false }, (err, user) => { + if (err) { + return next(err); + } + if (user) { + req.user = user; + } + next(); + })(req, res, next); +}; + +module.exports = optionalJwtAuth; diff --git a/api/server/middleware/roles/generateCheckAccess.js b/api/server/middleware/roles/generateCheckAccess.js index 900921ef80d..ffc0ddc6133 100644 --- a/api/server/middleware/roles/generateCheckAccess.js +++ b/api/server/middleware/roles/generateCheckAccess.js @@ -1,4 +1,3 @@ -const { SystemRoles } = require('librechat-data-provider'); const { getRoleByName } = require('~/models/Role'); /** @@ -17,10 +16,6 @@ const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => { return res.status(401).json({ message: 'Authorization required' }); } - if (user.role === SystemRoles.ADMIN) { - return next(); - } - const role = await getRoleByName(user.role); if (role && role[permissionType]) { const hasAnyPermission = permissions.some((permission) => { diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index e79f749fc13..dde3293b42a 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -41,7 +41,7 @@ router.post('/:agent_id', async (req, res) => { return res.status(400).json({ message: 'No functions provided' }); } - let metadata = encryptMetadata(_metadata); + let metadata = await encryptMetadata(_metadata); let { domain } = metadata; domain = await domainParser(req, domain, true); diff --git a/api/server/routes/agents/v1.js b/api/server/routes/agents/v1.js index 1001873fe41..d3a3005bd55 100644 --- a/api/server/routes/agents/v1.js +++ b/api/server/routes/agents/v1.js @@ -1,11 +1,30 @@ const multer = require('multer'); const express = require('express'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); const v1 = require('~/server/controllers/agents/v1'); const actions = require('./actions'); const upload = multer(); const router = express.Router(); +const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]); +const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [ + Permissions.USE, + Permissions.CREATE, +]); + +const checkGlobalAgentShare = generateCheckAccess( + PermissionTypes.AGENTS, + [Permissions.USE, Permissions.CREATE], + { + [Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'], + }, +); + +router.use(requireJwtAuth); +router.use(checkAgentAccess); + /** * Agent actions route. * @route GET|POST /agents/actions @@ -27,7 +46,7 @@ router.use('/tools', (req, res) => { * @param {AgentCreateParams} req.body - The agent creation parameters. * @returns {Agent} 201 - Success response - application/json */ -router.post('/', v1.createAgent); +router.post('/', checkAgentCreate, v1.createAgent); /** * Retrieves an agent. @@ -35,7 +54,7 @@ router.post('/', v1.createAgent); * @param {string} req.params.id - Agent identifier. * @returns {Agent} 200 - Success response - application/json */ -router.get('/:id', v1.getAgent); +router.get('/:id', checkAgentAccess, v1.getAgent); /** * Updates an agent. @@ -44,7 +63,7 @@ router.get('/:id', v1.getAgent); * @param {AgentUpdateParams} req.body - The agent update parameters. * @returns {Agent} 200 - Success response - application/json */ -router.patch('/:id', v1.updateAgent); +router.patch('/:id', checkGlobalAgentShare, v1.updateAgent); /** * Deletes an agent. @@ -52,7 +71,7 @@ router.patch('/:id', v1.updateAgent); * @param {string} req.params.id - Agent identifier. * @returns {Agent} 200 - success response - application/json */ -router.delete('/:id', v1.deleteAgent); +router.delete('/:id', checkAgentCreate, v1.deleteAgent); /** * Returns a list of agents. @@ -60,9 +79,7 @@ router.delete('/:id', v1.deleteAgent); * @param {AgentListParams} req.query - The agent list parameters for pagination and sorting. * @returns {AgentListResponse} 200 - success response - application/json */ -router.get('/', v1.getListAgents); - -// TODO: handle private agents +router.get('/', checkAgentAccess, v1.getListAgents); /** * Uploads and updates an avatar for a specific agent. @@ -72,6 +89,6 @@ router.get('/', v1.getListAgents); * @param {string} [req.body.metadata] - Optional metadata for the agent's avatar. * @returns {Object} 200 - success response - application/json */ -router.post('/avatar/:agent_id', upload.single('file'), v1.uploadAgentAvatar); +router.post('/avatar/:agent_id', checkAgentAccess, upload.single('file'), v1.uploadAgentAvatar); module.exports = router; diff --git a/api/server/routes/banner.js b/api/server/routes/banner.js new file mode 100644 index 00000000000..cf7eafd017d --- /dev/null +++ b/api/server/routes/banner.js @@ -0,0 +1,15 @@ +const express = require('express'); + +const { getBanner } = require('~/models/Banner'); +const optionalJwtAuth = require('~/server/middleware/optionalJwtAuth'); +const router = express.Router(); + +router.get('/', optionalJwtAuth, async (req, res) => { + try { + res.status(200).send(await getBanner(req.user)); + } catch (error) { + res.status(500).json({ message: 'Error getting banner' }); + } +}); + +module.exports = router; diff --git a/api/server/routes/bedrock/chat.js b/api/server/routes/bedrock/chat.js new file mode 100644 index 00000000000..605a012710c --- /dev/null +++ b/api/server/routes/bedrock/chat.js @@ -0,0 +1,36 @@ +const express = require('express'); + +const router = express.Router(); +const { + setHeaders, + handleAbort, + // validateModel, + // validateEndpoint, + buildEndpointOption, +} = require('~/server/middleware'); +const { initializeClient } = require('~/server/services/Endpoints/bedrock'); +const AgentController = require('~/server/controllers/agents/request'); +const addTitle = require('~/server/services/Endpoints/bedrock/title'); + +router.post('/abort', handleAbort()); + +/** + * @route POST / + * @desc Chat with an assistant + * @access Public + * @param {express.Request} req - The request object, containing the request data. + * @param {express.Response} res - The response object, used to send back a response. + * @returns {void} + */ +router.post( + '/', + // validateModel, + // validateEndpoint, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AgentController(req, res, next, initializeClient, addTitle); + }, +); + +module.exports = router; diff --git a/api/server/routes/bedrock/index.js b/api/server/routes/bedrock/index.js new file mode 100644 index 00000000000..b1a9efec4cc --- /dev/null +++ b/api/server/routes/bedrock/index.js @@ -0,0 +1,19 @@ +const express = require('express'); +const router = express.Router(); +const { + uaParser, + checkBan, + requireJwtAuth, + // concurrentLimiter, + // messageIpLimiter, + // messageUserLimiter, +} = require('~/server/middleware'); + +const chat = require('./chat'); + +router.use(requireJwtAuth); +router.use(checkBan); +router.use(uaParser); +router.use('/chat', chat); + +module.exports = router; diff --git a/api/server/routes/config.js b/api/server/routes/config.js index 3fc90c14bc5..f6669169acd 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,5 +1,5 @@ const express = require('express'); -const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider'); +const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider'); const { getLdapConfig } = require('~/server/services/Config/ldap'); const { getProjectByName } = require('~/models/Project'); const { isEnabled } = require('~/server/utils'); @@ -32,7 +32,7 @@ router.get('/', async function (req, res) { return today.getMonth() === 1 && today.getDate() === 11; }; - const instanceProject = await getProjectByName('instance', '_id'); + const instanceProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id'); const ldap = getLdapConfig(); diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 104b0616f81..47a8ef19a8f 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -109,8 +109,14 @@ router.post('/clear', async (req, res) => { router.post('/update', async (req, res) => { const update = req.body.arg; + if (!update.conversationId) { + return res.status(400).json({ error: 'conversationId is required' }); + } + try { - const dbResponse = await saveConvo(req, update, { context: 'POST /api/convos/update' }); + const dbResponse = await saveConvo(req, update, { + context: `POST /api/convos/update ${update.conversationId}`, + }); res.status(201).json(dbResponse); } catch (error) { logger.error('Error updating conversation', error); diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 90ba5c73add..4aba91e9548 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -8,6 +8,7 @@ const presets = require('./presets'); const prompts = require('./prompts'); const balance = require('./balance'); const plugins = require('./plugins'); +const bedrock = require('./bedrock'); const search = require('./search'); const models = require('./models'); const convos = require('./convos'); @@ -23,6 +24,7 @@ const edit = require('./edit'); const keys = require('./keys'); const user = require('./user'); const ask = require('./ask'); +const banner = require('./banner'); module.exports = { ask, @@ -36,6 +38,7 @@ module.exports = { files, share, agents, + bedrock, convos, search, prompts, @@ -50,4 +53,5 @@ module.exports = { assistants, categories, staticRoute, + banner, }; diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index f510f31f63e..0abca92001d 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -1,4 +1,5 @@ const express = require('express'); +const { ContentTypes } = require('librechat-data-provider'); const { saveConvo, saveMessage, getMessages, updateMessage, deleteMessages } = require('~/models'); const { requireJwtAuth, validateMessageReq } = require('~/server/middleware'); const { countTokens } = require('~/server/utils'); @@ -54,11 +55,50 @@ router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) = router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) => { try { - const { messageId, model } = req.params; - const { text } = req.body; - const tokenCount = await countTokens(text, model); - const result = await updateMessage(req, { messageId, text, tokenCount }); - res.status(200).json(result); + const { conversationId, messageId } = req.params; + const { text, index, model } = req.body; + + if (index === undefined) { + const tokenCount = await countTokens(text, model); + const result = await updateMessage(req, { messageId, text, tokenCount }); + return res.status(200).json(result); + } + + if (typeof index !== 'number' || index < 0) { + return res.status(400).json({ error: 'Invalid index' }); + } + + const message = (await getMessages({ conversationId, messageId }, 'content tokenCount'))?.[0]; + if (!message) { + return res.status(404).json({ error: 'Message not found' }); + } + + const existingContent = message.content; + if (!Array.isArray(existingContent) || index >= existingContent.length) { + return res.status(400).json({ error: 'Invalid index' }); + } + + const updatedContent = [...existingContent]; + if (!updatedContent[index]) { + return res.status(400).json({ error: 'Content part not found' }); + } + + if (updatedContent[index].type !== ContentTypes.TEXT) { + return res.status(400).json({ error: 'Cannot update non-text content' }); + } + + const oldText = updatedContent[index].text; + updatedContent[index] = { type: ContentTypes.TEXT, text }; + + let tokenCount = message.tokenCount; + if (tokenCount !== undefined) { + const oldTokenCount = await countTokens(oldText, model); + const newTokenCount = await countTokens(text, model); + tokenCount = Math.max(0, tokenCount - oldTokenCount) + newTokenCount; + } + + const result = await updateMessage(req, { messageId, content: updatedContent, tokenCount }); + return res.status(200).json(result); } catch (error) { logger.error('Error updating message:', error); res.status(500).json({ error: 'Internal server error' }); diff --git a/api/server/routes/prompts.js b/api/server/routes/prompts.js index 5a6dcafcb69..54128d3b395 100644 --- a/api/server/routes/prompts.js +++ b/api/server/routes/prompts.js @@ -24,6 +24,7 @@ const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [ Permissions.USE, Permissions.CREATE, ]); + const checkGlobalPromptShare = generateCheckAccess( PermissionTypes.PROMPTS, [Permissions.USE, Permissions.CREATE], diff --git a/api/server/routes/roles.js b/api/server/routes/roles.js index 06005ad40e8..36152e2c7e4 100644 --- a/api/server/routes/roles.js +++ b/api/server/routes/roles.js @@ -20,7 +20,10 @@ router.get('/:roleName', async (req, res) => { // TODO: TEMP, use a better parsing for roleName const roleName = _r.toUpperCase(); - if (req.user.role !== SystemRoles.ADMIN && !roleDefaults[roleName]) { + if ( + (req.user.role !== SystemRoles.ADMIN && roleName === SystemRoles.ADMIN) || + (req.user.role !== SystemRoles.ADMIN && !roleDefaults[roleName]) + ) { return res.status(403).send({ message: 'Unauthorized' }); } diff --git a/api/server/routes/tags.js b/api/server/routes/tags.js index c9f637c473b..d3e27d37110 100644 --- a/api/server/routes/tags.js +++ b/api/server/routes/tags.js @@ -61,7 +61,8 @@ router.post('/', async (req, res) => { */ router.put('/:tag', async (req, res) => { try { - const tag = await updateConversationTag(req.user.id, req.params.tag, req.body); + const decodedTag = decodeURIComponent(req.params.tag); + const tag = await updateConversationTag(req.user.id, decodedTag, req.body); if (tag) { res.status(200).json(tag); } else { @@ -81,7 +82,8 @@ router.put('/:tag', async (req, res) => { */ router.delete('/:tag', async (req, res) => { try { - const tag = await deleteConversationTag(req.user.id, req.params.tag); + const decodedTag = decodeURIComponent(req.params.tag); + const tag = await deleteConversationTag(req.user.id, decodedTag); if (tag) { res.status(200).json(tag); } else { diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 04a9b9829d1..da69548b43d 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -165,7 +165,7 @@ async function createActionTool({ action, requestBuilder, zodSchema, name, descr * Encrypts sensitive metadata values for an action. * * @param {ActionMetadata} metadata - The action metadata to encrypt. - * @returns {ActionMetadata} The updated action metadata with encrypted values. + * @returns {Promise} The updated action metadata with encrypted values. */ async function encryptMetadata(metadata) { const encryptedMetadata = { ...metadata }; diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index eae83bc6e09..f99e9628711 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -8,6 +8,7 @@ const { loadDefaultInterface } = require('./start/interface'); const { azureConfigSetup } = require('./start/azureOpenAI'); const { loadAndFormatTools } = require('./ToolService'); const { initializeRoles } = require('~/models/Role'); +const { cleanup } = require('./cleanup'); const paths = require('~/config/paths'); /** @@ -17,6 +18,7 @@ const paths = require('~/config/paths'); * @param {Express.Application} app - The Express application object. */ const AppService = async (app) => { + cleanup(); await initializeRoles(); /** @type {TCustomConfig}*/ const config = (await loadCustomConfig()) ?? {}; @@ -94,18 +96,19 @@ const AppService = async (app) => { ); } - if (endpoints?.[EModelEndpoint.openAI]) { - endpointLocals[EModelEndpoint.openAI] = endpoints[EModelEndpoint.openAI]; - } - if (endpoints?.[EModelEndpoint.google]) { - endpointLocals[EModelEndpoint.google] = endpoints[EModelEndpoint.google]; - } - if (endpoints?.[EModelEndpoint.anthropic]) { - endpointLocals[EModelEndpoint.anthropic] = endpoints[EModelEndpoint.anthropic]; - } - if (endpoints?.[EModelEndpoint.gptPlugins]) { - endpointLocals[EModelEndpoint.gptPlugins] = endpoints[EModelEndpoint.gptPlugins]; - } + const endpointKeys = [ + EModelEndpoint.openAI, + EModelEndpoint.google, + EModelEndpoint.bedrock, + EModelEndpoint.anthropic, + EModelEndpoint.gptPlugins, + ]; + + endpointKeys.forEach((key) => { + if (endpoints?.[key]) { + endpointLocals[key] = endpoints[key]; + } + }); app.locals = { ...defaultLocals, diff --git a/api/server/services/Config/EndpointService.js b/api/server/services/Config/EndpointService.js index b2f82f383be..56768905b82 100644 --- a/api/server/services/Config/EndpointService.js +++ b/api/server/services/Config/EndpointService.js @@ -45,6 +45,9 @@ module.exports = { AZURE_ASSISTANTS_BASE_URL, EModelEndpoint.azureAssistants, ), + [EModelEndpoint.bedrock]: generateConfig( + process.env.BEDROCK_AWS_SECRET_ACCESS_KEY ?? process.env.BEDROCK_AWS_DEFAULT_REGION, + ), /* key will be part of separate config */ [EModelEndpoint.agents]: generateConfig(process.env.I_AM_A_TEAPOT), }, diff --git a/api/server/services/Config/loadDefaultEConfig.js b/api/server/services/Config/loadDefaultEConfig.js index df331d92fb0..c11ddbe9d5b 100644 --- a/api/server/services/Config/loadDefaultEConfig.js +++ b/api/server/services/Config/loadDefaultEConfig.js @@ -9,22 +9,13 @@ const { config } = require('./EndpointService'); */ async function loadDefaultEndpointsConfig(req) { const { google, gptPlugins } = await loadAsyncEndpoints(req); - const { - openAI, - agents, - assistants, - azureAssistants, - bingAI, - anthropic, - azureOpenAI, - chatGPTBrowser, - } = config; + const { assistants, azureAssistants, bingAI, azureOpenAI, chatGPTBrowser } = config; const enabledEndpoints = getEnabledEndpoints(); const endpointConfig = { - [EModelEndpoint.openAI]: openAI, - [EModelEndpoint.agents]: agents, + [EModelEndpoint.openAI]: config[EModelEndpoint.openAI], + [EModelEndpoint.agents]: config[EModelEndpoint.agents], [EModelEndpoint.assistants]: assistants, [EModelEndpoint.azureAssistants]: azureAssistants, [EModelEndpoint.azureOpenAI]: azureOpenAI, @@ -32,7 +23,8 @@ async function loadDefaultEndpointsConfig(req) { [EModelEndpoint.bingAI]: bingAI, [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, [EModelEndpoint.gptPlugins]: gptPlugins, - [EModelEndpoint.anthropic]: anthropic, + [EModelEndpoint.anthropic]: config[EModelEndpoint.anthropic], + [EModelEndpoint.bedrock]: config[EModelEndpoint.bedrock], }; const orderedAndFilteredEndpoints = enabledEndpoints.reduce((config, key, index) => { diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index e06b73c0c0a..464e84d44e4 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -3,6 +3,7 @@ const { useAzurePlugins } = require('~/server/services/Config/EndpointService'). const { getOpenAIModels, getGoogleModels, + getBedrockModels, getAnthropicModels, getChatGPTBrowserModels, } = require('~/server/services/ModelService'); @@ -38,6 +39,7 @@ async function loadDefaultModels(req) { [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, [EModelEndpoint.assistants]: assistants, [EModelEndpoint.azureAssistants]: azureAssistants, + [EModelEndpoint.bedrock]: getBedrockModels(), }; } diff --git a/api/server/services/Endpoints/agents/build.js b/api/server/services/Endpoints/agents/build.js index 256901057de..d04dee9a06d 100644 --- a/api/server/services/Endpoints/agents/build.js +++ b/api/server/services/Endpoints/agents/build.js @@ -2,7 +2,7 @@ const { getAgent } = require('~/models/Agent'); const { logger } = require('~/config'); const buildOptions = (req, endpoint, parsedBody) => { - const { agent_id, instructions, spec, ...rest } = parsedBody; + const { agent_id, instructions, spec, ...model_parameters } = parsedBody; const agentPromise = getAgent({ id: agent_id, @@ -19,9 +19,7 @@ const buildOptions = (req, endpoint, parsedBody) => { agent_id, instructions, spec, - modelOptions: { - ...rest, - }, + model_parameters, }; return endpointOption; diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 8627775ce5d..a079e2145fb 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -11,7 +11,12 @@ const { z } = require('zod'); const { tool } = require('@langchain/core/tools'); -const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider'); +const { createContentAggregator } = require('@librechat/agents'); +const { + EModelEndpoint, + providerEndpointMap, + getResponseSender, +} = require('librechat-data-provider'); const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); // for testing purposes // const createTavilySearchTool = require('~/app/clients/tools/structured/TavilySearch'); @@ -53,7 +58,8 @@ const initializeClient = async ({ req, res, endpointOption }) => { } // TODO: use endpointOption to determine options/modelOptions - const eventHandlers = getDefaultHandlers({ res }); + const { contentParts, aggregateContent } = createContentAggregator(); + const eventHandlers = getDefaultHandlers({ res, aggregateContent }); // const tools = [createTavilySearchTool()]; // const tools = [_getWeather]; @@ -90,7 +96,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { } // TODO: pass-in override settings that are specific to current run - endpointOption.modelOptions.model = agent.model; + endpointOption.model_parameters.model = agent.model; const options = await getOptions({ req, res, @@ -101,13 +107,21 @@ const initializeClient = async ({ req, res, endpointOption }) => { }); modelOptions = Object.assign(modelOptions, options.llmConfig); + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.model_parameters.model, + }); + const client = new AgentClient({ req, agent, tools, + sender, toolMap, + contentParts, modelOptions, eventHandlers, + endpoint: EModelEndpoint.agents, configOptions: options.configOptions, maxContextTokens: agent.max_context_tokens ?? diff --git a/api/server/services/Endpoints/anthropic/addTitle.js b/api/server/services/Endpoints/anthropic/addTitle.js index b69c04de688..5c477632d27 100644 --- a/api/server/services/Endpoints/anthropic/addTitle.js +++ b/api/server/services/Endpoints/anthropic/addTitle.js @@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => { const title = await client.titleConvo({ text, - responseText: response?.text, + responseText: response?.text ?? '', conversationId: response.conversationId, }); await titleCache.set(key, title, 120000); diff --git a/api/server/services/Endpoints/bedrock/build.js b/api/server/services/Endpoints/bedrock/build.js new file mode 100644 index 00000000000..d6fb0636a93 --- /dev/null +++ b/api/server/services/Endpoints/bedrock/build.js @@ -0,0 +1,44 @@ +const { removeNullishValues, bedrockInputParser } = require('librechat-data-provider'); +const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); +const { logger } = require('~/config'); + +const buildOptions = (endpoint, parsedBody) => { + const { + modelLabel: name, + promptPrefix, + maxContextTokens, + resendFiles = true, + imageDetail, + iconURL, + greeting, + spec, + artifacts, + ...model_parameters + } = parsedBody; + let parsedParams = model_parameters; + try { + parsedParams = bedrockInputParser.parse(model_parameters); + } catch (error) { + logger.warn('Failed to parse bedrock input', error); + } + const endpointOption = removeNullishValues({ + endpoint, + name, + resendFiles, + imageDetail, + iconURL, + greeting, + spec, + promptPrefix, + maxContextTokens, + model_parameters: parsedParams, + }); + + if (typeof artifacts === 'string') { + endpointOption.artifactsPrompt = generateArtifactsPrompt({ endpoint, artifacts }); + } + + return endpointOption; +}; + +module.exports = { buildOptions }; diff --git a/api/server/services/Endpoints/bedrock/index.js b/api/server/services/Endpoints/bedrock/index.js new file mode 100644 index 00000000000..8989f7df8c6 --- /dev/null +++ b/api/server/services/Endpoints/bedrock/index.js @@ -0,0 +1,7 @@ +const build = require('./build'); +const initialize = require('./initialize'); + +module.exports = { + ...build, + ...initialize, +}; diff --git a/api/server/services/Endpoints/bedrock/initialize.js b/api/server/services/Endpoints/bedrock/initialize.js new file mode 100644 index 00000000000..4a7e98a4ad4 --- /dev/null +++ b/api/server/services/Endpoints/bedrock/initialize.js @@ -0,0 +1,76 @@ +const { createContentAggregator } = require('@librechat/agents'); +const { + EModelEndpoint, + providerEndpointMap, + getResponseSender, +} = require('librechat-data-provider'); +const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); +// const { loadAgentTools } = require('~/server/services/ToolService'); +const getOptions = require('~/server/services/Endpoints/bedrock/options'); +const AgentClient = require('~/server/controllers/agents/client'); +const { getModelMaxTokens } = require('~/utils'); + +const initializeClient = async ({ req, res, endpointOption }) => { + if (!endpointOption) { + throw new Error('Endpoint option not provided'); + } + + /** @type {Array} */ + const collectedUsage = []; + const { contentParts, aggregateContent } = createContentAggregator(); + const eventHandlers = getDefaultHandlers({ res, aggregateContent, collectedUsage }); + + // const tools = [createTavilySearchTool()]; + + /** @type {Agent} */ + const agent = { + id: EModelEndpoint.bedrock, + name: endpointOption.name, + instructions: endpointOption.promptPrefix, + provider: EModelEndpoint.bedrock, + model: endpointOption.model_parameters.model, + model_parameters: endpointOption.model_parameters, + }; + + if (typeof endpointOption.artifactsPrompt === 'string' && endpointOption.artifactsPrompt) { + agent.instructions = `${agent.instructions ?? ''}\n${endpointOption.artifactsPrompt}`.trim(); + } + + let modelOptions = { model: agent.model }; + + // TODO: pass-in override settings that are specific to current run + const options = await getOptions({ + req, + res, + endpointOption, + }); + + modelOptions = Object.assign(modelOptions, options.llmConfig); + const maxContextTokens = + agent.max_context_tokens ?? + getModelMaxTokens(modelOptions.model, providerEndpointMap[agent.provider]); + + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.model_parameters.model, + }); + + const client = new AgentClient({ + req, + agent, + sender, + // tools, + // toolMap, + modelOptions, + contentParts, + eventHandlers, + collectedUsage, + maxContextTokens, + endpoint: EModelEndpoint.bedrock, + configOptions: options.configOptions, + attachments: endpointOption.attachments, + }); + return { client }; +}; + +module.exports = { initializeClient }; diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js new file mode 100644 index 00000000000..405d76fe4df --- /dev/null +++ b/api/server/services/Endpoints/bedrock/options.js @@ -0,0 +1,101 @@ +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { + EModelEndpoint, + Constants, + AuthType, + removeNullishValues, +} = require('librechat-data-provider'); +const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); +const { sleep } = require('~/server/utils'); + +const getOptions = async ({ req, endpointOption }) => { + const { + BEDROCK_AWS_SECRET_ACCESS_KEY, + BEDROCK_AWS_ACCESS_KEY_ID, + BEDROCK_REVERSE_PROXY, + BEDROCK_AWS_DEFAULT_REGION, + PROXY, + } = process.env; + const expiresAt = req.body.key; + const isUserProvided = BEDROCK_AWS_SECRET_ACCESS_KEY === AuthType.USER_PROVIDED; + + let credentials = isUserProvided + ? await getUserKey({ userId: req.user.id, name: EModelEndpoint.bedrock }) + : { + accessKeyId: BEDROCK_AWS_ACCESS_KEY_ID, + secretAccessKey: BEDROCK_AWS_SECRET_ACCESS_KEY, + }; + + if (!credentials) { + throw new Error('Bedrock credentials not provided. Please provide them again.'); + } + + if ( + !isUserProvided && + (credentials.accessKeyId === undefined || credentials.accessKeyId === '') && + (credentials.secretAccessKey === undefined || credentials.secretAccessKey === '') + ) { + credentials = undefined; + } + + if (expiresAt && isUserProvided) { + checkUserKeyExpiry(expiresAt, EModelEndpoint.bedrock); + } + + /** @type {number} */ + let streamRate = Constants.DEFAULT_STREAM_RATE; + + /** @type {undefined | TBaseEndpoint} */ + const bedrockConfig = req.app.locals[EModelEndpoint.bedrock]; + + if (bedrockConfig && bedrockConfig.streamRate) { + streamRate = bedrockConfig.streamRate; + } + + /** @type {undefined | TBaseEndpoint} */ + const allConfig = req.app.locals.all; + if (allConfig && allConfig.streamRate) { + streamRate = allConfig.streamRate; + } + + /** @type {import('@librechat/agents').BedrockConverseClientOptions} */ + const requestOptions = Object.assign( + { + model: endpointOption.model, + region: BEDROCK_AWS_DEFAULT_REGION, + streaming: true, + streamUsage: true, + callbacks: [ + { + handleLLMNewToken: async () => { + if (!streamRate) { + return; + } + await sleep(streamRate); + }, + }, + ], + }, + endpointOption.model_parameters, + ); + + if (credentials) { + requestOptions.credentials = credentials; + } + + const configOptions = {}; + if (PROXY) { + configOptions.httpAgent = new HttpsProxyAgent(PROXY); + } + + if (BEDROCK_REVERSE_PROXY) { + configOptions.endpointHost = BEDROCK_REVERSE_PROXY; + } + + return { + llmConfig: removeNullishValues(requestOptions), + configOptions, + }; +}; + +module.exports = getOptions; diff --git a/api/server/services/Endpoints/bedrock/title.js b/api/server/services/Endpoints/bedrock/title.js new file mode 100644 index 00000000000..520b9f78c43 --- /dev/null +++ b/api/server/services/Endpoints/bedrock/title.js @@ -0,0 +1,40 @@ +const { CacheKeys } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); +const { isEnabled } = require('~/server/utils'); +const { saveConvo } = require('~/models'); + +const addTitle = async (req, { text, response, client }) => { + const { TITLE_CONVO = true } = process.env ?? {}; + if (!isEnabled(TITLE_CONVO)) { + return; + } + + if (client.options.titleConvo === false) { + return; + } + + // If the request was aborted, don't generate the title. + if (client.abortController.signal.aborted) { + return; + } + + const titleCache = getLogStores(CacheKeys.GEN_TITLE); + const key = `${req.user.id}-${response.conversationId}`; + + const title = await client.titleConvo({ + text, + responseText: response?.text ?? '', + conversationId: response.conversationId, + }); + await titleCache.set(key, title, 120000); + await saveConvo( + req, + { + conversationId: response.conversationId, + title, + }, + { context: 'api/server/services/Endpoints/bedrock/title.js' }, + ); +}; + +module.exports = addTitle; diff --git a/api/server/services/Endpoints/google/addTitle.js b/api/server/services/Endpoints/google/addTitle.js index 14eafe841d7..f21d123214b 100644 --- a/api/server/services/Endpoints/google/addTitle.js +++ b/api/server/services/Endpoints/google/addTitle.js @@ -49,7 +49,7 @@ const addTitle = async (req, { text, response, client }) => { const title = await titleClient.titleConvo({ text, - responseText: response?.text, + responseText: response?.text ?? '', conversationId: response.conversationId, }); await titleCache.set(key, title, 120000); diff --git a/api/server/services/Endpoints/openAI/addTitle.js b/api/server/services/Endpoints/openAI/addTitle.js index af886dd22df..35291c5e310 100644 --- a/api/server/services/Endpoints/openAI/addTitle.js +++ b/api/server/services/Endpoints/openAI/addTitle.js @@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => { const title = await client.titleConvo({ text, - responseText: response?.text, + responseText: response?.text ?? '', conversationId: response.conversationId, }); await titleCache.set(key, title, 120000); diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index 4d1157bd349..7b6bef03f84 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -1,4 +1,3 @@ -const WebSocket = require('ws'); const { CacheKeys, findLastSeparatorIndex, SEPARATORS } = require('librechat-data-provider'); const { getLogStores } = require('~/cache'); @@ -44,33 +43,6 @@ function getRandomVoiceId(voiceIds) { * @property {string[]} normalizedAlignment.chars */ -/** - * - * @param {Record} parameters - * @returns - */ -function assembleQuery(parameters) { - let query = ''; - let hasQuestionMark = false; - - for (const [key, value] of Object.entries(parameters)) { - if (value == null) { - continue; - } - - if (!hasQuestionMark) { - query += '?'; - hasQuestionMark = true; - } else { - query += '&'; - } - - query += `${key}=${value}`; - } - - return query; -} - const MAX_NOT_FOUND_COUNT = 6; const MAX_NO_CHANGE_COUNT = 10; @@ -197,144 +169,6 @@ function splitTextIntoChunks(text, chunkSize = 4000) { return chunks; } -/** - * Input stream text to speech - * @param {Express.Response} res - * @param {AsyncIterable} textStream - * @param {(token: string) => Promise} callback - Whether to continue the stream or not - * @returns {AsyncGenerator} - */ -function inputStreamTextToSpeech(res, textStream, callback) { - const model = 'eleven_monolingual_v1'; - const wsUrl = `wss://api.elevenlabs.io/v1/text-to-speech/${getRandomVoiceId()}/stream-input${assembleQuery( - { - model_id: model, - // flush: true, - // optimize_streaming_latency: this.settings.optimizeStreamingLatency, - optimize_streaming_latency: 1, - // output_format: this.settings.outputFormat, - }, - )}`; - const socket = new WebSocket(wsUrl); - - socket.onopen = function () { - const streamStart = { - text: ' ', - voice_settings: { - stability: 0.5, - similarity_boost: 0.8, - }, - xi_api_key: process.env.ELEVENLABS_API_KEY, - // generation_config: { chunk_length_schedule: [50, 90, 120, 150, 200] }, - }; - - socket.send(JSON.stringify(streamStart)); - - // send stream until done - const streamComplete = new Promise((resolve, reject) => { - (async () => { - let textBuffer = ''; - let shouldContinue = true; - for await (const textDelta of textStream) { - textBuffer += textDelta; - - // using ". " as separator: sending in full sentences improves the quality - // of the audio output significantly. - const separatorIndex = findLastSeparatorIndex(textBuffer); - - // Callback for textStream (will return false if signal is aborted) - shouldContinue = await callback(textDelta); - - if (separatorIndex === -1) { - continue; - } - - if (!shouldContinue) { - break; - } - - const textToProcess = textBuffer.slice(0, separatorIndex); - textBuffer = textBuffer.slice(separatorIndex + 1); - - const request = { - text: textToProcess, - try_trigger_generation: true, - }; - - socket.send(JSON.stringify(request)); - } - - // send remaining text: - if (shouldContinue && textBuffer.length > 0) { - socket.send( - JSON.stringify({ - text: `${textBuffer} `, // append space - try_trigger_generation: true, - }), - ); - } - })() - .then(resolve) - .catch(reject); - }); - - streamComplete - .then(() => { - const endStream = { - text: '', - }; - - socket.send(JSON.stringify(endStream)); - }) - .catch((e) => { - console.error('Error streaming text to speech:', e); - throw e; - }); - }; - - return (async function* audioStream() { - let isDone = false; - let chunks = []; - let resolve; - let waitForMessage = new Promise((r) => (resolve = r)); - - socket.onmessage = function (event) { - // console.log(event); - const audioChunk = JSON.parse(event.data); - if (audioChunk.audio && audioChunk.alignment) { - res.write(`event: audio\ndata: ${event.data}\n\n`); - chunks.push(audioChunk); - resolve(null); - waitForMessage = new Promise((r) => (resolve = r)); - } else if (audioChunk.isFinal) { - isDone = true; - resolve(null); - } else if (audioChunk.message) { - console.warn('Received Elevenlabs message:', audioChunk.message); - resolve(null); - } - }; - - socket.onerror = function (error) { - console.error('WebSocket error:', error); - // throw error; - }; - - socket.onclose = function () { - isDone = true; - resolve(null); - }; - - while (!isDone) { - await waitForMessage; - yield* chunks; - chunks = []; - } - - res.write('event: end\ndata: \n\n'); - })(); -} - /** * * @param {AsyncIterable} llmStream @@ -349,7 +183,6 @@ async function* llmMessageSource(llmStream) { } module.exports = { - inputStreamTextToSpeech, findLastSeparatorIndex, createChunkProcessor, splitTextIntoChunks, diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 4edb0bd56ce..05c9fc1d33f 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -23,7 +23,13 @@ async function fetchImageToBase64(url) { } } -const base64Only = new Set([EModelEndpoint.google, EModelEndpoint.anthropic, 'Ollama', 'ollama']); +const base64Only = new Set([ + EModelEndpoint.google, + EModelEndpoint.anthropic, + 'Ollama', + 'ollama', + EModelEndpoint.bedrock, +]); /** * Encodes and formats the given files. diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index b6ca6e4f4bb..7d2a3ae9ec9 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -5,6 +5,21 @@ const { extractBaseURL, inputSchema, processModelData, logAxiosError } = require const { OllamaClient } = require('~/app/clients/OllamaClient'); const getLogStores = require('~/cache/getLogStores'); +/** + * Splits a string by commas and trims each resulting value. + * @param {string} input - The input string to split. + * @returns {string[]} An array of trimmed values. + */ +const splitAndTrim = (input) => { + if (!input || typeof input !== 'string') { + return []; + } + return input + .split(',') + .map((item) => item.trim()) + .filter(Boolean); +}; + const { openAIApiKey, userProvidedOpenAI } = require('./Config/EndpointService').config; /** @@ -194,7 +209,7 @@ const getOpenAIModels = async (opts) => { } if (process.env[key]) { - models = String(process.env[key]).split(','); + models = splitAndTrim(process.env[key]); return models; } @@ -208,7 +223,7 @@ const getOpenAIModels = async (opts) => { const getChatGPTBrowserModels = () => { let models = ['text-davinci-002-render-sha', 'gpt-4']; if (process.env.CHATGPT_MODELS) { - models = String(process.env.CHATGPT_MODELS).split(','); + models = splitAndTrim(process.env.CHATGPT_MODELS); } return models; @@ -217,7 +232,7 @@ const getChatGPTBrowserModels = () => { const getAnthropicModels = () => { let models = defaultModels[EModelEndpoint.anthropic]; if (process.env.ANTHROPIC_MODELS) { - models = String(process.env.ANTHROPIC_MODELS).split(','); + models = splitAndTrim(process.env.ANTHROPIC_MODELS); } return models; @@ -226,7 +241,16 @@ const getAnthropicModels = () => { const getGoogleModels = () => { let models = defaultModels[EModelEndpoint.google]; if (process.env.GOOGLE_MODELS) { - models = String(process.env.GOOGLE_MODELS).split(','); + models = splitAndTrim(process.env.GOOGLE_MODELS); + } + + return models; +}; + +const getBedrockModels = () => { + let models = defaultModels[EModelEndpoint.bedrock]; + if (process.env.BEDROCK_AWS_MODELS) { + models = splitAndTrim(process.env.BEDROCK_AWS_MODELS); } return models; @@ -234,7 +258,9 @@ const getGoogleModels = () => { module.exports = { fetchModels, + splitAndTrim, getOpenAIModels, + getBedrockModels, getChatGPTBrowserModels, getAnthropicModels, getGoogleModels, diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index fc7c8b1079a..4e4647ee35d 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -1,7 +1,16 @@ const axios = require('axios'); +const { EModelEndpoint, defaultModels } = require('librechat-data-provider'); const { logger } = require('~/config'); -const { fetchModels, getOpenAIModels } = require('./ModelService'); +const { + fetchModels, + splitAndTrim, + getOpenAIModels, + getGoogleModels, + getBedrockModels, + getAnthropicModels, +} = require('./ModelService'); + jest.mock('~/utils', () => { const originalUtils = jest.requireActual('~/utils'); return { @@ -329,3 +338,71 @@ describe('fetchModels with Ollama specific logic', () => { ); }); }); + +describe('splitAndTrim', () => { + it('should split a string by commas and trim each value', () => { + const input = ' model1, model2 , model3,model4 '; + const expected = ['model1', 'model2', 'model3', 'model4']; + expect(splitAndTrim(input)).toEqual(expected); + }); + + it('should return an empty array for empty input', () => { + expect(splitAndTrim('')).toEqual([]); + }); + + it('should return an empty array for null input', () => { + expect(splitAndTrim(null)).toEqual([]); + }); + + it('should return an empty array for undefined input', () => { + expect(splitAndTrim(undefined)).toEqual([]); + }); + + it('should filter out empty values after trimming', () => { + const input = 'model1,, ,model2,'; + const expected = ['model1', 'model2']; + expect(splitAndTrim(input)).toEqual(expected); + }); +}); + +describe('getAnthropicModels', () => { + it('returns default models when ANTHROPIC_MODELS is not set', () => { + delete process.env.ANTHROPIC_MODELS; + const models = getAnthropicModels(); + expect(models).toEqual(defaultModels[EModelEndpoint.anthropic]); + }); + + it('returns models from ANTHROPIC_MODELS when set', () => { + process.env.ANTHROPIC_MODELS = 'claude-1, claude-2 '; + const models = getAnthropicModels(); + expect(models).toEqual(['claude-1', 'claude-2']); + }); +}); + +describe('getGoogleModels', () => { + it('returns default models when GOOGLE_MODELS is not set', () => { + delete process.env.GOOGLE_MODELS; + const models = getGoogleModels(); + expect(models).toEqual(defaultModels[EModelEndpoint.google]); + }); + + it('returns models from GOOGLE_MODELS when set', () => { + process.env.GOOGLE_MODELS = 'gemini-pro, bard '; + const models = getGoogleModels(); + expect(models).toEqual(['gemini-pro', 'bard']); + }); +}); + +describe('getBedrockModels', () => { + it('returns default models when BEDROCK_AWS_MODELS is not set', () => { + delete process.env.BEDROCK_AWS_MODELS; + const models = getBedrockModels(); + expect(models).toEqual(defaultModels[EModelEndpoint.bedrock]); + }); + + it('returns models from BEDROCK_AWS_MODELS when set', () => { + process.env.BEDROCK_AWS_MODELS = 'anthropic.claude-v2, ai21.j2-ultra '; + const models = getBedrockModels(); + expect(models).toEqual(['anthropic.claude-v2', 'ai21.j2-ultra']); + }); +}); diff --git a/api/server/services/cleanup.js b/api/server/services/cleanup.js new file mode 100644 index 00000000000..814c0ecc94c --- /dev/null +++ b/api/server/services/cleanup.js @@ -0,0 +1,13 @@ +const { logger } = require('~/config'); +const { deleteNullOrEmptyConversations } = require('~/models/Conversation'); +const cleanup = async () => { + try { + await deleteNullOrEmptyConversations(); + } catch (error) { + logger.error('[cleanup] Error during app cleanup', error); + } finally { + logger.debug('Startup cleanup complete'); + } +}; + +module.exports = { cleanup }; diff --git a/api/server/services/start/interface.js b/api/server/services/start/interface.js index 314babbcf54..bf31eb78b89 100644 --- a/api/server/services/start/interface.js +++ b/api/server/services/start/interface.js @@ -31,11 +31,18 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol termsOfService: interfaceConfig?.termsOfService ?? defaults.termsOfService, bookmarks: interfaceConfig?.bookmarks ?? defaults.bookmarks, prompts: interfaceConfig?.prompts ?? defaults.prompts, + multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo, }); await updateAccessPermissions(roleName, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, + }); + await updateAccessPermissions(SystemRoles.ADMIN, { + [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, + [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, }); let i = 0; diff --git a/api/server/services/start/interface.spec.js b/api/server/services/start/interface.spec.js index 2009e043ccb..62239a6a297 100644 --- a/api/server/services/start/interface.spec.js +++ b/api/server/services/start/interface.spec.js @@ -16,6 +16,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, }); }); @@ -28,6 +29,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: false }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, }); }); @@ -40,6 +42,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, }); }); @@ -52,6 +55,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, }); }); @@ -64,6 +68,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, }); }); @@ -76,6 +81,72 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, + }); + }); + + it('should call updateAccessPermissions with the correct parameters when multiConvo is true', async () => { + const config = { interface: { multiConvo: true } }; + const configDefaults = { interface: {} }; + + await loadDefaultInterface(config, configDefaults); + + expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { + [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, + }); + }); + + it('should call updateAccessPermissions with false when multiConvo is false', async () => { + const config = { interface: { multiConvo: false } }; + const configDefaults = { interface: {} }; + + await loadDefaultInterface(config, configDefaults); + + expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { + [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, + }); + }); + + it('should call updateAccessPermissions with undefined when multiConvo is not specified in config', async () => { + const config = {}; + const configDefaults = { interface: {} }; + + await loadDefaultInterface(config, configDefaults); + + expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { + [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, + }); + }); + + it('should call updateAccessPermissions with all interface options including multiConvo', async () => { + const config = { interface: { prompts: true, bookmarks: false, multiConvo: true } }; + const configDefaults = { interface: {} }; + + await loadDefaultInterface(config, configDefaults); + + expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { + [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, + [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, + }); + }); + + it('should use default values for multiConvo when config is undefined', async () => { + const config = undefined; + const configDefaults = { interface: { prompts: true, bookmarks: true, multiConvo: false } }; + + await loadDefaultInterface(config, configDefaults); + + expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { + [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, + [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, }); }); }); diff --git a/api/server/utils/crypto.js b/api/server/utils/crypto.js index c143506cc54..ea71df51ad0 100644 --- a/api/server/utils/crypto.js +++ b/api/server/utils/crypto.js @@ -102,4 +102,14 @@ async function hashToken(str) { return Buffer.from(hashBuffer).toString('hex'); } -module.exports = { encrypt, decrypt, encryptV2, decryptV2, hashToken }; +async function getRandomValues(length) { + if (!Number.isInteger(length) || length <= 0) { + throw new Error('Length must be a positive integer'); + } + + const randomValues = new Uint8Array(length); + webcrypto.getRandomValues(randomValues); + return Buffer.from(randomValues).toString('hex'); +} + +module.exports = { encrypt, decrypt, encryptV2, decryptV2, hashToken, getRandomValues }; diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index 756e1da4227..4d9124bb6ad 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -14,6 +14,7 @@ const { LDAP_FULL_NAME, LDAP_ID, LDAP_USERNAME, + LDAP_EMAIL, LDAP_TLS_REJECT_UNAUTHORIZED, } = process.env; @@ -43,6 +44,9 @@ if (LDAP_ID) { if (LDAP_USERNAME) { searchAttributes.push(LDAP_USERNAME); } +if (LDAP_EMAIL) { + searchAttributes.push(LDAP_EMAIL); +} const rejectUnauthorized = isEnabled(LDAP_TLS_REJECT_UNAUTHORIZED); const ldapOptions = { @@ -76,15 +80,6 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { return done(null, false, { message: 'Invalid credentials' }); } - if (!userinfo.mail) { - logger.warn( - '[ldapStrategy]', - 'No email attributes found in userinfo', - JSON.stringify(userinfo, null, 2), - ); - return done(null, false, { message: 'Invalid credentials' }); - } - try { const ldapId = (LDAP_ID && userinfo[LDAP_ID]) || userinfo.uid || userinfo.sAMAccountName || userinfo.mail; @@ -100,12 +95,25 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { const username = (LDAP_USERNAME && userinfo[LDAP_USERNAME]) || userinfo.givenName || userinfo.mail; + const mail = (LDAP_EMAIL && userinfo[LDAP_EMAIL]) || userinfo.mail || username + '@ldap.local'; + + if (!userinfo.mail && !(LDAP_EMAIL && userinfo[LDAP_EMAIL])) { + logger.warn( + '[ldapStrategy]', + `No valid email attribute found in LDAP userinfo. Using fallback email: ${username}@ldap.local`, + `LDAP_EMAIL env var: ${LDAP_EMAIL || 'not set'}`, + `Available userinfo attributes: ${Object.keys(userinfo).join(', ')}`, + 'Full userinfo:', + JSON.stringify(userinfo, null, 2), + ); + } + if (!user) { user = { provider: 'ldap', ldapId, username, - email: userinfo.mail, + email: mail, emailVerified: true, // The ldap server administrator should verify the email name: fullName, }; @@ -116,7 +124,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { // so update the user information with the values registered in LDAP user.provider = 'ldap'; user.ldapId = ldapId; - user.email = userinfo.mail; + user.email = mail; user.username = username; user.name = fullName; } diff --git a/api/typedefs.js b/api/typedefs.js index 6591d192b12..163768f58ac 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -20,12 +20,30 @@ * @memberof typedefs */ +/** + * @exports AgentRun + * @typedef {import('@librechat/agents').Run} AgentRun + * @memberof typedefs + */ + +/** + * @exports IState + * @typedef {import('@librechat/agents').IState} IState + * @memberof typedefs + */ + /** * @exports ClientCallbacks * @typedef {import('@librechat/agents').ClientCallbacks} ClientCallbacks * @memberof typedefs */ +/** + * @exports BedrockClientOptions + * @typedef {import('@librechat/agents').BedrockConverseClientOptions} BedrockClientOptions + * @memberof typedefs + */ + /** * @exports StreamEventData * @typedef {import('@librechat/agents').StreamEventData} StreamEventData @@ -38,6 +56,12 @@ * @memberof typedefs */ +/** + * @exports UsageMetadata + * @typedef {import('@langchain/core/messages').UsageMetadata} UsageMetadata + * @memberof typedefs + */ + /** * @exports Ollama * @typedef {import('ollama').Ollama} Ollama @@ -893,6 +917,12 @@ * @memberof typedefs */ +/** + * @exports TAgentClient + * @typedef {import('./server/controllers/agents/client')} TAgentClient + * @memberof typedefs + */ + /** * @exports ImportBatchBuilder * @typedef {import('./server/utils/import/importBatchBuilder.js').ImportBatchBuilder} ImportBatchBuilder @@ -1413,7 +1443,19 @@ */ /** - * @typedef {AnthropicStreamUsage} StreamUsage - Stream usage for all providers (currently only Anthropic) + * @exports OpenAIUsageMetadata + * @typedef {Object} OpenAIUsageMetadata - Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.). + * @property {number} [usage.completion_tokens] - Number of completion tokens used over the course of the run. + * @property {number} [usage.prompt_tokens] - Number of prompt tokens used over the course of the run. + * @property {number} [usage.total_tokens] - Total number of tokens used (prompt + completion). + * @property {number} [usage.reasoning_tokens] - Total number of tokens used for reasoning (OpenAI o1 models). + * @property {Object} [usage.completion_tokens_details] - Further details on the completion tokens used (OpenAI o1 models). + * @property {number} [usage.completion_tokens_details.reasoning_tokens] - Total number of tokens used for reasoning (OpenAI o1 models). + * @memberof typedefs + */ + +/** + * @typedef {AnthropicStreamUsage | OpenAIUsageMetadata | UsageMetadata} StreamUsage - Stream usage for all providers (currently only Anthropic, OpenAI, LangChain) */ /* Native app/client methods */ diff --git a/api/utils/tokens.js b/api/utils/tokens.js index 83246c5b74d..8c2a8a6cc18 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -2,18 +2,21 @@ const z = require('zod'); const { EModelEndpoint } = require('librechat-data-provider'); const openAIModels = { + o1: 127500, // -500 from max + 'o1-mini': 127500, // -500 from max + 'o1-preview': 127500, // -500 from max 'gpt-4': 8187, // -5 from max 'gpt-4-0613': 8187, // -5 from max 'gpt-4-32k': 32758, // -10 from max 'gpt-4-32k-0314': 32758, // -10 from max 'gpt-4-32k-0613': 32758, // -10 from max - 'gpt-4-1106': 127990, // -10 from max - 'gpt-4-0125': 127990, // -10 from max - 'gpt-4o': 127990, // -10 from max - 'gpt-4o-mini': 127990, // -10 from max - 'gpt-4o-2024-08-06': 127990, // -10 from max - 'gpt-4-turbo': 127990, // -10 from max - 'gpt-4-vision': 127990, // -10 from max + 'gpt-4-1106': 127500, // -500 from max + 'gpt-4-0125': 127500, // -500 from max + 'gpt-4o': 127500, // -500 from max + 'gpt-4o-mini': 127500, // -500 from max + 'gpt-4o-2024-08-06': 127500, // -500 from max + 'gpt-4-turbo': 127500, // -500 from max + 'gpt-4-vision': 127500, // -500 from max 'gpt-3.5-turbo': 16375, // -10 from max 'gpt-3.5-turbo-0613': 4092, // -5 from max 'gpt-3.5-turbo-0301': 4092, // -5 from max @@ -21,9 +24,15 @@ const openAIModels = { 'gpt-3.5-turbo-16k-0613': 16375, // -10 from max 'gpt-3.5-turbo-1106': 16375, // -10 from max 'gpt-3.5-turbo-0125': 16375, // -10 from max +}; + +const mistralModels = { 'mistral-': 31990, // -10 from max - llama3: 8187, // -5 from max - 'llama-3': 8187, // -5 from max + 'mistral-7b': 31990, // -10 from max + 'mistral-small': 31990, // -10 from max + 'mixtral-8x7b': 31990, // -10 from max + 'mistral-large-2402': 127500, + 'mistral-large-2407': 127500, }; const cohereModels = { @@ -54,6 +63,7 @@ const googleModels = { const anthropicModels = { 'claude-': 100000, + 'claude-instant': 100000, 'claude-2': 100000, 'claude-2.1': 200000, 'claude-3-haiku': 200000, @@ -63,7 +73,38 @@ const anthropicModels = { 'claude-3.5-sonnet': 200000, }; -const aggregateModels = { ...openAIModels, ...googleModels, ...anthropicModels, ...cohereModels }; +const metaModels = { + 'llama2-13b': 4000, + 'llama2-70b': 4000, + 'llama3-8b': 8000, + 'llama3-70b': 8000, + 'llama3-1-8b': 127500, + 'llama3-1-70b': 127500, + 'llama3-1-405b': 127500, +}; + +const ai21Models = { + 'ai21.j2-mid-v1': 8182, // -10 from max + 'ai21.j2-ultra-v1': 8182, // -10 from max + 'ai21.jamba-instruct-v1:0': 255500, // -500 from max +}; + +const amazonModels = { + 'amazon.titan-text-lite-v1': 4000, + 'amazon.titan-text-express-v1': 8000, + 'amazon.titan-text-premier-v1:0': 31500, // -500 from max +}; + +const bedrockModels = { + ...anthropicModels, + ...mistralModels, + ...cohereModels, + ...metaModels, + ...ai21Models, + ...amazonModels, +}; + +const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels }; const maxTokensMap = { [EModelEndpoint.azureOpenAI]: openAIModels, @@ -72,6 +113,29 @@ const maxTokensMap = { [EModelEndpoint.custom]: aggregateModels, [EModelEndpoint.google]: googleModels, [EModelEndpoint.anthropic]: anthropicModels, + [EModelEndpoint.bedrock]: bedrockModels, +}; + +const modelMaxOutputs = { + o1: 32268, // -500 from max: 32,768 + 'o1-mini': 65136, // -500 from max: 65,536 + 'o1-preview': 32268, // -500 from max: 32,768 + system_default: 1024, +}; + +const anthropicMaxOutputs = { + 'claude-3-haiku': 4096, + 'claude-3-sonnet': 4096, + 'claude-3-opus': 4096, + 'claude-3.5-sonnet': 8192, + 'claude-3-5-sonnet': 8192, +}; + +const maxOutputTokensMap = { + [EModelEndpoint.anthropic]: anthropicMaxOutputs, + [EModelEndpoint.azureOpenAI]: modelMaxOutputs, + [EModelEndpoint.openAI]: modelMaxOutputs, + [EModelEndpoint.custom]: modelMaxOutputs, }; /** @@ -93,27 +157,15 @@ function findMatchingPattern(modelName, tokensMap) { } /** - * Retrieves the maximum tokens for a given model name. If the exact model name isn't found, - * it searches for partial matches within the model name, checking keys in reverse order. + * Retrieves a token value for a given model name from a tokens map. * * @param {string} modelName - The name of the model to look up. - * @param {string} endpoint - The endpoint (default is 'openAI'). - * @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup - * @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found. - * - * @example - * getModelMaxTokens('gpt-4-32k-0613'); // Returns 32767 - * getModelMaxTokens('gpt-4-32k-unknown'); // Returns 32767 - * getModelMaxTokens('unknown-model'); // Returns undefined + * @param {EndpointTokenConfig | Record} tokensMap - The map of model names to token values. + * @param {string} [key='context'] - The key to look up in the tokens map. + * @returns {number|undefined} The token value for the given model or undefined if no match is found. */ -function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) { - if (typeof modelName !== 'string') { - return undefined; - } - - /** @type {EndpointTokenConfig | Record} */ - const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint]; - if (!tokensMap) { +function getModelTokenValue(modelName, tokensMap, key = 'context') { + if (typeof modelName !== 'string' || !tokensMap) { return undefined; } @@ -129,10 +181,36 @@ function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpoint if (matchedPattern) { const result = tokensMap[matchedPattern]; - return result?.context ?? result; + return result?.[key] ?? result ?? tokensMap.system_default; } - return undefined; + return tokensMap.system_default; +} + +/** + * Retrieves the maximum tokens for a given model name. + * + * @param {string} modelName - The name of the model to look up. + * @param {string} endpoint - The endpoint (default is 'openAI'). + * @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup + * @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found. + */ +function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) { + const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint]; + return getModelTokenValue(modelName, tokensMap); +} + +/** + * Retrieves the maximum output tokens for a given model name. + * + * @param {string} modelName - The name of the model to look up. + * @param {string} endpoint - The endpoint (default is 'openAI'). + * @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup + * @returns {number|undefined} The maximum output tokens for the given model or undefined if no match is found. + */ +function getModelMaxOutputTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) { + const tokensMap = endpointTokenConfig ?? maxOutputTokensMap[endpoint]; + return getModelTokenValue(modelName, tokensMap, 'output'); } /** @@ -259,7 +337,8 @@ module.exports = { maxTokensMap, inputSchema, modelSchema, - getModelMaxTokens, matchModelName, processModelData, + getModelMaxTokens, + getModelMaxOutputTokens, }; diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index afcd4b217a8..e76e01a5684 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -20,18 +20,6 @@ describe('getModelMaxTokens', () => { ); }); - test('should return correct tokens for LLama 3 models', () => { - expect(getModelMaxTokens('meta-llama/llama-3-8b')).toBe( - maxTokensMap[EModelEndpoint.openAI]['llama-3'], - ); - expect(getModelMaxTokens('meta-llama/llama-3-8b')).toBe( - maxTokensMap[EModelEndpoint.openAI]['llama3'], - ); - expect(getModelMaxTokens('llama-3-500b')).toBe(maxTokensMap[EModelEndpoint.openAI]['llama-3']); - expect(getModelMaxTokens('llama3-70b')).toBe(maxTokensMap[EModelEndpoint.openAI]['llama3']); - expect(getModelMaxTokens('llama3:latest')).toBe(maxTokensMap[EModelEndpoint.openAI]['llama3']); - }); - test('should return undefined for no match', () => { expect(getModelMaxTokens('unknown-model')).toBeUndefined(); }); diff --git a/client/index.html b/client/index.html index 633685704c8..3363299f437 100644 --- a/client/index.html +++ b/client/index.html @@ -1,48 +1,61 @@ - - - - - - - RaisChat - - - - - - - - -
- - - + + + + + + + RaisChat + + + + + + + + + + + +
+
+
+ + + + \ No newline at end of file diff --git a/client/package.json b/client/package.json index 9ea26134598..7803a6cf272 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/frontend", - "version": "v0.7.5-rc1", + "version": "v0.7.5-rc2", "description": "", "type": "module", "scripts": { @@ -28,7 +28,7 @@ }, "homepage": "https://librechat.ai", "dependencies": { - "@ariakit/react": "^0.4.8", + "@ariakit/react": "^0.4.11", "@codesandbox/sandpack-react": "^2.18.2", "@dicebear/collection": "^7.0.4", "@dicebear/core": "^7.0.4", @@ -54,7 +54,7 @@ "@tanstack/react-query": "^4.28.0", "@tanstack/react-table": "^8.11.7", "@zattoo/use-double-click": "1.2.0", - "axios": "^1.3.4", + "axios": "^1.7.7", "class-variance-authority": "^0.6.0", "clsx": "^1.2.1", "copy-to-clipboard": "^3.3.3", @@ -63,6 +63,7 @@ "downloadjs": "^1.4.7", "export-from-json": "^1.7.2", "filenamify": "^6.0.0", + "framer-motion": "^11.5.4", "html-to-image": "^1.11.11", "image-blob-reduce": "^4.1.0", "js-cookie": "^3.0.5", diff --git a/client/src/Providers/AnnouncerContext.tsx b/client/src/Providers/AnnouncerContext.tsx index 34437d4e305..a45cbd20897 100644 --- a/client/src/Providers/AnnouncerContext.tsx +++ b/client/src/Providers/AnnouncerContext.tsx @@ -1,10 +1,6 @@ // AnnouncerContext.tsx import React from 'react'; - -export interface AnnounceOptions { - message: string; - isStatus?: boolean; -} +import type { AnnounceOptions } from '~/common'; interface AnnouncerContextType { announceAssertive: (options: AnnounceOptions) => void; diff --git a/client/src/common/a11y.ts b/client/src/common/a11y.ts new file mode 100644 index 00000000000..0a0e56eab2f --- /dev/null +++ b/client/src/common/a11y.ts @@ -0,0 +1,6 @@ +export interface AnnounceOptions { + message: string; + isStatus?: boolean; +} + +export const MESSAGE_UPDATE_INTERVAL = 7000; diff --git a/client/src/common/agents-types.ts b/client/src/common/agents-types.ts index eaf64f4c6cb..07633a68dbd 100644 --- a/client/src/common/agents-types.ts +++ b/client/src/common/agents-types.ts @@ -1,8 +1,8 @@ import { Capabilities } from 'librechat-data-provider'; import type { Agent, AgentProvider, AgentModelParameters } from 'librechat-data-provider'; -import type { Option, ExtendedFile } from './types'; +import type { OptionWithIcon, ExtendedFile } from './types'; -export type TAgentOption = Option & +export type TAgentOption = OptionWithIcon & Agent & { files?: Array<[string, ExtendedFile]>; code_files?: Array<[string, ExtendedFile]>; @@ -23,5 +23,5 @@ export type AgentForm = { model: string | null; model_parameters: AgentModelParameters; tools?: string[]; - provider?: AgentProvider | Option; + provider?: AgentProvider | OptionWithIcon; } & AgentCapabilities; diff --git a/client/src/common/index.ts b/client/src/common/index.ts index 29739c7bd8f..85dda0700cb 100644 --- a/client/src/common/index.ts +++ b/client/src/common/index.ts @@ -1,3 +1,4 @@ +export * from './a11y'; export * from './artifacts'; export * from './types'; export * from './assistants-types'; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index bf003fbf68e..461dfcbe488 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -339,8 +339,12 @@ export type TAdditionalProps = { export type TMessageContentProps = TInitialProps & TAdditionalProps; export type TText = Pick & { className?: string }; -export type TEditProps = Pick & - Omit; +export type TEditProps = Pick & + Omit & { + text?: string; + index?: number; + siblingIdx: number | null; + }; export type TDisplayProps = TText & Pick & { showCursor?: boolean; @@ -536,3 +540,9 @@ export type TVectorStore = { }; export type TThread = { id: string; createdAt: string }; + +declare global { + interface Window { + google_tag_manager?: unknown; + } +} \ No newline at end of file diff --git a/client/src/components/Auth/AuthLayout.tsx b/client/src/components/Auth/AuthLayout.tsx index da53dcb2fb6..1d7c34add26 100644 --- a/client/src/components/Auth/AuthLayout.tsx +++ b/client/src/components/Auth/AuthLayout.tsx @@ -3,6 +3,7 @@ import { BlinkAnimation } from './BlinkAnimation'; import { TStartupConfig } from 'librechat-data-provider'; import SocialLoginRender from './SocialLoginRender'; import { ThemeSelector } from '~/components/ui'; +import { Banner } from '../Banners'; import Footer from './Footer'; const ErrorRender = ({ children }: { children: React.ReactNode }) => ( @@ -56,6 +57,7 @@ function AuthLayout({ return (
+
Logo diff --git a/client/src/components/Auth/LoginForm.tsx b/client/src/components/Auth/LoginForm.tsx index 030c8be4934..abc18e48e9e 100644 --- a/client/src/components/Auth/LoginForm.tsx +++ b/client/src/components/Auth/LoginForm.tsx @@ -81,7 +81,7 @@ const LoginForm: React.FC = ({ onSubmit, startupConfig, error, method="POST" onSubmit={handleSubmit((data) => onSubmit(data))} > -
+
= ({ onSubmit, startupConfig, error, }, })} aria-invalid={!!errors.email} - className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-gray-300 bg-transparent px-3.5 pb-3.5 pt-4 text-sm text-gray-900 focus:border-sky-500 focus:outline-none focus:ring-0 dark:border-gray-600 dark:text-white dark:focus:border-sky-500" + className=" + webkit-dark-styles transition-color peer w-full rounded-2xl border border-border-light + bg-surface-primary px-3.5 pb-2.5 pt-3 text-text-primary duration-200 focus:border-green-500 focus:outline-none + " placeholder=" " />