From 49636d89642e9203a6355b162a1b3a7fc73e8339 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Fri, 23 Aug 2024 16:30:15 -0700 Subject: [PATCH 01/12] added flag push --- build.gradle.kts | 1 + buildSrc/src/main/kotlin/Versions.kt | 1 + src/main/kotlin/LocalEvaluationClient.kt | 15 +- src/main/kotlin/LocalEvaluationConfig.kt | 2 + .../kotlin/deployment/DeploymentRunner.kt | 100 ++------ src/main/kotlin/flag/FlagConfigStreamApi.kt | 134 ++++++++++ src/main/kotlin/flag/FlagConfigUpdater.kt | 233 ++++++++++++++++++ src/main/kotlin/util/Metrics.kt | 10 + src/main/kotlin/util/SdkStream.kt | 122 +++++++++ .../kotlin/deployment/DeploymentRunnerTest.kt | 19 +- 10 files changed, 551 insertions(+), 86 deletions(-) create mode 100644 src/main/kotlin/flag/FlagConfigStreamApi.kt create mode 100644 src/main/kotlin/flag/FlagConfigUpdater.kt create mode 100644 src/main/kotlin/util/SdkStream.kt diff --git a/build.gradle.kts b/build.gradle.kts index f8d832e..863afa6 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -24,6 +24,7 @@ dependencies { testImplementation("io.mockk:mockk:${Versions.mockk}") implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:${Versions.serializationRuntime}") implementation("com.squareup.okhttp3:okhttp:${Versions.okhttp}") + implementation("com.squareup.okhttp3:okhttp-sse:${Versions.okhttpSse}") implementation("com.amplitude:evaluation-core:${Versions.evaluationCore}") implementation("com.amplitude:java-sdk:${Versions.amplitudeAnalytics}") implementation("org.json:json:${Versions.json}") diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index cfc233d..b1ed2c0 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -6,6 +6,7 @@ object Versions { const val serializationRuntime = "1.4.1" const val json = "20231013" const val okhttp = "4.12.0" + const val okhttpSse = "4.12.0" // Update this alongside okhttp. Note this library isn't stable and may contain breaking changes. const val evaluationCore = "2.0.0-beta.2" const val amplitudeAnalytics = "1.12.0" const val mockk = "1.13.9" diff --git a/src/main/kotlin/LocalEvaluationClient.kt b/src/main/kotlin/LocalEvaluationClient.kt index 57dc4ae..a25e143 100644 --- a/src/main/kotlin/LocalEvaluationClient.kt +++ b/src/main/kotlin/LocalEvaluationClient.kt @@ -19,6 +19,8 @@ import com.amplitude.experiment.evaluation.EvaluationEngineImpl import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.experiment.evaluation.topologicalSort import com.amplitude.experiment.flag.DynamicFlagConfigApi +import com.amplitude.experiment.flag.FlagConfigPoller +import com.amplitude.experiment.flag.FlagConfigStreamApi import com.amplitude.experiment.flag.InMemoryFlagConfigStorage import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper import com.amplitude.experiment.util.Logger @@ -43,7 +45,10 @@ class LocalEvaluationClient internal constructor( private val serverUrl: HttpUrl = getServerUrl(config) private val evaluation: EvaluationEngine = EvaluationEngineImpl() private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(config.metrics) - private val flagConfigApi = DynamicFlagConfigApi(apiKey, serverUrl, getProxyUrl(config), httpClient) + private val flagConfigApi = DynamicFlagConfigApi(apiKey, serverUrl, null, httpClient) + private val proxyUrl: HttpUrl? = getProxyUrl(config) + private val flagConfigProxyApi = if (proxyUrl == null) null else DynamicFlagConfigApi(apiKey, proxyUrl, null, httpClient) + private val flagConfigStreamApi = FlagConfigStreamApi(apiKey, "https://stream.lab.amplitude.com", httpClient) private val flagConfigStorage = InMemoryFlagConfigStorage() private val cohortStorage = if (config.cohortSyncConfig == null) { null @@ -60,6 +65,8 @@ class LocalEvaluationClient internal constructor( private val deploymentRunner = DeploymentRunner( config = config, flagConfigApi = flagConfigApi, + flagConfigProxyApi = flagConfigProxyApi, + flagConfigStreamApi = flagConfigStreamApi, flagConfigStorage = flagConfigStorage, cohortApi = cohortApi, cohortStorage = cohortStorage, @@ -214,3 +221,9 @@ private fun getEventServerUrl( assignmentConfiguration.serverUrl } } + +fun main() { + val client = LocalEvaluationClient("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz") + client.start() + println(client.evaluateV2(ExperimentUser("1"))) +} \ No newline at end of file diff --git a/src/main/kotlin/LocalEvaluationConfig.kt b/src/main/kotlin/LocalEvaluationConfig.kt index 67a5a72..e271610 100644 --- a/src/main/kotlin/LocalEvaluationConfig.kt +++ b/src/main/kotlin/LocalEvaluationConfig.kt @@ -207,6 +207,8 @@ interface LocalEvaluationMetrics { fun onFlagConfigFetch() fun onFlagConfigFetchFailure(exception: Exception) fun onFlagConfigFetchOriginFallback(exception: Exception) + fun onFlagConfigStream() + fun onFlagConfigStreamFailure(exception: Exception?) fun onCohortDownload() fun onCohortDownloadFailure(exception: Exception) fun onCohortDownloadOriginFallback(exception: Exception) diff --git a/src/main/kotlin/deployment/DeploymentRunner.kt b/src/main/kotlin/deployment/DeploymentRunner.kt index 0001278..6c38411 100644 --- a/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/src/main/kotlin/deployment/DeploymentRunner.kt @@ -2,22 +2,19 @@ package com.amplitude.experiment.deployment -import com.amplitude.experiment.ExperimentalApi -import com.amplitude.experiment.LocalEvaluationConfig -import com.amplitude.experiment.LocalEvaluationMetrics +import com.amplitude.experiment.* import com.amplitude.experiment.cohort.CohortApi import com.amplitude.experiment.cohort.CohortLoader import com.amplitude.experiment.cohort.CohortStorage +import com.amplitude.experiment.flag.* import com.amplitude.experiment.flag.FlagConfigApi +import com.amplitude.experiment.flag.FlagConfigPoller import com.amplitude.experiment.flag.FlagConfigStorage import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper import com.amplitude.experiment.util.Logger import com.amplitude.experiment.util.Once import com.amplitude.experiment.util.daemonFactory import com.amplitude.experiment.util.getAllCohortIds -import com.amplitude.experiment.util.wrapMetrics -import java.util.concurrent.CompletableFuture -import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.Executors import java.util.concurrent.TimeUnit @@ -26,6 +23,8 @@ private const val MIN_COHORT_POLLING_INTERVAL = 60000L internal class DeploymentRunner( private val config: LocalEvaluationConfig, private val flagConfigApi: FlagConfigApi, + private val flagConfigProxyApi: FlagConfigApi? = null, + private val flagConfigStreamApi: FlagConfigStreamApi? = null, private val flagConfigStorage: FlagConfigStorage, cohortApi: CohortApi?, private val cohortStorage: CohortStorage?, @@ -39,21 +38,26 @@ internal class DeploymentRunner( null } private val cohortPollingInterval: Long = getCohortPollingInterval() + // Fallback in this order: proxy, stream, poll. + private val amplitudeFlagConfigPoller = FlagConfigPoller(flagConfigApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics) + private val amplitudeFlagConfigUpdater = + if (flagConfigStreamApi != null) + FlagConfigFallbackRetryWrapper( + FlagConfigStreamer(flagConfigStreamApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), + amplitudeFlagConfigPoller, + ) + else amplitudeFlagConfigPoller + private val flagConfigUpdater = + if (flagConfigProxyApi != null) + FlagConfigFallbackRetryWrapper( + FlagConfigPoller(flagConfigProxyApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), + amplitudeFlagConfigPoller + ) + else + amplitudeFlagConfigUpdater fun start() = lock.once { - refresh() - poller.scheduleWithFixedDelay( - { - try { - refresh() - } catch (t: Throwable) { - Logger.e("Refresh flag configs failed.", t) - } - }, - config.flagConfigPollerIntervalMillis, - config.flagConfigPollerIntervalMillis, - TimeUnit.MILLISECONDS - ) + flagConfigUpdater.start() if (cohortLoader != null) { poller.scheduleWithFixedDelay( { @@ -74,63 +78,7 @@ internal class DeploymentRunner( fun stop() { poller.shutdown() - } - - fun refresh() { - Logger.d("Refreshing flag configs.") - // Get updated flags from the network. - val flagConfigs = wrapMetrics( - metric = metrics::onFlagConfigFetch, - failure = metrics::onFlagConfigFetchFailure, - ) { - flagConfigApi.getFlagConfigs() - } - - // Remove flags that no longer exist. - val flagKeys = flagConfigs.map { it.key }.toSet() - flagConfigStorage.removeIf { !flagKeys.contains(it.key) } - - // Get all flags from storage - val storageFlags = flagConfigStorage.getFlagConfigs() - - // Load cohorts for each flag if applicable and put the flag in storage. - val futures = ConcurrentHashMap>() - for (flagConfig in flagConfigs) { - if (cohortLoader == null) { - flagConfigStorage.putFlagConfig(flagConfig) - continue - } - val cohortIds = flagConfig.getAllCohortIds() - val storageCohortIds = storageFlags[flagConfig.key]?.getAllCohortIds() ?: emptySet() - val cohortsToLoad = cohortIds - storageCohortIds - if (cohortsToLoad.isEmpty()) { - flagConfigStorage.putFlagConfig(flagConfig) - continue - } - for (cohortId in cohortsToLoad) { - futures.putIfAbsent( - cohortId, - cohortLoader.loadCohort(cohortId).handle { _, exception -> - if (exception != null) { - Logger.e("Failed to load cohort $cohortId", exception) - } - flagConfigStorage.putFlagConfig(flagConfig) - } - ) - } - } - futures.values.forEach { it.join() } - - // Delete unused cohorts - if (cohortStorage != null) { - val flagCohortIds = flagConfigStorage.getFlagConfigs().values.toList().getAllCohortIds() - val storageCohortIds = cohortStorage.getCohorts().keys - val deletedCohortIds = storageCohortIds - flagCohortIds - for (deletedCohortId in deletedCohortIds) { - cohortStorage.deleteCohort(deletedCohortId) - } - } - Logger.d("Refreshed ${flagConfigs.size} flag configs.") + flagConfigUpdater.shutdown() } private fun getCohortPollingInterval(): Long { diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt new file mode 100644 index 0000000..3d61313 --- /dev/null +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -0,0 +1,134 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.util.* +import com.amplitude.experiment.util.SdkStream +import kotlinx.serialization.decodeFromString +import okhttp3.OkHttpClient +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ExecutionException +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicBoolean + +internal open class FlagConfigStreamApiError(message: String?, cause: Throwable?): Exception(message, cause) { + constructor(message: String?) : this(message, null) + constructor(cause: Throwable?) : this(cause?.toString(), cause) +} +internal class FlagConfigStreamApiConnTimeoutError: FlagConfigStreamApiError("Initial connection timed out") +internal class FlagConfigStreamApiDataCorruptError: FlagConfigStreamApiError("Stream data corrupted") +internal class FlagConfigStreamApiStreamError(cause: Throwable?): FlagConfigStreamApiError("Stream error", cause) + +private const val CONNECTION_TIMEOUT_MILLIS_DEFAULT = 2000L +private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 17000L +private const val RECONN_INTERVAL_MILLIS_DEFAULT = 15 * 60 * 1000L +internal class FlagConfigStreamApi ( + deploymentKey: String, + serverUrl: String, + httpClient: OkHttpClient = OkHttpClient(), + connectionTimeoutMillis: Long = CONNECTION_TIMEOUT_MILLIS_DEFAULT, + keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, + reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT +) { + var onInitUpdate: ((List) -> Unit)? = null + var onUpdate: ((List) -> Unit)? = null + var onError: ((Exception?) -> Unit)? = null + private val stream: SdkStream = SdkStream( + "Api-Key $deploymentKey", + "$serverUrl/sdk/stream/v1/flags", + httpClient, + connectionTimeoutMillis, + keepaliveTimeoutMillis, + reconnIntervalMillis) + + fun connect() { + val isInit = AtomicBoolean(true) + val connectTimeoutFuture = CompletableFuture() + val updateTimeoutFuture = CompletableFuture() + stream.onUpdate = { data -> + if (isInit.getAndSet(false)) { + // Stream is establishing. First data received. + // Resolve timeout. + connectTimeoutFuture.complete(Unit) + + // Make sure valid data. + try { + val flags = getFlagsFromData(data) + + try { + if (onInitUpdate != null) { + onInitUpdate?.let { it(flags) } + } else { + onUpdate?.let { it(flags) } + } + updateTimeoutFuture.complete(Unit) + } catch (e: Throwable) { + updateTimeoutFuture.completeExceptionally(e) + } + } catch (_: Throwable) { + updateTimeoutFuture.completeExceptionally(FlagConfigStreamApiDataCorruptError()) + } + + } else { + // Stream has already established. + // Make sure valid data. + try { + val flags = getFlagsFromData(data) + + try { + onUpdate?.let { it(flags) } + } catch (_: Throwable) { + // Don't care about application error. + } + } catch (_: Throwable) { + // Stream corrupted. Reconnect. + handleError(FlagConfigStreamApiDataCorruptError()) + } + + } + } + stream.onError = { t -> + if (isInit.getAndSet(false)) { + connectTimeoutFuture.completeExceptionally(t) + updateTimeoutFuture.completeExceptionally(t) + } else { + handleError(FlagConfigStreamApiStreamError(t)) + } + } + stream.connect() + + val t: Throwable + try { + connectTimeoutFuture.get(2000, TimeUnit.MILLISECONDS) + updateTimeoutFuture.get() + return + } catch (e: TimeoutException) { + // Timeouts should retry + t = FlagConfigStreamApiConnTimeoutError() + } catch (e: ExecutionException) { + val cause = e.cause + t = if (cause is StreamException) { + FlagConfigStreamApiStreamError(cause) + } else { + FlagConfigStreamApiError(e) + } + } catch (e: Throwable) { + t = FlagConfigStreamApiError(e) + } + close() + throw t + } + + fun close() { + stream.cancel() + } + + private fun getFlagsFromData(data: String): List { + return json.decodeFromString>(data) + } + + private fun handleError(e: Exception?) { + close() + onError?.let { it(e) } + } +} diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt new file mode 100644 index 0000000..f882ab5 --- /dev/null +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -0,0 +1,233 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.LocalEvaluationConfig +import com.amplitude.experiment.LocalEvaluationMetrics +import com.amplitude.experiment.cohort.CohortLoader +import com.amplitude.experiment.cohort.CohortStorage +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.util.* +import com.amplitude.experiment.util.Logger +import com.amplitude.experiment.util.daemonFactory +import com.amplitude.experiment.util.wrapMetrics +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Executors +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit + +internal interface FlagConfigUpdater { + // Start the updater. There can be multiple calls. + // If start fails, it should throw exception. The caller should handle fallback. + // If some other error happened while updating (already started successfully), it should call fallback. + fun start(fallback: (() -> Unit)? = null) + // Stop should stop updater temporarily. There may be another start in the future. + // To stop completely, with intention to never start again, use shutdown() instead. + fun stop() + // Destroy should stop the updater forever in preparation for server shutdown. + fun shutdown() +} + +internal abstract class FlagConfigUpdaterBase( + private val flagConfigStorage: FlagConfigStorage, + private val cohortLoader: CohortLoader?, + private val cohortStorage: CohortStorage?, +): FlagConfigUpdater { + fun update(flagConfigs: List) { + println("update") + // Remove flags that no longer exist. + val flagKeys = flagConfigs.map { it.key }.toSet() + flagConfigStorage.removeIf { !flagKeys.contains(it.key) } + + // Get all flags from storage + val storageFlags = flagConfigStorage.getFlagConfigs() + + // Load cohorts for each flag if applicable and put the flag in storage. + val futures = ConcurrentHashMap>() + for (flagConfig in flagConfigs) { + if (cohortLoader == null) { + flagConfigStorage.putFlagConfig(flagConfig) + continue + } + val cohortIds = flagConfig.getAllCohortIds() + val storageCohortIds = storageFlags[flagConfig.key]?.getAllCohortIds() ?: emptySet() + val cohortsToLoad = cohortIds - storageCohortIds + if (cohortsToLoad.isEmpty()) { + flagConfigStorage.putFlagConfig(flagConfig) + continue + } + for (cohortId in cohortsToLoad) { + futures.putIfAbsent( + cohortId, + cohortLoader.loadCohort(cohortId).handle { _, exception -> + if (exception != null) { + Logger.e("Failed to load cohort $cohortId", exception) + } + flagConfigStorage.putFlagConfig(flagConfig) + } + ) + } + } + futures.values.forEach { it.join() } + + // Delete unused cohorts + if (cohortStorage != null) { + val flagCohortIds = flagConfigStorage.getFlagConfigs().values.toList().getAllCohortIds() + val storageCohortIds = cohortStorage.getCohorts().keys + val deletedCohortIds = storageCohortIds - flagCohortIds + for (deletedCohortId in deletedCohortIds) { + cohortStorage.deleteCohort(deletedCohortId) + } + } + Logger.d("Refreshed ${flagConfigs.size} flag configs.") + } +} + +internal class FlagConfigPoller( + private val flagConfigApi: FlagConfigApi, + private val storage: FlagConfigStorage, + private val cohortLoader: CohortLoader?, + private val cohortStorage: CohortStorage?, + private val config: LocalEvaluationConfig, + private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper() +): FlagConfigUpdaterBase( + storage, cohortLoader, cohortStorage +) { + private val poller = Executors.newScheduledThreadPool(1, daemonFactory) + private var scheduledFuture: ScheduledFuture<*>? = null + override fun start(fallback: (() -> Unit)?) { + // Perform updates + refresh() + scheduledFuture = poller.scheduleWithFixedDelay( + { + try { + refresh() + } catch (t: Throwable) { + Logger.e("Refresh flag configs failed.", t) + stop() + fallback?.invoke() + } + }, + config.flagConfigPollerIntervalMillis, + config.flagConfigPollerIntervalMillis, + TimeUnit.MILLISECONDS + ) + } + + override fun stop() { + // Pause only stop the task scheduled. It doesn't stop the executor. + scheduledFuture?.cancel(true) + scheduledFuture = null + } + + override fun shutdown() { + // Stop the executor. + poller.shutdown() + } + + fun refresh() { + Logger.d("Refreshing flag configs.") + println("flag poller refreshing") + // Get updated flags from the network. + val flagConfigs = wrapMetrics( + metric = metrics::onFlagConfigFetch, + failure = metrics::onFlagConfigFetchFailure, + ) { + flagConfigApi.getFlagConfigs() + } + + update(flagConfigs) + println("flag poller refreshed") + } +} + +internal class FlagConfigStreamer( + private val flagConfigStreamApi: FlagConfigStreamApi, + private val storage: FlagConfigStorage, + private val cohortLoader: CohortLoader?, + private val cohortStorage: CohortStorage?, + private val config: LocalEvaluationConfig, + private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper() +): FlagConfigUpdaterBase( + storage, cohortLoader, cohortStorage +) { + override fun start(fallback: (() -> Unit)?) { + flagConfigStreamApi.onUpdate = {flags -> + println("flag streamer received") + update(flags) + } + flagConfigStreamApi.onError = {e -> + Logger.e("Stream flag configs streaming failed.", e) + metrics.onFlagConfigStreamFailure(e) + fallback?.invoke() + } + wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) { + flagConfigStreamApi.connect() + } + } + + override fun stop() { + flagConfigStreamApi.close() + } + + override fun shutdown() = stop() +} + +private const val RETRY_DELAY_MILLIS_DEFAULT = 15 * 1000L +private const val MAX_JITTER_MILLIS_DEFAULT = 2000L +internal class FlagConfigFallbackRetryWrapper( + private val mainUpdater: FlagConfigUpdater, + private val fallbackUpdater: FlagConfigUpdater, + private val retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT, + private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT +): FlagConfigUpdater { + private val reconnIntervalRange = (retryDelayMillis - maxJitterMillis)..(retryDelayMillis + maxJitterMillis) + private val executor = Executors.newScheduledThreadPool(1, daemonFactory) + private var retryTask: ScheduledFuture<*>? = null + + override fun start(fallback: (() -> Unit)?) { + try { + mainUpdater.start { + startRetry(fallback) // Don't care if poller start error or not, always retry. + try { + fallbackUpdater.start(fallback) + } catch (_: Throwable) { + fallback?.invoke() + } + } + } catch (t: Throwable) { + Logger.e("Update flag configs start failed.", t) + fallbackUpdater.start(fallback) // If fallback failed, don't retry. + startRetry(fallback) + } + } + + override fun stop() { + mainUpdater.stop() + fallbackUpdater.stop() + retryTask?.cancel(true) + } + + override fun shutdown() { + mainUpdater.shutdown() + fallbackUpdater.shutdown() + retryTask?.cancel(true) + } + + private fun startRetry(fallback: (() -> Unit)?) { + retryTask = executor.schedule({ + try { + mainUpdater.start { + startRetry(fallback) // Don't care if poller start error or not, always retry stream. + try { + fallbackUpdater.start(fallback) + } catch (_: Throwable) { + fallback?.invoke() + } + } + fallbackUpdater.stop() + } catch (_: Throwable) { + startRetry(fallback) + } + }, reconnIntervalRange.random(), TimeUnit.MILLISECONDS) + } +} \ No newline at end of file diff --git a/src/main/kotlin/util/Metrics.kt b/src/main/kotlin/util/Metrics.kt index 2748652..99345c5 100644 --- a/src/main/kotlin/util/Metrics.kt +++ b/src/main/kotlin/util/Metrics.kt @@ -66,6 +66,16 @@ internal class LocalEvaluationMetricsWrapper( executor?.execute { metrics.onFlagConfigFetchFailure(exception) } } + override fun onFlagConfigStream() { + val metrics = metrics ?: return + executor?.execute { metrics.onFlagConfigStream() } + } + + override fun onFlagConfigStreamFailure(exception: Exception?) { + val metrics = metrics ?: return + executor?.execute { metrics.onFlagConfigStreamFailure(exception) } + } + override fun onFlagConfigFetchOriginFallback(exception: Exception) { val metrics = metrics ?: return executor?.execute { metrics.onFlagConfigFetchOriginFallback(exception) } diff --git a/src/main/kotlin/util/SdkStream.kt b/src/main/kotlin/util/SdkStream.kt new file mode 100644 index 0000000..ec2292f --- /dev/null +++ b/src/main/kotlin/util/SdkStream.kt @@ -0,0 +1,122 @@ +package com.amplitude.experiment.util + +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.Response +import okhttp3.internal.http2.ErrorCode +import okhttp3.internal.http2.StreamResetException +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener +import okhttp3.sse.EventSources +import java.util.* +import java.util.concurrent.TimeUnit +import kotlin.concurrent.schedule + +internal class StreamException(error: String): Throwable(error) + +private const val RECONN_INTERVAL_MILLIS_DEFAULT = 30 * 60 * 1000L +private const val MAX_JITTER_MILLIS_DEFAULT = 5000L +internal class SdkStream ( + private val authToken: String, + private val serverUrl: String, + private val httpClient: OkHttpClient = OkHttpClient(), + private val connectionTimeoutMillis: Long, + private val keepaliveTimeoutMillis: Long, + private val reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, + private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT +) { + private val reconnIntervalRange = (reconnIntervalMillis - maxJitterMillis)..(reconnIntervalMillis + maxJitterMillis) + private val eventSourceListener = object : EventSourceListener() { + override fun onOpen(eventSource: EventSource, response: Response) { + // No action needed. + } + + override fun onClosed(eventSource: EventSource) { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } + // Server closed the connection, just reconnect. + cancel() + connect() + } + + override fun onEvent( + eventSource: EventSource, + id: String?, + type: String?, + data: String + ) { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } + // Keep alive data + if (" " == data) { + return + } + onUpdate?.let { it(data) } + } + + override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } + if (t is StreamResetException && t.errorCode == ErrorCode.CANCEL) { + // TODO: relying on okhttp3.internal to differentiate cancel case. + return + } + cancel() + var err = t + if (t == null) { + err = if (response != null) { + StreamException(response.toString()) + } else { + StreamException("Unknown stream failure") + } + } + onError?.let { it(err) } + } + } + + private val request = Request.Builder() + .url(serverUrl) + .header("Authorization", authToken) + .addHeader("Accept", "text/event-stream") + .build() + + private val client = httpClient.newBuilder() // client.newBuilder reuses the connection pool in the same client with new configs. + .connectTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Connection timeout for establishing SSE. + .callTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Call timeout for establishing SSE. + .readTimeout(keepaliveTimeoutMillis, TimeUnit.MILLISECONDS) // Timeout between messages, keepalive in this case. + .writeTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) + .retryOnConnectionFailure(false) + .build() + + private var es: EventSource? = null + private var reconnectTimerTask: TimerTask? = null + var onUpdate: ((String) -> Unit)? = null + var onError: ((Throwable?) -> Unit)? = null + + fun connect() { + cancel() // Clear any existing event sources. + es = EventSources.createFactory(client).newEventSource(request = request, listener = eventSourceListener) + reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. + // This forces client side reconnection after interval. + this@SdkStream.cancel() + connect() + } + } + + fun cancel() { + reconnectTimerTask?.cancel() + + // There can be cases where an event source is being cancelled by these calls, but take a long time and made a callback to onFailure callback. + es?.cancel() + es = null + } +} \ No newline at end of file diff --git a/src/test/kotlin/deployment/DeploymentRunnerTest.kt b/src/test/kotlin/deployment/DeploymentRunnerTest.kt index 8e58a79..9b74742 100644 --- a/src/test/kotlin/deployment/DeploymentRunnerTest.kt +++ b/src/test/kotlin/deployment/DeploymentRunnerTest.kt @@ -42,11 +42,11 @@ class DeploymentRunnerTest { val flagConfigStorage = Mockito.mock(FlagConfigStorage::class.java) val cohortStorage = Mockito.mock(CohortStorage::class.java) val runner = DeploymentRunner( - LocalEvaluationConfig(), - flagApi, - flagConfigStorage, - cohortApi, - cohortStorage, + config = LocalEvaluationConfig(), + flagConfigApi = flagApi, + flagConfigStorage = flagConfigStorage, + cohortApi = cohortApi, + cohortStorage = cohortStorage, ) Mockito.`when`(flagApi.getFlagConfigs()).thenThrow(RuntimeException("test")) try { @@ -71,10 +71,11 @@ class DeploymentRunnerTest { val flagConfigStorage = Mockito.mock(FlagConfigStorage::class.java) val cohortStorage = Mockito.mock(CohortStorage::class.java) val runner = DeploymentRunner( - LocalEvaluationConfig(), - flagApi, flagConfigStorage, - cohortApi, - cohortStorage, + config = LocalEvaluationConfig(), + flagConfigApi = flagApi, + flagConfigStorage = flagConfigStorage, + cohortApi = cohortApi, + cohortStorage = cohortStorage, ) Mockito.`when`(flagApi.getFlagConfigs()).thenReturn(listOf(flag)) Mockito.`when`(cohortApi.getCohort(COHORT_ID, null)).thenThrow(RuntimeException("test")) From 027196ff7b2770e8187c3db4c721fe6a54e3d53d Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Mon, 26 Aug 2024 15:01:07 -0700 Subject: [PATCH 02/12] added config, minor renames --- buildSrc/src/main/kotlin/Versions.kt | 2 +- src/main/kotlin/LocalEvaluationClient.kt | 16 ++++++- src/main/kotlin/LocalEvaluationConfig.kt | 30 +++++++++++++ src/main/kotlin/ServerZone.kt | 2 + src/main/kotlin/flag/FlagConfigStreamApi.kt | 14 ++++--- src/main/kotlin/flag/FlagConfigUpdater.kt | 42 +++++++++---------- src/main/kotlin/util/Request.kt | 2 +- .../util/{SdkStream.kt => SseStream.kt} | 20 ++++----- 8 files changed, 86 insertions(+), 42 deletions(-) rename src/main/kotlin/util/{SdkStream.kt => SseStream.kt} (88%) diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index b1ed2c0..098ce65 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -6,7 +6,7 @@ object Versions { const val serializationRuntime = "1.4.1" const val json = "20231013" const val okhttp = "4.12.0" - const val okhttpSse = "4.12.0" // Update this alongside okhttp. Note this library isn't stable and may contain breaking changes. + const val okhttpSse = "4.12.0" // Update this alongside okhttp. Note this library isn't stable and may contain breaking changes. Search uses of okhttp3.internal classes before updating. const val evaluationCore = "2.0.0-beta.2" const val amplitudeAnalytics = "1.12.0" const val mockk = "1.13.9" diff --git a/src/main/kotlin/LocalEvaluationClient.kt b/src/main/kotlin/LocalEvaluationClient.kt index a25e143..a9b9e5d 100644 --- a/src/main/kotlin/LocalEvaluationClient.kt +++ b/src/main/kotlin/LocalEvaluationClient.kt @@ -43,12 +43,13 @@ class LocalEvaluationClient internal constructor( ) { private val assignmentService: AssignmentService? = createAssignmentService(apiKey) private val serverUrl: HttpUrl = getServerUrl(config) + private val streamServerUrl: HttpUrl = getStreamServerUrl(config) private val evaluation: EvaluationEngine = EvaluationEngineImpl() private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(config.metrics) private val flagConfigApi = DynamicFlagConfigApi(apiKey, serverUrl, null, httpClient) private val proxyUrl: HttpUrl? = getProxyUrl(config) private val flagConfigProxyApi = if (proxyUrl == null) null else DynamicFlagConfigApi(apiKey, proxyUrl, null, httpClient) - private val flagConfigStreamApi = FlagConfigStreamApi(apiKey, "https://stream.lab.amplitude.com", httpClient) + private val flagConfigStreamApi = if (config.streamUpdates) FlagConfigStreamApi(apiKey, streamServerUrl, httpClient, config.streamFlagConnTimeoutMillis) else null private val flagConfigStorage = InMemoryFlagConfigStorage() private val cohortStorage = if (config.cohortSyncConfig == null) { null @@ -192,6 +193,17 @@ private fun getServerUrl(config: LocalEvaluationConfig): HttpUrl { } } +private fun getStreamServerUrl(config: LocalEvaluationConfig): HttpUrl { + return if (config.streamServerUrl == LocalEvaluationConfig.Defaults.STREAM_SERVER_URL) { + when (config.serverZone) { + ServerZone.US -> US_STREAM_SERVER_URL.toHttpUrl() + ServerZone.EU -> EU_STREAM_SERVER_URL.toHttpUrl() + } + } else { + config.streamServerUrl.toHttpUrl() + } +} + private fun getProxyUrl(config: LocalEvaluationConfig): HttpUrl? { return config.evaluationProxyConfig?.proxyUrl?.toHttpUrl() } @@ -223,7 +235,7 @@ private fun getEventServerUrl( } fun main() { - val client = LocalEvaluationClient("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz") + val client = LocalEvaluationClient("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", LocalEvaluationConfig(streamUpdates = true)) client.start() println(client.evaluateV2(ExperimentUser("1"))) } \ No newline at end of file diff --git a/src/main/kotlin/LocalEvaluationConfig.kt b/src/main/kotlin/LocalEvaluationConfig.kt index e271610..eeab7cf 100644 --- a/src/main/kotlin/LocalEvaluationConfig.kt +++ b/src/main/kotlin/LocalEvaluationConfig.kt @@ -22,6 +22,12 @@ class LocalEvaluationConfig internal constructor( @JvmField val flagConfigPollerRequestTimeoutMillis: Long = Defaults.FLAG_CONFIG_POLLER_REQUEST_TIMEOUT_MILLIS, @JvmField + val streamUpdates: Boolean = Defaults.STREAM_UPDATES, + @JvmField + val streamServerUrl: String = Defaults.STREAM_SERVER_URL, + @JvmField + val streamFlagConnTimeoutMillis: Long = Defaults.STREAM_FLAG_CONN_TIMEOUT_MILLIS, + @JvmField val assignmentConfiguration: AssignmentConfiguration? = Defaults.ASSIGNMENT_CONFIGURATION, @JvmField val cohortSyncConfig: CohortSyncConfig? = Defaults.COHORT_SYNC_CONFIGURATION, @@ -76,6 +82,12 @@ class LocalEvaluationConfig internal constructor( */ const val FLAG_CONFIG_POLLER_REQUEST_TIMEOUT_MILLIS = 10_000L + const val STREAM_UPDATES = false + + const val STREAM_SERVER_URL = US_STREAM_SERVER_URL + + const val STREAM_FLAG_CONN_TIMEOUT_MILLIS = 1_500L + /** * null */ @@ -111,6 +123,9 @@ class LocalEvaluationConfig internal constructor( private var serverUrl = Defaults.SERVER_URL private var flagConfigPollerIntervalMillis = Defaults.FLAG_CONFIG_POLLER_INTERVAL_MILLIS private var flagConfigPollerRequestTimeoutMillis = Defaults.FLAG_CONFIG_POLLER_REQUEST_TIMEOUT_MILLIS + private var streamUpdates = Defaults.STREAM_UPDATES + private var streamServerUrl = Defaults.STREAM_SERVER_URL + private var streamFlagConnTimeoutMillis = Defaults.STREAM_FLAG_CONN_TIMEOUT_MILLIS private var assignmentConfiguration = Defaults.ASSIGNMENT_CONFIGURATION private var cohortSyncConfiguration = Defaults.COHORT_SYNC_CONFIGURATION private var evaluationProxyConfiguration = Defaults.EVALUATION_PROXY_CONFIGURATION @@ -136,6 +151,18 @@ class LocalEvaluationConfig internal constructor( this.flagConfigPollerRequestTimeoutMillis = flagConfigPollerRequestTimeoutMillis } + fun streamUpdates(streamUpdates: Boolean) = apply { + this.streamUpdates = streamUpdates + } + + fun streamServerUrl(streamServerUrl: String) = apply { + this.streamServerUrl = streamServerUrl + } + + fun streamFlagConnTimeoutMillis(streamFlagConnTimeoutMillis: Long) = apply { + this.streamFlagConnTimeoutMillis = streamFlagConnTimeoutMillis + } + fun enableAssignmentTracking(assignmentConfiguration: AssignmentConfiguration) = apply { this.assignmentConfiguration = assignmentConfiguration } @@ -161,6 +188,9 @@ class LocalEvaluationConfig internal constructor( serverZone = serverZone, flagConfigPollerIntervalMillis = flagConfigPollerIntervalMillis, flagConfigPollerRequestTimeoutMillis = flagConfigPollerRequestTimeoutMillis, + streamUpdates = streamUpdates, + streamServerUrl = streamServerUrl, + streamFlagConnTimeoutMillis = streamFlagConnTimeoutMillis, assignmentConfiguration = assignmentConfiguration, cohortSyncConfig = cohortSyncConfiguration, evaluationProxyConfig = evaluationProxyConfiguration, diff --git a/src/main/kotlin/ServerZone.kt b/src/main/kotlin/ServerZone.kt index c5d5dc3..c658ded 100644 --- a/src/main/kotlin/ServerZone.kt +++ b/src/main/kotlin/ServerZone.kt @@ -2,6 +2,8 @@ package com.amplitude.experiment internal const val US_SERVER_URL = "https://api.lab.amplitude.com" internal const val EU_SERVER_URL = "https://api.lab.eu.amplitude.com" +internal const val US_STREAM_SERVER_URL = "https://stream.lab.amplitude.com" +internal const val EU_STREAM_SERVER_URL = "https://stream.lab.eu.amplitude.com" internal const val US_COHORT_SERVER_URL = "https://cohort-v2.lab.amplitude.com" internal const val EU_COHORT_SERVER_URL = "https://cohort-v2.lab.eu.amplitude.com" internal const val US_EVENT_SERVER_URL = "https://api2.amplitude.com/2/httpapi" diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt index 3d61313..919e936 100644 --- a/src/main/kotlin/flag/FlagConfigStreamApi.kt +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -2,8 +2,9 @@ package com.amplitude.experiment.flag import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.experiment.util.* -import com.amplitude.experiment.util.SdkStream +import com.amplitude.experiment.util.SseStream import kotlinx.serialization.decodeFromString +import okhttp3.HttpUrl import okhttp3.OkHttpClient import java.util.concurrent.CompletableFuture import java.util.concurrent.ExecutionException @@ -19,12 +20,12 @@ internal class FlagConfigStreamApiConnTimeoutError: FlagConfigStreamApiError("In internal class FlagConfigStreamApiDataCorruptError: FlagConfigStreamApiError("Stream data corrupted") internal class FlagConfigStreamApiStreamError(cause: Throwable?): FlagConfigStreamApiError("Stream error", cause) -private const val CONNECTION_TIMEOUT_MILLIS_DEFAULT = 2000L -private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 17000L +private const val CONNECTION_TIMEOUT_MILLIS_DEFAULT = 1500L +private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 17000L // keep alive sends at 15s interval. 2s grace period private const val RECONN_INTERVAL_MILLIS_DEFAULT = 15 * 60 * 1000L internal class FlagConfigStreamApi ( deploymentKey: String, - serverUrl: String, + serverUrl: HttpUrl, httpClient: OkHttpClient = OkHttpClient(), connectionTimeoutMillis: Long = CONNECTION_TIMEOUT_MILLIS_DEFAULT, keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, @@ -33,9 +34,10 @@ internal class FlagConfigStreamApi ( var onInitUpdate: ((List) -> Unit)? = null var onUpdate: ((List) -> Unit)? = null var onError: ((Exception?) -> Unit)? = null - private val stream: SdkStream = SdkStream( + val url = serverUrl.newBuilder().addPathSegments("sdk/stream/v1/flags").build() + private val stream: SseStream = SseStream( "Api-Key $deploymentKey", - "$serverUrl/sdk/stream/v1/flags", + url, httpClient, connectionTimeoutMillis, keepaliveTimeoutMillis, diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index f882ab5..0cf3616 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -17,9 +17,9 @@ import java.util.concurrent.TimeUnit internal interface FlagConfigUpdater { // Start the updater. There can be multiple calls. - // If start fails, it should throw exception. The caller should handle fallback. - // If some other error happened while updating (already started successfully), it should call fallback. - fun start(fallback: (() -> Unit)? = null) + // If start fails, it should throw exception. The caller should handle error. + // If some other error happened while updating (already started successfully), it should call onError. + fun start(onError: (() -> Unit)? = null) // Stop should stop updater temporarily. There may be another start in the future. // To stop completely, with intention to never start again, use shutdown() instead. fun stop() @@ -94,8 +94,7 @@ internal class FlagConfigPoller( ) { private val poller = Executors.newScheduledThreadPool(1, daemonFactory) private var scheduledFuture: ScheduledFuture<*>? = null - override fun start(fallback: (() -> Unit)?) { - // Perform updates + override fun start(onError: (() -> Unit)?) { refresh() scheduledFuture = poller.scheduleWithFixedDelay( { @@ -104,7 +103,7 @@ internal class FlagConfigPoller( } catch (t: Throwable) { Logger.e("Refresh flag configs failed.", t) stop() - fallback?.invoke() + onError?.invoke() } }, config.flagConfigPollerIntervalMillis, @@ -124,7 +123,7 @@ internal class FlagConfigPoller( poller.shutdown() } - fun refresh() { + private fun refresh() { Logger.d("Refreshing flag configs.") println("flag poller refreshing") // Get updated flags from the network. @@ -150,15 +149,14 @@ internal class FlagConfigStreamer( ): FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { - override fun start(fallback: (() -> Unit)?) { + override fun start(onError: (() -> Unit)?) { flagConfigStreamApi.onUpdate = {flags -> - println("flag streamer received") update(flags) } flagConfigStreamApi.onError = {e -> Logger.e("Stream flag configs streaming failed.", e) metrics.onFlagConfigStreamFailure(e) - fallback?.invoke() + onError?.invoke() } wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) { flagConfigStreamApi.connect() @@ -184,20 +182,20 @@ internal class FlagConfigFallbackRetryWrapper( private val executor = Executors.newScheduledThreadPool(1, daemonFactory) private var retryTask: ScheduledFuture<*>? = null - override fun start(fallback: (() -> Unit)?) { + override fun start(onError: (() -> Unit)?) { try { mainUpdater.start { - startRetry(fallback) // Don't care if poller start error or not, always retry. + scheduleRetry(onError) // Don't care if poller start error or not, always retry. try { - fallbackUpdater.start(fallback) + fallbackUpdater.start(onError) } catch (_: Throwable) { - fallback?.invoke() + onError?.invoke() } } } catch (t: Throwable) { - Logger.e("Update flag configs start failed.", t) - fallbackUpdater.start(fallback) // If fallback failed, don't retry. - startRetry(fallback) + Logger.e("Primary flag configs start failed, start fallback. Error: ", t) + fallbackUpdater.start(onError) // If fallback failed, don't retry. + scheduleRetry(onError) } } @@ -213,20 +211,20 @@ internal class FlagConfigFallbackRetryWrapper( retryTask?.cancel(true) } - private fun startRetry(fallback: (() -> Unit)?) { + private fun scheduleRetry(onError: (() -> Unit)?) { retryTask = executor.schedule({ try { mainUpdater.start { - startRetry(fallback) // Don't care if poller start error or not, always retry stream. + scheduleRetry(onError) // Don't care if poller start error or not, always retry stream. try { - fallbackUpdater.start(fallback) + fallbackUpdater.start(onError) } catch (_: Throwable) { - fallback?.invoke() + onError?.invoke() } } fallbackUpdater.stop() } catch (_: Throwable) { - startRetry(fallback) + scheduleRetry(onError) } }, reconnIntervalRange.random(), TimeUnit.MILLISECONDS) } diff --git a/src/main/kotlin/util/Request.kt b/src/main/kotlin/util/Request.kt index 2bd09fb..1a826df 100644 --- a/src/main/kotlin/util/Request.kt +++ b/src/main/kotlin/util/Request.kt @@ -60,7 +60,7 @@ private fun OkHttpClient.submit( return future } -private fun newGet( +internal fun newGet( serverUrl: HttpUrl, path: String? = null, headers: Map? = null, diff --git a/src/main/kotlin/util/SdkStream.kt b/src/main/kotlin/util/SseStream.kt similarity index 88% rename from src/main/kotlin/util/SdkStream.kt rename to src/main/kotlin/util/SseStream.kt index ec2292f..2776fd4 100644 --- a/src/main/kotlin/util/SdkStream.kt +++ b/src/main/kotlin/util/SseStream.kt @@ -1,5 +1,7 @@ package com.amplitude.experiment.util +import com.amplitude.experiment.LIBRARY_VERSION +import okhttp3.HttpUrl import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response @@ -14,14 +16,15 @@ import kotlin.concurrent.schedule internal class StreamException(error: String): Throwable(error) +private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 0L // no timeout private const val RECONN_INTERVAL_MILLIS_DEFAULT = 30 * 60 * 1000L private const val MAX_JITTER_MILLIS_DEFAULT = 5000L -internal class SdkStream ( +internal class SseStream ( private val authToken: String, - private val serverUrl: String, + private val url: HttpUrl, private val httpClient: OkHttpClient = OkHttpClient(), private val connectionTimeoutMillis: Long, - private val keepaliveTimeoutMillis: Long, + private val keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, private val reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT ) { @@ -67,7 +70,8 @@ internal class SdkStream ( return } if (t is StreamResetException && t.errorCode == ErrorCode.CANCEL) { - // TODO: relying on okhttp3.internal to differentiate cancel case. + // Relying on okhttp3.internal to differentiate cancel case. + // Can be a pitfall later on. return } cancel() @@ -83,11 +87,7 @@ internal class SdkStream ( } } - private val request = Request.Builder() - .url(serverUrl) - .header("Authorization", authToken) - .addHeader("Accept", "text/event-stream") - .build() + private val request = newGet(url, null, mapOf("Authorization" to authToken, "Accept" to "text/event-stream")) private val client = httpClient.newBuilder() // client.newBuilder reuses the connection pool in the same client with new configs. .connectTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Connection timeout for establishing SSE. @@ -107,7 +107,7 @@ internal class SdkStream ( es = EventSources.createFactory(client).newEventSource(request = request, listener = eventSourceListener) reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. // This forces client side reconnection after interval. - this@SdkStream.cancel() + this@SseStream.cancel() connect() } } From 10356efa4229c2020bbe8c90e581a72c2f886511 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Tue, 3 Sep 2024 04:22:37 -0700 Subject: [PATCH 03/12] add test, fixed small bugs --- src/main/kotlin/LocalEvaluationClient.kt | 4 + .../kotlin/deployment/DeploymentRunner.kt | 2 +- src/main/kotlin/flag/FlagConfigStreamApi.kt | 7 +- src/main/kotlin/flag/FlagConfigUpdater.kt | 37 +- src/main/kotlin/util/Request.kt | 8 + src/main/kotlin/util/SseStream.kt | 21 +- .../kotlin/flag/FlagConfigStreamApiTest.kt | 152 ++++++++ src/test/kotlin/flag/FlagConfigUpdaterTest.kt | 351 ++++++++++++++++++ src/test/kotlin/util/SseStreamTest.kt | 100 +++++ 9 files changed, 657 insertions(+), 25 deletions(-) create mode 100644 src/test/kotlin/flag/FlagConfigStreamApiTest.kt create mode 100644 src/test/kotlin/flag/FlagConfigUpdaterTest.kt create mode 100644 src/test/kotlin/util/SseStreamTest.kt diff --git a/src/main/kotlin/LocalEvaluationClient.kt b/src/main/kotlin/LocalEvaluationClient.kt index a9b9e5d..239b0e5 100644 --- a/src/main/kotlin/LocalEvaluationClient.kt +++ b/src/main/kotlin/LocalEvaluationClient.kt @@ -34,6 +34,10 @@ import com.amplitude.experiment.util.wrapMetrics import okhttp3.HttpUrl import okhttp3.HttpUrl.Companion.toHttpUrl import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener +import okhttp3.sse.EventSources class LocalEvaluationClient internal constructor( apiKey: String, diff --git a/src/main/kotlin/deployment/DeploymentRunner.kt b/src/main/kotlin/deployment/DeploymentRunner.kt index 6c38411..39bfd79 100644 --- a/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/src/main/kotlin/deployment/DeploymentRunner.kt @@ -39,7 +39,7 @@ internal class DeploymentRunner( } private val cohortPollingInterval: Long = getCohortPollingInterval() // Fallback in this order: proxy, stream, poll. - private val amplitudeFlagConfigPoller = FlagConfigPoller(flagConfigApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics) + private val amplitudeFlagConfigPoller = FlagConfigFallbackRetryWrapper(FlagConfigPoller(flagConfigApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), null, config.flagConfigPollerIntervalMillis, 1000) private val amplitudeFlagConfigUpdater = if (flagConfigStreamApi != null) FlagConfigFallbackRetryWrapper( diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt index 919e936..4ff2891 100644 --- a/src/main/kotlin/flag/FlagConfigStreamApi.kt +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -6,6 +6,9 @@ import com.amplitude.experiment.util.SseStream import kotlinx.serialization.decodeFromString import okhttp3.HttpUrl import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener import java.util.concurrent.CompletableFuture import java.util.concurrent.ExecutionException import java.util.concurrent.TimeUnit @@ -27,7 +30,7 @@ internal class FlagConfigStreamApi ( deploymentKey: String, serverUrl: HttpUrl, httpClient: OkHttpClient = OkHttpClient(), - connectionTimeoutMillis: Long = CONNECTION_TIMEOUT_MILLIS_DEFAULT, + val connectionTimeoutMillis: Long = CONNECTION_TIMEOUT_MILLIS_DEFAULT, keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT ) { @@ -101,7 +104,7 @@ internal class FlagConfigStreamApi ( val t: Throwable try { - connectTimeoutFuture.get(2000, TimeUnit.MILLISECONDS) + connectTimeoutFuture.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS) updateTimeoutFuture.get() return } catch (e: TimeoutException) { diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index 0cf3616..d99b8ba 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -14,6 +14,8 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.Executors import java.util.concurrent.ScheduledFuture import java.util.concurrent.TimeUnit +import kotlin.math.max +import kotlin.math.min internal interface FlagConfigUpdater { // Start the updater. There can be multiple calls. @@ -33,7 +35,6 @@ internal abstract class FlagConfigUpdaterBase( private val cohortStorage: CohortStorage?, ): FlagConfigUpdater { fun update(flagConfigs: List) { - println("update") // Remove flags that no longer exist. val flagKeys = flagConfigs.map { it.key }.toSet() flagConfigStorage.removeIf { !flagKeys.contains(it.key) } @@ -125,7 +126,6 @@ internal class FlagConfigPoller( private fun refresh() { Logger.d("Refreshing flag configs.") - println("flag poller refreshing") // Get updated flags from the network. val flagConfigs = wrapMetrics( metric = metrics::onFlagConfigFetch, @@ -135,7 +135,6 @@ internal class FlagConfigPoller( } update(flagConfigs) - println("flag poller refreshed") } } @@ -174,11 +173,11 @@ private const val RETRY_DELAY_MILLIS_DEFAULT = 15 * 1000L private const val MAX_JITTER_MILLIS_DEFAULT = 2000L internal class FlagConfigFallbackRetryWrapper( private val mainUpdater: FlagConfigUpdater, - private val fallbackUpdater: FlagConfigUpdater, + private val fallbackUpdater: FlagConfigUpdater?, private val retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT, private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT ): FlagConfigUpdater { - private val reconnIntervalRange = (retryDelayMillis - maxJitterMillis)..(retryDelayMillis + maxJitterMillis) + private val reconnIntervalRange = max(0, retryDelayMillis - maxJitterMillis)..(min(retryDelayMillis, retryDelayMillis - maxJitterMillis) + maxJitterMillis) private val executor = Executors.newScheduledThreadPool(1, daemonFactory) private var retryTask: ScheduledFuture<*>? = null @@ -186,28 +185,32 @@ internal class FlagConfigFallbackRetryWrapper( try { mainUpdater.start { scheduleRetry(onError) // Don't care if poller start error or not, always retry. - try { - fallbackUpdater.start(onError) - } catch (_: Throwable) { + if (fallbackUpdater != null) { + try { + fallbackUpdater.start(onError) + } catch (_: Throwable) { + onError?.invoke() + } + } else { onError?.invoke() } } } catch (t: Throwable) { Logger.e("Primary flag configs start failed, start fallback. Error: ", t) - fallbackUpdater.start(onError) // If fallback failed, don't retry. + fallbackUpdater?.start(onError) scheduleRetry(onError) } } override fun stop() { mainUpdater.stop() - fallbackUpdater.stop() + fallbackUpdater?.stop() retryTask?.cancel(true) } override fun shutdown() { mainUpdater.shutdown() - fallbackUpdater.shutdown() + fallbackUpdater?.shutdown() retryTask?.cancel(true) } @@ -216,13 +219,17 @@ internal class FlagConfigFallbackRetryWrapper( try { mainUpdater.start { scheduleRetry(onError) // Don't care if poller start error or not, always retry stream. - try { - fallbackUpdater.start(onError) - } catch (_: Throwable) { + if (fallbackUpdater != null) { + try { + fallbackUpdater.start(onError) + } catch (_: Throwable) { + onError?.invoke() + } + } else { onError?.invoke() } } - fallbackUpdater.stop() + fallbackUpdater?.stop() } catch (_: Throwable) { scheduleRetry(onError) } diff --git a/src/main/kotlin/util/Request.kt b/src/main/kotlin/util/Request.kt index 1a826df..a8d9cc1 100644 --- a/src/main/kotlin/util/Request.kt +++ b/src/main/kotlin/util/Request.kt @@ -8,6 +8,9 @@ import okhttp3.HttpUrl import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener +import okhttp3.sse.EventSources import okio.IOException import java.util.concurrent.CompletableFuture @@ -111,3 +114,8 @@ internal inline fun OkHttpClient.get( } } } + +internal fun OkHttpClient.newEventSource(request: Request, eventSourceListener: EventSourceListener): EventSource { + // Creates an event source and immediately returns it. The connection is performed async. + return EventSources.createFactory(this).newEventSource(request, eventSourceListener) +} diff --git a/src/main/kotlin/util/SseStream.kt b/src/main/kotlin/util/SseStream.kt index 2776fd4..956cea0 100644 --- a/src/main/kotlin/util/SseStream.kt +++ b/src/main/kotlin/util/SseStream.kt @@ -2,6 +2,7 @@ package com.amplitude.experiment.util import com.amplitude.experiment.LIBRARY_VERSION import okhttp3.HttpUrl +import okhttp3.HttpUrl.Companion.toHttpUrl import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response @@ -13,12 +14,15 @@ import okhttp3.sse.EventSources import java.util.* import java.util.concurrent.TimeUnit import kotlin.concurrent.schedule +import kotlin.math.max +import kotlin.math.min internal class StreamException(error: String): Throwable(error) private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 0L // no timeout private const val RECONN_INTERVAL_MILLIS_DEFAULT = 30 * 60 * 1000L private const val MAX_JITTER_MILLIS_DEFAULT = 5000L +private const val KEEP_ALIVE_DATA = " " internal class SseStream ( private val authToken: String, private val url: HttpUrl, @@ -28,7 +32,7 @@ internal class SseStream ( private val reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT ) { - private val reconnIntervalRange = (reconnIntervalMillis - maxJitterMillis)..(reconnIntervalMillis + maxJitterMillis) + private val reconnIntervalRange = max(0, reconnIntervalMillis - maxJitterMillis)..(min(reconnIntervalMillis, Long.MAX_VALUE - maxJitterMillis) + maxJitterMillis) private val eventSourceListener = object : EventSourceListener() { override fun onOpen(eventSource: EventSource, response: Response) { // No action needed. @@ -57,13 +61,15 @@ internal class SseStream ( return } // Keep alive data - if (" " == data) { + if (KEEP_ALIVE_DATA == data) { return } onUpdate?.let { it(data) } } override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { + println(t) + println(response) if ((eventSource != es)) { // Not the current event source using right now, should cancel. eventSource.cancel() @@ -75,14 +81,12 @@ internal class SseStream ( return } cancel() - var err = t - if (t == null) { - err = if (response != null) { + val err = t + ?: if (response != null) { StreamException(response.toString()) } else { StreamException("Unknown stream failure") } - } onError?.let { it(err) } } } @@ -102,9 +106,12 @@ internal class SseStream ( var onUpdate: ((String) -> Unit)? = null var onError: ((Throwable?) -> Unit)? = null + /** + * Creates an event source and immediately returns. The connection is performed async. Errors are informed through callbacks. + */ fun connect() { cancel() // Clear any existing event sources. - es = EventSources.createFactory(client).newEventSource(request = request, listener = eventSourceListener) + es = client.newEventSource(request, eventSourceListener) reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. // This forces client side reconnection after interval. this@SseStream.cancel() diff --git a/src/test/kotlin/flag/FlagConfigStreamApiTest.kt b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt new file mode 100644 index 0000000..a125b1f --- /dev/null +++ b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt @@ -0,0 +1,152 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.Experiment +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.util.SseStream +import io.mockk.* +import okhttp3.HttpUrl +import okhttp3.HttpUrl.Companion.toHttpUrl +import okhttp3.OkHttpClient +import java.util.concurrent.CompletableFuture +import java.util.concurrent.TimeUnit +import kotlin.test.* + +class FlagConfigStreamApiTest { + private val onUpdateCapture = slot<((String) -> Unit)>() + private val onErrorCapture = slot<((Throwable?) -> Unit)>() + + private var data: Array> = arrayOf() + private var err: Array = arrayOf() + + @BeforeTest + fun beforeTest() { + mockkConstructor(SseStream::class) + + every { anyConstructed().connect() } answers { + Thread.sleep(1000) + } + every { anyConstructed().cancel() } answers { + Thread.sleep(1000) + } + every { anyConstructed().onUpdate = capture(onUpdateCapture) } answers {} + every { anyConstructed().onError = capture(onErrorCapture) } answers {} + } + + private fun setupApi( + deploymentKey: String = "", + serverUrl: HttpUrl = "http://localhost".toHttpUrl(), + connTimeout: Long = 2000 + ): FlagConfigStreamApi { + val api = FlagConfigStreamApi(deploymentKey, serverUrl, OkHttpClient(), connTimeout, 10000) + + api.onUpdate = { d -> + data += d + } + api.onError = { t -> + err += t + } + return api + } + + @Test + fun `Test passes correct arguments`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl()) + api.onInitUpdate = { d -> + data += d + } + + val run = async { + api.connect() + } + Thread.sleep(100) + onUpdateCapture.captured("[{\"key\":\"flagkey\",\"variants\":{},\"segments\":[]}]") + run.join() + + verify { anyConstructed().connect() } + assertContentEquals(arrayOf(listOf(EvaluationFlag("flagkey", emptyMap(), emptyList()))), data) + + api.close() + } + + @Test + fun `Test conn timeout doesn't block`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 500) + try { + api.connect() + fail("Timeout not thrown") + } catch (_: FlagConfigStreamApiConnTimeoutError) { + } + Thread.sleep(100) + verify { anyConstructed().cancel() } + } + + @Test + fun `Test init update failure throws`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 2000) + + api.onInitUpdate = { d -> + Thread.sleep(2100) // Update time is not included in connection timeout. + throw Error() + } + api.onUpdate = null + try { + api.connect() + fail("Timeout not thrown") + } catch (_: FlagConfigStreamApiConnTimeoutError) { + } + verify { anyConstructed().cancel() } + } + + @Test + fun `Test init update fallbacks to onUpdate when onInitUpdate = null`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 2000) + + api.onInitUpdate = null + api.onUpdate = { d -> + Thread.sleep(2100) // Update time is not included in connection timeout. + throw Error() + } + try { + api.connect() + fail("Timeout not thrown") + } catch (_: FlagConfigStreamApiConnTimeoutError) { + } + verify { anyConstructed().cancel() } + } + + @Test + fun `Test error is passed through onError`() { + val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 2000) + + val run = async { + api.connect() + } + Thread.sleep(100) + onUpdateCapture.captured("[]") + run.join() + + assertEquals(0, err.size) + onErrorCapture.captured(Error("Haha error")) + assertEquals("Stream error", err[0]?.message) + assertEquals("Haha error", err[0]?.cause?.message) + assertEquals(1, err.size) + verify { anyConstructed().cancel() } + } +} + +@Suppress("SameParameterValue") +private fun async(delayMillis: Long = 0L, block: () -> T): CompletableFuture { + return if (delayMillis == 0L) { + CompletableFuture.supplyAsync(block) + } else { + val future = CompletableFuture() + Experiment.scheduler.schedule({ + try { + future.complete(block.invoke()) + } catch (t: Throwable) { + future.completeExceptionally(t) + } + }, delayMillis, TimeUnit.MILLISECONDS) + future + } +} diff --git a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt new file mode 100644 index 0000000..7f88b1b --- /dev/null +++ b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt @@ -0,0 +1,351 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.LocalEvaluationConfig +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.util.SseStream +import io.mockk.* +import kotlin.test.* + +private val FLAG1 = EvaluationFlag("key1", emptyMap(), emptyList()) +private val FLAG2 = EvaluationFlag("key2", emptyMap(), emptyList()) +class FlagConfigPollerTest { + private var fetchApi = mockk() + private var storage = InMemoryFlagConfigStorage() + + @BeforeTest + fun beforeEach() { + fetchApi = mockk() + storage = InMemoryFlagConfigStorage() + } + + @Test + fun `Test Poller`() { + every { fetchApi.getFlagConfigs() } returns emptyList() + val poller = FlagConfigPoller(fetchApi, storage, null, null, LocalEvaluationConfig(flagConfigPollerIntervalMillis = 1000)) + var errorCount = 0 + poller.start { errorCount++ } + + // start() polls + verify(exactly = 1) { fetchApi.getFlagConfigs() } + assertEquals(0, storage.getFlagConfigs().size) + + // Poller polls every 1s interval + every { fetchApi.getFlagConfigs() } returns listOf(FLAG1) + Thread.sleep(1100) + verify(exactly = 2) { fetchApi.getFlagConfigs() } + assertEquals(1, storage.getFlagConfigs().size) + assertEquals(mapOf(FLAG1.key to FLAG1), storage.getFlagConfigs()) + Thread.sleep(1100) + verify(exactly = 3) { fetchApi.getFlagConfigs() } + + // Stop poller stops + poller.stop() + Thread.sleep(1100) + verify(exactly = 3) { fetchApi.getFlagConfigs() } + + // Restart poller + every { fetchApi.getFlagConfigs() } returns listOf(FLAG1, FLAG2) + poller.start() + verify(exactly = 4) { fetchApi.getFlagConfigs() } + Thread.sleep(1100) + verify(exactly = 5) { fetchApi.getFlagConfigs() } + assertEquals(2, storage.getFlagConfigs().size) + assertEquals(mapOf(FLAG1.key to FLAG1, FLAG2.key to FLAG2), storage.getFlagConfigs()) + + // No errors + assertEquals(0, errorCount) + + poller.shutdown() + } + + @Test + fun `Test Poller start fails`(){ + every { fetchApi.getFlagConfigs() } answers { throw Error("Haha error") } + val poller = FlagConfigPoller(fetchApi, storage, null, null, LocalEvaluationConfig(flagConfigPollerIntervalMillis = 1000)) + var errorCount = 0 + try { + poller.start { errorCount++ } + fail("Poller start error not throwing") + } catch (_: Throwable) { + } + verify(exactly = 1) { fetchApi.getFlagConfigs() } + + // Poller stops + Thread.sleep(1100) + verify(exactly = 1) { fetchApi.getFlagConfigs() } + assertEquals(0, errorCount) + + poller.shutdown() + } + + @Test + fun `Test Poller poll fails`(){ + every { fetchApi.getFlagConfigs() } returns emptyList() + val poller = FlagConfigPoller(fetchApi, storage, null, null, LocalEvaluationConfig(flagConfigPollerIntervalMillis = 1000)) + var errorCount = 0 + poller.start { errorCount++ } + + // Poller start success + verify(exactly = 1) { fetchApi.getFlagConfigs() } + assertEquals(0, errorCount) + + // Next poll subsequent fails + every { fetchApi.getFlagConfigs() } answers { throw Error("Haha error") } + Thread.sleep(1100) + verify(exactly = 2) { fetchApi.getFlagConfigs() } + assertEquals(1, errorCount) + + // Poller stops + Thread.sleep(1100) + verify(exactly = 2) { fetchApi.getFlagConfigs() } + assertEquals(1, errorCount) + + poller.shutdown() + } +} + +class FlagConfigStreamerTest { + private val onUpdateCapture = slot<((List) -> Unit)>() + private val onErrorCapture = slot<((Throwable?) -> Unit)>() + private var streamApi = mockk() + private var storage = InMemoryFlagConfigStorage() + private val config = LocalEvaluationConfig(streamUpdates = true, streamServerUrl = "", streamFlagConnTimeoutMillis = 2000) + + @BeforeTest + fun beforeEach() { + streamApi = mockk() + storage = InMemoryFlagConfigStorage() + + justRun { streamApi.onUpdate = capture(onUpdateCapture) } + justRun { streamApi.onError = capture(onErrorCapture) } + } + + @Test + fun `Test Poller`() { + justRun { streamApi.connect() } + val streamer = FlagConfigStreamer(streamApi, storage, null, null, config) + var errorCount = 0 + streamer.start { errorCount++ } + + // Streamer starts + verify(exactly = 1) { streamApi.connect() } + + // Verify update callback updates storage + onUpdateCapture.captured(emptyList()) + assertEquals(0, storage.getFlagConfigs().size) + onUpdateCapture.captured(listOf(FLAG1)) + assertEquals(mapOf(FLAG1.key to FLAG1), storage.getFlagConfigs()) + onUpdateCapture.captured(listOf(FLAG1, FLAG2)) + assertEquals(mapOf(FLAG1.key to FLAG1, FLAG2.key to FLAG2), storage.getFlagConfigs()) + + // No extra connect calls + verify(exactly = 1) { streamApi.connect() } + + // No errors + assertEquals(0, errorCount) + } + + @Test + fun `Test Streamer start fails`(){ + every { streamApi.connect() } answers { throw Error("Haha error") } + val streamer = FlagConfigStreamer(streamApi, storage, null, null, config) + var errorCount = 0 + try { + streamer.start { errorCount++ } + fail("Streamer start error not throwing") + } catch (_: Throwable) { + } + verify(exactly = 1) { streamApi.connect() } + assertEquals(0, errorCount) // No error callback as it throws directly + } + + @Test + fun `Test Streamer stream fails`(){ + every { streamApi.connect() } answers { throw Error("Haha error") } + val streamer = FlagConfigStreamer(streamApi, storage, null, null, config) + var errorCount = 0 + streamer.start { errorCount++ } + + // Stream start success + verify(exactly = 1) { streamApi.connect() } + onUpdateCapture.captured(listOf(FLAG1)) + assertEquals(mapOf(FLAG1.key to FLAG1), storage.getFlagConfigs()) + assertEquals(0, errorCount) + + // Stream fails + onErrorCapture.captured(Error("Haha error")) + assertEquals(1, errorCount) // Error callback is called + } +} + + +class FlagConfigFallbackRetryWrapperTest { + private val mainOnErrorCapture = slot<(() -> Unit)>() + private val fallbackOnErrorCapture = slot<(() -> Unit)>() + + private var mainUpdater = mockk() + private var fallbackUpdater = mockk() + @BeforeTest + fun beforeEach() { + mainUpdater = mockk() + fallbackUpdater = mockk() + + justRun { mainUpdater.start(capture(mainOnErrorCapture)) } + justRun { mainUpdater.stop() } + justRun { mainUpdater.shutdown() } + justRun { fallbackUpdater.start(capture(fallbackOnErrorCapture)) } + justRun { fallbackUpdater.stop() } + justRun { fallbackUpdater.shutdown() } + } + + @Test + fun `Test FallbackRetryWrapper main updater all success`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) + var errorCount = 0 + + // Main starts + wrapper.start { errorCount++ } + verify(exactly = 1) { mainUpdater.start(any()) } + verify(exactly = 0) { fallbackUpdater.start() } + assertEquals(0, errorCount) + + // Stop + wrapper.stop() + verify(exactly = 1) { mainUpdater.stop() } + verify(exactly = 1) { fallbackUpdater.stop() } + assertEquals(0, errorCount) + + // Start again + wrapper.start { errorCount++ } + verify(exactly = 2) { mainUpdater.start(any()) } + verify(exactly = 0) { fallbackUpdater.start() } + assertEquals(0, errorCount) + + // Shutdown + wrapper.shutdown() + verify(exactly = 1) { mainUpdater.shutdown() } + verify(exactly = 1) { mainUpdater.shutdown() } + } + + @Test + fun `Test FallbackRetryWrapper main success no fallback updater`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) + var errorCount = 0 + + // Main starts + wrapper.start { errorCount++ } + verify(exactly = 1) { mainUpdater.start(any()) } + assertEquals(0, errorCount) + + // Stop + wrapper.stop() + verify(exactly = 1) { mainUpdater.stop() } + assertEquals(0, errorCount) + + // Start again + wrapper.start { errorCount++ } + verify(exactly = 2) { mainUpdater.start(any()) } + assertEquals(0, errorCount) + + // Shutdown + wrapper.shutdown() + verify(exactly = 1) { mainUpdater.shutdown() } + } + + @Test + fun `Test FallbackRetryWrapper main start error and retries with no fallback updater`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) + var errorCount = 0 + + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + + // Main start fail, no error, same as success case + wrapper.start { errorCount++ } + verify(exactly = 1) { mainUpdater.start(any()) } + assertEquals(0, errorCount) + + // Retries start + Thread.sleep(1100) + verify(exactly = 2) { mainUpdater.start(any()) } + assertEquals(0, errorCount) + + wrapper.shutdown() + } + + @Test + fun `Test FallbackRetryWrapper main error callback and retries with no fallback updater`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) + var errorCount = 0 + + // Main start success + wrapper.start { errorCount++ } + verify(exactly = 1) { mainUpdater.start(any()) } + assertEquals(0, errorCount) + + // Signal error + mainOnErrorCapture.captured() + verify(exactly = 1) { mainUpdater.start(any()) } + assertEquals(1, errorCount) // Updater failure from success calls callback + + // Retry fail after 1s + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + Thread.sleep(1100) + verify(exactly = 2) { mainUpdater.start(any()) } + assertEquals(1, errorCount) // Updater restart error doesn't call callback + + // Retry success after 1s + justRun { mainUpdater.start(capture(mainOnErrorCapture)) } + Thread.sleep(1100) + verify(exactly = 3) { mainUpdater.start(any()) } + assertEquals(1, errorCount) + + // No more start + Thread.sleep(1100) + verify(exactly = 3) { mainUpdater.start(any()) } + assertEquals(1, errorCount) + + wrapper.shutdown() + } + + @Test + fun `Test FallbackRetryWrapper main error callback and retries`() { +// val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) +// var errorCount = 0 +// +// // Main start success +// wrapper.start { errorCount++ } +// verify(exactly = 1) { mainUpdater.start(any()) } +// verify(exactly = 0) { fallbackUpdater.start(any()) } +// assertEquals(0, errorCount) +// +// // Signal error +// mainOnErrorCapture.captured() +// verify(exactly = 1) { mainUpdater.start(any()) } +// verify(exactly = 1) { fallbackUpdater.start(any()) } +// assertEquals(0, errorCount) // Fallback succeeded, so no callback +// +// // Retry fail after 1s +// every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } +// Thread.sleep(1100) +// verify(exactly = 2) { mainUpdater.start(any()) } +// verify(exactly = 1) { fallbackUpdater.start(any()) } +// assertEquals(0, errorCount) +// +// // Signal fallback fails +// fallbackOnErrorCapture.captured() +// +// +// // Retry success after 1s +// justRun { mainUpdater.start(capture(mainOnErrorCapture)) } +// Thread.sleep(1100) +// verify(exactly = 3) { mainUpdater.start(any()) } +// assertEquals(1, errorCount) +// +// // No more start +// Thread.sleep(1100) +// verify(exactly = 3) { mainUpdater.start(any()) } +// assertEquals(1, errorCount) +// +// wrapper.shutdown() + } +} \ No newline at end of file diff --git a/src/test/kotlin/util/SseStreamTest.kt b/src/test/kotlin/util/SseStreamTest.kt new file mode 100644 index 0000000..535668d --- /dev/null +++ b/src/test/kotlin/util/SseStreamTest.kt @@ -0,0 +1,100 @@ +package com.amplitude.experiment.util + +import com.amplitude.experiment.ExperimentUser +import com.amplitude.experiment.RemoteEvaluationClient +import io.mockk.* +import okhttp3.HttpUrl +import okhttp3.HttpUrl.Companion.toHttpUrl +import okhttp3.HttpUrl.Companion.toHttpUrlOrNull +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener +import org.mockito.Mockito +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals + + +class SseStreamTest { + private val listenerCapture = slot() + val clientMock = mockk() + private val es = mockk("mocked es") + + private var data: List = listOf() + private var err: List = listOf() + + @BeforeTest + fun beforeTest() { + mockkStatic("com.amplitude.experiment.util.RequestKt") + + justRun { es.cancel() } + every { clientMock.newEventSource(any(), capture(listenerCapture)) } returns es + + mockkConstructor(OkHttpClient.Builder::class) + every { anyConstructed().build() } returns clientMock + + } + + private fun setupStream( + reconnTimeout: Long = 5000 + ): SseStream { + val stream = SseStream("authtoken", "http://localhost".toHttpUrl(), OkHttpClient(), 1000, 1000, reconnTimeout, 0) + + stream.onUpdate = { d -> + data += d + } + stream.onError = { t -> + err += t + } + return stream + } + + @Test + fun `Test SseStream connect`() { + val stream = setupStream() + stream.connect() + + listenerCapture.captured.onEvent(es, null, null, "somedata") + assertEquals(listOf("somedata"), data) + listenerCapture.captured.onFailure(es, null, null) + assertEquals("Unknown stream failure", err[0]?.message) + listenerCapture.captured.onEvent(es, null, null, "nodata") + assertEquals(listOf("somedata"), data) + + stream.cancel() + } + + @Test + fun `Test SseStream keep alive data omits`() { + val stream = setupStream(1000) + stream.connect() + + listenerCapture.captured.onEvent(es, null, null, "somedata") + assertEquals(listOf("somedata"), data) + listenerCapture.captured.onEvent(es, null, null, " ") + assertEquals(listOf("somedata"), data) + + stream.cancel() + } + + @Test + fun `Test SseStream reconnects`() { + val stream = setupStream(1000) + stream.connect() + + listenerCapture.captured.onEvent(es, null, null, "somedata") + assertEquals(listOf("somedata"), data) + verify(exactly = 1) { + clientMock.newEventSource(allAny(), allAny()) + } + + Thread.sleep(1100) // Wait 1s for reconnection + + listenerCapture.captured.onEvent(es, null, null, "somedata") + assertEquals(listOf("somedata", "somedata"), data) + verify(exactly = 2) { + clientMock.newEventSource(allAny(), allAny()) + } + } +} From 8d635203826b44f1403964b8760896cda9b69a2d Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Tue, 3 Sep 2024 04:36:52 -0700 Subject: [PATCH 04/12] remove updater onError cb, add updater tests --- .../kotlin/deployment/DeploymentRunner.kt | 5 +- src/main/kotlin/flag/FlagConfigUpdater.kt | 41 +++-- src/test/kotlin/flag/FlagConfigUpdaterTest.kt | 156 +++++++++--------- 3 files changed, 100 insertions(+), 102 deletions(-) diff --git a/src/main/kotlin/deployment/DeploymentRunner.kt b/src/main/kotlin/deployment/DeploymentRunner.kt index 39bfd79..5d13fac 100644 --- a/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/src/main/kotlin/deployment/DeploymentRunner.kt @@ -39,7 +39,10 @@ internal class DeploymentRunner( } private val cohortPollingInterval: Long = getCohortPollingInterval() // Fallback in this order: proxy, stream, poll. - private val amplitudeFlagConfigPoller = FlagConfigFallbackRetryWrapper(FlagConfigPoller(flagConfigApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), null, config.flagConfigPollerIntervalMillis, 1000) + private val amplitudeFlagConfigPoller = FlagConfigFallbackRetryWrapper( + FlagConfigPoller(flagConfigApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), + null, config.flagConfigPollerIntervalMillis, 1000 + ) private val amplitudeFlagConfigUpdater = if (flagConfigStreamApi != null) FlagConfigFallbackRetryWrapper( diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index d99b8ba..ae89058 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -181,24 +181,26 @@ internal class FlagConfigFallbackRetryWrapper( private val executor = Executors.newScheduledThreadPool(1, daemonFactory) private var retryTask: ScheduledFuture<*>? = null + /** + * Since the wrapper retries, so there will never be error case. Thus, onError will never be called. + */ override fun start(onError: (() -> Unit)?) { + if (mainUpdater is FlagConfigFallbackRetryWrapper) { + throw Error("Do not use FlagConfigFallbackRetryWrapper as main updater. Fallback updater will never be used. Rewrite retry and fallback logic.") + } + try { mainUpdater.start { - scheduleRetry(onError) // Don't care if poller start error or not, always retry. - if (fallbackUpdater != null) { - try { - fallbackUpdater.start(onError) - } catch (_: Throwable) { - onError?.invoke() - } - } else { - onError?.invoke() + scheduleRetry() // Don't care if poller start error or not, always retry. + try { + fallbackUpdater?.start() + } catch (_: Throwable) { } } } catch (t: Throwable) { Logger.e("Primary flag configs start failed, start fallback. Error: ", t) - fallbackUpdater?.start(onError) - scheduleRetry(onError) + fallbackUpdater?.start() + scheduleRetry() } } @@ -214,24 +216,19 @@ internal class FlagConfigFallbackRetryWrapper( retryTask?.cancel(true) } - private fun scheduleRetry(onError: (() -> Unit)?) { + private fun scheduleRetry() { retryTask = executor.schedule({ try { mainUpdater.start { - scheduleRetry(onError) // Don't care if poller start error or not, always retry stream. - if (fallbackUpdater != null) { - try { - fallbackUpdater.start(onError) - } catch (_: Throwable) { - onError?.invoke() - } - } else { - onError?.invoke() + scheduleRetry() // Don't care if poller start error or not, always retry stream. + try { + fallbackUpdater?.start() + } catch (_: Throwable) { } } fallbackUpdater?.stop() } catch (_: Throwable) { - scheduleRetry(onError) + scheduleRetry() } }, reconnIntervalRange.random(), TimeUnit.MILLISECONDS) } diff --git a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt index 7f88b1b..11585d6 100644 --- a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt +++ b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt @@ -181,7 +181,6 @@ class FlagConfigStreamerTest { class FlagConfigFallbackRetryWrapperTest { private val mainOnErrorCapture = slot<(() -> Unit)>() - private val fallbackOnErrorCapture = slot<(() -> Unit)>() private var mainUpdater = mockk() private var fallbackUpdater = mockk() @@ -193,159 +192,158 @@ class FlagConfigFallbackRetryWrapperTest { justRun { mainUpdater.start(capture(mainOnErrorCapture)) } justRun { mainUpdater.stop() } justRun { mainUpdater.shutdown() } - justRun { fallbackUpdater.start(capture(fallbackOnErrorCapture)) } + justRun { fallbackUpdater.start() } // Fallback is never passed onError callback, no need to capture justRun { fallbackUpdater.stop() } justRun { fallbackUpdater.shutdown() } } @Test - fun `Test FallbackRetryWrapper main updater all success`() { - val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) - var errorCount = 0 + fun `Test FallbackRetryWrapper main success no fallback updater`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) // Main starts - wrapper.start { errorCount++ } + wrapper.start() verify(exactly = 1) { mainUpdater.start(any()) } - verify(exactly = 0) { fallbackUpdater.start() } - assertEquals(0, errorCount) // Stop wrapper.stop() verify(exactly = 1) { mainUpdater.stop() } - verify(exactly = 1) { fallbackUpdater.stop() } - assertEquals(0, errorCount) // Start again - wrapper.start { errorCount++ } + wrapper.start() verify(exactly = 2) { mainUpdater.start(any()) } - verify(exactly = 0) { fallbackUpdater.start() } - assertEquals(0, errorCount) // Shutdown wrapper.shutdown() verify(exactly = 1) { mainUpdater.shutdown() } - verify(exactly = 1) { mainUpdater.shutdown() } } @Test - fun `Test FallbackRetryWrapper main success no fallback updater`() { + fun `Test FallbackRetryWrapper main start error and retries with no fallback updater`() { val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) - var errorCount = 0 + + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + + // Main start fail, no error, same as success case + wrapper.start() + verify(exactly = 1) { mainUpdater.start(any()) } + + // Retries start + Thread.sleep(1100) + verify(exactly = 2) { mainUpdater.start(any()) } + + wrapper.shutdown() + } + + @Test + fun `Test FallbackRetryWrapper main error callback and retries with no fallback updater`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) + + // Main start success + wrapper.start() + verify(exactly = 1) { mainUpdater.start(any()) } + + // Signal error + mainOnErrorCapture.captured() + verify(exactly = 1) { mainUpdater.start(any()) } + + // Retry fail after 1s + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + Thread.sleep(1100) + verify(exactly = 2) { mainUpdater.start(any()) } + + // Retry success after 1s + justRun { mainUpdater.start(capture(mainOnErrorCapture)) } + Thread.sleep(1100) + verify(exactly = 3) { mainUpdater.start(any()) } + + // No more start + Thread.sleep(1100) + verify(exactly = 3) { mainUpdater.start(any()) } + + wrapper.shutdown() + } + + @Test + fun `Test FallbackRetryWrapper main updater all success`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) // Main starts - wrapper.start { errorCount++ } + wrapper.start() verify(exactly = 1) { mainUpdater.start(any()) } - assertEquals(0, errorCount) + verify(exactly = 0) { fallbackUpdater.start() } // Stop wrapper.stop() verify(exactly = 1) { mainUpdater.stop() } - assertEquals(0, errorCount) + verify(exactly = 1) { fallbackUpdater.stop() } // Start again - wrapper.start { errorCount++ } + wrapper.start() verify(exactly = 2) { mainUpdater.start(any()) } - assertEquals(0, errorCount) + verify(exactly = 0) { fallbackUpdater.start() } // Shutdown wrapper.shutdown() verify(exactly = 1) { mainUpdater.shutdown() } + verify(exactly = 1) { mainUpdater.shutdown() } } @Test - fun `Test FallbackRetryWrapper main start error and retries with no fallback updater`() { - val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) - var errorCount = 0 + fun `Test FallbackRetryWrapper main start error and retries`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } // Main start fail, no error, same as success case - wrapper.start { errorCount++ } + wrapper.start() verify(exactly = 1) { mainUpdater.start(any()) } - assertEquals(0, errorCount) + verify(exactly = 1) { fallbackUpdater.start(any()) } // Retries start Thread.sleep(1100) verify(exactly = 2) { mainUpdater.start(any()) } - assertEquals(0, errorCount) + verify(exactly = 1) { fallbackUpdater.start(any()) } wrapper.shutdown() } @Test - fun `Test FallbackRetryWrapper main error callback and retries with no fallback updater`() { - val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) - var errorCount = 0 + fun `Test FallbackRetryWrapper main error callback and retries`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) // Main start success - wrapper.start { errorCount++ } + wrapper.start() verify(exactly = 1) { mainUpdater.start(any()) } - assertEquals(0, errorCount) + verify(exactly = 0) { fallbackUpdater.start(any()) } // Signal error mainOnErrorCapture.captured() verify(exactly = 1) { mainUpdater.start(any()) } - assertEquals(1, errorCount) // Updater failure from success calls callback + verify(exactly = 1) { fallbackUpdater.start(any()) } // Retry fail after 1s every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } Thread.sleep(1100) verify(exactly = 2) { mainUpdater.start(any()) } - assertEquals(1, errorCount) // Updater restart error doesn't call callback + verify(exactly = 1) { fallbackUpdater.start(any()) } - // Retry success after 1s + // Retry success justRun { mainUpdater.start(capture(mainOnErrorCapture)) } + verify(exactly = 0) { fallbackUpdater.stop() } Thread.sleep(1100) verify(exactly = 3) { mainUpdater.start(any()) } - assertEquals(1, errorCount) + verify(exactly = 1) { fallbackUpdater.start(any()) } + verify(exactly = 0) { mainUpdater.stop() } + verify(exactly = 1) { fallbackUpdater.stop() } // No more start Thread.sleep(1100) verify(exactly = 3) { mainUpdater.start(any()) } - assertEquals(1, errorCount) + verify(exactly = 1) { fallbackUpdater.start(any()) } + verify(exactly = 0) { mainUpdater.stop() } + verify(exactly = 1) { fallbackUpdater.stop() } wrapper.shutdown() } - - @Test - fun `Test FallbackRetryWrapper main error callback and retries`() { -// val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) -// var errorCount = 0 -// -// // Main start success -// wrapper.start { errorCount++ } -// verify(exactly = 1) { mainUpdater.start(any()) } -// verify(exactly = 0) { fallbackUpdater.start(any()) } -// assertEquals(0, errorCount) -// -// // Signal error -// mainOnErrorCapture.captured() -// verify(exactly = 1) { mainUpdater.start(any()) } -// verify(exactly = 1) { fallbackUpdater.start(any()) } -// assertEquals(0, errorCount) // Fallback succeeded, so no callback -// -// // Retry fail after 1s -// every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } -// Thread.sleep(1100) -// verify(exactly = 2) { mainUpdater.start(any()) } -// verify(exactly = 1) { fallbackUpdater.start(any()) } -// assertEquals(0, errorCount) -// -// // Signal fallback fails -// fallbackOnErrorCapture.captured() -// -// -// // Retry success after 1s -// justRun { mainUpdater.start(capture(mainOnErrorCapture)) } -// Thread.sleep(1100) -// verify(exactly = 3) { mainUpdater.start(any()) } -// assertEquals(1, errorCount) -// -// // No more start -// Thread.sleep(1100) -// verify(exactly = 3) { mainUpdater.start(any()) } -// assertEquals(1, errorCount) -// -// wrapper.shutdown() - } } \ No newline at end of file From a6814aa9c60894c4d032197a7a94a75e3f3d2ea5 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Tue, 3 Sep 2024 05:40:20 -0700 Subject: [PATCH 05/12] fix bugs, added tests --- src/main/kotlin/flag/FlagConfigUpdater.kt | 6 +- src/main/kotlin/util/SseStream.kt | 2 - src/test/kotlin/LocalEvaluationClientTest.kt | 51 ++++++++++++++ .../kotlin/flag/FlagConfigStreamApiTest.kt | 5 ++ src/test/kotlin/flag/FlagConfigUpdaterTest.kt | 70 ++++++++++++++++--- src/test/kotlin/util/SseStreamTest.kt | 7 +- 6 files changed, 129 insertions(+), 12 deletions(-) diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index ae89058..471a61b 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -199,7 +199,11 @@ internal class FlagConfigFallbackRetryWrapper( } } catch (t: Throwable) { Logger.e("Primary flag configs start failed, start fallback. Error: ", t) - fallbackUpdater?.start() + if (fallbackUpdater == null) { + // No fallback, main start failed is wrapper start fail + throw t + } + fallbackUpdater.start() scheduleRetry() } } diff --git a/src/main/kotlin/util/SseStream.kt b/src/main/kotlin/util/SseStream.kt index 956cea0..bb5e82e 100644 --- a/src/main/kotlin/util/SseStream.kt +++ b/src/main/kotlin/util/SseStream.kt @@ -68,8 +68,6 @@ internal class SseStream ( } override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { - println(t) - println(response) if ((eventSource != es)) { // Not the current event source using right now, should cancel. eventSource.cancel() diff --git a/src/test/kotlin/LocalEvaluationClientTest.kt b/src/test/kotlin/LocalEvaluationClientTest.kt index 7ad24a0..68f5b1c 100644 --- a/src/test/kotlin/LocalEvaluationClientTest.kt +++ b/src/test/kotlin/LocalEvaluationClientTest.kt @@ -2,17 +2,26 @@ package com.amplitude.experiment import com.amplitude.experiment.cohort.Cohort import com.amplitude.experiment.cohort.CohortApi +import com.amplitude.experiment.flag.FlagConfigPoller +import io.mockk.clearAllMocks import io.mockk.every import io.mockk.mockk +import io.mockk.mockkConstructor import org.junit.Assert import org.junit.Assert.assertEquals import org.junit.Assert.assertNull import kotlin.system.measureNanoTime +import kotlin.test.AfterTest +import kotlin.test.BeforeTest import kotlin.test.Test private const val API_KEY = "server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz" class LocalEvaluationClientTest { + @AfterTest + fun afterTest() { + clearAllMocks() + } @Test fun `test evaluate, all flags, success`() { @@ -193,6 +202,7 @@ class LocalEvaluationClientTest { assertEquals("on", userVariant?.key) assertEquals("on", userVariant?.value) } + @Test fun `evaluate with user, cohort tester targeted`() { val cohortConfig = LocalEvaluationConfig( @@ -238,6 +248,7 @@ class LocalEvaluationClientTest { assertEquals("on", groupVariant?.key) assertEquals("on", groupVariant?.value) } + @Test fun `evaluate with group, cohort tester targeted`() { val cohortConfig = LocalEvaluationConfig( @@ -261,4 +272,44 @@ class LocalEvaluationClientTest { assertEquals("on", groupVariant?.key) assertEquals("on", groupVariant?.value) } + + @Test + fun `test evaluate, stream flags, all flags, success`() { + mockkConstructor(FlagConfigPoller::class) + every { anyConstructed().start(any()) } answers { + throw Exception("Should use stream, may be flaky test when stream failed") + } + val client = LocalEvaluationClient(API_KEY, LocalEvaluationConfig(streamUpdates = true)) + client.start() + val variants = client.evaluate(ExperimentUser(userId = "test_user")) + val variant = variants["sdk-local-evaluation-ci-test"] + Assert.assertEquals(Variant(key = "on", value = "on", payload = "payload"), variant?.copy(metadata = null)) + } + + @Test + fun `evaluate with user, stream flags, cohort segment targeted`() { + mockkConstructor(FlagConfigPoller::class) + every { anyConstructed().start(any()) } answers { + throw Exception("Should use stream, may be flaky test when stream failed") + } + val cohortConfig = LocalEvaluationConfig( + streamUpdates = true, + cohortSyncConfig = CohortSyncConfig("api", "secret") + ) + val cohortApi = mockk().apply { + every { getCohort(eq("52gz3yi7"), allAny()) } returns Cohort("52gz3yi7", "User", 2, 1722363790000, setOf("1", "2")) + every { getCohort(eq("mv7fn2bp"), allAny()) } returns Cohort("mv7fn2bp", "User", 1, 1719350216000, setOf("67890", "12345")) + every { getCohort(eq("s4t57y32"), allAny()) } returns Cohort("s4t57y32", "org name", 1, 1722368285000, setOf("Amplitude Website (Portfolio)")) + every { getCohort(eq("k1lklnnb"), allAny()) } returns Cohort("k1lklnnb", "org id", 1, 1722466388000, setOf("1")) + } + val client = LocalEvaluationClient(API_KEY, cohortConfig, cohortApi = cohortApi) + client.start() + val user = ExperimentUser( + userId = "12345", + deviceId = "device_id", + ) + val userVariant = client.evaluateV2(user, setOf("sdk-local-evaluation-user-cohort-ci-test"))["sdk-local-evaluation-user-cohort-ci-test"] + assertEquals("on", userVariant?.key) + assertEquals("on", userVariant?.value) + } } diff --git a/src/test/kotlin/flag/FlagConfigStreamApiTest.kt b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt index a125b1f..34da4e8 100644 --- a/src/test/kotlin/flag/FlagConfigStreamApiTest.kt +++ b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt @@ -32,6 +32,11 @@ class FlagConfigStreamApiTest { every { anyConstructed().onError = capture(onErrorCapture) } answers {} } + @AfterTest + fun afterTest() { + clearAllMocks() + } + private fun setupApi( deploymentKey: String = "", serverUrl: HttpUrl = "http://localhost".toHttpUrl(), diff --git a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt index 11585d6..4ee0b15 100644 --- a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt +++ b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt @@ -4,6 +4,7 @@ import com.amplitude.experiment.LocalEvaluationConfig import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.experiment.util.SseStream import io.mockk.* +import java.lang.Exception import kotlin.test.* private val FLAG1 = EvaluationFlag("key1", emptyMap(), emptyList()) @@ -13,11 +14,16 @@ class FlagConfigPollerTest { private var storage = InMemoryFlagConfigStorage() @BeforeTest - fun beforeEach() { + fun beforeTest() { fetchApi = mockk() storage = InMemoryFlagConfigStorage() } + @AfterTest + fun afterTest() { + clearAllMocks() + } + @Test fun `Test Poller`() { every { fetchApi.getFlagConfigs() } returns emptyList() @@ -112,7 +118,7 @@ class FlagConfigStreamerTest { private val config = LocalEvaluationConfig(streamUpdates = true, streamServerUrl = "", streamFlagConnTimeoutMillis = 2000) @BeforeTest - fun beforeEach() { + fun beforeTest() { streamApi = mockk() storage = InMemoryFlagConfigStorage() @@ -120,6 +126,11 @@ class FlagConfigStreamerTest { justRun { streamApi.onError = capture(onErrorCapture) } } + @AfterTest + fun afterTest() { + clearAllMocks() + } + @Test fun `Test Poller`() { justRun { streamApi.connect() } @@ -161,7 +172,7 @@ class FlagConfigStreamerTest { @Test fun `Test Streamer stream fails`(){ - every { streamApi.connect() } answers { throw Error("Haha error") } + justRun { streamApi.connect() } val streamer = FlagConfigStreamer(streamApi, storage, null, null, config) var errorCount = 0 streamer.start { errorCount++ } @@ -173,7 +184,7 @@ class FlagConfigStreamerTest { assertEquals(0, errorCount) // Stream fails - onErrorCapture.captured(Error("Haha error")) + onErrorCapture.captured(Exception("Haha error")) assertEquals(1, errorCount) // Error callback is called } } @@ -184,8 +195,9 @@ class FlagConfigFallbackRetryWrapperTest { private var mainUpdater = mockk() private var fallbackUpdater = mockk() + @BeforeTest - fun beforeEach() { + fun beforeTest() { mainUpdater = mockk() fallbackUpdater = mockk() @@ -197,6 +209,11 @@ class FlagConfigFallbackRetryWrapperTest { justRun { fallbackUpdater.shutdown() } } + @AfterTest + fun afterTest() { + clearAllMocks() + } + @Test fun `Test FallbackRetryWrapper main success no fallback updater`() { val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, null, 1000, 0) @@ -225,12 +242,15 @@ class FlagConfigFallbackRetryWrapperTest { every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } // Main start fail, no error, same as success case - wrapper.start() + try { + wrapper.start() + fail("Start errors should throw") + } catch (_: Throwable) {} verify(exactly = 1) { mainUpdater.start(any()) } - // Retries start + // Start errors no retry Thread.sleep(1100) - verify(exactly = 2) { mainUpdater.start(any()) } + verify(exactly = 1) { mainUpdater.start(any()) } wrapper.shutdown() } @@ -289,6 +309,29 @@ class FlagConfigFallbackRetryWrapperTest { verify(exactly = 1) { mainUpdater.shutdown() } } + @Test + fun `Test FallbackRetryWrapper main and fallback start error`() { + val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) + + every { mainUpdater.start(capture(mainOnErrorCapture)) } answers { throw Error() } + every { fallbackUpdater.start() } answers { throw Error() } + + // Main start fail, no error, same as success case + try { + wrapper.start() + fail("Start errors should throw") + } catch (_: Throwable) {} + verify(exactly = 1) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + + // Start errors no retry + Thread.sleep(1100) + verify(exactly = 1) { mainUpdater.start(any()) } + verify(exactly = 1) { fallbackUpdater.start(any()) } + + wrapper.shutdown() + } + @Test fun `Test FallbackRetryWrapper main start error and retries`() { val wrapper = FlagConfigFallbackRetryWrapper(mainUpdater, fallbackUpdater, 1000, 0) @@ -346,4 +389,15 @@ class FlagConfigFallbackRetryWrapperTest { wrapper.shutdown() } + + @Test + fun `Test FallbackRetryWrapper main updater cannot be FlagConfigFallbackRetryWrapper`() { + val wrapper = FlagConfigFallbackRetryWrapper(FlagConfigFallbackRetryWrapper(mainUpdater, null), null, 1000, 0) + try { + wrapper.start() + fail("Did not throw") + } catch (_: Throwable) { + } + verify(exactly = 0) { mainUpdater.start() } + } } \ No newline at end of file diff --git a/src/test/kotlin/util/SseStreamTest.kt b/src/test/kotlin/util/SseStreamTest.kt index 535668d..edbad63 100644 --- a/src/test/kotlin/util/SseStreamTest.kt +++ b/src/test/kotlin/util/SseStreamTest.kt @@ -11,6 +11,7 @@ import okhttp3.Request import okhttp3.sse.EventSource import okhttp3.sse.EventSourceListener import org.mockito.Mockito +import kotlin.test.AfterTest import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals @@ -18,7 +19,7 @@ import kotlin.test.assertEquals class SseStreamTest { private val listenerCapture = slot() - val clientMock = mockk() + private val clientMock = mockk() private val es = mockk("mocked es") private var data: List = listOf() @@ -33,7 +34,11 @@ class SseStreamTest { mockkConstructor(OkHttpClient.Builder::class) every { anyConstructed().build() } returns clientMock + } + @AfterTest + fun afterTest() { + clearAllMocks() } private fun setupStream( From f6b7498fdd4e96470b153cddf2ad39f981abbf8d Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Tue, 3 Sep 2024 10:16:46 -0700 Subject: [PATCH 06/12] cleanup --- src/main/kotlin/LocalEvaluationClient.kt | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/main/kotlin/LocalEvaluationClient.kt b/src/main/kotlin/LocalEvaluationClient.kt index 239b0e5..10f1021 100644 --- a/src/main/kotlin/LocalEvaluationClient.kt +++ b/src/main/kotlin/LocalEvaluationClient.kt @@ -237,9 +237,3 @@ private fun getEventServerUrl( assignmentConfiguration.serverUrl } } - -fun main() { - val client = LocalEvaluationClient("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", LocalEvaluationConfig(streamUpdates = true)) - client.start() - println(client.evaluateV2(ExperimentUser("1"))) -} \ No newline at end of file From c293786b8805f883306bf5544a026cbcb5095786 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Thu, 5 Sep 2024 13:44:39 -0700 Subject: [PATCH 07/12] changed visibility modifiers --- src/main/kotlin/flag/FlagConfigStreamApi.kt | 4 ++-- src/main/kotlin/flag/FlagConfigUpdater.kt | 11 +++++++---- src/main/kotlin/util/SseStream.kt | 8 ++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt index 4ff2891..38d9c47 100644 --- a/src/main/kotlin/flag/FlagConfigStreamApi.kt +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -46,7 +46,7 @@ internal class FlagConfigStreamApi ( keepaliveTimeoutMillis, reconnIntervalMillis) - fun connect() { + internal fun connect() { val isInit = AtomicBoolean(true) val connectTimeoutFuture = CompletableFuture() val updateTimeoutFuture = CompletableFuture() @@ -124,7 +124,7 @@ internal class FlagConfigStreamApi ( throw t } - fun close() { + internal fun close() { stream.cancel() } diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index 471a61b..b62ce9d 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -34,7 +34,7 @@ internal abstract class FlagConfigUpdaterBase( private val cohortLoader: CohortLoader?, private val cohortStorage: CohortStorage?, ): FlagConfigUpdater { - fun update(flagConfigs: List) { + protected fun update(flagConfigs: List) { // Remove flags that no longer exist. val flagKeys = flagConfigs.map { it.key }.toSet() flagConfigStorage.removeIf { !flagKeys.contains(it.key) } @@ -93,11 +93,14 @@ internal class FlagConfigPoller( ): FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { - private val poller = Executors.newScheduledThreadPool(1, daemonFactory) + private val pool = Executors.newScheduledThreadPool(1, daemonFactory) private var scheduledFuture: ScheduledFuture<*>? = null override fun start(onError: (() -> Unit)?) { refresh() - scheduledFuture = poller.scheduleWithFixedDelay( + if (scheduledFuture != null) { + stop() + } + scheduledFuture = pool.scheduleWithFixedDelay( { try { refresh() @@ -121,7 +124,7 @@ internal class FlagConfigPoller( override fun shutdown() { // Stop the executor. - poller.shutdown() + pool.shutdown() } private fun refresh() { diff --git a/src/main/kotlin/util/SseStream.kt b/src/main/kotlin/util/SseStream.kt index bb5e82e..1224fcf 100644 --- a/src/main/kotlin/util/SseStream.kt +++ b/src/main/kotlin/util/SseStream.kt @@ -101,13 +101,13 @@ internal class SseStream ( private var es: EventSource? = null private var reconnectTimerTask: TimerTask? = null - var onUpdate: ((String) -> Unit)? = null - var onError: ((Throwable?) -> Unit)? = null + internal var onUpdate: ((String) -> Unit)? = null + internal var onError: ((Throwable?) -> Unit)? = null /** * Creates an event source and immediately returns. The connection is performed async. Errors are informed through callbacks. */ - fun connect() { + internal fun connect() { cancel() // Clear any existing event sources. es = client.newEventSource(request, eventSourceListener) reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. @@ -117,7 +117,7 @@ internal class SseStream ( } } - fun cancel() { + internal fun cancel() { reconnectTimerTask?.cancel() // There can be cases where an event source is being cancelled by these calls, but take a long time and made a callback to onFailure callback. From 61dd7a8fd712651c5ca402262d938fae557e06a1 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Thu, 19 Sep 2024 12:01:25 -0700 Subject: [PATCH 08/12] added locks to ensure concurrency --- src/main/kotlin/flag/FlagConfigStreamApi.kt | 131 +++++++++-------- src/main/kotlin/flag/FlagConfigUpdater.kt | 135 +++++++++++------- src/main/kotlin/util/SseStream.kt | 64 ++++++--- src/test/kotlin/flag/FlagConfigUpdaterTest.kt | 10 +- 4 files changed, 198 insertions(+), 142 deletions(-) diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt index 38d9c47..6e6e9dd 100644 --- a/src/main/kotlin/flag/FlagConfigStreamApi.kt +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -14,6 +14,8 @@ import java.util.concurrent.ExecutionException import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock internal open class FlagConfigStreamApiError(message: String?, cause: Throwable?): Exception(message, cause) { constructor(message: String?) : this(message, null) @@ -32,8 +34,9 @@ internal class FlagConfigStreamApi ( httpClient: OkHttpClient = OkHttpClient(), val connectionTimeoutMillis: Long = CONNECTION_TIMEOUT_MILLIS_DEFAULT, keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, - reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT + reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, ) { + private val lock: ReentrantLock = ReentrantLock() var onInitUpdate: ((List) -> Unit)? = null var onUpdate: ((List) -> Unit)? = null var onError: ((Exception?) -> Unit)? = null @@ -47,84 +50,88 @@ internal class FlagConfigStreamApi ( reconnIntervalMillis) internal fun connect() { - val isInit = AtomicBoolean(true) - val connectTimeoutFuture = CompletableFuture() - val updateTimeoutFuture = CompletableFuture() - stream.onUpdate = { data -> - if (isInit.getAndSet(false)) { - // Stream is establishing. First data received. - // Resolve timeout. - connectTimeoutFuture.complete(Unit) - - // Make sure valid data. - try { - val flags = getFlagsFromData(data) + // Guarded by lock. Update to callbacks and waits can lead to race conditions. + lock.withLock { + val isInit = AtomicBoolean(true) + val connectTimeoutFuture = CompletableFuture() + val updateTimeoutFuture = CompletableFuture() + stream.onUpdate = { data -> + if (isInit.getAndSet(false)) { + // Stream is establishing. First data received. + // Resolve timeout. + connectTimeoutFuture.complete(Unit) + // Make sure valid data. try { - if (onInitUpdate != null) { - onInitUpdate?.let { it(flags) } - } else { - onUpdate?.let { it(flags) } + val flags = getFlagsFromData(data) + + try { + if (onInitUpdate != null) { + onInitUpdate?.let { it(flags) } + } else { + onUpdate?.let { it(flags) } + } + updateTimeoutFuture.complete(Unit) + } catch (e: Throwable) { + updateTimeoutFuture.completeExceptionally(e) } - updateTimeoutFuture.complete(Unit) - } catch (e: Throwable) { - updateTimeoutFuture.completeExceptionally(e) + } catch (_: Throwable) { + updateTimeoutFuture.completeExceptionally(FlagConfigStreamApiDataCorruptError()) } - } catch (_: Throwable) { - updateTimeoutFuture.completeExceptionally(FlagConfigStreamApiDataCorruptError()) - } - - } else { - // Stream has already established. - // Make sure valid data. - try { - val flags = getFlagsFromData(data) + } else { + // Stream has already established. + // Make sure valid data. try { - onUpdate?.let { it(flags) } + val flags = getFlagsFromData(data) + + try { + onUpdate?.let { it(flags) } + } catch (_: Throwable) { + // Don't care about application error. + } } catch (_: Throwable) { - // Don't care about application error. + // Stream corrupted. Reconnect. + handleError(FlagConfigStreamApiDataCorruptError()) } - } catch (_: Throwable) { - // Stream corrupted. Reconnect. - handleError(FlagConfigStreamApiDataCorruptError()) - } + } } - } - stream.onError = { t -> - if (isInit.getAndSet(false)) { - connectTimeoutFuture.completeExceptionally(t) - updateTimeoutFuture.completeExceptionally(t) - } else { - handleError(FlagConfigStreamApiStreamError(t)) + stream.onError = { t -> + if (isInit.getAndSet(false)) { + connectTimeoutFuture.completeExceptionally(t) + updateTimeoutFuture.completeExceptionally(t) + } else { + handleError(FlagConfigStreamApiStreamError(t)) + } } - } - stream.connect() + stream.connect() - val t: Throwable - try { - connectTimeoutFuture.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS) - updateTimeoutFuture.get() - return - } catch (e: TimeoutException) { - // Timeouts should retry - t = FlagConfigStreamApiConnTimeoutError() - } catch (e: ExecutionException) { - val cause = e.cause - t = if (cause is StreamException) { - FlagConfigStreamApiStreamError(cause) - } else { - FlagConfigStreamApiError(e) + val t: Throwable + try { + connectTimeoutFuture.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS) + updateTimeoutFuture.get() + return + } catch (e: TimeoutException) { + // Timeouts should retry + t = FlagConfigStreamApiConnTimeoutError() + } catch (e: ExecutionException) { + val cause = e.cause + t = if (cause is StreamException) { + FlagConfigStreamApiStreamError(cause) + } else { + FlagConfigStreamApiError(e) + } + } catch (e: Throwable) { + t = FlagConfigStreamApiError(e) } - } catch (e: Throwable) { - t = FlagConfigStreamApiError(e) + close() + throw t } - close() - throw t } internal fun close() { + // Not guarded by lock. close() can halt connect(). stream.cancel() } diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index b62ce9d..5d1ce5e 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -14,6 +14,8 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.Executors import java.util.concurrent.ScheduledFuture import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock import kotlin.math.max import kotlin.math.min @@ -89,42 +91,52 @@ internal class FlagConfigPoller( private val cohortLoader: CohortLoader?, private val cohortStorage: CohortStorage?, private val config: LocalEvaluationConfig, - private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper() + private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(), ): FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { + private val lock: ReentrantLock = ReentrantLock() private val pool = Executors.newScheduledThreadPool(1, daemonFactory) - private var scheduledFuture: ScheduledFuture<*>? = null + private var scheduledFuture: ScheduledFuture<*>? = null // @GuardedBy(lock) override fun start(onError: (() -> Unit)?) { refresh() - if (scheduledFuture != null) { - stop() + lock.withLock { + stopInternal() + scheduledFuture = pool.scheduleWithFixedDelay( + { + try { + refresh() + } catch (t: Throwable) { + Logger.e("Refresh flag configs failed.", t) + stop() + onError?.invoke() + } + }, + config.flagConfigPollerIntervalMillis, + config.flagConfigPollerIntervalMillis, + TimeUnit.MILLISECONDS + ) } - scheduledFuture = pool.scheduleWithFixedDelay( - { - try { - refresh() - } catch (t: Throwable) { - Logger.e("Refresh flag configs failed.", t) - stop() - onError?.invoke() - } - }, - config.flagConfigPollerIntervalMillis, - config.flagConfigPollerIntervalMillis, - TimeUnit.MILLISECONDS - ) } - override fun stop() { + // @GuardedBy(lock) + private fun stopInternal() { // Pause only stop the task scheduled. It doesn't stop the executor. scheduledFuture?.cancel(true) scheduledFuture = null } + override fun stop() { + lock.withLock { + stopInternal() + } + } + override fun shutdown() { - // Stop the executor. - pool.shutdown() + lock.withLock { + // Stop the executor. + pool.shutdown() + } } private fun refresh() { @@ -151,21 +163,25 @@ internal class FlagConfigStreamer( ): FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { + private val lock: ReentrantLock = ReentrantLock() override fun start(onError: (() -> Unit)?) { - flagConfigStreamApi.onUpdate = {flags -> - update(flags) - } - flagConfigStreamApi.onError = {e -> - Logger.e("Stream flag configs streaming failed.", e) - metrics.onFlagConfigStreamFailure(e) - onError?.invoke() - } - wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) { - flagConfigStreamApi.connect() + lock.withLock { + flagConfigStreamApi.onUpdate = { flags -> + update(flags) + } + flagConfigStreamApi.onError = { e -> + Logger.e("Stream flag configs streaming failed.", e) + metrics.onFlagConfigStreamFailure(e) + onError?.invoke() + } + wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) { + flagConfigStreamApi.connect() + } } } override fun stop() { + // Not guarded by lock. close() can cancel start(). flagConfigStreamApi.close() } @@ -178,11 +194,12 @@ internal class FlagConfigFallbackRetryWrapper( private val mainUpdater: FlagConfigUpdater, private val fallbackUpdater: FlagConfigUpdater?, private val retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT, - private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT + private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, ): FlagConfigUpdater { + private val lock: ReentrantLock = ReentrantLock() private val reconnIntervalRange = max(0, retryDelayMillis - maxJitterMillis)..(min(retryDelayMillis, retryDelayMillis - maxJitterMillis) + maxJitterMillis) private val executor = Executors.newScheduledThreadPool(1, daemonFactory) - private var retryTask: ScheduledFuture<*>? = null + private var retryTask: ScheduledFuture<*>? = null // @GuardedBy(lock) /** * Since the wrapper retries, so there will never be error case. Thus, onError will never be called. @@ -192,37 +209,49 @@ internal class FlagConfigFallbackRetryWrapper( throw Error("Do not use FlagConfigFallbackRetryWrapper as main updater. Fallback updater will never be used. Rewrite retry and fallback logic.") } - try { - mainUpdater.start { - scheduleRetry() // Don't care if poller start error or not, always retry. - try { - fallbackUpdater?.start() - } catch (_: Throwable) { + lock.withLock { + retryTask?.cancel(true) + + try { + mainUpdater.start { + lock.withLock { + scheduleRetry() // Don't care if poller start error or not, always retry. + try { + fallbackUpdater?.start() + } catch (_: Throwable) { + } + } } + fallbackUpdater?.stop() + } catch (t: Throwable) { + Logger.e("Primary flag configs start failed, start fallback. Error: ", t) + if (fallbackUpdater == null) { + // No fallback, main start failed is wrapper start fail + throw t + } + fallbackUpdater.start() + scheduleRetry() } - } catch (t: Throwable) { - Logger.e("Primary flag configs start failed, start fallback. Error: ", t) - if (fallbackUpdater == null) { - // No fallback, main start failed is wrapper start fail - throw t - } - fallbackUpdater.start() - scheduleRetry() } } override fun stop() { - mainUpdater.stop() - fallbackUpdater?.stop() - retryTask?.cancel(true) + lock.withLock { + mainUpdater.stop() + fallbackUpdater?.stop() + retryTask?.cancel(true) + } } override fun shutdown() { - mainUpdater.shutdown() - fallbackUpdater?.shutdown() - retryTask?.cancel(true) + lock.withLock { + mainUpdater.shutdown() + fallbackUpdater?.shutdown() + retryTask?.cancel(true) + } } + // @GuardedBy(lock) private fun scheduleRetry() { retryTask = executor.schedule({ try { diff --git a/src/main/kotlin/util/SseStream.kt b/src/main/kotlin/util/SseStream.kt index 1224fcf..d80d4dc 100644 --- a/src/main/kotlin/util/SseStream.kt +++ b/src/main/kotlin/util/SseStream.kt @@ -13,7 +13,9 @@ import okhttp3.sse.EventSourceListener import okhttp3.sse.EventSources import java.util.* import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.schedule +import kotlin.concurrent.withLock import kotlin.math.max import kotlin.math.min @@ -30,8 +32,9 @@ internal class SseStream ( private val connectionTimeoutMillis: Long, private val keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, private val reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, - private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT + private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, ) { + private val lock: ReentrantLock = ReentrantLock() private val reconnIntervalRange = max(0, reconnIntervalMillis - maxJitterMillis)..(min(reconnIntervalMillis, Long.MAX_VALUE - maxJitterMillis) + maxJitterMillis) private val eventSourceListener = object : EventSourceListener() { override fun onOpen(eventSource: EventSource, response: Response) { @@ -39,13 +42,15 @@ internal class SseStream ( } override fun onClosed(eventSource: EventSource) { - if ((eventSource != es)) { - // Not the current event source using right now, should cancel. - eventSource.cancel() - return + lock.withLock { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } } // Server closed the connection, just reconnect. - cancel() + cancelInternal() connect() } @@ -55,10 +60,12 @@ internal class SseStream ( type: String?, data: String ) { - if ((eventSource != es)) { - // Not the current event source using right now, should cancel. - eventSource.cancel() - return + lock.withLock { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } } // Keep alive data if (KEEP_ALIVE_DATA == data) { @@ -68,10 +75,12 @@ internal class SseStream ( } override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { - if ((eventSource != es)) { - // Not the current event source using right now, should cancel. - eventSource.cancel() - return + lock.withLock { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } } if (t is StreamResetException && t.errorCode == ErrorCode.CANCEL) { // Relying on okhttp3.internal to differentiate cancel case. @@ -99,8 +108,8 @@ internal class SseStream ( .retryOnConnectionFailure(false) .build() - private var es: EventSource? = null - private var reconnectTimerTask: TimerTask? = null + private var es: EventSource? = null // @GuardedBy(lock) + private var reconnectTimerTask: TimerTask? = null // @GuardedBy(lock) internal var onUpdate: ((String) -> Unit)? = null internal var onError: ((Throwable?) -> Unit)? = null @@ -108,20 +117,29 @@ internal class SseStream ( * Creates an event source and immediately returns. The connection is performed async. Errors are informed through callbacks. */ internal fun connect() { - cancel() // Clear any existing event sources. - es = client.newEventSource(request, eventSourceListener) - reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. - // This forces client side reconnection after interval. - this@SseStream.cancel() - connect() + lock.withLock { + cancelInternal() // Clear any existing event sources. + es = client.newEventSource(request, eventSourceListener) + reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. + // This forces client side reconnection after interval. + this@SseStream.cancel() + connect() + } } } - internal fun cancel() { + // @GuardedBy(lock) + private fun cancelInternal() { reconnectTimerTask?.cancel() // There can be cases where an event source is being cancelled by these calls, but take a long time and made a callback to onFailure callback. es?.cancel() es = null } + + internal fun cancel() { + lock.withLock { + cancelInternal() + } + } } \ No newline at end of file diff --git a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt index 4ee0b15..042d74e 100644 --- a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt +++ b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt @@ -292,16 +292,18 @@ class FlagConfigFallbackRetryWrapperTest { wrapper.start() verify(exactly = 1) { mainUpdater.start(any()) } verify(exactly = 0) { fallbackUpdater.start() } + verify(exactly = 1) { fallbackUpdater.stop() } // Stop wrapper.stop() verify(exactly = 1) { mainUpdater.stop() } - verify(exactly = 1) { fallbackUpdater.stop() } + verify(exactly = 2) { fallbackUpdater.stop() } // Start again wrapper.start() verify(exactly = 2) { mainUpdater.start(any()) } verify(exactly = 0) { fallbackUpdater.start() } + verify(exactly = 3) { fallbackUpdater.stop() } // Shutdown wrapper.shutdown() @@ -373,19 +375,19 @@ class FlagConfigFallbackRetryWrapperTest { // Retry success justRun { mainUpdater.start(capture(mainOnErrorCapture)) } - verify(exactly = 0) { fallbackUpdater.stop() } + verify(exactly = 1) { fallbackUpdater.stop() } Thread.sleep(1100) verify(exactly = 3) { mainUpdater.start(any()) } verify(exactly = 1) { fallbackUpdater.start(any()) } verify(exactly = 0) { mainUpdater.stop() } - verify(exactly = 1) { fallbackUpdater.stop() } + verify(exactly = 2) { fallbackUpdater.stop() } // No more start Thread.sleep(1100) verify(exactly = 3) { mainUpdater.start(any()) } verify(exactly = 1) { fallbackUpdater.start(any()) } verify(exactly = 0) { mainUpdater.stop() } - verify(exactly = 1) { fallbackUpdater.stop() } + verify(exactly = 2) { fallbackUpdater.stop() } wrapper.shutdown() } From c9b8130852cfd65fbae9623cc49de9eb0e03b2ab Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Mon, 23 Sep 2024 17:19:13 -0700 Subject: [PATCH 09/12] pr comments --- .../kotlin/deployment/DeploymentRunner.kt | 11 +- src/main/kotlin/flag/FlagConfigStreamApi.kt | 46 ++++---- src/main/kotlin/flag/FlagConfigUpdater.kt | 81 +++++++++---- src/main/kotlin/util/SseStream.kt | 107 ++++++++++-------- .../kotlin/flag/FlagConfigStreamApiTest.kt | 74 ++++++------ src/test/kotlin/flag/FlagConfigUpdaterTest.kt | 26 ++--- src/test/kotlin/util/SseStreamTest.kt | 19 ++-- 7 files changed, 211 insertions(+), 153 deletions(-) diff --git a/src/main/kotlin/deployment/DeploymentRunner.kt b/src/main/kotlin/deployment/DeploymentRunner.kt index 5d13fac..50d64fc 100644 --- a/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/src/main/kotlin/deployment/DeploymentRunner.kt @@ -1,15 +1,16 @@ -@file:OptIn(ExperimentalApi::class) - package com.amplitude.experiment.deployment -import com.amplitude.experiment.* +import com.amplitude.experiment.LocalEvaluationConfig +import com.amplitude.experiment.LocalEvaluationMetrics import com.amplitude.experiment.cohort.CohortApi import com.amplitude.experiment.cohort.CohortLoader import com.amplitude.experiment.cohort.CohortStorage -import com.amplitude.experiment.flag.* import com.amplitude.experiment.flag.FlagConfigApi +import com.amplitude.experiment.flag.FlagConfigFallbackRetryWrapper import com.amplitude.experiment.flag.FlagConfigPoller import com.amplitude.experiment.flag.FlagConfigStorage +import com.amplitude.experiment.flag.FlagConfigStreamApi +import com.amplitude.experiment.flag.FlagConfigStreamer import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper import com.amplitude.experiment.util.Logger import com.amplitude.experiment.util.Once @@ -46,7 +47,7 @@ internal class DeploymentRunner( private val amplitudeFlagConfigUpdater = if (flagConfigStreamApi != null) FlagConfigFallbackRetryWrapper( - FlagConfigStreamer(flagConfigStreamApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), + FlagConfigStreamer(flagConfigStreamApi, flagConfigStorage, cohortLoader, cohortStorage, metrics), amplitudeFlagConfigPoller, ) else amplitudeFlagConfigPoller diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt index 6e6e9dd..6b68b32 100644 --- a/src/main/kotlin/flag/FlagConfigStreamApi.kt +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -1,14 +1,12 @@ package com.amplitude.experiment.flag import com.amplitude.experiment.evaluation.EvaluationFlag -import com.amplitude.experiment.util.* import com.amplitude.experiment.util.SseStream +import com.amplitude.experiment.util.StreamException +import com.amplitude.experiment.util.json import kotlinx.serialization.decodeFromString import okhttp3.HttpUrl import okhttp3.OkHttpClient -import okhttp3.Request -import okhttp3.sse.EventSource -import okhttp3.sse.EventSourceListener import java.util.concurrent.CompletableFuture import java.util.concurrent.ExecutionException import java.util.concurrent.TimeUnit @@ -37,9 +35,6 @@ internal class FlagConfigStreamApi ( reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, ) { private val lock: ReentrantLock = ReentrantLock() - var onInitUpdate: ((List) -> Unit)? = null - var onUpdate: ((List) -> Unit)? = null - var onError: ((Exception?) -> Unit)? = null val url = serverUrl.newBuilder().addPathSegments("sdk/stream/v1/flags").build() private val stream: SseStream = SseStream( "Api-Key $deploymentKey", @@ -49,14 +44,23 @@ internal class FlagConfigStreamApi ( keepaliveTimeoutMillis, reconnIntervalMillis) - internal fun connect() { + /** + * Connects to flag configs stream. + * This will ensure stream connects, first set of flags is received and processed successfully, then returns. + * If stream fails to connect, first set of flags is not received, or first set of flags did not process successfully, it throws. + */ + internal fun connect( + onInitUpdate: ((List) -> Unit)? = null, + onUpdate: ((List) -> Unit)? = null, + onError: ((Exception?) -> Unit)? = null + ) { // Guarded by lock. Update to callbacks and waits can lead to race conditions. lock.withLock { - val isInit = AtomicBoolean(true) + val isDuringInit = AtomicBoolean(true) val connectTimeoutFuture = CompletableFuture() val updateTimeoutFuture = CompletableFuture() - stream.onUpdate = { data -> - if (isInit.getAndSet(false)) { + val onSseUpdate: ((String) -> Unit) = { data -> + if (isDuringInit.getAndSet(false)) { // Stream is establishing. First data received. // Resolve timeout. connectTimeoutFuture.complete(Unit) @@ -67,9 +71,9 @@ internal class FlagConfigStreamApi ( try { if (onInitUpdate != null) { - onInitUpdate?.let { it(flags) } + onInitUpdate.invoke(flags) } else { - onUpdate?.let { it(flags) } + onUpdate?.invoke(flags) } updateTimeoutFuture.complete(Unit) } catch (e: Throwable) { @@ -86,26 +90,26 @@ internal class FlagConfigStreamApi ( val flags = getFlagsFromData(data) try { - onUpdate?.let { it(flags) } + onUpdate?.invoke(flags) } catch (_: Throwable) { // Don't care about application error. } } catch (_: Throwable) { // Stream corrupted. Reconnect. - handleError(FlagConfigStreamApiDataCorruptError()) + handleError(onError, FlagConfigStreamApiDataCorruptError()) } } } - stream.onError = { t -> - if (isInit.getAndSet(false)) { + val onSseError: ((Throwable?) -> Unit) = { t -> + if (isDuringInit.getAndSet(false)) { connectTimeoutFuture.completeExceptionally(t) updateTimeoutFuture.completeExceptionally(t) } else { - handleError(FlagConfigStreamApiStreamError(t)) + handleError(onError, FlagConfigStreamApiStreamError(t)) } } - stream.connect() + stream.connect(onSseUpdate, onSseError) val t: Throwable try { @@ -139,8 +143,8 @@ internal class FlagConfigStreamApi ( return json.decodeFromString>(data) } - private fun handleError(e: Exception?) { + private fun handleError(onError: ((Exception?) -> Unit)?, e: Exception?) { close() - onError?.let { it(e) } + onError?.invoke(e) } } diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index 5d1ce5e..5e1fa4b 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -5,9 +5,10 @@ import com.amplitude.experiment.LocalEvaluationMetrics import com.amplitude.experiment.cohort.CohortLoader import com.amplitude.experiment.cohort.CohortStorage import com.amplitude.experiment.evaluation.EvaluationFlag -import com.amplitude.experiment.util.* +import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper import com.amplitude.experiment.util.Logger import com.amplitude.experiment.util.daemonFactory +import com.amplitude.experiment.util.getAllCohortIds import com.amplitude.experiment.util.wrapMetrics import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap @@ -19,23 +20,43 @@ import kotlin.concurrent.withLock import kotlin.math.max import kotlin.math.min +/** + * Flag config updaters should receive flags through their own means (ex. http GET, SSE stream), + * or as wrapper of others. + * They all should have these methods to control their lifecycle. + */ internal interface FlagConfigUpdater { - // Start the updater. There can be multiple calls. - // If start fails, it should throw exception. The caller should handle error. - // If some other error happened while updating (already started successfully), it should call onError. + /** + * Start the updater. There can be multiple calls. + * If start fails, it should throw exception. The caller should handle error. + * If some other error happened while updating (already started successfully), it should call onError. + */ fun start(onError: (() -> Unit)? = null) - // Stop should stop updater temporarily. There may be another start in the future. - // To stop completely, with intention to never start again, use shutdown() instead. + + /** + * Stop should stop updater temporarily. There may be another start in the future. + * To stop completely, with intention to never start again, use shutdown() instead. + */ fun stop() - // Destroy should stop the updater forever in preparation for server shutdown. + + /** + * Destroy should stop the updater forever in preparation for server shutdown. + */ fun shutdown() } +/** + * All flag config updaters should share this class, which contains a function to properly process flag updates. + */ internal abstract class FlagConfigUpdaterBase( private val flagConfigStorage: FlagConfigStorage, private val cohortLoader: CohortLoader?, private val cohortStorage: CohortStorage?, -): FlagConfigUpdater { +) { + /** + * Call this method after receiving and parsing flag configs from network. + * This method updates flag configs into storage and download all cohorts if needed. + */ protected fun update(flagConfigs: List) { // Remove flags that no longer exist. val flagKeys = flagConfigs.map { it.key }.toSet() @@ -85,19 +106,27 @@ internal abstract class FlagConfigUpdaterBase( } } +/** + * This is the poller for flag configs. + * It keeps polling flag configs with specified interval until error occurs. + */ internal class FlagConfigPoller( private val flagConfigApi: FlagConfigApi, - private val storage: FlagConfigStorage, - private val cohortLoader: CohortLoader?, - private val cohortStorage: CohortStorage?, + storage: FlagConfigStorage, + cohortLoader: CohortLoader?, + cohortStorage: CohortStorage?, private val config: LocalEvaluationConfig, private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(), -): FlagConfigUpdaterBase( +): FlagConfigUpdater, FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { private val lock: ReentrantLock = ReentrantLock() private val pool = Executors.newScheduledThreadPool(1, daemonFactory) private var scheduledFuture: ScheduledFuture<*>? = null // @GuardedBy(lock) + + /** + * Start will fetch once, then start poller to poll flag configs. + */ override fun start(onError: (() -> Unit)?) { refresh() lock.withLock { @@ -153,29 +182,36 @@ internal class FlagConfigPoller( } } +/** + * Streamer for flag configs. This receives flag updates with an SSE connection. + */ internal class FlagConfigStreamer( private val flagConfigStreamApi: FlagConfigStreamApi, - private val storage: FlagConfigStorage, - private val cohortLoader: CohortLoader?, - private val cohortStorage: CohortStorage?, - private val config: LocalEvaluationConfig, + storage: FlagConfigStorage, + cohortLoader: CohortLoader?, + cohortStorage: CohortStorage?, private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper() -): FlagConfigUpdaterBase( +): FlagConfigUpdater, FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { private val lock: ReentrantLock = ReentrantLock() + + /** + * Start makes sure it connects to stream and the first set of flag configs is loaded. + * Then, it will update the flags whenever there's a stream. + */ override fun start(onError: (() -> Unit)?) { lock.withLock { - flagConfigStreamApi.onUpdate = { flags -> + val onStreamUpdate: ((List) -> Unit) = { flags -> update(flags) } - flagConfigStreamApi.onError = { e -> + val onStreamError: ((Exception?) -> Unit) = { e -> Logger.e("Stream flag configs streaming failed.", e) metrics.onFlagConfigStreamFailure(e) onError?.invoke() } wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) { - flagConfigStreamApi.connect() + flagConfigStreamApi.connect(onStreamUpdate, onStreamUpdate, onStreamError) } } } @@ -190,11 +226,12 @@ internal class FlagConfigStreamer( private const val RETRY_DELAY_MILLIS_DEFAULT = 15 * 1000L private const val MAX_JITTER_MILLIS_DEFAULT = 2000L + internal class FlagConfigFallbackRetryWrapper( private val mainUpdater: FlagConfigUpdater, private val fallbackUpdater: FlagConfigUpdater?, - private val retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT, - private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, + retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT, + maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, ): FlagConfigUpdater { private val lock: ReentrantLock = ReentrantLock() private val reconnIntervalRange = max(0, retryDelayMillis - maxJitterMillis)..(min(retryDelayMillis, retryDelayMillis - maxJitterMillis) + maxJitterMillis) diff --git a/src/main/kotlin/util/SseStream.kt b/src/main/kotlin/util/SseStream.kt index d80d4dc..2703ba3 100644 --- a/src/main/kotlin/util/SseStream.kt +++ b/src/main/kotlin/util/SseStream.kt @@ -1,17 +1,14 @@ package com.amplitude.experiment.util -import com.amplitude.experiment.LIBRARY_VERSION import okhttp3.HttpUrl -import okhttp3.HttpUrl.Companion.toHttpUrl import okhttp3.OkHttpClient -import okhttp3.Request import okhttp3.Response import okhttp3.internal.http2.ErrorCode import okhttp3.internal.http2.StreamResetException import okhttp3.sse.EventSource import okhttp3.sse.EventSourceListener -import okhttp3.sse.EventSources -import java.util.* +import java.util.Timer +import java.util.TimerTask import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.schedule @@ -25,17 +22,36 @@ private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 0L // no timeout private const val RECONN_INTERVAL_MILLIS_DEFAULT = 30 * 60 * 1000L private const val MAX_JITTER_MILLIS_DEFAULT = 5000L private const val KEEP_ALIVE_DATA = " " + +/** + * For establishing an SSE stream. + */ internal class SseStream ( - private val authToken: String, - private val url: HttpUrl, - private val httpClient: OkHttpClient = OkHttpClient(), - private val connectionTimeoutMillis: Long, - private val keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, - private val reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, - private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, + authToken: String, // Will be used in header as Authorization: + url: HttpUrl, // The full url to connect to. + httpClient: OkHttpClient = OkHttpClient(), + connectionTimeoutMillis: Long, // Timeout for establishing a connection, not including reading body. + keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, // Keep alive should receive within this timeout. + reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, // Reconnect every this interval. + maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, // Jitter for reconnection. ) { private val lock: ReentrantLock = ReentrantLock() private val reconnIntervalRange = max(0, reconnIntervalMillis - maxJitterMillis)..(min(reconnIntervalMillis, Long.MAX_VALUE - maxJitterMillis) + maxJitterMillis) + + private val request = newGet(url, null, mapOf("Authorization" to authToken, "Accept" to "text/event-stream")) + private val client = httpClient.newBuilder() // client.newBuilder reuses the connection pool in the same client with new configs. + .connectTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Connection timeout for establishing SSE. + .callTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Call timeout for establishing SSE. + .readTimeout(keepaliveTimeoutMillis, TimeUnit.MILLISECONDS) // Timeout between messages, keepalive in this case. + .writeTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) + .retryOnConnectionFailure(false) + .build() + + private var es: EventSource? = null // @GuardedBy(lock) + private var reconnectTimerTask: TimerTask? = null // @GuardedBy(lock) + private var onUpdate: ((String) -> Unit)? = null + private var onError: ((Throwable?) -> Unit)? = null + private val eventSourceListener = object : EventSourceListener() { override fun onOpen(eventSource: EventSource, response: Response) { // No action needed. @@ -43,15 +59,15 @@ internal class SseStream ( override fun onClosed(eventSource: EventSource) { lock.withLock { - if ((eventSource != es)) { + if ((eventSource != es)) { // Reference comparison. // Not the current event source using right now, should cancel. eventSource.cancel() return } + // Server closed the connection, just reconnect. + cancelSse() } - // Server closed the connection, just reconnect. - cancelInternal() - connect() + connect(onUpdate, onError) } override fun onEvent( @@ -61,7 +77,7 @@ internal class SseStream ( data: String ) { lock.withLock { - if ((eventSource != es)) { + if ((eventSource != es)) { // Reference comparison. // Not the current event source using right now, should cancel. eventSource.cancel() return @@ -71,65 +87,57 @@ internal class SseStream ( if (KEEP_ALIVE_DATA == data) { return } - onUpdate?.let { it(data) } + onUpdate?.invoke(data) } override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { lock.withLock { - if ((eventSource != es)) { + if ((eventSource != es)) { // Reference comparison. // Not the current event source using right now, should cancel. eventSource.cancel() return } + if (t is StreamResetException && t.errorCode == ErrorCode.CANCEL) { + // Relying on okhttp3.internal to differentiate cancel case. + // Can be a pitfall later on. + return + } + cancelSse() } - if (t is StreamResetException && t.errorCode == ErrorCode.CANCEL) { - // Relying on okhttp3.internal to differentiate cancel case. - // Can be a pitfall later on. - return - } - cancel() val err = t ?: if (response != null) { StreamException(response.toString()) } else { StreamException("Unknown stream failure") } - onError?.let { it(err) } + onError?.invoke(err) } } - private val request = newGet(url, null, mapOf("Authorization" to authToken, "Accept" to "text/event-stream")) - - private val client = httpClient.newBuilder() // client.newBuilder reuses the connection pool in the same client with new configs. - .connectTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Connection timeout for establishing SSE. - .callTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Call timeout for establishing SSE. - .readTimeout(keepaliveTimeoutMillis, TimeUnit.MILLISECONDS) // Timeout between messages, keepalive in this case. - .writeTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) - .retryOnConnectionFailure(false) - .build() - - private var es: EventSource? = null // @GuardedBy(lock) - private var reconnectTimerTask: TimerTask? = null // @GuardedBy(lock) - internal var onUpdate: ((String) -> Unit)? = null - internal var onError: ((Throwable?) -> Unit)? = null - /** - * Creates an event source and immediately returns. The connection is performed async. Errors are informed through callbacks. + * Creates an event source and immediately returns. The connection is performed async. All errors are informed through callbacks. + * This the call may success even if stream cannot be established. + * + * @param onUpdate On stream update, this callback will be called. + * @param onError On stream error, this callback will be called. */ - internal fun connect() { + internal fun connect(onUpdate: ((String) -> Unit)?, onError: ((Throwable?) -> Unit)?) { lock.withLock { - cancelInternal() // Clear any existing event sources. + cancelSse() // Clear any existing event sources. + + this.onUpdate = onUpdate + this.onError = onError es = client.newEventSource(request, eventSourceListener) reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. // This forces client side reconnection after interval. this@SseStream.cancel() - connect() + connect(onUpdate, onError) } } } // @GuardedBy(lock) - private fun cancelInternal() { + private fun cancelSse() { reconnectTimerTask?.cancel() // There can be cases where an event source is being cancelled by these calls, but take a long time and made a callback to onFailure callback. @@ -137,9 +145,14 @@ internal class SseStream ( es = null } + /** + * Cancels the current connection. + */ internal fun cancel() { lock.withLock { - cancelInternal() + cancelSse() + this.onUpdate = null + this.onError = null } } } \ No newline at end of file diff --git a/src/test/kotlin/flag/FlagConfigStreamApiTest.kt b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt index 34da4e8..2872184 100644 --- a/src/test/kotlin/flag/FlagConfigStreamApiTest.kt +++ b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt @@ -3,33 +3,41 @@ package com.amplitude.experiment.flag import com.amplitude.experiment.Experiment import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.experiment.util.SseStream -import io.mockk.* +import io.mockk.clearAllMocks +import io.mockk.every +import io.mockk.mockkConstructor +import io.mockk.slot +import io.mockk.verify import okhttp3.HttpUrl import okhttp3.HttpUrl.Companion.toHttpUrl import okhttp3.OkHttpClient import java.util.concurrent.CompletableFuture import java.util.concurrent.TimeUnit -import kotlin.test.* +import kotlin.test.AfterTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.fail class FlagConfigStreamApiTest { private val onUpdateCapture = slot<((String) -> Unit)>() private val onErrorCapture = slot<((Throwable?) -> Unit)>() - private var data: Array> = arrayOf() - private var err: Array = arrayOf() - @BeforeTest fun beforeTest() { mockkConstructor(SseStream::class) - every { anyConstructed().connect() } answers { + every { anyConstructed().connect(capture(onUpdateCapture), capture(onErrorCapture)) } answers { Thread.sleep(1000) } every { anyConstructed().cancel() } answers { Thread.sleep(1000) } - every { anyConstructed().onUpdate = capture(onUpdateCapture) } answers {} - every { anyConstructed().onError = capture(onErrorCapture) } answers {} + } + + private fun anyConstructed(): Any { + TODO("Not yet implemented") } @AfterTest @@ -44,30 +52,29 @@ class FlagConfigStreamApiTest { ): FlagConfigStreamApi { val api = FlagConfigStreamApi(deploymentKey, serverUrl, OkHttpClient(), connTimeout, 10000) - api.onUpdate = { d -> - data += d - } - api.onError = { t -> - err += t - } return api } @Test fun `Test passes correct arguments`() { val api = setupApi("deplkey", "https://test.example.com".toHttpUrl()) - api.onInitUpdate = { d -> - data += d - } + var data: Array> = arrayOf() + var err: Array = arrayOf() val run = async { - api.connect() + api.connect({ d -> + data += d + }, { d -> + data += d + }, { t -> + err += t + }) } Thread.sleep(100) onUpdateCapture.captured("[{\"key\":\"flagkey\",\"variants\":{},\"segments\":[]}]") run.join() - verify { anyConstructed().connect() } + verify { anyConstructed().connect(any(), any()) } assertContentEquals(arrayOf(listOf(EvaluationFlag("flagkey", emptyMap(), emptyList()))), data) api.close() @@ -89,13 +96,11 @@ class FlagConfigStreamApiTest { fun `Test init update failure throws`() { val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 2000) - api.onInitUpdate = { d -> - Thread.sleep(2100) // Update time is not included in connection timeout. - throw Error() - } - api.onUpdate = null try { - api.connect() + api.connect({ + Thread.sleep(2100) // Update time is not included in connection timeout. + throw Error() + }) fail("Timeout not thrown") } catch (_: FlagConfigStreamApiConnTimeoutError) { } @@ -106,13 +111,11 @@ class FlagConfigStreamApiTest { fun `Test init update fallbacks to onUpdate when onInitUpdate = null`() { val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 2000) - api.onInitUpdate = null - api.onUpdate = { d -> - Thread.sleep(2100) // Update time is not included in connection timeout. - throw Error() - } try { - api.connect() + api.connect(null, { + Thread.sleep(2100) // Update time is not included in connection timeout. + throw Error() + }) fail("Timeout not thrown") } catch (_: FlagConfigStreamApiConnTimeoutError) { } @@ -122,9 +125,16 @@ class FlagConfigStreamApiTest { @Test fun `Test error is passed through onError`() { val api = setupApi("deplkey", "https://test.example.com".toHttpUrl(), 2000) + var err: Array = arrayOf() val run = async { - api.connect() + api.connect({d -> + assertEquals(listOf(), d) + }, {d -> + assertEquals(listOf(), d) + }, { t -> + err += t + }) } Thread.sleep(100) onUpdateCapture.captured("[]") diff --git a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt index 042d74e..c274f85 100644 --- a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt +++ b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt @@ -111,19 +111,16 @@ class FlagConfigPollerTest { } class FlagConfigStreamerTest { +// private val onInitUpdateCapture = slot<((List) -> Unit)>() private val onUpdateCapture = slot<((List) -> Unit)>() private val onErrorCapture = slot<((Throwable?) -> Unit)>() private var streamApi = mockk() private var storage = InMemoryFlagConfigStorage() - private val config = LocalEvaluationConfig(streamUpdates = true, streamServerUrl = "", streamFlagConnTimeoutMillis = 2000) @BeforeTest fun beforeTest() { streamApi = mockk() storage = InMemoryFlagConfigStorage() - - justRun { streamApi.onUpdate = capture(onUpdateCapture) } - justRun { streamApi.onError = capture(onErrorCapture) } } @AfterTest @@ -133,13 +130,13 @@ class FlagConfigStreamerTest { @Test fun `Test Poller`() { - justRun { streamApi.connect() } - val streamer = FlagConfigStreamer(streamApi, storage, null, null, config) + justRun { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } + val streamer = FlagConfigStreamer(streamApi, storage, null, null) var errorCount = 0 streamer.start { errorCount++ } // Streamer starts - verify(exactly = 1) { streamApi.connect() } + verify(exactly = 1) { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } // Verify update callback updates storage onUpdateCapture.captured(emptyList()) @@ -150,7 +147,7 @@ class FlagConfigStreamerTest { assertEquals(mapOf(FLAG1.key to FLAG1, FLAG2.key to FLAG2), storage.getFlagConfigs()) // No extra connect calls - verify(exactly = 1) { streamApi.connect() } + verify(exactly = 1) { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } // No errors assertEquals(0, errorCount) @@ -158,27 +155,27 @@ class FlagConfigStreamerTest { @Test fun `Test Streamer start fails`(){ - every { streamApi.connect() } answers { throw Error("Haha error") } - val streamer = FlagConfigStreamer(streamApi, storage, null, null, config) + every { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } answers { throw Error("Haha error") } + val streamer = FlagConfigStreamer(streamApi, storage, null, null) var errorCount = 0 try { streamer.start { errorCount++ } fail("Streamer start error not throwing") } catch (_: Throwable) { } - verify(exactly = 1) { streamApi.connect() } + verify(exactly = 1) { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } assertEquals(0, errorCount) // No error callback as it throws directly } @Test fun `Test Streamer stream fails`(){ - justRun { streamApi.connect() } - val streamer = FlagConfigStreamer(streamApi, storage, null, null, config) + justRun { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } + val streamer = FlagConfigStreamer(streamApi, storage, null, null) var errorCount = 0 streamer.start { errorCount++ } // Stream start success - verify(exactly = 1) { streamApi.connect() } + verify(exactly = 1) { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } onUpdateCapture.captured(listOf(FLAG1)) assertEquals(mapOf(FLAG1.key to FLAG1), storage.getFlagConfigs()) assertEquals(0, errorCount) @@ -189,7 +186,6 @@ class FlagConfigStreamerTest { } } - class FlagConfigFallbackRetryWrapperTest { private val mainOnErrorCapture = slot<(() -> Unit)>() diff --git a/src/test/kotlin/util/SseStreamTest.kt b/src/test/kotlin/util/SseStreamTest.kt index edbad63..2dbbd07 100644 --- a/src/test/kotlin/util/SseStreamTest.kt +++ b/src/test/kotlin/util/SseStreamTest.kt @@ -41,24 +41,23 @@ class SseStreamTest { clearAllMocks() } - private fun setupStream( + private fun setupAndConnectStream( reconnTimeout: Long = 5000 ): SseStream { val stream = SseStream("authtoken", "http://localhost".toHttpUrl(), OkHttpClient(), 1000, 1000, reconnTimeout, 0) - stream.onUpdate = { d -> + stream.connect({ d -> data += d - } - stream.onError = { t -> + }, { t -> err += t - } + }) + return stream } @Test fun `Test SseStream connect`() { - val stream = setupStream() - stream.connect() + val stream = setupAndConnectStream() listenerCapture.captured.onEvent(es, null, null, "somedata") assertEquals(listOf("somedata"), data) @@ -72,8 +71,7 @@ class SseStreamTest { @Test fun `Test SseStream keep alive data omits`() { - val stream = setupStream(1000) - stream.connect() + val stream = setupAndConnectStream(1000) listenerCapture.captured.onEvent(es, null, null, "somedata") assertEquals(listOf("somedata"), data) @@ -85,8 +83,7 @@ class SseStreamTest { @Test fun `Test SseStream reconnects`() { - val stream = setupStream(1000) - stream.connect() + val stream = setupAndConnectStream(1000) listenerCapture.captured.onEvent(es, null, null, "somedata") assertEquals(listOf("somedata"), data) From 5357083cf884a2bfcdd3b3781a5a5486210320a9 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Mon, 23 Sep 2024 17:29:28 -0700 Subject: [PATCH 10/12] style --- src/main/kotlin/deployment/DeploymentRunner.kt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/kotlin/deployment/DeploymentRunner.kt b/src/main/kotlin/deployment/DeploymentRunner.kt index 50d64fc..af0879f 100644 --- a/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/src/main/kotlin/deployment/DeploymentRunner.kt @@ -20,6 +20,7 @@ import java.util.concurrent.Executors import java.util.concurrent.TimeUnit private const val MIN_COHORT_POLLING_INTERVAL = 60000L +private const val FLAG_POLLING_JITTER = 1000L internal class DeploymentRunner( private val config: LocalEvaluationConfig, @@ -42,13 +43,15 @@ internal class DeploymentRunner( // Fallback in this order: proxy, stream, poll. private val amplitudeFlagConfigPoller = FlagConfigFallbackRetryWrapper( FlagConfigPoller(flagConfigApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), - null, config.flagConfigPollerIntervalMillis, 1000 + null, + config.flagConfigPollerIntervalMillis, ) private val amplitudeFlagConfigUpdater = if (flagConfigStreamApi != null) FlagConfigFallbackRetryWrapper( FlagConfigStreamer(flagConfigStreamApi, flagConfigStorage, cohortLoader, cohortStorage, metrics), amplitudeFlagConfigPoller, + FLAG_POLLING_JITTER ) else amplitudeFlagConfigPoller private val flagConfigUpdater = From 24c01ae0db514dbea794eebc459fe0085be88c65 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Mon, 23 Sep 2024 17:44:15 -0700 Subject: [PATCH 11/12] added comment, fix max jitter for retry wrapper --- src/main/kotlin/flag/FlagConfigUpdater.kt | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index 5e1fa4b..5cc2b2e 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -227,6 +227,14 @@ internal class FlagConfigStreamer( private const val RETRY_DELAY_MILLIS_DEFAULT = 15 * 1000L private const val MAX_JITTER_MILLIS_DEFAULT = 2000L +/** + * This is a wrapper class around flag config updaters. + * This provides retry capability in case errors encountered during update asynchronously, as well as fallbacks when an updater failed. + * + * `mainUpdater` cannot be a FlagConfigFallbackRetryWrapper. + * The developer should restructure arguments to make sure `mainUpdater` is never a `FlagConfigFallbackRetryWrapper`. + * All retry and fallback structures can be normalized into `mainUpdater`s not being `FlagConfigFallbackRetryWrapper`s. + */ internal class FlagConfigFallbackRetryWrapper( private val mainUpdater: FlagConfigUpdater, private val fallbackUpdater: FlagConfigUpdater?, @@ -234,12 +242,20 @@ internal class FlagConfigFallbackRetryWrapper( maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, ): FlagConfigUpdater { private val lock: ReentrantLock = ReentrantLock() - private val reconnIntervalRange = max(0, retryDelayMillis - maxJitterMillis)..(min(retryDelayMillis, retryDelayMillis - maxJitterMillis) + maxJitterMillis) + private val reconnIntervalRange = max(0, retryDelayMillis - maxJitterMillis)..(min(retryDelayMillis, Long.MAX_VALUE - maxJitterMillis) + maxJitterMillis) private val executor = Executors.newScheduledThreadPool(1, daemonFactory) private var retryTask: ScheduledFuture<*>? = null // @GuardedBy(lock) /** - * Since the wrapper retries, so there will never be error case. Thus, onError will never be called. + * Since the wrapper retries for mainUpdater, so there will never be error case. Thus, onError will never be called. + * + * During start, the wrapper tries to start main updater. + * If main start success, start success. + * If main start failed, fallback updater tries to start. + * If fallback start failed as well, throws exception. + * If fallback start success, start success, main enters retry loop. + * After started, if main failed, fallback is started and main enters retry loop. + * Fallback success or failures status is not monitored. It's suggested to wrap fallback into a retry wrapper. */ override fun start(onError: (() -> Unit)?) { if (mainUpdater is FlagConfigFallbackRetryWrapper) { From bf316e0f90ee380f492c8a0d43485c6ad3982fb8 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Tue, 24 Sep 2024 21:40:32 -0700 Subject: [PATCH 12/12] lint --- src/main/kotlin/LocalEvaluationClient.kt | 5 ---- src/main/kotlin/flag/FlagConfigStreamApi.kt | 15 ++++++------ src/main/kotlin/flag/FlagConfigUpdater.kt | 9 +++---- src/main/kotlin/util/SseStream.kt | 8 +++---- src/test/kotlin/LocalEvaluationClientTest.kt | 1 - .../kotlin/flag/FlagConfigStreamApiTest.kt | 4 ++-- src/test/kotlin/flag/FlagConfigUpdaterTest.kt | 24 ++++++++++++------- src/test/kotlin/util/SseStreamTest.kt | 16 ++++++------- 8 files changed, 42 insertions(+), 40 deletions(-) diff --git a/src/main/kotlin/LocalEvaluationClient.kt b/src/main/kotlin/LocalEvaluationClient.kt index 2dfaee5..d70ed6f 100644 --- a/src/main/kotlin/LocalEvaluationClient.kt +++ b/src/main/kotlin/LocalEvaluationClient.kt @@ -19,7 +19,6 @@ import com.amplitude.experiment.evaluation.EvaluationEngineImpl import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.experiment.evaluation.topologicalSort import com.amplitude.experiment.flag.DynamicFlagConfigApi -import com.amplitude.experiment.flag.FlagConfigPoller import com.amplitude.experiment.flag.FlagConfigStreamApi import com.amplitude.experiment.flag.InMemoryFlagConfigStorage import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper @@ -34,10 +33,6 @@ import com.amplitude.experiment.util.wrapMetrics import okhttp3.HttpUrl import okhttp3.HttpUrl.Companion.toHttpUrl import okhttp3.OkHttpClient -import okhttp3.Request -import okhttp3.sse.EventSource -import okhttp3.sse.EventSourceListener -import okhttp3.sse.EventSources class LocalEvaluationClient internal constructor( apiKey: String, diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt index 6b68b32..58b89de 100644 --- a/src/main/kotlin/flag/FlagConfigStreamApi.kt +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -15,18 +15,18 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.withLock -internal open class FlagConfigStreamApiError(message: String?, cause: Throwable?): Exception(message, cause) { +internal open class FlagConfigStreamApiError(message: String?, cause: Throwable?) : Exception(message, cause) { constructor(message: String?) : this(message, null) constructor(cause: Throwable?) : this(cause?.toString(), cause) } -internal class FlagConfigStreamApiConnTimeoutError: FlagConfigStreamApiError("Initial connection timed out") -internal class FlagConfigStreamApiDataCorruptError: FlagConfigStreamApiError("Stream data corrupted") -internal class FlagConfigStreamApiStreamError(cause: Throwable?): FlagConfigStreamApiError("Stream error", cause) +internal class FlagConfigStreamApiConnTimeoutError : FlagConfigStreamApiError("Initial connection timed out") +internal class FlagConfigStreamApiDataCorruptError : FlagConfigStreamApiError("Stream data corrupted") +internal class FlagConfigStreamApiStreamError(cause: Throwable?) : FlagConfigStreamApiError("Stream error", cause) private const val CONNECTION_TIMEOUT_MILLIS_DEFAULT = 1500L private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 17000L // keep alive sends at 15s interval. 2s grace period private const val RECONN_INTERVAL_MILLIS_DEFAULT = 15 * 60 * 1000L -internal class FlagConfigStreamApi ( +internal class FlagConfigStreamApi( deploymentKey: String, serverUrl: HttpUrl, httpClient: OkHttpClient = OkHttpClient(), @@ -42,7 +42,8 @@ internal class FlagConfigStreamApi ( httpClient, connectionTimeoutMillis, keepaliveTimeoutMillis, - reconnIntervalMillis) + reconnIntervalMillis + ) /** * Connects to flag configs stream. @@ -82,7 +83,6 @@ internal class FlagConfigStreamApi ( } catch (_: Throwable) { updateTimeoutFuture.completeExceptionally(FlagConfigStreamApiDataCorruptError()) } - } else { // Stream has already established. // Make sure valid data. @@ -98,7 +98,6 @@ internal class FlagConfigStreamApi ( // Stream corrupted. Reconnect. handleError(onError, FlagConfigStreamApiDataCorruptError()) } - } } val onSseError: ((Throwable?) -> Unit) = { t -> diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt index 5cc2b2e..0412a48 100644 --- a/src/main/kotlin/flag/FlagConfigUpdater.kt +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -117,7 +117,7 @@ internal class FlagConfigPoller( cohortStorage: CohortStorage?, private val config: LocalEvaluationConfig, private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(), -): FlagConfigUpdater, FlagConfigUpdaterBase( +) : FlagConfigUpdater, FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { private val lock: ReentrantLock = ReentrantLock() @@ -191,7 +191,7 @@ internal class FlagConfigStreamer( cohortLoader: CohortLoader?, cohortStorage: CohortStorage?, private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper() -): FlagConfigUpdater, FlagConfigUpdaterBase( +) : FlagConfigUpdater, FlagConfigUpdaterBase( storage, cohortLoader, cohortStorage ) { private val lock: ReentrantLock = ReentrantLock() @@ -240,7 +240,7 @@ internal class FlagConfigFallbackRetryWrapper( private val fallbackUpdater: FlagConfigUpdater?, retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT, maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT, -): FlagConfigUpdater { +) : FlagConfigUpdater { private val lock: ReentrantLock = ReentrantLock() private val reconnIntervalRange = max(0, retryDelayMillis - maxJitterMillis)..(min(retryDelayMillis, Long.MAX_VALUE - maxJitterMillis) + maxJitterMillis) private val executor = Executors.newScheduledThreadPool(1, daemonFactory) @@ -320,5 +320,6 @@ internal class FlagConfigFallbackRetryWrapper( scheduleRetry() } }, reconnIntervalRange.random(), TimeUnit.MILLISECONDS) + } } -} \ No newline at end of file + \ No newline at end of file diff --git a/src/main/kotlin/util/SseStream.kt b/src/main/kotlin/util/SseStream.kt index 2703ba3..0bd7f94 100644 --- a/src/main/kotlin/util/SseStream.kt +++ b/src/main/kotlin/util/SseStream.kt @@ -16,7 +16,7 @@ import kotlin.concurrent.withLock import kotlin.math.max import kotlin.math.min -internal class StreamException(error: String): Throwable(error) +internal class StreamException(error: String) : Throwable(error) private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 0L // no timeout private const val RECONN_INTERVAL_MILLIS_DEFAULT = 30 * 60 * 1000L @@ -26,7 +26,7 @@ private const val KEEP_ALIVE_DATA = " " /** * For establishing an SSE stream. */ -internal class SseStream ( +internal class SseStream( authToken: String, // Will be used in header as Authorization: url: HttpUrl, // The full url to connect to. httpClient: OkHttpClient = OkHttpClient(), @@ -128,7 +128,7 @@ internal class SseStream ( this.onUpdate = onUpdate this.onError = onError es = client.newEventSource(request, eventSourceListener) - reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. + reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) { // Timer for a new event source. // This forces client side reconnection after interval. this@SseStream.cancel() connect(onUpdate, onError) @@ -155,4 +155,4 @@ internal class SseStream ( this.onError = null } } -} \ No newline at end of file +} diff --git a/src/test/kotlin/LocalEvaluationClientTest.kt b/src/test/kotlin/LocalEvaluationClientTest.kt index 68f5b1c..9f612da 100644 --- a/src/test/kotlin/LocalEvaluationClientTest.kt +++ b/src/test/kotlin/LocalEvaluationClientTest.kt @@ -12,7 +12,6 @@ import org.junit.Assert.assertEquals import org.junit.Assert.assertNull import kotlin.system.measureNanoTime import kotlin.test.AfterTest -import kotlin.test.BeforeTest import kotlin.test.Test private const val API_KEY = "server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz" diff --git a/src/test/kotlin/flag/FlagConfigStreamApiTest.kt b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt index 2872184..2d553bc 100644 --- a/src/test/kotlin/flag/FlagConfigStreamApiTest.kt +++ b/src/test/kotlin/flag/FlagConfigStreamApiTest.kt @@ -128,9 +128,9 @@ class FlagConfigStreamApiTest { var err: Array = arrayOf() val run = async { - api.connect({d -> + api.connect({ d -> assertEquals(listOf(), d) - }, {d -> + }, { d -> assertEquals(listOf(), d) }, { t -> err += t diff --git a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt index c274f85..a0db63c 100644 --- a/src/test/kotlin/flag/FlagConfigUpdaterTest.kt +++ b/src/test/kotlin/flag/FlagConfigUpdaterTest.kt @@ -2,10 +2,18 @@ package com.amplitude.experiment.flag import com.amplitude.experiment.LocalEvaluationConfig import com.amplitude.experiment.evaluation.EvaluationFlag -import com.amplitude.experiment.util.SseStream -import io.mockk.* +import io.mockk.clearAllMocks +import io.mockk.every +import io.mockk.justRun +import io.mockk.mockk +import io.mockk.slot +import io.mockk.verify +import org.junit.Assert.assertEquals import java.lang.Exception -import kotlin.test.* +import kotlin.test.AfterTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.fail private val FLAG1 = EvaluationFlag("key1", emptyMap(), emptyList()) private val FLAG2 = EvaluationFlag("key2", emptyMap(), emptyList()) @@ -65,7 +73,7 @@ class FlagConfigPollerTest { } @Test - fun `Test Poller start fails`(){ + fun `Test Poller start fails`() { every { fetchApi.getFlagConfigs() } answers { throw Error("Haha error") } val poller = FlagConfigPoller(fetchApi, storage, null, null, LocalEvaluationConfig(flagConfigPollerIntervalMillis = 1000)) var errorCount = 0 @@ -85,7 +93,7 @@ class FlagConfigPollerTest { } @Test - fun `Test Poller poll fails`(){ + fun `Test Poller poll fails`() { every { fetchApi.getFlagConfigs() } returns emptyList() val poller = FlagConfigPoller(fetchApi, storage, null, null, LocalEvaluationConfig(flagConfigPollerIntervalMillis = 1000)) var errorCount = 0 @@ -154,7 +162,7 @@ class FlagConfigStreamerTest { } @Test - fun `Test Streamer start fails`(){ + fun `Test Streamer start fails`() { every { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } answers { throw Error("Haha error") } val streamer = FlagConfigStreamer(streamApi, storage, null, null) var errorCount = 0 @@ -168,7 +176,7 @@ class FlagConfigStreamerTest { } @Test - fun `Test Streamer stream fails`(){ + fun `Test Streamer stream fails`() { justRun { streamApi.connect(capture(onUpdateCapture), capture(onUpdateCapture), capture(onErrorCapture)) } val streamer = FlagConfigStreamer(streamApi, storage, null, null) var errorCount = 0 @@ -398,4 +406,4 @@ class FlagConfigFallbackRetryWrapperTest { } verify(exactly = 0) { mainUpdater.start() } } -} \ No newline at end of file +} diff --git a/src/test/kotlin/util/SseStreamTest.kt b/src/test/kotlin/util/SseStreamTest.kt index 2dbbd07..58f13cd 100644 --- a/src/test/kotlin/util/SseStreamTest.kt +++ b/src/test/kotlin/util/SseStreamTest.kt @@ -1,22 +1,22 @@ package com.amplitude.experiment.util -import com.amplitude.experiment.ExperimentUser -import com.amplitude.experiment.RemoteEvaluationClient -import io.mockk.* -import okhttp3.HttpUrl +import io.mockk.clearAllMocks +import io.mockk.every +import io.mockk.justRun +import io.mockk.mockk +import io.mockk.mockkConstructor +import io.mockk.mockkStatic +import io.mockk.slot +import io.mockk.verify import okhttp3.HttpUrl.Companion.toHttpUrl -import okhttp3.HttpUrl.Companion.toHttpUrlOrNull import okhttp3.OkHttpClient -import okhttp3.Request import okhttp3.sse.EventSource import okhttp3.sse.EventSourceListener -import org.mockito.Mockito import kotlin.test.AfterTest import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals - class SseStreamTest { private val listenerCapture = slot() private val clientMock = mockk()