diff --git a/.github/workflows/build_container.yaml b/.github/workflows/build_container.yaml index bafaa76166..270ef54c34 100644 --- a/.github/workflows/build_container.yaml +++ b/.github/workflows/build_container.yaml @@ -74,10 +74,10 @@ jobs: - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3 + uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3 - name: Build Docker Image - uses: docker/build-push-action@2cdde995de11925a030ce8070c3d77a52ffcf1c0 # v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5 with: context: . platforms: linux/amd64 @@ -126,10 +126,10 @@ jobs: - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3 + uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3 - name: Build Docker Image - uses: docker/build-push-action@2cdde995de11925a030ce8070c3d77a52ffcf1c0 # v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5 with: context: . file: ./container/Dockerfile diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index a91b14d7de..d9f6cd1f3c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -80,7 +80,7 @@ jobs: - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3 + uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3 - name: Login to GitHub Container Registry uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3 @@ -90,7 +90,7 @@ jobs: password: ${{ secrets.GITHUB_TOKEN }} - name: Build Docker Image - uses: docker/build-push-action@2cdde995de11925a030ce8070c3d77a52ffcf1c0 # v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5 with: context: . file: ./container/Dockerfile diff --git a/.github/workflows/semantic_pr.yaml b/.github/workflows/semantic_pr.yaml index 413d306a6a..375b0ac85f 100644 --- a/.github/workflows/semantic_pr.yaml +++ b/.github/workflows/semantic_pr.yaml @@ -16,7 +16,7 @@ jobs: pull-requests: read # Needed for reading prs steps: - name: Validate Pull Request - uses: amannn/action-semantic-pull-request@cfb60706e18bc85e8aec535e3c577abe8f70378e # v5.5.2 + uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: diff --git a/README.md b/README.md index cf9b66b7d9..38cd565e87 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,12 @@ ![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/k8sgpt-ai/k8sgpt/release.yaml) ![GitHub release (latest by date)](https://img.shields.io/github/v/release/k8sgpt-ai/k8sgpt) [![OpenSSF Best Practices](https://bestpractices.coreinfrastructure.org/projects/7272/badge)](https://bestpractices.coreinfrastructure.org/projects/7272) -[![Link to documentation](https://img.shields.io/static/v1?label=%F0%9F%93%96&message=Documentation&color=blue)](https://docs.k8sgpt.ai/) +[![Link to documentation](https://img.shields.io/static/v1?label=%F0%9F%93%96&message=Documentation&color=blue)](https://docs.k8sgpt.ai/) [![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2Fk8sgpt-ai%2Fk8sgpt.svg?type=shield)](https://app.fossa.com/projects/git%2Bgithub.com%2Fk8sgpt-ai%2Fk8sgpt?ref=badge_shield) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Go version](https://img.shields.io/github/go-mod/go-version/k8sgpt-ai/k8sgpt.svg)](https://github.com/k8sgpt-ai/k8sgpt) [![codecov](https://codecov.io/github/k8sgpt-ai/k8sgpt/graph/badge.svg?token=ZLR7NG8URE)](https://codecov.io/github/k8sgpt-ai/k8sgpt) -![GitHub last commit (branch)](https://img.shields.io/github/last-commit/k8sgpt-ai/k8sgpt/main) +![GitHub last commit (branch)](https://img.shields.io/github/last-commit/k8sgpt-ai/k8sgpt/main) `k8sgpt` is a tool for scanning your Kubernetes clusters, diagnosing, and triaging issues in simple English. @@ -30,7 +30,13 @@ _Out of the box integration with OpenAI, Azure, Cohere, Amazon Bedrock, Google G ### Linux/Mac via brew +```sh +$ brew install k8sgpt ``` + +or + +```sh brew tap k8sgpt-ai/k8sgpt brew install k8sgpt ``` @@ -302,12 +308,13 @@ K8sGPT uses the chosen LLM, generative AI provider when you want to explain the You can list available providers using `k8sgpt auth list`: ``` -Default: +Default: > openai -Active: -Unused: +Active: +Unused: > openai > localai +> ollama > azureopenai > cohere > amazonbedrock @@ -316,6 +323,7 @@ Unused: > huggingface > noopai > googlevertexai +> watsonxai ``` For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/). @@ -425,7 +433,7 @@ Config file locations: There may be scenarios where caching remotely is preferred. In these scenarios K8sGPT supports AWS S3 or Azure Blob storage Integration. - Remote caching + Remote caching Note: You can only configure and use only one remote cache at a time _Adding a remote cache_ @@ -440,11 +448,11 @@ _Adding a remote cache_ * We support a number of [techniques](https://learn.microsoft.com/en-us/azure/developer/go/azure-sdk-authentication?tabs=bash#2-authenticate-with-azure) to authenticate against Azure * Configuration, ``` k8sgpt cache add azure --storageacc --container ``` * K8sGPT assumes that the storage account already exist and it will create the container if it does not exist - * It is the **user** responsibility have to grant specific permissions to their identity in order to be able to upload blob files and create SA containers (e.g Storage Blob Data Contributor) + * It is the **user** responsibility have to grant specific permissions to their identity in order to be able to upload blob files and create SA containers (e.g Storage Blob Data Contributor) * Google Cloud Storage * _As a prerequisite `GOOGLE_APPLICATION_CREDENTIALS` are required as environmental variables._ * Configuration, ``` k8sgpt cache add gcs --region --bucket --projectid ``` - * K8sGPT will create the bucket if it does not exist + * K8sGPT will create the bucket if it does not exist _Listing cache items_ ``` diff --git a/go.mod b/go.mod index 533075d733..8c08970d2b 100644 --- a/go.mod +++ b/go.mod @@ -10,9 +10,10 @@ require ( github.com/kedacore/keda/v2 v2.11.2 github.com/magiconair/properties v1.8.7 github.com/mittwald/go-helm-client v0.12.9 + github.com/ollama/ollama v0.1.48 github.com/sashabaranov/go-openai v1.23.0 github.com/schollz/progressbar/v3 v3.14.2 - github.com/spf13/cobra v1.8.0 + github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.9.0 golang.org/x/term v0.21.0 @@ -34,6 +35,7 @@ require ( cloud.google.com/go/vertexai v0.7.1 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.2 + github.com/IBM/watsonx-go v1.0.0 github.com/aws/aws-sdk-go v1.53.21 github.com/cohere-ai/cohere-go/v2 v2.7.3 github.com/google/generative-ai-go v0.11.0 @@ -196,7 +198,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect - github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/peterbourgon/diskv v2.0.1+incompatible // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/go.sum b/go.sum index 9c702a3541..8159465c82 100644 --- a/go.sum +++ b/go.sum @@ -1245,6 +1245,8 @@ github.com/Code-Hex/go-generics-cache v1.3.1 h1:i8rLwyhoyhaerr7JpjtYjJZUcCbWOdiY github.com/Code-Hex/go-generics-cache v1.3.1/go.mod h1:qxcC9kRVrct9rHeiYpFWSoW1vxyillCVzX13KZG8dl4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/IBM/watsonx-go v1.0.0 h1:xG7xA2W9N0RsiztR26dwBI8/VxIX4wTBhdYmEis2Yl8= +github.com/IBM/watsonx-go v1.0.0/go.mod h1:8lzvpe/158JkrzvcoIcIj6OdNty5iC9co5nQHfkhRtM= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= @@ -1395,7 +1397,7 @@ github.com/containerd/errdefs v0.1.0 h1:m0wCRBiu1WJT/Fr+iOoQHMQS/eP5myQ8lCv4Dz5Z github.com/containerd/errdefs v0.1.0/go.mod h1:YgWiiHtLmSeBrvpw+UfPijzbLaB77mEG1WwJTDETIV0= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= -github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= @@ -1806,8 +1808,8 @@ github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ib github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.10/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= -github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU= -github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kolo/xmlrpc v0.0.0-20220921171641-a4b6fa1dd06b h1:udzkj9S/zlT5X367kqJis0QP7YMxobob6zhzq6Yre00= github.com/kolo/xmlrpc v0.0.0-20220921171641-a4b6fa1dd06b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -1917,6 +1919,8 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/ollama/ollama v0.1.48 h1:6j9ziyqyAJD3NuTLYJlzZl/b/Q9PDvvYg02FGmREUmE= +github.com/ollama/ollama v0.1.48/go.mod h1:TvVa25PEZI6M0bosiW1sa2XJGq3Xw/OPlpUAkMEntTU= github.com/onsi/ginkgo/v2 v2.17.2 h1:7eMhcy3GimbsA3hEnVKdw/PQM9XN9krpKVXsZdph0/g= github.com/onsi/ginkgo/v2 v2.17.2/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc= github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= @@ -1931,8 +1935,8 @@ github.com/ovh/go-ovh v1.4.3 h1:Gs3V823zwTFpzgGLZNI6ILS4rmxZgJwJCz54Er9LwD0= github.com/ovh/go-ovh v1.4.3/go.mod h1:AkPXVtgwB6xlKblMjRKJJmjRp+ogrE7fz2lVgcQY8SY= github.com/owenrumney/squealer v1.2.1 h1:4ryMMT59aaz8VMsqsD+FDkarADJz0F1dcq2fd0DRR+c= github.com/owenrumney/squealer v1.2.1/go.mod h1:7D0a/+Bouwy504YhaWsBYW73kyklSEq1MNf6zsNoTRg= -github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= -github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= @@ -2059,8 +2063,8 @@ github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNo github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= -github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 08caa02079..d158882b13 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -22,6 +22,7 @@ var ( &OpenAIClient{}, &AzureAIClient{}, &LocalAIClient{}, + &OllamaClient{}, &NoOpAIClient{}, &CohereClient{}, &AmazonBedRockClient{}, @@ -30,10 +31,12 @@ var ( &HuggingfaceClient{}, &GoogleVertexAIClient{}, &OCIGenAIClient{}, + &WatsonxAIClient{}, } Backends = []string{ openAIClientName, localAIClientName, + ollamaClientName, azureAIClientName, cohereAIClientName, amazonbedrockAIClientName, @@ -43,6 +46,7 @@ var ( huggingfaceAIClientName, googleVertexAIClientName, ociClientName, + watsonxAIClientName, } ) @@ -170,7 +174,7 @@ func (p *AIProvider) GetOrganizationId() string { return p.OrganizationId } -var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"} +var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"} func NeedPassword(backend string) bool { for _, b := range passwordlessProviders { diff --git a/pkg/ai/ollama.go b/pkg/ai/ollama.go new file mode 100644 index 0000000000..098c455f22 --- /dev/null +++ b/pkg/ai/ollama.go @@ -0,0 +1,102 @@ +/* +Copyright 2023 The K8sGPT Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ai + +import ( + "context" + "errors" + "net/http" + "net/url" + + ollama "github.com/ollama/ollama/api" +) + +const ollamaClientName = "ollama" + +type OllamaClient struct { + nopCloser + + client *ollama.Client + model string + temperature float32 + topP float32 +} + +const ( + defaultBaseURL = "http://localhost:11434" + defaultModel = "llama3" +) + +func (c *OllamaClient) Configure(config IAIConfig) error { + baseURL := config.GetBaseURL() + if baseURL == "" { + baseURL = defaultBaseURL + } + baseClientURL, err := url.Parse(baseURL) + if err != nil { + return err + } + + proxyEndpoint := config.GetProxyEndpoint() + httpClient := http.DefaultClient + if proxyEndpoint != "" { + proxyUrl, err := url.Parse(proxyEndpoint) + if err != nil { + return err + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + } + + httpClient = &http.Client{ + Transport: transport, + } + } + + c.client = ollama.NewClient(baseClientURL, httpClient) + if c.client == nil { + return errors.New("error creating Ollama client") + } + c.model = config.GetModel() + if c.model == "" { + c.model = defaultModel + } + c.temperature = config.GetTemperature() + c.topP = config.GetTopP() + return nil +} +func (c *OllamaClient) GetCompletion(ctx context.Context, prompt string) (string, error) { + req := &ollama.GenerateRequest{ + Model: c.model, + Prompt: prompt, + Stream: new(bool), + Options: map[string]interface{}{ + "temperature": c.temperature, + "top_p": c.topP, + }, + } + completion := "" + respFunc := func(resp ollama.GenerateResponse) error { + completion = resp.Response + return nil + } + err := c.client.Generate(ctx, req, respFunc) + if err != nil { + return "", err + } + return completion, nil +} +func (a *OllamaClient) GetName() string { + return ollamaClientName +} diff --git a/pkg/ai/watsonxai.go b/pkg/ai/watsonxai.go new file mode 100644 index 0000000000..f6ce81c1a6 --- /dev/null +++ b/pkg/ai/watsonxai.go @@ -0,0 +1,84 @@ +package ai + +import ( + "os" + "fmt" + "context" + "errors" + + wx "github.com/IBM/watsonx-go/pkg/models" +) + +const watsonxAIClientName = "watsonxai" + +type WatsonxAIClient struct { + nopCloser + + client *wx.Client + model string + temperature float32 + topP float32 + topK int32 + maxNewTokens int +} + +const ( + modelMetallama = "ibm/granite-13b-chat-v2" +) + +func (c *WatsonxAIClient) Configure(config IAIConfig) error { + if(config.GetModel() == "") { + c.model = config.GetModel() + } else { + c.model = modelMetallama + } + c.temperature = config.GetTemperature() + c.topP = config.GetTopP() + c.topK = config.GetTopK() + c.maxNewTokens = config.GetMaxTokens() + + // WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY" + // WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID" + apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName) + + if apiKey == "" { + return errors.New("No watsonx API key provided") + } + if projectID == "" { + return errors.New("No watsonx project ID provided") + } + + client, err := wx.NewClient( + wx.WithWatsonxAPIKey(apiKey), + wx.WithWatsonxProjectID(projectID), + ) + if err != nil { + return fmt.Errorf("Failed to create client for testing. Error: %v", err) + } + c.client = client + + return nil +} + +func (c *WatsonxAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { + result, err := c.client.GenerateText( + c.model, + prompt, + wx.WithTemperature((float64)(c.temperature)), + wx.WithTopP((float64)(c.topP)), + wx.WithTopK((uint)(c.topK)), + wx.WithMaxNewTokens((uint)(c.maxNewTokens)), + ) + if err != nil { + return "", fmt.Errorf("Expected no error, but got an error: %v", err) + } + if result.Text == "" { + return "", errors.New("Expected a result, but got an empty string") + } + + return result.Text, nil +} + +func (c *WatsonxAIClient) GetName() string { + return watsonxAIClientName +}