diff --git a/.github/workflows/build_container.yaml b/.github/workflows/build_container.yaml
index bafaa76166..270ef54c34 100644
--- a/.github/workflows/build_container.yaml
+++ b/.github/workflows/build_container.yaml
@@ -74,10 +74,10 @@ jobs:
- name: Set up Docker Buildx
id: buildx
- uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3
+ uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3
- name: Build Docker Image
- uses: docker/build-push-action@2cdde995de11925a030ce8070c3d77a52ffcf1c0 # v5
+ uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5
with:
context: .
platforms: linux/amd64
@@ -126,10 +126,10 @@ jobs:
- name: Set up Docker Buildx
id: buildx
- uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3
+ uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3
- name: Build Docker Image
- uses: docker/build-push-action@2cdde995de11925a030ce8070c3d77a52ffcf1c0 # v5
+ uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5
with:
context: .
file: ./container/Dockerfile
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index a91b14d7de..d9f6cd1f3c 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -80,7 +80,7 @@ jobs:
- name: Set up Docker Buildx
id: buildx
- uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3
+ uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3
- name: Login to GitHub Container Registry
uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3
@@ -90,7 +90,7 @@ jobs:
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build Docker Image
- uses: docker/build-push-action@2cdde995de11925a030ce8070c3d77a52ffcf1c0 # v5
+ uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5
with:
context: .
file: ./container/Dockerfile
diff --git a/.github/workflows/semantic_pr.yaml b/.github/workflows/semantic_pr.yaml
index 413d306a6a..375b0ac85f 100644
--- a/.github/workflows/semantic_pr.yaml
+++ b/.github/workflows/semantic_pr.yaml
@@ -16,7 +16,7 @@ jobs:
pull-requests: read # Needed for reading prs
steps:
- name: Validate Pull Request
- uses: amannn/action-semantic-pull-request@cfb60706e18bc85e8aec535e3c577abe8f70378e # v5.5.2
+ uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
diff --git a/README.md b/README.md
index cf9b66b7d9..38cd565e87 100644
--- a/README.md
+++ b/README.md
@@ -8,12 +8,12 @@
![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/k8sgpt-ai/k8sgpt/release.yaml)
![GitHub release (latest by date)](https://img.shields.io/github/v/release/k8sgpt-ai/k8sgpt)
[![OpenSSF Best Practices](https://bestpractices.coreinfrastructure.org/projects/7272/badge)](https://bestpractices.coreinfrastructure.org/projects/7272)
-[![Link to documentation](https://img.shields.io/static/v1?label=%F0%9F%93%96&message=Documentation&color=blue)](https://docs.k8sgpt.ai/)
+[![Link to documentation](https://img.shields.io/static/v1?label=%F0%9F%93%96&message=Documentation&color=blue)](https://docs.k8sgpt.ai/)
[![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2Fk8sgpt-ai%2Fk8sgpt.svg?type=shield)](https://app.fossa.com/projects/git%2Bgithub.com%2Fk8sgpt-ai%2Fk8sgpt?ref=badge_shield)
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Go version](https://img.shields.io/github/go-mod/go-version/k8sgpt-ai/k8sgpt.svg)](https://github.com/k8sgpt-ai/k8sgpt)
[![codecov](https://codecov.io/github/k8sgpt-ai/k8sgpt/graph/badge.svg?token=ZLR7NG8URE)](https://codecov.io/github/k8sgpt-ai/k8sgpt)
-![GitHub last commit (branch)](https://img.shields.io/github/last-commit/k8sgpt-ai/k8sgpt/main)
+![GitHub last commit (branch)](https://img.shields.io/github/last-commit/k8sgpt-ai/k8sgpt/main)
`k8sgpt` is a tool for scanning your Kubernetes clusters, diagnosing, and triaging issues in simple English.
@@ -30,7 +30,13 @@ _Out of the box integration with OpenAI, Azure, Cohere, Amazon Bedrock, Google G
### Linux/Mac via brew
+```sh
+$ brew install k8sgpt
```
+
+or
+
+```sh
brew tap k8sgpt-ai/k8sgpt
brew install k8sgpt
```
@@ -302,12 +308,13 @@ K8sGPT uses the chosen LLM, generative AI provider when you want to explain the
You can list available providers using `k8sgpt auth list`:
```
-Default:
+Default:
> openai
-Active:
-Unused:
+Active:
+Unused:
> openai
> localai
+> ollama
> azureopenai
> cohere
> amazonbedrock
@@ -316,6 +323,7 @@ Unused:
> huggingface
> noopai
> googlevertexai
+> watsonxai
```
For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/).
@@ -425,7 +433,7 @@ Config file locations:
There may be scenarios where caching remotely is preferred.
In these scenarios K8sGPT supports AWS S3 or Azure Blob storage Integration.
- Remote caching
+ Remote caching
Note: You can only configure and use only one remote cache at a time
_Adding a remote cache_
@@ -440,11 +448,11 @@ _Adding a remote cache_
* We support a number of [techniques](https://learn.microsoft.com/en-us/azure/developer/go/azure-sdk-authentication?tabs=bash#2-authenticate-with-azure) to authenticate against Azure
* Configuration, ``` k8sgpt cache add azure --storageacc --container ```
* K8sGPT assumes that the storage account already exist and it will create the container if it does not exist
- * It is the **user** responsibility have to grant specific permissions to their identity in order to be able to upload blob files and create SA containers (e.g Storage Blob Data Contributor)
+ * It is the **user** responsibility have to grant specific permissions to their identity in order to be able to upload blob files and create SA containers (e.g Storage Blob Data Contributor)
* Google Cloud Storage
* _As a prerequisite `GOOGLE_APPLICATION_CREDENTIALS` are required as environmental variables._
* Configuration, ``` k8sgpt cache add gcs --region --bucket --projectid ```
- * K8sGPT will create the bucket if it does not exist
+ * K8sGPT will create the bucket if it does not exist
_Listing cache items_
```
diff --git a/go.mod b/go.mod
index 533075d733..8c08970d2b 100644
--- a/go.mod
+++ b/go.mod
@@ -10,9 +10,10 @@ require (
github.com/kedacore/keda/v2 v2.11.2
github.com/magiconair/properties v1.8.7
github.com/mittwald/go-helm-client v0.12.9
+ github.com/ollama/ollama v0.1.48
github.com/sashabaranov/go-openai v1.23.0
github.com/schollz/progressbar/v3 v3.14.2
- github.com/spf13/cobra v1.8.0
+ github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.9.0
golang.org/x/term v0.21.0
@@ -34,6 +35,7 @@ require (
cloud.google.com/go/vertexai v0.7.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.2
+ github.com/IBM/watsonx-go v1.0.0
github.com/aws/aws-sdk-go v1.53.21
github.com/cohere-ai/cohere-go/v2 v2.7.3
github.com/google/generative-ai-go v0.11.0
@@ -196,7 +198,7 @@ require (
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
- github.com/pelletier/go-toml/v2 v2.1.0 // indirect
+ github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/peterbourgon/diskv v2.0.1+incompatible // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
diff --git a/go.sum b/go.sum
index 9c702a3541..8159465c82 100644
--- a/go.sum
+++ b/go.sum
@@ -1245,6 +1245,8 @@ github.com/Code-Hex/go-generics-cache v1.3.1 h1:i8rLwyhoyhaerr7JpjtYjJZUcCbWOdiY
github.com/Code-Hex/go-generics-cache v1.3.1/go.mod h1:qxcC9kRVrct9rHeiYpFWSoW1vxyillCVzX13KZG8dl4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
+github.com/IBM/watsonx-go v1.0.0 h1:xG7xA2W9N0RsiztR26dwBI8/VxIX4wTBhdYmEis2Yl8=
+github.com/IBM/watsonx-go v1.0.0/go.mod h1:8lzvpe/158JkrzvcoIcIj6OdNty5iC9co5nQHfkhRtM=
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
@@ -1395,7 +1397,7 @@ github.com/containerd/errdefs v0.1.0 h1:m0wCRBiu1WJT/Fr+iOoQHMQS/eP5myQ8lCv4Dz5Z
github.com/containerd/errdefs v0.1.0/go.mod h1:YgWiiHtLmSeBrvpw+UfPijzbLaB77mEG1WwJTDETIV0=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
-github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
+github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
@@ -1806,8 +1808,8 @@ github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ib
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.10/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
-github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU=
-github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
+github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
+github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kolo/xmlrpc v0.0.0-20220921171641-a4b6fa1dd06b h1:udzkj9S/zlT5X367kqJis0QP7YMxobob6zhzq6Yre00=
github.com/kolo/xmlrpc v0.0.0-20220921171641-a4b6fa1dd06b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
@@ -1917,6 +1919,8 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
+github.com/ollama/ollama v0.1.48 h1:6j9ziyqyAJD3NuTLYJlzZl/b/Q9PDvvYg02FGmREUmE=
+github.com/ollama/ollama v0.1.48/go.mod h1:TvVa25PEZI6M0bosiW1sa2XJGq3Xw/OPlpUAkMEntTU=
github.com/onsi/ginkgo/v2 v2.17.2 h1:7eMhcy3GimbsA3hEnVKdw/PQM9XN9krpKVXsZdph0/g=
github.com/onsi/ginkgo/v2 v2.17.2/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc=
github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk=
@@ -1931,8 +1935,8 @@ github.com/ovh/go-ovh v1.4.3 h1:Gs3V823zwTFpzgGLZNI6ILS4rmxZgJwJCz54Er9LwD0=
github.com/ovh/go-ovh v1.4.3/go.mod h1:AkPXVtgwB6xlKblMjRKJJmjRp+ogrE7fz2lVgcQY8SY=
github.com/owenrumney/squealer v1.2.1 h1:4ryMMT59aaz8VMsqsD+FDkarADJz0F1dcq2fd0DRR+c=
github.com/owenrumney/squealer v1.2.1/go.mod h1:7D0a/+Bouwy504YhaWsBYW73kyklSEq1MNf6zsNoTRg=
-github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
-github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
+github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
+github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI=
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI=
@@ -2059,8 +2063,8 @@ github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNo
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
-github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
-github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
+github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
+github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go
index 08caa02079..d158882b13 100644
--- a/pkg/ai/iai.go
+++ b/pkg/ai/iai.go
@@ -22,6 +22,7 @@ var (
&OpenAIClient{},
&AzureAIClient{},
&LocalAIClient{},
+ &OllamaClient{},
&NoOpAIClient{},
&CohereClient{},
&AmazonBedRockClient{},
@@ -30,10 +31,12 @@ var (
&HuggingfaceClient{},
&GoogleVertexAIClient{},
&OCIGenAIClient{},
+ &WatsonxAIClient{},
}
Backends = []string{
openAIClientName,
localAIClientName,
+ ollamaClientName,
azureAIClientName,
cohereAIClientName,
amazonbedrockAIClientName,
@@ -43,6 +46,7 @@ var (
huggingfaceAIClientName,
googleVertexAIClientName,
ociClientName,
+ watsonxAIClientName,
}
)
@@ -170,7 +174,7 @@ func (p *AIProvider) GetOrganizationId() string {
return p.OrganizationId
}
-var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
+var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}
func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {
diff --git a/pkg/ai/ollama.go b/pkg/ai/ollama.go
new file mode 100644
index 0000000000..098c455f22
--- /dev/null
+++ b/pkg/ai/ollama.go
@@ -0,0 +1,102 @@
+/*
+Copyright 2023 The K8sGPT Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package ai
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/url"
+
+ ollama "github.com/ollama/ollama/api"
+)
+
+const ollamaClientName = "ollama"
+
+type OllamaClient struct {
+ nopCloser
+
+ client *ollama.Client
+ model string
+ temperature float32
+ topP float32
+}
+
+const (
+ defaultBaseURL = "http://localhost:11434"
+ defaultModel = "llama3"
+)
+
+func (c *OllamaClient) Configure(config IAIConfig) error {
+ baseURL := config.GetBaseURL()
+ if baseURL == "" {
+ baseURL = defaultBaseURL
+ }
+ baseClientURL, err := url.Parse(baseURL)
+ if err != nil {
+ return err
+ }
+
+ proxyEndpoint := config.GetProxyEndpoint()
+ httpClient := http.DefaultClient
+ if proxyEndpoint != "" {
+ proxyUrl, err := url.Parse(proxyEndpoint)
+ if err != nil {
+ return err
+ }
+ transport := &http.Transport{
+ Proxy: http.ProxyURL(proxyUrl),
+ }
+
+ httpClient = &http.Client{
+ Transport: transport,
+ }
+ }
+
+ c.client = ollama.NewClient(baseClientURL, httpClient)
+ if c.client == nil {
+ return errors.New("error creating Ollama client")
+ }
+ c.model = config.GetModel()
+ if c.model == "" {
+ c.model = defaultModel
+ }
+ c.temperature = config.GetTemperature()
+ c.topP = config.GetTopP()
+ return nil
+}
+func (c *OllamaClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
+ req := &ollama.GenerateRequest{
+ Model: c.model,
+ Prompt: prompt,
+ Stream: new(bool),
+ Options: map[string]interface{}{
+ "temperature": c.temperature,
+ "top_p": c.topP,
+ },
+ }
+ completion := ""
+ respFunc := func(resp ollama.GenerateResponse) error {
+ completion = resp.Response
+ return nil
+ }
+ err := c.client.Generate(ctx, req, respFunc)
+ if err != nil {
+ return "", err
+ }
+ return completion, nil
+}
+func (a *OllamaClient) GetName() string {
+ return ollamaClientName
+}
diff --git a/pkg/ai/watsonxai.go b/pkg/ai/watsonxai.go
new file mode 100644
index 0000000000..f6ce81c1a6
--- /dev/null
+++ b/pkg/ai/watsonxai.go
@@ -0,0 +1,84 @@
+package ai
+
+import (
+ "os"
+ "fmt"
+ "context"
+ "errors"
+
+ wx "github.com/IBM/watsonx-go/pkg/models"
+)
+
+const watsonxAIClientName = "watsonxai"
+
+type WatsonxAIClient struct {
+ nopCloser
+
+ client *wx.Client
+ model string
+ temperature float32
+ topP float32
+ topK int32
+ maxNewTokens int
+}
+
+const (
+ modelMetallama = "ibm/granite-13b-chat-v2"
+)
+
+func (c *WatsonxAIClient) Configure(config IAIConfig) error {
+ if(config.GetModel() == "") {
+ c.model = config.GetModel()
+ } else {
+ c.model = modelMetallama
+ }
+ c.temperature = config.GetTemperature()
+ c.topP = config.GetTopP()
+ c.topK = config.GetTopK()
+ c.maxNewTokens = config.GetMaxTokens()
+
+ // WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY"
+ // WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID"
+ apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)
+
+ if apiKey == "" {
+ return errors.New("No watsonx API key provided")
+ }
+ if projectID == "" {
+ return errors.New("No watsonx project ID provided")
+ }
+
+ client, err := wx.NewClient(
+ wx.WithWatsonxAPIKey(apiKey),
+ wx.WithWatsonxProjectID(projectID),
+ )
+ if err != nil {
+ return fmt.Errorf("Failed to create client for testing. Error: %v", err)
+ }
+ c.client = client
+
+ return nil
+}
+
+func (c *WatsonxAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
+ result, err := c.client.GenerateText(
+ c.model,
+ prompt,
+ wx.WithTemperature((float64)(c.temperature)),
+ wx.WithTopP((float64)(c.topP)),
+ wx.WithTopK((uint)(c.topK)),
+ wx.WithMaxNewTokens((uint)(c.maxNewTokens)),
+ )
+ if err != nil {
+ return "", fmt.Errorf("Expected no error, but got an error: %v", err)
+ }
+ if result.Text == "" {
+ return "", errors.New("Expected a result, but got an empty string")
+ }
+
+ return result.Text, nil
+}
+
+func (c *WatsonxAIClient) GetName() string {
+ return watsonxAIClientName
+}