diff --git a/core/src/main/kotlin/Config.kt b/core/src/main/kotlin/Config.kt index 4e3f4ad..4bb2bb4 100644 --- a/core/src/main/kotlin/Config.kt +++ b/core/src/main/kotlin/Config.kt @@ -12,7 +12,7 @@ import java.io.File @Serializable data class ProjectsFile( - val projects: List + val projects: List, ) { companion object { fun fromEnv(): ProjectsFile { @@ -35,10 +35,9 @@ data class ProjectsFile( @Serializable data class ConfigurationFile( - val configuration: Configuration = Configuration() + val configuration: Configuration = Configuration(), ) { companion object { - fun fromEnv(): ConfigurationFile { val configuration = Configuration.fromEnv() return ConfigurationFile(configuration) @@ -61,13 +60,16 @@ data class ConfigurationFile( data class ProjectConfiguration( val apiKey: String, val secretKey: String, - val managementKey: String + val managementKey: String, ) { companion object { fun fromEnv(): ProjectConfiguration { val apiKey = checkNotNull(stringEnv(EnvKey.API_KEY)) { "${EnvKey.API_KEY} environment variable must be set." } val secretKey = checkNotNull(stringEnv(EnvKey.SECRET_KEY)) { "${EnvKey.SECRET_KEY} environment variable must be set." } - val managementKey = checkNotNull(stringEnv(EnvKey.EXPERIMENT_MANAGEMENT_KEY)) { "${EnvKey.EXPERIMENT_MANAGEMENT_KEY} environment variable must be set." } + val managementKey = + checkNotNull( + stringEnv(EnvKey.EXPERIMENT_MANAGEMENT_KEY), + ) { "${EnvKey.EXPERIMENT_MANAGEMENT_KEY} environment variable must be set." } return ProjectConfiguration(apiKey, secretKey, managementKey) } } @@ -86,32 +88,36 @@ data class Configuration( val cohortSyncIntervalMillis: Long = Default.COHORT_SYNC_INTERVAL_MILLIS, val maxCohortSize: Int = Default.MAX_COHORT_SIZE, val assignment: AssignmentConfiguration = AssignmentConfiguration(), - val redis: RedisConfiguration? = null + val redis: RedisConfiguration? = null, ) { companion object { - fun fromEnv() = Configuration( - port = intEnv(EnvKey.PORT, Default.PORT)!!, - serverZone = stringEnv(EnvKey.SERVER_ZONE, Default.SERVER_ZONE)!!, - serverUrl = stringEnv(EnvKey.SERVER_URL, Default.US_SERVER_URL)!!, - cohortServerUrl = stringEnv(EnvKey.COHORT_SERVER_URL, Default.US_COHORT_SERVER_URL)!!, - managementServerUrl = stringEnv(EnvKey.MANAGEMENT_SERVER_URL, Default.US_MANAGEMENT_SERVER_URL)!!, - analyticsServerUrl = stringEnv(EnvKey.ANALYTICS_SERVER_URL, Default.US_ANALYTICS_SERVER_URL)!!, - deploymentSyncIntervalMillis = longEnv( - EnvKey.DEPLOYMENT_SYNC_INTERVAL_MILLIS, - Default.DEPLOYMENT_SYNC_INTERVAL_MILLIS - )!!, - flagSyncIntervalMillis = longEnv( - EnvKey.FLAG_SYNC_INTERVAL_MILLIS, - Default.FLAG_SYNC_INTERVAL_MILLIS - )!!, - cohortSyncIntervalMillis = longEnv( - EnvKey.COHORT_SYNC_INTERVAL_MILLIS, - Default.COHORT_SYNC_INTERVAL_MILLIS - )!!, - maxCohortSize = intEnv(EnvKey.MAX_COHORT_SIZE, Default.MAX_COHORT_SIZE)!!, - assignment = AssignmentConfiguration.fromEnv(), - redis = RedisConfiguration.fromEnv() - ) + fun fromEnv() = + Configuration( + port = intEnv(EnvKey.PORT, Default.PORT)!!, + serverZone = stringEnv(EnvKey.SERVER_ZONE, Default.SERVER_ZONE)!!, + serverUrl = stringEnv(EnvKey.SERVER_URL, Default.US_SERVER_URL)!!, + cohortServerUrl = stringEnv(EnvKey.COHORT_SERVER_URL, Default.US_COHORT_SERVER_URL)!!, + managementServerUrl = stringEnv(EnvKey.MANAGEMENT_SERVER_URL, Default.US_MANAGEMENT_SERVER_URL)!!, + analyticsServerUrl = stringEnv(EnvKey.ANALYTICS_SERVER_URL, Default.US_ANALYTICS_SERVER_URL)!!, + deploymentSyncIntervalMillis = + longEnv( + EnvKey.DEPLOYMENT_SYNC_INTERVAL_MILLIS, + Default.DEPLOYMENT_SYNC_INTERVAL_MILLIS, + )!!, + flagSyncIntervalMillis = + longEnv( + EnvKey.FLAG_SYNC_INTERVAL_MILLIS, + Default.FLAG_SYNC_INTERVAL_MILLIS, + )!!, + cohortSyncIntervalMillis = + longEnv( + EnvKey.COHORT_SYNC_INTERVAL_MILLIS, + Default.COHORT_SYNC_INTERVAL_MILLIS, + )!!, + maxCohortSize = intEnv(EnvKey.MAX_COHORT_SIZE, Default.MAX_COHORT_SIZE)!!, + assignment = AssignmentConfiguration.fromEnv(), + redis = RedisConfiguration.fromEnv(), + ) } } @@ -120,27 +126,32 @@ data class AssignmentConfiguration( val filterCapacity: Int = Default.ASSIGNMENT_FILTER_CAPACITY, val eventUploadThreshold: Int = Default.ASSIGNMENT_EVENT_UPLOAD_THRESHOLD, val eventUploadPeriodMillis: Int = Default.ASSIGNMENT_EVENT_UPLOAD_PERIOD_MILLIS, - val useBatchMode: Boolean = Default.ASSIGNMENT_USE_BATCH_MODE + val useBatchMode: Boolean = Default.ASSIGNMENT_USE_BATCH_MODE, ) { companion object { - fun fromEnv() = AssignmentConfiguration( - filterCapacity = intEnv( - EnvKey.ASSIGNMENT_FILTER_CAPACITY, - Default.ASSIGNMENT_FILTER_CAPACITY - )!!, - eventUploadThreshold = intEnv( - EnvKey.ASSIGNMENT_EVENT_UPLOAD_THRESHOLD, - Default.ASSIGNMENT_EVENT_UPLOAD_THRESHOLD - )!!, - eventUploadPeriodMillis = intEnv( - EnvKey.ASSIGNMENT_EVENT_UPLOAD_PERIOD_MILLIS, - Default.ASSIGNMENT_EVENT_UPLOAD_PERIOD_MILLIS - )!!, - useBatchMode = booleanEnv( - EnvKey.ASSIGNMENT_USE_BATCH_MODE, - Default.ASSIGNMENT_USE_BATCH_MODE + fun fromEnv() = + AssignmentConfiguration( + filterCapacity = + intEnv( + EnvKey.ASSIGNMENT_FILTER_CAPACITY, + Default.ASSIGNMENT_FILTER_CAPACITY, + )!!, + eventUploadThreshold = + intEnv( + EnvKey.ASSIGNMENT_EVENT_UPLOAD_THRESHOLD, + Default.ASSIGNMENT_EVENT_UPLOAD_THRESHOLD, + )!!, + eventUploadPeriodMillis = + intEnv( + EnvKey.ASSIGNMENT_EVENT_UPLOAD_PERIOD_MILLIS, + Default.ASSIGNMENT_EVENT_UPLOAD_PERIOD_MILLIS, + )!!, + useBatchMode = + booleanEnv( + EnvKey.ASSIGNMENT_USE_BATCH_MODE, + Default.ASSIGNMENT_USE_BATCH_MODE, + ), ) - ) } } @@ -148,7 +159,7 @@ data class AssignmentConfiguration( data class RedisConfiguration( val uri: String? = null, val readOnlyUri: String? = uri, - val prefix: String = Default.REDIS_PREFIX + val prefix: String = Default.REDIS_PREFIX, ) { companion object { fun fromEnv(): RedisConfiguration? { @@ -159,7 +170,7 @@ data class RedisConfiguration( RedisConfiguration( uri = redisUri, readOnlyUri = redisReadOnlyUri, - prefix = redisPrefix + prefix = redisPrefix, ) } else { null diff --git a/core/src/main/kotlin/EvaluationProxy.kt b/core/src/main/kotlin/EvaluationProxy.kt index f6c35b0..6727695 100644 --- a/core/src/main/kotlin/EvaluationProxy.kt +++ b/core/src/main/kotlin/EvaluationProxy.kt @@ -30,15 +30,26 @@ import kotlin.time.toDuration const val EVALUATION_PROXY_VERSION = "0.4.7" +class EvaluationProxyResponseException( + val response: EvaluationProxyResponse, +) : Exception("Evaluation proxy response error: $response") + data class EvaluationProxyResponse( val status: HttpStatusCode, - val body: String + val body: String, ) { companion object { - fun error(status: HttpStatusCode, message: String): EvaluationProxyResponse { + fun error( + status: HttpStatusCode, + message: String, + ): EvaluationProxyResponse { return EvaluationProxyResponse(status, message) } - inline fun json(status: HttpStatusCode, response: T): EvaluationProxyResponse { + + inline fun json( + status: HttpStatusCode, + response: T, + ): EvaluationProxyResponse { return EvaluationProxyResponse(status, json.encodeToString(response)) } } @@ -48,9 +59,8 @@ class EvaluationProxy internal constructor( private val projectConfigurations: List, private val configuration: Configuration, private val projectStorage: ProjectStorage, - metricsHandler: MetricsHandler? = null, + metrics: MetricsHandler? = null, ) { - constructor( projectConfigurations: List, configuration: Configuration = Configuration(), @@ -67,7 +77,7 @@ class EvaluationProxy internal constructor( } init { - Metrics.handler = metricsHandler + Metrics.handler = metrics } private val supervisor = SupervisorJob() @@ -96,12 +106,13 @@ class EvaluationProxy internal constructor( val projectId = deployments.first().projectId log.info("Fetched ${deployments.size} deployments for project $projectId") // Add the project to local mappings. - val project = Project( - id = projectId, - apiKey = projectConfiguration.apiKey, - secretKey = projectConfiguration.secretKey, - managementKey = projectConfiguration.managementKey - ) + val project = + Project( + id = projectId, + apiKey = projectConfiguration.apiKey, + secretKey = projectConfiguration.secretKey, + managementKey = projectConfiguration.managementKey, + ) apiKeysToProject[project.apiKey] = project secretKeysToProject[project.secretKey] = project for (deployment in deployments) { @@ -115,8 +126,9 @@ class EvaluationProxy internal constructor( /* * Update project storage with configured projects, and clean up * projects that have been removed. + * + * Add all configured projects to storage. */ - // Add all configured projects to storage val projectIds = projectProxies.map { it.key.id }.toSet() for (projectId in projectIds) { log.debug("Adding project $projectId") @@ -167,98 +179,151 @@ class EvaluationProxy internal constructor( log.info("Evaluation proxy started.") } - suspend fun shutdown() = coroutineScope { - log.info("Shutting down evaluation proxy.") - projectProxies.map { scope.launch { it.value.shutdown() } }.joinAll() - supervisor.cancelAndJoin() - log.info("Evaluation proxy shut down.") - } + suspend fun shutdown() = + coroutineScope { + log.info("Shutting down evaluation proxy.") + projectProxies.map { scope.launch { it.value.shutdown() } }.joinAll() + supervisor.cancelAndJoin() + log.info("Evaluation proxy shut down.") + } // Apis - suspend fun getFlagConfigs( - deploymentKey: String? - ): EvaluationProxyResponse { - val project = getProject(deploymentKey) - ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") - val projectProxy = getProjectProxy(project) - ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") - return projectProxy.getFlagConfigs(deploymentKey) - } + suspend fun getFlagConfigs(deploymentKey: String?): EvaluationProxyResponse = + Metrics.wrapRequestMetric({ EvaluationProxyGetFlagsRequest }, { EvaluationProxyGetFlagsRequestError(it) }) { + val project = + getProject(deploymentKey) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.Unauthorized, + "Invalid deployment", + ) + val projectProxy = + getProjectProxy(project) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.InternalServerError, + "Project proxy not found for project.", + ) + return@wrapRequestMetric projectProxy.getFlagConfigs(deploymentKey) + } suspend fun getCohort( apiKey: String?, secretKey: String?, cohortId: String?, lastModified: Long?, - maxCohortSize: Int? - ): EvaluationProxyResponse { - val project = getProject(apiKey, secretKey) - ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid api or secret key") - val projectProxy = getProjectProxy(project) - ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") - return projectProxy.getCohort(cohortId, lastModified, maxCohortSize) - } + maxCohortSize: Int?, + ): EvaluationProxyResponse = + Metrics.wrapRequestMetric({ EvaluationProxyGetCohortRequest }, { EvaluationProxyGetCohortRequestError(it) }) { + val project = + getProject(apiKey, secretKey) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.Unauthorized, + "Invalid api or secret key", + ) + val projectProxy = + getProjectProxy(project) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.InternalServerError, + "Project proxy not found for project.", + ) + return@wrapRequestMetric projectProxy.getCohort(cohortId, lastModified, maxCohortSize) + } suspend fun getCohortMemberships( deploymentKey: String?, groupType: String?, - groupName: String? - ): EvaluationProxyResponse { - val project = getProject(deploymentKey) - ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") - val projectProxy = getProjectProxy(project) - ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") - return projectProxy.getCohortMemberships(deploymentKey, groupType, groupName) - } + groupName: String?, + ): EvaluationProxyResponse = + Metrics.wrapRequestMetric({ EvaluationProxyGetMembershipsRequest }, { EvaluationProxyGetMembershipsRequestError(it) }) { + val project = + getProject(deploymentKey) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.Unauthorized, + "Invalid deployment", + ) + val projectProxy = + getProjectProxy(project) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.InternalServerError, + "Project proxy not found for project.", + ) + return@wrapRequestMetric projectProxy.getCohortMemberships(deploymentKey, groupType, groupName) + } suspend fun evaluate( deploymentKey: String?, user: Map?, - flagKeys: Set? = null - ): EvaluationProxyResponse { - val project = getProject(deploymentKey) - ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") - val projectProxy = getProjectProxy(project) - ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") - return Metrics.with({ Evaluation }, { e -> EvaluationFailure(e) }) { - projectProxy.evaluate(deploymentKey, user, flagKeys) + flagKeys: Set? = null, + ): EvaluationProxyResponse = + Metrics.wrapRequestMetric({ EvaluationProxyEvaluationRequest }, { EvaluationProxyEvaluationRequestError(it) }) { + val project = + getProject(deploymentKey) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.Unauthorized, + "Invalid deployment", + ) + val projectProxy = + getProjectProxy(project) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.InternalServerError, + "Project proxy not found for project.", + ) + return@wrapRequestMetric Metrics.with({ Evaluation }, { e -> EvaluationFailure(e) }) { + projectProxy.evaluate(deploymentKey, user, flagKeys) + } } - } suspend fun evaluateV1( deploymentKey: String?, user: Map?, - flagKeys: Set? = null - ): EvaluationProxyResponse { - val project = getProject(deploymentKey) - ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") - val projectProxy = getProjectProxy(project) - ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") - return Metrics.with({ Evaluation }, { e -> EvaluationFailure(e) }) { - projectProxy.evaluateV1(deploymentKey, user, flagKeys) + flagKeys: Set? = null, + ): EvaluationProxyResponse = + Metrics.wrapRequestMetric({ EvaluationProxyEvaluationRequest }, { EvaluationProxyEvaluationRequestError(it) }) { + val project = + getProject(deploymentKey) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.Unauthorized, + "Invalid deployment", + ) + val projectProxy = + getProjectProxy(project) + ?: return@wrapRequestMetric EvaluationProxyResponse.error( + HttpStatusCode.InternalServerError, + "Project proxy not found for project.", + ) + return@wrapRequestMetric Metrics.with({ Evaluation }, { e -> EvaluationFailure(e) }) { + projectProxy.evaluateV1(deploymentKey, user, flagKeys) + } } - } // Private private suspend fun getProject(deploymentKey: String?): Project? { - val project = mutex.withLock { - deploymentKeysToProject[deploymentKey] - } + val project = + mutex.withLock { + deploymentKeysToProject[deploymentKey] + } if (project == null) { - log.warn("Unable to find project for deployment {}. Current mappings: {}", deploymentKey, deploymentKeysToProject.mapValues { it.value.id }) + log.warn( + "Unable to find project for deployment {}. Current mappings: {}", + deploymentKey, + deploymentKeysToProject.mapValues { it.value.id }, + ) return null } return project } - private suspend fun getProject(apiKey: String?, secretKey: String?): Project? { - val project = mutex.withLock { - apiKeysToProject[apiKey] - } + private suspend fun getProject( + apiKey: String?, + secretKey: String?, + ): Project? { + val project = + mutex.withLock { + apiKeysToProject[apiKey] + } if (project == null) { - log.warn("Unable to find project for deployment {}. Current mappings: {}", apiKey, apiKeysToProject.mapValues { it.value.id }) + log.warn("Unable to find project for api key {}. Current mappings: {}", apiKey, apiKeysToProject.mapValues { it.value.id }) return null } if (project.secretKey != secretKey) { @@ -280,29 +345,31 @@ class EvaluationProxy internal constructor( internal fun createProjectApi(managementKey: String): ProjectApi { return ProjectApiV1( configuration.managementServerUrl, - managementKey + managementKey, ) } @VisibleForTesting internal fun createProjectProxy(project: Project): ProjectProxy { - val assignmentTracker = AmplitudeAssignmentTracker( - project.apiKey, - configuration.analyticsServerUrl, - configuration.assignment - ) + val assignmentTracker = + AmplitudeAssignmentTracker( + project.apiKey, + configuration.analyticsServerUrl, + configuration.assignment, + ) val deploymentStorage = getDeploymentStorage(project.id, configuration.redis) - val cohortStorage = getCohortStorage( - project.id, - configuration.redis, - configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS) - ) + val cohortStorage = + getCohortStorage( + project.id, + configuration.redis, + configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS), + ) return ProjectProxy( project, configuration, assignmentTracker, deploymentStorage, - cohortStorage + cohortStorage, ) } @@ -316,7 +383,23 @@ class EvaluationProxy internal constructor( return getCohortStorage( projectId, configuration.redis, - configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS) + configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS), ) } + + private suspend fun Metrics.wrapRequestMetric( + metric: (() -> Metric)?, + failure: ((e: Exception) -> FailureMetric)?, + block: suspend () -> EvaluationProxyResponse, + ): EvaluationProxyResponse { + track(EvaluationProxyRequest) + metric?.invoke() + val response = block() + if (response.status.value >= 400) { + val exception = EvaluationProxyResponseException(response) + track(EvaluationProxyRequestError(exception)) + failure?.invoke(exception) + } + return response + } } diff --git a/core/src/main/kotlin/Metrics.kt b/core/src/main/kotlin/Metrics.kt index 0356895..150270f 100644 --- a/core/src/main/kotlin/Metrics.kt +++ b/core/src/main/kotlin/Metrics.kt @@ -1,8 +1,7 @@ package com.amplitude -import com.amplitude.project.InMemoryProjectStorage - sealed class Metric + sealed class FailureMetric : Metric() interface MetricsHandler { @@ -10,22 +9,54 @@ interface MetricsHandler { } data object Evaluation : Metric() + data class EvaluationFailure(val exception: Exception) : FailureMetric() + data object AssignmentEvent : Metric() + data object AssignmentEventFilter : Metric() + data object AssignmentEventSend : Metric() + data class AssignmentEventSendFailure(val exception: Exception) : FailureMetric() + data object DeploymentsFetch : Metric() + data class DeploymentsFetchFailure(val exception: Exception) : FailureMetric() + data object FlagsFetch : Metric() + data class FlagsFetchFailure(val exception: Exception) : FailureMetric() + data object CohortDownload : Metric() + data class CohortDownloadFailure(val exception: Exception) : FailureMetric() + data object RedisCommand : Metric() + data class RedisCommandFailure(val exception: Exception) : FailureMetric() -internal object Metrics : MetricsHandler { +data object EvaluationProxyRequest : Metric() + +data class EvaluationProxyRequestError(val exception: Exception) : FailureMetric() + +data object EvaluationProxyGetFlagsRequest : Metric() + +data class EvaluationProxyGetFlagsRequestError(val exception: Exception) : FailureMetric() + +data object EvaluationProxyGetCohortRequest : Metric() +data class EvaluationProxyGetCohortRequestError(val exception: Exception) : FailureMetric() + +data object EvaluationProxyGetMembershipsRequest : Metric() + +data class EvaluationProxyGetMembershipsRequestError(val exception: Exception) : FailureMetric() + +data object EvaluationProxyEvaluationRequest : Metric() + +data class EvaluationProxyEvaluationRequestError(val exception: Exception) : FailureMetric() + +internal object Metrics : MetricsHandler { internal var handler: MetricsHandler? = null override fun track(metric: Metric) { @@ -35,7 +66,7 @@ internal object Metrics : MetricsHandler { internal suspend fun with( metric: (() -> Metric)?, failure: ((e: Exception) -> FailureMetric)?, - block: suspend () -> R + block: suspend () -> R, ): R { try { metric?.invoke() diff --git a/core/src/main/kotlin/assignment/Assignment.kt b/core/src/main/kotlin/assignment/Assignment.kt index 8b96d3b..2f8e148 100644 --- a/core/src/main/kotlin/assignment/Assignment.kt +++ b/core/src/main/kotlin/assignment/Assignment.kt @@ -10,7 +10,7 @@ internal const val DAY_MILLIS: Long = 24 * 60 * 60 * 1000 internal data class Assignment( val context: EvaluationContext, val results: Map, - val timestamp: Long = System.currentTimeMillis() + val timestamp: Long = System.currentTimeMillis(), ) internal fun Assignment.canonicalize(): String { diff --git a/core/src/main/kotlin/assignment/AssignmentFilter.kt b/core/src/main/kotlin/assignment/AssignmentFilter.kt index 9090434..1717cac 100644 --- a/core/src/main/kotlin/assignment/AssignmentFilter.kt +++ b/core/src/main/kotlin/assignment/AssignmentFilter.kt @@ -7,7 +7,6 @@ internal interface AssignmentFilter { } internal class InMemoryAssignmentFilter(size: Int) : AssignmentFilter { - // Cache of canonical assignment to the last sent timestamp. private val cache = Cache(size, DAY_MILLIS) diff --git a/core/src/main/kotlin/assignment/AssignmentTracker.kt b/core/src/main/kotlin/assignment/AssignmentTracker.kt index 0ea0a72..352a2a6 100644 --- a/core/src/main/kotlin/assignment/AssignmentTracker.kt +++ b/core/src/main/kotlin/assignment/AssignmentTracker.kt @@ -28,9 +28,8 @@ internal interface AssignmentTracker { internal class AmplitudeAssignmentTracker( private val amplitude: Amplitude, - private val assignmentFilter: AssignmentFilter + private val assignmentFilter: AssignmentFilter, ) : AssignmentTracker { - companion object { val log by logger() } @@ -38,16 +37,17 @@ internal class AmplitudeAssignmentTracker( constructor( apiKey: String, serverUrl: String, - config: AssignmentConfiguration + config: AssignmentConfiguration, ) : this ( - amplitude = Amplitude.getInstance().apply { - setServerUrl(serverUrl) - setEventUploadThreshold(config.eventUploadThreshold) - setEventUploadPeriodMillis(config.eventUploadPeriodMillis) - useBatchMode(config.useBatchMode) - init(apiKey) - }, - assignmentFilter = InMemoryAssignmentFilter(config.filterCapacity) + amplitude = + Amplitude.getInstance().apply { + setServerUrl(serverUrl) + setEventUploadThreshold(config.eventUploadThreshold) + setEventUploadPeriodMillis(config.eventUploadPeriodMillis) + useBatchMode(config.useBatchMode) + init(apiKey) + }, + assignmentFilter = InMemoryAssignmentFilter(config.filterCapacity), ) override suspend fun track(assignment: Assignment) { @@ -67,42 +67,45 @@ internal class AmplitudeAssignmentTracker( } internal fun Assignment.toAmplitudeEvent(): Event { - val event = Event( - "[Experiment] Assignment", - this.context.userId(), - this.context.deviceId() - ) + val event = + Event( + "[Experiment] Assignment", + this.context.userId(), + this.context.deviceId(), + ) val groups = this.context.groups() if (!groups.isNullOrEmpty()) { event.groups = JSONObject(groups) } - event.eventProperties = JSONObject().apply { - for ((flagKey, variant) in this@toAmplitudeEvent.results) { - val version = variant.metadata?.get("flagVersion") - val segmentName = variant.metadata?.get("segmentName") - val details = "v$version rule:$segmentName" - put("$flagKey.variant", variant.key) - put("$flagKey.details", details) + event.eventProperties = + JSONObject().apply { + for ((flagKey, variant) in this@toAmplitudeEvent.results) { + val version = variant.metadata?.get("flagVersion") + val segmentName = variant.metadata?.get("segmentName") + val details = "v$version rule:$segmentName" + put("$flagKey.variant", variant.key) + put("$flagKey.details", details) + } } - } - event.userProperties = JSONObject().apply { - val set = JSONObject() - val unset = JSONObject() - for ((flagKey, variant) in this@toAmplitudeEvent.results) { - val flagType = variant.metadata?.get("flagType") as? String - val default = variant.metadata?.get("default") as? Boolean ?: false - if (flagType == FlagType.MUTUAL_EXCLUSION_GROUP) { - // Dont set user properties for mutual exclusion groups. - continue - } else if (default) { - unset.put("[Experiment] $flagKey", "-") - } else { - set.put("[Experiment] $flagKey", variant.key) + event.userProperties = + JSONObject().apply { + val set = JSONObject() + val unset = JSONObject() + for ((flagKey, variant) in this@toAmplitudeEvent.results) { + val flagType = variant.metadata?.get("flagType") as? String + val default = variant.metadata?.get("default") as? Boolean ?: false + if (flagType == FlagType.MUTUAL_EXCLUSION_GROUP) { + // Dont set user properties for mutual exclusion groups. + continue + } else if (default) { + unset.put("[Experiment] $flagKey", "-") + } else { + set.put("[Experiment] $flagKey", variant.key) + } } + put("\$set", set) + put("\$unset", unset) } - put("\$set", set) - put("\$unset", unset) - } event.insertId = "${this.context.userId()} ${this.context.deviceId()} ${this.canonicalize().hashCode()} ${this.timestamp / DAY_MILLIS}" return event } diff --git a/core/src/main/kotlin/cohort/CohortApi.kt b/core/src/main/kotlin/cohort/CohortApi.kt index a378f04..e44ba42 100644 --- a/core/src/main/kotlin/cohort/CohortApi.kt +++ b/core/src/main/kotlin/cohort/CohortApi.kt @@ -13,19 +13,17 @@ import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.plugins.HttpTimeout import io.ktor.client.request.headers import io.ktor.client.request.parameter -import io.ktor.client.statement.bodyAsText import io.ktor.http.HttpStatusCode -import io.ktor.util.logging.Logger import io.ktor.util.toByteArray import kotlinx.serialization.Serializable import java.util.Base64 internal class CohortTooLargeException(cohortId: String, maxCohortSize: Int) : RuntimeException( - "Cohort $cohortId exceeds the maximum cohort size defined in the SDK configuration $maxCohortSize" + "Cohort $cohortId exceeds the maximum cohort size defined in the SDK configuration $maxCohortSize", ) internal class CohortNotModifiedException(cohortId: String) : RuntimeException( - "Cohort $cohortId has not been modified." + "Cohort $cohortId has not been modified.", ) @Serializable @@ -34,29 +32,35 @@ data class GetCohortResponse( private val lastModified: Long, private val size: Int, private val groupType: String, - private val memberIds: Set? = null + private val memberIds: Set? = null, ) { - fun toCohort() = Cohort( - id = cohortId, - groupType = groupType, - size = size, - lastModified = lastModified, - members = memberIds ?: emptySet() - ) + fun toCohort() = + Cohort( + id = cohortId, + groupType = groupType, + size = size, + lastModified = lastModified, + members = memberIds ?: emptySet(), + ) companion object { - fun fromCohort(cohort: Cohort) = GetCohortResponse( - cohortId = cohort.id, - lastModified = cohort.lastModified, - size = cohort.size, - groupType = cohort.groupType, - memberIds = cohort.members - ) + fun fromCohort(cohort: Cohort) = + GetCohortResponse( + cohortId = cohort.id, + lastModified = cohort.lastModified, + size = cohort.size, + groupType = cohort.groupType, + memberIds = cohort.members, + ) } } internal interface CohortApi { - suspend fun getCohort(cohortId: String, lastModified: Long?, maxCohortSize: Int): Cohort + suspend fun getCohort( + cohortId: String, + lastModified: Long?, + maxCohortSize: Int, + ): Cohort } internal class CohortApiV1( @@ -64,41 +68,46 @@ internal class CohortApiV1( apiKey: String, secretKey: String, engine: HttpClientEngine = OkHttp.create(), - private val retryConfig: RetryConfig = RetryConfig() + private val retryConfig: RetryConfig = RetryConfig(), ) : CohortApi { - companion object { val log by logger() } private val token = Base64.getEncoder().encodeToString("$apiKey:$secretKey".toByteArray(Charsets.UTF_8)) - private val client = HttpClient(engine) { - install(HttpTimeout) { - socketTimeoutMillis = 30000 + private val client = + HttpClient(engine) { + install(HttpTimeout) { + socketTimeoutMillis = 30000 + } } - } - override suspend fun getCohort(cohortId: String, lastModified: Long?, maxCohortSize: Int): Cohort { + override suspend fun getCohort( + cohortId: String, + lastModified: Long?, + maxCohortSize: Int, + ): Cohort { log.debug("getCohortMembers({}): start - maxCohortSize={}, lastModified={}", cohortId, maxCohortSize, lastModified) - val response = retry( - config = retryConfig, - onFailure = { e -> log.error("Cohort download failed: $e") }, - acceptCodes = setOf(HttpStatusCode.NoContent, HttpStatusCode.PayloadTooLarge) - ) { - client.get( - url = serverUrl, - path = "sdk/v1/cohort/$cohortId" + val response = + retry( + config = retryConfig, + onFailure = { e -> log.error("Cohort download failed: $e") }, + acceptCodes = setOf(HttpStatusCode.NoContent, HttpStatusCode.PayloadTooLarge), ) { - parameter("maxCohortSize", "$maxCohortSize") - if (lastModified != null) { - parameter("lastModified", "$lastModified") - } - headers { - set("Authorization", "Basic $token") - set("X-Amp-Exp-Library", "evaluation-proxy/$EVALUATION_PROXY_VERSION") + client.get( + url = serverUrl, + path = "sdk/v1/cohort/$cohortId", + ) { + parameter("maxCohortSize", "$maxCohortSize") + if (lastModified != null) { + parameter("lastModified", "$lastModified") + } + headers { + set("Authorization", "Basic $token") + set("X-Amp-Exp-Library", "evaluation-proxy/$EVALUATION_PROXY_VERSION") + } } } - } log.debug("getCohortMembers({}): status={}", cohortId, response.status) when (response.status) { HttpStatusCode.NoContent -> throw CohortNotModifiedException(cohortId) diff --git a/core/src/main/kotlin/cohort/CohortLoader.kt b/core/src/main/kotlin/cohort/CohortLoader.kt index 8bd2683..8834f01 100644 --- a/core/src/main/kotlin/cohort/CohortLoader.kt +++ b/core/src/main/kotlin/cohort/CohortLoader.kt @@ -13,36 +13,37 @@ import kotlinx.coroutines.launch internal class CohortLoader( private val maxCohortSize: Int, private val cohortApi: CohortApi, - private val cohortStorage: CohortStorage + private val cohortStorage: CohortStorage, ) { - companion object { val log by logger() } private val loader = Loader() - suspend fun loadCohorts(cohortIds: Set) = coroutineScope { - val jobs = mutableListOf() - for (cohortId in cohortIds) { - jobs += launch { loadCohort(cohortId) } + suspend fun loadCohorts(cohortIds: Set) = + coroutineScope { + val jobs = mutableListOf() + for (cohortId in cohortIds) { + jobs += launch { loadCohort(cohortId) } + } + jobs.joinAll() } - jobs.joinAll() - } private suspend fun loadCohort(cohortId: String) { log.trace("loadCohort: start - cohortId={}", cohortId) val storageCohort = cohortStorage.getCohortDescription(cohortId) loader.load(cohortId) { try { - val cohort = Metrics.with({ CohortDownload }, { e -> CohortDownloadFailure(e) }) { - try { - cohortApi.getCohort(cohortId, storageCohort?.lastModified, maxCohortSize) - } catch (e: CohortNotModifiedException) { - log.debug("loadCohort: cohort not modified - cohortId={}", cohortId) - null + val cohort = + Metrics.with({ CohortDownload }, { e -> CohortDownloadFailure(e) }) { + try { + cohortApi.getCohort(cohortId, storageCohort?.lastModified, maxCohortSize) + } catch (e: CohortNotModifiedException) { + log.debug("loadCohort: cohort not modified - cohortId={}", cohortId) + null + } } - } if (cohort != null) { cohortStorage.putCohort(cohort) } diff --git a/core/src/main/kotlin/cohort/CohortStorage.kt b/core/src/main/kotlin/cohort/CohortStorage.kt index bd01e0e..a3be290 100644 --- a/core/src/main/kotlin/cohort/CohortStorage.kt +++ b/core/src/main/kotlin/cohort/CohortStorage.kt @@ -18,7 +18,7 @@ internal data class CohortDescription( @SerialName("cohortId") val id: String, val groupType: String, val size: Int, - val lastModified: Long + val lastModified: Long, ) { fun toCohort(members: Set): Cohort { return Cohort( @@ -26,7 +26,7 @@ internal data class CohortDescription( groupType = groupType, size = size, lastModified = lastModified, - members = members + members = members, ) } } @@ -36,37 +36,51 @@ internal fun Cohort.toCohortDescription(): CohortDescription { id = id, groupType = groupType, size = size, - lastModified = lastModified + lastModified = lastModified, ) } internal interface CohortStorage { suspend fun getCohort(cohortId: String): Cohort? + suspend fun getCohorts(): Map + suspend fun getCohortDescription(cohortId: String): CohortDescription? + suspend fun getCohortDescriptions(): Map - suspend fun getCohortMemberships(groupType: String, groupName: String, cohortIds: Set): Set + + suspend fun getCohortMemberships( + groupType: String, + groupName: String, + cohortIds: Set, + ): Set + suspend fun putCohort(cohort: Cohort) + suspend fun deleteCohort(description: CohortDescription) } -internal fun getCohortStorage(projectId: String, redisConfiguration: RedisConfiguration?, ttl: Duration): CohortStorage { +internal fun getCohortStorage( + projectId: String, + redisConfiguration: RedisConfiguration?, + ttl: Duration, +): CohortStorage { val uri = redisConfiguration?.uri return if (uri == null) { InMemoryCohortStorage() } else { val redis = RedisConnection(uri) - val readOnlyRedis = if (redisConfiguration.readOnlyUri != null) { - RedisConnection(redisConfiguration.readOnlyUri) - } else { - redis - } + val readOnlyRedis = + if (redisConfiguration.readOnlyUri != null) { + RedisConnection(redisConfiguration.readOnlyUri) + } else { + redis + } RedisCohortStorage(projectId, ttl, redisConfiguration.prefix, redis, readOnlyRedis) } } internal class InMemoryCohortStorage : CohortStorage { - private val lock = Mutex() private val cohorts = mutableMapOf() @@ -86,7 +100,11 @@ internal class InMemoryCohortStorage : CohortStorage { return lock.withLock { cohorts.toMap() }.mapValues { it.value.toCohortDescription() } } - override suspend fun getCohortMemberships(groupType: String, groupName: String, cohortIds: Set): Set { + override suspend fun getCohortMemberships( + groupType: String, + groupName: String, + cohortIds: Set, + ): Set { val result = mutableSetOf() lock.withLock { for (cohortId in cohortIds) { @@ -116,9 +134,8 @@ internal class RedisCohortStorage( private val ttl: Duration, private val prefix: String, private val redis: Redis, - private val readOnlyRedis: Redis + private val readOnlyRedis: Redis, ) : CohortStorage { - companion object { val log by logger() } @@ -160,7 +177,7 @@ internal class RedisCohortStorage( override suspend fun getCohortMemberships( groupType: String, groupName: String, - cohortIds: Set + cohortIds: Set, ): Set { val descriptions = getCohortDescriptions() val memberships = mutableSetOf() @@ -172,16 +189,17 @@ internal class RedisCohortStorage( continue } // High volume, use read connection - val isMember = readOnlyRedis.sismember( - RedisKey.CohortMembers( - prefix, - projectId, - description.id, - description.groupType, - description.lastModified - ), - groupName - ) + val isMember = + readOnlyRedis.sismember( + RedisKey.CohortMembers( + prefix, + projectId, + description.id, + description.groupType, + description.lastModified, + ), + groupName, + ) if (isMember) { memberships += description.id } @@ -200,9 +218,9 @@ internal class RedisCohortStorage( projectId, description.id, description.groupType, - description.lastModified + description.lastModified, ), - cohort.members + cohort.members, ) redis.hset(RedisKey.CohortDescriptions(prefix, projectId), mapOf(description.id to jsonEncodedDescription)) if (existingDescription != null) { @@ -212,9 +230,9 @@ internal class RedisCohortStorage( projectId, existingDescription.id, existingDescription.groupType, - existingDescription.lastModified + existingDescription.lastModified, ), - ttl + ttl, ) } } @@ -228,15 +246,15 @@ internal class RedisCohortStorage( projectId, description.id, description.groupType, - description.lastModified - ) + description.lastModified, + ), ) } private suspend fun getCohortMembers( cohortId: String, cohortGroupType: String, - cohortLastModified: Long + cohortLastModified: Long, ): Set? { return redis.smembers(RedisKey.CohortMembers(prefix, projectId, cohortId, cohortGroupType, cohortLastModified)) } diff --git a/core/src/main/kotlin/deployment/Deployment.kt b/core/src/main/kotlin/deployment/Deployment.kt index e65d7e2..1c4f732 100644 --- a/core/src/main/kotlin/deployment/Deployment.kt +++ b/core/src/main/kotlin/deployment/Deployment.kt @@ -7,5 +7,5 @@ internal data class Deployment( val id: String, val projectId: String, val label: String, - val key: String + val key: String, ) diff --git a/core/src/main/kotlin/deployment/DeploymentApi.kt b/core/src/main/kotlin/deployment/DeploymentApi.kt index 7848ff6..2a1fcd3 100644 --- a/core/src/main/kotlin/deployment/DeploymentApi.kt +++ b/core/src/main/kotlin/deployment/DeploymentApi.kt @@ -1,6 +1,9 @@ package com.amplitude.deployment import com.amplitude.EVALUATION_PROXY_VERSION +import com.amplitude.FlagsFetch +import com.amplitude.FlagsFetchFailure +import com.amplitude.Metrics import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.util.RetryConfig import com.amplitude.util.get @@ -21,9 +24,8 @@ internal interface DeploymentApi { internal class DeploymentApiV2( private val serverUrl: String, engine: HttpClientEngine = OkHttp.create(), - private val retryConfig: RetryConfig = RetryConfig() + private val retryConfig: RetryConfig = RetryConfig(), ) : DeploymentApi { - companion object { val log by logger() } @@ -32,18 +34,21 @@ internal class DeploymentApiV2( override suspend fun getFlagConfigs(deploymentKey: String): List { log.trace("getFlagConfigs: start - deploymentKey=$deploymentKey") - val response = retry( - config = retryConfig, - onFailure = { e -> log.error("Get flag configs failed: $e") } - ) { - client.get(serverUrl, "/sdk/v2/flags") { - parameter("v", "0") - headers { - set("Authorization", "Api-Key $deploymentKey") - set("X-Amp-Exp-Library", "evaluation-proxy/$EVALUATION_PROXY_VERSION") + val response = + Metrics.with({ FlagsFetch }, { e -> FlagsFetchFailure(e) }) { + retry( + config = retryConfig, + onFailure = { e -> log.error("Get flag configs failed: $e") }, + ) { + client.get(serverUrl, "/sdk/v2/flags") { + parameter("v", "0") + headers { + set("Authorization", "Api-Key $deploymentKey") + set("X-Amp-Exp-Library", "evaluation-proxy/$EVALUATION_PROXY_VERSION") + } + } } } - } return json.decodeFromString>(response.body()).also { log.trace("getFlagConfigs: end - deploymentKey=$deploymentKey") } diff --git a/core/src/main/kotlin/deployment/DeploymentLoader.kt b/core/src/main/kotlin/deployment/DeploymentLoader.kt index e221704..95922fd 100644 --- a/core/src/main/kotlin/deployment/DeploymentLoader.kt +++ b/core/src/main/kotlin/deployment/DeploymentLoader.kt @@ -1,8 +1,5 @@ package com.amplitude.deployment -import com.amplitude.FlagsFetch -import com.amplitude.FlagsFetchFailure -import com.amplitude.Metrics import com.amplitude.cohort.CohortLoader import com.amplitude.util.Loader import com.amplitude.util.getAllCohortIds @@ -12,9 +9,8 @@ import kotlinx.coroutines.launch internal class DeploymentLoader( private val deploymentApi: DeploymentApi, private val deploymentStorage: DeploymentStorage, - private val cohortLoader: CohortLoader + private val cohortLoader: CohortLoader, ) { - companion object { val log by logger() } @@ -24,9 +20,7 @@ internal class DeploymentLoader( suspend fun loadDeployment(deploymentKey: String) { log.trace("loadDeployment: - deploymentKey=$deploymentKey") loader.load(deploymentKey) { - val networkFlags = Metrics.with({ FlagsFetch }, { e -> FlagsFetchFailure(e) }) { - deploymentApi.getFlagConfigs(deploymentKey) - } + val networkFlags = deploymentApi.getFlagConfigs(deploymentKey) // Remove flags that are no longer deployed. val networkFlagKeys = networkFlags.map { it.key }.toSet() val storageFlagKeys = deploymentStorage.getAllFlags(deploymentKey).map { it.key }.toSet() diff --git a/core/src/main/kotlin/deployment/DeploymentRunner.kt b/core/src/main/kotlin/deployment/DeploymentRunner.kt index 6101d77..38af184 100644 --- a/core/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/core/src/main/kotlin/deployment/DeploymentRunner.kt @@ -17,7 +17,6 @@ internal class DeploymentRunner( private val deploymentStorage: DeploymentStorage, private val deploymentLoader: DeploymentLoader, ) { - companion object { val log by logger() } @@ -27,15 +26,16 @@ internal class DeploymentRunner( suspend fun start() { log.trace("start: - deploymentKey=$deploymentKey") - val job = scope.launch { - try { - deploymentLoader.loadDeployment(deploymentKey) - } catch (t: Throwable) { - // Catch failure and continue to run pollers. Assume deployment - // load will eventually succeed. - log.error("Load failed for deployment $deploymentKey", t) + val job = + scope.launch { + try { + deploymentLoader.loadDeployment(deploymentKey) + } catch (t: Throwable) { + // Catch failure and continue to run pollers. Assume deployment + // load will eventually succeed. + log.error("Load failed for deployment $deploymentKey", t) + } } - } // Periodic flag config loader scope.launch { while (true) { diff --git a/core/src/main/kotlin/deployment/DeploymentStorage.kt b/core/src/main/kotlin/deployment/DeploymentStorage.kt index ca52d7b..7bf2ac7 100644 --- a/core/src/main/kotlin/deployment/DeploymentStorage.kt +++ b/core/src/main/kotlin/deployment/DeploymentStorage.kt @@ -12,38 +12,63 @@ import kotlinx.serialization.encodeToString internal interface DeploymentStorage { suspend fun getDeployment(deploymentKey: String): Deployment? + suspend fun getDeployments(): Map + suspend fun putDeployment(deployment: Deployment) + suspend fun removeDeployment(deploymentKey: String) - suspend fun getFlag(deploymentKey: String, flagKey: String): EvaluationFlag? + + suspend fun getFlag( + deploymentKey: String, + flagKey: String, + ): EvaluationFlag? + suspend fun getAllFlags(deploymentKey: String): Map - suspend fun putFlag(deploymentKey: String, flag: EvaluationFlag) - suspend fun putAllFlags(deploymentKey: String, flags: List) - suspend fun removeFlag(deploymentKey: String, flagKey: String) + + suspend fun putFlag( + deploymentKey: String, + flag: EvaluationFlag, + ) + + suspend fun putAllFlags( + deploymentKey: String, + flags: List, + ) + + suspend fun removeFlag( + deploymentKey: String, + flagKey: String, + ) + suspend fun removeAllFlags(deploymentKey: String) } -internal fun getDeploymentStorage(projectId: String, redisConfiguration: RedisConfiguration?): DeploymentStorage { +internal fun getDeploymentStorage( + projectId: String, + redisConfiguration: RedisConfiguration?, +): DeploymentStorage { val uri = redisConfiguration?.uri return if (uri == null) { InMemoryDeploymentStorage() } else { val redis = RedisConnection(uri) - val readOnlyRedis = if (redisConfiguration.readOnlyUri != null) { - RedisConnection(redisConfiguration.readOnlyUri) - } else { - redis - } + val readOnlyRedis = + if (redisConfiguration.readOnlyUri != null) { + RedisConnection(redisConfiguration.readOnlyUri) + } else { + redis + } RedisDeploymentStorage(redisConfiguration.prefix, projectId, redis, readOnlyRedis) } } internal class InMemoryDeploymentStorage : DeploymentStorage { - private val mutex = Mutex() private val deploymentStorage = mutableMapOf() private val flagStorage = mutableMapOf>() + override suspend fun getDeployment(deploymentKey: String): Deployment? { return mutex.withLock { deploymentStorage[deploymentKey] @@ -69,7 +94,10 @@ internal class InMemoryDeploymentStorage : DeploymentStorage { } } - override suspend fun getFlag(deploymentKey: String, flagKey: String): EvaluationFlag? { + override suspend fun getFlag( + deploymentKey: String, + flagKey: String, + ): EvaluationFlag? { return mutex.withLock { flagStorage[deploymentKey]?.get(flagKey) } @@ -81,19 +109,28 @@ internal class InMemoryDeploymentStorage : DeploymentStorage { } } - override suspend fun putFlag(deploymentKey: String, flag: EvaluationFlag) { + override suspend fun putFlag( + deploymentKey: String, + flag: EvaluationFlag, + ) { return mutex.withLock { flagStorage.getOrPut(deploymentKey) { mutableMapOf() }[flag.key] = flag } } - override suspend fun putAllFlags(deploymentKey: String, flags: List) { + override suspend fun putAllFlags( + deploymentKey: String, + flags: List, + ) { return mutex.withLock { flagStorage.getOrPut(deploymentKey) { mutableMapOf() }.putAll(flags.associateBy { it.key }) } } - override suspend fun removeFlag(deploymentKey: String, flagKey: String) { + override suspend fun removeFlag( + deploymentKey: String, + flagKey: String, + ) { return mutex.withLock { flagStorage[deploymentKey]?.remove(flagKey) } @@ -110,7 +147,7 @@ internal class RedisDeploymentStorage( private val prefix: String, private val projectId: String, private val redis: Redis, - private val readOnlyRedis: Redis + private val readOnlyRedis: Redis, ) : DeploymentStorage { override suspend fun getDeployment(deploymentKey: String): Deployment? { val deploymentJson = redis.hget(RedisKey.Deployments(prefix, projectId), deploymentKey) ?: return null @@ -132,7 +169,10 @@ internal class RedisDeploymentStorage( removeAllFlags(deploymentKey) } - override suspend fun getFlag(deploymentKey: String, flagKey: String): EvaluationFlag? { + override suspend fun getFlag( + deploymentKey: String, + flagKey: String, + ): EvaluationFlag? { val flagJson = redis.hget(RedisKey.FlagConfigs(prefix, projectId, deploymentKey), flagKey) ?: return null return json.decodeFromString(flagJson) } @@ -144,18 +184,27 @@ internal class RedisDeploymentStorage( ?.mapValues { json.decodeFromString(it.value) } ?: mapOf() } - override suspend fun putFlag(deploymentKey: String, flag: EvaluationFlag) { + override suspend fun putFlag( + deploymentKey: String, + flag: EvaluationFlag, + ) { val flagJson = json.encodeToString(flag) redis.hset(RedisKey.FlagConfigs(prefix, projectId, deploymentKey), mapOf(flag.key to flagJson)) } - override suspend fun putAllFlags(deploymentKey: String, flags: List) { + override suspend fun putAllFlags( + deploymentKey: String, + flags: List, + ) { for (flag in flags) { putFlag(deploymentKey, flag) } } - override suspend fun removeFlag(deploymentKey: String, flagKey: String) { + override suspend fun removeFlag( + deploymentKey: String, + flagKey: String, + ) { redis.hdel(RedisKey.FlagConfigs(prefix, projectId, deploymentKey), flagKey) } diff --git a/core/src/main/kotlin/project/Project.kt b/core/src/main/kotlin/project/Project.kt index 614cf50..e6f991c 100644 --- a/core/src/main/kotlin/project/Project.kt +++ b/core/src/main/kotlin/project/Project.kt @@ -4,5 +4,5 @@ internal data class Project( val id: String, val apiKey: String, val secretKey: String, - val managementKey: String + val managementKey: String, ) diff --git a/core/src/main/kotlin/project/ProjectApi.kt b/core/src/main/kotlin/project/ProjectApi.kt index b481c3d..887d7e0 100644 --- a/core/src/main/kotlin/project/ProjectApi.kt +++ b/core/src/main/kotlin/project/ProjectApi.kt @@ -20,7 +20,7 @@ private const val MANAGEMENT_SERVER_URL = "https://experiment.amplitude.com" @Serializable internal data class DeploymentsResponse( - val deployments: List + val deployments: List, ) @Serializable @@ -29,7 +29,7 @@ internal data class SerialDeployment( val projectId: String, val label: String, val key: String, - val deleted: Boolean + val deleted: Boolean, ) private fun SerialDeployment.toDeployment(): Deployment? { @@ -44,32 +44,34 @@ internal interface ProjectApi { internal class ProjectApiV1( private val serverUrl: String, private val managementKey: String, - engine: HttpClientEngine = OkHttp.create() + engine: HttpClientEngine = OkHttp.create(), ) : ProjectApi { - companion object { val log by logger() } - private val client = HttpClient(engine) { - install(HttpTimeout) { - socketTimeoutMillis = 30000 + private val client = + HttpClient(engine) { + install(HttpTimeout) { + socketTimeoutMillis = 30000 + } } - } override suspend fun getDeployments(): List = Metrics.with({ DeploymentsFetch }, { e -> DeploymentsFetchFailure(e) }) { log.trace("getDeployments: start") - val response = retry(onFailure = { e -> log.error("Get deployments failed: $e") }) { - client.get( - url = serverUrl, - path = "api/1/deployments") { - headers { - set("Authorization", "Bearer $managementKey") - set("Accept", "application/json") + val response = + retry(onFailure = { e -> log.error("Get deployments failed: $e") }) { + client.get( + url = serverUrl, + path = "api/1/deployments", + ) { + headers { + set("Authorization", "Bearer $managementKey") + set("Accept", "application/json") + } } } - } json.decodeFromString(response.body()) .deployments .mapNotNull { it.toDeployment() } diff --git a/core/src/main/kotlin/project/ProjectProxy.kt b/core/src/main/kotlin/project/ProjectProxy.kt index 8dd4850..4b01858 100644 --- a/core/src/main/kotlin/project/ProjectProxy.kt +++ b/core/src/main/kotlin/project/ProjectProxy.kt @@ -29,9 +29,8 @@ internal class ProjectProxy( configuration: Configuration, private val assignmentTracker: AssignmentTracker, private val deploymentStorage: DeploymentStorage, - private val cohortStorage: CohortStorage + private val cohortStorage: CohortStorage, ) { - companion object { val log by logger() } @@ -43,15 +42,16 @@ internal class ProjectProxy( private val cohortApi = CohortApiV1(configuration.cohortServerUrl, project.apiKey, project.secretKey) private val cohortLoader = CohortLoader(configuration.maxCohortSize, cohortApi, cohortStorage) private val deploymentLoader = DeploymentLoader(deploymentApi, deploymentStorage, cohortLoader) - private val projectRunner = ProjectRunner( - project, - configuration, - projectApi, - deploymentLoader, - deploymentStorage, - cohortLoader, - cohortStorage - ) + private val projectRunner = + ProjectRunner( + project, + configuration, + projectApi, + deploymentLoader, + deploymentStorage, + cohortLoader, + cohortStorage, + ) suspend fun start() { log.info("Starting project. projectId=${project.id}") @@ -71,27 +71,37 @@ internal class ProjectProxy( return EvaluationProxyResponse.error(HttpStatusCode.OK, json.encodeToString(result)) } - suspend fun getCohort(cohortId: String?, lastModified: Long?, maxCohortSize: Int?): EvaluationProxyResponse { + suspend fun getCohort( + cohortId: String?, + lastModified: Long?, + maxCohortSize: Int?, + ): EvaluationProxyResponse { if (cohortId.isNullOrEmpty()) { return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort not found") } - val cohortDescription = cohortStorage.getCohort(cohortId) - ?: return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort not found") + val cohortDescription = + cohortStorage.getCohort(cohortId) + ?: return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort not found") if (cohortDescription.size > (maxCohortSize ?: Int.MAX_VALUE)) { return EvaluationProxyResponse.error( HttpStatusCode.PayloadTooLarge, - "Cohort $cohortId sized ${cohortDescription.size} is greater than max cohort size $maxCohortSize" + "Cohort $cohortId sized ${cohortDescription.size} is greater than max cohort size $maxCohortSize", ) } if (cohortDescription.lastModified == lastModified) { return EvaluationProxyResponse.error(HttpStatusCode.NoContent, "Cohort not modified") } - val cohort = cohortStorage.getCohort(cohortId) - ?: return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort members not found") + val cohort = + cohortStorage.getCohort(cohortId) + ?: return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort members not found") return EvaluationProxyResponse.json(HttpStatusCode.OK, GetCohortResponse.fromCohort(cohort)) } - suspend fun getCohortMemberships(deploymentKey: String?, groupType: String?, groupName: String?): EvaluationProxyResponse { + suspend fun getCohortMemberships( + deploymentKey: String?, + groupType: String?, + groupName: String?, + ): EvaluationProxyResponse { if (deploymentKey.isNullOrEmpty()) { return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") } @@ -112,7 +122,7 @@ internal class ProjectProxy( suspend fun evaluate( deploymentKey: String?, user: Map?, - flagKeys: Set? = null + flagKeys: Set? = null, ): EvaluationProxyResponse { if (deploymentKey.isNullOrEmpty()) { return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") @@ -124,23 +134,24 @@ internal class ProjectProxy( suspend fun evaluateV1( deploymentKey: String?, user: Map?, - flagKeys: Set? = null + flagKeys: Set? = null, ): EvaluationProxyResponse { if (deploymentKey.isNullOrEmpty()) { return EvaluationProxyResponse(HttpStatusCode.Unauthorized, "Invalid deployment") } - val result = evaluateInternal(deploymentKey, user, flagKeys).filter { entry -> - val default = entry.value.metadata?.get("default") as? Boolean ?: false - val deployed = entry.value.metadata?.get("deployed") as? Boolean ?: true - (!default && deployed) - } + val result = + evaluateInternal(deploymentKey, user, flagKeys).filter { entry -> + val default = entry.value.metadata?.get("default") as? Boolean ?: false + val deployed = entry.value.metadata?.get("deployed") as? Boolean ?: true + (!default && deployed) + } return EvaluationProxyResponse(HttpStatusCode.OK, json.encodeToString(result)) } private suspend fun evaluateInternal( deploymentKey: String, user: Map?, - flagKeys: Set? = null + flagKeys: Set? = null, ): Map { // Get flag configs for the deployment from storage and topo sort. val storageFlags = deploymentStorage.getAllFlags(deploymentKey) diff --git a/core/src/main/kotlin/project/ProjectRunner.kt b/core/src/main/kotlin/project/ProjectRunner.kt index a8ebb8b..9bd253f 100644 --- a/core/src/main/kotlin/project/ProjectRunner.kt +++ b/core/src/main/kotlin/project/ProjectRunner.kt @@ -28,9 +28,8 @@ internal class ProjectRunner( private val deploymentLoader: DeploymentLoader, private val deploymentStorage: DeploymentStorage, private val cohortLoader: CohortLoader, - private val cohortStorage: CohortStorage + private val cohortStorage: CohortStorage, ) { - companion object { val log by logger() } @@ -39,17 +38,19 @@ internal class ProjectRunner( private val scope = CoroutineScope(supervisor) private val lock = Mutex() + @VisibleForTesting internal val deploymentRunners = mutableMapOf() suspend fun start() { - val job = scope.launch { - try { - refresh() - } catch (t: Throwable) { - log.error("Refresh failed for project ${project.id}", t) + val job = + scope.launch { + try { + refresh() + } catch (t: Throwable) { + log.error("Refresh failed for project ${project.id}", t) + } } - } // Periodic deployment update and refresher scope.launch { while (true) { @@ -73,42 +74,43 @@ internal class ProjectRunner( supervisor.cancelAndJoin() } - private suspend fun refresh() = coroutineScope { - lock.withLock { - log.trace("refresh: start") - // Get deployments from API and update the storage. - val networkDeployments = projectApi.getDeployments().associateBy { it.key } - val storageDeployments = deploymentStorage.getDeployments() - // Determine added and removed deployments - val addedDeployments = networkDeployments - storageDeployments.keys - val removedDeployments = storageDeployments - networkDeployments.keys - val startingDeployments = networkDeployments - deploymentRunners.keys - val jobs = mutableListOf() - for ((_, addedDeployment) in addedDeployments) { - log.info("Adding deployment $addedDeployment") - deploymentStorage.putDeployment(addedDeployment) - } - for ((_, deployment) in startingDeployments) { - jobs += scope.launch { addDeploymentInternal(deployment.key) } - } - for ((_, removedDeployment) in removedDeployments) { - log.info("Removing deployment $removedDeployment") - deploymentStorage.removeAllFlags(removedDeployment.key) - deploymentStorage.removeDeployment(removedDeployment.key) - jobs += scope.launch { removeDeploymentInternal(removedDeployment.key) } + private suspend fun refresh() = + coroutineScope { + lock.withLock { + log.trace("refresh: start") + // Get deployments from API and update the storage. + val networkDeployments = projectApi.getDeployments().associateBy { it.key } + val storageDeployments = deploymentStorage.getDeployments() + // Determine added and removed deployments + val addedDeployments = networkDeployments - storageDeployments.keys + val removedDeployments = storageDeployments - networkDeployments.keys + val startingDeployments = networkDeployments - deploymentRunners.keys + val jobs = mutableListOf() + for ((_, addedDeployment) in addedDeployments) { + log.info("Adding deployment $addedDeployment") + deploymentStorage.putDeployment(addedDeployment) + } + for ((_, deployment) in startingDeployments) { + jobs += scope.launch { addDeploymentInternal(deployment.key) } + } + for ((_, removedDeployment) in removedDeployments) { + log.info("Removing deployment $removedDeployment") + deploymentStorage.removeAllFlags(removedDeployment.key) + deploymentStorage.removeDeployment(removedDeployment.key) + jobs += scope.launch { removeDeploymentInternal(removedDeployment.key) } + } + // Keep cohorts which are targeted by all stored deployments. + removeUnusedCohorts(networkDeployments.keys) + jobs.joinAll() + log.debug( + "Project refresh finished: addedDeployments={}, removedDeployments={}, startedDeployments={}", + addedDeployments.keys, + removedDeployments.keys, + startingDeployments.keys, + ) + log.trace("refresh: end") } - // Keep cohorts which are targeted by all stored deployments. - removeUnusedCohorts(networkDeployments.keys) - jobs.joinAll() - log.debug( - "Project refresh finished: addedDeployments={}, removedDeployments={}, startedDeployments={}", - addedDeployments.keys, - removedDeployments.keys, - startingDeployments.keys - ) - log.trace("refresh: end") } - } // Must be run within lock private suspend fun addDeploymentInternal(deploymentKey: String) { @@ -116,13 +118,14 @@ internal class ProjectRunner( return } log.debug("Adding and starting deployment runner for $deploymentKey") - val deploymentRunner = DeploymentRunner( - configuration, - deploymentKey, - cohortLoader, - deploymentStorage, - deploymentLoader - ) + val deploymentRunner = + DeploymentRunner( + configuration, + deploymentKey, + cohortLoader, + deploymentStorage, + deploymentLoader, + ) deploymentRunner.start() deploymentRunners[deploymentKey] = deploymentRunner } diff --git a/core/src/main/kotlin/project/ProjectStorage.kt b/core/src/main/kotlin/project/ProjectStorage.kt index 0d85342..484c271 100644 --- a/core/src/main/kotlin/project/ProjectStorage.kt +++ b/core/src/main/kotlin/project/ProjectStorage.kt @@ -4,15 +4,14 @@ import com.amplitude.RedisConfiguration import com.amplitude.util.Redis import com.amplitude.util.RedisConnection import com.amplitude.util.RedisKey -import kotlinx.coroutines.channels.BufferOverflow -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock internal interface ProjectStorage { suspend fun getProjects(): Set + suspend fun putProject(projectId: String) + suspend fun removeProject(projectId: String) } @@ -26,28 +25,29 @@ internal fun getProjectStorage(redisConfiguration: RedisConfiguration?): Project } internal class InMemoryProjectStorage : ProjectStorage { - private val mutex = Mutex() private val projectStorage = mutableSetOf() - override suspend fun getProjects(): Set = mutex.withLock { - projectStorage.toSet() - } + override suspend fun getProjects(): Set = + mutex.withLock { + projectStorage.toSet() + } - override suspend fun putProject(projectId: String): Unit = mutex.withLock { - projectStorage.add(projectId) - } + override suspend fun putProject(projectId: String): Unit = + mutex.withLock { + projectStorage.add(projectId) + } - override suspend fun removeProject(projectId: String): Unit = mutex.withLock { - projectStorage.remove(projectId) - } + override suspend fun removeProject(projectId: String): Unit = + mutex.withLock { + projectStorage.remove(projectId) + } } internal class RedisProjectStorage( private val prefix: String, - private val redis: Redis + private val redis: Redis, ) : ProjectStorage { - override suspend fun getProjects(): Set { return redis.smembers(RedisKey.Projects(prefix)) ?: emptySet() } diff --git a/core/src/main/kotlin/util/Cache.kt b/core/src/main/kotlin/util/Cache.kt index ded1bca..1d2ae7f 100644 --- a/core/src/main/kotlin/util/Cache.kt +++ b/core/src/main/kotlin/util/Cache.kt @@ -9,15 +9,14 @@ import java.util.HashMap */ internal class Cache( private val capacity: Int, - private val ttlMillis: Long = 0 + private val ttlMillis: Long = 0, ) { - private class Node( var key: K? = null, var value: V? = null, var prev: Node? = null, var next: Node? = null, - var ts: Long = System.currentTimeMillis() + var ts: Long = System.currentTimeMillis(), ) private var count = 0 @@ -32,17 +31,21 @@ internal class Cache( tail.prev = head } - suspend fun get(key: K): V? = mutex.withLock { - val n = map[key] ?: return null - if (timeout && n.ts + ttlMillis < System.currentTimeMillis()) { - removeNodeForKey(key) - return null + suspend fun get(key: K): V? = + mutex.withLock { + val n = map[key] ?: return null + if (timeout && n.ts + ttlMillis < System.currentTimeMillis()) { + removeNodeForKey(key) + return null + } + updateInternal(n) + return n.value } - updateInternal(n) - return n.value - } - suspend fun set(key: K, value: V) = mutex.withLock { + suspend fun set( + key: K, + value: V, + ) = mutex.withLock { var n = map[key] if (n == null) { n = Node(key, value) @@ -62,9 +65,10 @@ internal class Cache( } } - suspend fun remove(key: K): Unit = mutex.withLock { - removeNodeForKey(key) - } + suspend fun remove(key: K): Unit = + mutex.withLock { + removeNodeForKey(key) + } private fun removeNodeForKey(key: K) { val n = map[key] ?: return diff --git a/core/src/main/kotlin/util/Env.kt b/core/src/main/kotlin/util/Env.kt index 7522bfa..9c08da8 100644 --- a/core/src/main/kotlin/util/Env.kt +++ b/core/src/main/kotlin/util/Env.kt @@ -1,26 +1,44 @@ package com.amplitude.util -fun stringEnv(variable: String, default: String? = null): String? { +fun stringEnv( + variable: String, + default: String? = null, +): String? { return System.getenv(variable) ?: default } -fun booleanEnv(variable: String, default: Boolean = false): Boolean { +fun booleanEnv( + variable: String, + default: Boolean = false, +): Boolean { val stringEnv = stringEnv(variable) ?: return default return try { stringEnv.toBoolean() - } catch (_: NumberFormatException) { default } + } catch (_: NumberFormatException) { + default + } } -fun intEnv(variable: String, default: Int? = null): Int? { +fun intEnv( + variable: String, + default: Int? = null, +): Int? { val stringEnv = stringEnv(variable) ?: return default return try { stringEnv.toInt() - } catch (_: NumberFormatException) { default } + } catch (_: NumberFormatException) { + default + } } -fun longEnv(variable: String, default: Long? = null): Long? { +fun longEnv( + variable: String, + default: Long? = null, +): Long? { val stringEnv = stringEnv(variable) ?: return default return try { stringEnv.toLong() - } catch (_: NumberFormatException) { default } + } catch (_: NumberFormatException) { + default + } } diff --git a/core/src/main/kotlin/util/EvaluationContext.kt b/core/src/main/kotlin/util/EvaluationContext.kt index e13e88f..10f2f6d 100644 --- a/core/src/main/kotlin/util/EvaluationContext.kt +++ b/core/src/main/kotlin/util/EvaluationContext.kt @@ -5,6 +5,7 @@ import com.amplitude.experiment.evaluation.EvaluationContext internal fun EvaluationContext.userId(): String? { return (this["user"] as? Map<*, *>)?.get("user_id")?.toString() } + internal fun EvaluationContext.deviceId(): String? { return (this["user"] as? Map<*, *>)?.get("device_id")?.toString() } diff --git a/core/src/main/kotlin/util/EvaluationFlag.kt b/core/src/main/kotlin/util/EvaluationFlag.kt index 1faf301..1932e71 100644 --- a/core/src/main/kotlin/util/EvaluationFlag.kt +++ b/core/src/main/kotlin/util/EvaluationFlag.kt @@ -46,13 +46,14 @@ private fun EvaluationSegment.getGroupedCohortConditionIds(): Map 2) { val contextSubtype = condition.selector[1] - val groupType = if (contextSubtype == "user") { - USER_GROUP_TYPE - } else if (condition.selector.contains("groups")) { - condition.selector[2] - } else { - continue - } + val groupType = + if (contextSubtype == "user") { + USER_GROUP_TYPE + } else if (condition.selector.contains("groups")) { + condition.selector[2] + } else { + continue + } cohortIds.getOrPut(groupType) { mutableSetOf() } += condition.values } } diff --git a/core/src/main/kotlin/util/Http.kt b/core/src/main/kotlin/util/Http.kt index 10afe95..4e4c25b 100644 --- a/core/src/main/kotlin/util/Http.kt +++ b/core/src/main/kotlin/util/Http.kt @@ -12,21 +12,21 @@ import kotlinx.coroutines.delay internal class HttpErrorException( val statusCode: HttpStatusCode, - response: HttpResponse? = null + response: HttpResponse? = null, ) : Exception("HTTP error response: code=$statusCode, message=${statusCode.description}, response=$response") internal data class RetryConfig( val times: Int = 8, val initialDelayMillis: Long = 100, val maxDelay: Long = 10000, - val factor: Double = 2.0 + val factor: Double = 2.0, ) internal suspend fun retry( config: RetryConfig = RetryConfig(), onFailure: (Exception) -> Unit = {}, acceptCodes: Set = emptySet(), - block: suspend () -> HttpResponse + block: suspend () -> HttpResponse, ): HttpResponse { var currentDelay = config.initialDelayMillis var error: Exception? = null @@ -57,7 +57,7 @@ internal suspend fun retry( internal suspend fun HttpClient.get( url: String, path: String, - block: HttpRequestBuilder.() -> Unit + block: HttpRequestBuilder.() -> Unit, ): HttpResponse { return request(HttpMethod.Get, url, path, block) } @@ -66,7 +66,7 @@ internal suspend fun HttpClient.request( method: HttpMethod, url: String, path: String, - block: HttpRequestBuilder.() -> Unit + block: HttpRequestBuilder.() -> Unit, ): HttpResponse { return request { this.method = method diff --git a/core/src/main/kotlin/util/Json.kt b/core/src/main/kotlin/util/Json.kt index 9f3e0e6..3c6a7d2 100644 --- a/core/src/main/kotlin/util/Json.kt +++ b/core/src/main/kotlin/util/Json.kt @@ -16,30 +16,33 @@ import kotlinx.serialization.json.doubleOrNull import kotlinx.serialization.json.intOrNull import kotlinx.serialization.json.longOrNull -val json = Json { - ignoreUnknownKeys = true - isLenient = true - coerceInputValues = true - explicitNulls = false -} +val json = + Json { + ignoreUnknownKeys = true + isLenient = true + coerceInputValues = true + explicitNulls = false + } -internal fun Any?.toJsonElement(): JsonElement = when (this) { - null -> JsonNull - is Map<*, *> -> toJsonObject() - is Collection<*> -> toJsonArray() - is Boolean -> JsonPrimitive(this) - is Number -> JsonPrimitive(this) - is String -> JsonPrimitive(this) - else -> JsonPrimitive(toString()) -} +internal fun Any?.toJsonElement(): JsonElement = + when (this) { + null -> JsonNull + is Map<*, *> -> toJsonObject() + is Collection<*> -> toJsonArray() + is Boolean -> JsonPrimitive(this) + is Number -> JsonPrimitive(this) + is String -> JsonPrimitive(this) + else -> JsonPrimitive(toString()) + } internal fun Collection<*>.toJsonArray(): JsonArray = JsonArray(map { it.toJsonElement() }) -internal fun Map<*, *>.toJsonObject(): JsonObject = JsonObject( - mapNotNull { - (it.key as? String ?: return@mapNotNull null) to it.value.toJsonElement() - }.toMap() -) +internal fun Map<*, *>.toJsonObject(): JsonObject = + JsonObject( + mapNotNull { + (it.key as? String ?: return@mapNotNull null) to it.value.toJsonElement() + }.toMap(), + ) internal fun JsonElement.toAny(): Any? { return when (this) { @@ -66,7 +69,10 @@ internal object AnySerializer : KSerializer { override val descriptor: SerialDescriptor get() = SerialDescriptor("Any", delegate.descriptor) - override fun serialize(encoder: Encoder, value: Any?) { + override fun serialize( + encoder: Encoder, + value: Any?, + ) { val jsonElement = value.toJsonElement() encoder.encodeSerializableValue(delegate, jsonElement) } diff --git a/core/src/main/kotlin/util/Loader.kt b/core/src/main/kotlin/util/Loader.kt index 812d40c..d0dc88b 100644 --- a/core/src/main/kotlin/util/Loader.kt +++ b/core/src/main/kotlin/util/Loader.kt @@ -8,11 +8,13 @@ import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock class Loader { - private val jobsMutex = Mutex() private val jobs = mutableMapOf() - suspend fun load(key: String, loader: suspend CoroutineScope.() -> Unit) = coroutineScope { + suspend fun load( + key: String, + loader: suspend CoroutineScope.() -> Unit, + ) = coroutineScope { jobsMutex.withLock { jobs.getOrPut(key) { launch { diff --git a/core/src/main/kotlin/util/Redis.kt b/core/src/main/kotlin/util/Redis.kt index e4ad777..0ba1ec8 100644 --- a/core/src/main/kotlin/util/Redis.kt +++ b/core/src/main/kotlin/util/Redis.kt @@ -16,23 +16,22 @@ import kotlin.time.Duration private const val STORAGE_PROTOCOL_VERSION = "v3" internal sealed class RedisKey(val value: String) { - data class Projects(val prefix: String) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects") data class Deployments( val prefix: String, - val projectId: String + val projectId: String, ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:deployments") data class FlagConfigs( val prefix: String, val projectId: String, - val deploymentKey: String + val deploymentKey: String, ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:deployments:$deploymentKey:flags") data class CohortDescriptions( val prefix: String, - val projectId: String + val projectId: String, ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:cohorts") data class CohortMembers( @@ -41,28 +40,62 @@ internal sealed class RedisKey(val value: String) { val cohortId: String, val cohortGroupType: String, val cohortLastModified: Long, - ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:cohorts:${cohortId}:${cohortGroupType}:${cohortLastModified}") + ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:cohorts:$cohortId:$cohortGroupType:$cohortLastModified") } internal interface Redis { suspend fun get(key: RedisKey): String? - suspend fun set(key: RedisKey, value: String) + + suspend fun set( + key: RedisKey, + value: String, + ) + suspend fun del(key: RedisKey) - suspend fun sadd(key: RedisKey, values: Set) - suspend fun srem(key: RedisKey, value: String) + + suspend fun sadd( + key: RedisKey, + values: Set, + ) + + suspend fun srem( + key: RedisKey, + value: String, + ) + suspend fun smembers(key: RedisKey): Set? - suspend fun sismember(key: RedisKey, value: String): Boolean - suspend fun hget(key: RedisKey, field: String): String? + + suspend fun sismember( + key: RedisKey, + value: String, + ): Boolean + + suspend fun hget( + key: RedisKey, + field: String, + ): String? + suspend fun hgetall(key: RedisKey): Map? - suspend fun hset(key: RedisKey, values: Map) - suspend fun hdel(key: RedisKey, field: String) - suspend fun expire(key: RedisKey, ttl: Duration) + + suspend fun hset( + key: RedisKey, + values: Map, + ) + + suspend fun hdel( + key: RedisKey, + field: String, + ) + + suspend fun expire( + key: RedisKey, + ttl: Duration, + ) } internal class RedisConnection( - redisUri: String + redisUri: String, ) : Redis { - private val connection: Deferred> private val client: RedisClient = RedisClient.create(redisUri) @@ -76,7 +109,10 @@ internal class RedisConnection( } } - override suspend fun set(key: RedisKey, value: String) { + override suspend fun set( + key: RedisKey, + value: String, + ) { connection.run { set(key.value, value) } @@ -88,13 +124,19 @@ internal class RedisConnection( } } - override suspend fun sadd(key: RedisKey, values: Set) { + override suspend fun sadd( + key: RedisKey, + values: Set, + ) { connection.run { sadd(key.value, *values.toTypedArray()) } } - override suspend fun srem(key: RedisKey, value: String) { + override suspend fun srem( + key: RedisKey, + value: String, + ) { connection.run { srem(key.value, value) } @@ -106,13 +148,19 @@ internal class RedisConnection( } } - override suspend fun sismember(key: RedisKey, value: String): Boolean { + override suspend fun sismember( + key: RedisKey, + value: String, + ): Boolean { return connection.run { sismember(key.value, value) } } - override suspend fun hget(key: RedisKey, field: String): String? { + override suspend fun hget( + key: RedisKey, + field: String, + ): String? { return connection.run { hget(key.value, field) } @@ -124,26 +172,35 @@ internal class RedisConnection( } } - override suspend fun hset(key: RedisKey, values: Map) { + override suspend fun hset( + key: RedisKey, + values: Map, + ) { connection.run { hset(key.value, values) } } - override suspend fun hdel(key: RedisKey, field: String) { + override suspend fun hdel( + key: RedisKey, + field: String, + ) { connection.run { hdel(key.value, field) } } - override suspend fun expire(key: RedisKey, ttl: Duration) { + override suspend fun expire( + key: RedisKey, + ttl: Duration, + ) { connection.run { expire(key.value, ttl.inWholeSeconds) } } private suspend inline fun Deferred>.run( - crossinline action: RedisAsyncCommands.() -> RedisFuture + crossinline action: RedisAsyncCommands.() -> RedisFuture, ): R { return Metrics.with({ RedisCommand }, { e -> RedisCommandFailure(e) }) { await().async().action().asDeferred().await() diff --git a/core/src/test/kotlin/EvaluationProxyTest.kt b/core/src/test/kotlin/EvaluationProxyTest.kt index 8a7959e..4f292cb 100644 --- a/core/src/test/kotlin/EvaluationProxyTest.kt +++ b/core/src/test/kotlin/EvaluationProxyTest.kt @@ -20,95 +20,113 @@ import test.project import kotlin.test.Test class EvaluationProxyTest { - @Test - fun `test start, no projects, nothing happens`(): Unit = runBlocking { - val projectStorage = spyk(InMemoryProjectStorage()) - val projectConfigurations = listOf() - val evaluationProxy = spyk(EvaluationProxy( - projectConfigurations = projectConfigurations, - configuration = Configuration(), - projectStorage = projectStorage - )) - evaluationProxy.start() - verify(exactly = 0) { evaluationProxy.createProjectApi(allAny()) } - assertEquals(0, evaluationProxy.projectProxies.size) - coVerify(exactly = 0) { projectStorage.putProject(allAny()) } - } + fun `test start, no projects, nothing happens`(): Unit = + runBlocking { + val projectStorage = spyk(InMemoryProjectStorage()) + val projectConfigurations = listOf() + val evaluationProxy = + spyk( + EvaluationProxy( + projectConfigurations = projectConfigurations, + configuration = Configuration(), + projectStorage = projectStorage, + ), + ) + evaluationProxy.start() + verify(exactly = 0) { evaluationProxy.createProjectApi(allAny()) } + assertEquals(0, evaluationProxy.projectProxies.size) + coVerify(exactly = 0) { projectStorage.putProject(allAny()) } + } @Test - fun `test start, with project, storage loaded and proxy started`(): Unit = runBlocking { - val projectStorage = spyk(InMemoryProjectStorage()) - val project = project("1") - val projectConfigurations = listOf( - ProjectConfiguration("api", "secret", "management") - ) - val evaluationProxy = spyk(EvaluationProxy( - projectConfigurations = projectConfigurations, - configuration = Configuration(), - projectStorage = projectStorage - )) - coEvery { evaluationProxy.createProjectApi(allAny()) } returns - mockk().apply { - coEvery { getDeployments() } returns listOf(deployment("a", project.id)) - } - coEvery { evaluationProxy.createProjectProxy(allAny()) } returns - mockk().apply { - coEvery { start() } returns Unit - } - evaluationProxy.start() - verify(exactly = 1) { evaluationProxy.createProjectApi(allAny()) } - assertEquals(1, evaluationProxy.projectProxies.size) - coVerify(exactly = 1) { projectStorage.putProject(allAny()) } - val projectProxy = evaluationProxy.projectProxies[project] - coVerify(exactly = 1) { projectProxy?.start() } - } + fun `test start, with project, storage loaded and proxy started`(): Unit = + runBlocking { + val projectStorage = spyk(InMemoryProjectStorage()) + val project = project("1") + val projectConfigurations = + listOf( + ProjectConfiguration("api", "secret", "management"), + ) + val evaluationProxy = + spyk( + EvaluationProxy( + projectConfigurations = projectConfigurations, + configuration = Configuration(), + projectStorage = projectStorage, + ), + ) + coEvery { evaluationProxy.createProjectApi(allAny()) } returns + mockk().apply { + coEvery { getDeployments() } returns listOf(deployment("a", project.id)) + } + coEvery { evaluationProxy.createProjectProxy(allAny()) } returns + mockk().apply { + coEvery { start() } returns Unit + } + evaluationProxy.start() + verify(exactly = 1) { evaluationProxy.createProjectApi(allAny()) } + assertEquals(1, evaluationProxy.projectProxies.size) + coVerify(exactly = 1) { projectStorage.putProject(allAny()) } + val projectProxy = evaluationProxy.projectProxies[project] + coVerify(exactly = 1) { projectProxy?.start() } + } @Test - fun `test start, stored but no longer configured project deleted`(): Unit = runBlocking { - val projectStorage = spyk(InMemoryProjectStorage().apply { - putProject("2") - }) - val project = project("1") - val projectConfigurations = listOf( - ProjectConfiguration("api", "secret", "management") - ) - val evaluationProxy = spyk(EvaluationProxy( - projectConfigurations = projectConfigurations, - configuration = Configuration(), - projectStorage = projectStorage - )) - coEvery { evaluationProxy.createProjectApi(allAny()) } returns - mockk().apply { - coEvery { getDeployments() } returns listOf(deployment("a", project.id)) - } - coEvery { evaluationProxy.createProjectProxy(allAny()) } returns - mockk().apply { - coEvery { start() } returns Unit - } - val deploymentStorage = mockk().apply { - coEvery { getDeployments() } returns mapOf("b" to deployment("b")) - coEvery { removeDeployment(eq("b")) } returns Unit - coEvery { removeAllFlags(eq("b")) } returns Unit - } - val cohortStorage = mockk().apply { - coEvery { getCohortDescriptions() } returns mapOf("c" to cohort("c").toCohortDescription()) - coEvery { deleteCohort(eq(cohort("c").toCohortDescription()))} returns Unit + fun `test start, stored but no longer configured project deleted`(): Unit = + runBlocking { + val projectStorage = + spyk( + InMemoryProjectStorage().apply { + putProject("2") + }, + ) + val project = project("1") + val projectConfigurations = + listOf( + ProjectConfiguration("api", "secret", "management"), + ) + val evaluationProxy = + spyk( + EvaluationProxy( + projectConfigurations = projectConfigurations, + configuration = Configuration(), + projectStorage = projectStorage, + ), + ) + coEvery { evaluationProxy.createProjectApi(allAny()) } returns + mockk().apply { + coEvery { getDeployments() } returns listOf(deployment("a", project.id)) + } + coEvery { evaluationProxy.createProjectProxy(allAny()) } returns + mockk().apply { + coEvery { start() } returns Unit + } + val deploymentStorage = + mockk().apply { + coEvery { getDeployments() } returns mapOf("b" to deployment("b")) + coEvery { removeDeployment(eq("b")) } returns Unit + coEvery { removeAllFlags(eq("b")) } returns Unit + } + val cohortStorage = + mockk().apply { + coEvery { getCohortDescriptions() } returns mapOf("c" to cohort("c").toCohortDescription()) + coEvery { deleteCohort(eq(cohort("c").toCohortDescription())) } returns Unit + } + coEvery { evaluationProxy.createDeploymentStorage(allAny()) } returns deploymentStorage + coEvery { evaluationProxy.createCohortStorage(allAny()) } returns cohortStorage + evaluationProxy.start() + verify(exactly = 1) { evaluationProxy.createProjectApi(allAny()) } + assertEquals(1, evaluationProxy.projectProxies.size) + coVerify(exactly = 1) { projectStorage.putProject(allAny()) } + val projectProxy = evaluationProxy.projectProxies[project] + coVerify(exactly = 1) { projectProxy?.start() } + // Verify project "2" removed + coVerify(exactly = 1) { deploymentStorage.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.removeDeployment(eq("b")) } + coVerify(exactly = 1) { deploymentStorage.removeAllFlags(eq("b")) } + coVerify(exactly = 1) { cohortStorage.getCohortDescriptions() } + coVerify(exactly = 1) { cohortStorage.deleteCohort(eq(cohort("c").toCohortDescription())) } + coVerify(exactly = 1) { projectStorage.removeProject(eq("2")) } } - coEvery { evaluationProxy.createDeploymentStorage(allAny()) } returns deploymentStorage - coEvery { evaluationProxy.createCohortStorage(allAny()) } returns cohortStorage - evaluationProxy.start() - verify(exactly = 1) { evaluationProxy.createProjectApi(allAny()) } - assertEquals(1, evaluationProxy.projectProxies.size) - coVerify(exactly = 1) { projectStorage.putProject(allAny()) } - val projectProxy = evaluationProxy.projectProxies[project] - coVerify(exactly = 1) { projectProxy?.start() } - // Verify project "2" removed - coVerify(exactly = 1) { deploymentStorage.getDeployments() } - coVerify(exactly = 1) { deploymentStorage.removeDeployment(eq("b")) } - coVerify(exactly = 1) { deploymentStorage.removeAllFlags(eq("b")) } - coVerify(exactly = 1) { cohortStorage.getCohortDescriptions() } - coVerify(exactly = 1) { cohortStorage.deleteCohort(eq(cohort("c").toCohortDescription()))} - coVerify(exactly = 1) { projectStorage.removeProject(eq("2")) } - } } diff --git a/core/src/test/kotlin/assignment/AssignmentFilterTest.kt b/core/src/test/kotlin/assignment/AssignmentFilterTest.kt index 9c2fa6e..e3310c2 100644 --- a/core/src/test/kotlin/assignment/AssignmentFilterTest.kt +++ b/core/src/test/kotlin/assignment/AssignmentFilterTest.kt @@ -8,151 +8,172 @@ import org.junit.Test import test.user class AssignmentFilterTest { - @Test - fun `test single assignment`() = runBlocking { - val filter = InMemoryAssignmentFilter(100) - val assignment = Assignment( - user(userId = "user").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - Assert.assertTrue(filter.shouldTrack(assignment)) - } + fun `test single assignment`() = + runBlocking { + val filter = InMemoryAssignmentFilter(100) + val assignment = + Assignment( + user(userId = "user").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + Assert.assertTrue(filter.shouldTrack(assignment)) + } @Test - fun `test duplicate assignments`() = runBlocking { - val filter = InMemoryAssignmentFilter(100) - val assignment1 = Assignment( - user(userId = "user").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - filter.shouldTrack(assignment1) - val assignment2 = Assignment( - user(userId = "user").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - Assert.assertFalse(filter.shouldTrack(assignment2)) - } + fun `test duplicate assignments`() = + runBlocking { + val filter = InMemoryAssignmentFilter(100) + val assignment1 = + Assignment( + user(userId = "user").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + filter.shouldTrack(assignment1) + val assignment2 = + Assignment( + user(userId = "user").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + Assert.assertFalse(filter.shouldTrack(assignment2)) + } @Test - fun `test same user different results`() = runBlocking { - val filter = InMemoryAssignmentFilter(100) - val assignment1 = Assignment( - user(userId = "user").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - Assert.assertTrue(filter.shouldTrack(assignment1)) - val assignment2 = Assignment( - user(userId = "user").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "control"), - "flag-key-2" to EvaluationVariant(key = "on") - ) - ) - Assert.assertTrue(filter.shouldTrack(assignment2)) - } + fun `test same user different results`() = + runBlocking { + val filter = InMemoryAssignmentFilter(100) + val assignment1 = + Assignment( + user(userId = "user").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + Assert.assertTrue(filter.shouldTrack(assignment1)) + val assignment2 = + Assignment( + user(userId = "user").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "control"), + "flag-key-2" to EvaluationVariant(key = "on"), + ), + ) + Assert.assertTrue(filter.shouldTrack(assignment2)) + } @Test - fun `test same results for different users`() = runBlocking { - val filter = InMemoryAssignmentFilter(100) - val assignment1 = Assignment( - user(userId = "user").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - Assert.assertTrue(filter.shouldTrack(assignment1)) - val assignment2 = Assignment( - user(userId = "different user").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - Assert.assertTrue(filter.shouldTrack(assignment2)) - } + fun `test same results for different users`() = + runBlocking { + val filter = InMemoryAssignmentFilter(100) + val assignment1 = + Assignment( + user(userId = "user").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + Assert.assertTrue(filter.shouldTrack(assignment1)) + val assignment2 = + Assignment( + user(userId = "different user").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + Assert.assertTrue(filter.shouldTrack(assignment2)) + } @Test - fun `test empty results`() = runBlocking { - val filter = InMemoryAssignmentFilter(100) - val assignment1 = Assignment( - user(userId = "user").toEvaluationContext(), - mapOf() - ) - Assert.assertTrue(filter.shouldTrack(assignment1)) - val assignment2 = Assignment( - user(userId = "user").toEvaluationContext(), - mapOf() - ) - Assert.assertFalse(filter.shouldTrack(assignment2)) - val assignment3 = Assignment( - user(userId = "different user").toEvaluationContext(), - mapOf() - ) - Assert.assertTrue(filter.shouldTrack(assignment3)) - } + fun `test empty results`() = + runBlocking { + val filter = InMemoryAssignmentFilter(100) + val assignment1 = + Assignment( + user(userId = "user").toEvaluationContext(), + mapOf(), + ) + Assert.assertTrue(filter.shouldTrack(assignment1)) + val assignment2 = + Assignment( + user(userId = "user").toEvaluationContext(), + mapOf(), + ) + Assert.assertFalse(filter.shouldTrack(assignment2)) + val assignment3 = + Assignment( + user(userId = "different user").toEvaluationContext(), + mapOf(), + ) + Assert.assertTrue(filter.shouldTrack(assignment3)) + } @Test - fun `test duplicate assignments with different result ordering`() = runBlocking { - val filter = InMemoryAssignmentFilter(100) - val assignment1 = Assignment( - user(userId = "user").toEvaluationContext(), - linkedMapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - Assert.assertTrue(filter.shouldTrack(assignment1)) - val assignment2 = Assignment( - user(userId = "user").toEvaluationContext(), - linkedMapOf( - "flag-key-2" to EvaluationVariant(key = "control"), - "flag-key-1" to EvaluationVariant(key = "on") - ) - ) - Assert.assertFalse(filter.shouldTrack(assignment2)) - } + fun `test duplicate assignments with different result ordering`() = + runBlocking { + val filter = InMemoryAssignmentFilter(100) + val assignment1 = + Assignment( + user(userId = "user").toEvaluationContext(), + linkedMapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + Assert.assertTrue(filter.shouldTrack(assignment1)) + val assignment2 = + Assignment( + user(userId = "user").toEvaluationContext(), + linkedMapOf( + "flag-key-2" to EvaluationVariant(key = "control"), + "flag-key-1" to EvaluationVariant(key = "on"), + ), + ) + Assert.assertFalse(filter.shouldTrack(assignment2)) + } @Test - fun `test lru replacement`() = runBlocking { - val filter = InMemoryAssignmentFilter(2) - val assignment1 = Assignment( - user(userId = "user").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - Assert.assertTrue(filter.shouldTrack(assignment1)) - val assignment2 = Assignment( - user(userId = "user2").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - Assert.assertTrue(filter.shouldTrack(assignment2)) - val assignment3 = Assignment( - user(userId = "user3").toEvaluationContext(), - mapOf( - "flag-key-1" to EvaluationVariant(key = "on"), - "flag-key-2" to EvaluationVariant(key = "control") - ) - ) - Assert.assertTrue(filter.shouldTrack(assignment3)) - Assert.assertTrue(filter.shouldTrack(assignment1)) - } + fun `test lru replacement`() = + runBlocking { + val filter = InMemoryAssignmentFilter(2) + val assignment1 = + Assignment( + user(userId = "user").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + Assert.assertTrue(filter.shouldTrack(assignment1)) + val assignment2 = + Assignment( + user(userId = "user2").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + Assert.assertTrue(filter.shouldTrack(assignment2)) + val assignment3 = + Assignment( + user(userId = "user3").toEvaluationContext(), + mapOf( + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control"), + ), + ) + Assert.assertTrue(filter.shouldTrack(assignment3)) + Assert.assertTrue(filter.shouldTrack(assignment1)) + } } diff --git a/core/src/test/kotlin/assignment/AssignmentServiceTest.kt b/core/src/test/kotlin/assignment/AssignmentServiceTest.kt index 2cbffcc..444e8a3 100644 --- a/core/src/test/kotlin/assignment/AssignmentServiceTest.kt +++ b/core/src/test/kotlin/assignment/AssignmentServiceTest.kt @@ -11,46 +11,51 @@ import org.junit.Test import test.user class AssignmentServiceTest { - @Test - fun `test assignment to amplitude event`() = runBlocking { - val user = user(userId = "user", deviceId = "device").toEvaluationContext() - val results = mapOf( - "flag-key-1" to EvaluationVariant( - key = "on", - metadata = mapOf( - "flagVersion" to 1, - "segmentName" to "Segment 1" - ) - ), - "flag-key-2" to EvaluationVariant( - key = "off", - metadata = mapOf( - "default" to true, - "flagVersion" to 1, - "segmentName" to "All Other Users" + fun `test assignment to amplitude event`() = + runBlocking { + val user = user(userId = "user", deviceId = "device").toEvaluationContext() + val results = + mapOf( + "flag-key-1" to + EvaluationVariant( + key = "on", + metadata = + mapOf( + "flagVersion" to 1, + "segmentName" to "Segment 1", + ), + ), + "flag-key-2" to + EvaluationVariant( + key = "off", + metadata = + mapOf( + "default" to true, + "flagVersion" to 1, + "segmentName" to "All Other Users", + ), + ), ) - ) - ) - val assignment = Assignment(user, results) - val event = assignment.toAmplitudeEvent() - Assert.assertEquals(user.userId(), event.userId) - Assert.assertEquals(user.deviceId(), event.deviceId) - Assert.assertEquals("[Experiment] Assignment", event.eventType) - val eventProperties = event.eventProperties - Assert.assertEquals(4, eventProperties.length()) - Assert.assertEquals("on", eventProperties.get("flag-key-1.variant")) - Assert.assertEquals("v1 rule:Segment 1", eventProperties.get("flag-key-1.details")) - Assert.assertEquals("off", eventProperties.get("flag-key-2.variant")) - Assert.assertEquals("v1 rule:All Other Users", eventProperties.get("flag-key-2.details")) - val userProperties = event.userProperties - Assert.assertEquals(2, userProperties.length()) - Assert.assertEquals(1, userProperties.getJSONObject("\$set").length()) - Assert.assertEquals(1, userProperties.getJSONObject("\$unset").length()) - Assert.assertEquals("on", userProperties.getJSONObject("\$set").get("[Experiment] flag-key-1")) - Assert.assertEquals("-", userProperties.getJSONObject("\$unset").get("[Experiment] flag-key-2")) - val canonicalization = "user device flag-key-1 on flag-key-2 off " - val expected = "user device ${canonicalization.hashCode()} ${assignment.timestamp / DAY_MILLIS}" - Assert.assertEquals(expected, event.insertId) - } + val assignment = Assignment(user, results) + val event = assignment.toAmplitudeEvent() + Assert.assertEquals(user.userId(), event.userId) + Assert.assertEquals(user.deviceId(), event.deviceId) + Assert.assertEquals("[Experiment] Assignment", event.eventType) + val eventProperties = event.eventProperties + Assert.assertEquals(4, eventProperties.length()) + Assert.assertEquals("on", eventProperties.get("flag-key-1.variant")) + Assert.assertEquals("v1 rule:Segment 1", eventProperties.get("flag-key-1.details")) + Assert.assertEquals("off", eventProperties.get("flag-key-2.variant")) + Assert.assertEquals("v1 rule:All Other Users", eventProperties.get("flag-key-2.details")) + val userProperties = event.userProperties + Assert.assertEquals(2, userProperties.length()) + Assert.assertEquals(1, userProperties.getJSONObject("\$set").length()) + Assert.assertEquals(1, userProperties.getJSONObject("\$unset").length()) + Assert.assertEquals("on", userProperties.getJSONObject("\$set").get("[Experiment] flag-key-1")) + Assert.assertEquals("-", userProperties.getJSONObject("\$unset").get("[Experiment] flag-key-2")) + val canonicalization = "user device flag-key-1 on flag-key-2 off " + val expected = "user device ${canonicalization.hashCode()} ${assignment.timestamp / DAY_MILLIS}" + Assert.assertEquals(expected, event.insertId) + } } diff --git a/core/src/test/kotlin/cohort/CohortApiTest.kt b/core/src/test/kotlin/cohort/CohortApiTest.kt index d2d68d5..50563fa 100644 --- a/core/src/test/kotlin/cohort/CohortApiTest.kt +++ b/core/src/test/kotlin/cohort/CohortApiTest.kt @@ -18,202 +18,220 @@ import kotlinx.coroutines.runBlocking import kotlinx.serialization.encodeToString import java.io.IOException import java.util.Base64 -import kotlin.math.max import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue import kotlin.test.fail class CohortApiTest { - private val apiKey = "api" private val secretKey = "secret" private val serverUrl = "https://api.lab.amplitude.com/" private val token = Base64.getEncoder().encodeToString("$apiKey:$secretKey".toByteArray(Charsets.UTF_8)) - private val fastRetryConfig = RetryConfig( - times = 5, - initialDelayMillis = 1, - maxDelay = 1, - factor = 1.0 - ) + private val fastRetryConfig = + RetryConfig( + times = 5, + initialDelayMillis = 1, + maxDelay = 1, + factor = 1.0, + ) @Test - fun `without existing cohort, success`(): Unit = runBlocking { - val expected = Cohort("a", "User", 1, 100L, setOf("1")) - val mockEngine = MockEngine { request -> - respond( - content = ByteReadChannel(json.encodeToString(GetCohortResponse.fromCohort(expected))), - status = HttpStatusCode.OK + fun `without existing cohort, success`(): Unit = + runBlocking { + val expected = Cohort("a", "User", 1, 100L, setOf("1")) + val mockEngine = + MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(GetCohortResponse.fromCohort(expected))), + status = HttpStatusCode.OK, + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) + val actual = api.getCohort("a", null, Int.MAX_VALUE) + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/sdk/v1/cohort/a", request.url.encodedPath) + assertEquals( + Parameters.build { + set("maxCohortSize", "${Int.MAX_VALUE}") + }, + request.url.parameters, ) + assertTrue(request.headers["X-Amp-Exp-Library"]!!.startsWith("evaluation-proxy/")) + assertEquals("Basic $token", request.headers["Authorization"]) } - val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) - val actual = api.getCohort("a", null, Int.MAX_VALUE) - assertEquals(expected, actual) - val request = mockEngine.requestHistory[0] - assertEquals(HttpMethod.Get, request.method) - assertEquals("/sdk/v1/cohort/a", request.url.encodedPath) - assertEquals( - Parameters.build { - set("maxCohortSize", "${Int.MAX_VALUE}") - }, - request.url.parameters - ) - assertTrue(request.headers["X-Amp-Exp-Library"]!!.startsWith("evaluation-proxy/")) - assertEquals("Basic $token", request.headers["Authorization"]) - } @Test - fun `with existing cohort, success`(): Unit = runBlocking { - val expected = Cohort("a", "User", 1, 100L, setOf("1")) - val mockEngine = MockEngine { request -> - respond( - content = ByteReadChannel(json.encodeToString(GetCohortResponse.fromCohort(expected))), - status = HttpStatusCode.OK + fun `with existing cohort, success`(): Unit = + runBlocking { + val expected = Cohort("a", "User", 1, 100L, setOf("1")) + val mockEngine = + MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(GetCohortResponse.fromCohort(expected))), + status = HttpStatusCode.OK, + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) + val actual = api.getCohort("a", 99, Int.MAX_VALUE) + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/sdk/v1/cohort/a", request.url.encodedPath) + assertEquals( + Parameters.build { + set("maxCohortSize", "${Int.MAX_VALUE}") + set("lastModified", "99") + }, + request.url.parameters, ) + assertTrue(request.headers["X-Amp-Exp-Library"]!!.startsWith("evaluation-proxy/")) + assertEquals("Basic $token", request.headers["Authorization"]) } - val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) - val actual = api.getCohort("a", 99, Int.MAX_VALUE) - assertEquals(expected, actual) - val request = mockEngine.requestHistory[0] - assertEquals(HttpMethod.Get, request.method) - assertEquals("/sdk/v1/cohort/a", request.url.encodedPath) - assertEquals( - Parameters.build { - set("maxCohortSize", "${Int.MAX_VALUE}") - set("lastModified", "99") - }, - request.url.parameters - ) - assertTrue(request.headers["X-Amp-Exp-Library"]!!.startsWith("evaluation-proxy/")) - assertEquals("Basic $token", request.headers["Authorization"]) - } @Test - fun `with existing cohort, cohort not modified, no retries, throws`(): Unit = runBlocking { - val mockEngine = MockEngine { request -> - respond( - content = "", - status = HttpStatusCode.NoContent - ) - } - val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) - try { - api.getCohort("a", 99, Int.MAX_VALUE) - fail("Expected getCohort call to fail with CohortNotModifiedException") - } catch (e: CohortNotModifiedException) { - // Success + fun `with existing cohort, cohort not modified, no retries, throws`(): Unit = + runBlocking { + val mockEngine = + MockEngine { request -> + respond( + content = "", + status = HttpStatusCode.NoContent, + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with CohortNotModifiedException") + } catch (e: CohortNotModifiedException) { + // Success + } } - } @Test - fun `without existing cohort, cohort too large, no retries, throws`(): Unit = runBlocking { - val mockEngine = MockEngine { request -> - respond( - content = "", - status = HttpStatusCode.PayloadTooLarge - ) - } - val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) - try { - api.getCohort("a", 99, Int.MAX_VALUE) - fail("Expected getCohort call to fail with CohortTooLargeException") - } catch (e: CohortTooLargeException) { - // Success + fun `without existing cohort, cohort too large, no retries, throws`(): Unit = + runBlocking { + val mockEngine = + MockEngine { request -> + respond( + content = "", + status = HttpStatusCode.PayloadTooLarge, + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with CohortTooLargeException") + } catch (e: CohortTooLargeException) { + // Success + } } - } @Test - fun `request failures, retries, throws`(): Unit = runBlocking { - var failureCounter = 0 - val mockEngine = MockEngine { _ -> - failureCounter++ - throw IOException("test") - } - val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) - try { - api.getCohort("a", 99, Int.MAX_VALUE) - fail("Expected getCohort call to fail with IOException") - } catch (e: IOException) { - // Success + fun `request failures, retries, throws`(): Unit = + runBlocking { + var failureCounter = 0 + val mockEngine = + MockEngine { _ -> + failureCounter++ + throw IOException("test") + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with IOException") + } catch (e: IOException) { + // Success + } + assertEquals(5, failureCounter) } - assertEquals(5, failureCounter) - } @Test - fun `request server error responses, retries, throws`(): Unit = runBlocking { - val mockEngine = MockEngine { _ -> - respond( - content = "", - status = HttpStatusCode.InternalServerError - ) - } - val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) - try { - api.getCohort("a", 99, Int.MAX_VALUE) - fail("Expected getCohort call to fail with HttpErrorException") - } catch (e: HttpErrorException) { - // Success + fun `request server error responses, retries, throws`(): Unit = + runBlocking { + val mockEngine = + MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.InternalServerError, + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + // Success + } + assertEquals(5, mockEngine.responseHistory.size) } - assertEquals(5, mockEngine.responseHistory.size) - } @Test - fun `request client too many requests, retries, throws`(): Unit = runBlocking { - val mockEngine = MockEngine { _ -> - respond( - content = "", - status = HttpStatusCode.TooManyRequests - ) - } - val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) - try { - api.getCohort("a", 99, Int.MAX_VALUE) - fail("Expected getCohort call to fail with HttpErrorException") - } catch (e: HttpErrorException) { - assertEquals(HttpStatusCode.TooManyRequests, mockEngine.responseHistory[0].statusCode) + fun `request client too many requests, retries, throws`(): Unit = + runBlocking { + val mockEngine = + MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.TooManyRequests, + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + assertEquals(HttpStatusCode.TooManyRequests, mockEngine.responseHistory[0].statusCode) + } + assertEquals(5, mockEngine.responseHistory.size) } - assertEquals(5, mockEngine.responseHistory.size) - } @Test - fun `request client error, no retries, throws`(): Unit = runBlocking { - val mockEngine = MockEngine { _ -> - respond( - content = "", - status = HttpStatusCode.NotFound - ) - } - val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) - try { - api.getCohort("a", 99, Int.MAX_VALUE) - fail("Expected getCohort call to fail with HttpErrorException") - } catch (e: HttpErrorException) { - assertEquals(HttpStatusCode.NotFound, mockEngine.responseHistory[0].statusCode) + fun `request client error, no retries, throws`(): Unit = + runBlocking { + val mockEngine = + MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.NotFound, + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + assertEquals(HttpStatusCode.NotFound, mockEngine.responseHistory[0].statusCode) + } + assertEquals(1, mockEngine.responseHistory.size) } - assertEquals(1, mockEngine.responseHistory.size) - } @Test - fun `request errors, eventual success`(): Unit = runBlocking { - var i = 0 - val expected = Cohort("a", "User", 1, 100L, setOf("1")) - val mockEngine = MockEngine { _ -> - i++ - when (i) { - 1 -> respond(content = "", status = HttpStatusCode.InternalServerError) - 2 -> respond(content = "", status = HttpStatusCode.TooManyRequests) - 3 -> respond(content = "", status = HttpStatusCode.BadGateway) - 4 -> respond(content = "", status = HttpStatusCode.GatewayTimeout) - 5 -> respond( - content = ByteReadChannel(json.encodeToString(GetCohortResponse.fromCohort(expected))), - status = HttpStatusCode.OK - ) - else -> fail("unexpected number of requests") - } + fun `request errors, eventual success`(): Unit = + runBlocking { + var i = 0 + val expected = Cohort("a", "User", 1, 100L, setOf("1")) + val mockEngine = + MockEngine { _ -> + i++ + when (i) { + 1 -> respond(content = "", status = HttpStatusCode.InternalServerError) + 2 -> respond(content = "", status = HttpStatusCode.TooManyRequests) + 3 -> respond(content = "", status = HttpStatusCode.BadGateway) + 4 -> respond(content = "", status = HttpStatusCode.GatewayTimeout) + 5 -> + respond( + content = ByteReadChannel(json.encodeToString(GetCohortResponse.fromCohort(expected))), + status = HttpStatusCode.OK, + ) + else -> fail("unexpected number of requests") + } + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + val actual = api.getCohort("a", 99, Int.MAX_VALUE) + assertEquals(5, mockEngine.responseHistory.size) + assertEquals(expected, actual) } - val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) - val actual = api.getCohort("a", 99, Int.MAX_VALUE) - assertEquals(5, mockEngine.responseHistory.size) - assertEquals(expected, actual) - } } diff --git a/core/src/test/kotlin/cohort/CohortLoaderTest.kt b/core/src/test/kotlin/cohort/CohortLoaderTest.kt index 582c970..e4a0c39 100644 --- a/core/src/test/kotlin/cohort/CohortLoaderTest.kt +++ b/core/src/test/kotlin/cohort/CohortLoaderTest.kt @@ -20,108 +20,112 @@ import kotlin.test.Test import kotlin.test.assertEquals class CohortLoaderTest { - private val maxCohortSize = Int.MAX_VALUE @Test - fun `load cohorts, success`(): Unit = runBlocking { - val cohortA = Cohort("a", "User", 1, 100, setOf("1")) - val cohortB = Cohort("b", "User", 1, 100, setOf("1")) - val cohortC = Cohort("c", "User", 1, 100, setOf("1")) - val api = mockk() - val storage: CohortStorage = spyk(InMemoryCohortStorage()) - val loader = CohortLoader(maxCohortSize, api, storage) - coEvery { api.getCohort("a", null, maxCohortSize) } returns cohortA - coEvery { api.getCohort("b", null, maxCohortSize) } returns cohortB - coEvery { api.getCohort("c", null, maxCohortSize) } returns cohortC - // Run 1 - loader.loadCohorts(setOf("a", "b", "c")) - coVerify(exactly = 3) { api.getCohort(allAny(), isNull(), eq(maxCohortSize)) } - assertEquals( - mapOf( - "a" to cohortA, - "b" to cohortB, - "c" to cohortC - ), - storage.getCohorts() - ) - coVerify(exactly = 3) { storage.putCohort(allAny()) } - // Run 2 - val cohortB2 = cohortB.copy(size = 2, lastModified = 200, members = setOf("1", "2")) - coEvery { api.getCohort("a", 100, maxCohortSize) } returns cohortA - coEvery { api.getCohort("b", 100, maxCohortSize) } returns cohortB2 - coEvery { api.getCohort("c", 100, maxCohortSize) } throws CohortNotModifiedException("c") - loader.loadCohorts(setOf("a", "b", "c")) - coVerify(exactly = 3) { api.getCohort(allAny(), eq(100), eq(maxCohortSize)) } - assertEquals( - mapOf( - "a" to cohortA, - "b" to cohortB2, - "c" to cohortC - ), - storage.getCohorts() - ) - // Cohort C should not be stored - coVerify(exactly = 5) { storage.putCohort(allAny()) } - } - - @Test - fun `load cohorts, simultaneous loading of the same cohorts, only downloads once per cohort`(): Unit = runBlocking { - val cohortA = Cohort("a", "User", 1, 100, setOf("1")) - val cohortB = Cohort("b", "User", 1, 100, setOf("1")) - val cohortC = Cohort("c", "User", 1, 100, setOf("1")) - val api = mockk() - val storage: CohortStorage = spyk(InMemoryCohortStorage()) - val loader = CohortLoader(maxCohortSize, api, storage) - coEvery { api.getCohort("a", null, maxCohortSize) } coAnswers { - delay(100) - cohortA - } - coEvery { api.getCohort("b", null, maxCohortSize) } coAnswers { - delay(100) - cohortB - } - coEvery { api.getCohort("c", null, maxCohortSize) } coAnswers { - delay(100) - cohortC - } - val j1 = launch { + fun `load cohorts, success`(): Unit = + runBlocking { + val cohortA = Cohort("a", "User", 1, 100, setOf("1")) + val cohortB = Cohort("b", "User", 1, 100, setOf("1")) + val cohortC = Cohort("c", "User", 1, 100, setOf("1")) + val api = mockk() + val storage: CohortStorage = spyk(InMemoryCohortStorage()) + val loader = CohortLoader(maxCohortSize, api, storage) + coEvery { api.getCohort("a", null, maxCohortSize) } returns cohortA + coEvery { api.getCohort("b", null, maxCohortSize) } returns cohortB + coEvery { api.getCohort("c", null, maxCohortSize) } returns cohortC + // Run 1 loader.loadCohorts(setOf("a", "b", "c")) - } - val j2 = launch { + coVerify(exactly = 3) { api.getCohort(allAny(), isNull(), eq(maxCohortSize)) } + assertEquals( + mapOf( + "a" to cohortA, + "b" to cohortB, + "c" to cohortC, + ), + storage.getCohorts(), + ) + coVerify(exactly = 3) { storage.putCohort(allAny()) } + // Run 2 + val cohortB2 = cohortB.copy(size = 2, lastModified = 200, members = setOf("1", "2")) + coEvery { api.getCohort("a", 100, maxCohortSize) } returns cohortA + coEvery { api.getCohort("b", 100, maxCohortSize) } returns cohortB2 + coEvery { api.getCohort("c", 100, maxCohortSize) } throws CohortNotModifiedException("c") loader.loadCohorts(setOf("a", "b", "c")) + coVerify(exactly = 3) { api.getCohort(allAny(), eq(100), eq(maxCohortSize)) } + assertEquals( + mapOf( + "a" to cohortA, + "b" to cohortB2, + "c" to cohortC, + ), + storage.getCohorts(), + ) + // Cohort C should not be stored + coVerify(exactly = 5) { storage.putCohort(allAny()) } + } + + @Test + fun `load cohorts, simultaneous loading of the same cohorts, only downloads once per cohort`(): Unit = + runBlocking { + val cohortA = Cohort("a", "User", 1, 100, setOf("1")) + val cohortB = Cohort("b", "User", 1, 100, setOf("1")) + val cohortC = Cohort("c", "User", 1, 100, setOf("1")) + val api = mockk() + val storage: CohortStorage = spyk(InMemoryCohortStorage()) + val loader = CohortLoader(maxCohortSize, api, storage) + coEvery { api.getCohort("a", null, maxCohortSize) } coAnswers { + delay(100) + cohortA + } + coEvery { api.getCohort("b", null, maxCohortSize) } coAnswers { + delay(100) + cohortB + } + coEvery { api.getCohort("c", null, maxCohortSize) } coAnswers { + delay(100) + cohortC + } + val j1 = + launch { + loader.loadCohorts(setOf("a", "b", "c")) + } + val j2 = + launch { + loader.loadCohorts(setOf("a", "b", "c")) + } + listOf(j1, j2).joinAll() + coVerify(exactly = 3) { api.getCohort(allAny(), isNull(), eq(maxCohortSize)) } + assertEquals( + mapOf( + "a" to cohortA, + "b" to cohortB, + "c" to cohortC, + ), + storage.getCohorts(), + ) } - listOf(j1, j2).joinAll() - coVerify(exactly = 3) { api.getCohort(allAny(), isNull(), eq(maxCohortSize)) } - assertEquals( - mapOf( - "a" to cohortA, - "b" to cohortB, - "c" to cohortC - ), - storage.getCohorts() - ) - } @Test - fun `load cohorts, failure, failed cohort not stored, does not throw`(): Unit = runBlocking { - val cohortA = Cohort("a", "User", 1, 100, setOf("1")) - val cohortC = Cohort("c", "User", 1, 100, setOf("1")) - val api = mockk() - val storage: CohortStorage = spyk(InMemoryCohortStorage()) - val loader = CohortLoader(maxCohortSize, api, storage) - coEvery { api.getCohort("a", null, maxCohortSize) } returns cohortA - coEvery { api.getCohort("b", null, maxCohortSize) } throws HttpErrorException(HttpStatusCode.InternalServerError) - coEvery { api.getCohort("c", null, maxCohortSize) } returns cohortC - // Run 1 - loader.loadCohorts(setOf("a", "b", "c")) - coVerify(exactly = 3) { api.getCohort(allAny(), isNull(), eq(maxCohortSize)) } - assertEquals( - mapOf( - "a" to cohortA, - "c" to cohortC - ), - storage.getCohorts() - ) - } + fun `load cohorts, failure, failed cohort not stored, does not throw`(): Unit = + runBlocking { + val cohortA = Cohort("a", "User", 1, 100, setOf("1")) + val cohortC = Cohort("c", "User", 1, 100, setOf("1")) + val api = mockk() + val storage: CohortStorage = spyk(InMemoryCohortStorage()) + val loader = CohortLoader(maxCohortSize, api, storage) + coEvery { api.getCohort("a", null, maxCohortSize) } returns cohortA + coEvery { api.getCohort("b", null, maxCohortSize) } throws HttpErrorException(HttpStatusCode.InternalServerError) + coEvery { api.getCohort("c", null, maxCohortSize) } returns cohortC + // Run 1 + loader.loadCohorts(setOf("a", "b", "c")) + coVerify(exactly = 3) { api.getCohort(allAny(), isNull(), eq(maxCohortSize)) } + assertEquals( + mapOf( + "a" to cohortA, + "c" to cohortC, + ), + storage.getCohorts(), + ) + } } diff --git a/core/src/test/kotlin/cohort/CohortStorageTest.kt b/core/src/test/kotlin/cohort/CohortStorageTest.kt index 33d3f98..9ad120a 100644 --- a/core/src/test/kotlin/cohort/CohortStorageTest.kt +++ b/core/src/test/kotlin/cohort/CohortStorageTest.kt @@ -1,7 +1,5 @@ package cohort -import test.InMemoryRedis -import test.cohort import com.amplitude.cohort.Cohort import com.amplitude.cohort.CohortDescription import com.amplitude.cohort.CohortStorage @@ -9,94 +7,98 @@ import com.amplitude.cohort.InMemoryCohortStorage import com.amplitude.cohort.RedisCohortStorage import com.amplitude.cohort.toCohortDescription import kotlinx.coroutines.runBlocking +import test.InMemoryRedis +import test.cohort import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNull import kotlin.time.Duration class CohortStorageTest { - private val redis = InMemoryRedis() @Test - fun `test in memory`(): Unit = runBlocking { - test(InMemoryCohortStorage()) - } + fun `test in memory`(): Unit = + runBlocking { + test(InMemoryCohortStorage()) + } @Test - fun `test redis`(): Unit = runBlocking { - test(RedisCohortStorage("12345", Duration.INFINITE, "amplitude ", redis, redis)) - } + fun `test redis`(): Unit = + runBlocking { + test(RedisCohortStorage("12345", Duration.INFINITE, "amplitude ", redis, redis)) + } - private fun test(cohortStorage: CohortStorage): Unit = runBlocking { - val cohortA = cohort("a") - val cohortB = cohort("b") + private fun test(cohortStorage: CohortStorage): Unit = + runBlocking { + val cohortA = cohort("a") + val cohortB = cohort("b") - // test get, null - var cohort: Cohort? = cohortStorage.getCohort(cohortA.id) - assertNull(cohort) - // test get all, empty - var cohorts: Map = cohortStorage.getCohorts() - assertEquals(0, cohorts.size) - // test get description, null - var description: CohortDescription? = cohortStorage.getCohortDescription(cohortA.id) - assertNull(description) - // test get descriptions, empty - var descriptions: Map = cohortStorage.getCohortDescriptions() - assertEquals(0, descriptions.size) - // test put, get, cohort - cohortStorage.putCohort(cohortA) - cohort = cohortStorage.getCohort(cohortA.id) - assertEquals(cohortA, cohort) - // test get description, description - description = cohortStorage.getCohortDescription(cohortA.id) - assertEquals(cohortA.toCohortDescription(), description) - // test put, get all, cohorts - cohortStorage.putCohort(cohortB) - cohorts = cohortStorage.getCohorts() - assertEquals( - mapOf( - cohortA.id to cohortA, - cohortB.id to cohortB - ), - cohorts - ) - // test get descriptions, descriptions - descriptions = cohortStorage.getCohortDescriptions() - assertEquals( - mapOf( - cohortA.id to cohortA.toCohortDescription(), - cohortB.id to cohortB.toCohortDescription() - ), - descriptions - ) - // test delete one - cohortStorage.deleteCohort(cohortA.toCohortDescription()) - // test get deleted, null - cohort = cohortStorage.getCohort(cohortA.id) - assertNull(cohort) - // test get other, cohort - cohort = cohortStorage.getCohort(cohortB.id) - assertEquals(cohortB, cohort) - // test get all, cohort - cohorts = cohortStorage.getCohorts() - assertEquals(mapOf(cohortB.id to cohortB), cohorts) - // test get description deleted, null - description = cohortStorage.getCohortDescription(cohortA.id) - assertNull(description) - // test get description other, description - description = cohortStorage.getCohortDescription(cohortB.id) - assertEquals(cohortB.toCohortDescription(), description) - // test get descriptions, description - descriptions = cohortStorage.getCohortDescriptions() - assertEquals(mapOf(cohortB.id to cohortB.toCohortDescription()), descriptions) - // test delete other - cohortStorage.deleteCohort(cohortB.toCohortDescription()) - // test get all, empty - cohorts = cohortStorage.getCohorts() - assertEquals(0, cohorts.size) - // test get descriptions, empty - descriptions = cohortStorage.getCohortDescriptions() - assertEquals(0, descriptions.size) - } + // test get, null + var cohort: Cohort? = cohortStorage.getCohort(cohortA.id) + assertNull(cohort) + // test get all, empty + var cohorts: Map = cohortStorage.getCohorts() + assertEquals(0, cohorts.size) + // test get description, null + var description: CohortDescription? = cohortStorage.getCohortDescription(cohortA.id) + assertNull(description) + // test get descriptions, empty + var descriptions: Map = cohortStorage.getCohortDescriptions() + assertEquals(0, descriptions.size) + // test put, get, cohort + cohortStorage.putCohort(cohortA) + cohort = cohortStorage.getCohort(cohortA.id) + assertEquals(cohortA, cohort) + // test get description, description + description = cohortStorage.getCohortDescription(cohortA.id) + assertEquals(cohortA.toCohortDescription(), description) + // test put, get all, cohorts + cohortStorage.putCohort(cohortB) + cohorts = cohortStorage.getCohorts() + assertEquals( + mapOf( + cohortA.id to cohortA, + cohortB.id to cohortB, + ), + cohorts, + ) + // test get descriptions, descriptions + descriptions = cohortStorage.getCohortDescriptions() + assertEquals( + mapOf( + cohortA.id to cohortA.toCohortDescription(), + cohortB.id to cohortB.toCohortDescription(), + ), + descriptions, + ) + // test delete one + cohortStorage.deleteCohort(cohortA.toCohortDescription()) + // test get deleted, null + cohort = cohortStorage.getCohort(cohortA.id) + assertNull(cohort) + // test get other, cohort + cohort = cohortStorage.getCohort(cohortB.id) + assertEquals(cohortB, cohort) + // test get all, cohort + cohorts = cohortStorage.getCohorts() + assertEquals(mapOf(cohortB.id to cohortB), cohorts) + // test get description deleted, null + description = cohortStorage.getCohortDescription(cohortA.id) + assertNull(description) + // test get description other, description + description = cohortStorage.getCohortDescription(cohortB.id) + assertEquals(cohortB.toCohortDescription(), description) + // test get descriptions, description + descriptions = cohortStorage.getCohortDescriptions() + assertEquals(mapOf(cohortB.id to cohortB.toCohortDescription()), descriptions) + // test delete other + cohortStorage.deleteCohort(cohortB.toCohortDescription()) + // test get all, empty + cohorts = cohortStorage.getCohorts() + assertEquals(0, cohorts.size) + // test get descriptions, empty + descriptions = cohortStorage.getCohortDescriptions() + assertEquals(0, descriptions.size) + } } diff --git a/core/src/test/kotlin/deployment/DeploymentApiTest.kt b/core/src/test/kotlin/deployment/DeploymentApiTest.kt index 1794ac5..2835aea 100644 --- a/core/src/test/kotlin/deployment/DeploymentApiTest.kt +++ b/core/src/test/kotlin/deployment/DeploymentApiTest.kt @@ -4,7 +4,6 @@ import com.amplitude.deployment.DeploymentApiV2 import com.amplitude.util.HttpErrorException import com.amplitude.util.RetryConfig import com.amplitude.util.json -import test.flag import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respond import io.ktor.http.HttpMethod @@ -14,6 +13,7 @@ import io.ktor.utils.io.ByteReadChannel import kotlinx.coroutines.runBlocking import kotlinx.serialization.encodeToString import org.junit.Test +import test.flag import java.io.IOException import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -22,130 +22,144 @@ import kotlin.test.fail class DeploymentApiTest { private val deploymentKey = "deployment" private val serverUrl = "https://api.lab.amplitude.com/" - private val fastRetryConfig = RetryConfig( - times = 5, - initialDelayMillis = 1, - maxDelay = 1, - factor = 1.0 - ) + private val fastRetryConfig = + RetryConfig( + times = 5, + initialDelayMillis = 1, + maxDelay = 1, + factor = 1.0, + ) @Test - fun `get flags, success`(): Unit = runBlocking { - val expected = listOf(flag()) - val mockEngine = MockEngine { request -> - respond( - content = ByteReadChannel(json.encodeToString(expected)), - status = HttpStatusCode.OK + fun `get flags, success`(): Unit = + runBlocking { + val expected = listOf(flag()) + val mockEngine = + MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(expected)), + status = HttpStatusCode.OK, + ) + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + val actual = api.getFlagConfigs(deploymentKey) + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/sdk/v2/flags", request.url.encodedPath) + assertEquals( + Parameters.build { + set("v", "0") + }, + request.url.parameters, ) + assertTrue(request.headers["X-Amp-Exp-Library"]!!.startsWith("evaluation-proxy/")) + assertEquals("Api-Key $deploymentKey", request.headers["Authorization"]) } - val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) - val actual = api.getFlagConfigs(deploymentKey) - assertEquals(expected, actual) - val request = mockEngine.requestHistory[0] - assertEquals(HttpMethod.Get, request.method) - assertEquals("/sdk/v2/flags", request.url.encodedPath) - assertEquals( - Parameters.build { - set("v", "0") - }, - request.url.parameters - ) - assertTrue(request.headers["X-Amp-Exp-Library"]!!.startsWith("evaluation-proxy/")) - assertEquals("Api-Key $deploymentKey", request.headers["Authorization"]) - } @Test - fun `request failures, retries, throws`(): Unit = runBlocking { - var failureCounter = 0 - val mockEngine = MockEngine { _ -> - failureCounter++ - throw IOException("test") - } - val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) - try { - api.getFlagConfigs(deploymentKey) - fail("Expected getFlagConfigs call to fail with IOException") - } catch (e: IOException) { - // Success + fun `request failures, retries, throws`(): Unit = + runBlocking { + var failureCounter = 0 + val mockEngine = + MockEngine { _ -> + failureCounter++ + throw IOException("test") + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + try { + api.getFlagConfigs(deploymentKey) + fail("Expected getFlagConfigs call to fail with IOException") + } catch (e: IOException) { + // Success + } + assertEquals(5, failureCounter) } - assertEquals(5, failureCounter) - } @Test - fun `request server error responses, retries, throws`(): Unit = runBlocking { - val mockEngine = MockEngine { _ -> - respond( - content = "", - status = HttpStatusCode.InternalServerError - ) - } - val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) - try { - api.getFlagConfigs(deploymentKey) - fail("Expected getFlagConfigs call to fail with HttpErrorException") - } catch (e: HttpErrorException) { - // Success + fun `request server error responses, retries, throws`(): Unit = + runBlocking { + val mockEngine = + MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.InternalServerError, + ) + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + try { + api.getFlagConfigs(deploymentKey) + fail("Expected getFlagConfigs call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + // Success + } + assertEquals(5, mockEngine.responseHistory.size) } - assertEquals(5, mockEngine.responseHistory.size) - } @Test - fun `request client too many requests, retries, throws`(): Unit = runBlocking { - val mockEngine = MockEngine { _ -> - respond( - content = "", - status = HttpStatusCode.TooManyRequests - ) - } - val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) - try { - api.getFlagConfigs(deploymentKey) - fail("Expected getFlagConfigs call to fail with HttpErrorException") - } catch (e: HttpErrorException) { - assertEquals(HttpStatusCode.TooManyRequests, mockEngine.responseHistory[0].statusCode) + fun `request client too many requests, retries, throws`(): Unit = + runBlocking { + val mockEngine = + MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.TooManyRequests, + ) + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + try { + api.getFlagConfigs(deploymentKey) + fail("Expected getFlagConfigs call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + assertEquals(HttpStatusCode.TooManyRequests, mockEngine.responseHistory[0].statusCode) + } + assertEquals(5, mockEngine.responseHistory.size) } - assertEquals(5, mockEngine.responseHistory.size) - } @Test - fun `request client error, no retries, throws`(): Unit = runBlocking { - val mockEngine = MockEngine { _ -> - respond( - content = "", - status = HttpStatusCode.NotFound - ) - } - val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) - try { - api.getFlagConfigs(deploymentKey) - fail("Expected getFlagConfigs call to fail with HttpErrorException") - } catch (e: HttpErrorException) { - assertEquals(HttpStatusCode.NotFound, mockEngine.responseHistory[0].statusCode) + fun `request client error, no retries, throws`(): Unit = + runBlocking { + val mockEngine = + MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.NotFound, + ) + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + try { + api.getFlagConfigs(deploymentKey) + fail("Expected getFlagConfigs call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + assertEquals(HttpStatusCode.NotFound, mockEngine.responseHistory[0].statusCode) + } + assertEquals(1, mockEngine.responseHistory.size) } - assertEquals(1, mockEngine.responseHistory.size) - } @Test - fun `request errors, eventual success`(): Unit = runBlocking { - var i = 0 - val expected = listOf(flag()) - val mockEngine = MockEngine { _ -> - i++ - when (i) { - 1 -> respond(content = "", status = HttpStatusCode.InternalServerError) - 2 -> respond(content = "", status = HttpStatusCode.TooManyRequests) - 3 -> respond(content = "", status = HttpStatusCode.BadGateway) - 4 -> respond(content = "", status = HttpStatusCode.GatewayTimeout) - 5 -> respond( - content = ByteReadChannel(json.encodeToString(expected)), - status = HttpStatusCode.OK - ) - else -> fail("unexpected number of requests") - } + fun `request errors, eventual success`(): Unit = + runBlocking { + var i = 0 + val expected = listOf(flag()) + val mockEngine = + MockEngine { _ -> + i++ + when (i) { + 1 -> respond(content = "", status = HttpStatusCode.InternalServerError) + 2 -> respond(content = "", status = HttpStatusCode.TooManyRequests) + 3 -> respond(content = "", status = HttpStatusCode.BadGateway) + 4 -> respond(content = "", status = HttpStatusCode.GatewayTimeout) + 5 -> + respond( + content = ByteReadChannel(json.encodeToString(expected)), + status = HttpStatusCode.OK, + ) + else -> fail("unexpected number of requests") + } + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + val actual = api.getFlagConfigs(deploymentKey) + assertEquals(5, mockEngine.responseHistory.size) + assertEquals(expected, actual) } - val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) - val actual = api.getFlagConfigs(deploymentKey) - assertEquals(5, mockEngine.responseHistory.size) - assertEquals(expected, actual) - } } diff --git a/core/src/test/kotlin/deployment/DeploymentLoaderTest.kt b/core/src/test/kotlin/deployment/DeploymentLoaderTest.kt index bfb6af9..becdfc5 100644 --- a/core/src/test/kotlin/deployment/DeploymentLoaderTest.kt +++ b/core/src/test/kotlin/deployment/DeploymentLoaderTest.kt @@ -4,173 +4,185 @@ import com.amplitude.cohort.CohortLoader import com.amplitude.deployment.DeploymentApi import com.amplitude.deployment.DeploymentLoader import com.amplitude.deployment.InMemoryDeploymentStorage -import test.flag import io.mockk.coEvery import io.mockk.coVerify import io.mockk.mockk import io.mockk.spyk import kotlinx.coroutines.runBlocking import org.junit.Assert.assertNull +import test.flag import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.fail class DeploymentLoaderTest { - private val deploymentKey = "deployment" private val flagKey = "flag" @Test - fun `load flags without cohorts, cohorts not loaded, success`(): Unit = runBlocking { - val api = mockk() - val storage = spyk(InMemoryDeploymentStorage()) - val cohortLoader = mockk() - val cohortIds = emptySet() - val flag = flag(flagKey = flagKey, cohortIds = cohortIds) - coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) - val loader = DeploymentLoader(api, storage, cohortLoader) - loader.loadDeployment(deploymentKey) - coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } - coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } - coVerify(exactly = 0) { cohortLoader.loadCohorts(allAny()) } - coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } - assertEquals(flag, storage.getFlag(deploymentKey, flagKey)) - } + fun `load flags without cohorts, cohorts not loaded, success`(): Unit = + runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = emptySet() + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + val loader = DeploymentLoader(api, storage, cohortLoader) + loader.loadDeployment(deploymentKey) + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(allAny()) } + coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } + assertEquals(flag, storage.getFlag(deploymentKey, flagKey)) + } @Test - fun `load flags with cohorts, cohorts are loaded, success`(): Unit = runBlocking { - val api = mockk() - val storage = spyk(InMemoryDeploymentStorage()) - val cohortLoader = mockk() - val cohortIds = setOf("a", "b") - val flag = flag(flagKey = flagKey, cohortIds = cohortIds) - coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) - coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit - val loader = DeploymentLoader(api, storage, cohortLoader) - loader.loadDeployment(deploymentKey) - coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } - coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } - coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } - coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } - assertEquals(flag, storage.getFlag(deploymentKey, flagKey)) - } + fun `load flags with cohorts, cohorts are loaded, success`(): Unit = + runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + val loader = DeploymentLoader(api, storage, cohortLoader) + loader.loadDeployment(deploymentKey) + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } + assertEquals(flag, storage.getFlag(deploymentKey, flagKey)) + } @Test - fun `existing flags state, some flags removed, success`(): Unit = runBlocking { - val existingFlagKey = "existing" - val existingFlag = flag(existingFlagKey) - val api = mockk() - val storage = spyk(InMemoryDeploymentStorage().apply { - putFlag(deploymentKey, existingFlag) - }) - val cohortLoader = mockk() - val cohortIds = setOf("a", "b") - val flag = flag(flagKey = flagKey, cohortIds = cohortIds) - coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) - coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit - val loader = DeploymentLoader(api, storage, cohortLoader) - loader.loadDeployment(deploymentKey) - coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } - coVerify(exactly = 1) { storage.removeFlag(eq(deploymentKey), eq(existingFlagKey)) } - coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } - coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } - assertEquals(flag, storage.getFlag(deploymentKey, flagKey)) - } + fun `existing flags state, some flags removed, success`(): Unit = + runBlocking { + val existingFlagKey = "existing" + val existingFlag = flag(existingFlagKey) + val api = mockk() + val storage = + spyk( + InMemoryDeploymentStorage().apply { + putFlag(deploymentKey, existingFlag) + }, + ) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + val loader = DeploymentLoader(api, storage, cohortLoader) + loader.loadDeployment(deploymentKey) + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 1) { storage.removeFlag(eq(deploymentKey), eq(existingFlagKey)) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } + assertEquals(flag, storage.getFlag(deploymentKey, flagKey)) + } @Test - fun `getFlagConfigs fails, throws`(): Unit = runBlocking { - val api = mockk() - val storage = spyk(InMemoryDeploymentStorage()) - val cohortLoader = mockk() - val cohortIds = setOf("a", "b") - coEvery { api.getFlagConfigs(eq(deploymentKey)) } throws RuntimeException("fail") - coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit - val loader = DeploymentLoader(api, storage, cohortLoader) - try { - loader.loadDeployment(deploymentKey) - fail("Expected loadDeployment to throw exception") - } catch (e: RuntimeException) { - // Success + fun `getFlagConfigs fails, throws`(): Unit = + runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + coEvery { api.getFlagConfigs(eq(deploymentKey)) } throws RuntimeException("fail") + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + val loader = DeploymentLoader(api, storage, cohortLoader) + try { + loader.loadDeployment(deploymentKey) + fail("Expected loadDeployment to throw exception") + } catch (e: RuntimeException) { + // Success + } + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 0) { storage.putFlag(allAny(), allAny()) } + assertNull(storage.getFlag(deploymentKey, flagKey)) } - coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } - coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } - coVerify(exactly = 0) { cohortLoader.loadCohorts(eq(cohortIds)) } - coVerify(exactly = 0) { storage.putFlag(allAny(), allAny()) } - assertNull(storage.getFlag(deploymentKey, flagKey)) - } @Test - fun `removeFlag fails, throws`(): Unit = runBlocking { - val existingFlagKey = "existing" - val existingFlag = flag(existingFlagKey) - val api = mockk() - val storage = spyk(InMemoryDeploymentStorage().apply { - putFlag(deploymentKey, existingFlag) - }) - val cohortLoader = mockk() - val cohortIds = setOf("a", "b") - val flag = flag(flagKey = flagKey, cohortIds = cohortIds) - coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) - coEvery { storage.removeFlag(eq(deploymentKey), eq(existingFlagKey)) } throws RuntimeException("fail") - coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit - val loader = DeploymentLoader(api, storage, cohortLoader) - try { - loader.loadDeployment(deploymentKey) - fail("Expected loadDeployment to throw exception") - } catch (e: RuntimeException) { - // Success + fun `removeFlag fails, throws`(): Unit = + runBlocking { + val existingFlagKey = "existing" + val existingFlag = flag(existingFlagKey) + val api = mockk() + val storage = + spyk( + InMemoryDeploymentStorage().apply { + putFlag(deploymentKey, existingFlag) + }, + ) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { storage.removeFlag(eq(deploymentKey), eq(existingFlagKey)) } throws RuntimeException("fail") + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + val loader = DeploymentLoader(api, storage, cohortLoader) + try { + loader.loadDeployment(deploymentKey) + fail("Expected loadDeployment to throw exception") + } catch (e: RuntimeException) { + // Success + } + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 1) { storage.removeFlag(eq(deploymentKey), eq(existingFlagKey)) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 0) { storage.putFlag(allAny(), allAny()) } + assertNull(storage.getFlag(deploymentKey, flagKey)) } - coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } - coVerify(exactly = 1) { storage.removeFlag(eq(deploymentKey), eq(existingFlagKey)) } - coVerify(exactly = 0) { cohortLoader.loadCohorts(eq(cohortIds)) } - coVerify(exactly = 0) { storage.putFlag(allAny(), allAny()) } - assertNull(storage.getFlag(deploymentKey, flagKey)) - } @Test - fun `loadCohorts fails, throws`(): Unit = runBlocking { - val api = mockk() - val storage = spyk(InMemoryDeploymentStorage()) - val cohortLoader = mockk() - val cohortIds = setOf("a", "b") - val flag = flag(flagKey = flagKey, cohortIds = cohortIds) - coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) - coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } throws RuntimeException("fail") - val loader = DeploymentLoader(api, storage, cohortLoader) - try { - loader.loadDeployment(deploymentKey) - fail("Expected loadDeployment to throw exception") - } catch (e: RuntimeException) { - // Success + fun `loadCohorts fails, throws`(): Unit = + runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } throws RuntimeException("fail") + val loader = DeploymentLoader(api, storage, cohortLoader) + try { + loader.loadDeployment(deploymentKey) + fail("Expected loadDeployment to throw exception") + } catch (e: RuntimeException) { + // Success + } + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 0) { storage.putFlag(allAny(), allAny()) } + assertNull(storage.getFlag(deploymentKey, flagKey)) } - coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } - coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } - coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } - coVerify(exactly = 0) { storage.putFlag(allAny(), allAny()) } - assertNull(storage.getFlag(deploymentKey, flagKey)) - } @Test - fun `putFlag fails, throws`(): Unit = runBlocking { - val api = mockk() - val storage = spyk(InMemoryDeploymentStorage()) - val cohortLoader = mockk() - val cohortIds = setOf("a", "b") - val flag = flag(flagKey = flagKey, cohortIds = cohortIds) - coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) - coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit - coEvery { storage.putFlag(eq(deploymentKey), eq(flag)) } throws RuntimeException("fail") - val loader = DeploymentLoader(api, storage, cohortLoader) - try { - loader.loadDeployment(deploymentKey) - fail("Expected loadDeployment to throw exception") - } catch (e: RuntimeException) { - // Success + fun `putFlag fails, throws`(): Unit = + runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + coEvery { storage.putFlag(eq(deploymentKey), eq(flag)) } throws RuntimeException("fail") + val loader = DeploymentLoader(api, storage, cohortLoader) + try { + loader.loadDeployment(deploymentKey) + fail("Expected loadDeployment to throw exception") + } catch (e: RuntimeException) { + // Success + } + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } + assertNull(storage.getFlag(deploymentKey, flagKey)) } - coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } - coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } - coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } - coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } - assertNull(storage.getFlag(deploymentKey, flagKey)) - } } diff --git a/core/src/test/kotlin/deployment/DeploymentRunnerTest.kt b/core/src/test/kotlin/deployment/DeploymentRunnerTest.kt index b6b813e..5013d62 100644 --- a/core/src/test/kotlin/deployment/DeploymentRunnerTest.kt +++ b/core/src/test/kotlin/deployment/DeploymentRunnerTest.kt @@ -5,128 +5,139 @@ import com.amplitude.cohort.CohortLoader import com.amplitude.deployment.DeploymentLoader import com.amplitude.deployment.DeploymentRunner import com.amplitude.deployment.DeploymentStorage -import test.flag import io.mockk.coEvery import io.mockk.coVerify import io.mockk.mockk import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import org.junit.Test +import test.flag class DeploymentRunnerTest { - private val deploymentKey = "deployment" @Test - fun `start, stop, load deployment called once, success`(): Unit = runBlocking { - val flag = flag(cohortIds = setOf("a")) - val configuration = Configuration( - flagSyncIntervalMillis = 50, - cohortSyncIntervalMillis = 50, - ) - val cohortLoader = mockk() - val deploymentStorage = mockk() - val deploymentLoader = mockk() - coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } returns Unit - coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } returns mapOf(flag.key to flag) - coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit - val deploymentRunner = DeploymentRunner( - configuration, - deploymentKey, - cohortLoader, - deploymentStorage, - deploymentLoader - ) - deploymentRunner.start() - deploymentRunner.stop() - delay(100) - coVerify(exactly = 1) { deploymentLoader.loadDeployment(allAny()) } - coVerify(exactly = 0) { deploymentStorage.getAllFlags(allAny()) } - coVerify(exactly = 0) { cohortLoader.loadCohorts(allAny()) } - } + fun `start, stop, load deployment called once, success`(): Unit = + runBlocking { + val flag = flag(cohortIds = setOf("a")) + val configuration = + Configuration( + flagSyncIntervalMillis = 50, + cohortSyncIntervalMillis = 50, + ) + val cohortLoader = mockk() + val deploymentStorage = mockk() + val deploymentLoader = mockk() + coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } returns Unit + coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } returns mapOf(flag.key to flag) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + val deploymentRunner = + DeploymentRunner( + configuration, + deploymentKey, + cohortLoader, + deploymentStorage, + deploymentLoader, + ) + deploymentRunner.start() + deploymentRunner.stop() + delay(100) + coVerify(exactly = 1) { deploymentLoader.loadDeployment(allAny()) } + coVerify(exactly = 0) { deploymentStorage.getAllFlags(allAny()) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(allAny()) } + } @Test - fun `start, delay, periodic loaders run, success`(): Unit = runBlocking { - val cohortIds = setOf("a") - val flag = flag(cohortIds = cohortIds) - val configuration = Configuration( - flagSyncIntervalMillis = 50, - cohortSyncIntervalMillis = 50, - ) - val cohortLoader = mockk() - val deploymentStorage = mockk() - val deploymentLoader = mockk() - coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } returns Unit - coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } returns mapOf(flag.key to flag) - coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit - val deploymentRunner = DeploymentRunner( - configuration, - deploymentKey, - cohortLoader, - deploymentStorage, - deploymentLoader - ) - deploymentRunner.start() - delay(75) - deploymentRunner.stop() - coVerify(exactly = 2) { deploymentLoader.loadDeployment(eq(deploymentKey)) } - coVerify(exactly = 1) { deploymentStorage.getAllFlags(eq(deploymentKey)) } - coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } - } + fun `start, delay, periodic loaders run, success`(): Unit = + runBlocking { + val cohortIds = setOf("a") + val flag = flag(cohortIds = cohortIds) + val configuration = + Configuration( + flagSyncIntervalMillis = 50, + cohortSyncIntervalMillis = 50, + ) + val cohortLoader = mockk() + val deploymentStorage = mockk() + val deploymentLoader = mockk() + coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } returns Unit + coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } returns mapOf(flag.key to flag) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + val deploymentRunner = + DeploymentRunner( + configuration, + deploymentKey, + cohortLoader, + deploymentStorage, + deploymentLoader, + ) + deploymentRunner.start() + delay(75) + deploymentRunner.stop() + coVerify(exactly = 2) { deploymentLoader.loadDeployment(eq(deploymentKey)) } + coVerify(exactly = 1) { deploymentStorage.getAllFlags(eq(deploymentKey)) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + } @Test - fun `start, load deployment throws, does not throw, pollers run`(): Unit = runBlocking { - val cohortIds = setOf("a") - val flag = flag(cohortIds = cohortIds) - val configuration = Configuration( - flagSyncIntervalMillis = 50, - cohortSyncIntervalMillis = 50, - ) - val cohortLoader = mockk() - val deploymentStorage = mockk() - val deploymentLoader = mockk() - coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } throws RuntimeException("fail") - coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } returns mapOf(flag.key to flag) - coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit - val deploymentRunner = DeploymentRunner( - configuration, - deploymentKey, - cohortLoader, - deploymentStorage, - deploymentLoader - ) - deploymentRunner.start() - delay(75) - deploymentRunner.stop() - coVerify(exactly = 2) { deploymentLoader.loadDeployment(eq(deploymentKey)) } - coVerify(exactly = 1) { deploymentStorage.getAllFlags(eq(deploymentKey)) } - coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } - } + fun `start, load deployment throws, does not throw, pollers run`(): Unit = + runBlocking { + val cohortIds = setOf("a") + val flag = flag(cohortIds = cohortIds) + val configuration = + Configuration( + flagSyncIntervalMillis = 50, + cohortSyncIntervalMillis = 50, + ) + val cohortLoader = mockk() + val deploymentStorage = mockk() + val deploymentLoader = mockk() + coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } throws RuntimeException("fail") + coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } returns mapOf(flag.key to flag) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + val deploymentRunner = + DeploymentRunner( + configuration, + deploymentKey, + cohortLoader, + deploymentStorage, + deploymentLoader, + ) + deploymentRunner.start() + delay(75) + deploymentRunner.stop() + coVerify(exactly = 2) { deploymentLoader.loadDeployment(eq(deploymentKey)) } + coVerify(exactly = 1) { deploymentStorage.getAllFlags(eq(deploymentKey)) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + } @Test - fun `start get all flags throws, does not throw`(): Unit = runBlocking { - val configuration = Configuration( - flagSyncIntervalMillis = 10, - cohortSyncIntervalMillis = 10, - ) - val cohortLoader = mockk() - val deploymentStorage = mockk() - val deploymentLoader = mockk() - coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } returns Unit - coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } throws RuntimeException("fail") - coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit - val deploymentRunner = DeploymentRunner( - configuration, - deploymentKey, - cohortLoader, - deploymentStorage, - deploymentLoader - ) - deploymentRunner.start() - delay(100) - deploymentRunner.stop() - coVerify(atLeast = 3) { deploymentLoader.loadDeployment(eq(deploymentKey)) } - coVerify(atLeast = 2) { deploymentStorage.getAllFlags(eq(deploymentKey)) } - coVerify(exactly = 0) { cohortLoader.loadCohorts(allAny()) } - } + fun `start get all flags throws, does not throw`(): Unit = + runBlocking { + val configuration = + Configuration( + flagSyncIntervalMillis = 10, + cohortSyncIntervalMillis = 10, + ) + val cohortLoader = mockk() + val deploymentStorage = mockk() + val deploymentLoader = mockk() + coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } returns Unit + coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } throws RuntimeException("fail") + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + val deploymentRunner = + DeploymentRunner( + configuration, + deploymentKey, + cohortLoader, + deploymentStorage, + deploymentLoader, + ) + deploymentRunner.start() + delay(100) + deploymentRunner.stop() + coVerify(atLeast = 3) { deploymentLoader.loadDeployment(eq(deploymentKey)) } + coVerify(atLeast = 2) { deploymentStorage.getAllFlags(eq(deploymentKey)) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(allAny()) } + } } diff --git a/core/src/test/kotlin/deployment/DeploymentStorageTest.kt b/core/src/test/kotlin/deployment/DeploymentStorageTest.kt index acf9356..96b9892 100644 --- a/core/src/test/kotlin/deployment/DeploymentStorageTest.kt +++ b/core/src/test/kotlin/deployment/DeploymentStorageTest.kt @@ -1,113 +1,115 @@ package deployment -import test.InMemoryRedis import com.amplitude.deployment.DeploymentStorage import com.amplitude.deployment.InMemoryDeploymentStorage import com.amplitude.deployment.RedisDeploymentStorage -import test.deployment -import test.flag import kotlinx.coroutines.runBlocking import org.junit.Assert.assertEquals import org.junit.Test +import test.InMemoryRedis +import test.deployment +import test.flag import kotlin.test.assertNull class DeploymentStorageTest { - private val redis = InMemoryRedis() @Test - fun `test in memory`(): Unit = runBlocking { - test(InMemoryDeploymentStorage()) - } + fun `test in memory`(): Unit = + runBlocking { + test(InMemoryDeploymentStorage()) + } @Test - fun `test redis`(): Unit = runBlocking { - test(RedisDeploymentStorage("amplitude", "12345", redis, redis)) - } + fun `test redis`(): Unit = + runBlocking { + test(RedisDeploymentStorage("amplitude", "12345", redis, redis)) + } - private fun test(storage: DeploymentStorage): Unit = runBlocking { - val deploymentA = deployment("a", "1") - val deploymentB = deployment("b", "2") + private fun test(storage: DeploymentStorage): Unit = + runBlocking { + val deploymentA = deployment("a", "1") + val deploymentB = deployment("b", "2") - // get deployment, null - var deployment = storage.getDeployment(deploymentA.key) - assertNull(deployment) - // get deployments, empty - var deployments = storage.getDeployments() - assertEquals(0, deployments.size) - // put, get deployment, deployment - storage.putDeployment(deploymentA) - deployment = storage.getDeployment(deploymentA.key) - assertEquals(deploymentA, deployment) - // put, get deployments, deployments - storage.putDeployment(deploymentB) - deployments = storage.getDeployments() - assertEquals( - mapOf(deploymentA.key to deploymentA, deploymentB.key to deploymentB), - deployments - ) + // get deployment, null + var deployment = storage.getDeployment(deploymentA.key) + assertNull(deployment) + // get deployments, empty + var deployments = storage.getDeployments() + assertEquals(0, deployments.size) + // put, get deployment, deployment + storage.putDeployment(deploymentA) + deployment = storage.getDeployment(deploymentA.key) + assertEquals(deploymentA, deployment) + // put, get deployments, deployments + storage.putDeployment(deploymentB) + deployments = storage.getDeployments() + assertEquals( + mapOf(deploymentA.key to deploymentA, deploymentB.key to deploymentB), + deployments, + ) - val flag1 = flag("1") - val flag2 = flag("2") + val flag1 = flag("1") + val flag2 = flag("2") - // get flag, null - var flag = storage.getFlag(deploymentA.key, flag1.key) - assertNull(flag) - // get all flags, empty - var flags = storage.getAllFlags(deploymentA.key) - assertEquals(0, flags.size) - // put, get flag, flag - storage.putFlag(deploymentA.key, flag1) - flag = storage.getFlag(deploymentA.key, flag1.key) - assertEquals(flag1, flag) - // put, get all flags, flags - storage.putFlag(deploymentA.key, flag2) - flags = storage.getAllFlags(deploymentA.key) - assertEquals(mapOf(flag1.key to flag1, flag2.key to flag2), flags) - // remove, get removed, null - storage.removeFlag(deploymentA.key, flag1.key) - flag = storage.getFlag(deploymentA.key, flag1.key) - assertNull(flag) - // get other, other - flag = storage.getFlag(deploymentA.key, flag2.key) - assertEquals(flag2, flag) - // get all flags, other - flags = storage.getAllFlags(deploymentA.key) - assertEquals(mapOf(flag2.key to flag2), flags) - // remove all flags, get, null - storage.removeAllFlags(deploymentA.key) - // get all flags, empty - flags = storage.getAllFlags(deploymentA.key) - assertEquals(0, flags.size) + // get flag, null + var flag = storage.getFlag(deploymentA.key, flag1.key) + assertNull(flag) + // get all flags, empty + var flags = storage.getAllFlags(deploymentA.key) + assertEquals(0, flags.size) + // put, get flag, flag + storage.putFlag(deploymentA.key, flag1) + flag = storage.getFlag(deploymentA.key, flag1.key) + assertEquals(flag1, flag) + // put, get all flags, flags + storage.putFlag(deploymentA.key, flag2) + flags = storage.getAllFlags(deploymentA.key) + assertEquals(mapOf(flag1.key to flag1, flag2.key to flag2), flags) + // remove, get removed, null + storage.removeFlag(deploymentA.key, flag1.key) + flag = storage.getFlag(deploymentA.key, flag1.key) + assertNull(flag) + // get other, other + flag = storage.getFlag(deploymentA.key, flag2.key) + assertEquals(flag2, flag) + // get all flags, other + flags = storage.getAllFlags(deploymentA.key) + assertEquals(mapOf(flag2.key to flag2), flags) + // remove all flags, get, null + storage.removeAllFlags(deploymentA.key) + // get all flags, empty + flags = storage.getAllFlags(deploymentA.key) + assertEquals(0, flags.size) - // put flags - storage.putFlag(deploymentA.key, flag1) - storage.putFlag(deploymentA.key, flag2) - flags = storage.getAllFlags(deploymentA.key) - assertEquals(mapOf(flag1.key to flag1, flag2.key to flag2), flags) + // put flags + storage.putFlag(deploymentA.key, flag1) + storage.putFlag(deploymentA.key, flag2) + flags = storage.getAllFlags(deploymentA.key) + assertEquals(mapOf(flag1.key to flag1, flag2.key to flag2), flags) - // remove deployment, get removed, null - storage.removeDeployment(deploymentA.key) - // get flag, null - flag = storage.getFlag(deploymentA.key, flag1.key) - assertNull(flag) - flag = storage.getFlag(deploymentA.key, flag2.key) - assertNull(flag) - // get all flags, empty - flags = storage.getAllFlags(deploymentA.key) - assertEquals(0, flags.size) - // get other, deployment - deployment = storage.getDeployment(deploymentB.key) - assertEquals(deploymentB, deployment) - // get deployments, other - deployments = storage.getDeployments() - assertEquals(mapOf(deploymentB.key to deploymentB), deployments) - // get all flags, empty - flags = storage.getAllFlags(deploymentB.key) - assertEquals(0, flags.size) - // remove, get deployments, empty - storage.removeDeployment(deploymentB.key) - deployments = storage.getDeployments() - assertEquals(0, deployments.size) - } + // remove deployment, get removed, null + storage.removeDeployment(deploymentA.key) + // get flag, null + flag = storage.getFlag(deploymentA.key, flag1.key) + assertNull(flag) + flag = storage.getFlag(deploymentA.key, flag2.key) + assertNull(flag) + // get all flags, empty + flags = storage.getAllFlags(deploymentA.key) + assertEquals(0, flags.size) + // get other, deployment + deployment = storage.getDeployment(deploymentB.key) + assertEquals(deploymentB, deployment) + // get deployments, other + deployments = storage.getDeployments() + assertEquals(mapOf(deploymentB.key to deploymentB), deployments) + // get all flags, empty + flags = storage.getAllFlags(deploymentB.key) + assertEquals(0, flags.size) + // remove, get deployments, empty + storage.removeDeployment(deploymentB.key) + deployments = storage.getDeployments() + assertEquals(0, deployments.size) + } } diff --git a/core/src/test/kotlin/project/ProjectApiTest.kt b/core/src/test/kotlin/project/ProjectApiTest.kt index d7fdffe..a9653df 100644 --- a/core/src/test/kotlin/project/ProjectApiTest.kt +++ b/core/src/test/kotlin/project/ProjectApiTest.kt @@ -3,7 +3,6 @@ package project import com.amplitude.project.DeploymentsResponse import com.amplitude.project.ProjectApiV1 import com.amplitude.util.json -import test.deployment import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respond import io.ktor.http.HttpMethod @@ -12,56 +11,61 @@ import io.ktor.utils.io.ByteReadChannel import kotlinx.coroutines.runBlocking import kotlinx.serialization.encodeToString import org.junit.Test +import test.deployment import test.toSerialDeployment import kotlin.test.assertEquals class ProjectApiTest { - private val managementKey = "managementKey" private val serverUrl = "https://experiment.amplitude.com/" @Test - fun `get deployments, success`(): Unit = runBlocking { - val deploymentA = deployment("a") - val deploymentB = deployment("b") - val expected = listOf(deploymentA, deploymentB) - val serialDeployments = expected.map { it.toSerialDeployment() } - val mockEngine = MockEngine { request -> - respond( - content = ByteReadChannel(json.encodeToString(DeploymentsResponse(serialDeployments))), - status = HttpStatusCode.OK - ) + fun `get deployments, success`(): Unit = + runBlocking { + val deploymentA = deployment("a") + val deploymentB = deployment("b") + val expected = listOf(deploymentA, deploymentB) + val serialDeployments = expected.map { it.toSerialDeployment() } + val mockEngine = + MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(DeploymentsResponse(serialDeployments))), + status = HttpStatusCode.OK, + ) + } + val api = ProjectApiV1(serverUrl, managementKey, mockEngine) + val actual = api.getDeployments() + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/api/1/deployments", request.url.encodedPath) + assertEquals("Bearer $managementKey", request.headers["Authorization"]) } - val api = ProjectApiV1(serverUrl, managementKey, mockEngine) - val actual = api.getDeployments() - assertEquals(expected, actual) - val request = mockEngine.requestHistory[0] - assertEquals(HttpMethod.Get, request.method) - assertEquals("/api/1/deployments", request.url.encodedPath) - assertEquals("Bearer $managementKey", request.headers["Authorization"]) - } @Test - fun `get deployments, one deleted, returns only active deployment`(): Unit = runBlocking { - val deploymentA = deployment("a") - val deploymentB = deployment("b") - val expected = listOf(deploymentB) - val serialDeployments = listOf( - deploymentA.toSerialDeployment(true), - deploymentB.toSerialDeployment() - ) - val mockEngine = MockEngine { request -> - respond( - content = ByteReadChannel(json.encodeToString(DeploymentsResponse(serialDeployments))), - status = HttpStatusCode.OK - ) + fun `get deployments, one deleted, returns only active deployment`(): Unit = + runBlocking { + val deploymentA = deployment("a") + val deploymentB = deployment("b") + val expected = listOf(deploymentB) + val serialDeployments = + listOf( + deploymentA.toSerialDeployment(true), + deploymentB.toSerialDeployment(), + ) + val mockEngine = + MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(DeploymentsResponse(serialDeployments))), + status = HttpStatusCode.OK, + ) + } + val api = ProjectApiV1(serverUrl, managementKey, mockEngine) + val actual = api.getDeployments() + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/api/1/deployments", request.url.encodedPath) + assertEquals("Bearer $managementKey", request.headers["Authorization"]) } - val api = ProjectApiV1(serverUrl, managementKey, mockEngine) - val actual = api.getDeployments() - assertEquals(expected, actual) - val request = mockEngine.requestHistory[0] - assertEquals(HttpMethod.Get, request.method) - assertEquals("/api/1/deployments", request.url.encodedPath) - assertEquals("Bearer $managementKey", request.headers["Authorization"]) - } } diff --git a/core/src/test/kotlin/project/ProjectProxyTest.kt b/core/src/test/kotlin/project/ProjectProxyTest.kt index 6b81fd8..6a5096c 100644 --- a/core/src/test/kotlin/project/ProjectProxyTest.kt +++ b/core/src/test/kotlin/project/ProjectProxyTest.kt @@ -1,6 +1,5 @@ package project -import test.cohort import com.amplitude.Configuration import com.amplitude.assignment.AssignmentTracker import com.amplitude.cohort.GetCohortResponse @@ -8,311 +7,351 @@ import com.amplitude.cohort.InMemoryCohortStorage import com.amplitude.deployment.InMemoryDeploymentStorage import com.amplitude.project.ProjectProxy import com.amplitude.util.json -import test.deployment -import test.flag import io.ktor.http.HttpStatusCode import io.mockk.coEvery import io.mockk.mockk import kotlinx.coroutines.runBlocking import kotlinx.serialization.encodeToString +import test.cohort +import test.deployment +import test.flag import test.project import kotlin.test.Test import kotlin.test.assertEquals class ProjectProxyTest { - private val project = project() private val configuration = Configuration() @Test - fun `test get flag configs, null deployment, unauthorized`(): Unit = runBlocking { - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage() - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getFlagConfigs(null) - assertEquals(HttpStatusCode.Unauthorized, result.status) - } + fun `test get flag configs, null deployment, unauthorized`(): Unit = + runBlocking { + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getFlagConfigs(null) + assertEquals(HttpStatusCode.Unauthorized, result.status) + } @Test - fun `test get flag configs, with deployment, success`(): Unit = runBlocking { - val deployment = deployment("deployment") - val flag = flag("flag") - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage().apply { - putFlag(deployment.key, flag) + fun `test get flag configs, with deployment, success`(): Unit = + runBlocking { + val deployment = deployment("deployment") + val flag = flag("flag") + val assignmentTracker = mockk() + val deploymentStorage = + InMemoryDeploymentStorage().apply { + putFlag(deployment.key, flag) + } + val cohortStorage = InMemoryCohortStorage() + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getFlagConfigs(deployment.key) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(listOf(flag)), result.body) } - val cohortStorage = InMemoryCohortStorage() - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getFlagConfigs(deployment.key) - assertEquals(HttpStatusCode.OK, result.status) - assertEquals(json.encodeToString(listOf(flag)), result.body) - } @Test - fun `test get cohort, null cohort id, not found`(): Unit = runBlocking { - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage() - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohort(null, null, null) - assertEquals(HttpStatusCode.NotFound, result.status) - } + fun `test get cohort, null cohort id, not found`(): Unit = + runBlocking { + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohort(null, null, null) + assertEquals(HttpStatusCode.NotFound, result.status) + } @Test - fun `test get cohort, with cohort id, success`(): Unit = runBlocking { - val cohort = cohort("a") - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage().apply { - putCohort(cohort) + fun `test get cohort, with cohort id, success`(): Unit = + runBlocking { + val cohort = cohort("a") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = + InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohort("a", null, null) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(GetCohortResponse.fromCohort(cohort)), result.body) } - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohort("a", null, null) - assertEquals(HttpStatusCode.OK, result.status) - assertEquals(json.encodeToString(GetCohortResponse.fromCohort(cohort)), result.body) - } @Test - fun `test get cohort, with cohort id and last modified, success`(): Unit = runBlocking { - val cohort = cohort("a", 100) - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage().apply { - putCohort(cohort) + fun `test get cohort, with cohort id and last modified, success`(): Unit = + runBlocking { + val cohort = cohort("a", 100) + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = + InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohort("a", 1, null) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(GetCohortResponse.fromCohort(cohort)), result.body) } - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohort("a", 1, null) - assertEquals(HttpStatusCode.OK, result.status) - assertEquals(json.encodeToString(GetCohortResponse.fromCohort(cohort)), result.body) - } @Test - fun `test get cohort, with cohort id and last modified, not changed`(): Unit = runBlocking { - val cohort = cohort("a", 100) - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage().apply { - putCohort(cohort) + fun `test get cohort, with cohort id and last modified, not changed`(): Unit = + runBlocking { + val cohort = cohort("a", 100) + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = + InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohort("a", 100, null) + assertEquals(HttpStatusCode.NoContent, result.status) } - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohort("a", 100, null) - assertEquals(HttpStatusCode.NoContent, result.status) - } @Test - fun `test get cohort, with cohort id and max cohort size, success`(): Unit = runBlocking { - val deployment = deployment("deployment") - val flag = flag("flag", setOf("a")) - val cohort = cohort("a", 100, 100) - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage().apply { - putFlag(deployment.key, flag) + fun `test get cohort, with cohort id and max cohort size, success`(): Unit = + runBlocking { + val deployment = deployment("deployment") + val flag = flag("flag", setOf("a")) + val cohort = cohort("a", 100, 100) + val assignmentTracker = mockk() + val deploymentStorage = + InMemoryDeploymentStorage().apply { + putFlag(deployment.key, flag) + } + val cohortStorage = + InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohort("a", null, Int.MAX_VALUE) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(GetCohortResponse.fromCohort(cohort)), result.body) } - val cohortStorage = InMemoryCohortStorage().apply { - putCohort(cohort) - } - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohort("a", null, Int.MAX_VALUE) - assertEquals(HttpStatusCode.OK, result.status) - assertEquals(json.encodeToString(GetCohortResponse.fromCohort(cohort)), result.body) - } - @Test - fun `test get cohort, with cohort id and max cohort size, too large`(): Unit = runBlocking { - val cohort = cohort("a", 100, 100) - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage().apply { - putCohort(cohort) + fun `test get cohort, with cohort id and max cohort size, too large`(): Unit = + runBlocking { + val cohort = cohort("a", 100, 100) + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = + InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohort("a", null, 99) + assertEquals(HttpStatusCode.PayloadTooLarge, result.status) } - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohort("a", null, 99) - assertEquals(HttpStatusCode.PayloadTooLarge, result.status) - } @Test - fun `test get cohort memberships for group, null deployment, unauthorized`(): Unit = runBlocking { - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage() - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohortMemberships(null, null, null) - assertEquals(HttpStatusCode.Unauthorized, result.status) - } + fun `test get cohort memberships for group, null deployment, unauthorized`(): Unit = + runBlocking { + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohortMemberships(null, null, null) + assertEquals(HttpStatusCode.Unauthorized, result.status) + } @Test - fun `test get cohort memberships for group, with deployment and group type, null group name, bad request`(): Unit = runBlocking { - val deployment = deployment("deployment") - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage() - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohortMemberships(deployment.key, "User", null) - assertEquals(HttpStatusCode.BadRequest, result.status) - } + fun `test get cohort memberships for group, with deployment and group type, null group name, bad request`(): Unit = + runBlocking { + val deployment = deployment("deployment") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohortMemberships(deployment.key, "User", null) + assertEquals(HttpStatusCode.BadRequest, result.status) + } @Test - fun `test get cohort memberships for group, with deployment and group name, null group type, bad request`(): Unit = runBlocking { - val deployment = deployment("deployment") - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage() - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohortMemberships(deployment.key, null, "1") - assertEquals(HttpStatusCode.BadRequest, result.status) - } + fun `test get cohort memberships for group, with deployment and group name, null group type, bad request`(): Unit = + runBlocking { + val deployment = deployment("deployment") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohortMemberships(deployment.key, null, "1") + assertEquals(HttpStatusCode.BadRequest, result.status) + } @Test - fun `test get cohort memberships for group, with deployment group name and group type, success`(): Unit = runBlocking { - val deployment = deployment("deployment") - val cohort = cohort("a") - val flag = flag("flag", setOf("a")) - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage().apply { - putFlag(deployment.key, flag) + fun `test get cohort memberships for group, with deployment group name and group type, success`(): Unit = + runBlocking { + val deployment = deployment("deployment") + val cohort = cohort("a") + val flag = flag("flag", setOf("a")) + val assignmentTracker = mockk() + val deploymentStorage = + InMemoryDeploymentStorage().apply { + putFlag(deployment.key, flag) + } + val cohortStorage = + InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val result = projectProxy.getCohortMemberships(deployment.key, "User", "1") + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(listOf(cohort.id)), result.body) } - val cohortStorage = InMemoryCohortStorage().apply { - putCohort(cohort) - } - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val result = projectProxy.getCohortMemberships(deployment.key, "User", "1") - assertEquals(HttpStatusCode.OK, result.status) - assertEquals(json.encodeToString(listOf(cohort.id)), result.body) - } @Test - fun `test evaluate, null deployment, unauthorized`(): Unit = runBlocking { - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage() - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) + fun `test evaluate, null deployment, unauthorized`(): Unit = + runBlocking { + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) - val result = projectProxy.evaluate(null, null, null) - assertEquals(HttpStatusCode.Unauthorized, result.status) - } + val result = projectProxy.evaluate(null, null, null) + assertEquals(HttpStatusCode.Unauthorized, result.status) + } @Test - fun `test evaluate, with deployment, null user, success`(): Unit = runBlocking { - val deployment = deployment("deployment") - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage() - val cohortStorage = InMemoryCohortStorage() - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) + fun `test evaluate, with deployment, null user, success`(): Unit = + runBlocking { + val deployment = deployment("deployment") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) - val result = projectProxy.evaluate(deployment.key, null, null) - assertEquals(HttpStatusCode.OK, result.status) - assertEquals("{}", result.body) - } + val result = projectProxy.evaluate(deployment.key, null, null) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals("{}", result.body) + } @Test - fun `test evaluate, with deployment and flag keys, success`(): Unit = runBlocking { - val cohort = cohort("a") - val flag = flag("flag", setOf("a")) - val deployment = deployment("deployment") - val assignmentTracker = mockk() - val deploymentStorage = InMemoryDeploymentStorage().apply { - putFlag(deployment.key, flag) - } - val cohortStorage = InMemoryCohortStorage().apply { - putCohort(cohort) + fun `test evaluate, with deployment and flag keys, success`(): Unit = + runBlocking { + val cohort = cohort("a") + val flag = flag("flag", setOf("a")) + val deployment = deployment("deployment") + val assignmentTracker = mockk() + val deploymentStorage = + InMemoryDeploymentStorage().apply { + putFlag(deployment.key, flag) + } + val cohortStorage = + InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = + ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage, + ) + val user = mapOf("user_id" to "1") + coEvery { assignmentTracker.track(allAny()) } returns Unit + val result = projectProxy.evaluate(deployment.key, user, setOf("flag")) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals("""{"flag":{"key":"on","value":"on"}}""", result.body) } - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - val user = mapOf("user_id" to "1") - coEvery { assignmentTracker.track(allAny()) } returns Unit - val result = projectProxy.evaluate(deployment.key, user, setOf("flag")) - assertEquals(HttpStatusCode.OK, result.status) - assertEquals("""{"flag":{"key":"on","value":"on"}}""", result.body) - } } diff --git a/core/src/test/kotlin/project/ProjectRunnerTest.kt b/core/src/test/kotlin/project/ProjectRunnerTest.kt index 2a3f011..8bab030 100644 --- a/core/src/test/kotlin/project/ProjectRunnerTest.kt +++ b/core/src/test/kotlin/project/ProjectRunnerTest.kt @@ -1,6 +1,5 @@ package project -import test.cohort import com.amplitude.Configuration import com.amplitude.cohort.CohortLoader import com.amplitude.cohort.CohortStorage @@ -11,8 +10,6 @@ import com.amplitude.deployment.DeploymentStorage import com.amplitude.deployment.InMemoryDeploymentStorage import com.amplitude.project.ProjectApi import com.amplitude.project.ProjectRunner -import test.deployment -import test.flag import io.mockk.coEvery import io.mockk.coVerify import io.mockk.mockk @@ -20,108 +17,121 @@ import io.mockk.spyk import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import org.junit.Assert.assertNull -import test.project import org.junit.Test +import test.cohort +import test.deployment +import test.flag +import test.project import kotlin.test.assertNotNull class ProjectRunnerTest { - private val project = project() private val config = Configuration(deploymentSyncIntervalMillis = 50) @Test - fun `test start, no state initial load, success`(): Unit = runBlocking { - val projectApi = mockk() - val deploymentLoader = mockk() - val deploymentStorage = spyk(InMemoryDeploymentStorage()) - val cohortLoader = mockk() - val cohortStorage = spyk(InMemoryCohortStorage()) - val runner = ProjectRunner( - project, - config, - projectApi, - deploymentLoader, - deploymentStorage, - cohortLoader, - cohortStorage - ) - coEvery { projectApi.getDeployments() } returns listOf(deployment("a")) - coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit - coEvery { deploymentLoader.loadDeployment(allAny()) } returns Unit - runner.start() - runner.stop() - coVerify(exactly = 1) { projectApi.getDeployments() } - coVerify(exactly = 1) { deploymentStorage.getDeployments() } - coVerify(exactly = 1) { deploymentStorage.putDeployment(eq(deployment("a"))) } - assertNotNull(runner.deploymentRunners["a"]) - coVerify(exactly = 0) { deploymentStorage.removeAllFlags(allAny()) } - coVerify(exactly = 0) { deploymentStorage.removeDeployment(allAny()) } - coVerify(exactly = 1) { deploymentStorage.getAllFlags(allAny()) } - coVerify(exactly = 1) { cohortStorage.getCohortDescriptions() } - coVerify(exactly = 0) { cohortStorage.deleteCohort(allAny()) } - } + fun `test start, no state initial load, success`(): Unit = + runBlocking { + val projectApi = mockk() + val deploymentLoader = mockk() + val deploymentStorage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortStorage = spyk(InMemoryCohortStorage()) + val runner = + ProjectRunner( + project, + config, + projectApi, + deploymentLoader, + deploymentStorage, + cohortLoader, + cohortStorage, + ) + coEvery { projectApi.getDeployments() } returns listOf(deployment("a")) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + coEvery { deploymentLoader.loadDeployment(allAny()) } returns Unit + runner.start() + runner.stop() + coVerify(exactly = 1) { projectApi.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.putDeployment(eq(deployment("a"))) } + assertNotNull(runner.deploymentRunners["a"]) + coVerify(exactly = 0) { deploymentStorage.removeAllFlags(allAny()) } + coVerify(exactly = 0) { deploymentStorage.removeDeployment(allAny()) } + coVerify(exactly = 1) { deploymentStorage.getAllFlags(allAny()) } + coVerify(exactly = 1) { cohortStorage.getCohortDescriptions() } + coVerify(exactly = 0) { cohortStorage.deleteCohort(allAny()) } + } @Test - fun `test start, with initial state, add and remove deployment, success`(): Unit = runBlocking { - val projectApi = mockk() - val deploymentLoader = mockk() - val deploymentStorage = spyk(InMemoryDeploymentStorage().apply { - putDeployment(deployment("a")) - putFlag("a", flag(cohortIds = setOf("aa"))) - }) - val cohortLoader = mockk() - val cohortStorage = spyk(InMemoryCohortStorage().apply { - putCohort(cohort("aa")) - }) - coEvery { projectApi.getDeployments() } returns listOf(deployment("b")) - coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit - coEvery { deploymentLoader.loadDeployment(allAny()) } returns Unit - val runner = ProjectRunner( - project, - config, - projectApi, - deploymentLoader, - deploymentStorage, - cohortLoader, - cohortStorage - ) - runner.start() - runner.stop() - coVerify(exactly = 1) { projectApi.getDeployments() } - coVerify(exactly = 1) { deploymentStorage.getDeployments() } - coVerify(exactly = 1) { deploymentStorage.putDeployment(eq(deployment("b"))) } - assertNull(runner.deploymentRunners["a"]) - assertNotNull(runner.deploymentRunners["b"]) - coVerify(exactly = 1) { deploymentStorage.removeAllFlags(eq("a")) } - coVerify(exactly = 1) { deploymentStorage.removeDeployment(eq("a")) } - coVerify(exactly = 1) { deploymentStorage.getAllFlags(eq("b")) } - coVerify(exactly = 1) { cohortStorage.getCohortDescriptions() } - coVerify(exactly = 1) { cohortStorage.deleteCohort(eq(cohort("aa").toCohortDescription())) } - } + fun `test start, with initial state, add and remove deployment, success`(): Unit = + runBlocking { + val projectApi = mockk() + val deploymentLoader = mockk() + val deploymentStorage = + spyk( + InMemoryDeploymentStorage().apply { + putDeployment(deployment("a")) + putFlag("a", flag(cohortIds = setOf("aa"))) + }, + ) + val cohortLoader = mockk() + val cohortStorage = + spyk( + InMemoryCohortStorage().apply { + putCohort(cohort("aa")) + }, + ) + coEvery { projectApi.getDeployments() } returns listOf(deployment("b")) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + coEvery { deploymentLoader.loadDeployment(allAny()) } returns Unit + val runner = + ProjectRunner( + project, + config, + projectApi, + deploymentLoader, + deploymentStorage, + cohortLoader, + cohortStorage, + ) + runner.start() + runner.stop() + coVerify(exactly = 1) { projectApi.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.putDeployment(eq(deployment("b"))) } + assertNull(runner.deploymentRunners["a"]) + assertNotNull(runner.deploymentRunners["b"]) + coVerify(exactly = 1) { deploymentStorage.removeAllFlags(eq("a")) } + coVerify(exactly = 1) { deploymentStorage.removeDeployment(eq("a")) } + coVerify(exactly = 1) { deploymentStorage.getAllFlags(eq("b")) } + coVerify(exactly = 1) { cohortStorage.getCohortDescriptions() } + coVerify(exactly = 1) { cohortStorage.deleteCohort(eq(cohort("aa").toCohortDescription())) } + } @Test - fun `test start, deployments api failure, does not throw`(): Unit = runBlocking { - val projectApi = mockk() - val deploymentLoader = mockk() - val deploymentStorage = spyk(InMemoryDeploymentStorage()) - val cohortLoader = mockk() - val cohortStorage = spyk(InMemoryCohortStorage()) - coEvery { projectApi.getDeployments() } throws RuntimeException("test") - coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit - coEvery { deploymentLoader.loadDeployment(allAny()) } returns Unit - val runner = ProjectRunner( - project, - config, - projectApi, - deploymentLoader, - deploymentStorage, - cohortLoader, - cohortStorage - ) - runner.start() - delay(150) - // Pollers should start and call getDeployments again - coVerify(atLeast = 2) { projectApi.getDeployments() } - - } + fun `test start, deployments api failure, does not throw`(): Unit = + runBlocking { + val projectApi = mockk() + val deploymentLoader = mockk() + val deploymentStorage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortStorage = spyk(InMemoryCohortStorage()) + coEvery { projectApi.getDeployments() } throws RuntimeException("test") + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + coEvery { deploymentLoader.loadDeployment(allAny()) } returns Unit + val runner = + ProjectRunner( + project, + config, + projectApi, + deploymentLoader, + deploymentStorage, + cohortLoader, + cohortStorage, + ) + runner.start() + delay(150) + // Pollers should start and call getDeployments again + coVerify(atLeast = 2) { projectApi.getDeployments() } + } } diff --git a/core/src/test/kotlin/project/ProjectStorageTest.kt b/core/src/test/kotlin/project/ProjectStorageTest.kt index f88f629..f447a54 100644 --- a/core/src/test/kotlin/project/ProjectStorageTest.kt +++ b/core/src/test/kotlin/project/ProjectStorageTest.kt @@ -1,45 +1,47 @@ package project -import test.InMemoryRedis import com.amplitude.project.InMemoryProjectStorage import com.amplitude.project.ProjectStorage import com.amplitude.project.RedisProjectStorage import kotlinx.coroutines.runBlocking import org.junit.Assert.assertEquals +import test.InMemoryRedis import kotlin.test.Test class ProjectStorageTest { - @Test - fun `test in memory`(): Unit = runBlocking { - test(InMemoryProjectStorage()) - } + fun `test in memory`(): Unit = + runBlocking { + test(InMemoryProjectStorage()) + } @Test - fun `test redis`(): Unit = runBlocking { - test(RedisProjectStorage("amplitude", InMemoryRedis())) - } + fun `test redis`(): Unit = + runBlocking { + test(RedisProjectStorage("amplitude", InMemoryRedis())) + } - private fun test(storage: ProjectStorage) = runBlocking { - // get projects, empty - var projects = storage.getProjects() - assertEquals(0, projects.size) - // put project, 1 - storage.putProject("1") - // put project, 2 - storage.putProject("2") - // get projects, 1 2 - projects = storage.getProjects() - assertEquals(setOf("1", "2"), projects) - // remove project 1, 2 - storage.removeProject("1") - // get projects, 2 - projects = storage.getProjects() - assertEquals(setOf("2"), projects) - // remove project 2 - storage.removeProject("2") - // get projects, empty - projects = storage.getProjects() - assertEquals(0, projects.size) - } + private fun test(storage: ProjectStorage) = + runBlocking { + // get projects, empty + var projects = storage.getProjects() + assertEquals(0, projects.size) + // put project, 1 + storage.putProject("1") + // put project, 2 + storage.putProject("2") + // get projects, 1 2 + projects = storage.getProjects() + assertEquals(setOf("1", "2"), projects) + // remove project 1, 2 + storage.removeProject("1") + // get projects, 2 + projects = storage.getProjects() + assertEquals(setOf("2"), projects) + // remove project 2 + storage.removeProject("2") + // get projects, empty + projects = storage.getProjects() + assertEquals(0, projects.size) + } } diff --git a/core/src/test/kotlin/test/InMemoryRedis.kt b/core/src/test/kotlin/test/InMemoryRedis.kt index 02b0e9e..6aec712 100644 --- a/core/src/test/kotlin/test/InMemoryRedis.kt +++ b/core/src/test/kotlin/test/InMemoryRedis.kt @@ -4,8 +4,7 @@ import com.amplitude.util.Redis import com.amplitude.util.RedisKey import kotlin.time.Duration -internal class InMemoryRedis: Redis { - +internal class InMemoryRedis : Redis { private val kv = mutableMapOf() private val sets = mutableMapOf>() private val hashes = mutableMapOf>() @@ -14,7 +13,10 @@ internal class InMemoryRedis: Redis { return kv[key.value] } - override suspend fun set(key: RedisKey, value: String) { + override suspend fun set( + key: RedisKey, + value: String, + ) { kv[key.value] = value } @@ -22,11 +24,17 @@ internal class InMemoryRedis: Redis { kv.remove(key.value) } - override suspend fun sadd(key: RedisKey, values: Set) { + override suspend fun sadd( + key: RedisKey, + values: Set, + ) { sets.getOrPut(key.value) { mutableSetOf() }.addAll(values) } - override suspend fun srem(key: RedisKey, value: String) { + override suspend fun srem( + key: RedisKey, + value: String, + ) { sets.getOrPut(key.value) { mutableSetOf() }.remove(value) } @@ -34,11 +42,17 @@ internal class InMemoryRedis: Redis { return sets[key.value]?.toSet() } - override suspend fun sismember(key: RedisKey, value: String): Boolean { + override suspend fun sismember( + key: RedisKey, + value: String, + ): Boolean { return sets[key.value]?.contains(value) ?: false } - override suspend fun hget(key: RedisKey, field: String): String? { + override suspend fun hget( + key: RedisKey, + field: String, + ): String? { return hashes.getOrPut(key.value) { mutableMapOf() }[field] } @@ -46,18 +60,27 @@ internal class InMemoryRedis: Redis { return hashes[key.value]?.toMap() } - override suspend fun hset(key: RedisKey, values: Map) { + override suspend fun hset( + key: RedisKey, + values: Map, + ) { hashes.getOrPut(key.value) { mutableMapOf() }.putAll(values) } - override suspend fun hdel(key: RedisKey, field: String) { + override suspend fun hdel( + key: RedisKey, + field: String, + ) { hashes[key.value]?.remove(field) - if(hashes[key.value]?.isEmpty() == true) { + if (hashes[key.value]?.isEmpty() == true) { hashes.remove(key.value) } } - override suspend fun expire(key: RedisKey, ttl: Duration) { + override suspend fun expire( + key: RedisKey, + ttl: Duration, + ) { // Do nothing. } } diff --git a/core/src/test/kotlin/test/Utils.kt b/core/src/test/kotlin/test/Utils.kt index 986d4d4..deb3c5f 100644 --- a/core/src/test/kotlin/test/Utils.kt +++ b/core/src/test/kotlin/test/Utils.kt @@ -15,43 +15,46 @@ internal fun user( deviceId: String? = null, userProperties: Map? = null, groups: Map>? = null, - groupProperties: Map>>? = null + groupProperties: Map>>? = null, ): MutableMap { return mutableMapOf( "user_id" to userId, "device_id" to deviceId, "user_properties" to userProperties, "groups" to groups, - "group_properties" to groupProperties + "group_properties" to groupProperties, ) } internal fun flag( flagKey: String = "flag", - cohortIds: Set = setOf("a") + cohortIds: Set = setOf("a"), ) = EvaluationFlag( key = flagKey, - variants = mapOf( - "off" to EvaluationVariant("off", null, null, null), - "on" to EvaluationVariant("on", "on", null, null) - ), - segments = listOf( - EvaluationSegment( - conditions = listOf( - listOf( - EvaluationCondition( - selector = listOf("context", "user", "cohort_ids"), - op = EvaluationOperator.SET_CONTAINS_ANY, - values = cohortIds - ) - ) + variants = + mapOf( + "off" to EvaluationVariant("off", null, null, null), + "on" to EvaluationVariant("on", "on", null, null), + ), + segments = + listOf( + EvaluationSegment( + conditions = + listOf( + listOf( + EvaluationCondition( + selector = listOf("context", "user", "cohort_ids"), + op = EvaluationOperator.SET_CONTAINS_ANY, + values = cohortIds, + ), + ), + ), + variant = "on", + ), + EvaluationSegment( + variant = "off", ), - variant = "on" ), - EvaluationSegment( - variant = "off" - ) - ) ) internal fun cohort( @@ -59,33 +62,38 @@ internal fun cohort( lastModified: Long = 100, size: Int = 1, members: Set = setOf("1"), - groupType: String = "User" + groupType: String = "User", ) = Cohort( id = id, groupType = groupType, size = size, lastModified = lastModified, - members = members + members = members, ) -internal fun deployment(key: String, projectId: String = "1") = Deployment( +internal fun deployment( + key: String, + projectId: String = "1", +) = Deployment( id = "1", projectId = projectId, label = "", key = key, ) -internal fun Deployment.toSerialDeployment(deleted: Boolean = false) = SerialDeployment( - id = id, - projectId = projectId, - label = label, - key = key, - deleted = deleted, -) +internal fun Deployment.toSerialDeployment(deleted: Boolean = false) = + SerialDeployment( + id = id, + projectId = projectId, + label = label, + key = key, + deleted = deleted, + ) -internal fun project(id: String = "1") = Project( - id = id, - apiKey = "api", - secretKey = "secret", - managementKey = "management" -) +internal fun project(id: String = "1") = + Project( + id = id, + apiKey = "api", + secretKey = "secret", + managementKey = "management", + ) diff --git a/core/src/test/kotlin/util/CacheTest.kt b/core/src/test/kotlin/util/CacheTest.kt index db7f5e1..202ca8c 100644 --- a/core/src/test/kotlin/util/CacheTest.kt +++ b/core/src/test/kotlin/util/CacheTest.kt @@ -7,100 +7,108 @@ import org.junit.Assert import org.junit.Test class CacheTest { - - @Test - fun `test get no entry`() = runBlocking { - val cache = Cache(4) - val value = cache.get(0) - Assert.assertNull(value) - } - @Test - fun `test set and get`() = runBlocking { - val cache = Cache(4) - cache.set(0, 0) - val value = cache.get(0) - Assert.assertEquals(0, value) - } + fun `test get no entry`() = + runBlocking { + val cache = Cache(4) + val value = cache.get(0) + Assert.assertNull(value) + } @Test - fun `test least recently used entry is removed`() = runBlocking { - val cache = Cache(4) - repeat(4) { i -> - cache.set(i, i) + fun `test set and get`() = + runBlocking { + val cache = Cache(4) + cache.set(0, 0) + val value = cache.get(0) + Assert.assertEquals(0, value) } - cache.set(4, 4) - val value = cache.get(0) - Assert.assertNull(value) - } @Test - fun `test first set then get entry is not removed`() = runBlocking { - val cache = Cache(4) - repeat(4) { i -> - cache.set(i, i) + fun `test least recently used entry is removed`() = + runBlocking { + val cache = Cache(4) + repeat(4) { i -> + cache.set(i, i) + } + cache.set(4, 4) + val value = cache.get(0) + Assert.assertNull(value) } - val expectedValue = cache.get(0) - cache.set(4, 4) - val actualValue = cache.get(0) - Assert.assertEquals(expectedValue, actualValue) - val removedValue = cache.get(1) - Assert.assertNull(removedValue) - } @Test - fun `test first set then re-set entry is not removed`() = runBlocking { - val cache = Cache(4) - repeat(4) { i -> - cache.set(i, i) + fun `test first set then get entry is not removed`() = + runBlocking { + val cache = Cache(4) + repeat(4) { i -> + cache.set(i, i) + } + val expectedValue = cache.get(0) + cache.set(4, 4) + val actualValue = cache.get(0) + Assert.assertEquals(expectedValue, actualValue) + val removedValue = cache.get(1) + Assert.assertNull(removedValue) } - cache.set(0, 0) - cache.set(4, 4) - val actualValue = cache.get(0) - Assert.assertEquals(0, actualValue) - val removedValue = cache.get(1) - Assert.assertNull(removedValue) - } @Test - fun `test first set then re-set with different value entry is not removed`() = runBlocking { - val cache = Cache(4) - repeat(4) { i -> - cache.set(i, i) + fun `test first set then re-set entry is not removed`() = + runBlocking { + val cache = Cache(4) + repeat(4) { i -> + cache.set(i, i) + } + cache.set(0, 0) + cache.set(4, 4) + val actualValue = cache.get(0) + Assert.assertEquals(0, actualValue) + val removedValue = cache.get(1) + Assert.assertNull(removedValue) } - cache.set(0, 100) - cache.set(4, 4) - val actualValue = cache.get(0) - Assert.assertEquals(100, actualValue) - val removedValue = cache.get(1) - Assert.assertNull(removedValue) - } @Test - fun `test concurrent access`() = runBlocking { - val n = 100 - val cache = Cache(n) - val jobs = mutableListOf() - repeat(n) { i -> - jobs += launch { + fun `test first set then re-set with different value entry is not removed`() = + runBlocking { + val cache = Cache(4) + repeat(4) { i -> cache.set(i, i) } + cache.set(0, 100) + cache.set(4, 4) + val actualValue = cache.get(0) + Assert.assertEquals(100, actualValue) + val removedValue = cache.get(1) + Assert.assertNull(removedValue) } - jobs.joinAll() - repeat(n) { i -> - Assert.assertEquals(i, cache.get(i)) - } - jobs.clear() - val k = 50 - repeat(k) { i -> - jobs += launch { - cache.set(i + k, i + k) + + @Test + fun `test concurrent access`() = + runBlocking { + val n = 100 + val cache = Cache(n) + val jobs = mutableListOf() + repeat(n) { i -> + jobs += + launch { + cache.set(i, i) + } + } + jobs.joinAll() + repeat(n) { i -> + Assert.assertEquals(i, cache.get(i)) + } + jobs.clear() + val k = 50 + repeat(k) { i -> + jobs += + launch { + cache.set(i + k, i + k) + } + } + jobs.joinAll() + repeat(k) { i -> + Assert.assertEquals(i, cache.get(i)) + Assert.assertEquals(i + k, cache.get(i + k)) } } - jobs.joinAll() - repeat(k) { i -> - Assert.assertEquals(i, cache.get(i)) - Assert.assertEquals(i + k, cache.get(i + k)) - } - } } diff --git a/core/src/test/kotlin/util/EvaluationFlagTest.kt b/core/src/test/kotlin/util/EvaluationFlagTest.kt index 1311e61..3ace080 100644 --- a/core/src/test/kotlin/util/EvaluationFlagTest.kt +++ b/core/src/test/kotlin/util/EvaluationFlagTest.kt @@ -7,17 +7,18 @@ import org.junit.Assert.assertEquals import kotlin.test.Test class EvaluationFlagTest { - + @Suppress("ktlint:standard:max-line-length") private val testFlagsJson = """[{"key":"flag1","segments":[{"conditions":[[{"op":"set contains any","selector":["context","user","cohort_ids"],"values":["hahahaha1"]}]]},{"metadata":{"segmentName":"All Other Users"},"variant":"off"}],"variants":{}},{"key":"flag2","segments":[{"conditions":[[{"op":"set contains any","selector":["context","user","cohort_ids"],"values":["hahahaha2"]}]],"metadata":{"segmentName":"Segment 1"},"variant":"off"},{"metadata":{"segmentName":"All Other Users"},"variant":"off"}],"variants":{}},{"key":"flag3","metadata":{"deployed":true,"evaluationMode":"local","experimentKey":"exp-1","flagType":"experiment","flagVersion":6},"segments":[{"conditions":[[{"op":"set contains any","selector":["context","user","cohort_ids"],"values":["hahahaha3"]}]],"variant":"off"},{"conditions":[[{"op":"set contains any","selector":["context","user","cocoids"],"values":["nohaha"]}]],"variant":"off"},{"metadata":{"segmentName":"All Other Users"},"variant":"off"}],"variants":{}},{"key":"flag5","segments":[{"conditions":[[{"op":"set contains any","selector":["context","user","cohort_ids"],"values":["hahahaha3","hahahaha4"]}]]},{"conditions":[[{"op":"set contains any","selector":["context","groups","org name","cohort_ids"],"values":["hahaorgname1"]}]],"metadata":{"segmentName":"Segment 1"}},{"conditions":[[{"op":"set contains any","selector":["context","gg","org name","cohort_ids"],"values":["nohahaorgname"]}]],"metadata":{"segmentName":"Segment 1"}}],"variants":{}}]""" private val testFlags = json.decodeFromString>(testFlagsJson) @Test fun `test get grouped cohort ids from flags`() { val result = testFlags.getGroupedCohortIds() - val expected = mapOf( - "User" to setOf("hahahaha1", "hahahaha2", "hahahaha3", "hahahaha4"), - "org name" to setOf("hahaorgname1"), - ) + val expected = + mapOf( + "User" to setOf("hahahaha1", "hahahaha2", "hahahaha3", "hahahaha4"), + "org name" to setOf("hahaorgname1"), + ) assertEquals(expected, result) } } diff --git a/service/src/main/kotlin/Server.kt b/service/src/main/kotlin/Server.kt index ddcc3bb..78b46ab 100644 --- a/service/src/main/kotlin/Server.kt +++ b/service/src/main/kotlin/Server.kt @@ -44,30 +44,33 @@ fun main() { log.info("Accessing proxy configuration.") val proxyConfigFilePath = stringEnv("PROXY_CONFIG_FILE_PATH", "/etc/evaluation-proxy-config.yaml")!! val proxyProjectsFilePath = stringEnv("PROXY_PROJECTS_FILE_PATH", proxyConfigFilePath)!! - val projectsFile = try { - ProjectsFile.fromFile(proxyProjectsFilePath).also { - log.info("Found projects file at $proxyProjectsFilePath") + val projectsFile = + try { + ProjectsFile.fromFile(proxyProjectsFilePath).also { + log.info("Found projects file at $proxyProjectsFilePath") + } + } catch (file: FileNotFoundException) { + log.info("Proxy projects file not found at $proxyProjectsFilePath, reading project from env.") + ProjectsFile.fromEnv() } - } catch (file: FileNotFoundException) { - log.info("Proxy projects file not found at $proxyProjectsFilePath, reading project from env.") - ProjectsFile.fromEnv() - } - val configFile = try { - ConfigurationFile.fromFile(proxyConfigFilePath).also { - log.info("Found configuration file at $proxyConfigFilePath") + val configFile = + try { + ConfigurationFile.fromFile(proxyConfigFilePath).also { + log.info("Found configuration file at $proxyConfigFilePath") + } + } catch (file: FileNotFoundException) { + log.info("Proxy config file not found at $proxyConfigFilePath, reading configuration from env.") + ConfigurationFile.fromEnv() } - } catch (file: FileNotFoundException) { - log.info("Proxy config file not found at $proxyConfigFilePath, reading configuration from env.") - ConfigurationFile.fromEnv() - } /* * Initialize and start the evaluation proxy. */ - evaluationProxy = EvaluationProxy( - projectsFile.projects, - configFile.configuration - ) + evaluationProxy = + EvaluationProxy( + projectsFile.projects, + configFile.configuration, + ) /* * Start the server. @@ -76,7 +79,7 @@ fun main() { factory = Netty, port = configFile.configuration.port, host = "0.0.0.0", - module = Application::proxyServer + module = Application::proxyServer, ).start(wait = true) } @@ -102,14 +105,13 @@ fun Application.proxyServer() { plugin.doShutdown(call) } } - } + }, ) /* * Configure endpoints. */ routing { - // Local Evaluation get("/sdk/v2/flags") { @@ -173,7 +175,7 @@ fun Application.proxyServer() { suspend fun ApplicationCall.evaluate( evaluationProxy: EvaluationProxy, - userProvider: suspend ApplicationRequest.() -> Map + userProvider: suspend ApplicationRequest.() -> Map, ) { // Deployment key is included in Authorization header with prefix "Api-Key " val deploymentKey = request.getDeploymentKey() @@ -185,7 +187,7 @@ suspend fun ApplicationCall.evaluate( suspend fun ApplicationCall.evaluateV1( evaluationProxy: EvaluationProxy, - userProvider: suspend ApplicationRequest.() -> Map + userProvider: suspend ApplicationRequest.() -> Map, ) { // Deployment key is included in Authorization header with prefix "Api-Key " val deploymentKey = request.getDeploymentKey() @@ -262,24 +264,27 @@ private fun ApplicationRequest.getUserFromQuery(): JsonObject { val userId = this.queryParameters["user_id"] val deviceId = this.queryParameters["device_id"] val context = this.queryParameters["context"] - var user: JsonObject = if (context != null) { - json.decodeFromString(context) - } else { - JsonObject(emptyMap()) - } + var user: JsonObject = + if (context != null) { + json.decodeFromString(context) + } else { + JsonObject(emptyMap()) + } if (userId != null) { - user = JsonObject( - user.toMutableMap().apply { - put("user_id", JsonPrimitive(userId)) - } - ) + user = + JsonObject( + user.toMutableMap().apply { + put("user_id", JsonPrimitive(userId)) + }, + ) } if (deviceId != null) { - user = JsonObject( - user.toMutableMap().apply { - put("device_id", JsonPrimitive(userId)) - } - ) + user = + JsonObject( + user.toMutableMap().apply { + put("device_id", JsonPrimitive(userId)) + }, + ) } return user }