Skip to content

Commit

Permalink
fix: add default maxToken value of watsonxai backend (#1209)
Browse files Browse the repository at this point in the history
Signed-off-by: yanweili <[email protected]>
Co-authored-by: yanweili <[email protected]>
  • Loading branch information
liyanwei93 and yanweili authored Aug 2, 2024
1 parent a068310 commit d43fd87
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions pkg/ai/watsonxai.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package ai

import (
"os"
"fmt"
"context"
"errors"
"fmt"
"os"

wx "github.com/IBM/watsonx-go/pkg/models"
)
Expand All @@ -14,28 +14,33 @@ const watsonxAIClientName = "watsonxai"
type WatsonxAIClient struct {
nopCloser

client *wx.Client
model string
temperature float32
topP float32
topK int32
maxNewTokens int
client *wx.Client
model string
temperature float32
topP float32
topK int32
maxNewTokens int
}

const (
modelMetallama = "ibm/granite-13b-chat-v2"
maxTokens = 2048
)

func (c *WatsonxAIClient) Configure(config IAIConfig) error {
if(config.GetModel() == "") {
if config.GetModel() == "" {
c.model = modelMetallama
} else {
c.model = config.GetModel()
}
if config.GetMaxTokens() == 0 {
c.maxNewTokens = maxTokens
} else {
c.model = modelMetallama
c.maxNewTokens = config.GetMaxTokens()
}
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.topK = config.GetTopK()
c.maxNewTokens = config.GetMaxTokens()

// WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY"
// WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID"
Expand Down Expand Up @@ -75,7 +80,6 @@ func (c *WatsonxAIClient) GetCompletion(ctx context.Context, prompt string) (str
if result.Text == "" {
return "", errors.New("Expected a result, but got an empty string")
}

return result.Text, nil
}

Expand Down

0 comments on commit d43fd87

Please sign in to comment.