Skip to content

Commit

Permalink
♻️ refactor: refactor the tts route url (lobehub#4030)
Browse files Browse the repository at this point in the history
* ♻️ refactor: refactor the tts to new route

* ♻️ refactor: refactor the tts to new route
  • Loading branch information
arvinxx authored Sep 19, 2024
1 parent 8b91884 commit 60dcf19
Show file tree
Hide file tree
Showing 15 changed files with 78 additions and 18 deletions.
1 change: 1 addition & 0 deletions src/app/api/openai/createBizOpenAI/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { checkAuth } from './auth';
import { createOpenai } from './createOpenai';

/**
* @deprecated
* createOpenAI Instance with Auth and azure openai support
* if auth not pass ,just return error response
*/
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export const preferredRegion = [
export const POST = async (req: Request) => {
const payload = (await req.json()) as OpenAITTSPayload;

// need to be refactored with jwt auth mode
const openaiOrErrResponse = createBizOpenAI(req);

// if resOrOpenAI is a Response, it means there is an error,just return it
Expand Down
4 changes: 3 additions & 1 deletion src/const/fetch.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export const OPENAI_END_POINT = 'X-openai-end-point';
export const OPENAI_API_KEY_HEADER_KEY = 'X-openai-api-key';
export const LOBE_USER_ID = 'X-lobe-user-id';

export const USE_AZURE_OPENAI = 'X-use-azure-openai';

Expand All @@ -19,9 +20,10 @@ export const getOpenAIAuthFromRequest = (req: Request) => {
const useAzureStr = req.headers.get(USE_AZURE_OPENAI);
const apiVersion = req.headers.get(AZURE_OPENAI_API_VERSION);
const oauthAuthorizedStr = req.headers.get(OAUTH_AUTHORIZED);
const userId = req.headers.get(LOBE_USER_ID);

const oauthAuthorized = !!oauthAuthorizedStr;
const useAzure = !!useAzureStr;

return { accessCode, apiKey, apiVersion, endpoint, oauthAuthorized, useAzure };
return { accessCode, apiKey, apiVersion, endpoint, oauthAuthorized, useAzure, userId };
};
4 changes: 4 additions & 0 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
EmbeddingsPayload,
ModelProvider,
TextToImagePayload,
TextToSpeechPayload,
} from './types';
import { LobeUpstageAI } from './upstage';
import { LobeZeroOneAI } from './zeroone';
Expand Down Expand Up @@ -97,6 +98,9 @@ class AgentRuntime {
async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions) {
return this._runtime.embeddings?.(payload, options);
}
async textToSpeech(payload: TextToSpeechPayload, options?: EmbeddingsOptions) {
return this._runtime.textToSpeech?.(payload, options);
}

/**
* @description Initialize the runtime with the provider and the options
Expand Down
9 changes: 8 additions & 1 deletion src/libs/agent-runtime/BaseAI.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import OpenAI from 'openai';

import { TextToImagePayload } from '@/libs/agent-runtime/types/textToImage';
import { ChatModelCard } from '@/types/llm';

import {
Expand All @@ -9,6 +8,9 @@ import {
EmbeddingItem,
EmbeddingsOptions,
EmbeddingsPayload,
TextToImagePayload,
TextToSpeechOptions,
TextToSpeechPayload,
} from './types';

export interface LobeRuntimeAI {
Expand All @@ -20,6 +22,11 @@ export interface LobeRuntimeAI {
models?(): Promise<any>;

textToImage?: (payload: TextToImagePayload) => Promise<string[]>;

textToSpeech?: (
payload: TextToSpeechPayload,
options?: TextToSpeechOptions,
) => Promise<ArrayBuffer>;
}

export abstract class LobeOpenAICompatibleRuntime {
Expand Down
1 change: 1 addition & 0 deletions src/libs/agent-runtime/types/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export * from './chat';
export * from './embeddings';
export * from './textToImage';
export * from './tts';
export * from './type';
14 changes: 14 additions & 0 deletions src/libs/agent-runtime/types/tts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
export interface TextToSpeechPayload {
input: string;
model: string;
voice: string;
}

export interface TextToSpeechOptions {
headers?: Record<string, any>;
signal?: AbortSignal;
/**
* userId for the embeddings
*/
user?: string;
}
17 changes: 16 additions & 1 deletion src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import OpenAI, { ClientOptions } from 'openai';

import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders';
import { TextToImagePayload } from '@/libs/agent-runtime/types/textToImage';
import { ChatModelCard } from '@/types/llm';

import { LobeRuntimeAI } from '../../BaseAI';
Expand All @@ -13,6 +12,9 @@ import {
EmbeddingItem,
EmbeddingsOptions,
EmbeddingsPayload,
TextToImagePayload,
TextToSpeechOptions,
TextToSpeechPayload,
} from '../../types';
import { AgentRuntimeError } from '../createError';
import { debugResponse, debugStream } from '../debugStream';
Expand Down Expand Up @@ -253,6 +255,19 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
}
}

async textToSpeech(payload: TextToSpeechPayload, options?: TextToSpeechOptions) {
try {
const mp3 = await this.client.audio.speech.create(payload as any, {
headers: options?.headers,
signal: options?.signal,
});

return mp3.arrayBuffer();
} catch (error) {
throw this.handleError(error);
}
}

private handleError(error: any): ChatCompletionErrorPayload {
let desensitizedEndpoint = this.baseURL;

Expand Down
13 changes: 10 additions & 3 deletions src/services/_header.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { LOBE_CHAT_ACCESS_CODE, OPENAI_API_KEY_HEADER_KEY, OPENAI_END_POINT } from '@/const/fetch';
import {
LOBE_CHAT_ACCESS_CODE,
LOBE_USER_ID,
OPENAI_API_KEY_HEADER_KEY,
OPENAI_END_POINT,
} from '@/const/fetch';
import { useUserStore } from '@/store/user';
import { keyVaultsConfigSelectors } from '@/store/user/selectors';

Expand All @@ -8,12 +13,14 @@ import { keyVaultsConfigSelectors } from '@/store/user/selectors';
*/
// eslint-disable-next-line no-undef
export const createHeaderWithOpenAI = (header?: HeadersInit): HeadersInit => {
const openAIConfig = keyVaultsConfigSelectors.openAIConfig(useUserStore.getState());
const state = useUserStore.getState();
const openAIConfig = keyVaultsConfigSelectors.openAIConfig(state);

// eslint-disable-next-line no-undef
return {
...header,
[LOBE_CHAT_ACCESS_CODE]: keyVaultsConfigSelectors.password(useUserStore.getState()),
[LOBE_CHAT_ACCESS_CODE]: keyVaultsConfigSelectors.password(state),
[LOBE_USER_ID]: state.user?.id || '',
[OPENAI_API_KEY_HEADER_KEY]: openAIConfig.apiKey || '',
[OPENAI_END_POINT]: openAIConfig.baseURL || '',
};
Expand Down
14 changes: 8 additions & 6 deletions src/services/_url.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// TODO: 未来所有路由需要全部迁移到 trpc
// TODO: 未来路由需要迁移到 trpc or /webapi

/* eslint-disable sort-keys-fix/sort-keys-fix */
import { transform } from 'lodash-es';
Expand Down Expand Up @@ -38,9 +38,11 @@ export const API_ENDPOINTS = mapWithBasePath({
// image
images: '/api/text-to-image/openai',

// TTS & STT
stt: '/api/openai/stt',
tts: '/api/openai/tts',
edge: '/api/tts/edge-speech',
microsoft: '/api/tts/microsoft-speech',
// STT
stt: '/webapi/stt/openai',

// TTS
tts: '/webapi/tts/openai',
edge: '/webapi/tts/edge',
microsoft: '/webapi/tts/microsoft',
});
2 changes: 1 addition & 1 deletion src/store/file/slices/tts/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export const createTTSFileSlice: StateCreator<
};
const file = new File([blob], fileName, fileOptions);

const res = await get().uploadWithProgress({ file });
const res = await get().uploadWithProgress({ file, skipCheckFileType: true });

return res?.id;
},
Expand Down
16 changes: 11 additions & 5 deletions src/store/file/slices/upload/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ interface UploadWithProgressParams {
type: 'removeFile';
},
) => void;
/**
* Optional flag to indicate whether to skip the file type check.
* When set to `true`, any file type checks will be bypassed.
* Default is `false`, which means file type checks will be performed.
*/
skipCheckFileType?: boolean;
}

interface UploadWithProgressResult {
Expand All @@ -52,8 +58,8 @@ export const createFileUploadSlice: StateCreator<
[],
FileUploadAction
> = (set, get) => ({
internal_uploadToClientDB: async ({ file, onStatusUpdate }) => {
if (!file.type.startsWith('image')) {
internal_uploadToClientDB: async ({ file, onStatusUpdate, skipCheckFileType }) => {
if (!skipCheckFileType && !file.type.startsWith('image')) {
onStatusUpdate?.({ id: file.name, type: 'removeFile' });
message.info({
content: t('upload.fileOnlySupportInServerMode', {
Expand Down Expand Up @@ -158,11 +164,11 @@ export const createFileUploadSlice: StateCreator<
return data;
},

uploadWithProgress: async ({ file, onStatusUpdate, knowledgeBaseId }) => {
uploadWithProgress: async (payload) => {
const { internal_uploadToServer, internal_uploadToClientDB } = get();

if (isServerMode) return internal_uploadToServer({ file, knowledgeBaseId, onStatusUpdate });
if (isServerMode) return internal_uploadToServer(payload);

return internal_uploadToClientDB({ file, onStatusUpdate });
return internal_uploadToClientDB(payload);
},
});

0 comments on commit 60dcf19

Please sign in to comment.