Skip to content

Commit

Permalink
feat: added support for A21 and Amazon Titan models via bedrock api
Browse files Browse the repository at this point in the history
Signed-off-by: Yomesh Shah <[email protected]>
  • Loading branch information
awsyshah committed May 8, 2024
1 parent 3c48231 commit ae1f3ec
Showing 1 changed file with 37 additions and 7 deletions.
44 changes: 37 additions & 7 deletions pkg/ai/amazonbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ const (
ModelAnthropicClaudeV2 = "anthropic.claude-v2"
ModelAnthropicClaudeV1 = "anthropic.claude-v1"
ModelAnthropicClaudeInstantV1 = "anthropic.claude-instant-v1"
ModelA21J2UltraV1 = "ai21.j2-ultra-v1"
ModelA21J2JumboInstruct = "ai21.j2-jumbo-instruct"
ModelAmazonTitanExpressV1 = "amazon.titan-text-express-v1"
)

var BEDROCK_MODELS = []string{
ModelAnthropicClaudeV2,
ModelAnthropicClaudeV1,
ModelAnthropicClaudeInstantV1,
ModelA21J2UltraV1,
ModelA21J2JumboInstruct,
ModelAmazonTitanExpressV1,
}

// GetModelOrDefault check config model
Expand Down Expand Up @@ -116,13 +122,37 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// GetCompletion sends a request to the model for generating completion based on the provided prompt.
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {

// Prepare the input data for the model invocation
request := map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"max_tokens_to_sample": 1024,
"temperature": a.temperature,
"top_p": 0.9,
}
// Prepare the input data for the model invocation based on the model
var request map[string]interface{}

switch a.model {
case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1:
request = map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"max_tokens_to_sample": 1024,
"temperature": a.temperature,
"top_p": 0.9,
}
case ModelA21J2UltraV1, ModelA21J2JumboInstruct:
request = map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"maxTokens": 1024,
"temperature": a.temperature,
"topP": 0.9,
}
case ModelAmazonTitanExpressV1:
request = map[string]interface{}{
"inputText": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"textGenerationConfig": map[string]interface{}{
"maxTokenCount": 1024,
"temperature": a.temperature,
"topP": 0.9,
},
}
default:
return "", fmt.Errorf("model %s not supported", a.model)
}


body, err := json.Marshal(request)
if err != nil {
Expand Down

0 comments on commit ae1f3ec

Please sign in to comment.