Skip to content

Commit

Permalink
Merge pull request #139 from VictorS67/encre-core-openai-community-1
Browse files Browse the repository at this point in the history
Support community models in openai chat
  • Loading branch information
VictorS67 authored Jan 31, 2025
2 parents 70917a9 + 9cbe91b commit 24477ed
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 23 deletions.
2 changes: 2 additions & 0 deletions packages/core/jest.global.env.ts
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
process.env.OPENAI_API_KEY = "you_should_get_this_api_from_openai";
process.env.DEEPSEEK_API_KEY = "you_should_get_this_api_from_deepseek";
process.env.MOONSHOT_API_KEY = "you_should_get_this_api_from_moonshot";
75 changes: 53 additions & 22 deletions packages/core/src/events/inference/chat/llms/openai/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import {
isJSONInContent,
} from './utils.js';
import {
checkModalForDeepSeek,
checkModalForMoonShot,
checkModelForOpenAIChat,
checkModelForOpenAIVision,
type OpenAIChatCallOptions,
Expand Down Expand Up @@ -301,21 +303,6 @@ export class OpenAIChat<

super(fields ?? {});

this.openAIApiKey =
fields?.openAIApiKey ?? getEnvironmentVariables('OPENAI_API_KEY');

if (!this.openAIApiKey) {
throw new Error('OpenAI API Key not found');
}

this.modelName = fields?.modelName ?? this.modelName;

if (!fields.configuration && !checkModelForOpenAIChat(this.modelName)) {
throw new Error(
'model is not valid for OpenAIChat, please check openai model lists for chat completions'
);
}

this.temperature = fields?.temperature ?? this.temperature;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.topP = fields?.topP ?? this.topP;
Expand All @@ -334,10 +321,6 @@ export class OpenAIChat<
this.streaming = fields?.streaming ?? this.streaming;
this.chatMessages = fields?.chatMessages;

this.organization =
fields?.configuration?.organization ??
getEnvironmentVariables('OPENAI_ORGANIZATION');

this._clientOptions = {
apiKey: this.openAIApiKey,
organization: this.organization,
Expand All @@ -348,9 +331,11 @@ export class OpenAIChat<
...fields?.configuration,
};

if (checkModelForOpenAIVision(this.modelName)) {
this._isMultimodal = true;
}
this._validateModels(
fields?.openAIApiKey,
fields?.modelName,
fields?.configuration
);
}

_llmType(): string {
Expand Down Expand Up @@ -424,6 +409,10 @@ export class OpenAIChat<
messages: BaseMessage[],
options: this['SerializedCallOptions']
): Promise<LLMResult> {
if (!this.openAIApiKey) {
throw new Error('OpenAIChat API Key not found');
}

if (this._isMultimodal) {
if (this.responseFormatType !== undefined) {
console.warn(
Expand Down Expand Up @@ -1199,6 +1188,48 @@ export class OpenAIChat<

return requestOptions;
}

/**
* Validate models from the community (e.g. deepseek, moonshot) which supports OpenAI's API.
*
* @param apiKey - API key. Models can support specific key names from the environment.
* @param modelName - model name.
* @param configuration - OpenAI's API configuration.
*/
private _validateModels(
apiKey?: string,
modelName?: string,
configuration?: OpenAIClientOptions | undefined
): void {
const _modelName: string | undefined = modelName ?? this.modelName;

if (checkModelForOpenAIChat(_modelName)) {
this.modelName = _modelName;
this.openAIApiKey = apiKey ?? getEnvironmentVariables('OPENAI_API_KEY');
this.organization =
configuration?.organization ??
getEnvironmentVariables('OPENAI_ORGANIZATION');

if (checkModelForOpenAIVision(this.modelName)) {
this._isMultimodal = true;
}
} else if (checkModalForDeepSeek(_modelName)) {
this.modelName = _modelName;
this.openAIApiKey = apiKey ?? getEnvironmentVariables('DEEPSEEK_API_KEY');
this._clientOptions.baseURL = 'https://api.deepseek.com/v1';
this._isMultimodal = false;
} else if (checkModalForMoonShot(_modelName)) {
this.modelName = _modelName;
this.openAIApiKey = apiKey ?? getEnvironmentVariables('MOONSHOT_API_KEY');
this._isMultimodal = false;
this._clientOptions.baseURL = 'https://api.moonshot.cn/v1';
} else if (!configuration) {
throw new Error(
'model is not valid for OpenAIChat, please check openai model lists for chat completions'
);
}
this._clientOptions.apiKey = this.openAIApiKey;
}
}

function getMessageFromChatCompletionDelta(
Expand Down
23 changes: 22 additions & 1 deletion packages/core/src/events/inference/chat/llms/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ export interface OpenAIBaseInput {
openAIApiKey?: string;
}


/**
* Configuration for standard OpenAI API input parameters.
*/
Expand Down Expand Up @@ -432,6 +431,28 @@ export function checkModelForOpenAIChat(modelName?: string): boolean {
);
}

/**
* Checks if a model name is suitable for DeepSeek chat functions.
* @param modelName The model name to check.
* @returns True if the model is compatible with chat functions.
*/
export function checkModalForDeepSeek(modelName?: string): boolean {
return (
modelName !== undefined &&
(modelName.startsWith('deepseek-chat') ||
modelName.startsWith('deepseek-code'))
);
}

/**
* Checks if a model name is suitable for MoonShot chat functions.
* @param modelName The model name to check.
* @returns True if the model is compatible with chat functions.
*/
export function checkModalForMoonShot(modelName?: string): boolean {
return modelName !== undefined && modelName.startsWith('moonshot-v1');
}

/**
* Checks if a model name is suitable for OpenAI vision-related tasks.
* @param modelName The model name to check.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP

exports[`DeepSeek test DeepSeek text 1`] = `
{
"generations": [
{
"info": {
"finishReason": "stop",
},
"message": {
"_grp": 1,
"_id": [
"events",
"input",
"load",
"msgs",
"assistant",
"BotMessage",
],
"_kwargs": {
"additional_kwargs": {},
"content": "Hello! I'm DeepSeek-V3, an artificial intelligence assistant created by DeepSeek. I'm at your service and would be delighted to assist you with any inquiries or tasks you may have.",
},
"_type": "constructor",
},
"output": "Hello! I'm DeepSeek-V3, an artificial intelligence assistant created by DeepSeek. I'm at your service and would be delighted to assist you with any inquiries or tasks you may have.",
},
],
"llmOutput": {
"tokenUsage": {
"completionTokens": 41,
"promptTokens": 9,
"totalTokens": 50,
},
},
}
`;

exports[`DeepSeek test DeepSeek text 2`] = `
{
"completionTokens": 45,
"promptTokens": 12,
"totalTokens": 57,
}
`;
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { describe, expect, test } from '@jest/globals';

import { HumanMessage } from '../../../../../../input/load/msgs/human.js';
import { ChatGenerationChunk } from '../../../../../../output/provide/message.js';
import { OpenAIChat } from '../../chat.js';

describe('DeepSeek', () => {
const DEEPSEEK_API_KEY = process.env.DEEPSEEK_API_KEY;

test('test DeepSeek text', async () => {
const deepseek = new OpenAIChat({
openAIApiKey: DEEPSEEK_API_KEY,
modelName: 'deepseek-chat',
});
const messages = [new HumanMessage('Hello! Who are you?')];

const promptTokenNumber: number = await OpenAIChat.getNumTokensInChat(
deepseek.modelName,
messages
);
const llmResult = await deepseek.invoke(messages);
const completionTokenNumber: number =
await OpenAIChat.getNumTokensInGenerations(
deepseek.modelName,
llmResult.generations as ChatGenerationChunk[]
);

expect(llmResult).toMatchSnapshot();

expect({
completionTokens: completionTokenNumber,
promptTokens: promptTokenNumber,
totalTokens: promptTokenNumber + completionTokenNumber,
}).toMatchSnapshot();
});
});

0 comments on commit 24477ed

Please sign in to comment.