Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions pkg/client/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ import (
// CmabConfig holds CMAB configuration options exposed at the client level.
// This provides a stable public API while allowing internal cmab.Config to change.
type CmabConfig struct {
CacheSize int
CacheTTL time.Duration
HTTPTimeout time.Duration
Cache cache.CacheWithRemove // Custom cache implementation (Redis, etc.)
CacheSize int
CacheTTL time.Duration
HTTPTimeout time.Duration
Cache cache.CacheWithRemove // Custom cache implementation (Redis, etc.)
PredictionEndpoint string // Custom prediction endpoint template
}

// toCmabConfig converts client-level CmabConfig to internal cmab.Config
Expand All @@ -53,10 +54,11 @@ func (c *CmabConfig) toCmabConfig() *cmab.Config {
return nil
}
return &cmab.Config{
CacheSize: c.CacheSize,
CacheTTL: c.CacheTTL,
HTTPTimeout: c.HTTPTimeout,
Cache: c.Cache,
CacheSize: c.CacheSize,
CacheTTL: c.CacheTTL,
HTTPTimeout: c.HTTPTimeout,
Cache: c.Cache,
PredictionEndpoint: c.PredictionEndpoint,
}
}

Expand Down
43 changes: 43 additions & 0 deletions pkg/client/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,46 @@ func TestClientWithEmptyCmabConfig(t *testing.T) {
assert.NotNil(t, client.DecisionService)
client.Close()
}

func TestCmabConfigWithCustomPredictionEndpoint(t *testing.T) {
// Test that custom prediction endpoint is correctly set in CmabConfig
customEndpoint := "https://custom.endpoint.com/predict/%s"
cmabConfig := CmabConfig{
CacheSize: 100,
CacheTTL: time.Minute,
HTTPTimeout: 30 * time.Second,
PredictionEndpoint: customEndpoint,
}

factory := OptimizelyFactory{}
WithCmabConfig(&cmabConfig)(&factory)

assert.Equal(t, &cmabConfig, factory.cmabConfig)
assert.Equal(t, customEndpoint, factory.cmabConfig.PredictionEndpoint)
}

func TestCmabConfigToCmabConfig(t *testing.T) {
// Test the toCmabConfig conversion includes PredictionEndpoint
customEndpoint := "https://proxy.example.com/cmab/%s"
clientConfig := CmabConfig{
CacheSize: 200,
CacheTTL: 5 * time.Minute,
HTTPTimeout: 15 * time.Second,
PredictionEndpoint: customEndpoint,
}

internalConfig := clientConfig.toCmabConfig()

assert.NotNil(t, internalConfig)
assert.Equal(t, 200, internalConfig.CacheSize)
assert.Equal(t, 5*time.Minute, internalConfig.CacheTTL)
assert.Equal(t, 15*time.Second, internalConfig.HTTPTimeout)
assert.Equal(t, customEndpoint, internalConfig.PredictionEndpoint)
}

func TestCmabConfigToCmabConfigNil(t *testing.T) {
// Test that nil CmabConfig returns nil
var clientConfig *CmabConfig
internalConfig := clientConfig.toCmabConfig()
assert.Nil(t, internalConfig)
}
34 changes: 21 additions & 13 deletions pkg/cmab/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ import (
"github.com/optimizely/go-sdk/v2/pkg/logging"
)

// CMABPredictionEndpoint is the endpoint for CMAB predictions
var CMABPredictionEndpoint = "https://prediction.cmab.optimizely.com/predict/%s"

const (
// DefaultPredictionEndpoint is the default endpoint template for CMAB predictions
DefaultPredictionEndpoint = "https://prediction.cmab.optimizely.com/predict/%s"
// DefaultMaxRetries is the default number of retries for CMAB requests
DefaultMaxRetries = 1
// DefaultInitialBackoff is the default initial backoff duration
Expand Down Expand Up @@ -88,16 +87,18 @@ type RetryConfig struct {

// DefaultCmabClient implements the CmabClient interface
type DefaultCmabClient struct {
httpClient *http.Client
retryConfig *RetryConfig
logger logging.OptimizelyLogProducer
httpClient *http.Client
retryConfig *RetryConfig
logger logging.OptimizelyLogProducer
predictionEndpoint string
}

// ClientOptions defines options for creating a CMAB client
type ClientOptions struct {
HTTPClient *http.Client
RetryConfig *RetryConfig
Logger logging.OptimizelyLogProducer
HTTPClient *http.Client
RetryConfig *RetryConfig
Logger logging.OptimizelyLogProducer
PredictionEndpoint string
}

// NewDefaultCmabClient creates a new instance of DefaultCmabClient
Expand All @@ -118,10 +119,17 @@ func NewDefaultCmabClient(options ClientOptions) *DefaultCmabClient {
logger = logging.GetLogger("", "DefaultCmabClient")
}

// Use custom endpoint or default
predictionEndpoint := options.PredictionEndpoint
if predictionEndpoint == "" {
predictionEndpoint = DefaultPredictionEndpoint
}

return &DefaultCmabClient{
httpClient: httpClient,
retryConfig: retryConfig,
logger: logger,
httpClient: httpClient,
retryConfig: retryConfig,
logger: logger,
predictionEndpoint: predictionEndpoint,
}
}

Expand All @@ -134,7 +142,7 @@ func (c *DefaultCmabClient) FetchDecision(
) (string, error) {

// Create the URL
url := fmt.Sprintf(CMABPredictionEndpoint, ruleID)
url := fmt.Sprintf(c.predictionEndpoint, ruleID)

// Log the URL being called
c.logger.Debug(fmt.Sprintf("CMAB Prediction URL: %s", url))
Expand Down
114 changes: 66 additions & 48 deletions pkg/cmab/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,9 @@ func TestDefaultCmabClient_FetchDecision(t *testing.T) {
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
},
PredictionEndpoint: server.URL + "/%s",
})

// Override the endpoint for testing
originalEndpoint := CMABPredictionEndpoint
CMABPredictionEndpoint = server.URL + "/%s"
defer func() { CMABPredictionEndpoint = originalEndpoint }()

// Test with various attribute types
attributes := map[string]interface{}{
"string_attr": "string value",
Expand Down Expand Up @@ -205,13 +201,9 @@ func TestDefaultCmabClient_FetchDecision_WithRetry(t *testing.T) {
MaxBackoff: 100 * time.Millisecond,
BackoffMultiplier: 2.0,
},
PredictionEndpoint: server.URL + "/%s",
})

// Override the endpoint for testing
originalEndpoint := CMABPredictionEndpoint
CMABPredictionEndpoint = server.URL + "/%s"
defer func() { CMABPredictionEndpoint = originalEndpoint }()

// Test fetch decision with retry
attributes := map[string]interface{}{
"browser": "chrome",
Expand Down Expand Up @@ -253,13 +245,9 @@ func TestDefaultCmabClient_FetchDecision_ExhaustedRetries(t *testing.T) {
MaxBackoff: 100 * time.Millisecond,
BackoffMultiplier: 2.0,
},
PredictionEndpoint: server.URL + "/%s",
})

// Override the endpoint for testing
originalEndpoint := CMABPredictionEndpoint
CMABPredictionEndpoint = server.URL + "/%s"
defer func() { CMABPredictionEndpoint = originalEndpoint }()

// Test fetch decision with exhausted retries
attributes := map[string]interface{}{
"browser": "chrome",
Expand Down Expand Up @@ -300,14 +288,10 @@ func TestDefaultCmabClient_FetchDecision_NoRetryConfig(t *testing.T) {
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
},
RetryConfig: nil, // Explicitly set to nil to override default
RetryConfig: nil, // Explicitly set to nil to override default
PredictionEndpoint: server.URL + "/%s",
})

// Override the endpoint for testing
originalEndpoint := CMABPredictionEndpoint
CMABPredictionEndpoint = server.URL + "/%s"
defer func() { CMABPredictionEndpoint = originalEndpoint }()

// Test fetch decision without retry config
attributes := map[string]interface{}{
"browser": "chrome",
Expand Down Expand Up @@ -364,13 +348,9 @@ func TestDefaultCmabClient_FetchDecision_InvalidResponse(t *testing.T) {
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
},
PredictionEndpoint: server.URL + "/%s",
})

// Override the endpoint for testing
originalEndpoint := CMABPredictionEndpoint
CMABPredictionEndpoint = server.URL + "/%s"
defer func() { CMABPredictionEndpoint = originalEndpoint }()

// Test fetch decision with invalid response
attributes := map[string]interface{}{
"browser": "chrome",
Expand Down Expand Up @@ -407,14 +387,10 @@ func TestDefaultCmabClient_FetchDecision_NetworkErrors(t *testing.T) {
MaxBackoff: 100 * time.Millisecond,
BackoffMultiplier: 2.0,
},
Logger: mockLogger,
Logger: mockLogger,
PredictionEndpoint: "http://non-existent-server.example.com/%s",
})

// Set endpoint to a non-existent server
originalEndpoint := CMABPredictionEndpoint
CMABPredictionEndpoint = "http://non-existent-server.example.com/%s"
defer func() { CMABPredictionEndpoint = originalEndpoint }()

// Test fetch decision with network error
attributes := map[string]interface{}{
"browser": "chrome",
Expand Down Expand Up @@ -468,13 +444,9 @@ func TestDefaultCmabClient_ExponentialBackoff(t *testing.T) {
MaxBackoff: 1 * time.Second,
BackoffMultiplier: 2.0,
},
PredictionEndpoint: server.URL + "/%s",
})

// Override the endpoint for testing
originalEndpoint := CMABPredictionEndpoint
CMABPredictionEndpoint = server.URL + "/%s"
defer func() { CMABPredictionEndpoint = originalEndpoint }()

// Test fetch decision with exponential backoff
attributes := map[string]interface{}{
"browser": "chrome",
Expand Down Expand Up @@ -555,14 +527,10 @@ func TestDefaultCmabClient_LoggingBehavior(t *testing.T) {
MaxBackoff: 100 * time.Millisecond,
BackoffMultiplier: 2.0,
},
Logger: mockLogger,
Logger: mockLogger,
PredictionEndpoint: server.URL + "/%s",
})

// Override the endpoint for testing
originalEndpoint := CMABPredictionEndpoint
CMABPredictionEndpoint = server.URL + "/%s"
defer func() { CMABPredictionEndpoint = originalEndpoint }()

// Test fetch decision
attributes := map[string]interface{}{
"browser": "chrome",
Expand Down Expand Up @@ -618,14 +586,10 @@ func TestDefaultCmabClient_NonSuccessStatusCode(t *testing.T) {
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
},
PredictionEndpoint: server.URL + "/%s",
// No retry config to simplify the test
})

// Override the endpoint for testing
originalEndpoint := CMABPredictionEndpoint
CMABPredictionEndpoint = server.URL + "/%s"
defer func() { CMABPredictionEndpoint = originalEndpoint }()

// Test fetch decision
attributes := map[string]interface{}{
"browser": "chrome",
Expand All @@ -641,3 +605,57 @@ func TestDefaultCmabClient_NonSuccessStatusCode(t *testing.T) {
})
}
}
func TestDefaultCmabClient_CustomPredictionEndpoint(t *testing.T) {
// Setup test server
customEndpointCalled := false
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
customEndpointCalled = true
// Verify the URL path contains the rule ID
assert.Contains(t, r.URL.Path, "rule456")

// Return a valid response
response := Response{
Predictions: []Prediction{
{VariationID: "variation789"},
},
}
json.NewEncoder(w).Encode(response)
}))
defer server.Close()

// Create client with custom prediction endpoint
customEndpoint := server.URL + "/custom/predict/%s"
client := NewDefaultCmabClient(ClientOptions{
PredictionEndpoint: customEndpoint,
})

// Test fetch decision
attributes := map[string]interface{}{
"age": 25,
}

variationID, err := client.FetchDecision("rule456", "user123", attributes, "test-uuid")

// Verify results
assert.NoError(t, err)
assert.Equal(t, "variation789", variationID)
assert.True(t, customEndpointCalled, "Custom endpoint should have been called")
}

func TestDefaultCmabClient_DefaultPredictionEndpoint(t *testing.T) {
// Create client without specifying prediction endpoint
client := NewDefaultCmabClient(ClientOptions{})

// Verify it uses the default endpoint
assert.Equal(t, DefaultPredictionEndpoint, client.predictionEndpoint)
}

func TestDefaultCmabClient_EmptyPredictionEndpointUsesDefault(t *testing.T) {
// Create client with empty prediction endpoint
client := NewDefaultCmabClient(ClientOptions{
PredictionEndpoint: "",
})

// Verify it uses the default endpoint when empty string is provided
assert.Equal(t, DefaultPredictionEndpoint, client.predictionEndpoint)
}
14 changes: 9 additions & 5 deletions pkg/cmab/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,19 @@ const (

// DefaultHTTPTimeout is the default HTTP timeout for CMAB requests
DefaultHTTPTimeout = 10 * time.Second

// DefaultPredictionEndpointTemplate is the default endpoint template for CMAB predictions
DefaultPredictionEndpointTemplate = "https://prediction.cmab.optimizely.com/predict/%s"
)

// Config holds CMAB configuration options
type Config struct {
CacheSize int
CacheTTL time.Duration
HTTPTimeout time.Duration
RetryConfig *RetryConfig
Cache cache.CacheWithRemove // Custom cache implementation (Redis, etc.)
CacheSize int
CacheTTL time.Duration
HTTPTimeout time.Duration
RetryConfig *RetryConfig
Cache cache.CacheWithRemove // Custom cache implementation (Redis, etc.)
PredictionEndpoint string // Custom prediction endpoint template
}

// NewDefaultConfig creates a Config with default values
Expand Down
Loading