Skip to content

Commit

Permalink
feat: add Bytedance Doubao Embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
ucev committed Dec 31, 2024
1 parent ed63546 commit d33c54d
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
---
sidebar_class_name: node-only
---

# ByteDance Doubao

The `ByteDanceDoubaoEmbeddings` class uses the ByteDance Doubao API to generate embeddings for a given text.

## Setup

You'll need to sign up for an ByteDance API key and set it as an environment variable named `ARK_API_KEY`.

Then, you'll need to install the [`@langchain/community`](https://www.npmjs.com/package/@langchain/community) package:

import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";

<IntegrationInstallTooltip></IntegrationInstallTooltip>

```bash npm2yarn
npm install @langchain/community @langchain/core
```

## Usage

import CodeBlock from "@theme/CodeBlock";
import ByteDanceDoubaoExample from "@examples/embeddings/bytedance_doubao.ts";

<CodeBlock language="typescript">{ByteDanceDoubaoExample}</CodeBlock>

## Related

- Embedding model [conceptual guide](/docs/concepts/embedding_models)
- Embedding model [how-to guides](/docs/how_to/#embedding-models)
3 changes: 2 additions & 1 deletion examples/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,5 @@ FRIENDLI_TEAM=ADD_YOURS_HERE # https://suite.friendli.ai/
HANA_HOST=HANA_DB_ADDRESS
HANA_PORT=HANA_DB_PORT
HANA_UID=HANA_DB_USER
HANA_PWD=HANA_DB_PASSWORD
HANA_PWD=HANA_DB_PASSWORD
ARK_API_KEY=ADD_YOURS_HERE # https://console.volcengine.com/
9 changes: 9 additions & 0 deletions examples/src/embeddings/bytedance_doubao.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { ByteDanceDoubaoEmbeddings } from "@langchain/community/embeddings/bytedance_doubao";

const model = new ByteDanceDoubaoEmbeddings({
modelName: 'ep-xxx-xxx'
});
const res = await model.embedQuery(
"What would be a good company name a company that makes colorful socks?"
);
console.log({ res });
4 changes: 4 additions & 0 deletions libs/langchain-community/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ embeddings/bedrock.cjs
embeddings/bedrock.js
embeddings/bedrock.d.ts
embeddings/bedrock.d.cts
embeddings/bytedance_doubao.cjs
embeddings/bytedance_doubao.js
embeddings/bytedance_doubao.d.ts
embeddings/bytedance_doubao.d.cts
embeddings/cloudflare_workersai.cjs
embeddings/cloudflare_workersai.js
embeddings/cloudflare_workersai.d.ts
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-community/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export const config = {
"embeddings/alibaba_tongyi": "embeddings/alibaba_tongyi",
"embeddings/baidu_qianfan": "embeddings/baidu_qianfan",
"embeddings/bedrock": "embeddings/bedrock",
"embeddings/bytedance_doubao": "embeddings/bytedance_doubao",
"embeddings/cloudflare_workersai": "embeddings/cloudflare_workersai",
"embeddings/cohere": "embeddings/cohere",
"embeddings/deepinfra": "embeddings/deepinfra",
Expand Down
13 changes: 13 additions & 0 deletions libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,15 @@
"import": "./embeddings/bedrock.js",
"require": "./embeddings/bedrock.cjs"
},
"./embeddings/bytedance_doubao": {
"types": {
"import": "./embeddings/bytedance_doubao.d.ts",
"require": "./embeddings/bytedance_doubao.d.cts",
"default": "./embeddings/bytedance_doubao.d.ts"
},
"import": "./embeddings/bytedance_doubao.js",
"require": "./embeddings/bytedance_doubao.cjs"
},
"./embeddings/cloudflare_workersai": {
"types": {
"import": "./embeddings/cloudflare_workersai.d.ts",
Expand Down Expand Up @@ -3308,6 +3317,10 @@
"embeddings/bedrock.js",
"embeddings/bedrock.d.ts",
"embeddings/bedrock.d.cts",
"embeddings/bytedance_doubao.cjs",
"embeddings/bytedance_doubao.js",
"embeddings/bytedance_doubao.d.ts",
"embeddings/bytedance_doubao.d.cts",
"embeddings/cloudflare_workersai.cjs",
"embeddings/cloudflare_workersai.js",
"embeddings/cloudflare_workersai.d.ts",
Expand Down
175 changes: 175 additions & 0 deletions libs/langchain-community/src/embeddings/bytedance_doubao.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";

export interface ByteDanceDoubaoEmbeddingsParams extends EmbeddingsParams {
/** Model name to use */
modelName: string;

/**
* Timeout to use when making requests to ByteDanceDoubao.
*/
timeout?: number;

/**
* The maximum number of documents to embed in a single request. This is
* limited by the ByteDanceDoubao API to a maximum of 2048.
*/
batchSize?: number;

/**
* Whether to strip new lines from the input text.
*/
stripNewLines?: boolean;
}

interface EmbeddingCreateParams {
model: ByteDanceDoubaoEmbeddingsParams["modelName"];
input: string[];
encoding_format?: "float";
}

interface EmbeddingResponse {
data: {
index: number;
embedding: number[];
}[];

usage: {
prompt_tokens: number;
total_tokens: number;
};

id: string;
}

interface EmbeddingErrorResponse {
type: string;
code: string;
param: string;
message: string;
}

export class ByteDanceDoubaoEmbeddings
extends Embeddings
implements ByteDanceDoubaoEmbeddingsParams {
modelName: ByteDanceDoubaoEmbeddingsParams["modelName"] = "";

batchSize = 24;

stripNewLines = true;

apiKey: string;

constructor(
fields?: Partial<ByteDanceDoubaoEmbeddingsParams> & {
verbose?: boolean;
apiKey?: string;
}
) {
const fieldsWithDefaults = { maxConcurrency: 2, ...fields };
super(fieldsWithDefaults);

const apiKey =
fieldsWithDefaults?.apiKey ?? getEnvironmentVariable("ARK_API_KEY");

if (!apiKey) throw new Error("ByteDanceDoubao API key not found");

this.apiKey = apiKey;

this.modelName = fieldsWithDefaults?.modelName ?? this.modelName;
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
this.stripNewLines =
fieldsWithDefaults?.stripNewLines ?? this.stripNewLines;
}

/**
* Method to generate embeddings for an array of documents. Splits the
* documents into batches and makes requests to the ByteDanceDoubao API to generate
* embeddings.
* @param texts Array of documents to generate embeddings for.
* @returns Promise that resolves to a 2D array of embeddings for each document.
*/
async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize
);
const batchRequests = batches.map((batch) => {
const params = this.getParams(batch);

return this.embeddingWithRetry(params);
});

const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];

for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
const batchResponse = batchResponses[i] || [];
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse[j]);
}
}

return embeddings;
}

/**
* Method to generate an embedding for a single document. Calls the
* embeddingWithRetry method with the document as the input.
* @param text Document to generate an embedding for.
* @returns Promise that resolves to an embedding for the document.
*/
async embedQuery(text: string): Promise<number[]> {
const params = this.getParams([
this.stripNewLines ? text.replace(/\n/g, " ") : text,
]);

const embeddings = (await this.embeddingWithRetry(params)) || [[]];
return embeddings[0];
}

/**
* Method to generate an embedding params.
* @param texts Array of documents to generate embeddings for.
* @returns an embedding params.
*/
private getParams(
texts: EmbeddingCreateParams["input"]
): EmbeddingCreateParams {
return {
model: this.modelName,
input: texts,
};
}

/**
* Private method to make a request to the OpenAI API to generate
* embeddings. Handles the retry logic and returns the response from the
* API.
* @param request Request to send to the OpenAI API.
* @returns Promise that resolves to the response from the API.
*/
private async embeddingWithRetry(body: EmbeddingCreateParams) {
return fetch("https://ark.cn-beijing.volces.com/api/v3/embeddings", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify(body),
}).then(async (response) => {
const embeddingData: EmbeddingResponse | EmbeddingErrorResponse =
await response.json();

if ("code" in embeddingData && embeddingData.code) {
throw new Error(`${embeddingData.code}: ${embeddingData.message}`);
}

return (embeddingData as EmbeddingResponse).data.map(
({ embedding }) => embedding
);
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { test, expect } from "@jest/globals";
import { ByteDanceDoubaoEmbeddings } from "../bytedance_doubao.js";

const modelName = 'ep-xxx-xxx';
test.skip("Test ByteDanceDoubaoEmbeddings.embedQuery", async () => {
const embeddings = new ByteDanceDoubaoEmbeddings({
modelName,
});
const res = await embeddings.embedQuery("Hello world");
expect(typeof res[0]).toBe("number");
});

test.skip("Test ByteDanceDoubaoEmbeddings.embedDocuments", async () => {
const embeddings = new ByteDanceDoubaoEmbeddings({
modelName,
});
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
expect(res).toHaveLength(2);
expect(typeof res[0][0]).toBe("number");
expect(typeof res[1][0]).toBe("number");
});

test.skip("Test ByteDanceDoubaoEmbeddings concurrency", async () => {
const embeddings = new ByteDanceDoubaoEmbeddings({
modelName,
batchSize: 1,
});
const res = await embeddings.embedDocuments([
"Hello world",
"Bye bye",
"Hello world",
"Bye bye",
"Hello world",
"Bye bye",
]);
expect(res).toHaveLength(6);
expect(res.find((embedding) => typeof embedding[0] !== "number")).toBe(
undefined
);
});
1 change: 1 addition & 0 deletions libs/langchain-community/src/load/import_map.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export * as agents__toolkits__base from "../agents/toolkits/base.js";
export * as agents__toolkits__connery from "../agents/toolkits/connery/index.js";
export * as embeddings__alibaba_tongyi from "../embeddings/alibaba_tongyi.js";
export * as embeddings__baidu_qianfan from "../embeddings/baidu_qianfan.js";
export * as embeddings__bytedance_doubao from "../embeddings/bytedance_doubao.js";
export * as embeddings__deepinfra from "../embeddings/deepinfra.js";
export * as embeddings__fireworks from "../embeddings/fireworks.js";
export * as embeddings__minimax from "../embeddings/minimax.js";
Expand Down

0 comments on commit d33c54d

Please sign in to comment.