Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added support for A21 and Amazon Titan models via bedrock api #1101

Merged
merged 8 commits into from
Sep 17, 2024
44 changes: 37 additions & 7 deletions pkg/ai/amazonbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ const (
ModelAnthropicClaudeV2 = "anthropic.claude-v2"
ModelAnthropicClaudeV1 = "anthropic.claude-v1"
ModelAnthropicClaudeInstantV1 = "anthropic.claude-instant-v1"
ModelA21J2UltraV1 = "ai21.j2-ultra-v1"
ModelA21J2JumboInstruct = "ai21.j2-jumbo-instruct"
ModelAmazonTitanExpressV1 = "amazon.titan-text-express-v1"
)

var BEDROCK_MODELS = []string{
ModelAnthropicClaudeV2,
ModelAnthropicClaudeV1,
ModelAnthropicClaudeInstantV1,
ModelA21J2UltraV1,
ModelA21J2JumboInstruct,
ModelAmazonTitanExpressV1,
}

// GetModelOrDefault check config model
Expand Down Expand Up @@ -116,13 +122,37 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
// GetCompletion sends a request to the model for generating completion based on the provided prompt.
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {

// Prepare the input data for the model invocation
request := map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"max_tokens_to_sample": 1024,
"temperature": a.temperature,
"top_p": 0.9,
}
// Prepare the input data for the model invocation based on the model
var request map[string]interface{}

switch a.model {
case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1:
request = map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
awsyshah marked this conversation as resolved.
Show resolved Hide resolved
"max_tokens_to_sample": 1024,
awsyshah marked this conversation as resolved.
Show resolved Hide resolved
"temperature": a.temperature,
"top_p": 0.9,
}
case ModelA21J2UltraV1, ModelA21J2JumboInstruct:
request = map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"maxTokens": 1024,
"temperature": a.temperature,
"topP": 0.9,
}
case ModelAmazonTitanExpressV1:
request = map[string]interface{}{
"inputText": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"textGenerationConfig": map[string]interface{}{
"maxTokenCount": 1024,
"temperature": a.temperature,
"topP": 0.9,
},
}
default:
return "", fmt.Errorf("model %s not supported", a.model)
}


body, err := json.Marshal(request)
if err != nil {
Expand Down
Loading