From 2fa925bded9e597aa6a97698ae568e54dc557de2 Mon Sep 17 00:00:00 2001 From: Guangya Liu Date: Tue, 20 Aug 2024 03:35:45 -0400 Subject: [PATCH] fix: enabled auth add support watsonx backend (#1190) Signed-off-by: Guangya Liu Signed-off-by: Alex Jones Co-authored-by: Alex Jones Co-authored-by: Matthis Signed-off-by: AlexsJones --- cmd/auth/add.go | 7 +++++-- pkg/ai/iai.go | 2 +- pkg/ai/watsonxai.go | 13 +++++-------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/cmd/auth/add.go b/cmd/auth/add.go index 5464eadf9f..1fb19103d6 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -48,6 +48,9 @@ var addCmd = &cobra.Command{ if strings.ToLower(backend) == "amazonbedrock" { _ = cmd.MarkFlagRequired("providerRegion") } + if strings.ToLower(backend) == "watsonxai" { + _ = cmd.MarkFlagRequired("providerId") + } }, Run: func(cmd *cobra.Command, args []string) { @@ -173,8 +176,8 @@ func init() { addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)") //add flag for amazonbedrock region name addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, googlevertexai backend)") - //add flag for vertexAI Project ID - addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)") + //add flag for vertexAI/WatsonxAI Project ID + addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai/watsonxai backend)") //add flag for OCI Compartment ID addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)") // add flag for openai organization diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 38c8500346..e1f1c41e10 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -181,7 +181,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header { return p.CustomHeaders } -var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"} +var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"} func NeedPassword(backend string) bool { for _, b := range passwordlessProviders { diff --git a/pkg/ai/watsonxai.go b/pkg/ai/watsonxai.go index 655b8b82d2..15bbfdc922 100644 --- a/pkg/ai/watsonxai.go +++ b/pkg/ai/watsonxai.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "os" - wx "github.com/IBM/watsonx-go/pkg/models" ) @@ -42,20 +40,19 @@ func (c *WatsonxAIClient) Configure(config IAIConfig) error { c.topP = config.GetTopP() c.topK = config.GetTopK() - // WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY" - // WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID" - apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName) - + apiKey := config.GetPassword() if apiKey == "" { return errors.New("No watsonx API key provided") } - if projectID == "" { + + projectId := config.GetProviderId() + if projectId == "" { return errors.New("No watsonx project ID provided") } client, err := wx.NewClient( wx.WithWatsonxAPIKey(apiKey), - wx.WithWatsonxProjectID(projectID), + wx.WithWatsonxProjectID(projectId), ) if err != nil { return fmt.Errorf("Failed to create client for testing. Error: %v", err)