-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Bytedance Doubao Embeddings
- Loading branch information
Showing
9 changed files
with
278 additions
and
1 deletion.
There are no files selected for viewing
33 changes: 33 additions & 0 deletions
33
docs/core_docs/docs/integrations/text_embedding/bytedance_doubao.mdx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
--- | ||
sidebar_class_name: node-only | ||
--- | ||
|
||
# ByteDance Doubao | ||
|
||
The `ByteDanceDoubaoEmbeddings` class uses the ByteDance Doubao API to generate embeddings for a given text. | ||
|
||
## Setup | ||
|
||
You'll need to sign up for an ByteDance API key and set it as an environment variable named `ARK_API_KEY`. | ||
|
||
Then, you'll need to install the [`@langchain/community`](https://www.npmjs.com/package/@langchain/community) package: | ||
|
||
import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx"; | ||
|
||
<IntegrationInstallTooltip></IntegrationInstallTooltip> | ||
|
||
```bash npm2yarn | ||
npm install @langchain/community @langchain/core | ||
``` | ||
|
||
## Usage | ||
|
||
import CodeBlock from "@theme/CodeBlock"; | ||
import ByteDanceDoubaoExample from "@examples/embeddings/bytedance_doubao.ts"; | ||
|
||
<CodeBlock language="typescript">{ByteDanceDoubaoExample}</CodeBlock> | ||
|
||
## Related | ||
|
||
- Embedding model [conceptual guide](/docs/concepts/embedding_models) | ||
- Embedding model [how-to guides](/docs/how_to/#embedding-models) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import { ByteDanceDoubaoEmbeddings } from "@langchain/community/embeddings/bytedance_doubao"; | ||
|
||
const model = new ByteDanceDoubaoEmbeddings({ | ||
modelName: 'ep-xxx-xxx' | ||
}); | ||
const res = await model.embedQuery( | ||
"What would be a good company name a company that makes colorful socks?" | ||
); | ||
console.log({ res }); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
175 changes: 175 additions & 0 deletions
175
libs/langchain-community/src/embeddings/bytedance_doubao.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import { getEnvironmentVariable } from "@langchain/core/utils/env"; | ||
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings"; | ||
import { chunkArray } from "@langchain/core/utils/chunk_array"; | ||
|
||
export interface ByteDanceDoubaoEmbeddingsParams extends EmbeddingsParams { | ||
/** Model name to use */ | ||
modelName: string; | ||
|
||
/** | ||
* Timeout to use when making requests to ByteDanceDoubao. | ||
*/ | ||
timeout?: number; | ||
|
||
/** | ||
* The maximum number of documents to embed in a single request. This is | ||
* limited by the ByteDanceDoubao API to a maximum of 2048. | ||
*/ | ||
batchSize?: number; | ||
|
||
/** | ||
* Whether to strip new lines from the input text. | ||
*/ | ||
stripNewLines?: boolean; | ||
} | ||
|
||
interface EmbeddingCreateParams { | ||
model: ByteDanceDoubaoEmbeddingsParams["modelName"]; | ||
input: string[]; | ||
encoding_format?: "float"; | ||
} | ||
|
||
interface EmbeddingResponse { | ||
data: { | ||
index: number; | ||
embedding: number[]; | ||
}[]; | ||
|
||
usage: { | ||
prompt_tokens: number; | ||
total_tokens: number; | ||
}; | ||
|
||
id: string; | ||
} | ||
|
||
interface EmbeddingErrorResponse { | ||
type: string; | ||
code: string; | ||
param: string; | ||
message: string; | ||
} | ||
|
||
export class ByteDanceDoubaoEmbeddings | ||
extends Embeddings | ||
implements ByteDanceDoubaoEmbeddingsParams { | ||
modelName: ByteDanceDoubaoEmbeddingsParams["modelName"] = ""; | ||
|
||
batchSize = 24; | ||
|
||
stripNewLines = true; | ||
|
||
apiKey: string; | ||
|
||
constructor( | ||
fields?: Partial<ByteDanceDoubaoEmbeddingsParams> & { | ||
verbose?: boolean; | ||
apiKey?: string; | ||
} | ||
) { | ||
const fieldsWithDefaults = { maxConcurrency: 2, ...fields }; | ||
super(fieldsWithDefaults); | ||
|
||
const apiKey = | ||
fieldsWithDefaults?.apiKey ?? getEnvironmentVariable("ARK_API_KEY"); | ||
|
||
if (!apiKey) throw new Error("ByteDanceDoubao API key not found"); | ||
|
||
this.apiKey = apiKey; | ||
|
||
this.modelName = fieldsWithDefaults?.modelName ?? this.modelName; | ||
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize; | ||
this.stripNewLines = | ||
fieldsWithDefaults?.stripNewLines ?? this.stripNewLines; | ||
} | ||
|
||
/** | ||
* Method to generate embeddings for an array of documents. Splits the | ||
* documents into batches and makes requests to the ByteDanceDoubao API to generate | ||
* embeddings. | ||
* @param texts Array of documents to generate embeddings for. | ||
* @returns Promise that resolves to a 2D array of embeddings for each document. | ||
*/ | ||
async embedDocuments(texts: string[]): Promise<number[][]> { | ||
const batches = chunkArray( | ||
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, | ||
this.batchSize | ||
); | ||
const batchRequests = batches.map((batch) => { | ||
const params = this.getParams(batch); | ||
|
||
return this.embeddingWithRetry(params); | ||
}); | ||
|
||
const batchResponses = await Promise.all(batchRequests); | ||
const embeddings: number[][] = []; | ||
|
||
for (let i = 0; i < batchResponses.length; i += 1) { | ||
const batch = batches[i]; | ||
const batchResponse = batchResponses[i] || []; | ||
for (let j = 0; j < batch.length; j += 1) { | ||
embeddings.push(batchResponse[j]); | ||
} | ||
} | ||
|
||
return embeddings; | ||
} | ||
|
||
/** | ||
* Method to generate an embedding for a single document. Calls the | ||
* embeddingWithRetry method with the document as the input. | ||
* @param text Document to generate an embedding for. | ||
* @returns Promise that resolves to an embedding for the document. | ||
*/ | ||
async embedQuery(text: string): Promise<number[]> { | ||
const params = this.getParams([ | ||
this.stripNewLines ? text.replace(/\n/g, " ") : text, | ||
]); | ||
|
||
const embeddings = (await this.embeddingWithRetry(params)) || [[]]; | ||
return embeddings[0]; | ||
} | ||
|
||
/** | ||
* Method to generate an embedding params. | ||
* @param texts Array of documents to generate embeddings for. | ||
* @returns an embedding params. | ||
*/ | ||
private getParams( | ||
texts: EmbeddingCreateParams["input"] | ||
): EmbeddingCreateParams { | ||
return { | ||
model: this.modelName, | ||
input: texts, | ||
}; | ||
} | ||
|
||
/** | ||
* Private method to make a request to the OpenAI API to generate | ||
* embeddings. Handles the retry logic and returns the response from the | ||
* API. | ||
* @param request Request to send to the OpenAI API. | ||
* @returns Promise that resolves to the response from the API. | ||
*/ | ||
private async embeddingWithRetry(body: EmbeddingCreateParams) { | ||
return fetch("https://ark.cn-beijing.volces.com/api/v3/embeddings", { | ||
method: "POST", | ||
headers: { | ||
"Content-Type": "application/json", | ||
Authorization: `Bearer ${this.apiKey}`, | ||
}, | ||
body: JSON.stringify(body), | ||
}).then(async (response) => { | ||
const embeddingData: EmbeddingResponse | EmbeddingErrorResponse = | ||
await response.json(); | ||
|
||
if ("code" in embeddingData && embeddingData.code) { | ||
throw new Error(`${embeddingData.code}: ${embeddingData.message}`); | ||
} | ||
|
||
return (embeddingData as EmbeddingResponse).data.map( | ||
({ embedding }) => embedding | ||
); | ||
}); | ||
} | ||
} |
40 changes: 40 additions & 0 deletions
40
libs/langchain-community/src/embeddings/tests/bytedance_doubao.int.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import { test, expect } from "@jest/globals"; | ||
import { ByteDanceDoubaoEmbeddings } from "../bytedance_doubao.js"; | ||
|
||
const modelName = 'ep-xxx-xxx'; | ||
test.skip("Test ByteDanceDoubaoEmbeddings.embedQuery", async () => { | ||
const embeddings = new ByteDanceDoubaoEmbeddings({ | ||
modelName, | ||
}); | ||
const res = await embeddings.embedQuery("Hello world"); | ||
expect(typeof res[0]).toBe("number"); | ||
}); | ||
|
||
test.skip("Test ByteDanceDoubaoEmbeddings.embedDocuments", async () => { | ||
const embeddings = new ByteDanceDoubaoEmbeddings({ | ||
modelName, | ||
}); | ||
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]); | ||
expect(res).toHaveLength(2); | ||
expect(typeof res[0][0]).toBe("number"); | ||
expect(typeof res[1][0]).toBe("number"); | ||
}); | ||
|
||
test.skip("Test ByteDanceDoubaoEmbeddings concurrency", async () => { | ||
const embeddings = new ByteDanceDoubaoEmbeddings({ | ||
modelName, | ||
batchSize: 1, | ||
}); | ||
const res = await embeddings.embedDocuments([ | ||
"Hello world", | ||
"Bye bye", | ||
"Hello world", | ||
"Bye bye", | ||
"Hello world", | ||
"Bye bye", | ||
]); | ||
expect(res).toHaveLength(6); | ||
expect(res.find((embedding) => typeof embedding[0] !== "number")).toBe( | ||
undefined | ||
); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters