Skip to content

Commit

Permalink
feat: oci genai multi-model, serving mode
Browse files Browse the repository at this point in the history
Signed-off-by: Anders Swanson <[email protected]>
  • Loading branch information
anders-swanson committed Aug 8, 2024
1 parent b2b8682 commit f45e673
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 25 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ require (
github.com/hupe1980/go-huggingface v0.0.15
github.com/kyverno/policy-reporter-kyverno-plugin v1.6.3
github.com/olekukonko/tablewriter v0.0.5
github.com/oracle/oci-go-sdk/v65 v65.65.1
github.com/oracle/oci-go-sdk/v65 v65.71.0
github.com/prometheus/prometheus v0.53.1
github.com/pterm/pterm v0.12.79
google.golang.org/api v0.187.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,8 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/oracle/oci-go-sdk/v65 v65.65.1 h1:sv7uD844tJGa2Vc+2KaByoXQ0FllZDGV/2+9MdxN6nA=
github.com/oracle/oci-go-sdk/v65 v65.65.1/go.mod h1:IBEV9l1qBzUpo7zgGaRUhbB05BVfcDGYRFBCPlTcPp0=
github.com/oracle/oci-go-sdk/v65 v65.71.0 h1:eEnFD/CzcoqdAA0xu+EmK32kJL3jfV0oLYNWVzoKNyo=
github.com/oracle/oci-go-sdk/v65 v65.71.0/go.mod h1:IBEV9l1qBzUpo7zgGaRUhbB05BVfcDGYRFBCPlTcPp0=
github.com/ovh/go-ovh v1.5.1 h1:P8O+7H+NQuFK9P/j4sFW5C0fvSS2DnHYGPwdVCp45wI=
github.com/ovh/go-ovh v1.5.1/go.mod h1:cTVDnl94z4tl8pP1uZ/8jlVxntjSIf09bNcQ5TJSC7c=
github.com/owenrumney/squealer v1.2.1 h1:4ryMMT59aaz8VMsqsD+FDkarADJz0F1dcq2fd0DRR+c=
Expand Down
125 changes: 101 additions & 24 deletions pkg/ai/ocigenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,33 @@ package ai
import (
"context"
"errors"
"fmt"
"github.com/oracle/oci-go-sdk/v65/common"
"github.com/oracle/oci-go-sdk/v65/generativeai"
"github.com/oracle/oci-go-sdk/v65/generativeaiinference"
"reflect"
"strings"
)

const ociClientName = "oci"

type ociModelVendor string

const (
vendorCohere = "cohere"
vendorMeta = "meta"
)

type OCIGenAIClient struct {
nopCloser

client *generativeaiinference.GenerativeAiInferenceClient
model string
model *generativeai.Model
modelId string
compartmentId string
temperature float32
topP float32
topK int32
maxTokens int
}

Expand All @@ -40,9 +52,10 @@ func (c *OCIGenAIClient) GetName() string {

func (c *OCIGenAIClient) Configure(config IAIConfig) error {
config.GetEndpointName()
c.model = config.GetModel()
c.modelId = config.GetModel()
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.topK = config.GetTopK()
c.maxTokens = config.GetMaxTokens()
c.compartmentId = config.GetCompartmentId()
provider := common.DefaultConfigProvider()
Expand All @@ -51,6 +64,12 @@ func (c *OCIGenAIClient) Configure(config IAIConfig) error {
return err
}
c.client = &client
model, err := c.getModel(provider)
if err != nil {
return err
}
c.model = model

return nil
}

Expand All @@ -60,38 +79,96 @@ func (c *OCIGenAIClient) GetCompletion(ctx context.Context, prompt string) (stri
if err != nil {
return "", err
}
return extractGeneratedText(generateTextResponse.InferenceResponse)
return c.extractGeneratedText(generateTextResponse.InferenceResponse)
}

func (c *OCIGenAIClient) newGenerateTextRequest(prompt string) generativeaiinference.GenerateTextRequest {
temperatureF64 := float64(c.temperature)
topPF64 := float64(c.topP)
return generativeaiinference.GenerateTextRequest{
GenerateTextDetails: generativeaiinference.GenerateTextDetails{
CompartmentId: &c.compartmentId,
ServingMode: generativeaiinference.OnDemandServingMode{
ModelId: &c.model,
},
InferenceRequest: generativeaiinference.CohereLlmInferenceRequest{
Prompt: &prompt,
MaxTokens: &c.maxTokens,
Temperature: &temperatureF64,
TopP: &topPF64,
},
CompartmentId: &c.compartmentId,
ServingMode: c.getServingMode(),
InferenceRequest: c.getInferenceRequest(prompt),
},
}
}

func extractGeneratedText(llmInferenceResponse generativeaiinference.LlmInferenceResponse) (string, error) {
response, ok := llmInferenceResponse.(generativeaiinference.CohereLlmInferenceResponse)
if !ok {
return "", errors.New("failed to extract generated text from backed response")
func (c *OCIGenAIClient) getServingMode() generativeaiinference.ServingMode {
if c.isBaseModel() {
return generativeaiinference.OnDemandServingMode{
ModelId: &c.modelId,
}
}
return generativeaiinference.DedicatedServingMode{
EndpointId: &c.modelId,
}
sb := strings.Builder{}
for _, text := range response.GeneratedTexts {
if text.Text != nil {
sb.WriteString(*text.Text)
}

func (c *OCIGenAIClient) getInferenceRequest(prompt string) generativeaiinference.LlmInferenceRequest {
temperatureF64 := float64(c.temperature)
topPF64 := float64(c.topP)
topK := int(c.topP)

switch c.getVendor() {
case vendorMeta:
return generativeaiinference.LlamaLlmInferenceRequest{
Prompt: &prompt,
MaxTokens: &c.maxTokens,
Temperature: &temperatureF64,
TopK: &topK,
TopP: &topPF64,
}
default: // Default to cohere
return generativeaiinference.CohereLlmInferenceRequest{
Prompt: &prompt,
MaxTokens: &c.maxTokens,
Temperature: &temperatureF64,
TopK: &topK,
TopP: &topPF64,
}
}
}

func (c *OCIGenAIClient) getModel(provider common.ConfigurationProvider) (*generativeai.Model, error) {
client, err := generativeai.NewGenerativeAiClientWithConfigurationProvider(provider)
if err != nil {
return nil, err
}
response, err := client.GetModel(context.Background(), generativeai.GetModelRequest{
ModelId: &c.modelId,
})
if err != nil {
return nil, err
}
return &response.Model, nil
}

func (c *OCIGenAIClient) isBaseModel() bool {
return c.model != nil && c.model.Type == generativeai.ModelTypeBase
}

func (c *OCIGenAIClient) getVendor() ociModelVendor {
if c.model == nil || c.model.Vendor == nil {
return ""
}
return ociModelVendor(*c.model.Vendor)
}

func (c *OCIGenAIClient) extractGeneratedText(llmInferenceResponse generativeaiinference.LlmInferenceResponse) (string, error) {
switch response := llmInferenceResponse.(type) {
case generativeaiinference.LlamaLlmInferenceResponse:
if len(response.Choices) > 0 && response.Choices[0].Text != nil {
return *response.Choices[0].Text, nil
}
return "", errors.New("no text found in oci response")
case generativeaiinference.CohereLlmInferenceResponse:
sb := strings.Builder{}
for _, text := range response.GeneratedTexts {
if text.Text != nil {
sb.WriteString(*text.Text)
}
}
return sb.String(), nil
default:
return "", fmt.Errorf("unknown oci response type: %s", reflect.TypeOf(llmInferenceResponse).Name())
}
return sb.String(), nil
}

0 comments on commit f45e673

Please sign in to comment.