diff --git a/core/index.d.ts b/core/index.d.ts index b21bae0799..eec22f18d6 100644 --- a/core/index.d.ts +++ b/core/index.d.ts @@ -773,6 +773,7 @@ export interface ModelDescription { } export type EmbeddingsProviderName = + | "bedrock" | "huggingface-tei" | "transformers.js" | "ollama" @@ -793,6 +794,11 @@ export interface EmbedOptions { apiVersion?: string; requestOptions?: RequestOptions; maxChunkSize?: number; + // AWS options + profile?: string; + + // AWS and GCP Options + region?: string; } export interface EmbeddingsProviderDescription extends EmbedOptions { diff --git a/core/indexing/embeddings/BedrockEmbeddingsProvider.ts b/core/indexing/embeddings/BedrockEmbeddingsProvider.ts new file mode 100644 index 0000000000..12e20aa76f --- /dev/null +++ b/core/indexing/embeddings/BedrockEmbeddingsProvider.ts @@ -0,0 +1,93 @@ +import { + BedrockRuntimeClient, + InvokeModelCommand, +} from "@aws-sdk/client-bedrock-runtime"; +import { fromIni } from "@aws-sdk/credential-providers"; +import { EmbeddingsProviderName, EmbedOptions, FetchFunction } from "../../index.js"; +import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js"; + +class BedrockEmbeddingsProvider extends BaseEmbeddingsProvider { + + static providerName: EmbeddingsProviderName = "bedrock"; + + static defaultOptions: Partial | undefined = { + model: "amazon.titan-embed-text-v2:0", + region: "us-east-1" + }; + profile?: string | undefined; + + constructor(options: EmbedOptions, fetch: FetchFunction) { + super(options, fetch); + if (!options.apiBase) { + options.apiBase = `https://bedrock-runtime.${options.region}.amazonaws.com`; + } + + if (options.profile) { + this.profile = options.profile; + } else { + this.profile = "bedrock"; + } + } + + async embed(chunks: string[]) { + const credentials = await this._getCredentials(); + const client = new BedrockRuntimeClient({ + region: this.options.region, + credentials: { + accessKeyId: credentials.accessKeyId, + secretAccessKey: credentials.secretAccessKey, + sessionToken: credentials.sessionToken || "", + }, + }); + + return ( + await Promise.all( + chunks.map(async (chunk) => { + const input = this._generateInvokeModelCommandInput(chunk, this.options); + const command = new InvokeModelCommand(input); + const response = await client.send(command); + + if (response.body) { + const responseBody = JSON.parse(new TextDecoder().decode(response.body)); + return responseBody.embedding; + } + }), + ) + ).flat(); + } + + private _generateInvokeModelCommandInput( + prompt: string, + options: EmbedOptions, + ): any { + const payload = { + "inputText": prompt, + "dimensions": 1024, + "normalize": true + }; + + return { + body: JSON.stringify(payload), + modelId: this.options.model, + accept: "application/json", + contentType: "application/json" + }; + } + + private async _getCredentials() { + try { + return await + fromIni({ + profile: this.profile + })(); + } catch (e) { + console.warn( + `AWS profile with name ${this.profile} not found in ~/.aws/credentials, using default profile`, + ); + return await fromIni()(); + } + } + +} + +export default BedrockEmbeddingsProvider; diff --git a/core/indexing/embeddings/index.ts b/core/indexing/embeddings/index.ts index acba5dbe21..65c583769e 100644 --- a/core/indexing/embeddings/index.ts +++ b/core/indexing/embeddings/index.ts @@ -1,5 +1,6 @@ import { EmbeddingsProviderName } from "../../index.js"; import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js"; +import BedrockEmbeddingsProvider from "./BedrockEmbeddingsProvider.js"; import CohereEmbeddingsProvider from "./CohereEmbeddingsProvider.js"; import ContinueProxyEmbeddingsProvider from "./ContinueProxyEmbeddingsProvider.js"; import DeepInfraEmbeddingsProvider from "./DeepInfraEmbeddingsProvider.js"; @@ -19,6 +20,7 @@ export const allEmbeddingsProviders: Record< EmbeddingsProviderName, EmbeddingsProviderConstructor > = { + bedrock: BedrockEmbeddingsProvider, ollama: OllamaEmbeddingsProvider, "transformers.js": TransformersJsEmbeddingsProvider, openai: OpenAIEmbeddingsProvider, diff --git a/docs/docs/features/codebase-embeddings.md b/docs/docs/features/codebase-embeddings.md index 21945734b7..0ad0d92fe1 100644 --- a/docs/docs/features/codebase-embeddings.md +++ b/docs/docs/features/codebase-embeddings.md @@ -202,6 +202,21 @@ As of May 2024, the only available embedding model from Gemini is [`text-embeddi } ``` +### AWS Bedrock + +As of August 30, 2024 the only tested model is [`amazon.titan-embed-text-v2:0`](https://docs.aws.amazon.com/bedrock/latest/devguide/models.html#amazon.titan-embed-text-v2-0). + +```json title="~/.continue/config.json" +{ + "embeddingsProvider": { + "title": "Embeddings Model", + "provider": "bedrock", + "model": "amazon.titan-embed-text-v2:0", + "region": "us-west-2" + }, +} +``` + ### Writing a custom `EmbeddingsProvider` If you have your own API capable of generating embeddings, Continue makes it easy to write a custom `EmbeddingsProvider`. All you have to do is write a function that converts strings to arrays of numbers, and add this to your config in `config.ts`. Here's an example: diff --git a/extensions/intellij/src/main/resources/config_schema.json b/extensions/intellij/src/main/resources/config_schema.json index 7166520e10..1087144491 100644 --- a/extensions/intellij/src/main/resources/config_schema.json +++ b/extensions/intellij/src/main/resources/config_schema.json @@ -200,8 +200,8 @@ "### Free Trial\nNew users can try out Continue for free using a proxy server that securely makes calls to OpenAI using our API key. If you are ready to use your own API key or have used all 250 free uses, you can enter your API key in config.json where it says `apiKey=\"\"` or select another model provider.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/freetrial)", "### Anthropic\nTo get started with Anthropic models, you first need to sign up for the open beta [here](https://claude.ai/login) to obtain an API key.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/anthropicllm)", "### Cohere\nTo use Cohere, visit the [Cohere dashboard](https://dashboard.cohere.com/api-keys) to create an API key.\n\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/cohere)", - "### Bedrock\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock/claude/)", - "### Bedrock Imported Models\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock/claude/)", + "### Bedrock\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock)", + "### Bedrock Imported Models\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock)", "### Sagemaker\nSagemaker is AWS' machine learning platform.", "### Together\nTogether is a hosted service that provides extremely fast streaming of open-source language models. To get started with Together:\n1. Obtain an API key from [here](https://together.ai)\n2. Paste below\n3. Select a model preset\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/togetherllm)", "### Ollama\nTo get started with Ollama, follow these steps:\n1. Download from [ollama.ai](https://ollama.ai/) and open the application\n2. Open a terminal and run `ollama run `. Example model names are `codellama:7b-instruct` or `llama2:7b-text`. You can find the full list [here](https://ollama.ai/library).\n3. Make sure that the model name used in step 2 is the same as the one in config.json (e.g. `model=\"codellama:7b-instruct\"`)\n4. Once the model has finished downloading, you can start asking questions through Continue.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/ollama)", @@ -242,6 +242,21 @@ "description": "The base URL of the LLM API.", "type": "string" }, + "region": { + "title": "Region", + "description": "The region where the model is hosted", + "type": "string" + }, + "profile": { + "title": "Profile", + "description": "The AWS security profile to use", + "type": "string" + }, + "modelArn": { + "title": "Profile", + "description": "The AWS arn for the imported model", + "type": "string" + }, "contextLength": { "title": "Context Length", "description": "The maximum context length of the LLM in tokens, as counted by countTokens.", @@ -393,6 +408,21 @@ "required": ["apiKey"] } }, + { + "if": { + "properties": { + "provider": { + "enum": [ + "bedrockimport" + ] + } + }, + "required": ["provider"] + }, + "then": { + "required": ["modelArn"] + } + }, { "if": { "properties": { @@ -2030,7 +2060,8 @@ "cohere", "free-trial", "gemini", - "voyage" + "voyage", + "bedrock" ] }, "model": { @@ -2053,6 +2084,16 @@ "type": "integer", "minimum": 128, "exclusiveMaximum": 2147483647 + }, + "region": { + "title": "Region", + "description": "The region where the model is hosted", + "type": "string" + }, + "profile": { + "title": "Profile", + "description": "The AWS security profile to use", + "type": "string" } }, "required": ["provider"], diff --git a/extensions/vscode/config_schema.json b/extensions/vscode/config_schema.json index 7166520e10..1087144491 100644 --- a/extensions/vscode/config_schema.json +++ b/extensions/vscode/config_schema.json @@ -200,8 +200,8 @@ "### Free Trial\nNew users can try out Continue for free using a proxy server that securely makes calls to OpenAI using our API key. If you are ready to use your own API key or have used all 250 free uses, you can enter your API key in config.json where it says `apiKey=\"\"` or select another model provider.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/freetrial)", "### Anthropic\nTo get started with Anthropic models, you first need to sign up for the open beta [here](https://claude.ai/login) to obtain an API key.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/anthropicllm)", "### Cohere\nTo use Cohere, visit the [Cohere dashboard](https://dashboard.cohere.com/api-keys) to create an API key.\n\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/cohere)", - "### Bedrock\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock/claude/)", - "### Bedrock Imported Models\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock/claude/)", + "### Bedrock\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock)", + "### Bedrock Imported Models\nTo get started with Bedrock you need to sign up on AWS [here](https://aws.amazon.com/bedrock)", "### Sagemaker\nSagemaker is AWS' machine learning platform.", "### Together\nTogether is a hosted service that provides extremely fast streaming of open-source language models. To get started with Together:\n1. Obtain an API key from [here](https://together.ai)\n2. Paste below\n3. Select a model preset\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/togetherllm)", "### Ollama\nTo get started with Ollama, follow these steps:\n1. Download from [ollama.ai](https://ollama.ai/) and open the application\n2. Open a terminal and run `ollama run `. Example model names are `codellama:7b-instruct` or `llama2:7b-text`. You can find the full list [here](https://ollama.ai/library).\n3. Make sure that the model name used in step 2 is the same as the one in config.json (e.g. `model=\"codellama:7b-instruct\"`)\n4. Once the model has finished downloading, you can start asking questions through Continue.\n> [Reference](https://docs.continue.dev/reference/Model%20Providers/ollama)", @@ -242,6 +242,21 @@ "description": "The base URL of the LLM API.", "type": "string" }, + "region": { + "title": "Region", + "description": "The region where the model is hosted", + "type": "string" + }, + "profile": { + "title": "Profile", + "description": "The AWS security profile to use", + "type": "string" + }, + "modelArn": { + "title": "Profile", + "description": "The AWS arn for the imported model", + "type": "string" + }, "contextLength": { "title": "Context Length", "description": "The maximum context length of the LLM in tokens, as counted by countTokens.", @@ -393,6 +408,21 @@ "required": ["apiKey"] } }, + { + "if": { + "properties": { + "provider": { + "enum": [ + "bedrockimport" + ] + } + }, + "required": ["provider"] + }, + "then": { + "required": ["modelArn"] + } + }, { "if": { "properties": { @@ -2030,7 +2060,8 @@ "cohere", "free-trial", "gemini", - "voyage" + "voyage", + "bedrock" ] }, "model": { @@ -2053,6 +2084,16 @@ "type": "integer", "minimum": 128, "exclusiveMaximum": 2147483647 + }, + "region": { + "title": "Region", + "description": "The region where the model is hosted", + "type": "string" + }, + "profile": { + "title": "Profile", + "description": "The AWS security profile to use", + "type": "string" } }, "required": ["provider"],