Skip to content

Commit

Permalink
Merge branch 'main' into added-codeql
Browse files Browse the repository at this point in the history
  • Loading branch information
rubentalstra authored Feb 7, 2025
2 parents ce879d7 + 18339ec commit 99f078a
Show file tree
Hide file tree
Showing 30 changed files with 2,311 additions and 1,982 deletions.
4 changes: 3 additions & 1 deletion api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class BaseClient {
this.continued;
/** @type {TMessage[]} */
this.currentMessages = [];
/** @type {import('librechat-data-provider').VisionModes | undefined} */
this.visionMode;
}

setOptions() {
Expand Down Expand Up @@ -1095,7 +1097,7 @@ class BaseClient {
file_id: { $in: fileIds },
});

await this.addImageURLs(message, files);
await this.addImageURLs(message, files, this.visionMode);

this.message_file_map[message.messageId] = files;
return message;
Expand Down
103 changes: 40 additions & 63 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ const {
getResponseSender,
endpointSettings,
EModelEndpoint,
ContentTypes,
VisionModes,
ErrorTypes,
Constants,
AuthKeys,
} = require('librechat-data-provider');
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
const { encodeAndFormat } = require('~/server/services/Files/images');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens');
Expand Down Expand Up @@ -70,7 +72,7 @@ class GoogleClient extends BaseClient {
/** The key for the usage object's output tokens
* @type {string} */
this.outputTokensKey = 'output_tokens';

this.visionMode = VisionModes.generative;
if (options.skipSetOptions) {
return;
}
Expand Down Expand Up @@ -215,10 +217,29 @@ class GoogleClient extends BaseClient {
}

formatMessages() {
return ((message) => ({
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
content: message?.content ?? message.text,
})).bind(this);
return ((message) => {
const msg = {
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
content: message?.content ?? message.text,
};

if (!message.image_urls?.length) {
return msg;
}

msg.content = (
!Array.isArray(msg.content)
? [
{
type: ContentTypes.TEXT,
[ContentTypes.TEXT]: msg.content,
},
]
: msg.content
).concat(message.image_urls);

return msg;
}).bind(this);
}

/**
Expand Down Expand Up @@ -566,6 +587,7 @@ class GoogleClient extends BaseClient {

if (this.project_id != null) {
logger.debug('Creating VertexAI client');
this.visionMode = undefined;
clientOptions.streaming = true;
const client = new ChatVertexAI(clientOptions);
client.temperature = clientOptions.temperature;
Expand Down Expand Up @@ -607,13 +629,14 @@ class GoogleClient extends BaseClient {
}

async getCompletion(_payload, options = {}) {
const safetySettings = this.getSafetySettings();
const { onProgress, abortController } = options;
const safetySettings = getSafetySettings(this.modelOptions.model);
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
const modelName = this.modelOptions.modelName ?? this.modelOptions.model ?? '';

let reply = '';

/** @type {Error} */
let error;
try {
if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) {
/** @type {GenAI} */
Expand Down Expand Up @@ -714,8 +737,16 @@ class GoogleClient extends BaseClient {
this.usage = usageMetadata;
}
} catch (e) {
error = e;
logger.error('[GoogleClient] There was an issue generating the completion', e);
}

if (error != null && reply === '') {
const errorMessage = `{ "type": "${ErrorTypes.GoogleError}", "info": "${
error.message ?? 'The Google provider failed to generate content, please contact the Admin.'
}" }`;
throw new Error(errorMessage);
}
return reply;
}

Expand Down Expand Up @@ -781,12 +812,11 @@ class GoogleClient extends BaseClient {
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
*/
async titleChatCompletion(_payload, options = {}) {
const { abortController } = options;
const safetySettings = this.getSafetySettings();

let reply = '';
const { abortController } = options;

const model = this.modelOptions.modelName ?? this.modelOptions.model ?? '';
const safetySettings = getSafetySettings(model);
if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) {
logger.debug('Identified titling model as GenAI version');
/** @type {GenerativeModel} */
Expand Down Expand Up @@ -844,17 +874,6 @@ class GoogleClient extends BaseClient {
},
]);

const model = process.env.GOOGLE_TITLE_MODEL ?? this.modelOptions.model;
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
this.isVisionModel = validateVisionModel({ model, availableModels });

if (this.isVisionModel) {
logger.warn(
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
);
this.modelOptions.model = settings.model.default;
}

try {
this.initializeClient();
title = await this.titleChatCompletion(payload, {
Expand Down Expand Up @@ -892,48 +911,6 @@ class GoogleClient extends BaseClient {
return reply.trim();
}

getSafetySettings() {
const model = this.modelOptions.model;
const isGemini2 = model.includes('gemini-2.0') && !model.includes('thinking');
const mapThreshold = (value) => {
if (isGemini2 && value === 'BLOCK_NONE') {
return 'OFF';
}
return value;
};

return [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold: mapThreshold(
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: mapThreshold(
process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
},
{
category: 'HARM_CATEGORY_HARASSMENT',
threshold: mapThreshold(
process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold: mapThreshold(
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
},
{
category: 'HARM_CATEGORY_CIVIC_INTEGRITY',
threshold: mapThreshold(process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE'),
},
];
}

getEncoding() {
return 'cl100k_base';
}
Expand Down
6 changes: 3 additions & 3 deletions api/config/parsers.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const traverse = require('traverse');

const SPLAT_SYMBOL = Symbol.for('splat');
const MESSAGE_SYMBOL = Symbol.for('message');
const CONSOLE_JSON_LONG_STRING_LENGTH=parseInt(process.env.CONSOLE_JSON_LONG_STRING_LENGTH) || 255;
const CONSOLE_JSON_STRING_LENGTH = parseInt(process.env.CONSOLE_JSON_STRING_LENGTH) || 255;

const sensitiveKeys = [
/^(sk-)[^\s]+/, // OpenAI API key pattern
Expand Down Expand Up @@ -206,13 +206,13 @@ const jsonTruncateFormat = winston.format((info) => {
seen.add(obj);

if (Array.isArray(obj)) {
return obj.map(item => truncateObject(item));
return obj.map((item) => truncateObject(item));
}

const newObj = {};
Object.entries(obj).forEach(([key, value]) => {
if (typeof value === 'string') {
newObj[key] = truncateLongStrings(value, CONSOLE_JSON_LONG_STRING_LENGTH);
newObj[key] = truncateLongStrings(value, CONSOLE_JSON_STRING_LENGTH);
} else {
newObj[key] = truncateObject(value);
}
Expand Down
9 changes: 7 additions & 2 deletions api/models/tx.js
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,14 @@ const tokenValues = Object.assign(
/* cohere doesn't have rates for the older command models,
so this was from https://artificialanalysis.ai/models/command-light/providers */
command: { prompt: 0.38, completion: 0.38 },
'gemini-2.0-flash-lite': { prompt: 0.075, completion: 0.3 },
'gemini-2.0-flash': { prompt: 0.1, completion: 0.7 },
'gemini-2.0': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing
'gemini-1.5': { prompt: 7, completion: 21 }, // May 2nd, 2024 pricing
gemini: { prompt: 0.5, completion: 1.5 }, // May 2nd, 2024 pricing
'gemini-1.5-flash-8b': { prompt: 0.075, completion: 0.3 },
'gemini-1.5-flash': { prompt: 0.15, completion: 0.6 },
'gemini-1.5': { prompt: 2.5, completion: 10 },
'gemini-pro-vision': { prompt: 0.5, completion: 1.5 },
gemini: { prompt: 0.5, completion: 1.5 },
},
bedrockValues,
);
Expand Down
78 changes: 78 additions & 0 deletions api/models/tx.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,81 @@ describe('getCacheMultiplier', () => {
).toBe(0.03);
});
});

describe('Google Model Tests', () => {
const googleModels = [
'gemini-2.0-flash-lite-preview-02-05',
'gemini-2.0-flash-001',
'gemini-2.0-flash-exp',
'gemini-2.0-pro-exp-02-05',
'gemini-1.5-flash-8b',
'gemini-1.5-flash-thinking',
'gemini-1.5-pro-latest',
'gemini-1.5-pro-preview-0409',
'gemini-pro-vision',
'gemini-1.0',
'gemini-pro',
];

it('should return the correct prompt and completion rates for all models', () => {
const results = googleModels.map((model) => {
const valueKey = getValueKey(model, EModelEndpoint.google);
const promptRate = getMultiplier({
model,
tokenType: 'prompt',
endpoint: EModelEndpoint.google,
});
const completionRate = getMultiplier({
model,
tokenType: 'completion',
endpoint: EModelEndpoint.google,
});
return { model, valueKey, promptRate, completionRate };
});

results.forEach(({ valueKey, promptRate, completionRate }) => {
expect(promptRate).toBe(tokenValues[valueKey].prompt);
expect(completionRate).toBe(tokenValues[valueKey].completion);
});
});

it('should map to the correct model keys', () => {
const expected = {
'gemini-2.0-flash-lite-preview-02-05': 'gemini-2.0-flash-lite',
'gemini-2.0-flash-001': 'gemini-2.0-flash',
'gemini-2.0-flash-exp': 'gemini-2.0-flash',
'gemini-2.0-pro-exp-02-05': 'gemini-2.0',
'gemini-1.5-flash-8b': 'gemini-1.5-flash-8b',
'gemini-1.5-flash-thinking': 'gemini-1.5-flash',
'gemini-1.5-pro-latest': 'gemini-1.5',
'gemini-1.5-pro-preview-0409': 'gemini-1.5',
'gemini-pro-vision': 'gemini-pro-vision',
'gemini-1.0': 'gemini',
'gemini-pro': 'gemini',
};

Object.entries(expected).forEach(([model, expectedKey]) => {
const valueKey = getValueKey(model, EModelEndpoint.google);
expect(valueKey).toBe(expectedKey);
});
});

it('should handle model names with different formats', () => {
const testCases = [
{ input: 'google/gemini-pro', expected: 'gemini' },
{ input: 'gemini-pro/google', expected: 'gemini' },
{ input: 'google/gemini-2.0-flash-lite', expected: 'gemini-2.0-flash-lite' },
];

testCases.forEach(({ input, expected }) => {
const valueKey = getValueKey(input, EModelEndpoint.google);
expect(valueKey).toBe(expected);
expect(
getMultiplier({ model: input, tokenType: 'prompt', endpoint: EModelEndpoint.google }),
).toBe(tokenValues[expected].prompt);
expect(
getMultiplier({ model: input, tokenType: 'completion', endpoint: EModelEndpoint.google }),
).toBe(tokenValues[expected].completion);
});
});
});
2 changes: 1 addition & 1 deletion api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"@langchain/google-genai": "^0.1.7",
"@langchain/google-vertexai": "^0.1.8",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.0.1",
"@librechat/agents": "^2.0.2",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.7.7",
"bcryptjs": "^2.4.3",
Expand Down
46 changes: 35 additions & 11 deletions api/server/services/Endpoints/google/llm.js
Original file line number Diff line number Diff line change
@@ -1,18 +1,42 @@
const { Providers } = require('@librechat/agents');
const { AuthKeys } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');

function getThresholdMapping(model) {
const gemini1Pattern = /gemini-(1\.0|1\.5|pro$|1\.0-pro|1\.5-pro|1\.5-flash-001)/;
const restrictedPattern = /(gemini-(1\.5-flash-8b|2\.0|exp)|learnlm)/;

if (gemini1Pattern.test(model)) {
return (value) => {
if (value === 'OFF') {
return 'BLOCK_NONE';
}
return value;
};
}

if (restrictedPattern.test(model)) {
return (value) => {
if (value === 'OFF' || value === 'HARM_BLOCK_THRESHOLD_UNSPECIFIED') {
return 'BLOCK_NONE';
}
return value;
};
}

return (value) => value;
}

/**
*
* @param {boolean} isGemini2
* @returns {Array<{category: string, threshold: string}>}
* @param {string} model
* @returns {Array<{category: string, threshold: string}> | undefined}
*/
function getSafetySettings(isGemini2) {
const mapThreshold = (value) => {
if (isGemini2 && value === 'BLOCK_NONE') {
return 'OFF';
}
return value;
};
function getSafetySettings(model) {
if (isEnabled(process.env.GOOGLE_EXCLUDE_SAFETY_SETTINGS)) {
return undefined;
}
const mapThreshold = getThresholdMapping(model);

return [
{
Expand Down Expand Up @@ -85,8 +109,7 @@ function getLLMConfig(credentials, options = {}) {
};

/** Used only for Safety Settings */
const isGemini2 = llmConfig.model.includes('gemini-2.0') && !llmConfig.model.includes('thinking');
llmConfig.safetySettings = getSafetySettings(isGemini2);
llmConfig.safetySettings = getSafetySettings(llmConfig.model);

let provider;

Expand Down Expand Up @@ -153,4 +176,5 @@ function getLLMConfig(credentials, options = {}) {

module.exports = {
getLLMConfig,
getSafetySettings,
};
Loading

0 comments on commit 99f078a

Please sign in to comment.