diff --git a/build.gradle.kts b/build.gradle.kts index a06e364..8ffa4c8 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -24,6 +24,7 @@ dependencies { testImplementation("io.mockk:mockk:${Versions.mockk}") implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:${Versions.serializationRuntime}") implementation("com.squareup.okhttp3:okhttp:${Versions.okhttp}") + implementation("com.squareup.okhttp3:okhttp-sse:${Versions.okhttpSse}") implementation("com.amplitude:evaluation-core:${Versions.evaluationCore}") implementation("com.amplitude:java-sdk:${Versions.amplitudeAnalytics}") implementation("org.json:json:${Versions.json}") diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index cfc233d..098ce65 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -6,6 +6,7 @@ object Versions { const val serializationRuntime = "1.4.1" const val json = "20231013" const val okhttp = "4.12.0" + const val okhttpSse = "4.12.0" // Update this alongside okhttp. Note this library isn't stable and may contain breaking changes. Search uses of okhttp3.internal classes before updating. const val evaluationCore = "2.0.0-beta.2" const val amplitudeAnalytics = "1.12.0" const val mockk = "1.13.9" diff --git a/src/main/kotlin/LocalEvaluationClient.kt b/src/main/kotlin/LocalEvaluationClient.kt index 7867e28..d70ed6f 100644 --- a/src/main/kotlin/LocalEvaluationClient.kt +++ b/src/main/kotlin/LocalEvaluationClient.kt @@ -19,6 +19,7 @@ import com.amplitude.experiment.evaluation.EvaluationEngineImpl import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.experiment.evaluation.topologicalSort import com.amplitude.experiment.flag.DynamicFlagConfigApi +import com.amplitude.experiment.flag.FlagConfigStreamApi import com.amplitude.experiment.flag.InMemoryFlagConfigStorage import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper import com.amplitude.experiment.util.Logger @@ -42,8 +43,12 @@ class LocalEvaluationClient internal constructor( ) { private val assignmentService: AssignmentService? = createAssignmentService(apiKey) private val serverUrl: HttpUrl = getServerUrl(config) + private val streamServerUrl: HttpUrl = getStreamServerUrl(config) private val evaluation: EvaluationEngine = EvaluationEngineImpl() - private val flagConfigApi = DynamicFlagConfigApi(apiKey, serverUrl, getProxyUrl(config), httpClient, metrics) + private val flagConfigApi = DynamicFlagConfigApi(apiKey, serverUrl, null, httpClient, metrics) + private val proxyUrl: HttpUrl? = getProxyUrl(config) + private val flagConfigProxyApi = if (proxyUrl == null) null else DynamicFlagConfigApi(apiKey, proxyUrl, null, httpClient) + private val flagConfigStreamApi = if (config.streamUpdates) FlagConfigStreamApi(apiKey, streamServerUrl, httpClient, config.streamFlagConnTimeoutMillis) else null private val flagConfigStorage = InMemoryFlagConfigStorage() private val cohortStorage = if (config.cohortSyncConfig == null) { null @@ -60,6 +65,8 @@ class LocalEvaluationClient internal constructor( private val deploymentRunner = DeploymentRunner( config = config, flagConfigApi = flagConfigApi, + flagConfigProxyApi = flagConfigProxyApi, + flagConfigStreamApi = flagConfigStreamApi, flagConfigStorage = flagConfigStorage, cohortApi = cohortApi, cohortStorage = cohortStorage, @@ -190,6 +197,17 @@ private fun getServerUrl(config: LocalEvaluationConfig): HttpUrl { } } +private fun getStreamServerUrl(config: LocalEvaluationConfig): HttpUrl { + return if (config.streamServerUrl == LocalEvaluationConfig.Defaults.STREAM_SERVER_URL) { + when (config.serverZone) { + ServerZone.US -> US_STREAM_SERVER_URL.toHttpUrl() + ServerZone.EU -> EU_STREAM_SERVER_URL.toHttpUrl() + } + } else { + config.streamServerUrl.toHttpUrl() + } +} + private fun getProxyUrl(config: LocalEvaluationConfig): HttpUrl? { return config.evaluationProxyConfig?.proxyUrl?.toHttpUrl() } diff --git a/src/main/kotlin/LocalEvaluationConfig.kt b/src/main/kotlin/LocalEvaluationConfig.kt index 0475204..9516114 100644 --- a/src/main/kotlin/LocalEvaluationConfig.kt +++ b/src/main/kotlin/LocalEvaluationConfig.kt @@ -22,6 +22,12 @@ class LocalEvaluationConfig internal constructor( @JvmField val flagConfigPollerRequestTimeoutMillis: Long = Defaults.FLAG_CONFIG_POLLER_REQUEST_TIMEOUT_MILLIS, @JvmField + val streamUpdates: Boolean = Defaults.STREAM_UPDATES, + @JvmField + val streamServerUrl: String = Defaults.STREAM_SERVER_URL, + @JvmField + val streamFlagConnTimeoutMillis: Long = Defaults.STREAM_FLAG_CONN_TIMEOUT_MILLIS, + @JvmField val assignmentConfiguration: AssignmentConfiguration? = Defaults.ASSIGNMENT_CONFIGURATION, @JvmField val cohortSyncConfig: CohortSyncConfig? = Defaults.COHORT_SYNC_CONFIGURATION, @@ -76,6 +82,12 @@ class LocalEvaluationConfig internal constructor( */ const val FLAG_CONFIG_POLLER_REQUEST_TIMEOUT_MILLIS = 10_000L + const val STREAM_UPDATES = false + + const val STREAM_SERVER_URL = US_STREAM_SERVER_URL + + const val STREAM_FLAG_CONN_TIMEOUT_MILLIS = 1_500L + /** * null */ @@ -111,6 +123,9 @@ class LocalEvaluationConfig internal constructor( private var serverUrl = Defaults.SERVER_URL private var flagConfigPollerIntervalMillis = Defaults.FLAG_CONFIG_POLLER_INTERVAL_MILLIS private var flagConfigPollerRequestTimeoutMillis = Defaults.FLAG_CONFIG_POLLER_REQUEST_TIMEOUT_MILLIS + private var streamUpdates = Defaults.STREAM_UPDATES + private var streamServerUrl = Defaults.STREAM_SERVER_URL + private var streamFlagConnTimeoutMillis = Defaults.STREAM_FLAG_CONN_TIMEOUT_MILLIS private var assignmentConfiguration = Defaults.ASSIGNMENT_CONFIGURATION private var cohortSyncConfiguration = Defaults.COHORT_SYNC_CONFIGURATION private var evaluationProxyConfiguration = Defaults.EVALUATION_PROXY_CONFIGURATION @@ -136,6 +151,18 @@ class LocalEvaluationConfig internal constructor( this.flagConfigPollerRequestTimeoutMillis = flagConfigPollerRequestTimeoutMillis } + fun streamUpdates(streamUpdates: Boolean) = apply { + this.streamUpdates = streamUpdates + } + + fun streamServerUrl(streamServerUrl: String) = apply { + this.streamServerUrl = streamServerUrl + } + + fun streamFlagConnTimeoutMillis(streamFlagConnTimeoutMillis: Long) = apply { + this.streamFlagConnTimeoutMillis = streamFlagConnTimeoutMillis + } + fun enableAssignmentTracking(assignmentConfiguration: AssignmentConfiguration) = apply { this.assignmentConfiguration = assignmentConfiguration } @@ -161,6 +188,9 @@ class LocalEvaluationConfig internal constructor( serverZone = serverZone, flagConfigPollerIntervalMillis = flagConfigPollerIntervalMillis, flagConfigPollerRequestTimeoutMillis = flagConfigPollerRequestTimeoutMillis, + streamUpdates = streamUpdates, + streamServerUrl = streamServerUrl, + streamFlagConnTimeoutMillis = streamFlagConnTimeoutMillis, assignmentConfiguration = assignmentConfiguration, cohortSyncConfig = cohortSyncConfiguration, evaluationProxyConfig = evaluationProxyConfiguration, @@ -207,6 +237,8 @@ interface LocalEvaluationMetrics { fun onFlagConfigFetch() fun onFlagConfigFetchFailure(exception: Exception) fun onFlagConfigFetchOriginFallback(exception: Exception) + fun onFlagConfigStream() + fun onFlagConfigStreamFailure(exception: Exception?) fun onCohortDownload() fun onCohortDownloadTooLarge(exception: Exception) fun onCohortDownloadFailure(exception: Exception) diff --git a/src/main/kotlin/ServerZone.kt b/src/main/kotlin/ServerZone.kt index 7cd81ed..6767459 100644 --- a/src/main/kotlin/ServerZone.kt +++ b/src/main/kotlin/ServerZone.kt @@ -2,6 +2,8 @@ package com.amplitude.experiment internal const val US_SERVER_URL = "https://api.lab.amplitude.com" internal const val EU_SERVER_URL = "https://api.lab.eu.amplitude.com" +internal const val US_STREAM_SERVER_URL = "https://stream.lab.amplitude.com" +internal const val EU_STREAM_SERVER_URL = "https://stream.lab.eu.amplitude.com" internal const val US_COHORT_SERVER_URL = "https://cohort-v2.lab.amplitude.com" internal const val EU_COHORT_SERVER_URL = "https://cohort-v2.lab.eu.amplitude.com" internal const val US_EVENT_SERVER_URL = "https://api2.amplitude.com/2/httpapi" diff --git a/src/main/kotlin/deployment/DeploymentRunner.kt b/src/main/kotlin/deployment/DeploymentRunner.kt index fcec6fe..5997df8 100644 --- a/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/src/main/kotlin/deployment/DeploymentRunner.kt @@ -1,31 +1,32 @@ -@file:OptIn(ExperimentalApi::class) - package com.amplitude.experiment.deployment -import com.amplitude.experiment.ExperimentalApi import com.amplitude.experiment.LocalEvaluationConfig import com.amplitude.experiment.LocalEvaluationMetrics import com.amplitude.experiment.cohort.CohortApi import com.amplitude.experiment.cohort.CohortLoader import com.amplitude.experiment.cohort.CohortStorage import com.amplitude.experiment.flag.FlagConfigApi +import com.amplitude.experiment.flag.FlagConfigFallbackRetryWrapper +import com.amplitude.experiment.flag.FlagConfigPoller import com.amplitude.experiment.flag.FlagConfigStorage +import com.amplitude.experiment.flag.FlagConfigStreamApi +import com.amplitude.experiment.flag.FlagConfigStreamer import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper import com.amplitude.experiment.util.Logger import com.amplitude.experiment.util.Once import com.amplitude.experiment.util.daemonFactory import com.amplitude.experiment.util.getAllCohortIds -import com.amplitude.experiment.util.wrapMetrics -import java.util.concurrent.CompletableFuture -import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.Executors import java.util.concurrent.TimeUnit private const val MIN_COHORT_POLLING_INTERVAL = 60000L +private const val FLAG_POLLING_JITTER = 1000L internal class DeploymentRunner( private val config: LocalEvaluationConfig, private val flagConfigApi: FlagConfigApi, + private val flagConfigProxyApi: FlagConfigApi? = null, + private val flagConfigStreamApi: FlagConfigStreamApi? = null, private val flagConfigStorage: FlagConfigStorage, cohortApi: CohortApi?, private val cohortStorage: CohortStorage?, @@ -39,21 +40,31 @@ internal class DeploymentRunner( null } private val cohortPollingInterval: Long = getCohortPollingInterval() + // Fallback in this order: proxy, stream, poll. + private val amplitudeFlagConfigPoller = FlagConfigFallbackRetryWrapper( + FlagConfigPoller(flagConfigApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), + null, + config.flagConfigPollerIntervalMillis, + ) + private val amplitudeFlagConfigUpdater = + if (flagConfigStreamApi != null) + FlagConfigFallbackRetryWrapper( + FlagConfigStreamer(flagConfigStreamApi, flagConfigStorage, cohortLoader, cohortStorage, metrics), + amplitudeFlagConfigPoller, + FLAG_POLLING_JITTER + ) + else amplitudeFlagConfigPoller + private val flagConfigUpdater = + if (flagConfigProxyApi != null) + FlagConfigFallbackRetryWrapper( + FlagConfigPoller(flagConfigProxyApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), + amplitudeFlagConfigPoller + ) + else + amplitudeFlagConfigUpdater fun start() = lock.once { - refresh() - poller.scheduleWithFixedDelay( - { - try { - refresh() - } catch (t: Throwable) { - Logger.e("Refresh flag configs failed.", t) - } - }, - config.flagConfigPollerIntervalMillis, - config.flagConfigPollerIntervalMillis, - TimeUnit.MILLISECONDS - ) + flagConfigUpdater.start() if (cohortLoader != null) { poller.scheduleWithFixedDelay( { @@ -79,63 +90,7 @@ internal class DeploymentRunner( fun stop() { poller.shutdown() - } - - fun refresh() { - Logger.d("Refreshing flag configs.") - // Get updated flags from the network. - val flagConfigs = wrapMetrics( - metric = metrics::onFlagConfigFetch, - failure = metrics::onFlagConfigFetchFailure, - ) { - flagConfigApi.getFlagConfigs() - } - - // Remove flags that no longer exist. - val flagKeys = flagConfigs.map { it.key }.toSet() - flagConfigStorage.removeIf { !flagKeys.contains(it.key) } - - // Get all flags from storage - val storageFlags = flagConfigStorage.getFlagConfigs() - - // Load cohorts for each flag if applicable and put the flag in storage. - val futures = ConcurrentHashMap>() - for (flagConfig in flagConfigs) { - if (cohortLoader == null) { - flagConfigStorage.putFlagConfig(flagConfig) - continue - } - val cohortIds = flagConfig.getAllCohortIds() - val storageCohortIds = storageFlags[flagConfig.key]?.getAllCohortIds() ?: emptySet() - val cohortsToLoad = cohortIds - storageCohortIds - if (cohortsToLoad.isEmpty()) { - flagConfigStorage.putFlagConfig(flagConfig) - continue - } - for (cohortId in cohortsToLoad) { - futures.putIfAbsent( - cohortId, - cohortLoader.loadCohort(cohortId).handle { _, exception -> - if (exception != null) { - Logger.e("Failed to load cohort $cohortId", exception) - } - flagConfigStorage.putFlagConfig(flagConfig) - } - ) - } - } - futures.values.forEach { it.join() } - - // Delete unused cohorts - if (cohortStorage != null) { - val flagCohortIds = flagConfigStorage.getFlagConfigs().values.toList().getAllCohortIds() - val storageCohortIds = cohortStorage.getCohorts().keys - val deletedCohortIds = storageCohortIds - flagCohortIds - for (deletedCohortId in deletedCohortIds) { - cohortStorage.deleteCohort(deletedCohortId) - } - } - Logger.d("Refreshed ${flagConfigs.size} flag configs.") + flagConfigUpdater.shutdown() } private fun getCohortPollingInterval(): Long { diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt new file mode 100644 index 0000000..58b89de --- /dev/null +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -0,0 +1,149 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.util.SseStream +import com.amplitude.experiment.util.StreamException +import com.amplitude.experiment.util.json +import kotlinx.serialization.decodeFromString +import okhttp3.HttpUrl +import okhttp3.OkHttpClient +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ExecutionException +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock + +internal open class FlagConfigStreamApiError(message: String?, cause: Throwable?) : Exception(message, cause) { + constructor(message: String?) : this(message, null) + constructor(cause: Throwable?) : this(cause?.toString(), cause) +} +internal class FlagConfigStreamApiConnTimeoutError : FlagConfigStreamApiError("Initial connection timed out") +internal class FlagConfigStreamApiDataCorruptError : FlagConfigStreamApiError("Stream data corrupted") +internal class FlagConfigStreamApiStreamError(cause: Throwable?) : FlagConfigStreamApiError("Stream error", cause) + +private const val CONNECTION_TIMEOUT_MILLIS_DEFAULT = 1500L +private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 17000L // keep alive sends at 15s interval. 2s grace period +private const val RECONN_INTERVAL_MILLIS_DEFAULT = 15 * 60 * 1000L +internal class FlagConfigStreamApi( + deploymentKey: String, + serverUrl: HttpUrl, + httpClient: OkHttpClient = OkHttpClient(), + val connectionTimeoutMillis: Long = CONNECTION_TIMEOUT_MILLIS_DEFAULT, + keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, + reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, +) { + private val lock: ReentrantLock = ReentrantLock() + val url = serverUrl.newBuilder().addPathSegments("sdk/stream/v1/flags").build() + private val stream: SseStream = SseStream( + "Api-Key $deploymentKey", + url, + httpClient, + connectionTimeoutMillis, + keepaliveTimeoutMillis, + reconnIntervalMillis + ) + + /** + * Connects to flag configs stream. + * This will ensure stream connects, first set of flags is received and processed successfully, then returns. + * If stream fails to connect, first set of flags is not received, or first set of flags did not process successfully, it throws. + */ + internal fun connect( + onInitUpdate: ((List) -> Unit)? = null, + onUpdate: ((List) -> Unit)? = null, + onError: ((Exception?) -> Unit)? = null + ) { + // Guarded by lock. Update to callbacks and waits can lead to race conditions. + lock.withLock { + val isDuringInit = AtomicBoolean(true) + val connectTimeoutFuture = CompletableFuture() + val updateTimeoutFuture = CompletableFuture() + val onSseUpdate: ((String) -> Unit) = { data -> + if (isDuringInit.getAndSet(false)) { + // Stream is establishing. First data received. + // Resolve timeout. + connectTimeoutFuture.complete(Unit) + + // Make sure valid data. + try { + val flags = getFlagsFromData(data) + + try { + if (onInitUpdate != null) { + onInitUpdate.invoke(flags) + } else { + onUpdate?.invoke(flags) + } + updateTimeoutFuture.complete(Unit) + } catch (e: Throwable) { + updateTimeoutFuture.completeExceptionally(e) + } + } catch (_: Throwable) { + updateTimeoutFuture.completeExceptionally(FlagConfigStreamApiDataCorruptError()) + } + } else { + // Stream has already established. + // Make sure valid data. + try { + val flags = getFlagsFromData(data) + + try { + onUpdate?.invoke(flags) + } catch (_: Throwable) { + // Don't care about application error. + } + } catch (_: Throwable) { + // Stream corrupted. Reconnect. + handleError(onError, FlagConfigStreamApiDataCorruptError()) + } + } + } + val onSseError: ((Throwable?) -> Unit) = { t -> + if (isDuringInit.getAndSet(false)) { + connectTimeoutFuture.completeExceptionally(t) + updateTimeoutFuture.completeExceptionally(t) + } else { + handleError(onError, FlagConfigStreamApiStreamError(t)) + } + } + stream.connect(onSseUpdate, onSseError) + + val t: Throwable + try { + connectTimeoutFuture.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS) + updateTimeoutFuture.get() + return + } catch (e: TimeoutException) { + // Timeouts should retry + t = FlagConfigStreamApiConnTimeoutError() + } catch (e: ExecutionException) { + val cause = e.cause + t = if (cause is StreamException) { + FlagConfigStreamApiStreamError(cause) + } else { + FlagConfigStreamApiError(e) + } + } catch (e: Throwable) { + t = FlagConfigStreamApiError(e) + } + close() + throw t + } + } + + internal fun close() { + // Not guarded by lock. close() can halt connect(). + stream.cancel() + } + + private fun getFlagsFromData(data: String): List { + return json.decodeFromString>(data) + } + + private fun handleError(onError: ((Exception?) -> Unit)?, e: Exception?) { + close() + onError?.invoke(e) + } +} diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt new file mode 100644 index 0000000..0412a48 --- /dev/null +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -0,0 +1,325 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.LocalEvaluationConfig +import com.amplitude.experiment.LocalEvaluationMetrics +import com.amplitude.experiment.cohort.CohortLoader +import com.amplitude.experiment.cohort.CohortStorage +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper +import com.amplitude.experiment.util.Logger +import com.amplitude.experiment.util.daemonFactory +import com.amplitude.experiment.util.getAllCohortIds +import com.amplitude.experiment.util.wrapMetrics +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Executors +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock +import kotlin.math.max +import kotlin.math.min + +/** + * Flag config updaters should receive flags through their own means (ex. http GET, SSE stream), + * or as wrapper of others. + * They all should have these methods to control their lifecycle. + */ +internal interface FlagConfigUpdater { + /** + * Start the updater. There can be multiple calls. + * If start fails, it should throw exception. The caller should handle error. + * If some other error happened while updating (already started successfully), it should call onError. + */ + fun start(onError: (() -> Unit)? = null) + + /** + * Stop should stop updater temporarily. There may be another start in the future. + * To stop completely, with intention to never start again, use shutdown() instead. + */ + fun stop() + + /** + * Destroy should stop the updater forever in preparation for server shutdown. + */ + fun shutdown() +} + +/** + * All flag config updaters should share this class, which contains a function to properly process flag updates. + */ +internal abstract class FlagConfigUpdaterBase( + private val flagConfigStorage: FlagConfigStorage, + private val cohortLoader: CohortLoader?, + private val cohortStorage: CohortStorage?, +) { + /** + * Call this method after receiving and parsing flag configs from network. + * This method updates flag configs into storage and download all cohorts if needed. + */ + protected fun update(flagConfigs: List) { + // Remove flags that no longer exist. + val flagKeys = flagConfigs.map { it.key }.toSet() + flagConfigStorage.removeIf { !flagKeys.contains(it.key) } + + // Get all flags from storage + val storageFlags = flagConfigStorage.getFlagConfigs() + + // Load cohorts for each flag if applicable and put the flag in storage. + val futures = ConcurrentHashMap>() + for (flagConfig in flagConfigs) { + if (cohortLoader == null) { + flagConfigStorage.putFlagConfig(flagConfig) + continue + } + val cohortIds = flagConfig.getAllCohortIds() + val storageCohortIds = storageFlags[flagConfig.key]?.getAllCohortIds() ?: emptySet() + val cohortsToLoad = cohortIds - storageCohortIds + if (cohortsToLoad.isEmpty()) { + flagConfigStorage.putFlagConfig(flagConfig) + continue + } + for (cohortId in cohortsToLoad) { + futures.putIfAbsent( + cohortId, + cohortLoader.loadCohort(cohortId).handle { _, exception -> + if (exception != null) { + Logger.e("Failed to load cohort $cohortId", exception) + } + flagConfigStorage.putFlagConfig(flagConfig) + } + ) + } + } + futures.values.forEach { it.join() } + + // Delete unused cohorts + if (cohortStorage != null) { + val flagCohortIds = flagConfigStorage.getFlagConfigs().values.toList().getAllCohortIds() + val storageCohortIds = cohortStorage.getCohorts().keys + val deletedCohortIds = storageCohortIds - flagCohortIds + for (deletedCohortId in deletedCohortIds) { + cohortStorage.deleteCohort(deletedCohortId) + } + } + Logger.d("Refreshed ${flagConfigs.size} flag configs.") + } +} + +/** + * This is the poller for flag configs. + * It keeps polling flag configs with specified interval until error occurs. + */ +internal class FlagConfigPoller( + private val flagConfigApi: FlagConfigApi, + storage: FlagConfigStorage, + cohortLoader: CohortLoader?, + cohortStorage: CohortStorage?, + private val config: LocalEvaluationConfig, + private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(), +) : FlagConfigUpdater, FlagConfigUpdaterBase( + storage, cohortLoader, cohortStorage +) { + private val lock: ReentrantLock = ReentrantLock() + private val pool = Executors.newScheduledThreadPool(1, daemonFactory) + private var scheduledFuture: ScheduledFuture<*>? = null // @GuardedBy(lock) + + /** + * Start will fetch once, then start poller to poll flag configs. + */ + override fun start(onError: (() -> Unit)?) { + refresh() + lock.withLock { + stopInternal() + scheduledFuture = pool.scheduleWithFixedDelay( + { + try { + refresh() + } catch (t: Throwable) { + Logger.e("Refresh flag configs failed.", t) + stop() + onError?.invoke() + } + }, + config.flagConfigPollerIntervalMillis, + config.flagConfigPollerIntervalMillis, + TimeUnit.MILLISECONDS + ) + } + } + + // @GuardedBy(lock) + private fun stopInternal() { + // Pause only stop the task scheduled. It doesn't stop the executor. + scheduledFuture?.cancel(true) + scheduledFuture = null + } + + override fun stop() { + lock.withLock { + stopInternal() + } + } + + override fun shutdown() { + lock.withLock { + // Stop the executor. + pool.shutdown() + } + } + + private fun refresh() { + Logger.d("Refreshing flag configs.") + // Get updated flags from the network. + val flagConfigs = wrapMetrics( + metric = metrics::onFlagConfigFetch, + failure = metrics::onFlagConfigFetchFailure, + ) { + flagConfigApi.getFlagConfigs() + } + + update(flagConfigs) + } +} + +/** + * Streamer for flag configs. This receives flag updates with an SSE connection. + */ +internal class FlagConfigStreamer( + private val flagConfigStreamApi: FlagConfigStreamApi, + storage: FlagConfigStorage, + cohortLoader: CohortLoader?, + cohortStorage: CohortStorage?, + private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper() +) : FlagConfigUpdater, FlagConfigUpdaterBase( + storage, cohortLoader, cohortStorage +) { + private val lock: ReentrantLock = ReentrantLock() + + /** + * Start makes sure it connects to stream and the first set of flag configs is loaded. + * Then, it will update the flags whenever there's a stream. + */ + override fun start(onError: (() -> Unit)?) { + lock.withLock { + val onStreamUpdate: ((List) -> Unit) = { flags -> + update(flags) + } + val onStreamError: ((Exception?) -> Unit) = { e -> + Logger.e("Stream flag configs streaming failed.", e) + metrics.onFlagConfigStreamFailure(e) + onError?.invoke() + } + wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) { + flagConfigStreamApi.connect(onStreamUpdate, onStreamUpdate, onStreamError) + } + } + } + + override fun stop() { + // Not guarded by lock. close() can cancel start(). + flagConfigStreamApi.close() + } + + override fun shutdown() = stop() +} + +private const val RETRY_DELAY_MILLIS_DEFAULT = 15 * 1000L +private const val MAX_JITTER_MILLIS_DEFAULT = 2000L + +/** + * This is a wrapper class around flag config updaters. + * This provides retry capability in case errors encountered during update asynchronously, as well as fallbacks when an updater failed. + * + * `mainUpdater` cannot be a FlagConfigFallbackRetryWrapper. + * The developer should restructure arguments to make sure `mainUpdater` is never a `FlagConfigFallbackRetryWrapper`. + * All retry and fallback structures can be normalized into `mainUpdater`s not being `FlagConfigFallbackRetryWrapper`s. + */ +internal class FlagConfigFallbackRetryWrapper( + private val mainUpdater: FlagConfigUpdater, + private val fallbackUpdater: FlagConfigUpdater?, + retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT, + maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, +) : FlagConfigUpdater { + private val lock: ReentrantLock = ReentrantLock() + private val reconnIntervalRange = max(0, retryDelayMillis - maxJitterMillis)..(min(retryDelayMillis, Long.MAX_VALUE - maxJitterMillis) + maxJitterMillis) + private val executor = Executors.newScheduledThreadPool(1, daemonFactory) + private var retryTask: ScheduledFuture<*>? = null // @GuardedBy(lock) + + /** + * Since the wrapper retries for mainUpdater, so there will never be error case. Thus, onError will never be called. + * + * During start, the wrapper tries to start main updater. + * If main start success, start success. + * If main start failed, fallback updater tries to start. + * If fallback start failed as well, throws exception. + * If fallback start success, start success, main enters retry loop. + * After started, if main failed, fallback is started and main enters retry loop. + * Fallback success or failures status is not monitored. It's suggested to wrap fallback into a retry wrapper. + */ + override fun start(onError: (() -> Unit)?) { + if (mainUpdater is FlagConfigFallbackRetryWrapper) { + throw Error("Do not use FlagConfigFallbackRetryWrapper as main updater. Fallback updater will never be used. Rewrite retry and fallback logic.") + } + + lock.withLock { + retryTask?.cancel(true) + + try { + mainUpdater.start { + lock.withLock { + scheduleRetry() // Don't care if poller start error or not, always retry. + try { + fallbackUpdater?.start() + } catch (_: Throwable) { + } + } + } + fallbackUpdater?.stop() + } catch (t: Throwable) { + Logger.e("Primary flag configs start failed, start fallback. Error: ", t) + if (fallbackUpdater == null) { + // No fallback, main start failed is wrapper start fail + throw t + } + fallbackUpdater.start() + scheduleRetry() + } + } + } + + override fun stop() { + lock.withLock { + mainUpdater.stop() + fallbackUpdater?.stop() + retryTask?.cancel(true) + } + } + + override fun shutdown() { + lock.withLock { + mainUpdater.shutdown() + fallbackUpdater?.shutdown() + retryTask?.cancel(true) + } + } + + // @GuardedBy(lock) + private fun scheduleRetry() { + retryTask = executor.schedule({ + try { + mainUpdater.start { + scheduleRetry() // Don't care if poller start error or not, always retry stream. + try { + fallbackUpdater?.start() + } catch (_: Throwable) { + } + } + fallbackUpdater?.stop() + } catch (_: Throwable) { + scheduleRetry() + } + }, reconnIntervalRange.random(), TimeUnit.MILLISECONDS) + } + } + \ No newline at end of file diff --git a/src/main/kotlin/util/Metrics.kt b/src/main/kotlin/util/Metrics.kt index bc06dbb..d680023 100644 --- a/src/main/kotlin/util/Metrics.kt +++ b/src/main/kotlin/util/Metrics.kt @@ -66,6 +66,16 @@ internal class LocalEvaluationMetricsWrapper( executor?.execute { metrics.onFlagConfigFetchFailure(exception) } } + override fun onFlagConfigStream() { + val metrics = metrics ?: return + executor?.execute { metrics.onFlagConfigStream() } + } + + override fun onFlagConfigStreamFailure(exception: Exception?) { + val metrics = metrics ?: return + executor?.execute { metrics.onFlagConfigStreamFailure(exception) } + } + override fun onFlagConfigFetchOriginFallback(exception: Exception) { val metrics = metrics ?: return executor?.execute { metrics.onFlagConfigFetchOriginFallback(exception) } diff --git a/src/main/kotlin/util/Request.kt b/src/main/kotlin/util/Request.kt index 2bd09fb..a8d9cc1 100644 --- a/src/main/kotlin/util/Request.kt +++ b/src/main/kotlin/util/Request.kt @@ -8,6 +8,9 @@ import okhttp3.HttpUrl import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener +import okhttp3.sse.EventSources import okio.IOException import java.util.concurrent.CompletableFuture @@ -60,7 +63,7 @@ private fun OkHttpClient.submit( return future } -private fun newGet( +internal fun newGet( serverUrl: HttpUrl, path: String? = null, headers: Map? = null, @@ -111,3 +114,8 @@ internal inline fun OkHttpClient.get( } } } + +internal fun OkHttpClient.newEventSource(request: Request, eventSourceListener: EventSourceListener): EventSource { + // Creates an event source and immediately returns it. The connection is performed async. + return EventSources.createFactory(this).newEventSource(request, eventSourceListener) +} diff --git a/src/main/kotlin/util/SseStream.kt b/src/main/kotlin/util/SseStream.kt new file mode 100644 index 0000000..0bd7f94 --- /dev/null +++ b/src/main/kotlin/util/SseStream.kt @@ -0,0 +1,158 @@ +package com.amplitude.experiment.util + +import okhttp3.HttpUrl +import okhttp3.OkHttpClient +import okhttp3.Response +import okhttp3.internal.http2.ErrorCode +import okhttp3.internal.http2.StreamResetException +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener +import java.util.Timer +import java.util.TimerTask +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.schedule +import kotlin.concurrent.withLock +import kotlin.math.max +import kotlin.math.min + +internal class StreamException(error: String) : Throwable(error) + +private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 0L // no timeout +private const val RECONN_INTERVAL_MILLIS_DEFAULT = 30 * 60 * 1000L +private const val MAX_JITTER_MILLIS_DEFAULT = 5000L +private const val KEEP_ALIVE_DATA = " " + +/** + * For establishing an SSE stream. + */ +internal class SseStream( + authToken: String, // Will be used in header as Authorization: + url: HttpUrl, // The full url to connect to. + httpClient: OkHttpClient = OkHttpClient(), + connectionTimeoutMillis: Long, // Timeout for establishing a connection, not including reading body. + keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, // Keep alive should receive within this timeout. + reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, // Reconnect every this interval. + maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, // Jitter for reconnection. +) { + private val lock: ReentrantLock = ReentrantLock() + private val reconnIntervalRange = max(0, reconnIntervalMillis - maxJitterMillis)..(min(reconnIntervalMillis, Long.MAX_VALUE - maxJitterMillis) + maxJitterMillis) + + private val request = newGet(url, null, mapOf("Authorization" to authToken, "Accept" to "text/event-stream")) + private val client = httpClient.newBuilder() // client.newBuilder reuses the connection pool in the same client with new configs. + .connectTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Connection timeout for establishing SSE. + .callTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Call timeout for establishing SSE. + .readTimeout(keepaliveTimeoutMillis, TimeUnit.MILLISECONDS) // Timeout between messages, keepalive in this case. + .writeTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) + .retryOnConnectionFailure(false) + .build() + + private var es: EventSource? = null // @GuardedBy(lock) + private var reconnectTimerTask: TimerTask? = null // @GuardedBy(lock) + private var onUpdate: ((String) -> Unit)? = null + private var onError: ((Throwable?) -> Unit)? = null + + private val eventSourceListener = object : EventSourceListener() { + override fun onOpen(eventSource: EventSource, response: Response) { + // No action needed. + } + + override fun onClosed(eventSource: EventSource) { + lock.withLock { + if ((eventSource != es)) { // Reference comparison. + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } + // Server closed the connection, just reconnect. + cancelSse() + } + connect(onUpdate, onError) + } + + override fun onEvent( + eventSource: EventSource, + id: String?, + type: String?, + data: String + ) { + lock.withLock { + if ((eventSource != es)) { // Reference comparison. + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } + } + // Keep alive data + if (KEEP_ALIVE_DATA == data) { + return + } + onUpdate?.invoke(data) + } + + override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { + lock.withLock { + if ((eventSource != es)) { // Reference comparison. + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } + if (t is StreamResetException && t.errorCode == ErrorCode.CANCEL) { + // Relying on okhttp3.internal to differentiate cancel case. + // Can be a pitfall later on. + return + } + cancelSse() + } + val err = t + ?: if (response != null) { + StreamException(response.toString()) + } else { + StreamException("Unknown stream failure") + } + onError?.invoke(err) + } + } + + /** + * Creates an event source and immediately returns. The connection is performed async. All errors are informed through callbacks. + * This the call may success even if stream cannot be established. + * + * @param onUpdate On stream update, this callback will be called. + * @param onError On stream error, this callback will be called. + */ + internal fun connect(onUpdate: ((String) -> Unit)?, onError: ((Throwable?) -> Unit)?) { + lock.withLock { + cancelSse() // Clear any existing event sources. + + this.onUpdate = onUpdate + this.onError = onError + es = client.newEventSource(request, eventSourceListener) + reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) { // Timer for a new event source. + // This forces client side reconnection after interval. + this@SseStream.cancel() + connect(onUpdate, onError) + } + } + } + + // @GuardedBy(lock) + private fun cancelSse() { + reconnectTimerTask?.cancel() + + // There can be cases where an event source is being cancelled by these calls, but take a long time and made a callback to onFailure callback. + es?.cancel() + es = null + } + + /** + * Cancels the current connection. + */ + internal fun cancel() { + lock.withLock { + cancelSse() + this.onUpdate = null + this.onError = null + } + } +} diff --git a/src/test/kotlin/LocalEvaluationClientTest.kt b/src/test/kotlin/LocalEvaluationClientTest.kt index 7ad24a0..9f612da 100644 --- a/src/test/kotlin/LocalEvaluationClientTest.kt +++ b/src/test/kotlin/LocalEvaluationClientTest.kt @@ -2,17 +2,25 @@ package com.amplitude.experiment import com.amplitude.experiment.cohort.Cohort import com.amplitude.experiment.cohort.CohortApi +import com.amplitude.experiment.flag.FlagConfigPoller +import io.mockk.clearAllMocks import io.mockk.every import io.mockk.mockk +import io.mockk.mockkConstructor import org.junit.Assert import org.junit.Assert.assertEquals import org.junit.Assert.assertNull import kotlin.system.measureNanoTime +import kotlin.test.AfterTest import kotlin.test.Test private const val API_KEY = "server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz" class LocalEvaluationClientTest { + @AfterTest + fun afterTest() { + clearAllMocks() + } @Test fun `test evaluate, all flags, success`() { @@ -193,6 +201,7 @@ class LocalEvaluationClientTest { assertEquals("on", userVariant?.key) assertEquals("on", userVariant?.value) } + @Test fun `evaluate with user, cohort tester targeted`() { val cohortConfig = LocalEvaluationConfig( @@ -238,6 +247,7 @@ class LocalEvaluationClientTest { assertEquals("on", groupVariant?.key) assertEquals("on", groupVariant?.value) } + @Test fun `evaluate with group, cohort tester targeted`() { val cohortConfig = LocalEvaluationConfig( @@ -261,4 +271,44 @@ class LocalEvaluationClientTest { assertEquals("on", groupVariant?.key) assertEquals("on", groupVariant?.value) } + + @Test + fun `test evaluate, stream flags, all flags, success`() { + mockkConstructor(FlagConfigPoller::class) + every { anyConstructed().start(any()) } answers { + throw Exception("Should use stream, may be flaky test when stream failed") + } + val client = LocalEvaluationClient(API_KEY, LocalEvaluationConfig(streamUpdates = true)) + client.start() + val variants = client.evaluate(ExperimentUser(userId = "test_user")) + val variant = variants["sdk-local-evaluation-ci-test"] + Assert.assertEquals(Variant(key = "on", value = "on", payload = "payload"), variant?.copy(metadata = null)) + } + + @Test + fun `evaluate with user, stream flags, cohort segment targeted`() { + mockkConstructor(FlagConfigPoller::class) + every { anyConstructed().start(any()) } answers { + throw Exception("Should use stream, may be flaky test when stream failed") + } + val cohortConfig = LocalEvaluationConfig( + streamUpdates = true, + cohortSyncConfig = CohortSyncConfig("api", "secret") + ) + val cohortApi = mockk().apply { + every { getCohort(eq("52gz3yi7"), allAny()) } returns Cohort("52gz3yi7", "User", 2, 1722363790000, setOf("1", "2")) + every { getCohort(eq("mv7fn2bp"), allAny()) } returns Cohort("mv7fn2bp", "User", 1, 1719350216000, setOf("67890", "12345")) + every { getCohort(eq("s4t57y32"), allAny()) } returns Cohort("s4t57y32", "org name", 1, 1722368285000, setOf("Amplitude Website (Portfolio)")) + every { getCohort(eq("k1lklnnb"), allAny()) } returns Cohort("k1lklnnb", "org id", 1, 1722466388000, setOf("1")) + } + val client = LocalEvaluationClient(API_KEY, cohortConfig, cohortApi = cohortApi) + client.start() + val user = ExperimentUser( + userId = "12345", + deviceId = "device_id", + ) + val userVariant = client.evaluateV2(user, setOf("sdk-local-evaluation-user-cohort-ci-test"))["sdk-local-evaluation-user-cohort-ci-test"] + assertEquals("on", userVariant?.key) + assertEquals("on", userVariant?.value) + } } diff --git a/src/test/kotlin/deployment/DeploymentRunnerTest.kt b/src/test/kotlin/deployment/DeploymentRunnerTest.kt index 8e58a79..9b74742 100644 --- a/src/test/kotlin/deployment/DeploymentRunnerTest.kt +++ b/src/test/kotlin/deployment/DeploymentRunnerTest.kt @@ -42,11 +42,11 @@ class DeploymentRunnerTest { val flagConfigStorage = Mockito.mock(FlagConfigStorage::class.java) val cohortStorage = Mockito.mock(CohortStorage::class.java) val runner = DeploymentRunner( - LocalEvaluationConfig(), - flagApi, - flagConfigStorage, - cohortApi, - cohortStorage, + config = LocalEvaluationConfig(), + flagConfigApi = flagApi, + flagConfigStorage = flagConfigStorage, + cohortApi = cohortApi, + cohortStorage = cohortStorage, ) Mockito.`when`(flagApi.getFlagConfigs()).thenThrow(RuntimeException("test")) try { @@ -71,10 +71,11 @@ class DeploymentRunnerTest { val flagConfigStorage = Mockito.mock(FlagConfigStorage::class.java) val cohortStorage = Mockito.mock(CohortStorage::class.java) val runner = DeploymentRunner( - LocalEvaluationConfig(), - flagApi, flagConfigStorage, - cohortApi, - cohortStorage, + config = LocalEvaluationConfig(), + flagConfigApi = flagApi, + flagConfigStorage = flagConfigStorage, + cohortApi = cohortApi, + cohortStorage = cohortStorage, ) Mockito.`when`(flagApi.getFlagConfigs()).thenReturn(listOf(flag)) Mockito.`when`(cohortApi.getCohort(COHORT_ID, null)).thenThrow(RuntimeException("test")) diff --git a/src/test/kotlin/flag/FlagConfigStreamApiTest.kt b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt new file mode 100644 index 0000000..2d553bc --- /dev/null +++ b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt @@ -0,0 +1,167 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.Experiment +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.util.SseStream +import io.mockk.clearAllMocks +import io.mockk.every +import io.mockk.mockkConstructor +import io.mockk.slot +import io.mockk.verify +import okhttp3.HttpUrl +import okhttp3.HttpUrl.Companion.toHttpUrl +import okhttp3.OkHttpClient +import java.util.concurrent.CompletableFuture +import java.util.concurrent.TimeUnit +import kotlin.test.AfterTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.fail + +class FlagConfigStreamApiTest { + private val onUpdateCapture = slot<((String) -> Unit)>() + private val onErrorCapture = slot<((Throwable?) -> Unit)>() + + @BeforeTest + fun beforeTest() { + mockkConstructor(SseStream::class) + + every { anyConstructed().connect(capture(onUpdateCapture), capture(onErrorCapture)) } answers { + Thread.sleep(1000) + } + every { anyConstructed().cancel() } answers { + Thread.sleep(1000) + } + } + + private fun anyConstructed(): Any { + TODO("Not yet implemented") + } + + @AfterTest + fun afterTest() { + clearAllMocks() + } + + private fun setupApi( + deploymentKey: String = "", + serverUrl: HttpUrl = "http://localhost".toHttpUrl(), + connTimeout: Long = 2000 + ): FlagConfigStreamApi { + val api = FlagConfigStreamApi(deploymentKey, serverUrl, OkHttpClient(), connTimeout, 10000) + + return api + } + + @Test + fun `Test passes correct arguments`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl()) + var data: Array> = arrayOf() + var err: Array = arrayOf() + + val run = async { + api.connect({ d -> + data += d + }, { d -> + data += d + }, { t -> + err += t + }) + } + Thread.sleep(100) + onUpdateCapture.captured("[{\"key\":\"flagkey\",\"variants\":{},\"segments\":[]}]") + run.join() + + verify { anyConstructed().connect(any(), any()) } + assertContentEquals(arrayOf(listOf(EvaluationFlag("flagkey", emptyMap(), emptyList()))), data) + + api.close() + } + + @Test + fun `Test conn timeout doesn't block`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 500) + try { + api.connect() + fail("Timeout not thrown") + } catch (_: FlagConfigStreamApiConnTimeoutError) { + } + Thread.sleep(100) + verify { anyConstructed().cancel() } + } + + @Test + fun `Test init update failure throws`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 2000) + + try { + api.connect({ + Thread.sleep(2100) // Update time is not included in connection timeout. + throw Error() + }) + fail("Timeout not thrown") + } catch (_: FlagConfigStreamApiConnTimeoutError) { + } + verify { anyConstructed().cancel() } + } + + @Test + fun `Test init update fallbacks to onUpdate when onInitUpdate = null`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 2000) + + try { + api.connect(null, { + Thread.sleep(2100) // Update time is not included in connection timeout. + throw Error() + }) + fail("Timeout not thrown") + } catch (_: FlagConfigStreamApiConnTimeoutError) { + } + verify { anyConstructed().cancel() } + } + + @Test + fun `Test error is passed through onError`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 2000) + var err: Array = arrayOf() + + val run = async { + api.connect({ d -> + assertEquals(listOf(), d) + }, { d -> + assertEquals(listOf(), d) + }, { t -> + err += t + }) + } + Thread.sleep(100) + onUpdateCapture.captured("[]") + run.join() + + assertEquals(0, err.size) + onErrorCapture.captured(Error("Haha error")) + assertEquals("Stream error", err[0]?.message) + assertEquals("Haha error", err[0]?.cause?.message) + assertEquals(1, err.size) + verify { anyConstructed().cancel() } + } +} + +@Suppress("SameParameterValue") +private fun async(delayMillis: Long = 0L, block: () -> T): CompletableFuture { + return if (delayMillis == 0L) { + CompletableFuture.supplyAsync(block) + } else { + val future = CompletableFuture() + Experiment.scheduler.schedule({ + try { + future.complete(block.invoke()) + } catch (t: Throwable) { + future.completeExceptionally(t) + } + }, delayMillis, TimeUnit.MILLISECONDS) + future + } +} diff --git a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt new file mode 100644 index 0000000..a0db63c --- /dev/null +++ b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt @@ -0,0 +1,409 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.LocalEvaluationConfig +import com.amplitude.experiment.evaluation.EvaluationFlag +import io.mockk.clearAllMocks +import io.mockk.every +import io.mockk.justRun +import io.mockk.mockk +import io.mockk.slot +import io.mockk.verify +import org.junit.Assert.assertEquals +import java.lang.Exception +import kotlin.test.AfterTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.fail + +private val FLAG1 = EvaluationFlag("key1", emptyMap(), emptyList()) +private val FLAG2 = EvaluationFlag("key2", emptyMap(), emptyList()) +class FlagConfigPollerTest { + private var fetchApi = mockk() + private var storage = InMemoryFlagConfigStorage() + + @BeforeTest + fun beforeTest() { + fetchApi = mockk() + storage = InMemoryFlagConfigStorage() + } + + @AfterTest + fun afterTest() { + clearAllMocks() + } + + @Test + fun `Test Poller`() { + every { fetchApi.getFlagConfigs() } returns emptyList() + val poller = FlagConfigPoller(fetchApi, storage, null, null, LocalEvaluationConfig(flagConfigPollerIntervalMillis = 1000)) + var errorCount = 0 + poller.start { errorCount++ } + + // start() polls + verify(exactly = 1) { fetchApi.getFlagConfigs() } + assertEquals(0, storage.getFlagConfigs().size) + + // Poller polls every 1s interval + every { fetchApi.getFlagConfigs() } returns listOf(FLAG1) + Thread.sleep(1100) + verify(exactly = 2) { fetchApi.getFlagConfigs() } + assertEquals(1, storage.getFlagConfigs().size) + assertEquals(mapOf(FLAG1.key to FLAG1), storage.getFlagConfigs()) + Thread.sleep(1100) + verify(exactly = 3) { fetchApi.getFlagConfigs() } + + // Stop poller stops + poller.stop() + Thread.sleep(1100) + verify(exactly = 3) { fetchApi.getFlagConfigs() } + + // Restart poller + every { fetchApi.getFlagConfigs() } returns listOf(FLAG1, FLAG2) + poller.start() + verify(exactly = 4) { fetchApi.getFlagConfigs() } + Thread.sleep(1100) + verify(exactly = 5) { fetchApi.getFlagConfigs() } + assertEquals(2, storage.getFlagConfigs().size) + assertEquals(mapOf(FLAG1.key to FLAG1, FLAG2.key to FLAG2), storage.getFlagConfigs()) + + // No errors + assertEquals(0, errorCount) + + poller.shutdown() + } + + @Test + fun `Test Poller start fails`() { + every { fetchApi.getFlagConfigs() } answers { throw Error("Haha error") } + val poller = FlagConfigPoller(fetchApi, storage, null, null, LocalEvaluationConfig(flagConfigPollerIntervalMillis = 1000)) + var errorCount = 0 + try { + poller.start { errorCount++ } + fail("Poller start error not throwing") + } catch (_: Throwable) { + } + verify(exactly = 1) { fetchApi.getFlagConfigs() } + + // Poller stops + Thread.sleep(1100) + verify(exactly = 1) { fetchApi.getFlagConfigs() } + assertEquals(0, errorCount) + + poller.shutdown() + } + + @Test + fun `Test Poller poll fails`() { + every { fetchApi.getFlagConfigs() } returns emptyList() + val poller = FlagConfigPoller(fetchApi, storage, null, null, LocalEvaluationConfig(flagConfigPollerIntervalMillis = 1000)) + var errorCount = 0 + poller.start { errorCount++ } + + // Poller start success + verify(exactly = 1) { fetchApi.getFlagConfigs() } + assertEquals(0, errorCount) + + // Next poll subsequent fails + every { fetchApi.getFlagConfigs() } answers { throw Error("Haha error") } + Thread.sleep(1100) + verify(exactly = 2) { fetchApi.getFlagConfigs() } + assertEquals(1, errorCount) + + // Poller stops + Thread.sleep(1100) + verify(exactly = 2) { fetchApi.getFlagConfigs() } + assertEquals(1, errorCount) + + poller.shutdown() + } +} + +class FlagConfigStreamerTest { +// private val onInitUpdateCapture = slot<((List) -> Unit)>() + private val onUpdateCapture = slot<((List) -> Unit)>() + private val onErrorCapture = slot<((Throwable?) -> Unit)>() + private var streamApi = mockk() + private var storage = InMemoryFlagConfigStorage() + + @BeforeTest + fun beforeTest() { + streamApi = mockk() + storage = InMemoryFlagConfigStorage() + } + + @AfterTest + fun afterTest() { + clearAllMocks() + } + + @Test + fun `Test Poller`() { + justRun { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } + val streamer = FlagConfigStreamer(streamApi, storage, null, null) + var errorCount = 0 + streamer.start { errorCount++ } + + // Streamer starts + verify(exactly = 1) { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } + + // Verify update callback updates storage + onUpdateCapture.captured(emptyList()) + assertEquals(0, storage.getFlagConfigs().size) + onUpdateCapture.captured(listOf(FLAG1)) + assertEquals(mapOf(FLAG1.key to FLAG1), storage.getFlagConfigs()) + onUpdateCapture.captured(listOf(FLAG1, FLAG2)) + assertEquals(mapOf(FLAG1.key to FLAG1, FLAG2.key to FLAG2), storage.getFlagConfigs()) + + // No extra connect calls + verify(exactly = 1) { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } + + // No errors + assertEquals(0, errorCount) + } + + @Test + fun `Test Streamer start fails`() { + every { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } answers { throw Error("Haha error") } + val streamer = FlagConfigStreamer(streamApi, storage, null, null) + var errorCount = 0 + try { + streamer.start { errorCount++ } + fail("Streamer start error not throwing") + } catch (_: Throwable) { + } + verify(exactly = 1) { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } + assertEquals(0, errorCount) // No error callback as it throws directly + } + + @Test + fun `Test Streamer stream fails`() { + justRun { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } + val streamer = FlagConfigStreamer(streamApi, storage, null, null) + var errorCount = 0 + streamer.start { errorCount++ } + + // Stream start success + verify(exactly = 1) { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } + onUpdateCapture.captured(listOf(FLAG1)) + assertEquals(mapOf(FLAG1.key to FLAG1), storage.getFlagConfigs()) + assertEquals(0, errorCount) + + // Stream fails + onErrorCapture.captured(Exception("Haha error")) + assertEquals(1, errorCount) // Error callback is called + } +} + +class FlagConfigFallbackRetryWrapperTest { + private val mainOnErrorCapture = slot<(() -> Unit)>() + + private var mainUpdater = mockk() + private var fallbackUpdater = mockk() + + @BeforeTest + fun beforeTest() { + mainUpdater = mockk() + fallbackUpdater = mockk() + + justRun { mainUpdater.start(capture(mainOnErrorCapture)) } + justRun { mainUpdater.stop() } + justRun { mainUpdater.shutdown() } + justRun { fallbackUpdater.start() } // Fallback is never passed onError callback, no need to capture + justRun { fallbackUpdater.stop() } + justRun { fallbackUpdater.shutdown() } + } + + @AfterTest + fun afterTest() { + clearAllMocks() + } + + @Test + fun `Test FallbackRetryWrapper main success no fallback updater`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) + + // Main starts + wrapper.start() + verify(exactly = 1) { mainUpdater.start(any()) } + + // Stop + wrapper.stop() + verify(exactly = 1) { mainUpdater.stop() } + + // Start again + wrapper.start() + verify(exactly = 2) { mainUpdater.start(any()) } + + // Shutdown + wrapper.shutdown() + verify(exactly = 1) { mainUpdater.shutdown() } + } + + @Test + fun `Test FallbackRetryWrapper main start error and retries with no fallback updater`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) + + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + + // Main start fail, no error, same as success case + try { + wrapper.start() + fail("Start errors should throw") + } catch (_: Throwable) {} + verify(exactly = 1) { mainUpdater.start(any()) } + + // Start errors no retry + Thread.sleep(1100) + verify(exactly = 1) { mainUpdater.start(any()) } + + wrapper.shutdown() + } + + @Test + fun `Test FallbackRetryWrapper main error callback and retries with no fallback updater`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) + + // Main start success + wrapper.start() + verify(exactly = 1) { mainUpdater.start(any()) } + + // Signal error + mainOnErrorCapture.captured() + verify(exactly = 1) { mainUpdater.start(any()) } + + // Retry fail after 1s + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + Thread.sleep(1100) + verify(exactly = 2) { mainUpdater.start(any()) } + + // Retry success after 1s + justRun { mainUpdater.start(capture(mainOnErrorCapture)) } + Thread.sleep(1100) + verify(exactly = 3) { mainUpdater.start(any()) } + + // No more start + Thread.sleep(1100) + verify(exactly = 3) { mainUpdater.start(any()) } + + wrapper.shutdown() + } + + @Test + fun `Test FallbackRetryWrapper main updater all success`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) + + // Main starts + wrapper.start() + verify(exactly = 1) { mainUpdater.start(any()) } + verify(exactly = 0) { fallbackUpdater.start() } + verify(exactly = 1) { fallbackUpdater.stop() } + + // Stop + wrapper.stop() + verify(exactly = 1) { mainUpdater.stop() } + verify(exactly = 2) { fallbackUpdater.stop() } + + // Start again + wrapper.start() + verify(exactly = 2) { mainUpdater.start(any()) } + verify(exactly = 0) { fallbackUpdater.start() } + verify(exactly = 3) { fallbackUpdater.stop() } + + // Shutdown + wrapper.shutdown() + verify(exactly = 1) { mainUpdater.shutdown() } + verify(exactly = 1) { mainUpdater.shutdown() } + } + + @Test + fun `Test FallbackRetryWrapper main and fallback start error`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) + + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + every { fallbackUpdater.start() } answers { throw Error() } + + // Main start fail, no error, same as success case + try { + wrapper.start() + fail("Start errors should throw") + } catch (_: Throwable) {} + verify(exactly = 1) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + + // Start errors no retry + Thread.sleep(1100) + verify(exactly = 1) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + + wrapper.shutdown() + } + + @Test + fun `Test FallbackRetryWrapper main start error and retries`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) + + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + + // Main start fail, no error, same as success case + wrapper.start() + verify(exactly = 1) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + + // Retries start + Thread.sleep(1100) + verify(exactly = 2) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + + wrapper.shutdown() + } + + @Test + fun `Test FallbackRetryWrapper main error callback and retries`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) + + // Main start success + wrapper.start() + verify(exactly = 1) { mainUpdater.start(any()) } + verify(exactly = 0) { fallbackUpdater.start(any()) } + + // Signal error + mainOnErrorCapture.captured() + verify(exactly = 1) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + + // Retry fail after 1s + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + Thread.sleep(1100) + verify(exactly = 2) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + + // Retry success + justRun { mainUpdater.start(capture(mainOnErrorCapture)) } + verify(exactly = 1) { fallbackUpdater.stop() } + Thread.sleep(1100) + verify(exactly = 3) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + verify(exactly = 0) { mainUpdater.stop() } + verify(exactly = 2) { fallbackUpdater.stop() } + + // No more start + Thread.sleep(1100) + verify(exactly = 3) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + verify(exactly = 0) { mainUpdater.stop() } + verify(exactly = 2) { fallbackUpdater.stop() } + + wrapper.shutdown() + } + + @Test + fun `Test FallbackRetryWrapper main updater cannot be FlagConfigFallbackRetryWrapper`() { + val wrapper = FlagConfigFallbackRetryWrapper(FlagConfigFallbackRetryWrapper(mainUpdater, null), null, 1000, 0) + try { + wrapper.start() + fail("Did not throw") + } catch (_: Throwable) { + } + verify(exactly = 0) { mainUpdater.start() } + } +} diff --git a/src/test/kotlin/util/SseStreamTest.kt b/src/test/kotlin/util/SseStreamTest.kt new file mode 100644 index 0000000..58f13cd --- /dev/null +++ b/src/test/kotlin/util/SseStreamTest.kt @@ -0,0 +1,102 @@ +package com.amplitude.experiment.util + +import io.mockk.clearAllMocks +import io.mockk.every +import io.mockk.justRun +import io.mockk.mockk +import io.mockk.mockkConstructor +import io.mockk.mockkStatic +import io.mockk.slot +import io.mockk.verify +import okhttp3.HttpUrl.Companion.toHttpUrl +import okhttp3.OkHttpClient +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener +import kotlin.test.AfterTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class SseStreamTest { + private val listenerCapture = slot() + private val clientMock = mockk() + private val es = mockk("mocked es") + + private var data: List = listOf() + private var err: List = listOf() + + @BeforeTest + fun beforeTest() { + mockkStatic("com.amplitude.experiment.util.RequestKt") + + justRun { es.cancel() } + every { clientMock.newEventSource(any(), capture(listenerCapture)) } returns es + + mockkConstructor(OkHttpClient.Builder::class) + every { anyConstructed().build() } returns clientMock + } + + @AfterTest + fun afterTest() { + clearAllMocks() + } + + private fun setupAndConnectStream( + reconnTimeout: Long = 5000 + ): SseStream { + val stream = SseStream("authtoken", "http://localhost".toHttpUrl(), OkHttpClient(), 1000, 1000, reconnTimeout, 0) + + stream.connect({ d -> + data += d + }, { t -> + err += t + }) + + return stream + } + + @Test + fun `Test SseStream connect`() { + val stream = setupAndConnectStream() + + listenerCapture.captured.onEvent(es, null, null, "somedata") + assertEquals(listOf("somedata"), data) + listenerCapture.captured.onFailure(es, null, null) + assertEquals("Unknown stream failure", err[0]?.message) + listenerCapture.captured.onEvent(es, null, null, "nodata") + assertEquals(listOf("somedata"), data) + + stream.cancel() + } + + @Test + fun `Test SseStream keep alive data omits`() { + val stream = setupAndConnectStream(1000) + + listenerCapture.captured.onEvent(es, null, null, "somedata") + assertEquals(listOf("somedata"), data) + listenerCapture.captured.onEvent(es, null, null, " ") + assertEquals(listOf("somedata"), data) + + stream.cancel() + } + + @Test + fun `Test SseStream reconnects`() { + val stream = setupAndConnectStream(1000) + + listenerCapture.captured.onEvent(es, null, null, "somedata") + assertEquals(listOf("somedata"), data) + verify(exactly = 1) { + clientMock.newEventSource(allAny(), allAny()) + } + + Thread.sleep(1100) // Wait 1s for reconnection + + listenerCapture.captured.onEvent(es, null, null, "somedata") + assertEquals(listOf("somedata", "somedata"), data) + verify(exactly = 2) { + clientMock.newEventSource(allAny(), allAny()) + } + } +}