From 7a4d9f868d01050485cea42a3dde8f20ede4c8f6 Mon Sep 17 00:00:00 2001 From: Guangya Liu Date: Thu, 11 Jul 2024 14:41:23 -0400 Subject: [PATCH] Enabled auth add support watsonx backend Signed-off-by: Guangya Liu --- cmd/auth/add.go | 6 ++++++ cmd/auth/auth.go | 1 + pkg/ai/iai.go | 8 +++++++- pkg/ai/watsonxai.go | 28 +++++++++++++--------------- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/cmd/auth/add.go b/cmd/auth/add.go index 5464eadf9f..660d7977d4 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -48,6 +48,9 @@ var addCmd = &cobra.Command{ if strings.ToLower(backend) == "amazonbedrock" { _ = cmd.MarkFlagRequired("providerRegion") } + if strings.ToLower(backend) == "watsonxai" { + _ = cmd.MarkFlagRequired("projectId") + } }, Run: func(cmd *cobra.Command, args []string) { @@ -132,6 +135,7 @@ var addCmd = &cobra.Command{ TopK: topK, MaxTokens: maxTokens, OrganizationId: organizationId, + ProjectID: projectId, } if providerIndex == -1 { @@ -179,4 +183,6 @@ func init() { addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)") // add flag for openai organization addCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "OpenAI or AzureOpenAI Organization ID (only for openai and azureopenai backend)") + // add flag for IBM Watsonx Project ID + addCmd.Flags().StringVarP(&projectId, "projectId", "j", "", "IBM Watsonx Project ID (only for watsonxai backend)") } diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index c8f4e209e9..ed44dff5e1 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -33,6 +33,7 @@ var ( topK int32 maxTokens int organizationId string + projectId string ) var configAI ai.AIConfiguration diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 38c8500346..f7fd128b31 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -84,6 +84,7 @@ type IAIConfig interface { GetProviderId() string GetCompartmentId() string GetOrganizationId() string + GetProjectId() string GetCustomHeaders() []http.Header } @@ -119,6 +120,7 @@ type AIProvider struct { TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"` MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"` + ProjectID string `mapstructure:"projectid" yaml:"projectid,omitempty"` CustomHeaders []http.Header `mapstructure:"customHeaders"` } @@ -177,11 +179,15 @@ func (p *AIProvider) GetOrganizationId() string { return p.OrganizationId } +func (p *AIProvider) GetProjectId() string { + return p.ProjectID +} + func (p *AIProvider) GetCustomHeaders() []http.Header { return p.CustomHeaders } -var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"} +var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"} func NeedPassword(backend string) bool { for _, b := range passwordlessProviders { diff --git a/pkg/ai/watsonxai.go b/pkg/ai/watsonxai.go index f6ce81c1a6..9134b7b6e2 100644 --- a/pkg/ai/watsonxai.go +++ b/pkg/ai/watsonxai.go @@ -1,10 +1,9 @@ package ai import ( - "os" - "fmt" "context" "errors" + "fmt" wx "github.com/IBM/watsonx-go/pkg/models" ) @@ -14,12 +13,12 @@ const watsonxAIClientName = "watsonxai" type WatsonxAIClient struct { nopCloser - client *wx.Client - model string - temperature float32 - topP float32 - topK int32 - maxNewTokens int + client *wx.Client + model string + temperature float32 + topP float32 + topK int32 + maxNewTokens int } const ( @@ -27,7 +26,7 @@ const ( ) func (c *WatsonxAIClient) Configure(config IAIConfig) error { - if(config.GetModel() == "") { + if config.GetModel() == "" { c.model = config.GetModel() } else { c.model = modelMetallama @@ -37,20 +36,19 @@ func (c *WatsonxAIClient) Configure(config IAIConfig) error { c.topK = config.GetTopK() c.maxNewTokens = config.GetMaxTokens() - // WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY" - // WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID" - apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName) - + apiKey := config.GetPassword() if apiKey == "" { return errors.New("No watsonx API key provided") } - if projectID == "" { + + projectId := config.GetProjectId() + if projectId == "" { return errors.New("No watsonx project ID provided") } client, err := wx.NewClient( wx.WithWatsonxAPIKey(apiKey), - wx.WithWatsonxProjectID(projectID), + wx.WithWatsonxProjectID(projectId), ) if err != nil { return fmt.Errorf("Failed to create client for testing. Error: %v", err)