From b123d93a1a74ae7abe454454398d53ae92f63919 Mon Sep 17 00:00:00 2001 From: Brian Giori Date: Thu, 1 Aug 2024 17:39:00 -0700 Subject: [PATCH] feat: new cohort api --- Makefile | 2 +- build.gradle.kts | 2 +- core/build.gradle.kts | 14 +- core/src/main/kotlin/Config.kt | 60 +++- core/src/main/kotlin/EvaluationProxy.kt | 250 ++++++++------ core/src/main/kotlin/Metrics.kt | 4 +- .../kotlin/assignment/AssignmentTracker.kt | 2 + core/src/main/kotlin/cohort/Cohort.kt | 15 + core/src/main/kotlin/cohort/CohortApi.kt | 160 ++++----- .../main/kotlin/cohort/CohortDescription.kt | 13 - core/src/main/kotlin/cohort/CohortLoader.kt | 50 ++- core/src/main/kotlin/cohort/CohortStorage.kt | 231 ++++++++----- .../main/kotlin/deployment/DeploymentApi.kt | 19 +- .../kotlin/deployment/DeploymentLoader.kt | 60 ++-- .../kotlin/deployment/DeploymentRunner.kt | 32 +- .../kotlin/deployment/DeploymentStorage.kt | 1 + core/src/main/kotlin/project/ProjectApi.kt | 15 +- core/src/main/kotlin/project/ProjectProxy.kt | 145 ++++---- core/src/main/kotlin/project/ProjectRunner.kt | 101 +++--- .../src/main/kotlin/project/ProjectStorage.kt | 18 +- core/src/main/kotlin/util/Http.kt | 17 +- core/src/main/kotlin/util/Loader.kt | 28 ++ core/src/main/kotlin/util/Redis.kt | 9 +- core/src/test/kotlin/EvaluationProxyTest.kt | 114 +++++++ core/src/test/kotlin/Utils.kt | 15 - .../kotlin/assignment/AssignmentFilterTest.kt | 1 + .../assignment/AssignmentServiceTest.kt | 1 + core/src/test/kotlin/cohort/CohortApiTest.kt | 219 ++++++++++++ .../test/kotlin/cohort/CohortLoaderTest.kt | 127 +++++++ .../test/kotlin/cohort/CohortStorageTest.kt | 102 ++++++ .../kotlin/deployment/DeploymentApiTest.kt | 151 +++++++++ .../kotlin/deployment/DeploymentLoaderTest.kt | 176 ++++++++++ .../kotlin/deployment/DeploymentRunnerTest.kt | 132 ++++++++ .../deployment/DeploymentStorageTest.kt | 113 +++++++ .../src/test/kotlin/project/ProjectApiTest.kt | 67 ++++ .../test/kotlin/project/ProjectProxyTest.kt | 318 ++++++++++++++++++ .../test/kotlin/project/ProjectRunnerTest.kt | 127 +++++++ .../test/kotlin/project/ProjectStorageTest.kt | 45 +++ core/src/test/kotlin/test/InMemoryRedis.kt | 63 ++++ core/src/test/kotlin/test/Utils.kt | 91 +++++ .../test/kotlin/util/EvaluationFlagTest.kt | 23 ++ gradle.properties | 21 +- service/build.gradle.kts | 9 +- service/src/main/kotlin/Server.kt | 110 +++--- 44 files changed, 2657 insertions(+), 616 deletions(-) create mode 100644 core/src/main/kotlin/cohort/Cohort.kt delete mode 100644 core/src/main/kotlin/cohort/CohortDescription.kt create mode 100644 core/src/main/kotlin/util/Loader.kt create mode 100644 core/src/test/kotlin/EvaluationProxyTest.kt delete mode 100644 core/src/test/kotlin/Utils.kt create mode 100644 core/src/test/kotlin/cohort/CohortApiTest.kt create mode 100644 core/src/test/kotlin/cohort/CohortLoaderTest.kt create mode 100644 core/src/test/kotlin/cohort/CohortStorageTest.kt create mode 100644 core/src/test/kotlin/deployment/DeploymentApiTest.kt create mode 100644 core/src/test/kotlin/deployment/DeploymentLoaderTest.kt create mode 100644 core/src/test/kotlin/deployment/DeploymentRunnerTest.kt create mode 100644 core/src/test/kotlin/deployment/DeploymentStorageTest.kt create mode 100644 core/src/test/kotlin/project/ProjectApiTest.kt create mode 100644 core/src/test/kotlin/project/ProjectProxyTest.kt create mode 100644 core/src/test/kotlin/project/ProjectRunnerTest.kt create mode 100644 core/src/test/kotlin/project/ProjectStorageTest.kt create mode 100644 core/src/test/kotlin/test/InMemoryRedis.kt create mode 100644 core/src/test/kotlin/test/Utils.kt create mode 100644 core/src/test/kotlin/util/EvaluationFlagTest.kt diff --git a/Makefile b/Makefile index 61a9e28..2977d9d 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ lint: ./gradlew ktlintFormat run: build - PROXY_CONFIG_FILE_PATH=`pwd`/config.yaml ./gradlew run --console=plain + AMPLITUDE_LOG_LEVEL=DEBUG PROXY_CONFIG_FILE_PATH=`pwd`/config.yaml PROXY_PROJECTS_FILE_PATH=`pwd`/projects.yaml ./gradlew run --console=plain docker-build: build docker build -t evaluation-proxy:local . diff --git a/build.gradle.kts b/build.gradle.kts index ca7559e..3ea7e5c 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,5 +1,5 @@ plugins { - kotlin("jvm") version "1.9.10" + kotlin("jvm") version "2.0.0" id("io.github.gradle-nexus.publish-plugin") version "1.1.0" } diff --git a/core/build.gradle.kts b/core/build.gradle.kts index eaccd88..9fa6acd 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -1,11 +1,11 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { - kotlin("jvm") version "1.9.10" - kotlin("plugin.serialization") version "1.9.0" + kotlin("jvm") version "2.0.0" + kotlin("plugin.serialization") version "2.0.0" `maven-publish` signing - id("org.jlleitschuh.gradle.ktlint") version "11.3.1" + id("org.jlleitschuh.gradle.ktlint") version "12.1.1" } java { @@ -20,14 +20,17 @@ tasks { // Defined in gradle.properties val kotlinVersion: String by project val ktorVersion: String by project +val coroutinesVersion: String by project +val serializationVersion: String by project val experimentEvaluationVersion: String by project val amplitudeAnalytics: String by project val amplitudeAnalyticsJson: String by project val lettuce: String by project -val apacheCommons: String by project val kaml: String by project +val mockk: String by project dependencies { + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:$coroutinesVersion") implementation("com.amplitude:evaluation-core:$experimentEvaluationVersion") implementation("com.amplitude:java-sdk:$amplitudeAnalytics") implementation("org.json:json:$amplitudeAnalyticsJson") @@ -35,8 +38,9 @@ dependencies { implementation("io.ktor:ktor-client-okhttp:$ktorVersion") implementation("io.ktor:ktor-serialization-kotlinx-json-jvm:$ktorVersion") implementation("com.charleskorn.kaml:kaml:$kaml") - implementation("org.apache.commons:commons-csv:$apacheCommons") testImplementation("org.jetbrains.kotlin:kotlin-test-junit:$kotlinVersion") + testImplementation("io.mockk:mockk:$mockk") + testImplementation("io.ktor:ktor-client-mock:$ktorVersion") } // Publishing diff --git a/core/src/main/kotlin/Config.kt b/core/src/main/kotlin/Config.kt index 8e079e3..4e3f4ad 100644 --- a/core/src/main/kotlin/Config.kt +++ b/core/src/main/kotlin/Config.kt @@ -76,8 +76,11 @@ data class ProjectConfiguration( @Serializable data class Configuration( val port: Int = Default.PORT, - val serverUrl: String = Default.SERVER_URL, - val cohortServerUrl: String = Default.COHORT_SERVER_URL, + val serverZone: String = Default.SERVER_ZONE, + val serverUrl: String = getServerUrl(serverZone), + val cohortServerUrl: String = getCohortServerUrl(serverZone), + val managementServerUrl: String = getManagementServerUrl(serverZone), + val analyticsServerUrl: String = getAnalyticsServerUrl(serverZone), val deploymentSyncIntervalMillis: Long = Default.DEPLOYMENT_SYNC_INTERVAL_MILLIS, val flagSyncIntervalMillis: Long = Default.FLAG_SYNC_INTERVAL_MILLIS, val cohortSyncIntervalMillis: Long = Default.COHORT_SYNC_INTERVAL_MILLIS, @@ -88,8 +91,11 @@ data class Configuration( companion object { fun fromEnv() = Configuration( port = intEnv(EnvKey.PORT, Default.PORT)!!, - serverUrl = stringEnv(EnvKey.SERVER_URL, Default.SERVER_URL)!!, - cohortServerUrl = stringEnv(EnvKey.COHORT_SERVER_URL, Default.COHORT_SERVER_URL)!!, + serverZone = stringEnv(EnvKey.SERVER_ZONE, Default.SERVER_ZONE)!!, + serverUrl = stringEnv(EnvKey.SERVER_URL, Default.US_SERVER_URL)!!, + cohortServerUrl = stringEnv(EnvKey.COHORT_SERVER_URL, Default.US_COHORT_SERVER_URL)!!, + managementServerUrl = stringEnv(EnvKey.MANAGEMENT_SERVER_URL, Default.US_MANAGEMENT_SERVER_URL)!!, + analyticsServerUrl = stringEnv(EnvKey.ANALYTICS_SERVER_URL, Default.US_ANALYTICS_SERVER_URL)!!, deploymentSyncIntervalMillis = longEnv( EnvKey.DEPLOYMENT_SYNC_INTERVAL_MILLIS, Default.DEPLOYMENT_SYNC_INTERVAL_MILLIS @@ -164,8 +170,11 @@ data class RedisConfiguration( object EnvKey { const val PORT = "AMPLITUDE_PORT" + const val SERVER_ZONE = "AMPLITUDE_SERVER_ZONE" const val SERVER_URL = "AMPLITUDE_SERVER_URL" const val COHORT_SERVER_URL = "AMPLITUDE_COHORT_SERVER_URL" + const val MANAGEMENT_SERVER_URL = "AMPLITUDE_MANAGEMENT_SERVER_URL" + const val ANALYTICS_SERVER_URL = "AMPLITUDE_ANALYTICS_SERVER_URL" const val API_KEY = "AMPLITUDE_API_KEY" const val SECRET_KEY = "AMPLITUDE_SECRET_KEY" @@ -188,8 +197,15 @@ object EnvKey { object Default { const val PORT = 3546 - const val SERVER_URL = "https://flag.lab.amplitude.com" - const val COHORT_SERVER_URL = "https://cohort.lab.amplitude.com" + const val SERVER_ZONE = "US" + const val US_SERVER_URL = "https://flag.lab.amplitude.com" + const val US_COHORT_SERVER_URL = "https://cohort-v2.lab.amplitude.com" + const val US_MANAGEMENT_SERVER_URL = "https://experiment.amplitude.com" + const val US_ANALYTICS_SERVER_URL = "https://api2.amplitude.com/2/httpapi" + const val EU_SERVER_URL = "https://flag.lab.eu.amplitude.com" + const val EU_COHORT_SERVER_URL = "https://cohort-v2.lab.eu.amplitude.com" + const val EU_MANAGEMENT_SERVER_URL = "https://experiment.eu.amplitude.com" + const val EU_ANALYTICS_SERVER_URL = "https://api.eu.amplitude.com/2/httpapi" const val DEPLOYMENT_SYNC_INTERVAL_MILLIS = 60 * 1000L const val FLAG_SYNC_INTERVAL_MILLIS = 10 * 1000L const val COHORT_SYNC_INTERVAL_MILLIS = 60 * 1000L @@ -204,3 +220,35 @@ object Default { val REDIS_READ_ONLY_URI: String? = null const val REDIS_PREFIX = "amplitude" } + +private fun getServerUrl(zone: String): String { + return if (zone == "EU") { + Default.EU_SERVER_URL + } else { + Default.US_SERVER_URL + } +} + +private fun getCohortServerUrl(zone: String): String { + return if (zone == "EU") { + Default.EU_COHORT_SERVER_URL + } else { + Default.US_COHORT_SERVER_URL + } +} + +private fun getManagementServerUrl(zone: String): String { + return if (zone == "EU") { + Default.EU_MANAGEMENT_SERVER_URL + } else { + Default.US_MANAGEMENT_SERVER_URL + } +} + +private fun getAnalyticsServerUrl(zone: String): String { + return if (zone == "EU") { + Default.EU_ANALYTICS_SERVER_URL + } else { + Default.US_ANALYTICS_SERVER_URL + } +} diff --git a/core/src/main/kotlin/EvaluationProxy.kt b/core/src/main/kotlin/EvaluationProxy.kt index 404bcd0..f6c35b0 100644 --- a/core/src/main/kotlin/EvaluationProxy.kt +++ b/core/src/main/kotlin/EvaluationProxy.kt @@ -1,17 +1,19 @@ package com.amplitude import com.amplitude.assignment.AmplitudeAssignmentTracker -import com.amplitude.cohort.CohortDescription +import com.amplitude.cohort.CohortStorage import com.amplitude.cohort.getCohortStorage +import com.amplitude.deployment.DeploymentStorage import com.amplitude.deployment.getDeploymentStorage -import com.amplitude.experiment.evaluation.EvaluationFlag -import com.amplitude.experiment.evaluation.EvaluationVariant import com.amplitude.project.Project +import com.amplitude.project.ProjectApi import com.amplitude.project.ProjectApiV1 import com.amplitude.project.ProjectProxy +import com.amplitude.project.ProjectStorage import com.amplitude.project.getProjectStorage import com.amplitude.util.json import com.amplitude.util.logger +import io.ktor.http.HttpStatusCode import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancelAndJoin @@ -22,23 +24,44 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.serialization.encodeToString +import org.jetbrains.annotations.VisibleForTesting import kotlin.time.DurationUnit import kotlin.time.toDuration -const val VERSION = "0.4.7" +const val EVALUATION_PROXY_VERSION = "0.4.7" -class HttpErrorResponseException( - val status: Int, - override val message: String, - override val cause: Exception? = null -) : Exception(message, cause) +data class EvaluationProxyResponse( + val status: HttpStatusCode, + val body: String +) { + companion object { + fun error(status: HttpStatusCode, message: String): EvaluationProxyResponse { + return EvaluationProxyResponse(status, message) + } + inline fun json(status: HttpStatusCode, response: T): EvaluationProxyResponse { + return EvaluationProxyResponse(status, json.encodeToString(response)) + } + } +} -class EvaluationProxy( +class EvaluationProxy internal constructor( private val projectConfigurations: List, - private val configuration: Configuration = Configuration(), - metricsHandler: MetricsHandler? = null + private val configuration: Configuration, + private val projectStorage: ProjectStorage, + metricsHandler: MetricsHandler? = null, ) { + constructor( + projectConfigurations: List, + configuration: Configuration = Configuration(), + metricsHandler: MetricsHandler? = null, + ) : this( + projectConfigurations, + configuration, + getProjectStorage(configuration.redis), + metricsHandler, + ) + companion object { val log by logger() } @@ -50,14 +73,13 @@ class EvaluationProxy( private val supervisor = SupervisorJob() private val scope = CoroutineScope(supervisor) - private val projectProxies = mutableMapOf() + @VisibleForTesting + internal val projectProxies = mutableMapOf() private val apiKeysToProject = mutableMapOf() private val secretKeysToProject = mutableMapOf() private val deploymentKeysToProject = mutableMapOf() private val mutex = Mutex() - private val projectStorage = getProjectStorage(configuration.redis) - suspend fun start() { log.info("Starting evaluation proxy.") /* @@ -66,7 +88,7 @@ class EvaluationProxy( */ log.info("Setting up ${projectConfigurations.size} project(s)") for (projectConfiguration in projectConfigurations) { - val projectApi = ProjectApiV1(projectConfiguration.managementKey) + val projectApi = createProjectApi(projectConfiguration.managementKey) val deployments = projectApi.getDeployments() if (deployments.isEmpty()) { continue @@ -86,23 +108,8 @@ class EvaluationProxy( log.debug("Mapping deployment {} project {}", deployment.key, project.id) deploymentKeysToProject[deployment.key] = project } - // Create a project proxy and add the project to storage. - val assignmentTracker = AmplitudeAssignmentTracker(project.apiKey, configuration.assignment) - val deploymentStorage = getDeploymentStorage(project.id, configuration.redis) - val cohortStorage = getCohortStorage( - project.id, - configuration.redis, - configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS) - ) - val projectProxy = ProjectProxy( - project, - configuration, - assignmentTracker, - deploymentStorage, - cohortStorage - ) - projectProxies[project] = projectProxy + projectProxies[project] = createProjectProxy(project) } /* @@ -119,12 +126,8 @@ class EvaluationProxy( val storageProjectIds = projectStorage.getProjects() for (projectId in storageProjectIds - projectIds) { log.info("Removing project $projectId") - val deploymentStorage = getDeploymentStorage(projectId, configuration.redis) - val cohortStorage = getCohortStorage( - projectId, - configuration.redis, - configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS) - ) + val deploymentStorage = createDeploymentStorage(projectId) + val cohortStorage = createCohortStorage(projectId) // Remove all deployments for project val deployments = deploymentStorage.getDeployments() for ((deploymentKey, _) in deployments) { @@ -135,7 +138,7 @@ class EvaluationProxy( // Remove all cohorts for project val cohortDescriptions = cohortStorage.getCohortDescriptions().values for (cohortDescription in cohortDescriptions) { - cohortStorage.removeCohort(cohortDescription) + cohortStorage.deleteCohort(cohortDescription) } projectStorage.removeProject(projectId) } @@ -152,8 +155,12 @@ class EvaluationProxy( while (true) { delay(configuration.deploymentSyncIntervalMillis) for ((project, projectProxy) in projectProxies) { - val deployments = projectProxy.getDeployments().associateWith { project } - mutex.withLock { deploymentKeysToProject.putAll(deployments) } + try { + val deployments = projectProxy.getDeployments().associateWith { project } + mutex.withLock { deploymentKeysToProject.putAll(deployments) } + } catch (t: Throwable) { + log.error("Periodic deployment to project cache update failed for project ${project.id}", t) + } } } } @@ -162,44 +169,58 @@ class EvaluationProxy( suspend fun shutdown() = coroutineScope { log.info("Shutting down evaluation proxy.") - projectProxies.map { launch { it.value.shutdown() } }.joinAll() + projectProxies.map { scope.launch { it.value.shutdown() } }.joinAll() supervisor.cancelAndJoin() log.info("Evaluation proxy shut down.") } // Apis - suspend fun getFlagConfigs(deploymentKey: String?): List { - val projectProxy = getProjectProxy(deploymentKey) + suspend fun getFlagConfigs( + deploymentKey: String? + ): EvaluationProxyResponse { + val project = getProject(deploymentKey) + ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") + val projectProxy = getProjectProxy(project) + ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") return projectProxy.getFlagConfigs(deploymentKey) } - suspend fun getCohortDescription(deploymentKey: String?, cohortId: String?): CohortDescription { - val projectProxy = getProjectProxy(deploymentKey) - return projectProxy.getCohortDescription(cohortId) - } - - suspend fun getCohortMembers(deploymentKey: String?, cohortId: String?): Set { - val projectProxy = getProjectProxy(deploymentKey) - return projectProxy.getCohortMembers(cohortId) - } - - suspend fun getCohortMembershipsForUser(deploymentKey: String?, userId: String?): Set { - val projectProxy = getProjectProxy(deploymentKey) - return projectProxy.getCohortMembershipsForUser(deploymentKey, userId) + suspend fun getCohort( + apiKey: String?, + secretKey: String?, + cohortId: String?, + lastModified: Long?, + maxCohortSize: Int? + ): EvaluationProxyResponse { + val project = getProject(apiKey, secretKey) + ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid api or secret key") + val projectProxy = getProjectProxy(project) + ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") + return projectProxy.getCohort(cohortId, lastModified, maxCohortSize) } - suspend fun getCohortMembershipsForGroup(deploymentKey: String?, groupType: String?, groupName: String?): Set { - val projectProxy = getProjectProxy(deploymentKey) - return projectProxy.getCohortMembershipsForGroup(deploymentKey, groupType, groupName) + suspend fun getCohortMemberships( + deploymentKey: String?, + groupType: String?, + groupName: String? + ): EvaluationProxyResponse { + val project = getProject(deploymentKey) + ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") + val projectProxy = getProjectProxy(project) + ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") + return projectProxy.getCohortMemberships(deploymentKey, groupType, groupName) } suspend fun evaluate( deploymentKey: String?, user: Map?, flagKeys: Set? = null - ): Map { - val projectProxy = getProjectProxy(deploymentKey) + ): EvaluationProxyResponse { + val project = getProject(deploymentKey) + ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") + val projectProxy = getProjectProxy(project) + ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") return Metrics.with({ Evaluation }, { e -> EvaluationFailure(e) }) { projectProxy.evaluate(deploymentKey, user, flagKeys) } @@ -209,8 +230,11 @@ class EvaluationProxy( deploymentKey: String?, user: Map?, flagKeys: Set? = null - ): Map { - val projectProxy = getProjectProxy(deploymentKey) + ): EvaluationProxyResponse { + val project = getProject(deploymentKey) + ?: return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") + val projectProxy = getProjectProxy(project) + ?: return EvaluationProxyResponse.error(HttpStatusCode.InternalServerError, "Project proxy not found for project.") return Metrics.with({ Evaluation }, { e -> EvaluationFailure(e) }) { projectProxy.evaluateV1(deploymentKey, user, flagKeys) } @@ -218,43 +242,81 @@ class EvaluationProxy( // Private - private suspend fun getProjectProxy(deploymentKey: String?): ProjectProxy { - val cachedProject = mutex.withLock { + private suspend fun getProject(deploymentKey: String?): Project? { + val project = mutex.withLock { deploymentKeysToProject[deploymentKey] } - if (cachedProject == null) { - log.debug("Unable to find project for deployment {}. Current mappings: {}", deploymentKey, deploymentKeysToProject.mapValues { it.value.id }) - throw HttpErrorResponseException(401, "Invalid deployment key.") + if (project == null) { + log.warn("Unable to find project for deployment {}. Current mappings: {}", deploymentKey, deploymentKeysToProject.mapValues { it.value.id }) + return null } - return projectProxies[cachedProject] ?: throw HttpErrorResponseException(404, "Project not found.") + return project } -} - -// Serialized Proxy Calls - -suspend fun EvaluationProxy.getSerializedCohortDescription(deploymentKey: String?, cohortId: String?): String = - json.encodeToString(getCohortDescription(deploymentKey, cohortId)) -suspend fun EvaluationProxy.getSerializedCohortMembers(deploymentKey: String?, cohortId: String?): String = - json.encodeToString(getCohortMembers(deploymentKey, cohortId)) + private suspend fun getProject(apiKey: String?, secretKey: String?): Project? { + val project = mutex.withLock { + apiKeysToProject[apiKey] + } + if (project == null) { + log.warn("Unable to find project for deployment {}. Current mappings: {}", apiKey, apiKeysToProject.mapValues { it.value.id }) + return null + } + if (project.secretKey != secretKey) { + log.warn("Secret key does not match api key for project") + return null + } + return project + } -suspend fun EvaluationProxy.getSerializedFlagConfigs(deploymentKey: String?): String = - json.encodeToString(getFlagConfigs(deploymentKey)) + private fun getProjectProxy(project: Project): ProjectProxy? { + val projectProxy = projectProxies[project] + if (projectProxy == null) { + log.warn("Unable to find proxy for project {}", project) + } + return projectProxy + } -suspend fun EvaluationProxy.getSerializedCohortMembershipsForUser(deploymentKey: String?, userId: String?): String = - json.encodeToString(getCohortMembershipsForUser(deploymentKey, userId)) + @VisibleForTesting + internal fun createProjectApi(managementKey: String): ProjectApi { + return ProjectApiV1( + configuration.managementServerUrl, + managementKey + ) + } -suspend fun EvaluationProxy.getSerializedCohortMembershipsForGroup(deploymentKey: String?, groupType: String?, groupName: String?): String = - json.encodeToString(getCohortMembershipsForGroup(deploymentKey, groupType, groupName)) + @VisibleForTesting + internal fun createProjectProxy(project: Project): ProjectProxy { + val assignmentTracker = AmplitudeAssignmentTracker( + project.apiKey, + configuration.analyticsServerUrl, + configuration.assignment + ) + val deploymentStorage = getDeploymentStorage(project.id, configuration.redis) + val cohortStorage = getCohortStorage( + project.id, + configuration.redis, + configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS) + ) + return ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + } -suspend fun EvaluationProxy.serializedEvaluate( - deploymentKey: String?, - user: Map?, - flagKeys: Set? = null -): String = json.encodeToString(evaluate(deploymentKey, user, flagKeys)) + @VisibleForTesting + internal fun createDeploymentStorage(projectId: String): DeploymentStorage { + return getDeploymentStorage(projectId, configuration.redis) + } -suspend fun EvaluationProxy.serializedEvaluateV1( - deploymentKey: String?, - user: Map?, - flagKeys: Set? = null -): String = json.encodeToString(evaluateV1(deploymentKey, user, flagKeys)) + @VisibleForTesting + internal fun createCohortStorage(projectId: String): CohortStorage { + return getCohortStorage( + projectId, + configuration.redis, + configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS) + ) + } +} diff --git a/core/src/main/kotlin/Metrics.kt b/core/src/main/kotlin/Metrics.kt index d187926..0356895 100644 --- a/core/src/main/kotlin/Metrics.kt +++ b/core/src/main/kotlin/Metrics.kt @@ -1,5 +1,7 @@ package com.amplitude +import com.amplitude.project.InMemoryProjectStorage + sealed class Metric sealed class FailureMetric : Metric() @@ -17,8 +19,6 @@ data object DeploymentsFetch : Metric() data class DeploymentsFetchFailure(val exception: Exception) : FailureMetric() data object FlagsFetch : Metric() data class FlagsFetchFailure(val exception: Exception) : FailureMetric() -data object CohortDescriptionFetch : Metric() -data class CohortDescriptionFetchFailure(val exception: Exception) : FailureMetric() data object CohortDownload : Metric() data class CohortDownloadFailure(val exception: Exception) : FailureMetric() data object RedisCommand : Metric() diff --git a/core/src/main/kotlin/assignment/AssignmentTracker.kt b/core/src/main/kotlin/assignment/AssignmentTracker.kt index 2c620b3..0ea0a72 100644 --- a/core/src/main/kotlin/assignment/AssignmentTracker.kt +++ b/core/src/main/kotlin/assignment/AssignmentTracker.kt @@ -37,9 +37,11 @@ internal class AmplitudeAssignmentTracker( constructor( apiKey: String, + serverUrl: String, config: AssignmentConfiguration ) : this ( amplitude = Amplitude.getInstance().apply { + setServerUrl(serverUrl) setEventUploadThreshold(config.eventUploadThreshold) setEventUploadPeriodMillis(config.eventUploadPeriodMillis) useBatchMode(config.useBatchMode) diff --git a/core/src/main/kotlin/cohort/Cohort.kt b/core/src/main/kotlin/cohort/Cohort.kt new file mode 100644 index 0000000..4a51811 --- /dev/null +++ b/core/src/main/kotlin/cohort/Cohort.kt @@ -0,0 +1,15 @@ +package com.amplitude.cohort + +internal const val USER_GROUP_TYPE = "User" + +data class Cohort( + val id: String, + val groupType: String, + val size: Int, + val lastModified: Long, + val members: Set, +) { + override fun toString(): String { + return "Cohort(id='$id', groupType='$groupType', size=$size, lastModified=$lastModified)" + } +} diff --git a/core/src/main/kotlin/cohort/CohortApi.kt b/core/src/main/kotlin/cohort/CohortApi.kt index 841f9eb..a378f04 100644 --- a/core/src/main/kotlin/cohort/CohortApi.kt +++ b/core/src/main/kotlin/cohort/CohortApi.kt @@ -1,133 +1,109 @@ package com.amplitude.cohort -import com.amplitude.util.HttpErrorException +import com.amplitude.EVALUATION_PROXY_VERSION +import com.amplitude.util.RetryConfig import com.amplitude.util.get import com.amplitude.util.json import com.amplitude.util.logger import com.amplitude.util.retry import io.ktor.client.HttpClient import io.ktor.client.call.body +import io.ktor.client.engine.HttpClientEngine import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.plugins.HttpTimeout import io.ktor.client.request.headers import io.ktor.client.request.parameter -import io.ktor.client.statement.bodyAsChannel +import io.ktor.client.statement.bodyAsText import io.ktor.http.HttpStatusCode -import io.ktor.utils.io.jvm.javaio.toInputStream -import kotlinx.coroutines.delay -import kotlinx.serialization.SerialName +import io.ktor.util.logging.Logger +import io.ktor.util.toByteArray import kotlinx.serialization.Serializable -import org.apache.commons.csv.CSVFormat -import org.apache.commons.csv.CSVParser -import java.lang.IllegalArgumentException import java.util.Base64 -@Serializable -private data class SerialCohortInfoResponse( - @SerialName("cohort_id") val cohortId: String, - @SerialName("app_id") val appId: Int = 0, - @SerialName("org_id") val orgId: Int = 0, - @SerialName("name") val name: String? = null, - @SerialName("size") val size: Int = Int.MAX_VALUE, - @SerialName("description") val description: String? = null, - @SerialName("last_computed") val lastComputed: Long = 0, - @SerialName("group_type") val groupType: String = USER_GROUP_TYPE +internal class CohortTooLargeException(cohortId: String, maxCohortSize: Int) : RuntimeException( + "Cohort $cohortId exceeds the maximum cohort size defined in the SDK configuration $maxCohortSize" ) -@Serializable -private data class GetCohortAsyncResponse( - @SerialName("cohort_id") val cohortId: String, - @SerialName("request_id") val requestId: String +internal class CohortNotModifiedException(cohortId: String) : RuntimeException( + "Cohort $cohortId has not been modified." ) +@Serializable +data class GetCohortResponse( + private val cohortId: String, + private val lastModified: Long, + private val size: Int, + private val groupType: String, + private val memberIds: Set? = null +) { + fun toCohort() = Cohort( + id = cohortId, + groupType = groupType, + size = size, + lastModified = lastModified, + members = memberIds ?: emptySet() + ) + + companion object { + fun fromCohort(cohort: Cohort) = GetCohortResponse( + cohortId = cohort.id, + lastModified = cohort.lastModified, + size = cohort.size, + groupType = cohort.groupType, + memberIds = cohort.members + ) + } +} + internal interface CohortApi { - suspend fun getCohortDescription(cohortId: String): CohortDescription - suspend fun getCohortMembers(cohortDescription: CohortDescription): Set + suspend fun getCohort(cohortId: String, lastModified: Long?, maxCohortSize: Int): Cohort } -internal class CohortApiV5( +internal class CohortApiV1( private val serverUrl: String, apiKey: String, - secretKey: String + secretKey: String, + engine: HttpClientEngine = OkHttp.create(), + private val retryConfig: RetryConfig = RetryConfig() ) : CohortApi { companion object { val log by logger() } - private val csvFormat = CSVFormat.RFC4180.builder().setHeader().build() - private val basicAuth = Base64.getEncoder().encodeToString("$apiKey:$secretKey".toByteArray(Charsets.UTF_8)) - private val client = HttpClient(OkHttp) { + + private val token = Base64.getEncoder().encodeToString("$apiKey:$secretKey".toByteArray(Charsets.UTF_8)) + private val client = HttpClient(engine) { install(HttpTimeout) { socketTimeoutMillis = 30000 } } - override suspend fun getCohortDescription(cohortId: String): CohortDescription { - val response = retry(onFailure = { e -> log.info("Get cohort descriptions failed: $e") }) { - client.get(serverUrl, "/api/3/cohorts/info/$cohortId") { - headers { set("Authorization", "Basic $basicAuth") } - } - } - val serialDescription = json.decodeFromString(response.body()) - return CohortDescription( - id = serialDescription.cohortId, - lastComputed = serialDescription.lastComputed, - size = serialDescription.size, - groupType = serialDescription.groupType - ) - } - - override suspend fun getCohortMembers(cohortDescription: CohortDescription): Set { - log.debug("getCohortMembers: start - cohortDescription={}", cohortDescription) - // Initiate async cohort download - val initialResponse = retry(onFailure = { e -> log.error("Cohort download request failed: $e") }) { - client.get(serverUrl, "/api/5/cohorts/request/${cohortDescription.id}") { - headers { set("Authorization", "Basic $basicAuth") } - parameter("lastComputed", cohortDescription.lastComputed) - } - } - val getCohortResponse = json.decodeFromString(initialResponse.body()) - log.debug("getCohortMembers: poll for status - cohortId=${cohortDescription.id}, requestId=${getCohortResponse.requestId}") - // Poll until the cohort is ready for download - while (true) { - val statusResponse = retry(onFailure = { e -> log.error("Cohort request status failed: $e") }) { - client.get(serverUrl, "/api/5/cohorts/request-status/${getCohortResponse.requestId}") { - headers { set("Authorization", "Basic $basicAuth") } + override suspend fun getCohort(cohortId: String, lastModified: Long?, maxCohortSize: Int): Cohort { + log.debug("getCohortMembers({}): start - maxCohortSize={}, lastModified={}", cohortId, maxCohortSize, lastModified) + val response = retry( + config = retryConfig, + onFailure = { e -> log.error("Cohort download failed: $e") }, + acceptCodes = setOf(HttpStatusCode.NoContent, HttpStatusCode.PayloadTooLarge) + ) { + client.get( + url = serverUrl, + path = "sdk/v1/cohort/$cohortId" + ) { + parameter("maxCohortSize", "$maxCohortSize") + if (lastModified != null) { + parameter("lastModified", "$lastModified") + } + headers { + set("Authorization", "Basic $token") + set("X-Amp-Exp-Library", "evaluation-proxy/$EVALUATION_PROXY_VERSION") } } - log.trace("getCohortMembers: cohortId={}, status={}", cohortDescription.id, statusResponse.status) - if (statusResponse.status == HttpStatusCode.OK) { - break - } else if (statusResponse.status != HttpStatusCode.Accepted) { - throw HttpErrorException(statusResponse.status, statusResponse) - } - delay(5000) } - // Download the cohort - log.debug("getCohortMembers: download cohort - cohortId=${cohortDescription.id}, requestId=${getCohortResponse.requestId}") - val downloadResponse = retry(onFailure = { e -> log.error("Cohort file download failed: $e") }) { - client.get(serverUrl, "/api/5/cohorts/request/${getCohortResponse.requestId}/file") { - headers { set("Authorization", "Basic $basicAuth") } - } + log.debug("getCohortMembers({}): status={}", cohortId, response.status) + when (response.status) { + HttpStatusCode.NoContent -> throw CohortNotModifiedException(cohortId) + HttpStatusCode.PayloadTooLarge -> throw CohortTooLargeException(cohortId, maxCohortSize) + else -> return json.decodeFromString(response.body()).toCohort() } - // Parse the csv response - val csv = CSVParser.parse(downloadResponse.bodyAsChannel().toInputStream(), Charsets.UTF_8, csvFormat) - return if (cohortDescription.groupType == USER_GROUP_TYPE) { - csv.map { it.get("user_id") }.filterNot { it.isNullOrEmpty() }.toSet() - } else { - csv.map { - try { - // CSV returned from API has all strings prefixed with a tab character - it.get("\tgroup_value") - } catch (e: IllegalArgumentException) { - it.get("group_value") - } - }.filterNot { - it.isNullOrEmpty() - }.map { - // CSV returned from API has all strings prefixed with a tab character - it.removePrefix("\t") - }.toSet() - }.also { log.debug("getCohortMembers: end - resultSize=${it.size}") } } } diff --git a/core/src/main/kotlin/cohort/CohortDescription.kt b/core/src/main/kotlin/cohort/CohortDescription.kt deleted file mode 100644 index 357ca26..0000000 --- a/core/src/main/kotlin/cohort/CohortDescription.kt +++ /dev/null @@ -1,13 +0,0 @@ -package com.amplitude.cohort - -import kotlinx.serialization.Serializable - -const val USER_GROUP_TYPE = "User" - -@Serializable -data class CohortDescription( - val id: String, - val lastComputed: Long, - val size: Int, - val groupType: String = USER_GROUP_TYPE -) diff --git a/core/src/main/kotlin/cohort/CohortLoader.kt b/core/src/main/kotlin/cohort/CohortLoader.kt index da5eecc..8bd2683 100644 --- a/core/src/main/kotlin/cohort/CohortLoader.kt +++ b/core/src/main/kotlin/cohort/CohortLoader.kt @@ -1,20 +1,17 @@ package com.amplitude.cohort -import com.amplitude.CohortDescriptionFetch -import com.amplitude.CohortDescriptionFetchFailure import com.amplitude.CohortDownload import com.amplitude.CohortDownloadFailure import com.amplitude.Metrics +import com.amplitude.util.Loader import com.amplitude.util.logger import kotlinx.coroutines.Job import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock internal class CohortLoader( - @Volatile var maxCohortSize: Int, + private val maxCohortSize: Int, private val cohortApi: CohortApi, private val cohortStorage: CohortStorage ) { @@ -22,8 +19,8 @@ internal class CohortLoader( companion object { val log by logger() } - private val jobsMutex = Mutex() - private val jobs = mutableMapOf() + + private val loader = Loader() suspend fun loadCohorts(cohortIds: Set) = coroutineScope { val jobs = mutableListOf() @@ -33,31 +30,28 @@ internal class CohortLoader( jobs.joinAll() } - private suspend fun loadCohort(cohortId: String) = coroutineScope { + private suspend fun loadCohort(cohortId: String) { log.trace("loadCohort: start - cohortId={}", cohortId) - val networkCohort = Metrics.with( - { CohortDescriptionFetch }, - { e -> CohortDescriptionFetchFailure(e) } - ) { - cohortApi.getCohortDescription(cohortId) - } val storageCohort = cohortStorage.getCohortDescription(cohortId) - val shouldDownloadCohort = networkCohort.size <= maxCohortSize && - networkCohort.lastComputed > (storageCohort?.lastComputed ?: -1) - if (shouldDownloadCohort) { - jobsMutex.withLock { - jobs.getOrPut(cohortId) { - launch { - log.info("Downloading cohort. $networkCohort") - val cohortMembers = Metrics.with({ CohortDownload }, { e -> CohortDownloadFailure(e) }) { - cohortApi.getCohortMembers(networkCohort) - } - cohortStorage.putCohort(networkCohort, cohortMembers) - jobsMutex.withLock { jobs.remove(cohortId) } - log.info("Cohort download complete. $networkCohort") + loader.load(cohortId) { + try { + val cohort = Metrics.with({ CohortDownload }, { e -> CohortDownloadFailure(e) }) { + try { + cohortApi.getCohort(cohortId, storageCohort?.lastModified, maxCohortSize) + } catch (e: CohortNotModifiedException) { + log.debug("loadCohort: cohort not modified - cohortId={}", cohortId) + null } } - }.join() + if (cohort != null) { + cohortStorage.putCohort(cohort) + } + log.info("Cohort download complete. {}", cohort ?: cohortId) + } catch (t: Throwable) { + // Don't throw if we fail to download the cohort. We + // prefer to continue to update flags. + log.error("Cohort download failed. $cohortId", t) + } } log.trace("loadCohort: end - cohortId={}", cohortId) } diff --git a/core/src/main/kotlin/cohort/CohortStorage.kt b/core/src/main/kotlin/cohort/CohortStorage.kt index 0a9d8be..bd01e0e 100644 --- a/core/src/main/kotlin/cohort/CohortStorage.kt +++ b/core/src/main/kotlin/cohort/CohortStorage.kt @@ -1,26 +1,53 @@ package com.amplitude.cohort import com.amplitude.RedisConfiguration +import com.amplitude.util.Redis import com.amplitude.util.RedisConnection import com.amplitude.util.RedisKey import com.amplitude.util.json +import com.amplitude.util.logger import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable import kotlinx.serialization.encodeToString import kotlin.time.Duration +@Serializable +internal data class CohortDescription( + @SerialName("cohortId") val id: String, + val groupType: String, + val size: Int, + val lastModified: Long +) { + fun toCohort(members: Set): Cohort { + return Cohort( + id = id, + groupType = groupType, + size = size, + lastModified = lastModified, + members = members + ) + } +} + +internal fun Cohort.toCohortDescription(): CohortDescription { + return CohortDescription( + id = id, + groupType = groupType, + size = size, + lastModified = lastModified + ) +} + internal interface CohortStorage { + suspend fun getCohort(cohortId: String): Cohort? + suspend fun getCohorts(): Map suspend fun getCohortDescription(cohortId: String): CohortDescription? suspend fun getCohortDescriptions(): Map - suspend fun getCohortMembers(cohortDescription: CohortDescription): Set? - suspend fun getCohortMembershipsForUser(userId: String, cohortIds: Set? = null): Set - suspend fun getCohortMembershipsForGroup( - groupType: String, - groupName: String, - cohortIds: Set? = null - ): Set - suspend fun putCohort(description: CohortDescription, members: Set) - suspend fun removeCohort(cohortDescription: CohortDescription) + suspend fun getCohortMemberships(groupType: String, groupName: String, cohortIds: Set): Set + suspend fun putCohort(cohort: Cohort) + suspend fun deleteCohort(description: CohortDescription) } internal fun getCohortStorage(projectId: String, redisConfiguration: RedisConfiguration?, ttl: Duration): CohortStorage { @@ -40,63 +67,47 @@ internal fun getCohortStorage(projectId: String, redisConfiguration: RedisConfig internal class InMemoryCohortStorage : CohortStorage { - private class Cohort( - val description: CohortDescription, - val members: Set - ) - private val lock = Mutex() private val cohorts = mutableMapOf() - override suspend fun getCohortDescription(cohortId: String): CohortDescription? { - return lock.withLock { cohorts[cohortId]?.description } - } - - override suspend fun getCohortDescriptions(): Map { - return lock.withLock { cohorts.mapValues { it.value.description } } + override suspend fun getCohort(cohortId: String): Cohort? { + return lock.withLock { cohorts[cohortId] } } - override suspend fun getCohortMembers(cohortDescription: CohortDescription): Set? { - return lock.withLock { cohorts[cohortDescription.id]?.members } + override suspend fun getCohorts(): Map { + return lock.withLock { cohorts.toMap() } } - override suspend fun putCohort(description: CohortDescription, members: Set) { - return lock.withLock { cohorts[description.id] = Cohort(description, members) } + override suspend fun getCohortDescription(cohortId: String): CohortDescription? { + return lock.withLock { cohorts[cohortId] }?.toCohortDescription() } - override suspend fun removeCohort(cohortDescription: CohortDescription) { - lock.withLock { cohorts.remove(cohortDescription.id) } + override suspend fun getCohortDescriptions(): Map { + return lock.withLock { cohorts.toMap() }.mapValues { it.value.toCohortDescription() } } - override suspend fun getCohortMembershipsForUser(userId: String, cohortIds: Set?): Set { - return lock.withLock { - (cohortIds ?: cohorts.keys).mapNotNull { id -> - when (cohorts[id]?.members?.contains(userId)) { - true -> id - else -> null + override suspend fun getCohortMemberships(groupType: String, groupName: String, cohortIds: Set): Set { + val result = mutableSetOf() + lock.withLock { + for (cohortId in cohortIds) { + val cohort = cohorts[cohortId] ?: continue + if (cohort.groupType != groupType) { + continue + } + if (cohort.members.contains(groupName)) { + result.add(cohortId) } - }.toSet() + } } + return result } - override suspend fun getCohortMembershipsForGroup( - groupType: String, - groupName: String, - cohortIds: Set? - ): Set { - return lock.withLock { - (cohortIds ?: cohorts.keys).mapNotNull { id -> - val cohort = cohorts[id] - if (cohort?.description?.groupType != groupType) { - null - } else { - when (cohort.members.contains(groupName)) { - true -> id - else -> null - } - } - }.toSet() - } + override suspend fun putCohort(cohort: Cohort) { + lock.withLock { cohorts[cohort.id] = cohort } + } + + override suspend fun deleteCohort(description: CohortDescription) { + lock.withLock { cohorts.remove(description.id) } } } @@ -104,10 +115,38 @@ internal class RedisCohortStorage( private val projectId: String, private val ttl: Duration, private val prefix: String, - private val redis: RedisConnection, - private val readOnlyRedis: RedisConnection + private val redis: Redis, + private val readOnlyRedis: Redis ) : CohortStorage { + companion object { + val log by logger() + } + + override suspend fun getCohort(cohortId: String): Cohort? { + val description = getCohortDescription(cohortId) ?: return null + val members = getCohortMembers(cohortId, description.groupType, description.lastModified) + if (members == null) { + log.error("Cohort description found, but members missing. $description") + return null + } + return description.toCohort(members) + } + + override suspend fun getCohorts(): Map { + val result = mutableMapOf() + val cohortDescriptions = getCohortDescriptions() + for (description in cohortDescriptions.values) { + val members = getCohortMembers(description.id, description.groupType, description.lastModified) + if (members == null) { + log.error("Cohort description found, but members missing. $description") + continue + } + result[description.id] = description.toCohort(members) + } + return result + } + override suspend fun getCohortDescription(cohortId: String): CohortDescription? { val jsonEncodedDescription = redis.hget(RedisKey.CohortDescriptions(prefix, projectId), cohortId) ?: return null return json.decodeFromString(jsonEncodedDescription) @@ -118,42 +157,31 @@ internal class RedisCohortStorage( return jsonEncodedDescriptions?.mapValues { json.decodeFromString(it.value) } ?: mapOf() } - override suspend fun getCohortMembers(cohortDescription: CohortDescription): Set? { - return redis.smembers(RedisKey.CohortMembers(prefix, projectId, cohortDescription)) - } - - override suspend fun getCohortMembershipsForUser(userId: String, cohortIds: Set?): Set { - val descriptions = getCohortDescriptions() - val memberships = mutableSetOf() - for (description in descriptions.values) { - if (cohortIds != null && !cohortIds.contains(description.id)) { - continue - } - // High volume, use read connection - val isMember = readOnlyRedis.sismember(RedisKey.CohortMembers(prefix, projectId, description), userId) - if (isMember) { - memberships += description.id - } - } - return memberships - } - - override suspend fun getCohortMembershipsForGroup( + override suspend fun getCohortMemberships( groupType: String, groupName: String, - cohortIds: Set? + cohortIds: Set ): Set { val descriptions = getCohortDescriptions() val memberships = mutableSetOf() for (description in descriptions.values) { - if (cohortIds != null && !cohortIds.contains(description.id)) { + if (!cohortIds.contains(description.id)) { continue } if (description.groupType != groupType) { continue } // High volume, use read connection - val isMember = readOnlyRedis.sismember(RedisKey.CohortMembers(prefix, projectId, description), groupName) + val isMember = readOnlyRedis.sismember( + RedisKey.CohortMembers( + prefix, + projectId, + description.id, + description.groupType, + description.lastModified + ), + groupName + ) if (isMember) { memberships += description.id } @@ -161,20 +189,55 @@ internal class RedisCohortStorage( return memberships } - override suspend fun putCohort(description: CohortDescription, members: Set) { + override suspend fun putCohort(cohort: Cohort) { + val description = cohort.toCohortDescription() val jsonEncodedDescription = json.encodeToString(description) val existingDescription = getCohortDescription(description.id) - if ((existingDescription?.lastComputed ?: 0L) < description.lastComputed) { - redis.sadd(RedisKey.CohortMembers(prefix, projectId, description), members) + if ((existingDescription?.lastModified ?: 0L) < description.lastModified) { + redis.sadd( + RedisKey.CohortMembers( + prefix, + projectId, + description.id, + description.groupType, + description.lastModified + ), + cohort.members + ) redis.hset(RedisKey.CohortDescriptions(prefix, projectId), mapOf(description.id to jsonEncodedDescription)) if (existingDescription != null) { - redis.expire(RedisKey.CohortMembers(prefix, projectId, existingDescription), ttl) + redis.expire( + RedisKey.CohortMembers( + prefix, + projectId, + existingDescription.id, + existingDescription.groupType, + existingDescription.lastModified + ), + ttl + ) } } } - override suspend fun removeCohort(cohortDescription: CohortDescription) { - redis.hdel(RedisKey.CohortDescriptions(prefix, projectId), cohortDescription.id) - redis.del(RedisKey.CohortMembers(prefix, projectId, cohortDescription)) + override suspend fun deleteCohort(description: CohortDescription) { + redis.hdel(RedisKey.CohortDescriptions(prefix, projectId), description.id) + redis.del( + RedisKey.CohortMembers( + prefix, + projectId, + description.id, + description.groupType, + description.lastModified + ) + ) + } + + private suspend fun getCohortMembers( + cohortId: String, + cohortGroupType: String, + cohortLastModified: Long + ): Set? { + return redis.smembers(RedisKey.CohortMembers(prefix, projectId, cohortId, cohortGroupType, cohortLastModified)) } } diff --git a/core/src/main/kotlin/deployment/DeploymentApi.kt b/core/src/main/kotlin/deployment/DeploymentApi.kt index a0b3ab8..7848ff6 100644 --- a/core/src/main/kotlin/deployment/DeploymentApi.kt +++ b/core/src/main/kotlin/deployment/DeploymentApi.kt @@ -1,13 +1,15 @@ package com.amplitude.deployment -import com.amplitude.VERSION +import com.amplitude.EVALUATION_PROXY_VERSION import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.util.RetryConfig import com.amplitude.util.get import com.amplitude.util.json import com.amplitude.util.logger import com.amplitude.util.retry import io.ktor.client.HttpClient import io.ktor.client.call.body +import io.ktor.client.engine.HttpClientEngine import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.request.headers import io.ktor.client.request.parameter @@ -16,24 +18,29 @@ internal interface DeploymentApi { suspend fun getFlagConfigs(deploymentKey: String): List } -internal class DeploymentApiV1( - private val serverUrl: String +internal class DeploymentApiV2( + private val serverUrl: String, + engine: HttpClientEngine = OkHttp.create(), + private val retryConfig: RetryConfig = RetryConfig() ) : DeploymentApi { companion object { val log by logger() } - private val client = HttpClient(OkHttp) + private val client = HttpClient(engine) override suspend fun getFlagConfigs(deploymentKey: String): List { log.trace("getFlagConfigs: start - deploymentKey=$deploymentKey") - val response = retry(onFailure = { e -> log.error("Get flag configs failed: $e") }) { + val response = retry( + config = retryConfig, + onFailure = { e -> log.error("Get flag configs failed: $e") } + ) { client.get(serverUrl, "/sdk/v2/flags") { parameter("v", "0") headers { set("Authorization", "Api-Key $deploymentKey") - set("X-Amp-Exp-Library", "experiment-local-proxy/$VERSION") + set("X-Amp-Exp-Library", "evaluation-proxy/$EVALUATION_PROXY_VERSION") } } } diff --git a/core/src/main/kotlin/deployment/DeploymentLoader.kt b/core/src/main/kotlin/deployment/DeploymentLoader.kt index df4348b..e221704 100644 --- a/core/src/main/kotlin/deployment/DeploymentLoader.kt +++ b/core/src/main/kotlin/deployment/DeploymentLoader.kt @@ -4,13 +4,10 @@ import com.amplitude.FlagsFetch import com.amplitude.FlagsFetchFailure import com.amplitude.Metrics import com.amplitude.cohort.CohortLoader +import com.amplitude.util.Loader import com.amplitude.util.getAllCohortIds import com.amplitude.util.logger -import kotlinx.coroutines.Job -import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock internal class DeploymentLoader( private val deploymentApi: DeploymentApi, @@ -22,42 +19,35 @@ internal class DeploymentLoader( val log by logger() } - private val jobsMutex = Mutex() - private val jobs = mutableMapOf() + private val loader = Loader() - suspend fun loadDeployment(deploymentKey: String) = coroutineScope { + suspend fun loadDeployment(deploymentKey: String) { log.trace("loadDeployment: - deploymentKey=$deploymentKey") - jobsMutex.withLock { - jobs.getOrPut(deploymentKey) { - launch { - val networkFlags = Metrics.with({ FlagsFetch }, { e -> FlagsFetchFailure(e) }) { - deploymentApi.getFlagConfigs(deploymentKey) - } - // Remove flags that are no longer deployed. - val networkFlagKeys = networkFlags.map { it.key }.toSet() - val storageFlagKeys = deploymentStorage.getAllFlags(deploymentKey).map { it.key }.toSet() - for (flagToRemove in storageFlagKeys - networkFlagKeys) { - log.debug("Removing flag: $flagToRemove") - deploymentStorage.removeFlag(deploymentKey, flagToRemove) - } - // Load cohorts for each flag independently then put the - // flag into storage. - for (flag in networkFlags) { - val cohortIds = flag.getAllCohortIds() - if (cohortIds.isNotEmpty()) { - launch { - cohortLoader.loadCohorts(cohortIds) - deploymentStorage.putFlag(deploymentKey, flag) - } - } else { - deploymentStorage.putFlag(deploymentKey, flag) - } + loader.load(deploymentKey) { + val networkFlags = Metrics.with({ FlagsFetch }, { e -> FlagsFetchFailure(e) }) { + deploymentApi.getFlagConfigs(deploymentKey) + } + // Remove flags that are no longer deployed. + val networkFlagKeys = networkFlags.map { it.key }.toSet() + val storageFlagKeys = deploymentStorage.getAllFlags(deploymentKey).map { it.key }.toSet() + for (flagToRemove in storageFlagKeys - networkFlagKeys) { + log.debug("Removing flag: $flagToRemove") + deploymentStorage.removeFlag(deploymentKey, flagToRemove) + } + // Load cohorts for each flag independently then put the + // flag into storage. + for (flag in networkFlags) { + val cohortIds = flag.getAllCohortIds() + if (cohortIds.isNotEmpty()) { + launch { + cohortLoader.loadCohorts(cohortIds) + deploymentStorage.putFlag(deploymentKey, flag) } - // Remove the job - jobsMutex.withLock { jobs.remove(deploymentKey) } + } else { + deploymentStorage.putFlag(deploymentKey, flag) } } - }.join() + } log.trace("loadDeployment: end - deploymentKey=$deploymentKey") } } diff --git a/core/src/main/kotlin/deployment/DeploymentRunner.kt b/core/src/main/kotlin/deployment/DeploymentRunner.kt index 1a5d751..6101d77 100644 --- a/core/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/core/src/main/kotlin/deployment/DeploymentRunner.kt @@ -11,11 +11,11 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.launch internal class DeploymentRunner( - @Volatile var configuration: Configuration, + private val configuration: Configuration, private val deploymentKey: String, - private val deploymentApi: DeploymentApi, + private val cohortLoader: CohortLoader, private val deploymentStorage: DeploymentStorage, - private val cohortLoader: CohortLoader + private val deploymentLoader: DeploymentLoader, ) { companion object { @@ -24,26 +24,42 @@ internal class DeploymentRunner( private val supervisor = SupervisorJob() private val scope = CoroutineScope(supervisor) - private val deploymentLoader = DeploymentLoader(deploymentApi, deploymentStorage, cohortLoader) suspend fun start() { log.trace("start: - deploymentKey=$deploymentKey") - deploymentLoader.loadDeployment(deploymentKey) + val job = scope.launch { + try { + deploymentLoader.loadDeployment(deploymentKey) + } catch (t: Throwable) { + // Catch failure and continue to run pollers. Assume deployment + // load will eventually succeed. + log.error("Load failed for deployment $deploymentKey", t) + } + } // Periodic flag config loader scope.launch { while (true) { delay(configuration.flagSyncIntervalMillis) - deploymentLoader.loadDeployment(deploymentKey) + try { + deploymentLoader.loadDeployment(deploymentKey) + } catch (t: Throwable) { + log.error("Periodic deployment load failed for deployment $deploymentKey", t) + } } } // Periodic cohort refresher scope.launch { while (true) { delay(configuration.cohortSyncIntervalMillis) - val cohortIds = deploymentStorage.getAllFlags(deploymentKey).values.getAllCohortIds() - cohortLoader.loadCohorts(cohortIds) + try { + val cohortIds = deploymentStorage.getAllFlags(deploymentKey).values.getAllCohortIds() + cohortLoader.loadCohorts(cohortIds) + } catch (t: Throwable) { + log.error("Periodic cohort load failed for deployment $deploymentKey", t) + } } } + job.join() } suspend fun stop() { diff --git a/core/src/main/kotlin/deployment/DeploymentStorage.kt b/core/src/main/kotlin/deployment/DeploymentStorage.kt index 6164fcb..ca52d7b 100644 --- a/core/src/main/kotlin/deployment/DeploymentStorage.kt +++ b/core/src/main/kotlin/deployment/DeploymentStorage.kt @@ -129,6 +129,7 @@ internal class RedisDeploymentStorage( override suspend fun removeDeployment(deploymentKey: String) { redis.hdel(RedisKey.Deployments(prefix, projectId), deploymentKey) + removeAllFlags(deploymentKey) } override suspend fun getFlag(deploymentKey: String, flagKey: String): EvaluationFlag? { diff --git a/core/src/main/kotlin/project/ProjectApi.kt b/core/src/main/kotlin/project/ProjectApi.kt index 43f300e..b481c3d 100644 --- a/core/src/main/kotlin/project/ProjectApi.kt +++ b/core/src/main/kotlin/project/ProjectApi.kt @@ -10,6 +10,7 @@ import com.amplitude.util.logger import com.amplitude.util.retry import io.ktor.client.HttpClient import io.ktor.client.call.body +import io.ktor.client.engine.HttpClientEngine import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.plugins.HttpTimeout import io.ktor.client.request.headers @@ -18,7 +19,7 @@ import kotlinx.serialization.Serializable private const val MANAGEMENT_SERVER_URL = "https://experiment.amplitude.com" @Serializable -private data class DeploymentsResponse( +internal data class DeploymentsResponse( val deployments: List ) @@ -40,13 +41,17 @@ internal interface ProjectApi { suspend fun getDeployments(): List } -internal class ProjectApiV1(private val managementKey: String) : ProjectApi { +internal class ProjectApiV1( + private val serverUrl: String, + private val managementKey: String, + engine: HttpClientEngine = OkHttp.create() +) : ProjectApi { companion object { val log by logger() } - private val client = HttpClient(OkHttp) { + private val client = HttpClient(engine) { install(HttpTimeout) { socketTimeoutMillis = 30000 } @@ -56,7 +61,9 @@ internal class ProjectApiV1(private val managementKey: String) : ProjectApi { Metrics.with({ DeploymentsFetch }, { e -> DeploymentsFetchFailure(e) }) { log.trace("getDeployments: start") val response = retry(onFailure = { e -> log.error("Get deployments failed: $e") }) { - client.get(MANAGEMENT_SERVER_URL, "/api/1/deployments") { + client.get( + url = serverUrl, + path = "api/1/deployments") { headers { set("Authorization", "Bearer $managementKey") set("Accept", "application/json") diff --git a/core/src/main/kotlin/project/ProjectProxy.kt b/core/src/main/kotlin/project/ProjectProxy.kt index fa047ce..8dd4850 100644 --- a/core/src/main/kotlin/project/ProjectProxy.kt +++ b/core/src/main/kotlin/project/ProjectProxy.kt @@ -1,24 +1,28 @@ package com.amplitude.project import com.amplitude.Configuration -import com.amplitude.HttpErrorResponseException +import com.amplitude.EvaluationProxyResponse import com.amplitude.assignment.Assignment import com.amplitude.assignment.AssignmentTracker -import com.amplitude.cohort.CohortApiV5 -import com.amplitude.cohort.CohortDescription +import com.amplitude.cohort.CohortApiV1 +import com.amplitude.cohort.CohortLoader import com.amplitude.cohort.CohortStorage +import com.amplitude.cohort.GetCohortResponse import com.amplitude.cohort.USER_GROUP_TYPE -import com.amplitude.deployment.DeploymentApiV1 +import com.amplitude.deployment.DeploymentApiV2 +import com.amplitude.deployment.DeploymentLoader import com.amplitude.deployment.DeploymentStorage import com.amplitude.experiment.evaluation.EvaluationEngineImpl -import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.experiment.evaluation.EvaluationVariant import com.amplitude.experiment.evaluation.topologicalSort import com.amplitude.util.getGroupedCohortIds +import com.amplitude.util.json import com.amplitude.util.logger import com.amplitude.util.toEvaluationContext +import io.ktor.http.HttpStatusCode import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch +import kotlinx.serialization.encodeToString internal class ProjectProxy( private val project: Project, @@ -34,15 +38,18 @@ internal class ProjectProxy( private val engine = EvaluationEngineImpl() - private val projectApi = ProjectApiV1(project.managementKey) - private val deploymentApi = DeploymentApiV1(configuration.serverUrl) - private val cohortApi = CohortApiV5(configuration.cohortServerUrl, project.apiKey, project.secretKey) + private val projectApi = ProjectApiV1(configuration.managementServerUrl, project.managementKey) + private val deploymentApi = DeploymentApiV2(configuration.serverUrl) + private val cohortApi = CohortApiV1(configuration.cohortServerUrl, project.apiKey, project.secretKey) + private val cohortLoader = CohortLoader(configuration.maxCohortSize, cohortApi, cohortStorage) + private val deploymentLoader = DeploymentLoader(deploymentApi, deploymentStorage, cohortLoader) private val projectRunner = ProjectRunner( + project, configuration, projectApi, - deploymentApi, + deploymentLoader, deploymentStorage, - cohortApi, + cohortLoader, cohortStorage ) @@ -56,70 +63,85 @@ internal class ProjectProxy( projectRunner.stop() } - suspend fun getFlagConfigs(deploymentKey: String?): List { + suspend fun getFlagConfigs(deploymentKey: String?): EvaluationProxyResponse { if (deploymentKey.isNullOrEmpty()) { - throw HttpErrorResponseException(status = 401, message = "Invalid deployment.") + return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") } - return deploymentStorage.getAllFlags(deploymentKey).values.toList() + val result = deploymentStorage.getAllFlags(deploymentKey).values.toList() + return EvaluationProxyResponse.error(HttpStatusCode.OK, json.encodeToString(result)) } - suspend fun getCohortDescription(cohortId: String?): CohortDescription { + suspend fun getCohort(cohortId: String?, lastModified: Long?, maxCohortSize: Int?): EvaluationProxyResponse { if (cohortId.isNullOrEmpty()) { - throw HttpErrorResponseException(status = 404, message = "Cohort not found.") - } - return cohortStorage.getCohortDescription(cohortId) - ?: throw HttpErrorResponseException(status = 404, message = "Cohort not found.") - } - - suspend fun getCohortMembers(cohortId: String?): Set { - if (cohortId.isNullOrEmpty()) { - throw HttpErrorResponseException(status = 404, message = "Cohort not found.") - } - val cohortDescription = cohortStorage.getCohortDescription(cohortId) - ?: throw HttpErrorResponseException(status = 404, message = "Cohort not found.") - return cohortStorage.getCohortMembers(cohortDescription) - ?: throw HttpErrorResponseException(status = 404, message = "Cohort not found.") + return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort not found") + } + val cohortDescription = cohortStorage.getCohort(cohortId) + ?: return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort not found") + if (cohortDescription.size > (maxCohortSize ?: Int.MAX_VALUE)) { + return EvaluationProxyResponse.error( + HttpStatusCode.PayloadTooLarge, + "Cohort $cohortId sized ${cohortDescription.size} is greater than max cohort size $maxCohortSize" + ) + } + if (cohortDescription.lastModified == lastModified) { + return EvaluationProxyResponse.error(HttpStatusCode.NoContent, "Cohort not modified") + } + val cohort = cohortStorage.getCohort(cohortId) + ?: return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort members not found") + return EvaluationProxyResponse.json(HttpStatusCode.OK, GetCohortResponse.fromCohort(cohort)) } - suspend fun getCohortMembershipsForUser(deploymentKey: String?, userId: String?): Set { + suspend fun getCohortMemberships(deploymentKey: String?, groupType: String?, groupName: String?): EvaluationProxyResponse { if (deploymentKey.isNullOrEmpty()) { - throw HttpErrorResponseException(status = 401, message = "Invalid deployment.") - } - if (userId.isNullOrEmpty()) { - throw HttpErrorResponseException(status = 400, message = "Invalid user ID.") - } - val cohortIds = deploymentStorage.getAllFlags(deploymentKey).values.getGroupedCohortIds()[USER_GROUP_TYPE] - if (cohortIds.isNullOrEmpty()) { - return setOf() - } - return cohortStorage.getCohortMembershipsForUser(userId, cohortIds) - } - - suspend fun getCohortMembershipsForGroup(deploymentKey: String?, groupType: String?, groupName: String?): Set { - if (deploymentKey.isNullOrEmpty()) { - throw HttpErrorResponseException(status = 401, message = "Invalid deployment.") + return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") } if (groupType.isNullOrEmpty()) { - throw HttpErrorResponseException(status = 400, message = "Invalid group type.") + return EvaluationProxyResponse.error(HttpStatusCode.BadRequest, "Invalid group type") } if (groupName.isNullOrEmpty()) { - throw HttpErrorResponseException(status = 400, message = "Invalid group name.") + return EvaluationProxyResponse.error(HttpStatusCode.BadRequest, "Invalid group name") } val cohortIds = deploymentStorage.getAllFlags(deploymentKey).values.getGroupedCohortIds()[groupType] if (cohortIds.isNullOrEmpty()) { - return setOf() + return EvaluationProxyResponse.json(HttpStatusCode.OK, emptySet()) } - return cohortStorage.getCohortMembershipsForGroup(groupType, groupName, cohortIds) + val result = cohortStorage.getCohortMemberships(groupType, groupName, cohortIds) + return EvaluationProxyResponse.json(HttpStatusCode.OK, result) } suspend fun evaluate( deploymentKey: String?, user: Map?, flagKeys: Set? = null - ): Map { + ): EvaluationProxyResponse { + if (deploymentKey.isNullOrEmpty()) { + return EvaluationProxyResponse.error(HttpStatusCode.Unauthorized, "Invalid deployment") + } + val result = evaluateInternal(deploymentKey, user, flagKeys) + return EvaluationProxyResponse(HttpStatusCode.OK, json.encodeToString(result)) + } + + suspend fun evaluateV1( + deploymentKey: String?, + user: Map?, + flagKeys: Set? = null + ): EvaluationProxyResponse { if (deploymentKey.isNullOrEmpty()) { - throw HttpErrorResponseException(status = 401, message = "Invalid deployment.") + return EvaluationProxyResponse(HttpStatusCode.Unauthorized, "Invalid deployment") } + val result = evaluateInternal(deploymentKey, user, flagKeys).filter { entry -> + val default = entry.value.metadata?.get("default") as? Boolean ?: false + val deployed = entry.value.metadata?.get("deployed") as? Boolean ?: true + (!default && deployed) + } + return EvaluationProxyResponse(HttpStatusCode.OK, json.encodeToString(result)) + } + + private suspend fun evaluateInternal( + deploymentKey: String, + user: Map?, + flagKeys: Set? = null + ): Map { // Get flag configs for the deployment from storage and topo sort. val storageFlags = deploymentStorage.getAllFlags(deploymentKey) if (storageFlags.isEmpty()) { @@ -129,11 +151,13 @@ internal class ProjectProxy( if (flags.isEmpty()) { return mapOf() } + val groupedCohortIds = flags.getGroupedCohortIds() // Enrich user with cohort IDs and build the evaluation context val userId = user?.get("user_id") as? String val enrichedUser = user?.toMutableMap() ?: mutableMapOf() - if (userId != null) { - enrichedUser["cohort_ids"] = cohortStorage.getCohortMembershipsForUser(userId) + val userCohortIds = groupedCohortIds[USER_GROUP_TYPE] + if (userId != null && userCohortIds != null) { + enrichedUser["cohort_ids"] = cohortStorage.getCohortMemberships(USER_GROUP_TYPE, userId, userCohortIds) } val groups = enrichedUser["groups"] as? Map<*, *> if (!groups.isNullOrEmpty()) { @@ -141,8 +165,9 @@ internal class ProjectProxy( for (entry in groups.entries) { val groupType = entry.key as? String val groupName = (entry.value as? Collection<*>)?.firstOrNull() as? String - if (groupType != null && groupName != null) { - val cohortIds = cohortStorage.getCohortMembershipsForGroup(groupType, groupName) + val groupTypeCohortIds = groupedCohortIds[groupType] + if (groupType != null && groupName != null && groupTypeCohortIds != null) { + val cohortIds = cohortStorage.getCohortMemberships(groupType, groupName, groupTypeCohortIds) if (groupCohortIds.isNotEmpty()) { groupCohortIds.putIfAbsent(groupType, mutableMapOf(groupName to cohortIds)) } @@ -166,18 +191,6 @@ internal class ProjectProxy( return result } - suspend fun evaluateV1( - deploymentKey: String?, - user: Map?, - flagKeys: Set? = null - ): Map { - return evaluate(deploymentKey, user, flagKeys).filter { entry -> - val default = entry.value.metadata?.get("default") as? Boolean ?: false - val deployed = entry.value.metadata?.get("deployed") as? Boolean ?: true - (!default && deployed) - } - } - // Internal internal suspend fun getDeployments(): Set { diff --git a/core/src/main/kotlin/project/ProjectRunner.kt b/core/src/main/kotlin/project/ProjectRunner.kt index b4aa0f0..a8ebb8b 100644 --- a/core/src/main/kotlin/project/ProjectRunner.kt +++ b/core/src/main/kotlin/project/ProjectRunner.kt @@ -1,10 +1,9 @@ package com.amplitude.project import com.amplitude.Configuration -import com.amplitude.cohort.CohortApi import com.amplitude.cohort.CohortLoader import com.amplitude.cohort.CohortStorage -import com.amplitude.deployment.DeploymentApi +import com.amplitude.deployment.DeploymentLoader import com.amplitude.deployment.DeploymentRunner import com.amplitude.deployment.DeploymentStorage import com.amplitude.experiment.evaluation.EvaluationFlag @@ -14,18 +13,21 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.delay import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock +import org.jetbrains.annotations.VisibleForTesting internal class ProjectRunner( + private val project: Project, private val configuration: Configuration, private val projectApi: ProjectApi, - private val deploymentApi: DeploymentApi, + private val deploymentLoader: DeploymentLoader, private val deploymentStorage: DeploymentStorage, - cohortApi: CohortApi, + private val cohortLoader: CohortLoader, private val cohortStorage: CohortStorage ) { @@ -37,18 +39,29 @@ internal class ProjectRunner( private val scope = CoroutineScope(supervisor) private val lock = Mutex() - private val deploymentRunners = mutableMapOf() - private val cohortLoader = CohortLoader(configuration.maxCohortSize, cohortApi, cohortStorage) + @VisibleForTesting + internal val deploymentRunners = mutableMapOf() suspend fun start() { - refresh() + val job = scope.launch { + try { + refresh() + } catch (t: Throwable) { + log.error("Refresh failed for project ${project.id}", t) + } + } // Periodic deployment update and refresher scope.launch { while (true) { delay(configuration.deploymentSyncIntervalMillis) - refresh() + try { + refresh() + } catch (t: Throwable) { + log.error("Periodic project refresh failed for project ${project.id}", t) + } } } + job.join() } suspend fun stop() { @@ -60,39 +73,41 @@ internal class ProjectRunner( supervisor.cancelAndJoin() } - private suspend fun refresh() = lock.withLock { - log.trace("refresh: start") - // Get deployments from API and update the storage. - val networkDeployments = projectApi.getDeployments().associateBy { it.key } - val storageDeployments = deploymentStorage.getDeployments() - // Determine added and removed deployments - val addedDeployments = networkDeployments - storageDeployments.keys - val removedDeployments = storageDeployments - networkDeployments.keys - val startingDeployments = networkDeployments - deploymentRunners.keys - val jobs = mutableListOf() - for ((_, addedDeployment) in addedDeployments) { - log.info("Adding deployment $addedDeployment") - deploymentStorage.putDeployment(addedDeployment) - } - for ((_, deployment) in startingDeployments) { - jobs += scope.launch { addDeploymentInternal(deployment.key) } - } - for ((_, removedDeployment) in removedDeployments) { - log.info("Removing deployment $removedDeployment") - deploymentStorage.removeAllFlags(removedDeployment.key) - deploymentStorage.removeDeployment(removedDeployment.key) - jobs += scope.launch { removeDeploymentInternal(removedDeployment.key) } + private suspend fun refresh() = coroutineScope { + lock.withLock { + log.trace("refresh: start") + // Get deployments from API and update the storage. + val networkDeployments = projectApi.getDeployments().associateBy { it.key } + val storageDeployments = deploymentStorage.getDeployments() + // Determine added and removed deployments + val addedDeployments = networkDeployments - storageDeployments.keys + val removedDeployments = storageDeployments - networkDeployments.keys + val startingDeployments = networkDeployments - deploymentRunners.keys + val jobs = mutableListOf() + for ((_, addedDeployment) in addedDeployments) { + log.info("Adding deployment $addedDeployment") + deploymentStorage.putDeployment(addedDeployment) + } + for ((_, deployment) in startingDeployments) { + jobs += scope.launch { addDeploymentInternal(deployment.key) } + } + for ((_, removedDeployment) in removedDeployments) { + log.info("Removing deployment $removedDeployment") + deploymentStorage.removeAllFlags(removedDeployment.key) + deploymentStorage.removeDeployment(removedDeployment.key) + jobs += scope.launch { removeDeploymentInternal(removedDeployment.key) } + } + // Keep cohorts which are targeted by all stored deployments. + removeUnusedCohorts(networkDeployments.keys) + jobs.joinAll() + log.debug( + "Project refresh finished: addedDeployments={}, removedDeployments={}, startedDeployments={}", + addedDeployments.keys, + removedDeployments.keys, + startingDeployments.keys + ) + log.trace("refresh: end") } - jobs.joinAll() - // Keep cohorts which are targeted by all stored deployments. - removeUnusedCohorts(networkDeployments.keys) - log.debug( - "Project refresh finished: addedDeployments={}, removedDeployments={}, startedDeployments={}", - addedDeployments.keys, - removedDeployments.keys, - startingDeployments.keys - ) - log.trace("refresh: end") } // Must be run within lock @@ -104,9 +119,9 @@ internal class ProjectRunner( val deploymentRunner = DeploymentRunner( configuration, deploymentKey, - deploymentApi, + cohortLoader, deploymentStorage, - cohortLoader + deploymentLoader ) deploymentRunner.start() deploymentRunners[deploymentKey] = deploymentRunner @@ -128,7 +143,7 @@ internal class ProjectRunner( for (cohortDescription in allStoredCohortDescriptions) { if (!allTargetedCohortIds.contains(cohortDescription.id)) { log.info("Removing unused cohort $cohortDescription") - cohortStorage.removeCohort(cohortDescription) + cohortStorage.deleteCohort(cohortDescription) } } } diff --git a/core/src/main/kotlin/project/ProjectStorage.kt b/core/src/main/kotlin/project/ProjectStorage.kt index 6fe1b36..0d85342 100644 --- a/core/src/main/kotlin/project/ProjectStorage.kt +++ b/core/src/main/kotlin/project/ProjectStorage.kt @@ -1,6 +1,7 @@ package com.amplitude.project import com.amplitude.RedisConfiguration +import com.amplitude.util.Redis import com.amplitude.util.RedisConnection import com.amplitude.util.RedisKey import kotlinx.coroutines.channels.BufferOverflow @@ -10,7 +11,6 @@ import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock internal interface ProjectStorage { - val projects: Flow> suspend fun getProjects(): Set suspend fun putProject(projectId: String) suspend fun removeProject(projectId: String) @@ -27,11 +27,6 @@ internal fun getProjectStorage(redisConfiguration: RedisConfiguration?): Project internal class InMemoryProjectStorage : ProjectStorage { - override val projects = MutableSharedFlow>( - extraBufferCapacity = 1, - onBufferOverflow = BufferOverflow.DROP_OLDEST - ) - private val mutex = Mutex() private val projectStorage = mutableSetOf() @@ -41,36 +36,27 @@ internal class InMemoryProjectStorage : ProjectStorage { override suspend fun putProject(projectId: String): Unit = mutex.withLock { projectStorage.add(projectId) - projects.emit(projectStorage.toSet()) } override suspend fun removeProject(projectId: String): Unit = mutex.withLock { projectStorage.remove(projectId) - projects.emit(projectStorage.toSet()) } } internal class RedisProjectStorage( private val prefix: String, - private val redis: RedisConnection + private val redis: Redis ) : ProjectStorage { - override val projects = MutableSharedFlow>( - extraBufferCapacity = 1, - onBufferOverflow = BufferOverflow.DROP_OLDEST - ) - override suspend fun getProjects(): Set { return redis.smembers(RedisKey.Projects(prefix)) ?: emptySet() } override suspend fun putProject(projectId: String) { redis.sadd(RedisKey.Projects(prefix), setOf(projectId)) - projects.emit(getProjects()) } override suspend fun removeProject(projectId: String) { redis.srem(RedisKey.Projects(prefix), projectId) - projects.emit(getProjects()) } } diff --git a/core/src/main/kotlin/util/Http.kt b/core/src/main/kotlin/util/Http.kt index 5d244a5..10afe95 100644 --- a/core/src/main/kotlin/util/Http.kt +++ b/core/src/main/kotlin/util/Http.kt @@ -25,27 +25,28 @@ internal data class RetryConfig( internal suspend fun retry( config: RetryConfig = RetryConfig(), onFailure: (Exception) -> Unit = {}, + acceptCodes: Set = emptySet(), block: suspend () -> HttpResponse ): HttpResponse { var currentDelay = config.initialDelayMillis var error: Exception? = null - for (i in 0..config.times) { + for (i in 0..() + + suspend fun load(key: String, loader: suspend CoroutineScope.() -> Unit) = coroutineScope { + jobsMutex.withLock { + jobs.getOrPut(key) { + launch { + try { + loader() + } finally { + jobsMutex.withLock { jobs.remove(key) } + } + } + } + }.join() + } +} diff --git a/core/src/main/kotlin/util/Redis.kt b/core/src/main/kotlin/util/Redis.kt index 2cad423..e4ad777 100644 --- a/core/src/main/kotlin/util/Redis.kt +++ b/core/src/main/kotlin/util/Redis.kt @@ -3,7 +3,6 @@ package com.amplitude.util import com.amplitude.Metrics import com.amplitude.RedisCommand import com.amplitude.RedisCommandFailure -import com.amplitude.cohort.CohortDescription import io.lettuce.core.RedisClient import io.lettuce.core.RedisFuture import io.lettuce.core.RedisURI @@ -14,7 +13,7 @@ import kotlinx.coroutines.Deferred import kotlinx.coroutines.future.asDeferred import kotlin.time.Duration -private const val STORAGE_PROTOCOL_VERSION = "v2" +private const val STORAGE_PROTOCOL_VERSION = "v3" internal sealed class RedisKey(val value: String) { @@ -39,8 +38,10 @@ internal sealed class RedisKey(val value: String) { data class CohortMembers( val prefix: String, val projectId: String, - val cohortDescription: CohortDescription - ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:cohorts:${cohortDescription.id}:${cohortDescription.groupType}:${cohortDescription.lastComputed}") + val cohortId: String, + val cohortGroupType: String, + val cohortLastModified: Long, + ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:cohorts:${cohortId}:${cohortGroupType}:${cohortLastModified}") } internal interface Redis { diff --git a/core/src/test/kotlin/EvaluationProxyTest.kt b/core/src/test/kotlin/EvaluationProxyTest.kt new file mode 100644 index 0000000..8a7959e --- /dev/null +++ b/core/src/test/kotlin/EvaluationProxyTest.kt @@ -0,0 +1,114 @@ +import com.amplitude.Configuration +import com.amplitude.EvaluationProxy +import com.amplitude.ProjectConfiguration +import com.amplitude.cohort.CohortStorage +import com.amplitude.cohort.toCohortDescription +import com.amplitude.deployment.DeploymentStorage +import com.amplitude.project.InMemoryProjectStorage +import com.amplitude.project.ProjectApi +import com.amplitude.project.ProjectProxy +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockk +import io.mockk.spyk +import io.mockk.verify +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertEquals +import test.cohort +import test.deployment +import test.project +import kotlin.test.Test + +class EvaluationProxyTest { + + @Test + fun `test start, no projects, nothing happens`(): Unit = runBlocking { + val projectStorage = spyk(InMemoryProjectStorage()) + val projectConfigurations = listOf() + val evaluationProxy = spyk(EvaluationProxy( + projectConfigurations = projectConfigurations, + configuration = Configuration(), + projectStorage = projectStorage + )) + evaluationProxy.start() + verify(exactly = 0) { evaluationProxy.createProjectApi(allAny()) } + assertEquals(0, evaluationProxy.projectProxies.size) + coVerify(exactly = 0) { projectStorage.putProject(allAny()) } + } + + @Test + fun `test start, with project, storage loaded and proxy started`(): Unit = runBlocking { + val projectStorage = spyk(InMemoryProjectStorage()) + val project = project("1") + val projectConfigurations = listOf( + ProjectConfiguration("api", "secret", "management") + ) + val evaluationProxy = spyk(EvaluationProxy( + projectConfigurations = projectConfigurations, + configuration = Configuration(), + projectStorage = projectStorage + )) + coEvery { evaluationProxy.createProjectApi(allAny()) } returns + mockk().apply { + coEvery { getDeployments() } returns listOf(deployment("a", project.id)) + } + coEvery { evaluationProxy.createProjectProxy(allAny()) } returns + mockk().apply { + coEvery { start() } returns Unit + } + evaluationProxy.start() + verify(exactly = 1) { evaluationProxy.createProjectApi(allAny()) } + assertEquals(1, evaluationProxy.projectProxies.size) + coVerify(exactly = 1) { projectStorage.putProject(allAny()) } + val projectProxy = evaluationProxy.projectProxies[project] + coVerify(exactly = 1) { projectProxy?.start() } + } + + @Test + fun `test start, stored but no longer configured project deleted`(): Unit = runBlocking { + val projectStorage = spyk(InMemoryProjectStorage().apply { + putProject("2") + }) + val project = project("1") + val projectConfigurations = listOf( + ProjectConfiguration("api", "secret", "management") + ) + val evaluationProxy = spyk(EvaluationProxy( + projectConfigurations = projectConfigurations, + configuration = Configuration(), + projectStorage = projectStorage + )) + coEvery { evaluationProxy.createProjectApi(allAny()) } returns + mockk().apply { + coEvery { getDeployments() } returns listOf(deployment("a", project.id)) + } + coEvery { evaluationProxy.createProjectProxy(allAny()) } returns + mockk().apply { + coEvery { start() } returns Unit + } + val deploymentStorage = mockk().apply { + coEvery { getDeployments() } returns mapOf("b" to deployment("b")) + coEvery { removeDeployment(eq("b")) } returns Unit + coEvery { removeAllFlags(eq("b")) } returns Unit + } + val cohortStorage = mockk().apply { + coEvery { getCohortDescriptions() } returns mapOf("c" to cohort("c").toCohortDescription()) + coEvery { deleteCohort(eq(cohort("c").toCohortDescription()))} returns Unit + } + coEvery { evaluationProxy.createDeploymentStorage(allAny()) } returns deploymentStorage + coEvery { evaluationProxy.createCohortStorage(allAny()) } returns cohortStorage + evaluationProxy.start() + verify(exactly = 1) { evaluationProxy.createProjectApi(allAny()) } + assertEquals(1, evaluationProxy.projectProxies.size) + coVerify(exactly = 1) { projectStorage.putProject(allAny()) } + val projectProxy = evaluationProxy.projectProxies[project] + coVerify(exactly = 1) { projectProxy?.start() } + // Verify project "2" removed + coVerify(exactly = 1) { deploymentStorage.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.removeDeployment(eq("b")) } + coVerify(exactly = 1) { deploymentStorage.removeAllFlags(eq("b")) } + coVerify(exactly = 1) { cohortStorage.getCohortDescriptions() } + coVerify(exactly = 1) { cohortStorage.deleteCohort(eq(cohort("c").toCohortDescription()))} + coVerify(exactly = 1) { projectStorage.removeProject(eq("2")) } + } +} diff --git a/core/src/test/kotlin/Utils.kt b/core/src/test/kotlin/Utils.kt deleted file mode 100644 index e038dba..0000000 --- a/core/src/test/kotlin/Utils.kt +++ /dev/null @@ -1,15 +0,0 @@ -fun user( - userId: String? = null, - deviceId: String? = null, - userProperties: Map? = null, - groups: Map>? = null, - groupProperties: Map>>? = null -): MutableMap { - return mutableMapOf( - "user_id" to userId, - "device_id" to deviceId, - "user_properties" to userProperties, - "groups" to groups, - "group_properties" to groupProperties - ) -} diff --git a/core/src/test/kotlin/assignment/AssignmentFilterTest.kt b/core/src/test/kotlin/assignment/AssignmentFilterTest.kt index 2aeb57d..9c2fa6e 100644 --- a/core/src/test/kotlin/assignment/AssignmentFilterTest.kt +++ b/core/src/test/kotlin/assignment/AssignmentFilterTest.kt @@ -5,6 +5,7 @@ import com.amplitude.util.toEvaluationContext import kotlinx.coroutines.runBlocking import org.junit.Assert import org.junit.Test +import test.user class AssignmentFilterTest { diff --git a/core/src/test/kotlin/assignment/AssignmentServiceTest.kt b/core/src/test/kotlin/assignment/AssignmentServiceTest.kt index 7f6499c..2cbffcc 100644 --- a/core/src/test/kotlin/assignment/AssignmentServiceTest.kt +++ b/core/src/test/kotlin/assignment/AssignmentServiceTest.kt @@ -8,6 +8,7 @@ import com.amplitude.util.userId import kotlinx.coroutines.runBlocking import org.junit.Assert import org.junit.Test +import test.user class AssignmentServiceTest { diff --git a/core/src/test/kotlin/cohort/CohortApiTest.kt b/core/src/test/kotlin/cohort/CohortApiTest.kt new file mode 100644 index 0000000..d2d68d5 --- /dev/null +++ b/core/src/test/kotlin/cohort/CohortApiTest.kt @@ -0,0 +1,219 @@ +package cohort + +import com.amplitude.cohort.Cohort +import com.amplitude.cohort.CohortApiV1 +import com.amplitude.cohort.CohortNotModifiedException +import com.amplitude.cohort.CohortTooLargeException +import com.amplitude.cohort.GetCohortResponse +import com.amplitude.util.HttpErrorException +import com.amplitude.util.RetryConfig +import com.amplitude.util.json +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.http.Parameters +import io.ktor.utils.io.ByteReadChannel +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.encodeToString +import java.io.IOException +import java.util.Base64 +import kotlin.math.max +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.fail + +class CohortApiTest { + + private val apiKey = "api" + private val secretKey = "secret" + private val serverUrl = "https://api.lab.amplitude.com/" + private val token = Base64.getEncoder().encodeToString("$apiKey:$secretKey".toByteArray(Charsets.UTF_8)) + private val fastRetryConfig = RetryConfig( + times = 5, + initialDelayMillis = 1, + maxDelay = 1, + factor = 1.0 + ) + + @Test + fun `without existing cohort, success`(): Unit = runBlocking { + val expected = Cohort("a", "User", 1, 100L, setOf("1")) + val mockEngine = MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(GetCohortResponse.fromCohort(expected))), + status = HttpStatusCode.OK + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) + val actual = api.getCohort("a", null, Int.MAX_VALUE) + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/sdk/v1/cohort/a", request.url.encodedPath) + assertEquals( + Parameters.build { + set("maxCohortSize", "${Int.MAX_VALUE}") + }, + request.url.parameters + ) + assertTrue(request.headers["X-Amp-Exp-Library"]!!.startsWith("evaluation-proxy/")) + assertEquals("Basic $token", request.headers["Authorization"]) + } + + @Test + fun `with existing cohort, success`(): Unit = runBlocking { + val expected = Cohort("a", "User", 1, 100L, setOf("1")) + val mockEngine = MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(GetCohortResponse.fromCohort(expected))), + status = HttpStatusCode.OK + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) + val actual = api.getCohort("a", 99, Int.MAX_VALUE) + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/sdk/v1/cohort/a", request.url.encodedPath) + assertEquals( + Parameters.build { + set("maxCohortSize", "${Int.MAX_VALUE}") + set("lastModified", "99") + }, + request.url.parameters + ) + assertTrue(request.headers["X-Amp-Exp-Library"]!!.startsWith("evaluation-proxy/")) + assertEquals("Basic $token", request.headers["Authorization"]) + } + + @Test + fun `with existing cohort, cohort not modified, no retries, throws`(): Unit = runBlocking { + val mockEngine = MockEngine { request -> + respond( + content = "", + status = HttpStatusCode.NoContent + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with CohortNotModifiedException") + } catch (e: CohortNotModifiedException) { + // Success + } + } + + @Test + fun `without existing cohort, cohort too large, no retries, throws`(): Unit = runBlocking { + val mockEngine = MockEngine { request -> + respond( + content = "", + status = HttpStatusCode.PayloadTooLarge + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with CohortTooLargeException") + } catch (e: CohortTooLargeException) { + // Success + } + } + + @Test + fun `request failures, retries, throws`(): Unit = runBlocking { + var failureCounter = 0 + val mockEngine = MockEngine { _ -> + failureCounter++ + throw IOException("test") + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with IOException") + } catch (e: IOException) { + // Success + } + assertEquals(5, failureCounter) + } + + @Test + fun `request server error responses, retries, throws`(): Unit = runBlocking { + val mockEngine = MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.InternalServerError + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + // Success + } + assertEquals(5, mockEngine.responseHistory.size) + } + + @Test + fun `request client too many requests, retries, throws`(): Unit = runBlocking { + val mockEngine = MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.TooManyRequests + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + assertEquals(HttpStatusCode.TooManyRequests, mockEngine.responseHistory[0].statusCode) + } + assertEquals(5, mockEngine.responseHistory.size) + } + + @Test + fun `request client error, no retries, throws`(): Unit = runBlocking { + val mockEngine = MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.NotFound + ) + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + try { + api.getCohort("a", 99, Int.MAX_VALUE) + fail("Expected getCohort call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + assertEquals(HttpStatusCode.NotFound, mockEngine.responseHistory[0].statusCode) + } + assertEquals(1, mockEngine.responseHistory.size) + } + + @Test + fun `request errors, eventual success`(): Unit = runBlocking { + var i = 0 + val expected = Cohort("a", "User", 1, 100L, setOf("1")) + val mockEngine = MockEngine { _ -> + i++ + when (i) { + 1 -> respond(content = "", status = HttpStatusCode.InternalServerError) + 2 -> respond(content = "", status = HttpStatusCode.TooManyRequests) + 3 -> respond(content = "", status = HttpStatusCode.BadGateway) + 4 -> respond(content = "", status = HttpStatusCode.GatewayTimeout) + 5 -> respond( + content = ByteReadChannel(json.encodeToString(GetCohortResponse.fromCohort(expected))), + status = HttpStatusCode.OK + ) + else -> fail("unexpected number of requests") + } + } + val api = CohortApiV1(serverUrl, apiKey, secretKey, mockEngine, fastRetryConfig) + val actual = api.getCohort("a", 99, Int.MAX_VALUE) + assertEquals(5, mockEngine.responseHistory.size) + assertEquals(expected, actual) + } +} diff --git a/core/src/test/kotlin/cohort/CohortLoaderTest.kt b/core/src/test/kotlin/cohort/CohortLoaderTest.kt new file mode 100644 index 0000000..582c970 --- /dev/null +++ b/core/src/test/kotlin/cohort/CohortLoaderTest.kt @@ -0,0 +1,127 @@ +package cohort + +import com.amplitude.cohort.Cohort +import com.amplitude.cohort.CohortApi +import com.amplitude.cohort.CohortLoader +import com.amplitude.cohort.CohortNotModifiedException +import com.amplitude.cohort.CohortStorage +import com.amplitude.cohort.InMemoryCohortStorage +import com.amplitude.util.HttpErrorException +import io.ktor.http.HttpStatusCode +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockk +import io.mockk.spyk +import kotlinx.coroutines.delay +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlin.test.Test +import kotlin.test.assertEquals + +class CohortLoaderTest { + + private val maxCohortSize = Int.MAX_VALUE + + @Test + fun `load cohorts, success`(): Unit = runBlocking { + val cohortA = Cohort("a", "User", 1, 100, setOf("1")) + val cohortB = Cohort("b", "User", 1, 100, setOf("1")) + val cohortC = Cohort("c", "User", 1, 100, setOf("1")) + val api = mockk() + val storage: CohortStorage = spyk(InMemoryCohortStorage()) + val loader = CohortLoader(maxCohortSize, api, storage) + coEvery { api.getCohort("a", null, maxCohortSize) } returns cohortA + coEvery { api.getCohort("b", null, maxCohortSize) } returns cohortB + coEvery { api.getCohort("c", null, maxCohortSize) } returns cohortC + // Run 1 + loader.loadCohorts(setOf("a", "b", "c")) + coVerify(exactly = 3) { api.getCohort(allAny(), isNull(), eq(maxCohortSize)) } + assertEquals( + mapOf( + "a" to cohortA, + "b" to cohortB, + "c" to cohortC + ), + storage.getCohorts() + ) + coVerify(exactly = 3) { storage.putCohort(allAny()) } + // Run 2 + val cohortB2 = cohortB.copy(size = 2, lastModified = 200, members = setOf("1", "2")) + coEvery { api.getCohort("a", 100, maxCohortSize) } returns cohortA + coEvery { api.getCohort("b", 100, maxCohortSize) } returns cohortB2 + coEvery { api.getCohort("c", 100, maxCohortSize) } throws CohortNotModifiedException("c") + loader.loadCohorts(setOf("a", "b", "c")) + coVerify(exactly = 3) { api.getCohort(allAny(), eq(100), eq(maxCohortSize)) } + assertEquals( + mapOf( + "a" to cohortA, + "b" to cohortB2, + "c" to cohortC + ), + storage.getCohorts() + ) + // Cohort C should not be stored + coVerify(exactly = 5) { storage.putCohort(allAny()) } + } + + @Test + fun `load cohorts, simultaneous loading of the same cohorts, only downloads once per cohort`(): Unit = runBlocking { + val cohortA = Cohort("a", "User", 1, 100, setOf("1")) + val cohortB = Cohort("b", "User", 1, 100, setOf("1")) + val cohortC = Cohort("c", "User", 1, 100, setOf("1")) + val api = mockk() + val storage: CohortStorage = spyk(InMemoryCohortStorage()) + val loader = CohortLoader(maxCohortSize, api, storage) + coEvery { api.getCohort("a", null, maxCohortSize) } coAnswers { + delay(100) + cohortA + } + coEvery { api.getCohort("b", null, maxCohortSize) } coAnswers { + delay(100) + cohortB + } + coEvery { api.getCohort("c", null, maxCohortSize) } coAnswers { + delay(100) + cohortC + } + val j1 = launch { + loader.loadCohorts(setOf("a", "b", "c")) + } + val j2 = launch { + loader.loadCohorts(setOf("a", "b", "c")) + } + listOf(j1, j2).joinAll() + coVerify(exactly = 3) { api.getCohort(allAny(), isNull(), eq(maxCohortSize)) } + assertEquals( + mapOf( + "a" to cohortA, + "b" to cohortB, + "c" to cohortC + ), + storage.getCohorts() + ) + } + + @Test + fun `load cohorts, failure, failed cohort not stored, does not throw`(): Unit = runBlocking { + val cohortA = Cohort("a", "User", 1, 100, setOf("1")) + val cohortC = Cohort("c", "User", 1, 100, setOf("1")) + val api = mockk() + val storage: CohortStorage = spyk(InMemoryCohortStorage()) + val loader = CohortLoader(maxCohortSize, api, storage) + coEvery { api.getCohort("a", null, maxCohortSize) } returns cohortA + coEvery { api.getCohort("b", null, maxCohortSize) } throws HttpErrorException(HttpStatusCode.InternalServerError) + coEvery { api.getCohort("c", null, maxCohortSize) } returns cohortC + // Run 1 + loader.loadCohorts(setOf("a", "b", "c")) + coVerify(exactly = 3) { api.getCohort(allAny(), isNull(), eq(maxCohortSize)) } + assertEquals( + mapOf( + "a" to cohortA, + "c" to cohortC + ), + storage.getCohorts() + ) + } +} diff --git a/core/src/test/kotlin/cohort/CohortStorageTest.kt b/core/src/test/kotlin/cohort/CohortStorageTest.kt new file mode 100644 index 0000000..33d3f98 --- /dev/null +++ b/core/src/test/kotlin/cohort/CohortStorageTest.kt @@ -0,0 +1,102 @@ +package cohort + +import test.InMemoryRedis +import test.cohort +import com.amplitude.cohort.Cohort +import com.amplitude.cohort.CohortDescription +import com.amplitude.cohort.CohortStorage +import com.amplitude.cohort.InMemoryCohortStorage +import com.amplitude.cohort.RedisCohortStorage +import com.amplitude.cohort.toCohortDescription +import kotlinx.coroutines.runBlocking +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.time.Duration + +class CohortStorageTest { + + private val redis = InMemoryRedis() + + @Test + fun `test in memory`(): Unit = runBlocking { + test(InMemoryCohortStorage()) + } + + @Test + fun `test redis`(): Unit = runBlocking { + test(RedisCohortStorage("12345", Duration.INFINITE, "amplitude ", redis, redis)) + } + + private fun test(cohortStorage: CohortStorage): Unit = runBlocking { + val cohortA = cohort("a") + val cohortB = cohort("b") + + // test get, null + var cohort: Cohort? = cohortStorage.getCohort(cohortA.id) + assertNull(cohort) + // test get all, empty + var cohorts: Map = cohortStorage.getCohorts() + assertEquals(0, cohorts.size) + // test get description, null + var description: CohortDescription? = cohortStorage.getCohortDescription(cohortA.id) + assertNull(description) + // test get descriptions, empty + var descriptions: Map = cohortStorage.getCohortDescriptions() + assertEquals(0, descriptions.size) + // test put, get, cohort + cohortStorage.putCohort(cohortA) + cohort = cohortStorage.getCohort(cohortA.id) + assertEquals(cohortA, cohort) + // test get description, description + description = cohortStorage.getCohortDescription(cohortA.id) + assertEquals(cohortA.toCohortDescription(), description) + // test put, get all, cohorts + cohortStorage.putCohort(cohortB) + cohorts = cohortStorage.getCohorts() + assertEquals( + mapOf( + cohortA.id to cohortA, + cohortB.id to cohortB + ), + cohorts + ) + // test get descriptions, descriptions + descriptions = cohortStorage.getCohortDescriptions() + assertEquals( + mapOf( + cohortA.id to cohortA.toCohortDescription(), + cohortB.id to cohortB.toCohortDescription() + ), + descriptions + ) + // test delete one + cohortStorage.deleteCohort(cohortA.toCohortDescription()) + // test get deleted, null + cohort = cohortStorage.getCohort(cohortA.id) + assertNull(cohort) + // test get other, cohort + cohort = cohortStorage.getCohort(cohortB.id) + assertEquals(cohortB, cohort) + // test get all, cohort + cohorts = cohortStorage.getCohorts() + assertEquals(mapOf(cohortB.id to cohortB), cohorts) + // test get description deleted, null + description = cohortStorage.getCohortDescription(cohortA.id) + assertNull(description) + // test get description other, description + description = cohortStorage.getCohortDescription(cohortB.id) + assertEquals(cohortB.toCohortDescription(), description) + // test get descriptions, description + descriptions = cohortStorage.getCohortDescriptions() + assertEquals(mapOf(cohortB.id to cohortB.toCohortDescription()), descriptions) + // test delete other + cohortStorage.deleteCohort(cohortB.toCohortDescription()) + // test get all, empty + cohorts = cohortStorage.getCohorts() + assertEquals(0, cohorts.size) + // test get descriptions, empty + descriptions = cohortStorage.getCohortDescriptions() + assertEquals(0, descriptions.size) + } +} diff --git a/core/src/test/kotlin/deployment/DeploymentApiTest.kt b/core/src/test/kotlin/deployment/DeploymentApiTest.kt new file mode 100644 index 0000000..1794ac5 --- /dev/null +++ b/core/src/test/kotlin/deployment/DeploymentApiTest.kt @@ -0,0 +1,151 @@ +package deployment + +import com.amplitude.deployment.DeploymentApiV2 +import com.amplitude.util.HttpErrorException +import com.amplitude.util.RetryConfig +import com.amplitude.util.json +import test.flag +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.http.Parameters +import io.ktor.utils.io.ByteReadChannel +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.encodeToString +import org.junit.Test +import java.io.IOException +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.fail + +class DeploymentApiTest { + private val deploymentKey = "deployment" + private val serverUrl = "https://api.lab.amplitude.com/" + private val fastRetryConfig = RetryConfig( + times = 5, + initialDelayMillis = 1, + maxDelay = 1, + factor = 1.0 + ) + + @Test + fun `get flags, success`(): Unit = runBlocking { + val expected = listOf(flag()) + val mockEngine = MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(expected)), + status = HttpStatusCode.OK + ) + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + val actual = api.getFlagConfigs(deploymentKey) + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/sdk/v2/flags", request.url.encodedPath) + assertEquals( + Parameters.build { + set("v", "0") + }, + request.url.parameters + ) + assertTrue(request.headers["X-Amp-Exp-Library"]!!.startsWith("evaluation-proxy/")) + assertEquals("Api-Key $deploymentKey", request.headers["Authorization"]) + } + + @Test + fun `request failures, retries, throws`(): Unit = runBlocking { + var failureCounter = 0 + val mockEngine = MockEngine { _ -> + failureCounter++ + throw IOException("test") + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + try { + api.getFlagConfigs(deploymentKey) + fail("Expected getFlagConfigs call to fail with IOException") + } catch (e: IOException) { + // Success + } + assertEquals(5, failureCounter) + } + + @Test + fun `request server error responses, retries, throws`(): Unit = runBlocking { + val mockEngine = MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.InternalServerError + ) + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + try { + api.getFlagConfigs(deploymentKey) + fail("Expected getFlagConfigs call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + // Success + } + assertEquals(5, mockEngine.responseHistory.size) + } + + @Test + fun `request client too many requests, retries, throws`(): Unit = runBlocking { + val mockEngine = MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.TooManyRequests + ) + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + try { + api.getFlagConfigs(deploymentKey) + fail("Expected getFlagConfigs call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + assertEquals(HttpStatusCode.TooManyRequests, mockEngine.responseHistory[0].statusCode) + } + assertEquals(5, mockEngine.responseHistory.size) + } + + @Test + fun `request client error, no retries, throws`(): Unit = runBlocking { + val mockEngine = MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.NotFound + ) + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + try { + api.getFlagConfigs(deploymentKey) + fail("Expected getFlagConfigs call to fail with HttpErrorException") + } catch (e: HttpErrorException) { + assertEquals(HttpStatusCode.NotFound, mockEngine.responseHistory[0].statusCode) + } + assertEquals(1, mockEngine.responseHistory.size) + } + + @Test + fun `request errors, eventual success`(): Unit = runBlocking { + var i = 0 + val expected = listOf(flag()) + val mockEngine = MockEngine { _ -> + i++ + when (i) { + 1 -> respond(content = "", status = HttpStatusCode.InternalServerError) + 2 -> respond(content = "", status = HttpStatusCode.TooManyRequests) + 3 -> respond(content = "", status = HttpStatusCode.BadGateway) + 4 -> respond(content = "", status = HttpStatusCode.GatewayTimeout) + 5 -> respond( + content = ByteReadChannel(json.encodeToString(expected)), + status = HttpStatusCode.OK + ) + else -> fail("unexpected number of requests") + } + } + val api = DeploymentApiV2(serverUrl, mockEngine, fastRetryConfig) + val actual = api.getFlagConfigs(deploymentKey) + assertEquals(5, mockEngine.responseHistory.size) + assertEquals(expected, actual) + } +} diff --git a/core/src/test/kotlin/deployment/DeploymentLoaderTest.kt b/core/src/test/kotlin/deployment/DeploymentLoaderTest.kt new file mode 100644 index 0000000..bfb6af9 --- /dev/null +++ b/core/src/test/kotlin/deployment/DeploymentLoaderTest.kt @@ -0,0 +1,176 @@ +package deployment + +import com.amplitude.cohort.CohortLoader +import com.amplitude.deployment.DeploymentApi +import com.amplitude.deployment.DeploymentLoader +import com.amplitude.deployment.InMemoryDeploymentStorage +import test.flag +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockk +import io.mockk.spyk +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertNull +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.fail + +class DeploymentLoaderTest { + + private val deploymentKey = "deployment" + private val flagKey = "flag" + + @Test + fun `load flags without cohorts, cohorts not loaded, success`(): Unit = runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = emptySet() + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + val loader = DeploymentLoader(api, storage, cohortLoader) + loader.loadDeployment(deploymentKey) + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(allAny()) } + coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } + assertEquals(flag, storage.getFlag(deploymentKey, flagKey)) + } + + @Test + fun `load flags with cohorts, cohorts are loaded, success`(): Unit = runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + val loader = DeploymentLoader(api, storage, cohortLoader) + loader.loadDeployment(deploymentKey) + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } + assertEquals(flag, storage.getFlag(deploymentKey, flagKey)) + } + + @Test + fun `existing flags state, some flags removed, success`(): Unit = runBlocking { + val existingFlagKey = "existing" + val existingFlag = flag(existingFlagKey) + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage().apply { + putFlag(deploymentKey, existingFlag) + }) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + val loader = DeploymentLoader(api, storage, cohortLoader) + loader.loadDeployment(deploymentKey) + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 1) { storage.removeFlag(eq(deploymentKey), eq(existingFlagKey)) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } + assertEquals(flag, storage.getFlag(deploymentKey, flagKey)) + } + + @Test + fun `getFlagConfigs fails, throws`(): Unit = runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + coEvery { api.getFlagConfigs(eq(deploymentKey)) } throws RuntimeException("fail") + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + val loader = DeploymentLoader(api, storage, cohortLoader) + try { + loader.loadDeployment(deploymentKey) + fail("Expected loadDeployment to throw exception") + } catch (e: RuntimeException) { + // Success + } + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 0) { storage.putFlag(allAny(), allAny()) } + assertNull(storage.getFlag(deploymentKey, flagKey)) + } + + @Test + fun `removeFlag fails, throws`(): Unit = runBlocking { + val existingFlagKey = "existing" + val existingFlag = flag(existingFlagKey) + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage().apply { + putFlag(deploymentKey, existingFlag) + }) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { storage.removeFlag(eq(deploymentKey), eq(existingFlagKey)) } throws RuntimeException("fail") + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + val loader = DeploymentLoader(api, storage, cohortLoader) + try { + loader.loadDeployment(deploymentKey) + fail("Expected loadDeployment to throw exception") + } catch (e: RuntimeException) { + // Success + } + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 1) { storage.removeFlag(eq(deploymentKey), eq(existingFlagKey)) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 0) { storage.putFlag(allAny(), allAny()) } + assertNull(storage.getFlag(deploymentKey, flagKey)) + } + + @Test + fun `loadCohorts fails, throws`(): Unit = runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } throws RuntimeException("fail") + val loader = DeploymentLoader(api, storage, cohortLoader) + try { + loader.loadDeployment(deploymentKey) + fail("Expected loadDeployment to throw exception") + } catch (e: RuntimeException) { + // Success + } + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 0) { storage.putFlag(allAny(), allAny()) } + assertNull(storage.getFlag(deploymentKey, flagKey)) + } + + @Test + fun `putFlag fails, throws`(): Unit = runBlocking { + val api = mockk() + val storage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortIds = setOf("a", "b") + val flag = flag(flagKey = flagKey, cohortIds = cohortIds) + coEvery { api.getFlagConfigs(eq(deploymentKey)) } returns listOf(flag) + coEvery { cohortLoader.loadCohorts(eq(cohortIds)) } returns Unit + coEvery { storage.putFlag(eq(deploymentKey), eq(flag)) } throws RuntimeException("fail") + val loader = DeploymentLoader(api, storage, cohortLoader) + try { + loader.loadDeployment(deploymentKey) + fail("Expected loadDeployment to throw exception") + } catch (e: RuntimeException) { + // Success + } + coVerify(exactly = 1) { api.getFlagConfigs(eq(deploymentKey)) } + coVerify(exactly = 0) { storage.removeFlag(allAny(), allAny()) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + coVerify(exactly = 1) { storage.putFlag(allAny(), allAny()) } + assertNull(storage.getFlag(deploymentKey, flagKey)) + } +} diff --git a/core/src/test/kotlin/deployment/DeploymentRunnerTest.kt b/core/src/test/kotlin/deployment/DeploymentRunnerTest.kt new file mode 100644 index 0000000..b6b813e --- /dev/null +++ b/core/src/test/kotlin/deployment/DeploymentRunnerTest.kt @@ -0,0 +1,132 @@ +package deployment + +import com.amplitude.Configuration +import com.amplitude.cohort.CohortLoader +import com.amplitude.deployment.DeploymentLoader +import com.amplitude.deployment.DeploymentRunner +import com.amplitude.deployment.DeploymentStorage +import test.flag +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockk +import kotlinx.coroutines.delay +import kotlinx.coroutines.runBlocking +import org.junit.Test + +class DeploymentRunnerTest { + + private val deploymentKey = "deployment" + + @Test + fun `start, stop, load deployment called once, success`(): Unit = runBlocking { + val flag = flag(cohortIds = setOf("a")) + val configuration = Configuration( + flagSyncIntervalMillis = 50, + cohortSyncIntervalMillis = 50, + ) + val cohortLoader = mockk() + val deploymentStorage = mockk() + val deploymentLoader = mockk() + coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } returns Unit + coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } returns mapOf(flag.key to flag) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + val deploymentRunner = DeploymentRunner( + configuration, + deploymentKey, + cohortLoader, + deploymentStorage, + deploymentLoader + ) + deploymentRunner.start() + deploymentRunner.stop() + delay(100) + coVerify(exactly = 1) { deploymentLoader.loadDeployment(allAny()) } + coVerify(exactly = 0) { deploymentStorage.getAllFlags(allAny()) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(allAny()) } + } + + @Test + fun `start, delay, periodic loaders run, success`(): Unit = runBlocking { + val cohortIds = setOf("a") + val flag = flag(cohortIds = cohortIds) + val configuration = Configuration( + flagSyncIntervalMillis = 50, + cohortSyncIntervalMillis = 50, + ) + val cohortLoader = mockk() + val deploymentStorage = mockk() + val deploymentLoader = mockk() + coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } returns Unit + coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } returns mapOf(flag.key to flag) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + val deploymentRunner = DeploymentRunner( + configuration, + deploymentKey, + cohortLoader, + deploymentStorage, + deploymentLoader + ) + deploymentRunner.start() + delay(75) + deploymentRunner.stop() + coVerify(exactly = 2) { deploymentLoader.loadDeployment(eq(deploymentKey)) } + coVerify(exactly = 1) { deploymentStorage.getAllFlags(eq(deploymentKey)) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + } + + @Test + fun `start, load deployment throws, does not throw, pollers run`(): Unit = runBlocking { + val cohortIds = setOf("a") + val flag = flag(cohortIds = cohortIds) + val configuration = Configuration( + flagSyncIntervalMillis = 50, + cohortSyncIntervalMillis = 50, + ) + val cohortLoader = mockk() + val deploymentStorage = mockk() + val deploymentLoader = mockk() + coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } throws RuntimeException("fail") + coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } returns mapOf(flag.key to flag) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + val deploymentRunner = DeploymentRunner( + configuration, + deploymentKey, + cohortLoader, + deploymentStorage, + deploymentLoader + ) + deploymentRunner.start() + delay(75) + deploymentRunner.stop() + coVerify(exactly = 2) { deploymentLoader.loadDeployment(eq(deploymentKey)) } + coVerify(exactly = 1) { deploymentStorage.getAllFlags(eq(deploymentKey)) } + coVerify(exactly = 1) { cohortLoader.loadCohorts(eq(cohortIds)) } + } + + @Test + fun `start get all flags throws, does not throw`(): Unit = runBlocking { + val configuration = Configuration( + flagSyncIntervalMillis = 10, + cohortSyncIntervalMillis = 10, + ) + val cohortLoader = mockk() + val deploymentStorage = mockk() + val deploymentLoader = mockk() + coEvery { deploymentLoader.loadDeployment(eq(deploymentKey)) } returns Unit + coEvery { deploymentStorage.getAllFlags(eq(deploymentKey)) } throws RuntimeException("fail") + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + val deploymentRunner = DeploymentRunner( + configuration, + deploymentKey, + cohortLoader, + deploymentStorage, + deploymentLoader + ) + deploymentRunner.start() + delay(100) + deploymentRunner.stop() + coVerify(atLeast = 3) { deploymentLoader.loadDeployment(eq(deploymentKey)) } + coVerify(atLeast = 2) { deploymentStorage.getAllFlags(eq(deploymentKey)) } + coVerify(exactly = 0) { cohortLoader.loadCohorts(allAny()) } + } +} diff --git a/core/src/test/kotlin/deployment/DeploymentStorageTest.kt b/core/src/test/kotlin/deployment/DeploymentStorageTest.kt new file mode 100644 index 0000000..acf9356 --- /dev/null +++ b/core/src/test/kotlin/deployment/DeploymentStorageTest.kt @@ -0,0 +1,113 @@ +package deployment + +import test.InMemoryRedis +import com.amplitude.deployment.DeploymentStorage +import com.amplitude.deployment.InMemoryDeploymentStorage +import com.amplitude.deployment.RedisDeploymentStorage +import test.deployment +import test.flag +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertEquals +import org.junit.Test +import kotlin.test.assertNull + +class DeploymentStorageTest { + + private val redis = InMemoryRedis() + + @Test + fun `test in memory`(): Unit = runBlocking { + test(InMemoryDeploymentStorage()) + } + + @Test + fun `test redis`(): Unit = runBlocking { + test(RedisDeploymentStorage("amplitude", "12345", redis, redis)) + } + + private fun test(storage: DeploymentStorage): Unit = runBlocking { + val deploymentA = deployment("a", "1") + val deploymentB = deployment("b", "2") + + // get deployment, null + var deployment = storage.getDeployment(deploymentA.key) + assertNull(deployment) + // get deployments, empty + var deployments = storage.getDeployments() + assertEquals(0, deployments.size) + // put, get deployment, deployment + storage.putDeployment(deploymentA) + deployment = storage.getDeployment(deploymentA.key) + assertEquals(deploymentA, deployment) + // put, get deployments, deployments + storage.putDeployment(deploymentB) + deployments = storage.getDeployments() + assertEquals( + mapOf(deploymentA.key to deploymentA, deploymentB.key to deploymentB), + deployments + ) + + val flag1 = flag("1") + val flag2 = flag("2") + + // get flag, null + var flag = storage.getFlag(deploymentA.key, flag1.key) + assertNull(flag) + // get all flags, empty + var flags = storage.getAllFlags(deploymentA.key) + assertEquals(0, flags.size) + // put, get flag, flag + storage.putFlag(deploymentA.key, flag1) + flag = storage.getFlag(deploymentA.key, flag1.key) + assertEquals(flag1, flag) + // put, get all flags, flags + storage.putFlag(deploymentA.key, flag2) + flags = storage.getAllFlags(deploymentA.key) + assertEquals(mapOf(flag1.key to flag1, flag2.key to flag2), flags) + // remove, get removed, null + storage.removeFlag(deploymentA.key, flag1.key) + flag = storage.getFlag(deploymentA.key, flag1.key) + assertNull(flag) + // get other, other + flag = storage.getFlag(deploymentA.key, flag2.key) + assertEquals(flag2, flag) + // get all flags, other + flags = storage.getAllFlags(deploymentA.key) + assertEquals(mapOf(flag2.key to flag2), flags) + // remove all flags, get, null + storage.removeAllFlags(deploymentA.key) + // get all flags, empty + flags = storage.getAllFlags(deploymentA.key) + assertEquals(0, flags.size) + + // put flags + storage.putFlag(deploymentA.key, flag1) + storage.putFlag(deploymentA.key, flag2) + flags = storage.getAllFlags(deploymentA.key) + assertEquals(mapOf(flag1.key to flag1, flag2.key to flag2), flags) + + // remove deployment, get removed, null + storage.removeDeployment(deploymentA.key) + // get flag, null + flag = storage.getFlag(deploymentA.key, flag1.key) + assertNull(flag) + flag = storage.getFlag(deploymentA.key, flag2.key) + assertNull(flag) + // get all flags, empty + flags = storage.getAllFlags(deploymentA.key) + assertEquals(0, flags.size) + // get other, deployment + deployment = storage.getDeployment(deploymentB.key) + assertEquals(deploymentB, deployment) + // get deployments, other + deployments = storage.getDeployments() + assertEquals(mapOf(deploymentB.key to deploymentB), deployments) + // get all flags, empty + flags = storage.getAllFlags(deploymentB.key) + assertEquals(0, flags.size) + // remove, get deployments, empty + storage.removeDeployment(deploymentB.key) + deployments = storage.getDeployments() + assertEquals(0, deployments.size) + } +} diff --git a/core/src/test/kotlin/project/ProjectApiTest.kt b/core/src/test/kotlin/project/ProjectApiTest.kt new file mode 100644 index 0000000..d7fdffe --- /dev/null +++ b/core/src/test/kotlin/project/ProjectApiTest.kt @@ -0,0 +1,67 @@ +package project + +import com.amplitude.project.DeploymentsResponse +import com.amplitude.project.ProjectApiV1 +import com.amplitude.util.json +import test.deployment +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.utils.io.ByteReadChannel +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.encodeToString +import org.junit.Test +import test.toSerialDeployment +import kotlin.test.assertEquals + +class ProjectApiTest { + + private val managementKey = "managementKey" + private val serverUrl = "https://experiment.amplitude.com/" + + @Test + fun `get deployments, success`(): Unit = runBlocking { + val deploymentA = deployment("a") + val deploymentB = deployment("b") + val expected = listOf(deploymentA, deploymentB) + val serialDeployments = expected.map { it.toSerialDeployment() } + val mockEngine = MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(DeploymentsResponse(serialDeployments))), + status = HttpStatusCode.OK + ) + } + val api = ProjectApiV1(serverUrl, managementKey, mockEngine) + val actual = api.getDeployments() + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/api/1/deployments", request.url.encodedPath) + assertEquals("Bearer $managementKey", request.headers["Authorization"]) + } + + @Test + fun `get deployments, one deleted, returns only active deployment`(): Unit = runBlocking { + val deploymentA = deployment("a") + val deploymentB = deployment("b") + val expected = listOf(deploymentB) + val serialDeployments = listOf( + deploymentA.toSerialDeployment(true), + deploymentB.toSerialDeployment() + ) + val mockEngine = MockEngine { request -> + respond( + content = ByteReadChannel(json.encodeToString(DeploymentsResponse(serialDeployments))), + status = HttpStatusCode.OK + ) + } + val api = ProjectApiV1(serverUrl, managementKey, mockEngine) + val actual = api.getDeployments() + assertEquals(expected, actual) + val request = mockEngine.requestHistory[0] + assertEquals(HttpMethod.Get, request.method) + assertEquals("/api/1/deployments", request.url.encodedPath) + assertEquals("Bearer $managementKey", request.headers["Authorization"]) + } +} diff --git a/core/src/test/kotlin/project/ProjectProxyTest.kt b/core/src/test/kotlin/project/ProjectProxyTest.kt new file mode 100644 index 0000000..6b81fd8 --- /dev/null +++ b/core/src/test/kotlin/project/ProjectProxyTest.kt @@ -0,0 +1,318 @@ +package project + +import test.cohort +import com.amplitude.Configuration +import com.amplitude.assignment.AssignmentTracker +import com.amplitude.cohort.GetCohortResponse +import com.amplitude.cohort.InMemoryCohortStorage +import com.amplitude.deployment.InMemoryDeploymentStorage +import com.amplitude.project.ProjectProxy +import com.amplitude.util.json +import test.deployment +import test.flag +import io.ktor.http.HttpStatusCode +import io.mockk.coEvery +import io.mockk.mockk +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.encodeToString +import test.project +import kotlin.test.Test +import kotlin.test.assertEquals + +class ProjectProxyTest { + + private val project = project() + private val configuration = Configuration() + + @Test + fun `test get flag configs, null deployment, unauthorized`(): Unit = runBlocking { + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getFlagConfigs(null) + assertEquals(HttpStatusCode.Unauthorized, result.status) + } + + @Test + fun `test get flag configs, with deployment, success`(): Unit = runBlocking { + val deployment = deployment("deployment") + val flag = flag("flag") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage().apply { + putFlag(deployment.key, flag) + } + val cohortStorage = InMemoryCohortStorage() + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getFlagConfigs(deployment.key) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(listOf(flag)), result.body) + } + + @Test + fun `test get cohort, null cohort id, not found`(): Unit = runBlocking { + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohort(null, null, null) + assertEquals(HttpStatusCode.NotFound, result.status) + } + + @Test + fun `test get cohort, with cohort id, success`(): Unit = runBlocking { + val cohort = cohort("a") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohort("a", null, null) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(GetCohortResponse.fromCohort(cohort)), result.body) + } + + @Test + fun `test get cohort, with cohort id and last modified, success`(): Unit = runBlocking { + val cohort = cohort("a", 100) + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohort("a", 1, null) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(GetCohortResponse.fromCohort(cohort)), result.body) + } + + @Test + fun `test get cohort, with cohort id and last modified, not changed`(): Unit = runBlocking { + val cohort = cohort("a", 100) + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohort("a", 100, null) + assertEquals(HttpStatusCode.NoContent, result.status) + } + + @Test + fun `test get cohort, with cohort id and max cohort size, success`(): Unit = runBlocking { + val deployment = deployment("deployment") + val flag = flag("flag", setOf("a")) + val cohort = cohort("a", 100, 100) + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage().apply { + putFlag(deployment.key, flag) + } + val cohortStorage = InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohort("a", null, Int.MAX_VALUE) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(GetCohortResponse.fromCohort(cohort)), result.body) + } + + + @Test + fun `test get cohort, with cohort id and max cohort size, too large`(): Unit = runBlocking { + val cohort = cohort("a", 100, 100) + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohort("a", null, 99) + assertEquals(HttpStatusCode.PayloadTooLarge, result.status) + } + + @Test + fun `test get cohort memberships for group, null deployment, unauthorized`(): Unit = runBlocking { + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohortMemberships(null, null, null) + assertEquals(HttpStatusCode.Unauthorized, result.status) + } + + @Test + fun `test get cohort memberships for group, with deployment and group type, null group name, bad request`(): Unit = runBlocking { + val deployment = deployment("deployment") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohortMemberships(deployment.key, "User", null) + assertEquals(HttpStatusCode.BadRequest, result.status) + } + + @Test + fun `test get cohort memberships for group, with deployment and group name, null group type, bad request`(): Unit = runBlocking { + val deployment = deployment("deployment") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohortMemberships(deployment.key, null, "1") + assertEquals(HttpStatusCode.BadRequest, result.status) + } + + @Test + fun `test get cohort memberships for group, with deployment group name and group type, success`(): Unit = runBlocking { + val deployment = deployment("deployment") + val cohort = cohort("a") + val flag = flag("flag", setOf("a")) + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage().apply { + putFlag(deployment.key, flag) + } + val cohortStorage = InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val result = projectProxy.getCohortMemberships(deployment.key, "User", "1") + assertEquals(HttpStatusCode.OK, result.status) + assertEquals(json.encodeToString(listOf(cohort.id)), result.body) + } + + @Test + fun `test evaluate, null deployment, unauthorized`(): Unit = runBlocking { + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + + val result = projectProxy.evaluate(null, null, null) + assertEquals(HttpStatusCode.Unauthorized, result.status) + } + + @Test + fun `test evaluate, with deployment, null user, success`(): Unit = runBlocking { + val deployment = deployment("deployment") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage() + val cohortStorage = InMemoryCohortStorage() + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + + val result = projectProxy.evaluate(deployment.key, null, null) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals("{}", result.body) + } + + @Test + fun `test evaluate, with deployment and flag keys, success`(): Unit = runBlocking { + val cohort = cohort("a") + val flag = flag("flag", setOf("a")) + val deployment = deployment("deployment") + val assignmentTracker = mockk() + val deploymentStorage = InMemoryDeploymentStorage().apply { + putFlag(deployment.key, flag) + } + val cohortStorage = InMemoryCohortStorage().apply { + putCohort(cohort) + } + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + val user = mapOf("user_id" to "1") + coEvery { assignmentTracker.track(allAny()) } returns Unit + val result = projectProxy.evaluate(deployment.key, user, setOf("flag")) + assertEquals(HttpStatusCode.OK, result.status) + assertEquals("""{"flag":{"key":"on","value":"on"}}""", result.body) + } +} diff --git a/core/src/test/kotlin/project/ProjectRunnerTest.kt b/core/src/test/kotlin/project/ProjectRunnerTest.kt new file mode 100644 index 0000000..2a3f011 --- /dev/null +++ b/core/src/test/kotlin/project/ProjectRunnerTest.kt @@ -0,0 +1,127 @@ +package project + +import test.cohort +import com.amplitude.Configuration +import com.amplitude.cohort.CohortLoader +import com.amplitude.cohort.CohortStorage +import com.amplitude.cohort.InMemoryCohortStorage +import com.amplitude.cohort.toCohortDescription +import com.amplitude.deployment.DeploymentLoader +import com.amplitude.deployment.DeploymentStorage +import com.amplitude.deployment.InMemoryDeploymentStorage +import com.amplitude.project.ProjectApi +import com.amplitude.project.ProjectRunner +import test.deployment +import test.flag +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockk +import io.mockk.spyk +import kotlinx.coroutines.delay +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertNull +import test.project +import org.junit.Test +import kotlin.test.assertNotNull + +class ProjectRunnerTest { + + private val project = project() + private val config = Configuration(deploymentSyncIntervalMillis = 50) + + @Test + fun `test start, no state initial load, success`(): Unit = runBlocking { + val projectApi = mockk() + val deploymentLoader = mockk() + val deploymentStorage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortStorage = spyk(InMemoryCohortStorage()) + val runner = ProjectRunner( + project, + config, + projectApi, + deploymentLoader, + deploymentStorage, + cohortLoader, + cohortStorage + ) + coEvery { projectApi.getDeployments() } returns listOf(deployment("a")) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + coEvery { deploymentLoader.loadDeployment(allAny()) } returns Unit + runner.start() + runner.stop() + coVerify(exactly = 1) { projectApi.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.putDeployment(eq(deployment("a"))) } + assertNotNull(runner.deploymentRunners["a"]) + coVerify(exactly = 0) { deploymentStorage.removeAllFlags(allAny()) } + coVerify(exactly = 0) { deploymentStorage.removeDeployment(allAny()) } + coVerify(exactly = 1) { deploymentStorage.getAllFlags(allAny()) } + coVerify(exactly = 1) { cohortStorage.getCohortDescriptions() } + coVerify(exactly = 0) { cohortStorage.deleteCohort(allAny()) } + } + + @Test + fun `test start, with initial state, add and remove deployment, success`(): Unit = runBlocking { + val projectApi = mockk() + val deploymentLoader = mockk() + val deploymentStorage = spyk(InMemoryDeploymentStorage().apply { + putDeployment(deployment("a")) + putFlag("a", flag(cohortIds = setOf("aa"))) + }) + val cohortLoader = mockk() + val cohortStorage = spyk(InMemoryCohortStorage().apply { + putCohort(cohort("aa")) + }) + coEvery { projectApi.getDeployments() } returns listOf(deployment("b")) + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + coEvery { deploymentLoader.loadDeployment(allAny()) } returns Unit + val runner = ProjectRunner( + project, + config, + projectApi, + deploymentLoader, + deploymentStorage, + cohortLoader, + cohortStorage + ) + runner.start() + runner.stop() + coVerify(exactly = 1) { projectApi.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.getDeployments() } + coVerify(exactly = 1) { deploymentStorage.putDeployment(eq(deployment("b"))) } + assertNull(runner.deploymentRunners["a"]) + assertNotNull(runner.deploymentRunners["b"]) + coVerify(exactly = 1) { deploymentStorage.removeAllFlags(eq("a")) } + coVerify(exactly = 1) { deploymentStorage.removeDeployment(eq("a")) } + coVerify(exactly = 1) { deploymentStorage.getAllFlags(eq("b")) } + coVerify(exactly = 1) { cohortStorage.getCohortDescriptions() } + coVerify(exactly = 1) { cohortStorage.deleteCohort(eq(cohort("aa").toCohortDescription())) } + } + + @Test + fun `test start, deployments api failure, does not throw`(): Unit = runBlocking { + val projectApi = mockk() + val deploymentLoader = mockk() + val deploymentStorage = spyk(InMemoryDeploymentStorage()) + val cohortLoader = mockk() + val cohortStorage = spyk(InMemoryCohortStorage()) + coEvery { projectApi.getDeployments() } throws RuntimeException("test") + coEvery { cohortLoader.loadCohorts(allAny()) } returns Unit + coEvery { deploymentLoader.loadDeployment(allAny()) } returns Unit + val runner = ProjectRunner( + project, + config, + projectApi, + deploymentLoader, + deploymentStorage, + cohortLoader, + cohortStorage + ) + runner.start() + delay(150) + // Pollers should start and call getDeployments again + coVerify(atLeast = 2) { projectApi.getDeployments() } + + } +} diff --git a/core/src/test/kotlin/project/ProjectStorageTest.kt b/core/src/test/kotlin/project/ProjectStorageTest.kt new file mode 100644 index 0000000..f88f629 --- /dev/null +++ b/core/src/test/kotlin/project/ProjectStorageTest.kt @@ -0,0 +1,45 @@ +package project + +import test.InMemoryRedis +import com.amplitude.project.InMemoryProjectStorage +import com.amplitude.project.ProjectStorage +import com.amplitude.project.RedisProjectStorage +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertEquals +import kotlin.test.Test + +class ProjectStorageTest { + + @Test + fun `test in memory`(): Unit = runBlocking { + test(InMemoryProjectStorage()) + } + + @Test + fun `test redis`(): Unit = runBlocking { + test(RedisProjectStorage("amplitude", InMemoryRedis())) + } + + private fun test(storage: ProjectStorage) = runBlocking { + // get projects, empty + var projects = storage.getProjects() + assertEquals(0, projects.size) + // put project, 1 + storage.putProject("1") + // put project, 2 + storage.putProject("2") + // get projects, 1 2 + projects = storage.getProjects() + assertEquals(setOf("1", "2"), projects) + // remove project 1, 2 + storage.removeProject("1") + // get projects, 2 + projects = storage.getProjects() + assertEquals(setOf("2"), projects) + // remove project 2 + storage.removeProject("2") + // get projects, empty + projects = storage.getProjects() + assertEquals(0, projects.size) + } +} diff --git a/core/src/test/kotlin/test/InMemoryRedis.kt b/core/src/test/kotlin/test/InMemoryRedis.kt new file mode 100644 index 0000000..02b0e9e --- /dev/null +++ b/core/src/test/kotlin/test/InMemoryRedis.kt @@ -0,0 +1,63 @@ +package test + +import com.amplitude.util.Redis +import com.amplitude.util.RedisKey +import kotlin.time.Duration + +internal class InMemoryRedis: Redis { + + private val kv = mutableMapOf() + private val sets = mutableMapOf>() + private val hashes = mutableMapOf>() + + override suspend fun get(key: RedisKey): String? { + return kv[key.value] + } + + override suspend fun set(key: RedisKey, value: String) { + kv[key.value] = value + } + + override suspend fun del(key: RedisKey) { + kv.remove(key.value) + } + + override suspend fun sadd(key: RedisKey, values: Set) { + sets.getOrPut(key.value) { mutableSetOf() }.addAll(values) + } + + override suspend fun srem(key: RedisKey, value: String) { + sets.getOrPut(key.value) { mutableSetOf() }.remove(value) + } + + override suspend fun smembers(key: RedisKey): Set? { + return sets[key.value]?.toSet() + } + + override suspend fun sismember(key: RedisKey, value: String): Boolean { + return sets[key.value]?.contains(value) ?: false + } + + override suspend fun hget(key: RedisKey, field: String): String? { + return hashes.getOrPut(key.value) { mutableMapOf() }[field] + } + + override suspend fun hgetall(key: RedisKey): Map? { + return hashes[key.value]?.toMap() + } + + override suspend fun hset(key: RedisKey, values: Map) { + hashes.getOrPut(key.value) { mutableMapOf() }.putAll(values) + } + + override suspend fun hdel(key: RedisKey, field: String) { + hashes[key.value]?.remove(field) + if(hashes[key.value]?.isEmpty() == true) { + hashes.remove(key.value) + } + } + + override suspend fun expire(key: RedisKey, ttl: Duration) { + // Do nothing. + } +} diff --git a/core/src/test/kotlin/test/Utils.kt b/core/src/test/kotlin/test/Utils.kt new file mode 100644 index 0000000..986d4d4 --- /dev/null +++ b/core/src/test/kotlin/test/Utils.kt @@ -0,0 +1,91 @@ +package test + +import com.amplitude.cohort.Cohort +import com.amplitude.deployment.Deployment +import com.amplitude.experiment.evaluation.EvaluationCondition +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.evaluation.EvaluationOperator +import com.amplitude.experiment.evaluation.EvaluationSegment +import com.amplitude.experiment.evaluation.EvaluationVariant +import com.amplitude.project.Project +import com.amplitude.project.SerialDeployment + +internal fun user( + userId: String? = null, + deviceId: String? = null, + userProperties: Map? = null, + groups: Map>? = null, + groupProperties: Map>>? = null +): MutableMap { + return mutableMapOf( + "user_id" to userId, + "device_id" to deviceId, + "user_properties" to userProperties, + "groups" to groups, + "group_properties" to groupProperties + ) +} + +internal fun flag( + flagKey: String = "flag", + cohortIds: Set = setOf("a") +) = EvaluationFlag( + key = flagKey, + variants = mapOf( + "off" to EvaluationVariant("off", null, null, null), + "on" to EvaluationVariant("on", "on", null, null) + ), + segments = listOf( + EvaluationSegment( + conditions = listOf( + listOf( + EvaluationCondition( + selector = listOf("context", "user", "cohort_ids"), + op = EvaluationOperator.SET_CONTAINS_ANY, + values = cohortIds + ) + ) + ), + variant = "on" + ), + EvaluationSegment( + variant = "off" + ) + ) +) + +internal fun cohort( + id: String, + lastModified: Long = 100, + size: Int = 1, + members: Set = setOf("1"), + groupType: String = "User" +) = Cohort( + id = id, + groupType = groupType, + size = size, + lastModified = lastModified, + members = members +) + +internal fun deployment(key: String, projectId: String = "1") = Deployment( + id = "1", + projectId = projectId, + label = "", + key = key, +) + +internal fun Deployment.toSerialDeployment(deleted: Boolean = false) = SerialDeployment( + id = id, + projectId = projectId, + label = label, + key = key, + deleted = deleted, +) + +internal fun project(id: String = "1") = Project( + id = id, + apiKey = "api", + secretKey = "secret", + managementKey = "management" +) diff --git a/core/src/test/kotlin/util/EvaluationFlagTest.kt b/core/src/test/kotlin/util/EvaluationFlagTest.kt new file mode 100644 index 0000000..1311e61 --- /dev/null +++ b/core/src/test/kotlin/util/EvaluationFlagTest.kt @@ -0,0 +1,23 @@ +package util + +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.util.getGroupedCohortIds +import com.amplitude.util.json +import org.junit.Assert.assertEquals +import kotlin.test.Test + +class EvaluationFlagTest { + + private val testFlagsJson = """[{"key":"flag1","segments":[{"conditions":[[{"op":"set contains any","selector":["context","user","cohort_ids"],"values":["hahahaha1"]}]]},{"metadata":{"segmentName":"All Other Users"},"variant":"off"}],"variants":{}},{"key":"flag2","segments":[{"conditions":[[{"op":"set contains any","selector":["context","user","cohort_ids"],"values":["hahahaha2"]}]],"metadata":{"segmentName":"Segment 1"},"variant":"off"},{"metadata":{"segmentName":"All Other Users"},"variant":"off"}],"variants":{}},{"key":"flag3","metadata":{"deployed":true,"evaluationMode":"local","experimentKey":"exp-1","flagType":"experiment","flagVersion":6},"segments":[{"conditions":[[{"op":"set contains any","selector":["context","user","cohort_ids"],"values":["hahahaha3"]}]],"variant":"off"},{"conditions":[[{"op":"set contains any","selector":["context","user","cocoids"],"values":["nohaha"]}]],"variant":"off"},{"metadata":{"segmentName":"All Other Users"},"variant":"off"}],"variants":{}},{"key":"flag5","segments":[{"conditions":[[{"op":"set contains any","selector":["context","user","cohort_ids"],"values":["hahahaha3","hahahaha4"]}]]},{"conditions":[[{"op":"set contains any","selector":["context","groups","org name","cohort_ids"],"values":["hahaorgname1"]}]],"metadata":{"segmentName":"Segment 1"}},{"conditions":[[{"op":"set contains any","selector":["context","gg","org name","cohort_ids"],"values":["nohahaorgname"]}]],"metadata":{"segmentName":"Segment 1"}}],"variants":{}}]""" + private val testFlags = json.decodeFromString>(testFlagsJson) + + @Test + fun `test get grouped cohort ids from flags`() { + val result = testFlags.getGroupedCohortIds() + val expected = mapOf( + "User" to setOf("hahahaha1", "hahahaha2", "hahahaha3", "hahahaha4"), + "org name" to setOf("hahaorgname1"), + ) + assertEquals(expected, result) + } +} diff --git a/gradle.properties b/gradle.properties index 1ee7abd..b37e111 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,23 +1,24 @@ # kotlin kotlin.code.style=official -ktorVersion=2.3.5 -kotlinVersion=1.9.10 -serialziationVersion=1.9.0 +ktorVersion=2.3.12 +kotlinVersion=2.0.0 +coroutinesVersion=1.9.0-RC +serializationVersion=1.7.1 # logging & metrics logbackVersion=1.4.6 prometheusVersion=1.10.5 # amplitude -experimentEvaluationVersion = 2.0.0-beta.2 -amplitudeAnalytics = 1.12.0 -amplitudeAnalyticsJson = 20230227 +experimentEvaluationVersion = 2.1.0 +amplitudeAnalytics = 1.12.2 +amplitudeAnalyticsJson = 20240303 # redis -lettuce = 6.2.3.RELEASE - -# apache -apacheCommons = 1.10.0 +lettuce = 6.3.2.RELEASE # yaml kaml = 0.53.0 + +# mock +mockk = 1.13.9 diff --git a/service/build.gradle.kts b/service/build.gradle.kts index 60dc1b1..dbfb997 100644 --- a/service/build.gradle.kts +++ b/service/build.gradle.kts @@ -2,10 +2,10 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { application - id("io.ktor.plugin") version "2.2.4" - kotlin("jvm") version "1.9.10" - kotlin("plugin.serialization") version "1.9.0" - id("org.jlleitschuh.gradle.ktlint") version "11.3.1" + id("io.ktor.plugin") version "2.3.4" + kotlin("jvm") version "2.0.0" + kotlin("plugin.serialization") version "2.0.0" + id("org.jlleitschuh.gradle.ktlint") version "12.1.1" } application { @@ -34,6 +34,7 @@ val serializationVersion: String by project dependencies { implementation(project(":core")) implementation("com.amplitude:evaluation-core:$experimentEvaluationVersion") + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:$serializationVersion") implementation("io.ktor:ktor-server-call-logging-jvm:$ktorVersion") implementation("io.ktor:ktor-server-core-jvm:$ktorVersion") implementation("io.ktor:ktor-server-metrics-micrometer-jvm:$ktorVersion") diff --git a/service/src/main/kotlin/Server.kt b/service/src/main/kotlin/Server.kt index 9284bf3..ddcc3bb 100644 --- a/service/src/main/kotlin/Server.kt +++ b/service/src/main/kotlin/Server.kt @@ -6,7 +6,6 @@ import com.amplitude.util.json import com.amplitude.util.logger import com.amplitude.util.stringEnv import com.amplitude.util.toAnyMap -import io.ktor.http.HttpStatusCode import io.ktor.serialization.kotlinx.json.json import io.ktor.server.application.Application import io.ktor.server.application.ApplicationCall @@ -110,66 +109,30 @@ fun Application.proxyServer() { * Configure endpoints. */ routing { + // Local Evaluation get("/sdk/v2/flags") { val deployment = this.call.request.getDeploymentKey() - val result = try { - evaluationProxy.getSerializedFlagConfigs(deployment) - } catch (e: HttpErrorResponseException) { - call.respond(HttpStatusCode.fromValue(e.status), e.message) - return@get - } - call.respond(result) + val result = evaluationProxy.getFlagConfigs(deployment) + call.respond(result.status, result.body) } - get("/sdk/v2/cohorts/{cohortId}/description") { - val deployment = this.call.request.getDeploymentKey() + get("/sdk/v1/cohort/{cohortId}") { + val (apiKey, secretKey) = this.call.request.getApiAndSecretKey() val cohortId = this.call.parameters["cohortId"] - val result = try { - evaluationProxy.getSerializedCohortDescription(deployment, cohortId) - } catch (e: HttpErrorResponseException) { - call.respond(HttpStatusCode.fromValue(e.status), e.message) - return@get - } - call.respond(result) - } - - get("/sdk/v2/cohorts/{cohortId}/members") { - val deployment = this.call.request.getDeploymentKey() - val cohortId = this.call.parameters["cohortId"] - val result = try { - evaluationProxy.getSerializedCohortMembers(deployment, cohortId) - } catch (e: HttpErrorResponseException) { - call.respond(HttpStatusCode.fromValue(e.status), e.message) - return@get - } - call.respond(result) - } - - get("/sdk/v2/users/{userId}/cohorts") { - val deployment = this.call.request.getDeploymentKey() - val userId = this.call.parameters["userId"] - val result = try { - evaluationProxy.getSerializedCohortMembershipsForUser(deployment, userId) - } catch (e: HttpErrorResponseException) { - call.respond(HttpStatusCode.fromValue(e.status), e.message) - return@get - } - call.respond(result) + val maxCohortSize = this.call.request.queryParameters["maxCohortSize"]?.toIntOrNull() + val lastModified = this.call.request.queryParameters["lastModified"]?.toLongOrNull() + val result = evaluationProxy.getCohort(apiKey, secretKey, cohortId, lastModified, maxCohortSize) + call.respond(result.status, result.body) } - get("/sdk/v2/groups/{groupType}/{groupName}/cohorts") { + get("/sdk/v2/memberships/{groupType}/{groupName}") { val deployment = this.call.request.getDeploymentKey() val groupType = this.call.parameters["groupType"] val groupName = this.call.parameters["groupName"] - val result = try { - evaluationProxy.getSerializedCohortMembershipsForGroup(deployment, groupType, groupName) - } catch (e: HttpErrorResponseException) { - call.respond(HttpStatusCode.fromValue(e.status), e.message) - return@get - } - call.respond(result) + val result = evaluationProxy.getCohortMemberships(deployment, groupType, groupName) + call.respond(result.status, result.body) } // Remote Evaluation V2 Endpoints @@ -212,34 +175,24 @@ suspend fun ApplicationCall.evaluate( evaluationProxy: EvaluationProxy, userProvider: suspend ApplicationRequest.() -> Map ) { - val result = try { - // Deployment key is included in Authorization header with prefix "Api-Key " - val deploymentKey = request.getDeploymentKey() - val user = request.userProvider() - val flagKeys = request.getFlagKeys() - evaluationProxy.serializedEvaluate(deploymentKey, user, flagKeys) - } catch (e: HttpErrorResponseException) { - respond(HttpStatusCode.fromValue(e.status), e.message) - return - } - respond(result) + // Deployment key is included in Authorization header with prefix "Api-Key " + val deploymentKey = request.getDeploymentKey() + val user = request.userProvider() + val flagKeys = request.getFlagKeys() + val result = evaluationProxy.evaluate(deploymentKey, user, flagKeys) + respond(result.status, result.body) } suspend fun ApplicationCall.evaluateV1( evaluationProxy: EvaluationProxy, userProvider: suspend ApplicationRequest.() -> Map ) { - val result = try { - // Deployment key is included in Authorization header with prefix "Api-Key " - val deploymentKey = request.getDeploymentKey() - val user = request.userProvider() - val flagKeys = request.getFlagKeys() - evaluationProxy.serializedEvaluateV1(deploymentKey, user, flagKeys) - } catch (e: HttpErrorResponseException) { - respond(HttpStatusCode.fromValue(e.status), e.message) - return - } - respond(result) + // Deployment key is included in Authorization header with prefix "Api-Key " + val deploymentKey = request.getDeploymentKey() + val user = request.userProvider() + val flagKeys = request.getFlagKeys() + val result = evaluationProxy.evaluateV1(deploymentKey, user, flagKeys) + respond(result.status, result.body) } /** @@ -253,6 +206,21 @@ private fun ApplicationRequest.getDeploymentKey(): String? { return deploymentKey.substring("Api-Key ".length) } +/** + * Get the API and secret key from the request, included in Authorization header as Basic auth. + */ +private fun ApplicationRequest.getApiAndSecretKey(): Pair { + val authHeaderValue = this.headers["Authorization"] + if (authHeaderValue == null || !authHeaderValue.startsWith("Basic", ignoreCase = true)) { + return null to null + } + val segmentedAuthValue = authHeaderValue.substring("Basic ".length).split(":") + if (segmentedAuthValue.size < 2) { + return null to null + } + return segmentedAuthValue[0] to segmentedAuthValue[1] +} + /** * Get the flag keys from the request. Either contained in header or query params. * Flag keys are used to filter the results to only required flags.