diff --git a/pkg/client/factory.go b/pkg/client/factory.go index 050b0904..b2bbfafe 100644 --- a/pkg/client/factory.go +++ b/pkg/client/factory.go @@ -41,10 +41,11 @@ import ( // CmabConfig holds CMAB configuration options exposed at the client level. // This provides a stable public API while allowing internal cmab.Config to change. type CmabConfig struct { - CacheSize int - CacheTTL time.Duration - HTTPTimeout time.Duration - Cache cache.CacheWithRemove // Custom cache implementation (Redis, etc.) + CacheSize int + CacheTTL time.Duration + HTTPTimeout time.Duration + Cache cache.CacheWithRemove // Custom cache implementation (Redis, etc.) + PredictionEndpointTemplate string // Custom prediction endpoint template } // toCmabConfig converts client-level CmabConfig to internal cmab.Config @@ -53,10 +54,11 @@ func (c *CmabConfig) toCmabConfig() *cmab.Config { return nil } return &cmab.Config{ - CacheSize: c.CacheSize, - CacheTTL: c.CacheTTL, - HTTPTimeout: c.HTTPTimeout, - Cache: c.Cache, + CacheSize: c.CacheSize, + CacheTTL: c.CacheTTL, + HTTPTimeout: c.HTTPTimeout, + Cache: c.Cache, + PredictionEndpointTemplate: c.PredictionEndpointTemplate, } } diff --git a/pkg/client/factory_test.go b/pkg/client/factory_test.go index d6e35860..54845ae9 100644 --- a/pkg/client/factory_test.go +++ b/pkg/client/factory_test.go @@ -560,3 +560,46 @@ func TestClientWithEmptyCmabConfig(t *testing.T) { assert.NotNil(t, client.DecisionService) client.Close() } + +func TestCmabConfigWithCustomPredictionEndpoint(t *testing.T) { + // Test that custom prediction endpoint is correctly set in CmabConfig + customEndpoint := "https://custom.example.com/predict/%s" + cmabConfig := CmabConfig{ + CacheSize: 100, + CacheTTL: time.Minute, + HTTPTimeout: 30 * time.Second, + PredictionEndpointTemplate: customEndpoint, + } + + factory := OptimizelyFactory{} + WithCmabConfig(&cmabConfig)(&factory) + + assert.Equal(t, &cmabConfig, factory.cmabConfig) + assert.Equal(t, customEndpoint, factory.cmabConfig.PredictionEndpointTemplate) +} + +func TestCmabConfigToCmabConfig(t *testing.T) { + // Test the toCmabConfig conversion includes PredictionEndpointTemplate + customEndpoint := "https://proxy.example.com/cmab/%s" + clientConfig := CmabConfig{ + CacheSize: 200, + CacheTTL: 5 * time.Minute, + HTTPTimeout: 15 * time.Second, + PredictionEndpointTemplate: customEndpoint, + } + + internalConfig := clientConfig.toCmabConfig() + + assert.NotNil(t, internalConfig) + assert.Equal(t, 200, internalConfig.CacheSize) + assert.Equal(t, 5*time.Minute, internalConfig.CacheTTL) + assert.Equal(t, 15*time.Second, internalConfig.HTTPTimeout) + assert.Equal(t, customEndpoint, internalConfig.PredictionEndpointTemplate) +} + +func TestCmabConfigToCmabConfigNil(t *testing.T) { + // Test that nil CmabConfig returns nil + var clientConfig *CmabConfig + internalConfig := clientConfig.toCmabConfig() + assert.Nil(t, internalConfig) +} diff --git a/pkg/cmab/client.go b/pkg/cmab/client.go index 588e518e..f4d9835e 100644 --- a/pkg/cmab/client.go +++ b/pkg/cmab/client.go @@ -30,10 +30,9 @@ import ( "github.com/optimizely/go-sdk/v2/pkg/logging" ) -// CMABPredictionEndpoint is the endpoint for CMAB predictions -var CMABPredictionEndpoint = "https://prediction.cmab.optimizely.com/predict/%s" - const ( + // DefaultPredictionEndpointTemplate is the default endpoint template for CMAB predictions + DefaultPredictionEndpointTemplate = "https://prediction.cmab.optimizely.com/predict/%s" // DefaultMaxRetries is the default number of retries for CMAB requests DefaultMaxRetries = 1 // DefaultInitialBackoff is the default initial backoff duration @@ -88,16 +87,18 @@ type RetryConfig struct { // DefaultCmabClient implements the CmabClient interface type DefaultCmabClient struct { - httpClient *http.Client - retryConfig *RetryConfig - logger logging.OptimizelyLogProducer + httpClient *http.Client + retryConfig *RetryConfig + logger logging.OptimizelyLogProducer + predictionEndpoint string } // ClientOptions defines options for creating a CMAB client type ClientOptions struct { - HTTPClient *http.Client - RetryConfig *RetryConfig - Logger logging.OptimizelyLogProducer + HTTPClient *http.Client + RetryConfig *RetryConfig + Logger logging.OptimizelyLogProducer + PredictionEndpointTemplate string } // NewDefaultCmabClient creates a new instance of DefaultCmabClient @@ -118,10 +119,17 @@ func NewDefaultCmabClient(options ClientOptions) *DefaultCmabClient { logger = logging.GetLogger("", "DefaultCmabClient") } + // Use custom endpoint or default + predictionEndpoint := options.PredictionEndpointTemplate + if predictionEndpoint == "" { + predictionEndpoint = DefaultPredictionEndpointTemplate + } + return &DefaultCmabClient{ - httpClient: httpClient, - retryConfig: retryConfig, - logger: logger, + httpClient: httpClient, + retryConfig: retryConfig, + logger: logger, + predictionEndpoint: predictionEndpoint, } } @@ -134,7 +142,7 @@ func (c *DefaultCmabClient) FetchDecision( ) (string, error) { // Create the URL - url := fmt.Sprintf(CMABPredictionEndpoint, ruleID) + url := fmt.Sprintf(c.predictionEndpoint, ruleID) // Log the URL being called c.logger.Debug(fmt.Sprintf("CMAB Prediction URL: %s", url)) diff --git a/pkg/cmab/client_test.go b/pkg/cmab/client_test.go index 9080a0f5..dbe32fd5 100644 --- a/pkg/cmab/client_test.go +++ b/pkg/cmab/client_test.go @@ -124,13 +124,9 @@ func TestDefaultCmabClient_FetchDecision(t *testing.T) { HTTPClient: &http.Client{ Timeout: 5 * time.Second, }, + PredictionEndpointTemplate: server.URL + "/%s", }) - // Override the endpoint for testing - originalEndpoint := CMABPredictionEndpoint - CMABPredictionEndpoint = server.URL + "/%s" - defer func() { CMABPredictionEndpoint = originalEndpoint }() - // Test with various attribute types attributes := map[string]interface{}{ "string_attr": "string value", @@ -205,13 +201,9 @@ func TestDefaultCmabClient_FetchDecision_WithRetry(t *testing.T) { MaxBackoff: 100 * time.Millisecond, BackoffMultiplier: 2.0, }, + PredictionEndpointTemplate: server.URL + "/%s", }) - // Override the endpoint for testing - originalEndpoint := CMABPredictionEndpoint - CMABPredictionEndpoint = server.URL + "/%s" - defer func() { CMABPredictionEndpoint = originalEndpoint }() - // Test fetch decision with retry attributes := map[string]interface{}{ "browser": "chrome", @@ -253,13 +245,9 @@ func TestDefaultCmabClient_FetchDecision_ExhaustedRetries(t *testing.T) { MaxBackoff: 100 * time.Millisecond, BackoffMultiplier: 2.0, }, + PredictionEndpointTemplate: server.URL + "/%s", }) - // Override the endpoint for testing - originalEndpoint := CMABPredictionEndpoint - CMABPredictionEndpoint = server.URL + "/%s" - defer func() { CMABPredictionEndpoint = originalEndpoint }() - // Test fetch decision with exhausted retries attributes := map[string]interface{}{ "browser": "chrome", @@ -300,14 +288,10 @@ func TestDefaultCmabClient_FetchDecision_NoRetryConfig(t *testing.T) { HTTPClient: &http.Client{ Timeout: 5 * time.Second, }, - RetryConfig: nil, // Explicitly set to nil to override default + RetryConfig: nil, // Explicitly set to nil to override default + PredictionEndpointTemplate: server.URL + "/%s", }) - // Override the endpoint for testing - originalEndpoint := CMABPredictionEndpoint - CMABPredictionEndpoint = server.URL + "/%s" - defer func() { CMABPredictionEndpoint = originalEndpoint }() - // Test fetch decision without retry config attributes := map[string]interface{}{ "browser": "chrome", @@ -364,13 +348,9 @@ func TestDefaultCmabClient_FetchDecision_InvalidResponse(t *testing.T) { HTTPClient: &http.Client{ Timeout: 5 * time.Second, }, + PredictionEndpointTemplate: server.URL + "/%s", }) - // Override the endpoint for testing - originalEndpoint := CMABPredictionEndpoint - CMABPredictionEndpoint = server.URL + "/%s" - defer func() { CMABPredictionEndpoint = originalEndpoint }() - // Test fetch decision with invalid response attributes := map[string]interface{}{ "browser": "chrome", @@ -407,14 +387,10 @@ func TestDefaultCmabClient_FetchDecision_NetworkErrors(t *testing.T) { MaxBackoff: 100 * time.Millisecond, BackoffMultiplier: 2.0, }, - Logger: mockLogger, + Logger: mockLogger, + PredictionEndpointTemplate: "http://non-existent-server.example.com/%s", }) - // Set endpoint to a non-existent server - originalEndpoint := CMABPredictionEndpoint - CMABPredictionEndpoint = "http://non-existent-server.example.com/%s" - defer func() { CMABPredictionEndpoint = originalEndpoint }() - // Test fetch decision with network error attributes := map[string]interface{}{ "browser": "chrome", @@ -468,13 +444,9 @@ func TestDefaultCmabClient_ExponentialBackoff(t *testing.T) { MaxBackoff: 1 * time.Second, BackoffMultiplier: 2.0, }, + PredictionEndpointTemplate: server.URL + "/%s", }) - // Override the endpoint for testing - originalEndpoint := CMABPredictionEndpoint - CMABPredictionEndpoint = server.URL + "/%s" - defer func() { CMABPredictionEndpoint = originalEndpoint }() - // Test fetch decision with exponential backoff attributes := map[string]interface{}{ "browser": "chrome", @@ -555,14 +527,10 @@ func TestDefaultCmabClient_LoggingBehavior(t *testing.T) { MaxBackoff: 100 * time.Millisecond, BackoffMultiplier: 2.0, }, - Logger: mockLogger, + Logger: mockLogger, + PredictionEndpointTemplate: server.URL + "/%s", }) - // Override the endpoint for testing - originalEndpoint := CMABPredictionEndpoint - CMABPredictionEndpoint = server.URL + "/%s" - defer func() { CMABPredictionEndpoint = originalEndpoint }() - // Test fetch decision attributes := map[string]interface{}{ "browser": "chrome", @@ -618,14 +586,10 @@ func TestDefaultCmabClient_NonSuccessStatusCode(t *testing.T) { HTTPClient: &http.Client{ Timeout: 5 * time.Second, }, + PredictionEndpointTemplate: server.URL + "/%s", // No retry config to simplify the test }) - // Override the endpoint for testing - originalEndpoint := CMABPredictionEndpoint - CMABPredictionEndpoint = server.URL + "/%s" - defer func() { CMABPredictionEndpoint = originalEndpoint }() - // Test fetch decision attributes := map[string]interface{}{ "browser": "chrome", @@ -641,3 +605,57 @@ func TestDefaultCmabClient_NonSuccessStatusCode(t *testing.T) { }) } } +func TestDefaultCmabClient_CustomPredictionEndpoint(t *testing.T) { + // Setup test server + customEndpointCalled := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + customEndpointCalled = true + // Verify the URL path contains the rule ID + assert.Contains(t, r.URL.Path, "rule456") + + // Return a valid response + response := Response{ + Predictions: []Prediction{ + {VariationID: "variation789"}, + }, + } + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // Create client with custom prediction endpoint + customEndpoint := server.URL + "/custom/predict/%s" + client := NewDefaultCmabClient(ClientOptions{ + PredictionEndpointTemplate: customEndpoint, + }) + + // Test fetch decision + attributes := map[string]interface{}{ + "age": 25, + } + + variationID, err := client.FetchDecision("rule456", "user123", attributes, "test-uuid") + + // Verify results + assert.NoError(t, err) + assert.Equal(t, "variation789", variationID) + assert.True(t, customEndpointCalled, "Custom endpoint should have been called") +} + +func TestDefaultCmabClient_DefaultPredictionEndpointTemplate(t *testing.T) { + // Create client without specifying prediction endpoint + client := NewDefaultCmabClient(ClientOptions{}) + + // Verify it uses the default endpoint + assert.Equal(t, DefaultPredictionEndpointTemplate, client.predictionEndpoint) +} + +func TestDefaultCmabClient_EmptyPredictionEndpointUsesDefault(t *testing.T) { + // Create client with empty prediction endpoint + client := NewDefaultCmabClient(ClientOptions{ + PredictionEndpointTemplate: "", + }) + + // Verify it uses the default endpoint when empty string is provided + assert.Equal(t, DefaultPredictionEndpointTemplate, client.predictionEndpoint) +} diff --git a/pkg/cmab/config.go b/pkg/cmab/config.go index 9118b380..9488a10b 100644 --- a/pkg/cmab/config.go +++ b/pkg/cmab/config.go @@ -36,19 +36,21 @@ const ( // Config holds CMAB configuration options type Config struct { - CacheSize int - CacheTTL time.Duration - HTTPTimeout time.Duration - RetryConfig *RetryConfig - Cache cache.CacheWithRemove // Custom cache implementation (Redis, etc.) + CacheSize int + CacheTTL time.Duration + HTTPTimeout time.Duration + RetryConfig *RetryConfig + Cache cache.CacheWithRemove // Custom cache implementation (Redis, etc.) + PredictionEndpointTemplate string // Custom prediction endpoint template } // NewDefaultConfig creates a Config with default values func NewDefaultConfig() Config { return Config{ - CacheSize: DefaultCacheSize, - CacheTTL: DefaultCacheTTL, - HTTPTimeout: DefaultHTTPTimeout, + CacheSize: DefaultCacheSize, + CacheTTL: DefaultCacheTTL, + HTTPTimeout: DefaultHTTPTimeout, + PredictionEndpointTemplate: DefaultPredictionEndpointTemplate, RetryConfig: &RetryConfig{ MaxRetries: DefaultMaxRetries, }, diff --git a/pkg/cmab/config_test.go b/pkg/cmab/config_test.go index 16d28e63..26aeba10 100644 --- a/pkg/cmab/config_test.go +++ b/pkg/cmab/config_test.go @@ -28,6 +28,7 @@ func TestNewDefaultConfig(t *testing.T) { assert.Equal(t, DefaultCacheSize, config.CacheSize) assert.Equal(t, DefaultCacheTTL, config.CacheTTL) assert.Equal(t, DefaultHTTPTimeout, config.HTTPTimeout) + assert.Equal(t, DefaultPredictionEndpointTemplate, config.PredictionEndpointTemplate) assert.NotNil(t, config.RetryConfig) assert.Equal(t, DefaultMaxRetries, config.RetryConfig.MaxRetries) assert.Nil(t, config.Cache) // Should be nil by default diff --git a/pkg/decision/experiment_cmab_service.go b/pkg/decision/experiment_cmab_service.go index 53dd9c23..32a2bb2d 100644 --- a/pkg/decision/experiment_cmab_service.go +++ b/pkg/decision/experiment_cmab_service.go @@ -52,6 +52,7 @@ func NewExperimentCmabService(sdkKey string, config *cmab.Config) *ExperimentCma var httpTimeout time.Duration var retryConfig *cmab.RetryConfig var customCache cache.CacheWithRemove + var predictionEndpoint string if config == nil { // Use all defaults @@ -83,6 +84,11 @@ func NewExperimentCmabService(sdkKey string, config *cmab.Config) *ExperimentCma httpTimeout = cmab.DefaultHTTPTimeout } + predictionEndpoint = config.PredictionEndpointTemplate + if predictionEndpoint == "" { + predictionEndpoint = cmab.DefaultPredictionEndpointTemplate + } + // Handle retry config if config.RetryConfig == nil { retryConfig = &cmab.RetryConfig{ @@ -124,9 +130,10 @@ func NewExperimentCmabService(sdkKey string, config *cmab.Config) *ExperimentCma // Create CMAB client options cmabClientOptions := cmab.ClientOptions{ - HTTPClient: httpClient, - RetryConfig: retryConfig, - Logger: logging.GetLogger(sdkKey, "DefaultCmabClient"), + HTTPClient: httpClient, + RetryConfig: retryConfig, + Logger: logging.GetLogger(sdkKey, "DefaultCmabClient"), + PredictionEndpointTemplate: predictionEndpoint, } // Create CMAB client with adapter to match interface