Skip to content

Commit

Permalink
Merge pull request #2141 from chezsmithy/feat-bedrock-embeddings
Browse files Browse the repository at this point in the history
feat-bedrock-embeddings
  • Loading branch information
sestinj authored Sep 4, 2024
2 parents 0558533 + 1614588 commit 7614cb7
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 6 deletions.
6 changes: 6 additions & 0 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ export interface ModelDescription {
}

export type EmbeddingsProviderName =
| "bedrock"
| "huggingface-tei"
| "transformers.js"
| "ollama"
Expand All @@ -793,6 +794,11 @@ export interface EmbedOptions {
apiVersion?: string;
requestOptions?: RequestOptions;
maxChunkSize?: number;
// AWS options
profile?: string;

// AWS and GCP Options
region?: string;
}

export interface EmbeddingsProviderDescription extends EmbedOptions {
Expand Down
93 changes: 93 additions & 0 deletions core/indexing/embeddings/BedrockEmbeddingsProvider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import {
BedrockRuntimeClient,
InvokeModelCommand,
} from "@aws-sdk/client-bedrock-runtime";
import { fromIni } from "@aws-sdk/credential-providers";
import { EmbeddingsProviderName, EmbedOptions, FetchFunction } from "../../index.js";
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";

class BedrockEmbeddingsProvider extends BaseEmbeddingsProvider {

static providerName: EmbeddingsProviderName = "bedrock";

static defaultOptions: Partial<EmbedOptions> | undefined = {
model: "amazon.titan-embed-text-v2:0",
region: "us-east-1"
};
profile?: string | undefined;

constructor(options: EmbedOptions, fetch: FetchFunction) {
super(options, fetch);
if (!options.apiBase) {
options.apiBase = `https://bedrock-runtime.${options.region}.amazonaws.com`;
}

if (options.profile) {
this.profile = options.profile;
} else {
this.profile = "bedrock";
}
}

async embed(chunks: string[]) {
const credentials = await this._getCredentials();
const client = new BedrockRuntimeClient({
region: this.options.region,
credentials: {
accessKeyId: credentials.accessKeyId,
secretAccessKey: credentials.secretAccessKey,
sessionToken: credentials.sessionToken || "",
},
});

return (
await Promise.all(
chunks.map(async (chunk) => {
const input = this._generateInvokeModelCommandInput(chunk, this.options);
const command = new InvokeModelCommand(input);
const response = await client.send(command);

if (response.body) {
const responseBody = JSON.parse(new TextDecoder().decode(response.body));
return responseBody.embedding;
}
}),
)
).flat();
}

private _generateInvokeModelCommandInput(
prompt: string,
options: EmbedOptions,
): any {
const payload = {
"inputText": prompt,
"dimensions": 1024,
"normalize": true
};

return {
body: JSON.stringify(payload),
modelId: this.options.model,
accept: "application/json",
contentType: "application/json"
};
}

private async _getCredentials() {
try {
return await
fromIni({
profile: this.profile
})();
} catch (e) {
console.warn(
`AWS profile with name ${this.profile} not found in ~/.aws/credentials, using default profile`,
);
return await fromIni()();
}
}

}

export default BedrockEmbeddingsProvider;
2 changes: 2 additions & 0 deletions core/indexing/embeddings/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { EmbeddingsProviderName } from "../../index.js";
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";
import BedrockEmbeddingsProvider from "./BedrockEmbeddingsProvider.js";
import CohereEmbeddingsProvider from "./CohereEmbeddingsProvider.js";
import ContinueProxyEmbeddingsProvider from "./ContinueProxyEmbeddingsProvider.js";
import DeepInfraEmbeddingsProvider from "./DeepInfraEmbeddingsProvider.js";
Expand All @@ -19,6 +20,7 @@ export const allEmbeddingsProviders: Record<
EmbeddingsProviderName,
EmbeddingsProviderConstructor
> = {
bedrock: BedrockEmbeddingsProvider,
ollama: OllamaEmbeddingsProvider,
"transformers.js": TransformersJsEmbeddingsProvider,
openai: OpenAIEmbeddingsProvider,
Expand Down
15 changes: 15 additions & 0 deletions docs/docs/features/codebase-embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,21 @@ As of May 2024, the only available embedding model from Gemini is [`text-embeddi
}
```

### AWS Bedrock

As of August 30, 2024 the only tested model is [`amazon.titan-embed-text-v2:0`](https://docs.aws.amazon.com/bedrock/latest/devguide/models.html#amazon.titan-embed-text-v2-0).

```json title="~/.continue/config.json"
{
"embeddingsProvider": {
"title": "Embeddings Model",
"provider": "bedrock",
"model": "amazon.titan-embed-text-v2:0",
"region": "us-west-2"
},
}
```

### Writing a custom `EmbeddingsProvider`

If you have your own API capable of generating embeddings, Continue makes it easy to write a custom `EmbeddingsProvider`. All you have to do is write a function that converts strings to arrays of numbers, and add this to your config in `config.ts`. Here's an example:
Expand Down
47 changes: 44 additions & 3 deletions extensions/intellij/src/main/resources/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@
"### Free Trial\nNew users can try out Continue for free using a proxy server that securely makes calls to OpenAI using our API key. If you are ready to use your own API key or have used all 250 free uses, you can enter your API key in config.json where it says `apiKey=\"\"` or select another model provider.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/freetrial)",
"### Anthropic\nTo get started with Anthropic models, you first need to sign up for the open beta [here](https://claude.ai/login) to obtain an API key.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/anthropicllm)",
"### Cohere\nTo use Cohere, visit the [Cohere dashboard](https://dashboard.cohere.com/api-keys) to create an API key.\n\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/cohere)",
"### Bedrock\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock/claude/)",
"### Bedrock Imported Models\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock/claude/)",
"### Bedrock\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock)",
"### Bedrock Imported Models\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock)",
"### Sagemaker\nSagemaker is AWS' machine learning platform.",
"### Together\nTogether is a hosted service that provides extremely fast streaming of open-source language models. To get started with Together:\n1. Obtain an API key from [here](https://together.ai)\n2. Paste below\n3. Select a model preset\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/togetherllm)",
"### Ollama\nTo get started with Ollama, follow these steps:\n1. Download from [ollama.ai](https://ollama.ai/) and open the application\n2. Open a terminal and run `ollama run <MODEL_NAME>`. Example model names are `codellama:7b-instruct` or `llama2:7b-text`. You can find the full list [here](https://ollama.ai/library).\n3. Make sure that the model name used in step 2 is the same as the one in config.json (e.g. `model=\"codellama:7b-instruct\"`)\n4. Once the model has finished downloading, you can start asking questions through Continue.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/ollama)",
Expand Down Expand Up @@ -242,6 +242,21 @@
"description": "The base URL of the LLM API.",
"type": "string"
},
"region": {
"title": "Region",
"description": "The region where the model is hosted",
"type": "string"
},
"profile": {
"title": "Profile",
"description": "The AWS security profile to use",
"type": "string"
},
"modelArn": {
"title": "Profile",
"description": "The AWS arn for the imported model",
"type": "string"
},
"contextLength": {
"title": "Context Length",
"description": "The maximum context length of the LLM in tokens, as counted by countTokens.",
Expand Down Expand Up @@ -393,6 +408,21 @@
"required": ["apiKey"]
}
},
{
"if": {
"properties": {
"provider": {
"enum": [
"bedrockimport"
]
}
},
"required": ["provider"]
},
"then": {
"required": ["modelArn"]
}
},
{
"if": {
"properties": {
Expand Down Expand Up @@ -2030,7 +2060,8 @@
"cohere",
"free-trial",
"gemini",
"voyage"
"voyage",
"bedrock"
]
},
"model": {
Expand All @@ -2053,6 +2084,16 @@
"type": "integer",
"minimum": 128,
"exclusiveMaximum": 2147483647
},
"region": {
"title": "Region",
"description": "The region where the model is hosted",
"type": "string"
},
"profile": {
"title": "Profile",
"description": "The AWS security profile to use",
"type": "string"
}
},
"required": ["provider"],
Expand Down
47 changes: 44 additions & 3 deletions extensions/vscode/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@
"### Free Trial\nNew users can try out Continue for free using a proxy server that securely makes calls to OpenAI using our API key. If you are ready to use your own API key or have used all 250 free uses, you can enter your API key in config.json where it says `apiKey=\"\"` or select another model provider.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/freetrial)",
"### Anthropic\nTo get started with Anthropic models, you first need to sign up for the open beta [here](https://claude.ai/login) to obtain an API key.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/anthropicllm)",
"### Cohere\nTo use Cohere, visit the [Cohere dashboard](https://dashboard.cohere.com/api-keys) to create an API key.\n\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/cohere)",
"### Bedrock\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock/claude/)",
"### Bedrock Imported Models\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock/claude/)",
"### Bedrock\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock)",
"### Bedrock Imported Models\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock)",
"### Sagemaker\nSagemaker is AWS' machine learning platform.",
"### Together\nTogether is a hosted service that provides extremely fast streaming of open-source language models. To get started with Together:\n1. Obtain an API key from [here](https://together.ai)\n2. Paste below\n3. Select a model preset\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/togetherllm)",
"### Ollama\nTo get started with Ollama, follow these steps:\n1. Download from [ollama.ai](https://ollama.ai/) and open the application\n2. Open a terminal and run `ollama run <MODEL_NAME>`. Example model names are `codellama:7b-instruct` or `llama2:7b-text`. You can find the full list [here](https://ollama.ai/library).\n3. Make sure that the model name used in step 2 is the same as the one in config.json (e.g. `model=\"codellama:7b-instruct\"`)\n4. Once the model has finished downloading, you can start asking questions through Continue.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/ollama)",
Expand Down Expand Up @@ -242,6 +242,21 @@
"description": "The base URL of the LLM API.",
"type": "string"
},
"region": {
"title": "Region",
"description": "The region where the model is hosted",
"type": "string"
},
"profile": {
"title": "Profile",
"description": "The AWS security profile to use",
"type": "string"
},
"modelArn": {
"title": "Profile",
"description": "The AWS arn for the imported model",
"type": "string"
},
"contextLength": {
"title": "Context Length",
"description": "The maximum context length of the LLM in tokens, as counted by countTokens.",
Expand Down Expand Up @@ -393,6 +408,21 @@
"required": ["apiKey"]
}
},
{
"if": {
"properties": {
"provider": {
"enum": [
"bedrockimport"
]
}
},
"required": ["provider"]
},
"then": {
"required": ["modelArn"]
}
},
{
"if": {
"properties": {
Expand Down Expand Up @@ -2030,7 +2060,8 @@
"cohere",
"free-trial",
"gemini",
"voyage"
"voyage",
"bedrock"
]
},
"model": {
Expand All @@ -2053,6 +2084,16 @@
"type": "integer",
"minimum": 128,
"exclusiveMaximum": 2147483647
},
"region": {
"title": "Region",
"description": "The region where the model is hosted",
"type": "string"
},
"profile": {
"title": "Profile",
"description": "The AWS security profile to use",
"type": "string"
}
},
"required": ["provider"],
Expand Down

0 comments on commit 7614cb7

Please sign in to comment.