From ae1f3ec0de117abcc6803f3d5a2b73d4fa1fd760 Mon Sep 17 00:00:00 2001 From: Yomesh Shah Date: Wed, 8 May 2024 15:22:21 +0100 Subject: [PATCH 1/4] feat: added support for A21 and Amazon Titan models via bedrock api Signed-off-by: Yomesh Shah --- pkg/ai/amazonbedrock.go | 44 ++++++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index f9cd80fa61..b2e0ec32c5 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -52,12 +52,18 @@ const ( ModelAnthropicClaudeV2 = "anthropic.claude-v2" ModelAnthropicClaudeV1 = "anthropic.claude-v1" ModelAnthropicClaudeInstantV1 = "anthropic.claude-instant-v1" + ModelA21J2UltraV1 = "ai21.j2-ultra-v1" + ModelA21J2JumboInstruct = "ai21.j2-jumbo-instruct" + ModelAmazonTitanExpressV1 = "amazon.titan-text-express-v1" ) var BEDROCK_MODELS = []string{ ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1, + ModelA21J2UltraV1, + ModelA21J2JumboInstruct, + ModelAmazonTitanExpressV1, } // GetModelOrDefault check config model @@ -116,13 +122,37 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { // GetCompletion sends a request to the model for generating completion based on the provided prompt. func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) { - // Prepare the input data for the model invocation - request := map[string]interface{}{ - "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), - "max_tokens_to_sample": 1024, - "temperature": a.temperature, - "top_p": 0.9, - } + // Prepare the input data for the model invocation based on the model + var request map[string]interface{} + + switch a.model { + case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1: + request = map[string]interface{}{ + "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), + "max_tokens_to_sample": 1024, + "temperature": a.temperature, + "top_p": 0.9, + } + case ModelA21J2UltraV1, ModelA21J2JumboInstruct: + request = map[string]interface{}{ + "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), + "maxTokens": 1024, + "temperature": a.temperature, + "topP": 0.9, + } + case ModelAmazonTitanExpressV1: + request = map[string]interface{}{ + "inputText": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), + "textGenerationConfig": map[string]interface{}{ + "maxTokenCount": 1024, + "temperature": a.temperature, + "topP": 0.9, + }, + } + default: + return "", fmt.Errorf("model %s not supported", a.model) + } + body, err := json.Marshal(request) if err != nil { From 076ab2f09b8eaa10a9ca344245a0fc1045cf1e9d Mon Sep 17 00:00:00 2001 From: Yomesh Shah Date: Thu, 9 May 2024 22:12:34 +0100 Subject: [PATCH 2/4] fix: response type for diffrent models and use of constant for top_P Signed-off-by: Yomesh Shah --- pkg/ai/amazonbedrock.go | 92 ++++++++++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 28 deletions(-) diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index b2e0ec32c5..f2853a7f7a 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -22,12 +22,6 @@ type AmazonBedRockClient struct { temperature float32 } -// InvokeModelResponseBody represents the response body structure from the model invocation. -type InvokeModelResponseBody struct { - Completion string `json:"completion"` - Stop_reason string `json:"stop_reason"` -} - // Amazon BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt) // https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#bedrock-regions const BEDROCK_DEFAULT_REGION = "us-east-1" // default use us-east-1 region @@ -66,6 +60,8 @@ var BEDROCK_MODELS = []string{ ModelAmazonTitanExpressV1, } +const TOPP = "0.9" + // GetModelOrDefault check config model func GetModelOrDefault(model string) string { @@ -122,36 +118,35 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { // GetCompletion sends a request to the model for generating completion based on the provided prompt. func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) { - // Prepare the input data for the model invocation based on the model + // Prepare the input data for the model invocation based on the model & the Response Body per model as well. var request map[string]interface{} - switch a.model { case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1: request = map[string]interface{}{ "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), "max_tokens_to_sample": 1024, "temperature": a.temperature, - "top_p": 0.9, + "top_p": TOPP, } - case ModelA21J2UltraV1, ModelA21J2JumboInstruct: + case ModelA21J2UltraV1, ModelA21J2JumboInstruct: request = map[string]interface{}{ - "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), - "maxTokens": 1024, + "prompt": prompt, + "maxTokens": 2048, "temperature": a.temperature, - "topP": 0.9, + "topP": TOPP, } - case ModelAmazonTitanExpressV1: + case ModelAmazonTitanExpressV1: request = map[string]interface{}{ - "inputText": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), + "inputText": fmt.Sprintf("\n\nUser: %s", prompt), "textGenerationConfig": map[string]interface{}{ - "maxTokenCount": 1024, + "maxTokenCount": 8000, "temperature": a.temperature, - "topP": 0.9, + "topP": TOPP, }, - } - default: + } + default: return "", fmt.Errorf("model %s not supported", a.model) - } + } body, err := json.Marshal(request) @@ -172,15 +167,56 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) if err != nil { return "", err } - // Parse the response body - output := &InvokeModelResponseBody{} - err = json.Unmarshal(resp.Body, output) - if err != nil { - return "", err - } - return output.Completion, nil + + // Response type changes as per model + switch a.model { + case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1: + type InvokeModelResponseBody struct { + Completion string `json:"completion"` + Stop_reason string `json:"stop_reason"` + } + output := &InvokeModelResponseBody{} + err = json.Unmarshal(resp.Body, output) + if err != nil { + return "", err + } + return output.Completion, nil + case ModelA21J2UltraV1, ModelA21J2JumboInstruct: + type Data struct { + Text string `json:"text"` + } + type Completion struct { + Data Data `json:"data"` + } + type InvokeModelResponseBody struct { + Completions []Completion `json:"completions"` + } + output := &InvokeModelResponseBody{} + err = json.Unmarshal(resp.Body, output) + if err != nil { + return "", err + } + return output.Completions[0].Data.Text, nil + case ModelAmazonTitanExpressV1: + type Result struct { + TokenCount int `json:"tokenCount"` + OutputText string `json:"outputText"` + CompletionReason string `json:"completionReason"` + } + type InvokeModelResponseBody struct { + InputTextTokenCount int `json:"inputTextTokenCount"` + Results []Result `json:"results"` + } + output := &InvokeModelResponseBody{} + err = json.Unmarshal(resp.Body, output) + if err != nil { + return "", err + } + return output.Results[0].OutputText, nil + default: + return "", fmt.Errorf("model %s not supported", a.model) + } } - // GetName returns the name of the AmazonBedRockClient. func (a *AmazonBedRockClient) GetName() string { return amazonbedrockAIClientName From 57a7653abc6b569bde76c0993a3f82ecaa8f56a0 Mon Sep 17 00:00:00 2001 From: Yomesh Shah Date: Fri, 10 May 2024 12:26:17 +0100 Subject: [PATCH 3/4] fix: constant for top_P as int vs string Signed-off-by: Yomesh Shah --- pkg/ai/amazonbedrock.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index f2853a7f7a..fff07609a7 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -60,7 +60,7 @@ var BEDROCK_MODELS = []string{ ModelAmazonTitanExpressV1, } -const TOPP = "0.9" +const TOPP = 0.9 // GetModelOrDefault check config model func GetModelOrDefault(model string) string { From a2502771a925d4b8eef25eb949f61945775fc547 Mon Sep 17 00:00:00 2001 From: Yomesh Shah Date: Fri, 12 Jul 2024 22:04:37 +0100 Subject: [PATCH 4/4] feat: moved topP and maxTokens to config rather than being constants in the code Signed-off-by: Yomesh Shah --- pkg/ai/amazonbedrock.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index fff07609a7..c7868c382b 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -20,6 +20,8 @@ type AmazonBedRockClient struct { client *bedrockruntime.BedrockRuntime model string temperature float32 + topP float32 + maxTokens int } // Amazon BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt) @@ -60,7 +62,7 @@ var BEDROCK_MODELS = []string{ ModelAmazonTitanExpressV1, } -const TOPP = 0.9 +//const TOPP = 0.9 moved to config // GetModelOrDefault check config model func GetModelOrDefault(model string) string { @@ -111,6 +113,8 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { a.client = bedrockruntime.New(sess) a.model = GetModelOrDefault(config.GetModel()) a.temperature = config.GetTemperature() + a.topP = config.GetTopP() + a.maxTokens = config.GetMaxTokens() return nil } @@ -124,24 +128,24 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1: request = map[string]interface{}{ "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), - "max_tokens_to_sample": 1024, + "max_tokens_to_sample": a.maxTokens, "temperature": a.temperature, - "top_p": TOPP, + "top_p": a.topP, } case ModelA21J2UltraV1, ModelA21J2JumboInstruct: request = map[string]interface{}{ "prompt": prompt, - "maxTokens": 2048, + "maxTokens": a.maxTokens, "temperature": a.temperature, - "topP": TOPP, + "topP": a.topP, } case ModelAmazonTitanExpressV1: request = map[string]interface{}{ "inputText": fmt.Sprintf("\n\nUser: %s", prompt), "textGenerationConfig": map[string]interface{}{ - "maxTokenCount": 8000, + "maxTokenCount": a.maxTokens, "temperature": a.temperature, - "topP": TOPP, + "topP": a.topP, }, } default: