diff --git a/README.md b/README.md index 3d71dd4..588bda2 100644 --- a/README.md +++ b/README.md @@ -18,12 +18,9 @@ The default location for the configuration yaml file is `/etc/evaluation-proxy-c ```yaml projects: - - id: "YOUR PROJECT ID" - apiKey: "YOUR API KEY" - secretKey: " YOUR SECRET KEY" - deploymentKeys: - - "YOUR DEPLOYMENT KEY 1" - - "YOUR DEPLOYMENT KEY 2" + - apiKey: "YOUR API KEY" + secretKey: "YOUR SECRET KEY" + managementKey: "YOUR MANAGEMENT API KEY" configuration: redis: @@ -43,7 +40,7 @@ Use the evaluation proxy [Helm chart](https://github.com/amplitude/evaluation-pr ### Docker Compose Example -Run the container locally with redis persistence using `docker compose`. You must first update the `compose-config.yaml` file with your project and deployment keys before running the composition. +Run the container locally with redis persistence using `docker compose`. You must first update the `compose-config.yaml` file with your project keys before running the composition. ``` docker compose build diff --git a/build.gradle.kts b/build.gradle.kts index fcb0877..ca7559e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,5 +1,5 @@ plugins { - kotlin("jvm") version "1.8.10" + kotlin("jvm") version "1.9.10" id("io.github.gradle-nexus.publish-plugin") version "1.1.0" } diff --git a/compose-config.yaml b/compose-config.yaml index dc777dc..3a8febb 100644 --- a/compose-config.yaml +++ b/compose-config.yaml @@ -1,9 +1,7 @@ projects: - - id: "TODO" - apiKey: "TODO" + - apiKey: "TODO" secretKey: "TODO" - deploymentKeys: - - "TODO" + managementKey: "TODO" configuration: redis: diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 59e81d1..f9c043f 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -1,8 +1,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { - kotlin("jvm") version "1.8.10" - kotlin("plugin.serialization") version "1.8.0" + kotlin("jvm") version "1.9.10" + kotlin("plugin.serialization") version "1.9.0" `maven-publish` signing id("org.jlleitschuh.gradle.ktlint") version "11.3.1" @@ -29,7 +29,6 @@ val kaml: String by project dependencies { implementation("com.amplitude:evaluation-core:$experimentEvaluationVersion") - implementation("com.amplitude:evaluation-serialization:$experimentEvaluationVersion") implementation("com.amplitude:java-sdk:$amplitudeAnalytics") implementation("org.json:json:$amplitudeAnalyticsJson") implementation("io.lettuce:lettuce-core:$lettuce") @@ -53,7 +52,7 @@ publishing { create("core") { groupId = "com.amplitude" artifactId = "evaluation-proxy-core" - version = "0.3.2" + version = "0.4.4" from(components["java"]) pom { name.set("Amplitude Evaluation Proxy") diff --git a/core/src/main/kotlin/Config.kt b/core/src/main/kotlin/Config.kt index f843169..8e079e3 100644 --- a/core/src/main/kotlin/Config.kt +++ b/core/src/main/kotlin/Config.kt @@ -5,25 +5,25 @@ import com.amplitude.util.intEnv import com.amplitude.util.json import com.amplitude.util.longEnv import com.amplitude.util.stringEnv -import com.charleskorn.kaml.Yaml +import com.amplitude.util.yaml import kotlinx.serialization.Serializable import kotlinx.serialization.decodeFromString import java.io.File @Serializable data class ProjectsFile( - val projects: List + val projects: List ) { companion object { fun fromEnv(): ProjectsFile { - val project = Project.fromEnv() + val project = ProjectConfiguration.fromEnv() return ProjectsFile(listOf(project)) } fun fromFile(path: String): ProjectsFile { val data = File(path).readText() return if (path.endsWith(".yaml") || path.endsWith(".yml")) { - Yaml.default.decodeFromString(data) + yaml.decodeFromString(data) } else if (path.endsWith(".json")) { json.decodeFromString(data) } else { @@ -47,7 +47,7 @@ data class ConfigurationFile( fun fromFile(path: String): ConfigurationFile { val data = File(path).readText() return if (path.endsWith(".yaml") || path.endsWith(".yml")) { - Yaml.default.decodeFromString(data) + yaml.decodeFromString(data) } else if (path.endsWith(".json")) { json.decodeFromString(data) } else { @@ -58,19 +58,17 @@ data class ConfigurationFile( } @Serializable -data class Project( - val id: String, +data class ProjectConfiguration( val apiKey: String, val secretKey: String, - val deploymentKeys: Set + val managementKey: String ) { companion object { - fun fromEnv(): Project { - val id = checkNotNull(stringEnv(EnvKey.PROJECT_ID)) { "${EnvKey.PROJECT_ID} environment variable must be set." } + fun fromEnv(): ProjectConfiguration { val apiKey = checkNotNull(stringEnv(EnvKey.API_KEY)) { "${EnvKey.API_KEY} environment variable must be set." } val secretKey = checkNotNull(stringEnv(EnvKey.SECRET_KEY)) { "${EnvKey.SECRET_KEY} environment variable must be set." } - val deploymentKey = checkNotNull(stringEnv(EnvKey.EXPERIMENT_DEPLOYMENT_KEY)) { "${EnvKey.SECRET_KEY} environment variable must be set." } - return Project(id, apiKey, secretKey, setOf(deploymentKey)) + val managementKey = checkNotNull(stringEnv(EnvKey.EXPERIMENT_MANAGEMENT_KEY)) { "${EnvKey.EXPERIMENT_MANAGEMENT_KEY} environment variable must be set." } + return ProjectConfiguration(apiKey, secretKey, managementKey) } } } @@ -80,6 +78,7 @@ data class Configuration( val port: Int = Default.PORT, val serverUrl: String = Default.SERVER_URL, val cohortServerUrl: String = Default.COHORT_SERVER_URL, + val deploymentSyncIntervalMillis: Long = Default.DEPLOYMENT_SYNC_INTERVAL_MILLIS, val flagSyncIntervalMillis: Long = Default.FLAG_SYNC_INTERVAL_MILLIS, val cohortSyncIntervalMillis: Long = Default.COHORT_SYNC_INTERVAL_MILLIS, val maxCohortSize: Int = Default.MAX_COHORT_SIZE, @@ -91,6 +90,10 @@ data class Configuration( port = intEnv(EnvKey.PORT, Default.PORT)!!, serverUrl = stringEnv(EnvKey.SERVER_URL, Default.SERVER_URL)!!, cohortServerUrl = stringEnv(EnvKey.COHORT_SERVER_URL, Default.COHORT_SERVER_URL)!!, + deploymentSyncIntervalMillis = longEnv( + EnvKey.DEPLOYMENT_SYNC_INTERVAL_MILLIS, + Default.DEPLOYMENT_SYNC_INTERVAL_MILLIS + )!!, flagSyncIntervalMillis = longEnv( EnvKey.FLAG_SYNC_INTERVAL_MILLIS, Default.FLAG_SYNC_INTERVAL_MILLIS @@ -164,11 +167,11 @@ object EnvKey { const val SERVER_URL = "AMPLITUDE_SERVER_URL" const val COHORT_SERVER_URL = "AMPLITUDE_COHORT_SERVER_URL" - const val PROJECT_ID = "AMPLITUDE_PROJECT_ID" const val API_KEY = "AMPLITUDE_API_KEY" const val SECRET_KEY = "AMPLITUDE_SECRET_KEY" - const val EXPERIMENT_DEPLOYMENT_KEY = "AMPLITUDE_EXPERIMENT_DEPLOYMENT_KEY" + const val EXPERIMENT_MANAGEMENT_KEY = "AMPLITUDE_EXPERIMENT_MANAGEMENT_API_KEY" + const val DEPLOYMENT_SYNC_INTERVAL_MILLIS = "AMPLITUDE_DEPLOYMENT_SYNC_INTERVAL_MILLIS" const val FLAG_SYNC_INTERVAL_MILLIS = "AMPLITUDE_FLAG_SYNC_INTERVAL_MILLIS" const val COHORT_SYNC_INTERVAL_MILLIS = "AMPLITUDE_COHORT_SYNC_INTERVAL_MILLIS" const val MAX_COHORT_SIZE = "AMPLITUDE_MAX_COHORT_SIZE" @@ -185,8 +188,9 @@ object EnvKey { object Default { const val PORT = 3546 - const val SERVER_URL = "https://api.lab.amplitude.com" + const val SERVER_URL = "https://flag.lab.amplitude.com" const val COHORT_SERVER_URL = "https://cohort.lab.amplitude.com" + const val DEPLOYMENT_SYNC_INTERVAL_MILLIS = 60 * 1000L const val FLAG_SYNC_INTERVAL_MILLIS = 10 * 1000L const val COHORT_SYNC_INTERVAL_MILLIS = 60 * 1000L const val MAX_COHORT_SIZE = Int.MAX_VALUE diff --git a/core/src/main/kotlin/EvaluationProxy.kt b/core/src/main/kotlin/EvaluationProxy.kt index 5741f14..f1080aa 100644 --- a/core/src/main/kotlin/EvaluationProxy.kt +++ b/core/src/main/kotlin/EvaluationProxy.kt @@ -1,21 +1,31 @@ package com.amplitude +import com.amplitude.assignment.AmplitudeAssignmentTracker +import com.amplitude.cohort.CohortDescription +import com.amplitude.cohort.getCohortStorage import com.amplitude.deployment.getDeploymentStorage -import com.amplitude.experiment.evaluation.FlagConfig -import com.amplitude.experiment.evaluation.SkylabUser -import com.amplitude.experiment.evaluation.Variant -import com.amplitude.experiment.evaluation.serialization.SerialFlagConfig -import com.amplitude.experiment.evaluation.serialization.SerialVariant +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.evaluation.EvaluationVariant +import com.amplitude.project.Project +import com.amplitude.project.ProjectApiV1 import com.amplitude.project.ProjectProxy import com.amplitude.project.getProjectStorage import com.amplitude.util.json import com.amplitude.util.logger +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.delay import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.serialization.encodeToString +import kotlin.time.DurationUnit +import kotlin.time.toDuration -const val VERSION = "0.3.2" +const val VERSION = "0.4.4" class HttpErrorResponseException( val status: Int, @@ -24,86 +34,227 @@ class HttpErrorResponseException( ) : Exception(message, cause) class EvaluationProxy( - private val projects: List, - private val configuration: Configuration = Configuration() + private val projectConfigurations: List, + private val configuration: Configuration = Configuration(), + metricsHandler: MetricsHandler? = null ) { companion object { val log by logger() } - private val projectProxies = projects.associateWith { ProjectProxy(it, configuration) } - private val deploymentProxies = projects - .map { project -> project.deploymentKeys.associateWith { project }.toMutableMap() } - .reduce { acc, map -> acc.apply { putAll(map) } } + init { + Metrics.handler = metricsHandler + } + + private val supervisor = SupervisorJob() + private val scope = CoroutineScope(supervisor) + + private val projectProxies = mutableMapOf() + private val apiKeysToProject = mutableMapOf() + private val secretKeysToProject = mutableMapOf() + private val deploymentKeysToProject = mutableMapOf() + private val mutex = Mutex() private val projectStorage = getProjectStorage(configuration.redis) - suspend fun start() = coroutineScope { - log.info("Starting evaluation proxy. projects=${projectProxies.keys.map { it.id }}") - for (project in projects) { - projectStorage.putProject(project.id) + suspend fun start() { + log.info("Starting evaluation proxy.") + /* + * Fetch deployments, setup initial mappings for each project + * configuration, and create the project proxy. + */ + log.info("Setting up ${projectConfigurations.size} project(s)") + for (projectConfiguration in projectConfigurations) { + val projectApi = ProjectApiV1(projectConfiguration.managementKey) + val deployments = projectApi.getDeployments() + if (deployments.isEmpty()) { + continue + } + val projectId = deployments.first().projectId + log.info("Fetched ${deployments.size} deployments for project $projectId") + // Add the project to local mappings. + val project = Project( + id = projectId, + apiKey = projectConfiguration.apiKey, + secretKey = projectConfiguration.secretKey, + managementKey = projectConfiguration.managementKey + ) + apiKeysToProject[project.apiKey] = project + secretKeysToProject[project.secretKey] = project + for (deployment in deployments) { + log.debug("Mapping deployment {} project {}", deployment.key, project.id) + deploymentKeysToProject[deployment.key] = project + } + + // Create a project proxy and add the project to storage. + val assignmentTracker = AmplitudeAssignmentTracker(project.apiKey, configuration.assignment) + val deploymentStorage = getDeploymentStorage(project.id, configuration.redis) + val cohortStorage = getCohortStorage( + project.id, + configuration.redis, + configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS) + ) + val projectProxy = ProjectProxy( + project, + configuration, + assignmentTracker, + deploymentStorage, + cohortStorage + ) + projectProxies[project] = projectProxy + } + + /* + * Update project storage with configured projects, and clean up + * projects that have been removed. + */ + // Add all configured projects to storage + val projectIds = projectProxies.map { it.key.id }.toSet() + for (projectId in projectIds) { + log.debug("Adding project $projectId") + projectStorage.putProject(projectId) } // Remove all non-configured projects and associated data val storageProjectIds = projectStorage.getProjects() - val projectIds = projects.map { it.id }.toSet() for (projectId in storageProjectIds - projectIds) { log.info("Removing project $projectId") - val storage = getDeploymentStorage(projectId, configuration.redis) - val deployments = storage.getDeployments() - for (deployment in deployments) { - log.info("Removing deployment $deployment for project $projectId") - storage.removeDeployment(deployment) - storage.removeFlagConfigs(deployment) + val deploymentStorage = getDeploymentStorage(projectId, configuration.redis) + val cohortStorage = getCohortStorage( + projectId, + configuration.redis, + configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS) + ) + // Remove all deployments for project + val deployments = deploymentStorage.getDeployments() + for ((deploymentKey, _) in deployments) { + log.info("Removing deployment and flag configs for deployment $deploymentKey for project $projectId") + deploymentStorage.removeDeployment(deploymentKey) + deploymentStorage.removeAllFlags(deploymentKey) + } + // Remove all cohorts for project + val cohortDescriptions = cohortStorage.getCohortDescriptions().values + for (cohortDescription in cohortDescriptions) { + cohortStorage.removeCohort(cohortDescription) } projectStorage.removeProject(projectId) } - projectProxies.map { launch { it.value.start() } }.joinAll() + + /* + * Start all project proxies. + */ + projectProxies.map { scope.launch { it.value.start() } }.joinAll() + + /* + * Periodically update the local cache of deployments to project values. + */ + scope.launch { + while (true) { + delay(configuration.deploymentSyncIntervalMillis) + for ((project, projectProxy) in projectProxies) { + val deployments = projectProxy.getDeployments().associateWith { project } + mutex.withLock { deploymentKeysToProject.putAll(deployments) } + } + } + } log.info("Evaluation proxy started.") } suspend fun shutdown() = coroutineScope { log.info("Shutting down evaluation proxy.") projectProxies.map { launch { it.value.shutdown() } }.joinAll() + supervisor.cancelAndJoin() log.info("Evaluation proxy shut down.") } - suspend fun getFlagConfigs(deploymentKey: String?): List { - val project = deploymentProxies[deploymentKey] ?: throw HttpErrorResponseException(404, "Deployment not found.") - val projectProxy = projectProxies[project] ?: throw HttpErrorResponseException(404, "Project not found.") + // Apis + + suspend fun getFlagConfigs(deploymentKey: String?): List { + val projectProxy = getProjectProxy(deploymentKey) return projectProxy.getFlagConfigs(deploymentKey) } + suspend fun getCohortDescription(deploymentKey: String?, cohortId: String?): CohortDescription { + val projectProxy = getProjectProxy(deploymentKey) + return projectProxy.getCohortDescription(cohortId) + } + + suspend fun getCohortMembers(deploymentKey: String?, cohortId: String?): Set { + val projectProxy = getProjectProxy(deploymentKey) + return projectProxy.getCohortMembers(cohortId) + } + suspend fun getCohortMembershipsForUser(deploymentKey: String?, userId: String?): Set { - val project = deploymentProxies[deploymentKey] ?: throw HttpErrorResponseException(404, "Deployment not found.") - val projectProxy = projectProxies[project] ?: throw HttpErrorResponseException(404, "Project not found.") + val projectProxy = getProjectProxy(deploymentKey) return projectProxy.getCohortMembershipsForUser(deploymentKey, userId) } + suspend fun getCohortMembershipsForGroup(deploymentKey: String?, groupType: String?, groupName: String?): Set { + val projectProxy = getProjectProxy(deploymentKey) + return projectProxy.getCohortMembershipsForGroup(deploymentKey, groupType, groupName) + } + suspend fun evaluate( deploymentKey: String?, - user: SkylabUser?, + user: Map?, + flagKeys: Set? = null + ): Map { + val projectProxy = getProjectProxy(deploymentKey) + return Metrics.with({ Evaluation }, { e -> EvaluationFailure(e) }) { + projectProxy.evaluate(deploymentKey, user, flagKeys) + } + } + + suspend fun evaluateV1( + deploymentKey: String?, + user: Map?, flagKeys: Set? = null - ): Map { - val project = deploymentProxies[deploymentKey] ?: throw HttpErrorResponseException(404, "Deployment not found.") - val projectProxy = projectProxies[project] ?: throw HttpErrorResponseException(404, "Project not found.") - return projectProxy.evaluate(deploymentKey, user, flagKeys) + ): Map { + val projectProxy = getProjectProxy(deploymentKey) + return Metrics.with({ Evaluation }, { e -> EvaluationFailure(e) }) { + projectProxy.evaluateV1(deploymentKey, user, flagKeys) + } + } + + // Private + + private suspend fun getProjectProxy(deploymentKey: String?): ProjectProxy { + val cachedProject = mutex.withLock { + deploymentKeysToProject[deploymentKey] + } + if (cachedProject == null) { + log.debug("Unable to find project for deployment {}. Current mappings: {}", deploymentKey, deploymentKeysToProject.mapValues { it.value.id }) + throw HttpErrorResponseException(401, "Invalid deployment key.") + } + return projectProxies[cachedProject] ?: throw HttpErrorResponseException(404, "Project not found.") } } +// Serialized Proxy Calls + +suspend fun EvaluationProxy.getSerializedCohortDescription(deploymentKey: String?, cohortId: String?): String = + json.encodeToString(getCohortDescription(deploymentKey, cohortId)) + +suspend fun EvaluationProxy.getSerializedCohortMembers(deploymentKey: String?, cohortId: String?): String = + json.encodeToString(getCohortMembers(deploymentKey, cohortId)) + suspend fun EvaluationProxy.getSerializedFlagConfigs(deploymentKey: String?): String = - getFlagConfigs(deploymentKey).encodeToJsonString() + json.encodeToString(getFlagConfigs(deploymentKey)) suspend fun EvaluationProxy.getSerializedCohortMembershipsForUser(deploymentKey: String?, userId: String?): String = - getCohortMembershipsForUser(deploymentKey, userId).encodeToJsonString() + json.encodeToString(getCohortMembershipsForUser(deploymentKey, userId)) + +suspend fun EvaluationProxy.getSerializedCohortMembershipsForGroup(deploymentKey: String?, groupType: String?, groupName: String?): String = + json.encodeToString(getCohortMembershipsForGroup(deploymentKey, groupType, groupName)) suspend fun EvaluationProxy.serializedEvaluate( deploymentKey: String?, - user: SkylabUser?, + user: Map?, flagKeys: Set? = null -) = evaluate(deploymentKey, user, flagKeys).encodeToJsonString() +): String = json.encodeToString(evaluate(deploymentKey, user, flagKeys)) -private fun List.encodeToJsonString(): String = json.encodeToString(map { SerialFlagConfig(it) }) -private fun Set.encodeToJsonString(): String = json.encodeToString(this) -private fun Map.encodeToJsonString(): String = - json.encodeToString(mapValues { SerialVariant(it.value) }) +suspend fun EvaluationProxy.serializedEvaluateV1( + deploymentKey: String?, + user: Map?, + flagKeys: Set? = null +): String = json.encodeToString(evaluateV1(deploymentKey, user, flagKeys)) diff --git a/core/src/main/kotlin/Metrics.kt b/core/src/main/kotlin/Metrics.kt new file mode 100644 index 0000000..d187926 --- /dev/null +++ b/core/src/main/kotlin/Metrics.kt @@ -0,0 +1,48 @@ +package com.amplitude + +sealed class Metric +sealed class FailureMetric : Metric() + +interface MetricsHandler { + fun track(metric: Metric) +} + +data object Evaluation : Metric() +data class EvaluationFailure(val exception: Exception) : FailureMetric() +data object AssignmentEvent : Metric() +data object AssignmentEventFilter : Metric() +data object AssignmentEventSend : Metric() +data class AssignmentEventSendFailure(val exception: Exception) : FailureMetric() +data object DeploymentsFetch : Metric() +data class DeploymentsFetchFailure(val exception: Exception) : FailureMetric() +data object FlagsFetch : Metric() +data class FlagsFetchFailure(val exception: Exception) : FailureMetric() +data object CohortDescriptionFetch : Metric() +data class CohortDescriptionFetchFailure(val exception: Exception) : FailureMetric() +data object CohortDownload : Metric() +data class CohortDownloadFailure(val exception: Exception) : FailureMetric() +data object RedisCommand : Metric() +data class RedisCommandFailure(val exception: Exception) : FailureMetric() + +internal object Metrics : MetricsHandler { + + internal var handler: MetricsHandler? = null + + override fun track(metric: Metric) { + handler?.track(metric) + } + + internal suspend fun with( + metric: (() -> Metric)?, + failure: ((e: Exception) -> FailureMetric)?, + block: suspend () -> R + ): R { + try { + metric?.invoke() + return block.invoke() + } catch (e: Exception) { + failure?.invoke(e) + throw e + } + } +} diff --git a/core/src/main/kotlin/assignment/Assignment.kt b/core/src/main/kotlin/assignment/Assignment.kt index e8b43d6..8b96d3b 100644 --- a/core/src/main/kotlin/assignment/Assignment.kt +++ b/core/src/main/kotlin/assignment/Assignment.kt @@ -1,21 +1,23 @@ package com.amplitude.assignment -import com.amplitude.experiment.evaluation.FlagResult -import com.amplitude.experiment.evaluation.SkylabUser +import com.amplitude.experiment.evaluation.EvaluationContext +import com.amplitude.experiment.evaluation.EvaluationVariant +import com.amplitude.util.deviceId +import com.amplitude.util.userId -const val DAY_MILLIS: Long = 24 * 60 * 60 * 1000 +internal const val DAY_MILLIS: Long = 24 * 60 * 60 * 1000 -data class Assignment( - val user: SkylabUser, - val results: Map, +internal data class Assignment( + val context: EvaluationContext, + val results: Map, val timestamp: Long = System.currentTimeMillis() ) -fun Assignment.canonicalize(): String { - val sb = StringBuilder().append(this.user.userId?.trim(), " ", this.user.deviceId?.trim(), " ") +internal fun Assignment.canonicalize(): String { + val sb = StringBuilder().append(this.context.userId()?.trim(), " ", this.context.deviceId()?.trim(), " ") for (key in this.results.keys.sorted()) { - val value = this.results[key] - sb.append(key.trim(), " ", value?.variant?.key?.trim(), " ") + val variant = this.results[key] + sb.append(key.trim(), " ", variant?.key?.trim(), " ") } return sb.toString() } diff --git a/core/src/main/kotlin/assignment/AssignmentFilter.kt b/core/src/main/kotlin/assignment/AssignmentFilter.kt index 03f9706..9090434 100644 --- a/core/src/main/kotlin/assignment/AssignmentFilter.kt +++ b/core/src/main/kotlin/assignment/AssignmentFilter.kt @@ -2,11 +2,11 @@ package com.amplitude.assignment import com.amplitude.util.Cache -interface AssignmentFilter { +internal interface AssignmentFilter { suspend fun shouldTrack(assignment: Assignment): Boolean } -class InMemoryAssignmentFilter(size: Int) : AssignmentFilter { +internal class InMemoryAssignmentFilter(size: Int) : AssignmentFilter { // Cache of canonical assignment to the last sent timestamp. private val cache = Cache(size, DAY_MILLIS) diff --git a/core/src/main/kotlin/assignment/AssignmentTracker.kt b/core/src/main/kotlin/assignment/AssignmentTracker.kt index 8d2d164..18a8cee 100644 --- a/core/src/main/kotlin/assignment/AssignmentTracker.kt +++ b/core/src/main/kotlin/assignment/AssignmentTracker.kt @@ -2,19 +2,39 @@ package com.amplitude.assignment import com.amplitude.Amplitude import com.amplitude.AssignmentConfiguration +import com.amplitude.AssignmentEvent +import com.amplitude.AssignmentEventFilter +import com.amplitude.AssignmentEventSend +import com.amplitude.AssignmentEventSendFailure import com.amplitude.Event -import com.amplitude.experiment.evaluation.FLAG_TYPE_MUTUAL_EXCLUSION_GROUP +import com.amplitude.Metrics +import com.amplitude.util.deviceId +import com.amplitude.util.groups +import com.amplitude.util.logger +import com.amplitude.util.userId import org.json.JSONObject -interface AssignmentTracker { +private object FlagType { + const val RELEASE = "release" + const val EXPERIMENT = "experiment" + const val MUTUAL_EXCLUSION_GROUP = "mutual-exclusion-group" + const val HOLDOUT_GROUP = "holdout-group" + const val RELEASE_GROUP = "release-group" +} + +internal interface AssignmentTracker { suspend fun track(assignment: Assignment) } -class AmplitudeAssignmentTracker( +internal class AmplitudeAssignmentTracker( private val amplitude: Amplitude, private val assignmentFilter: AssignmentFilter ) : AssignmentTracker { + companion object { + val log by logger() + } + constructor( apiKey: String, config: AssignmentConfiguration @@ -29,8 +49,17 @@ class AmplitudeAssignmentTracker( ) override suspend fun track(assignment: Assignment) { - if (assignmentFilter.shouldTrack(assignment)) { - amplitude.logEvent(assignment.toAmplitudeEvent()) + try { + Metrics.track(AssignmentEvent) + if (assignmentFilter.shouldTrack(assignment)) { + Metrics.with({ AssignmentEventSend }, { e -> AssignmentEventSendFailure(e) }) { + amplitude.logEvent(assignment.toAmplitudeEvent()) + } + } else { + Metrics.track(AssignmentEventFilter) + } + } catch (e: Exception) { + log.error("Failed to track assignment event", e) } } } @@ -38,31 +67,40 @@ class AmplitudeAssignmentTracker( internal fun Assignment.toAmplitudeEvent(): Event { val event = Event( "[Experiment] Assignment", - this.user.userId, - this.user.deviceId + this.context.userId(), + this.context.deviceId() ) + val groups = this.context.groups() + if (!groups.isNullOrEmpty()) { + event.groups = JSONObject(groups) + } event.eventProperties = JSONObject().apply { - for ((flagKey, result) in this@toAmplitudeEvent.results) { - put("$flagKey.variant", result.variant.key) - put("$flagKey.details", result.description) + for ((flagKey, variant) in this@toAmplitudeEvent.results) { + val version = variant.metadata?.get("version") + val segmentName = variant.metadata?.get("segmentName") + val details = "v$version rule:$segmentName" + put("$flagKey.variant", variant.key) + put("$flagKey.details", details) } } event.userProperties = JSONObject().apply { val set = JSONObject() val unset = JSONObject() - for ((flagKey, result) in this@toAmplitudeEvent.results) { - if (result.type == FLAG_TYPE_MUTUAL_EXCLUSION_GROUP) { + for ((flagKey, variant) in this@toAmplitudeEvent.results) { + val flagType = variant.metadata?.get("flagType") as? String + val default = variant.metadata?.get("default") as? Boolean ?: false + if (flagType == FlagType.MUTUAL_EXCLUSION_GROUP) { // Dont set user properties for mutual exclusion groups. continue - } else if (result.isDefaultVariant) { + } else if (default) { unset.put("[Experiment] $flagKey", "-") } else { - set.put("[Experiment] $flagKey", result.variant.key) + set.put("[Experiment] $flagKey", variant.key) } } put("\$set", set) put("\$unset", unset) } - event.insertId = "${this.user.userId} ${this.user.deviceId} ${this.canonicalize().hashCode()} ${this.timestamp / DAY_MILLIS}" + event.insertId = "${this.context.userId()} ${this.context.deviceId()} ${this.canonicalize().hashCode()} ${this.timestamp / DAY_MILLIS}" return event } diff --git a/core/src/main/kotlin/cohort/CohortApi.kt b/core/src/main/kotlin/cohort/CohortApi.kt index 54ac19f..841f9eb 100644 --- a/core/src/main/kotlin/cohort/CohortApi.kt +++ b/core/src/main/kotlin/cohort/CohortApi.kt @@ -1,6 +1,6 @@ package com.amplitude.cohort -import com.amplitude.util.HttpErrorResponseException +import com.amplitude.util.HttpErrorException import com.amplitude.util.get import com.amplitude.util.json import com.amplitude.util.logger @@ -17,46 +17,35 @@ import io.ktor.utils.io.jvm.javaio.toInputStream import kotlinx.coroutines.delay import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable -import kotlinx.serialization.decodeFromString import org.apache.commons.csv.CSVFormat import org.apache.commons.csv.CSVParser +import java.lang.IllegalArgumentException import java.util.Base64 @Serializable -private data class SerialCohortDescription( - @SerialName("lastComputed") val lastComputed: Long, - @SerialName("published") val published: Boolean, - @SerialName("archived") val archived: Boolean, - @SerialName("appId") val appId: Int, - @SerialName("lastMod") val lastMod: Long, - @SerialName("type") val type: String, - @SerialName("id") val id: String, - @SerialName("size") val size: Int -) - -@Serializable -private data class SerialSingleCohortDescription( +private data class SerialCohortInfoResponse( @SerialName("cohort_id") val cohortId: String, @SerialName("app_id") val appId: Int = 0, @SerialName("org_id") val orgId: Int = 0, @SerialName("name") val name: String? = null, @SerialName("size") val size: Int = Int.MAX_VALUE, @SerialName("description") val description: String? = null, - @SerialName("last_computed") val lastComputed: Long = 0 + @SerialName("last_computed") val lastComputed: Long = 0, + @SerialName("group_type") val groupType: String = USER_GROUP_TYPE ) @Serializable -data class GetCohortAsyncResponse( +private data class GetCohortAsyncResponse( @SerialName("cohort_id") val cohortId: String, @SerialName("request_id") val requestId: String ) -interface CohortApi { - suspend fun getCohortDescriptions(cohortIds: Set): List +internal interface CohortApi { + suspend fun getCohortDescription(cohortId: String): CohortDescription suspend fun getCohortMembers(cohortDescription: CohortDescription): Set } -class CohortApiV5( +internal class CohortApiV5( private val serverUrl: String, apiKey: String, secretKey: String @@ -69,57 +58,76 @@ class CohortApiV5( private val basicAuth = Base64.getEncoder().encodeToString("$apiKey:$secretKey".toByteArray(Charsets.UTF_8)) private val client = HttpClient(OkHttp) { install(HttpTimeout) { - socketTimeoutMillis = 360000 + socketTimeoutMillis = 30000 } } - override suspend fun getCohortDescriptions(cohortIds: Set): List { - log.debug("getCohortDescriptions: start") - return cohortIds.map { cohortId -> - val response = retry(onFailure = { e -> log.info("Get cohort descriptions failed: $e") }) { - client.get(serverUrl, "/api/3/cohorts/info/$cohortId") { - headers { set("Authorization", "Basic $basicAuth") } - } + override suspend fun getCohortDescription(cohortId: String): CohortDescription { + val response = retry(onFailure = { e -> log.info("Get cohort descriptions failed: $e") }) { + client.get(serverUrl, "/api/3/cohorts/info/$cohortId") { + headers { set("Authorization", "Basic $basicAuth") } } - val serialDescription = json.decodeFromString(response.body()) - CohortDescription( - id = serialDescription.cohortId, - lastComputed = serialDescription.lastComputed, - size = serialDescription.size - ) - }.toList().also { log.debug("getCohortDescriptions: end - result=$it") } + } + val serialDescription = json.decodeFromString(response.body()) + return CohortDescription( + id = serialDescription.cohortId, + lastComputed = serialDescription.lastComputed, + size = serialDescription.size, + groupType = serialDescription.groupType + ) } override suspend fun getCohortMembers(cohortDescription: CohortDescription): Set { - log.debug("getCohortMembers: start - cohortDescription=$cohortDescription") + log.debug("getCohortMembers: start - cohortDescription={}", cohortDescription) // Initiate async cohort download - val initialResponse = client.get(serverUrl, "/api/5/cohorts/request/${cohortDescription.id}") { - headers { set("Authorization", "Basic $basicAuth") } - parameter("lastComputed", cohortDescription.lastComputed) + val initialResponse = retry(onFailure = { e -> log.error("Cohort download request failed: $e") }) { + client.get(serverUrl, "/api/5/cohorts/request/${cohortDescription.id}") { + headers { set("Authorization", "Basic $basicAuth") } + parameter("lastComputed", cohortDescription.lastComputed) + } } val getCohortResponse = json.decodeFromString(initialResponse.body()) - log.debug("getCohortMembers: cohortId=${cohortDescription.id}, requestId=${getCohortResponse.requestId}") + log.debug("getCohortMembers: poll for status - cohortId=${cohortDescription.id}, requestId=${getCohortResponse.requestId}") // Poll until the cohort is ready for download while (true) { - val statusResponse = + val statusResponse = retry(onFailure = { e -> log.error("Cohort request status failed: $e") }) { client.get(serverUrl, "/api/5/cohorts/request-status/${getCohortResponse.requestId}") { headers { set("Authorization", "Basic $basicAuth") } } - log.debug("getCohortMembers: cohortId=${cohortDescription.id}, status=${statusResponse.status}") + } + log.trace("getCohortMembers: cohortId={}, status={}", cohortDescription.id, statusResponse.status) if (statusResponse.status == HttpStatusCode.OK) { break } else if (statusResponse.status != HttpStatusCode.Accepted) { - throw HttpErrorResponseException(statusResponse.status) + throw HttpErrorException(statusResponse.status, statusResponse) } - delay(1000) + delay(5000) } // Download the cohort - val downloadResponse = + log.debug("getCohortMembers: download cohort - cohortId=${cohortDescription.id}, requestId=${getCohortResponse.requestId}") + val downloadResponse = retry(onFailure = { e -> log.error("Cohort file download failed: $e") }) { client.get(serverUrl, "/api/5/cohorts/request/${getCohortResponse.requestId}/file") { headers { set("Authorization", "Basic $basicAuth") } } + } + // Parse the csv response val csv = CSVParser.parse(downloadResponse.bodyAsChannel().toInputStream(), Charsets.UTF_8, csvFormat) - return csv.map { it.get("user_id") }.filterNot { it.isNullOrEmpty() }.toSet() - .also { log.debug("getCohortMembers: end - cohortId=${cohortDescription.id}, resultSize=${it.size}") } + return if (cohortDescription.groupType == USER_GROUP_TYPE) { + csv.map { it.get("user_id") }.filterNot { it.isNullOrEmpty() }.toSet() + } else { + csv.map { + try { + // CSV returned from API has all strings prefixed with a tab character + it.get("\tgroup_value") + } catch (e: IllegalArgumentException) { + it.get("group_value") + } + }.filterNot { + it.isNullOrEmpty() + }.map { + // CSV returned from API has all strings prefixed with a tab character + it.removePrefix("\t") + }.toSet() + }.also { log.debug("getCohortMembers: end - resultSize=${it.size}") } } } diff --git a/core/src/main/kotlin/cohort/CohortDescription.kt b/core/src/main/kotlin/cohort/CohortDescription.kt index df27abf..357ca26 100644 --- a/core/src/main/kotlin/cohort/CohortDescription.kt +++ b/core/src/main/kotlin/cohort/CohortDescription.kt @@ -2,9 +2,12 @@ package com.amplitude.cohort import kotlinx.serialization.Serializable +const val USER_GROUP_TYPE = "User" + @Serializable data class CohortDescription( val id: String, val lastComputed: Long, - val size: Int + val size: Int, + val groupType: String = USER_GROUP_TYPE ) diff --git a/core/src/main/kotlin/cohort/CohortLoader.kt b/core/src/main/kotlin/cohort/CohortLoader.kt index 66753bb..da5eecc 100644 --- a/core/src/main/kotlin/cohort/CohortLoader.kt +++ b/core/src/main/kotlin/cohort/CohortLoader.kt @@ -1,13 +1,19 @@ package com.amplitude.cohort +import com.amplitude.CohortDescriptionFetch +import com.amplitude.CohortDescriptionFetchFailure +import com.amplitude.CohortDownload +import com.amplitude.CohortDownloadFailure +import com.amplitude.Metrics import com.amplitude.util.logger import kotlinx.coroutines.Job import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -class CohortLoader( +internal class CohortLoader( @Volatile var maxCohortSize: Int, private val cohortApi: CohortApi, private val cohortStorage: CohortStorage @@ -19,38 +25,40 @@ class CohortLoader( private val jobsMutex = Mutex() private val jobs = mutableMapOf() - suspend fun loadCohorts(cohortIds: Set, state: Set = cohortIds) = coroutineScope { - log.debug("loadCohorts: start - cohortIds=$cohortIds") - - // Get cohort descriptions from storage and network. - val networkCohortDescriptions = cohortApi.getCohortDescriptions(state) - - // Filter cohorts received from network. Removes cohorts which are: - // 1. Not requested for management by this function. - // 2. Larger than the max size. - // 3. Are equal to what has been downloaded already. - val cohorts = networkCohortDescriptions.filter { networkCohortDescription -> - val storageDescription = cohortStorage.getCohortDescription(networkCohortDescription.id) - cohortIds.contains(networkCohortDescription.id) && - networkCohortDescription.size <= maxCohortSize && - networkCohortDescription.lastComputed > (storageDescription?.lastComputed ?: -1) + suspend fun loadCohorts(cohortIds: Set) = coroutineScope { + val jobs = mutableListOf() + for (cohortId in cohortIds) { + jobs += launch { loadCohort(cohortId) } } - log.debug("loadCohorts: filtered network descriptions - $cohorts") + jobs.joinAll() + } - // Download and store each cohort if a download job has not already been started. - for (cohort in cohorts) { - val job = jobsMutex.withLock { - jobs.getOrPut(cohort.id) { + private suspend fun loadCohort(cohortId: String) = coroutineScope { + log.trace("loadCohort: start - cohortId={}", cohortId) + val networkCohort = Metrics.with( + { CohortDescriptionFetch }, + { e -> CohortDescriptionFetchFailure(e) } + ) { + cohortApi.getCohortDescription(cohortId) + } + val storageCohort = cohortStorage.getCohortDescription(cohortId) + val shouldDownloadCohort = networkCohort.size <= maxCohortSize && + networkCohort.lastComputed > (storageCohort?.lastComputed ?: -1) + if (shouldDownloadCohort) { + jobsMutex.withLock { + jobs.getOrPut(cohortId) { launch { - log.info("Downloading cohort. $cohort") - val cohortMembers = cohortApi.getCohortMembers(cohort) - cohortStorage.putCohort(cohort, cohortMembers) - jobsMutex.withLock { jobs.remove(cohort.id) } + log.info("Downloading cohort. $networkCohort") + val cohortMembers = Metrics.with({ CohortDownload }, { e -> CohortDownloadFailure(e) }) { + cohortApi.getCohortMembers(networkCohort) + } + cohortStorage.putCohort(networkCohort, cohortMembers) + jobsMutex.withLock { jobs.remove(cohortId) } + log.info("Cohort download complete. $networkCohort") } } - } - job.join() + }.join() } - log.debug("loadCohorts: end - cohortIds=$cohortIds") + log.trace("loadCohort: end - cohortId={}", cohortId) } } diff --git a/core/src/main/kotlin/cohort/CohortStorage.kt b/core/src/main/kotlin/cohort/CohortStorage.kt index 7037e66..0a9d8be 100644 --- a/core/src/main/kotlin/cohort/CohortStorage.kt +++ b/core/src/main/kotlin/cohort/CohortStorage.kt @@ -6,30 +6,39 @@ import com.amplitude.util.RedisKey import com.amplitude.util.json import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import kotlinx.serialization.decodeFromString import kotlinx.serialization.encodeToString import kotlin.time.Duration -interface CohortStorage { +internal interface CohortStorage { suspend fun getCohortDescription(cohortId: String): CohortDescription? suspend fun getCohortDescriptions(): Map + suspend fun getCohortMembers(cohortDescription: CohortDescription): Set? suspend fun getCohortMembershipsForUser(userId: String, cohortIds: Set? = null): Set + suspend fun getCohortMembershipsForGroup( + groupType: String, + groupName: String, + cohortIds: Set? = null + ): Set suspend fun putCohort(description: CohortDescription, members: Set) suspend fun removeCohort(cohortDescription: CohortDescription) } -fun getCohortStorage(projectId: String, redisConfiguration: RedisConfiguration?, ttl: Duration): CohortStorage { +internal fun getCohortStorage(projectId: String, redisConfiguration: RedisConfiguration?, ttl: Duration): CohortStorage { val uri = redisConfiguration?.uri - val readOnlyUri = redisConfiguration?.readOnlyUri ?: uri - val prefix = redisConfiguration?.prefix - return if (uri == null || readOnlyUri == null || prefix == null) { + return if (uri == null) { InMemoryCohortStorage() } else { - RedisCohortStorage(uri, readOnlyUri, prefix, projectId, ttl) + val redis = RedisConnection(uri) + val readOnlyRedis = if (redisConfiguration.readOnlyUri != null) { + RedisConnection(redisConfiguration.readOnlyUri) + } else { + redis + } + RedisCohortStorage(projectId, ttl, redisConfiguration.prefix, redis, readOnlyRedis) } } -class InMemoryCohortStorage : CohortStorage { +internal class InMemoryCohortStorage : CohortStorage { private class Cohort( val description: CohortDescription, @@ -47,6 +56,10 @@ class InMemoryCohortStorage : CohortStorage { return lock.withLock { cohorts.mapValues { it.value.description } } } + override suspend fun getCohortMembers(cohortDescription: CohortDescription): Set? { + return lock.withLock { cohorts[cohortDescription.id]?.members } + } + override suspend fun putCohort(description: CohortDescription, members: Set) { return lock.withLock { cohorts[description.id] = Cohort(description, members) } } @@ -65,35 +78,82 @@ class InMemoryCohortStorage : CohortStorage { }.toSet() } } + + override suspend fun getCohortMembershipsForGroup( + groupType: String, + groupName: String, + cohortIds: Set? + ): Set { + return lock.withLock { + (cohortIds ?: cohorts.keys).mapNotNull { id -> + val cohort = cohorts[id] + if (cohort?.description?.groupType != groupType) { + null + } else { + when (cohort.members.contains(groupName)) { + true -> id + else -> null + } + } + }.toSet() + } + } } -class RedisCohortStorage( - uri: String, - readOnlyUri: String, - prefix: String, +internal class RedisCohortStorage( private val projectId: String, - private val ttl: Duration + private val ttl: Duration, + private val prefix: String, + private val redis: RedisConnection, + private val readOnlyRedis: RedisConnection ) : CohortStorage { - private val redis = RedisConnection(uri, prefix) - private val readOnlyRedis = RedisConnection(readOnlyUri, prefix) - override suspend fun getCohortDescription(cohortId: String): CohortDescription? { - val jsonEncodedDescription = redis.hget(RedisKey.CohortDescriptions(projectId), cohortId) ?: return null + val jsonEncodedDescription = redis.hget(RedisKey.CohortDescriptions(prefix, projectId), cohortId) ?: return null return json.decodeFromString(jsonEncodedDescription) } override suspend fun getCohortDescriptions(): Map { - val jsonEncodedDescriptions = redis.hgetall(RedisKey.CohortDescriptions(projectId)) + val jsonEncodedDescriptions = redis.hgetall(RedisKey.CohortDescriptions(prefix, projectId)) return jsonEncodedDescriptions?.mapValues { json.decodeFromString(it.value) } ?: mapOf() } + override suspend fun getCohortMembers(cohortDescription: CohortDescription): Set? { + return redis.smembers(RedisKey.CohortMembers(prefix, projectId, cohortDescription)) + } + override suspend fun getCohortMembershipsForUser(userId: String, cohortIds: Set?): Set { val descriptions = getCohortDescriptions() val memberships = mutableSetOf() for (description in descriptions.values) { + if (cohortIds != null && !cohortIds.contains(description.id)) { + continue + } + // High volume, use read connection + val isMember = readOnlyRedis.sismember(RedisKey.CohortMembers(prefix, projectId, description), userId) + if (isMember) { + memberships += description.id + } + } + return memberships + } + + override suspend fun getCohortMembershipsForGroup( + groupType: String, + groupName: String, + cohortIds: Set? + ): Set { + val descriptions = getCohortDescriptions() + val memberships = mutableSetOf() + for (description in descriptions.values) { + if (cohortIds != null && !cohortIds.contains(description.id)) { + continue + } + if (description.groupType != groupType) { + continue + } // High volume, use read connection - val isMember = readOnlyRedis.sismember(RedisKey.CohortMembers(projectId, description), userId) + val isMember = readOnlyRedis.sismember(RedisKey.CohortMembers(prefix, projectId, description), groupName) if (isMember) { memberships += description.id } @@ -105,16 +165,16 @@ class RedisCohortStorage( val jsonEncodedDescription = json.encodeToString(description) val existingDescription = getCohortDescription(description.id) if ((existingDescription?.lastComputed ?: 0L) < description.lastComputed) { - redis.sadd(RedisKey.CohortMembers(projectId, description), members) - redis.hset(RedisKey.CohortDescriptions(projectId), mapOf(description.id to jsonEncodedDescription)) + redis.sadd(RedisKey.CohortMembers(prefix, projectId, description), members) + redis.hset(RedisKey.CohortDescriptions(prefix, projectId), mapOf(description.id to jsonEncodedDescription)) if (existingDescription != null) { - redis.expire(RedisKey.CohortMembers(projectId, existingDescription), ttl) + redis.expire(RedisKey.CohortMembers(prefix, projectId, existingDescription), ttl) } } } override suspend fun removeCohort(cohortDescription: CohortDescription) { - redis.hdel(RedisKey.CohortDescriptions(projectId), cohortDescription.id) - redis.del(RedisKey.CohortMembers(projectId, cohortDescription)) + redis.hdel(RedisKey.CohortDescriptions(prefix, projectId), cohortDescription.id) + redis.del(RedisKey.CohortMembers(prefix, projectId, cohortDescription)) } } diff --git a/core/src/main/kotlin/deployment/Deployment.kt b/core/src/main/kotlin/deployment/Deployment.kt new file mode 100644 index 0000000..e65d7e2 --- /dev/null +++ b/core/src/main/kotlin/deployment/Deployment.kt @@ -0,0 +1,11 @@ +package com.amplitude.deployment + +import kotlinx.serialization.Serializable + +@Serializable +internal data class Deployment( + val id: String, + val projectId: String, + val label: String, + val key: String +) diff --git a/core/src/main/kotlin/deployment/DeploymentApi.kt b/core/src/main/kotlin/deployment/DeploymentApi.kt index 7762f3f..a0b3ab8 100644 --- a/core/src/main/kotlin/deployment/DeploymentApi.kt +++ b/core/src/main/kotlin/deployment/DeploymentApi.kt @@ -1,8 +1,7 @@ package com.amplitude.deployment import com.amplitude.VERSION -import com.amplitude.experiment.evaluation.FlagConfig -import com.amplitude.experiment.evaluation.serialization.SerialFlagConfig +import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.util.get import com.amplitude.util.json import com.amplitude.util.logger @@ -11,13 +10,13 @@ import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.request.headers -import kotlinx.serialization.decodeFromString +import io.ktor.client.request.parameter -interface DeploymentApi { - suspend fun getFlagConfigs(deploymentKey: String): List +internal interface DeploymentApi { + suspend fun getFlagConfigs(deploymentKey: String): List } -class DeploymentApiV1( +internal class DeploymentApiV1( private val serverUrl: String ) : DeploymentApi { @@ -27,19 +26,19 @@ class DeploymentApiV1( private val client = HttpClient(OkHttp) - override suspend fun getFlagConfigs(deploymentKey: String): List { - log.debug("getFlagConfigs: start - deploymentKey=$deploymentKey") - val response = retry(onFailure = { e -> log.info("Get flag configs failed: $e") }) { - client.get(serverUrl, "/sdk/v1/flags") { + override suspend fun getFlagConfigs(deploymentKey: String): List { + log.trace("getFlagConfigs: start - deploymentKey=$deploymentKey") + val response = retry(onFailure = { e -> log.error("Get flag configs failed: $e") }) { + client.get(serverUrl, "/sdk/v2/flags") { + parameter("v", "0") headers { set("Authorization", "Api-Key $deploymentKey") set("X-Amp-Exp-Library", "experiment-local-proxy/$VERSION") } } } - val body = json.decodeFromString>(response.body()) - return body.map { it.convert() }.also { - log.debug("getFlagConfigs: end - deploymentKey=$deploymentKey") + return json.decodeFromString>(response.body()).also { + log.trace("getFlagConfigs: end - deploymentKey=$deploymentKey") } } } diff --git a/core/src/main/kotlin/deployment/DeploymentLoader.kt b/core/src/main/kotlin/deployment/DeploymentLoader.kt index 3994d55..df4348b 100644 --- a/core/src/main/kotlin/deployment/DeploymentLoader.kt +++ b/core/src/main/kotlin/deployment/DeploymentLoader.kt @@ -1,7 +1,10 @@ package com.amplitude.deployment +import com.amplitude.FlagsFetch +import com.amplitude.FlagsFetchFailure +import com.amplitude.Metrics import com.amplitude.cohort.CohortLoader -import com.amplitude.util.getCohortIds +import com.amplitude.util.getAllCohortIds import com.amplitude.util.logger import kotlinx.coroutines.Job import kotlinx.coroutines.coroutineScope @@ -9,7 +12,7 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -class DeploymentLoader( +internal class DeploymentLoader( private val deploymentApi: DeploymentApi, private val deploymentStorage: DeploymentStorage, private val cohortLoader: CohortLoader @@ -23,23 +26,38 @@ class DeploymentLoader( private val jobs = mutableMapOf() suspend fun loadDeployment(deploymentKey: String) = coroutineScope { - log.debug("loadDeployment: - deploymentKey=$deploymentKey") + log.trace("loadDeployment: - deploymentKey=$deploymentKey") jobsMutex.withLock { jobs.getOrPut(deploymentKey) { launch { - val networkFlagConfigs = deploymentApi.getFlagConfigs(deploymentKey) - val storageFlagConfigs = deploymentStorage.getFlagConfigs(deploymentKey) ?: listOf() - val networkCohortIds = networkFlagConfigs.getCohortIds() - val storageCohortIds = storageFlagConfigs.getCohortIds() - val addedCohortIds = networkCohortIds - storageCohortIds - if (addedCohortIds.isNotEmpty()) { - cohortLoader.loadCohorts(addedCohortIds, networkCohortIds) + val networkFlags = Metrics.with({ FlagsFetch }, { e -> FlagsFetchFailure(e) }) { + deploymentApi.getFlagConfigs(deploymentKey) } - deploymentStorage.putFlagConfigs(deploymentKey, networkFlagConfigs) - jobs.remove(deploymentKey) + // Remove flags that are no longer deployed. + val networkFlagKeys = networkFlags.map { it.key }.toSet() + val storageFlagKeys = deploymentStorage.getAllFlags(deploymentKey).map { it.key }.toSet() + for (flagToRemove in storageFlagKeys - networkFlagKeys) { + log.debug("Removing flag: $flagToRemove") + deploymentStorage.removeFlag(deploymentKey, flagToRemove) + } + // Load cohorts for each flag independently then put the + // flag into storage. + for (flag in networkFlags) { + val cohortIds = flag.getAllCohortIds() + if (cohortIds.isNotEmpty()) { + launch { + cohortLoader.loadCohorts(cohortIds) + deploymentStorage.putFlag(deploymentKey, flag) + } + } else { + deploymentStorage.putFlag(deploymentKey, flag) + } + } + // Remove the job + jobsMutex.withLock { jobs.remove(deploymentKey) } } } }.join() - log.debug("loadDeployment: end - deploymentKey=$deploymentKey") + log.trace("loadDeployment: end - deploymentKey=$deploymentKey") } } diff --git a/core/src/main/kotlin/deployment/DeploymentRunner.kt b/core/src/main/kotlin/deployment/DeploymentRunner.kt index 3f1a7f2..1a5d751 100644 --- a/core/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/core/src/main/kotlin/deployment/DeploymentRunner.kt @@ -2,7 +2,7 @@ package com.amplitude.deployment import com.amplitude.Configuration import com.amplitude.cohort.CohortLoader -import com.amplitude.util.getCohortIds +import com.amplitude.util.getAllCohortIds import com.amplitude.util.logger import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.SupervisorJob @@ -10,7 +10,7 @@ import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay import kotlinx.coroutines.launch -class DeploymentRunner( +internal class DeploymentRunner( @Volatile var configuration: Configuration, private val deploymentKey: String, private val deploymentApi: DeploymentApi, @@ -27,7 +27,7 @@ class DeploymentRunner( private val deploymentLoader = DeploymentLoader(deploymentApi, deploymentStorage, cohortLoader) suspend fun start() { - log.debug("start: - deploymentKey=$deploymentKey") + log.trace("start: - deploymentKey=$deploymentKey") deploymentLoader.loadDeployment(deploymentKey) // Periodic flag config loader scope.launch { @@ -40,10 +40,8 @@ class DeploymentRunner( scope.launch { while (true) { delay(configuration.cohortSyncIntervalMillis) - val cohortIds = deploymentStorage.getFlagConfigs(deploymentKey)?.getCohortIds() - if (!cohortIds.isNullOrEmpty()) { - cohortLoader.loadCohorts(cohortIds) - } + val cohortIds = deploymentStorage.getAllFlags(deploymentKey).values.getAllCohortIds() + cohortLoader.loadCohorts(cohortIds) } } } diff --git a/core/src/main/kotlin/deployment/DeploymentStorage.kt b/core/src/main/kotlin/deployment/DeploymentStorage.kt index c786d05..6164fcb 100644 --- a/core/src/main/kotlin/deployment/DeploymentStorage.kt +++ b/core/src/main/kotlin/deployment/DeploymentStorage.kt @@ -1,149 +1,168 @@ package com.amplitude.deployment import com.amplitude.RedisConfiguration -import com.amplitude.experiment.evaluation.FlagConfig -import com.amplitude.experiment.evaluation.serialization.SerialFlagConfig +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.util.Redis import com.amplitude.util.RedisConnection import com.amplitude.util.RedisKey import com.amplitude.util.json -import kotlinx.coroutines.channels.BufferOverflow -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import kotlinx.serialization.decodeFromString import kotlinx.serialization.encodeToString -interface DeploymentStorage { - val deployments: Flow> - suspend fun getDeployments(): Set - suspend fun putDeployment(deploymentKey: String) +internal interface DeploymentStorage { + suspend fun getDeployment(deploymentKey: String): Deployment? + suspend fun getDeployments(): Map + suspend fun putDeployment(deployment: Deployment) suspend fun removeDeployment(deploymentKey: String) - suspend fun getFlagConfigs(deploymentKey: String): List? - suspend fun putFlagConfigs(deploymentKey: String, flagConfigs: List) - suspend fun removeFlagConfigs(deploymentKey: String) + suspend fun getFlag(deploymentKey: String, flagKey: String): EvaluationFlag? + suspend fun getAllFlags(deploymentKey: String): Map + suspend fun putFlag(deploymentKey: String, flag: EvaluationFlag) + suspend fun putAllFlags(deploymentKey: String, flags: List) + suspend fun removeFlag(deploymentKey: String, flagKey: String) + suspend fun removeAllFlags(deploymentKey: String) } -fun getDeploymentStorage(projectId: String, redisConfiguration: RedisConfiguration?): DeploymentStorage { +internal fun getDeploymentStorage(projectId: String, redisConfiguration: RedisConfiguration?): DeploymentStorage { val uri = redisConfiguration?.uri - val readOnlyUri = redisConfiguration?.readOnlyUri ?: uri - val prefix = redisConfiguration?.prefix - return if (uri == null || readOnlyUri == null || prefix == null) { + return if (uri == null) { InMemoryDeploymentStorage() } else { - RedisDeploymentStorage(uri, readOnlyUri, prefix, projectId) + val redis = RedisConnection(uri) + val readOnlyRedis = if (redisConfiguration.readOnlyUri != null) { + RedisConnection(redisConfiguration.readOnlyUri) + } else { + redis + } + RedisDeploymentStorage(redisConfiguration.prefix, projectId, redis, readOnlyRedis) } } -class InMemoryDeploymentStorage : DeploymentStorage { +internal class InMemoryDeploymentStorage : DeploymentStorage { - override val deployments = MutableSharedFlow>( - extraBufferCapacity = 1, - onBufferOverflow = BufferOverflow.DROP_OLDEST - ) + private val mutex = Mutex() - private val lock = Mutex() - private val deploymentStorage = mutableMapOf?>() + private val deploymentStorage = mutableMapOf() + private val flagStorage = mutableMapOf>() + override suspend fun getDeployment(deploymentKey: String): Deployment? { + return mutex.withLock { + deploymentStorage[deploymentKey] + } + } - override suspend fun getDeployments(): Set { - return lock.withLock { - deploymentStorage.keys.toSet() + override suspend fun getDeployments(): Map { + return mutex.withLock { + deploymentStorage.toMap() } } - override suspend fun putDeployment(deploymentKey: String) { - return lock.withLock { - deploymentStorage[deploymentKey] = null - deployments.emit(deploymentStorage.keys) + override suspend fun putDeployment(deployment: Deployment) { + mutex.withLock { + deploymentStorage[deployment.key] = deployment } } override suspend fun removeDeployment(deploymentKey: String) { - return lock.withLock { + return mutex.withLock { deploymentStorage.remove(deploymentKey) - deployments.emit(deploymentStorage.keys) + flagStorage.remove(deploymentKey) } } - override suspend fun getFlagConfigs(deploymentKey: String): List? { - return lock.withLock { - deploymentStorage[deploymentKey] + override suspend fun getFlag(deploymentKey: String, flagKey: String): EvaluationFlag? { + return mutex.withLock { + flagStorage[deploymentKey]?.get(flagKey) } } - override suspend fun putFlagConfigs(deploymentKey: String, flagConfigs: List) { - lock.withLock { - deploymentStorage[deploymentKey] = flagConfigs + override suspend fun getAllFlags(deploymentKey: String): Map { + return mutex.withLock { + flagStorage[deploymentKey] ?: mapOf() } } - override suspend fun removeFlagConfigs(deploymentKey: String) { - lock.withLock { - deploymentStorage.remove(deploymentKey) + override suspend fun putFlag(deploymentKey: String, flag: EvaluationFlag) { + return mutex.withLock { + flagStorage.getOrPut(deploymentKey) { mutableMapOf() }[flag.key] = flag } } -} -class RedisDeploymentStorage( - uri: String, - readOnlyUri: String, - prefix: String, - private val projectId: String -) : DeploymentStorage { + override suspend fun putAllFlags(deploymentKey: String, flags: List) { + return mutex.withLock { + flagStorage.getOrPut(deploymentKey) { mutableMapOf() }.putAll(flags.associateBy { it.key }) + } + } - private val redis = RedisConnection(uri, prefix) - private val readOnlyRedis = RedisConnection(readOnlyUri, prefix) + override suspend fun removeFlag(deploymentKey: String, flagKey: String) { + return mutex.withLock { + flagStorage[deploymentKey]?.remove(flagKey) + } + } - // TODO Could be optimized w/ pub sub - override val deployments = MutableSharedFlow>( - extraBufferCapacity = 1, - onBufferOverflow = BufferOverflow.DROP_OLDEST - ) + override suspend fun removeAllFlags(deploymentKey: String) { + return mutex.withLock { + flagStorage.remove(deploymentKey) + } + } +} - private val mutex = Mutex() - private val flagConfigCache: MutableList = mutableListOf() +internal class RedisDeploymentStorage( + private val prefix: String, + private val projectId: String, + private val redis: Redis, + private val readOnlyRedis: Redis +) : DeploymentStorage { + override suspend fun getDeployment(deploymentKey: String): Deployment? { + val deploymentJson = redis.hget(RedisKey.Deployments(prefix, projectId), deploymentKey) ?: return null + return json.decodeFromString(deploymentJson) + } - override suspend fun getDeployments(): Set { - return redis.smembers(RedisKey.Deployments(projectId)) ?: emptySet() + override suspend fun getDeployments(): Map { + return redis.hgetall(RedisKey.Deployments(prefix, projectId)) + ?.mapValues { json.decodeFromString(it.value) } ?: mapOf() } - override suspend fun putDeployment(deploymentKey: String) { - redis.sadd(RedisKey.Deployments(projectId), setOf(deploymentKey)) - deployments.emit(getDeployments()) + override suspend fun putDeployment(deployment: Deployment) { + val deploymentJson = json.encodeToString(deployment) + redis.hset(RedisKey.Deployments(prefix, projectId), mapOf(deployment.key to deploymentJson)) } override suspend fun removeDeployment(deploymentKey: String) { - redis.srem(RedisKey.Deployments(projectId), deploymentKey) - deployments.emit(getDeployments()) + redis.hdel(RedisKey.Deployments(prefix, projectId), deploymentKey) + } + + override suspend fun getFlag(deploymentKey: String, flagKey: String): EvaluationFlag? { + val flagJson = redis.hget(RedisKey.FlagConfigs(prefix, projectId, deploymentKey), flagKey) ?: return null + return json.decodeFromString(flagJson) } - override suspend fun getFlagConfigs(deploymentKey: String): List? { + // TODO Add in memory caching w/ invalidation + override suspend fun getAllFlags(deploymentKey: String): Map { // High volume, use read only redis - val jsonEncodedFlags = readOnlyRedis.get(RedisKey.FlagConfigs(projectId, deploymentKey)) ?: return null - return json.decodeFromString>(jsonEncodedFlags).map { it.convert() } - } - - override suspend fun putFlagConfigs(deploymentKey: String, flagConfigs: List) { - // Optimization so repeat puts don't update the data to the same value in redis. - val changed = mutex.withLock { - if (flagConfigs != flagConfigCache) { - flagConfigCache.clear() - flagConfigCache += flagConfigs - true - } else { - false - } - } - if (changed) { - val jsonEncodedFlags = json.encodeToString(flagConfigs.map { SerialFlagConfig(it) }) - redis.set(RedisKey.FlagConfigs(projectId, deploymentKey), jsonEncodedFlags) + return readOnlyRedis.hgetall(RedisKey.FlagConfigs(prefix, projectId, deploymentKey)) + ?.mapValues { json.decodeFromString(it.value) } ?: mapOf() + } + + override suspend fun putFlag(deploymentKey: String, flag: EvaluationFlag) { + val flagJson = json.encodeToString(flag) + redis.hset(RedisKey.FlagConfigs(prefix, projectId, deploymentKey), mapOf(flag.key to flagJson)) + } + + override suspend fun putAllFlags(deploymentKey: String, flags: List) { + for (flag in flags) { + putFlag(deploymentKey, flag) } } - override suspend fun removeFlagConfigs(deploymentKey: String) { - redis.del(RedisKey.FlagConfigs(projectId, deploymentKey)) - mutex.withLock { - flagConfigCache.clear() + override suspend fun removeFlag(deploymentKey: String, flagKey: String) { + redis.hdel(RedisKey.FlagConfigs(prefix, projectId, deploymentKey), flagKey) + } + + override suspend fun removeAllFlags(deploymentKey: String) { + val redisKey = RedisKey.FlagConfigs(prefix, projectId, deploymentKey) + val flags = redis.hgetall(RedisKey.FlagConfigs(prefix, projectId, deploymentKey)) ?: return + for (key in flags.keys) { + redis.hdel(redisKey, key) } } } diff --git a/core/src/main/kotlin/project/Project.kt b/core/src/main/kotlin/project/Project.kt new file mode 100644 index 0000000..614cf50 --- /dev/null +++ b/core/src/main/kotlin/project/Project.kt @@ -0,0 +1,8 @@ +package com.amplitude.project + +internal data class Project( + val id: String, + val apiKey: String, + val secretKey: String, + val managementKey: String +) diff --git a/core/src/main/kotlin/project/ProjectApi.kt b/core/src/main/kotlin/project/ProjectApi.kt new file mode 100644 index 0000000..43f300e --- /dev/null +++ b/core/src/main/kotlin/project/ProjectApi.kt @@ -0,0 +1,71 @@ +package com.amplitude.project + +import com.amplitude.DeploymentsFetch +import com.amplitude.DeploymentsFetchFailure +import com.amplitude.Metrics +import com.amplitude.deployment.Deployment +import com.amplitude.util.get +import com.amplitude.util.json +import com.amplitude.util.logger +import com.amplitude.util.retry +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.request.headers +import kotlinx.serialization.Serializable + +private const val MANAGEMENT_SERVER_URL = "https://experiment.amplitude.com" + +@Serializable +private data class DeploymentsResponse( + val deployments: List +) + +@Serializable +internal data class SerialDeployment( + val id: String, + val projectId: String, + val label: String, + val key: String, + val deleted: Boolean +) + +private fun SerialDeployment.toDeployment(): Deployment? { + if (deleted) return null + return Deployment(id, projectId, label, key) +} + +internal interface ProjectApi { + suspend fun getDeployments(): List +} + +internal class ProjectApiV1(private val managementKey: String) : ProjectApi { + + companion object { + val log by logger() + } + + private val client = HttpClient(OkHttp) { + install(HttpTimeout) { + socketTimeoutMillis = 30000 + } + } + + override suspend fun getDeployments(): List = + Metrics.with({ DeploymentsFetch }, { e -> DeploymentsFetchFailure(e) }) { + log.trace("getDeployments: start") + val response = retry(onFailure = { e -> log.error("Get deployments failed: $e") }) { + client.get(MANAGEMENT_SERVER_URL, "/api/1/deployments") { + headers { + set("Authorization", "Bearer $managementKey") + set("Accept", "application/json") + } + } + } + json.decodeFromString(response.body()) + .deployments + .mapNotNull { it.toDeployment() } + .also { log.trace("getDeployments: end") } + } +} diff --git a/core/src/main/kotlin/project/ProjectProxy.kt b/core/src/main/kotlin/project/ProjectProxy.kt index 663e4c4..fa047ce 100644 --- a/core/src/main/kotlin/project/ProjectProxy.kt +++ b/core/src/main/kotlin/project/ProjectProxy.kt @@ -2,28 +2,30 @@ package com.amplitude.project import com.amplitude.Configuration import com.amplitude.HttpErrorResponseException -import com.amplitude.Project -import com.amplitude.assignment.AmplitudeAssignmentTracker import com.amplitude.assignment.Assignment +import com.amplitude.assignment.AssignmentTracker import com.amplitude.cohort.CohortApiV5 -import com.amplitude.cohort.getCohortStorage +import com.amplitude.cohort.CohortDescription +import com.amplitude.cohort.CohortStorage +import com.amplitude.cohort.USER_GROUP_TYPE import com.amplitude.deployment.DeploymentApiV1 -import com.amplitude.deployment.getDeploymentStorage +import com.amplitude.deployment.DeploymentStorage import com.amplitude.experiment.evaluation.EvaluationEngineImpl -import com.amplitude.experiment.evaluation.FlagConfig -import com.amplitude.experiment.evaluation.FlagResult -import com.amplitude.experiment.evaluation.SkylabUser -import com.amplitude.experiment.evaluation.Variant -import com.amplitude.util.getCohortIds +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.evaluation.EvaluationVariant +import com.amplitude.experiment.evaluation.topologicalSort +import com.amplitude.util.getGroupedCohortIds import com.amplitude.util.logger +import com.amplitude.util.toEvaluationContext import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch -import kotlin.time.DurationUnit -import kotlin.time.toDuration -class ProjectProxy( +internal class ProjectProxy( private val project: Project, - configuration: Configuration = Configuration() + configuration: Configuration, + private val assignmentTracker: AssignmentTracker, + private val deploymentStorage: DeploymentStorage, + private val cohortStorage: CohortStorage ) { companion object { @@ -32,17 +34,12 @@ class ProjectProxy( private val engine = EvaluationEngineImpl() - private val assignmentTracker = AmplitudeAssignmentTracker(project.apiKey, configuration.assignment) + private val projectApi = ProjectApiV1(project.managementKey) private val deploymentApi = DeploymentApiV1(configuration.serverUrl) - private val deploymentStorage = getDeploymentStorage(project.id, configuration.redis) private val cohortApi = CohortApiV5(configuration.cohortServerUrl, project.apiKey, project.secretKey) - private val cohortStorage = getCohortStorage( - project.id, - configuration.redis, - configuration.cohortSyncIntervalMillis.toDuration(DurationUnit.MILLISECONDS) - ) private val projectRunner = ProjectRunner( configuration, + projectApi, deploymentApi, deploymentStorage, cohortApi, @@ -50,86 +47,140 @@ class ProjectProxy( ) suspend fun start() { - log.info("Starting project. projectId=${project.id} deploymentKeys=${project.deploymentKeys}") - // Add deployments to storage - for (deploymentKey in project.deploymentKeys) { - deploymentStorage.putDeployment(deploymentKey) - } - // Remove deployments which are no longer being managed - val storageDeploymentKeys = deploymentStorage.getDeployments() - for (storageDeploymentKey in storageDeploymentKeys - project.deploymentKeys) { - deploymentStorage.removeDeployment(storageDeploymentKey) - deploymentStorage.removeFlagConfigs(storageDeploymentKey) - } + log.info("Starting project. projectId=${project.id}") projectRunner.start() } suspend fun shutdown() { - log.info("Shutting down project. projectId=${project.id}") + log.info("Shutting down project. project.id=${project.id}") projectRunner.stop() } - suspend fun getFlagConfigs(deploymentKey: String?): List { - if (deploymentKey.isNullOrEmpty() || !deploymentKey.startsWith("server-")) { + suspend fun getFlagConfigs(deploymentKey: String?): List { + if (deploymentKey.isNullOrEmpty()) { throw HttpErrorResponseException(status = 401, message = "Invalid deployment.") } - return deploymentStorage.getFlagConfigs(deploymentKey) - ?: throw HttpErrorResponseException(status = 404, message = "Unknown deployment.") + return deploymentStorage.getAllFlags(deploymentKey).values.toList() + } + + suspend fun getCohortDescription(cohortId: String?): CohortDescription { + if (cohortId.isNullOrEmpty()) { + throw HttpErrorResponseException(status = 404, message = "Cohort not found.") + } + return cohortStorage.getCohortDescription(cohortId) + ?: throw HttpErrorResponseException(status = 404, message = "Cohort not found.") + } + + suspend fun getCohortMembers(cohortId: String?): Set { + if (cohortId.isNullOrEmpty()) { + throw HttpErrorResponseException(status = 404, message = "Cohort not found.") + } + val cohortDescription = cohortStorage.getCohortDescription(cohortId) + ?: throw HttpErrorResponseException(status = 404, message = "Cohort not found.") + return cohortStorage.getCohortMembers(cohortDescription) + ?: throw HttpErrorResponseException(status = 404, message = "Cohort not found.") } suspend fun getCohortMembershipsForUser(deploymentKey: String?, userId: String?): Set { - if (deploymentKey.isNullOrEmpty() || !deploymentKey.startsWith("server-")) { + if (deploymentKey.isNullOrEmpty()) { throw HttpErrorResponseException(status = 401, message = "Invalid deployment.") } if (userId.isNullOrEmpty()) { throw HttpErrorResponseException(status = 400, message = "Invalid user ID.") } - val cohortIds = deploymentStorage.getFlagConfigs(deploymentKey)?.getCohortIds() - ?: throw HttpErrorResponseException(status = 404, message = "Unknown deployment.") + val cohortIds = deploymentStorage.getAllFlags(deploymentKey).values.getGroupedCohortIds()[USER_GROUP_TYPE] + if (cohortIds.isNullOrEmpty()) { + return setOf() + } return cohortStorage.getCohortMembershipsForUser(userId, cohortIds) } + suspend fun getCohortMembershipsForGroup(deploymentKey: String?, groupType: String?, groupName: String?): Set { + if (deploymentKey.isNullOrEmpty()) { + throw HttpErrorResponseException(status = 401, message = "Invalid deployment.") + } + if (groupType.isNullOrEmpty()) { + throw HttpErrorResponseException(status = 400, message = "Invalid group type.") + } + if (groupName.isNullOrEmpty()) { + throw HttpErrorResponseException(status = 400, message = "Invalid group name.") + } + val cohortIds = deploymentStorage.getAllFlags(deploymentKey).values.getGroupedCohortIds()[groupType] + if (cohortIds.isNullOrEmpty()) { + return setOf() + } + return cohortStorage.getCohortMembershipsForGroup(groupType, groupName, cohortIds) + } + suspend fun evaluate( deploymentKey: String?, - user: SkylabUser?, + user: Map?, flagKeys: Set? = null - ): Map { - if (deploymentKey.isNullOrEmpty() || !deploymentKey.startsWith("server-")) { + ): Map { + if (deploymentKey.isNullOrEmpty()) { throw HttpErrorResponseException(status = 401, message = "Invalid deployment.") } - // Get flag configs for the deployment from storage. - val flagConfigs = deploymentStorage.getFlagConfigs(deploymentKey) - if (flagConfigs == null || flagConfigs.isEmpty()) { + // Get flag configs for the deployment from storage and topo sort. + val storageFlags = deploymentStorage.getAllFlags(deploymentKey) + if (storageFlags.isEmpty()) { + return mapOf() + } + val flags = topologicalSort(storageFlags, flagKeys ?: setOf()) + if (flags.isEmpty()) { return mapOf() } - // Enrich user with cohort IDs. - val enrichedUser = user?.userId?.let { userId -> - user.copy(cohortIds = cohortStorage.getCohortMembershipsForUser(userId)) + // Enrich user with cohort IDs and build the evaluation context + val userId = user?.get("user_id") as? String + val enrichedUser = user?.toMutableMap() ?: mutableMapOf() + if (userId != null) { + enrichedUser["cohort_ids"] = cohortStorage.getCohortMembershipsForUser(userId) + } + val groups = enrichedUser["groups"] as? Map<*, *> + if (!groups.isNullOrEmpty()) { + val groupCohortIds = mutableMapOf>>() + for (entry in groups.entries) { + val groupType = entry.key as? String + val groupName = (entry.value as? Collection<*>)?.firstOrNull() as? String + if (groupType != null && groupName != null) { + val cohortIds = cohortStorage.getCohortMembershipsForGroup(groupType, groupName) + if (groupCohortIds.isNotEmpty()) { + groupCohortIds.putIfAbsent(groupType, mutableMapOf(groupName to cohortIds)) + } + } + } + if (groupCohortIds.isNotEmpty()) { + enrichedUser["group_cohort_ids"] = groupCohortIds + } } + val evaluationContext = enrichedUser.toEvaluationContext() // Evaluate results - log.debug("evaluate - user=$enrichedUser") - val result = engine.evaluate(flagConfigs, enrichedUser) - if (enrichedUser != null) { + log.debug("evaluate - context={}", evaluationContext) + val result = engine.evaluate(evaluationContext, flags) + if (enrichedUser.isNotEmpty()) { coroutineScope { launch { - assignmentTracker.track(Assignment(enrichedUser, result)) + assignmentTracker.track(Assignment(evaluationContext, result)) } } } - return result.filterDeployedVariants(flagKeys) + return result } - /** - * Filter only non-default, deployed variants from the results that are included if flag keys (if not empty). - */ - private fun Map.filterDeployedVariants(flagKeys: Set?): Map { - return filter { entry -> - val isVariant = !entry.value.isDefaultVariant - val isIncluded = (flagKeys.isNullOrEmpty() || flagKeys.contains(entry.key)) - val isDeployed = entry.value.deployed - isVariant && isIncluded && isDeployed - }.mapValues { entry -> - entry.value.variant - }.toMap() + suspend fun evaluateV1( + deploymentKey: String?, + user: Map?, + flagKeys: Set? = null + ): Map { + return evaluate(deploymentKey, user, flagKeys).filter { entry -> + val default = entry.value.metadata?.get("default") as? Boolean ?: false + val deployed = entry.value.metadata?.get("deployed") as? Boolean ?: true + (!default && deployed) + } + } + + // Internal + + internal suspend fun getDeployments(): Set { + return deploymentStorage.getDeployments().keys } } diff --git a/core/src/main/kotlin/project/ProjectRunner.kt b/core/src/main/kotlin/project/ProjectRunner.kt index c4a50a2..b4aa0f0 100644 --- a/core/src/main/kotlin/project/ProjectRunner.kt +++ b/core/src/main/kotlin/project/ProjectRunner.kt @@ -7,8 +7,8 @@ import com.amplitude.cohort.CohortStorage import com.amplitude.deployment.DeploymentApi import com.amplitude.deployment.DeploymentRunner import com.amplitude.deployment.DeploymentStorage -import com.amplitude.experiment.evaluation.FlagConfig -import com.amplitude.util.getCohortIds +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.util.getAllCohortIds import com.amplitude.util.logger import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job @@ -20,8 +20,9 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -class ProjectRunner( +internal class ProjectRunner( private val configuration: Configuration, + private val projectApi: ProjectApi, private val deploymentApi: DeploymentApi, private val deploymentStorage: DeploymentStorage, cohortApi: CohortApi, @@ -40,18 +41,12 @@ class ProjectRunner( private val cohortLoader = CohortLoader(configuration.maxCohortSize, cohortApi, cohortStorage) suspend fun start() { - refresh(deploymentStorage.getDeployments()) - // Collect deployment updates from the storage - scope.launch { - deploymentStorage.deployments.collect { deployments -> - refresh(deployments) - } - } - // Periodic deployment refresher + refresh() + // Periodic deployment update and refresher scope.launch { while (true) { - delay(configuration.flagSyncIntervalMillis) - refresh(deploymentStorage.getDeployments()) + delay(configuration.deploymentSyncIntervalMillis) + refresh() } } } @@ -65,29 +60,47 @@ class ProjectRunner( supervisor.cancelAndJoin() } - private suspend fun refresh(deploymentKeys: Set) { - log.debug("refresh: start - deploymentKeys=$deploymentKeys") - lock.withLock { - val jobs = mutableListOf() - val runningDeployments = deploymentRunners.keys.toSet() - val addedDeployments = deploymentKeys - runningDeployments - val removedDeployments = runningDeployments - deploymentKeys - addedDeployments.forEach { deployment -> - jobs += scope.launch { addDeploymentInternal(deployment) } - } - removedDeployments.forEach { deployment -> - jobs += scope.launch { removeDeploymentInternal(deployment) } - } - jobs.joinAll() - // Keep cohorts which are targeted by all stored deployments. - removeUnusedCohorts(deploymentKeys) + private suspend fun refresh() = lock.withLock { + log.trace("refresh: start") + // Get deployments from API and update the storage. + val networkDeployments = projectApi.getDeployments().associateBy { it.key } + val storageDeployments = deploymentStorage.getDeployments() + // Determine added and removed deployments + val addedDeployments = networkDeployments - storageDeployments.keys + val removedDeployments = storageDeployments - networkDeployments.keys + val startingDeployments = networkDeployments - deploymentRunners.keys + val jobs = mutableListOf() + for ((_, addedDeployment) in addedDeployments) { + log.info("Adding deployment $addedDeployment") + deploymentStorage.putDeployment(addedDeployment) + } + for ((_, deployment) in startingDeployments) { + jobs += scope.launch { addDeploymentInternal(deployment.key) } + } + for ((_, removedDeployment) in removedDeployments) { + log.info("Removing deployment $removedDeployment") + deploymentStorage.removeAllFlags(removedDeployment.key) + deploymentStorage.removeDeployment(removedDeployment.key) + jobs += scope.launch { removeDeploymentInternal(removedDeployment.key) } } - log.debug("refresh: end - deploymentKeys=$deploymentKeys") + jobs.joinAll() + // Keep cohorts which are targeted by all stored deployments. + removeUnusedCohorts(networkDeployments.keys) + log.debug( + "Project refresh finished: addedDeployments={}, removedDeployments={}, startedDeployments={}", + addedDeployments.keys, + removedDeployments.keys, + startingDeployments.keys + ) + log.trace("refresh: end") } // Must be run within lock private suspend fun addDeploymentInternal(deploymentKey: String) { - log.info("Adding deployment $deploymentKey") + if (deploymentRunners.contains(deploymentKey)) { + return + } + log.debug("Adding and starting deployment runner for $deploymentKey") val deploymentRunner = DeploymentRunner( configuration, deploymentKey, @@ -101,18 +114,16 @@ class ProjectRunner( // Must be run within lock private suspend fun removeDeploymentInternal(deploymentKey: String) { - log.info("Removing deployment $deploymentKey") + log.debug("Removing and stopping deployment runner for $deploymentKey") deploymentRunners.remove(deploymentKey)?.stop() - deploymentStorage.removeFlagConfigs(deploymentKey) - deploymentStorage.removeDeployment(deploymentKey) } private suspend fun removeUnusedCohorts(deploymentKeys: Set) { - val allFlagConfigs = mutableListOf() + val allFlagConfigs = mutableListOf() for (deploymentKey in deploymentKeys) { - allFlagConfigs += deploymentStorage.getFlagConfigs(deploymentKey) ?: continue + allFlagConfigs += deploymentStorage.getAllFlags(deploymentKey).values } - val allTargetedCohortIds = allFlagConfigs.getCohortIds() + val allTargetedCohortIds = allFlagConfigs.getAllCohortIds() val allStoredCohortDescriptions = cohortStorage.getCohortDescriptions().values for (cohortDescription in allStoredCohortDescriptions) { if (!allTargetedCohortIds.contains(cohortDescription.id)) { diff --git a/core/src/main/kotlin/project/ProjectStorage.kt b/core/src/main/kotlin/project/ProjectStorage.kt index 6d7462d..6fe1b36 100644 --- a/core/src/main/kotlin/project/ProjectStorage.kt +++ b/core/src/main/kotlin/project/ProjectStorage.kt @@ -9,24 +9,23 @@ import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -interface ProjectStorage { +internal interface ProjectStorage { val projects: Flow> suspend fun getProjects(): Set suspend fun putProject(projectId: String) suspend fun removeProject(projectId: String) } -fun getProjectStorage(redisConfiguration: RedisConfiguration?): ProjectStorage { +internal fun getProjectStorage(redisConfiguration: RedisConfiguration?): ProjectStorage { val uri = redisConfiguration?.uri - val prefix = redisConfiguration?.prefix - return if (uri == null || prefix == null) { + return if (uri == null) { InMemoryProjectStorage() } else { - RedisProjectStorage(uri, prefix) + RedisProjectStorage(redisConfiguration.prefix, RedisConnection(uri)) } } -class InMemoryProjectStorage : ProjectStorage { +internal class InMemoryProjectStorage : ProjectStorage { override val projects = MutableSharedFlow>( extraBufferCapacity = 1, @@ -51,29 +50,27 @@ class InMemoryProjectStorage : ProjectStorage { } } -class RedisProjectStorage( - uri: String, - prefix: String +internal class RedisProjectStorage( + private val prefix: String, + private val redis: RedisConnection ) : ProjectStorage { - private val redis = RedisConnection(uri, prefix) - override val projects = MutableSharedFlow>( extraBufferCapacity = 1, onBufferOverflow = BufferOverflow.DROP_OLDEST ) override suspend fun getProjects(): Set { - return redis.smembers(RedisKey.Projects) ?: emptySet() + return redis.smembers(RedisKey.Projects(prefix)) ?: emptySet() } override suspend fun putProject(projectId: String) { - redis.sadd(RedisKey.Projects, setOf(projectId)) + redis.sadd(RedisKey.Projects(prefix), setOf(projectId)) projects.emit(getProjects()) } override suspend fun removeProject(projectId: String) { - redis.srem(RedisKey.Projects, projectId) + redis.srem(RedisKey.Projects(prefix), projectId) projects.emit(getProjects()) } } diff --git a/core/src/main/kotlin/util/Cache.kt b/core/src/main/kotlin/util/Cache.kt index fb0f07e..ded1bca 100644 --- a/core/src/main/kotlin/util/Cache.kt +++ b/core/src/main/kotlin/util/Cache.kt @@ -7,7 +7,7 @@ import java.util.HashMap /** * Least recently used (LRU) cache with TTL for cache entries. */ -class Cache( +internal class Cache( private val capacity: Int, private val ttlMillis: Long = 0 ) { diff --git a/core/src/main/kotlin/util/EvaluationContext.kt b/core/src/main/kotlin/util/EvaluationContext.kt new file mode 100644 index 0000000..e13e88f --- /dev/null +++ b/core/src/main/kotlin/util/EvaluationContext.kt @@ -0,0 +1,62 @@ +package com.amplitude.util + +import com.amplitude.experiment.evaluation.EvaluationContext + +internal fun EvaluationContext.userId(): String? { + return (this["user"] as? Map<*, *>)?.get("user_id")?.toString() +} +internal fun EvaluationContext.deviceId(): String? { + return (this["user"] as? Map<*, *>)?.get("device_id")?.toString() +} + +internal fun EvaluationContext.groups(): Map<*, *>? { + return this["groups"] as? Map<*, *> +} + +internal fun MutableMap?.toEvaluationContext(): EvaluationContext { + val context = EvaluationContext() + if (this.isNullOrEmpty()) { + return context + } + val groups = mutableMapOf>() + val userGroups = this["groups"] as? Map<*, *> + if (!userGroups.isNullOrEmpty()) { + for (entry in userGroups) { + val groupType = entry.key as? String ?: continue + val groupNames = entry.value as? Collection<*> ?: continue + if (groupNames.isNotEmpty()) { + val groupName = groupNames.first() as? String ?: continue + val groupNameMap = mutableMapOf().apply { put("group_name", groupName) } + val groupProperties = this.select("group_properties", groupType, groupName) as? Map<*, *> + if (!groupProperties.isNullOrEmpty()) { + groupNameMap["group_properties"] = groupProperties + } + val groupCohortIds = this.select("group_cohort_ids", groupType, groupName) as? Map<*, *> + if (!groupCohortIds.isNullOrEmpty()) { + groupNameMap["cohort_ids"] = groupCohortIds + } + groups[groupType] = groupNameMap + } + } + context["groups"] = groups + } + remove("groups") + remove("group_properties") + context["user"] = this + return context +} + +private fun Map<*, *>.select(vararg selector: Any?): Any? { + var map: Map<*, *> = this + var result: Any? + for (i in 0 until selector.size - 1) { + val select = selector[i] + result = map[select] + if (result is Map<*, *>) { + map = result + } else { + return null + } + } + return map[selector.last()] +} diff --git a/core/src/main/kotlin/util/EvaluationFlag.kt b/core/src/main/kotlin/util/EvaluationFlag.kt new file mode 100644 index 0000000..1faf301 --- /dev/null +++ b/core/src/main/kotlin/util/EvaluationFlag.kt @@ -0,0 +1,67 @@ +package com.amplitude.util + +import com.amplitude.cohort.USER_GROUP_TYPE +import com.amplitude.experiment.evaluation.EvaluationCondition +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.evaluation.EvaluationOperator +import com.amplitude.experiment.evaluation.EvaluationSegment + +internal fun Collection.getAllCohortIds(): Set { + return getGroupedCohortIds().flatMap { it.value }.toSet() +} + +internal fun Collection.getGroupedCohortIds(): Map> { + val cohortIds = mutableMapOf>() + for (flag in this) { + for (entry in flag.getGroupedCohortIds()) { + cohortIds.getOrPut(entry.key) { mutableSetOf() } += entry.value + } + } + return cohortIds +} + +internal fun EvaluationFlag.getAllCohortIds(): Set { + return getGroupedCohortIds().flatMap { it.value }.toSet() +} + +internal fun EvaluationFlag.getGroupedCohortIds(): Map> { + val cohortIds = mutableMapOf>() + for (segment in this.segments) { + for (entry in segment.getGroupedCohortConditionIds()) { + cohortIds.getOrPut(entry.key) { mutableSetOf() } += entry.value + } + } + return cohortIds +} + +private fun EvaluationSegment.getGroupedCohortConditionIds(): Map> { + val cohortIds = mutableMapOf>() + if (conditions == null) { + return cohortIds + } + for (outer in conditions!!) { + for (condition in outer) { + if (condition.isCohortFilter()) { + // User cohort selector is [context, user, cohort_ids] + // Groups cohort selector is [context, groups, {group_type}, cohort_ids] + if (condition.selector.size > 2) { + val contextSubtype = condition.selector[1] + val groupType = if (contextSubtype == "user") { + USER_GROUP_TYPE + } else if (condition.selector.contains("groups")) { + condition.selector[2] + } else { + continue + } + cohortIds.getOrPut(groupType) { mutableSetOf() } += condition.values + } + } + } + } + return cohortIds +} + +// Only cohort filters use these operators. +private fun EvaluationCondition.isCohortFilter(): Boolean = + (this.op == EvaluationOperator.SET_CONTAINS_ANY || this.op == EvaluationOperator.SET_DOES_NOT_CONTAIN_ANY) && + this.selector.isNotEmpty() && this.selector.last() == "cohort_ids" diff --git a/core/src/main/kotlin/util/FlagConfig.kt b/core/src/main/kotlin/util/FlagConfig.kt deleted file mode 100644 index 81b99f8..0000000 --- a/core/src/main/kotlin/util/FlagConfig.kt +++ /dev/null @@ -1,34 +0,0 @@ -package com.amplitude.util - -import com.amplitude.experiment.evaluation.FlagConfig -import com.amplitude.experiment.evaluation.UserPropertyFilter - -private const val COHORT_PROP_KEY = "userdata_cohort" - -fun Collection.getCohortIds(): Set { - val cohortIds = mutableSetOf() - for (flag in this) { - cohortIds += flag.getCohortIds() - } - return cohortIds -} - -private fun FlagConfig.getCohortIds(): Set { - val cohortIds = mutableSetOf() - for (filter in this.allUsersTargetingConfig.conditions) { - if (filter.isCohortFilter()) { - cohortIds += filter.values - } - } - val customSegments = this.customSegmentTargetingConfigs ?: listOf() - for (segment in customSegments) { - for (filter in segment.conditions) { - if (filter.isCohortFilter()) { - cohortIds += filter.values - } - } - } - return cohortIds -} - -private fun UserPropertyFilter.isCohortFilter(): Boolean = this.prop == COHORT_PROP_KEY diff --git a/core/src/main/kotlin/util/Http.kt b/core/src/main/kotlin/util/Http.kt index b03e334..5d244a5 100644 --- a/core/src/main/kotlin/util/Http.kt +++ b/core/src/main/kotlin/util/Http.kt @@ -10,9 +10,10 @@ import io.ktor.http.HttpStatusCode import io.ktor.http.path import kotlinx.coroutines.delay -internal class HttpErrorResponseException( - val statusCode: HttpStatusCode -) : Exception("HTTP error response: code=$statusCode, message=${statusCode.description}") +internal class HttpErrorException( + val statusCode: HttpStatusCode, + response: HttpResponse? = null +) : Exception("HTTP error response: code=$statusCode, message=${statusCode.description}, response=$response") internal data class RetryConfig( val times: Int = 8, @@ -31,12 +32,12 @@ internal suspend fun retry( for (i in 0..config.times) { try { val response = block() - if (response.status.value in 200..299) { + if (response.status.value in 100..399) { return response } else { - throw HttpErrorResponseException(response.status) + throw HttpErrorException(response.status, response) } - } catch (e: HttpErrorResponseException) { + } catch (e: HttpErrorException) { val code = e.statusCode.value onFailure(e) if (code != 429 && code in 400..499) { @@ -52,7 +53,7 @@ internal suspend fun retry( throw error!! } -suspend fun HttpClient.get( +internal suspend fun HttpClient.get( url: String, path: String, block: HttpRequestBuilder.() -> Unit @@ -60,7 +61,7 @@ suspend fun HttpClient.get( return request(HttpMethod.Get, url, path, block) } -suspend fun HttpClient.request( +internal suspend fun HttpClient.request( method: HttpMethod, url: String, path: String, diff --git a/core/src/main/kotlin/util/Json.kt b/core/src/main/kotlin/util/Json.kt index 6bf68f5..c9d9d70 100644 --- a/core/src/main/kotlin/util/Json.kt +++ b/core/src/main/kotlin/util/Json.kt @@ -1,7 +1,78 @@ package com.amplitude.util +import kotlinx.serialization.KSerializer +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.booleanOrNull +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.doubleOrNull +import kotlinx.serialization.json.intOrNull +import kotlinx.serialization.json.longOrNull val json = Json { ignoreUnknownKeys = true + isLenient = true + coerceInputValues = true + explicitNulls = false } + +object AnySerializer : KSerializer { + private val delegate = JsonElement.serializer() + override val descriptor: SerialDescriptor + get() = SerialDescriptor("Any", delegate.descriptor) + + override fun serialize(encoder: Encoder, value: Any?) { + val jsonElement = value.toJsonElement() + encoder.encodeSerializableValue(delegate, jsonElement) + } + + override fun deserialize(decoder: Decoder): Any? { + val jsonElement = decoder.decodeSerializableValue(delegate) + return jsonElement.toAny() + } +} + +fun Any?.toJsonElement(): JsonElement = when (this) { + null -> JsonNull + is Map<*, *> -> toJsonObject() + is Collection<*> -> toJsonArray() + is Boolean -> JsonPrimitive(this) + is Number -> JsonPrimitive(this) + is String -> JsonPrimitive(this) + else -> JsonPrimitive(toString()) +} + +fun Collection<*>.toJsonArray(): JsonArray = JsonArray(map { it.toJsonElement() }) + +fun Map<*, *>.toJsonObject(): JsonObject = JsonObject( + mapNotNull { + (it.key as? String ?: return@mapNotNull null) to it.value.toJsonElement() + }.toMap() +) + +fun JsonElement.toAny(): Any? { + return when (this) { + is JsonPrimitive -> toAny() + is JsonArray -> toList() + is JsonObject -> toMap() + } +} + +fun JsonPrimitive.toAny(): Any? { + return if (isString) { + contentOrNull + } else { + booleanOrNull ?: intOrNull ?: longOrNull ?: doubleOrNull + } +} + +fun JsonArray.toList(): List = map { it.toAny() } + +fun JsonObject.toMap(): Map = mapValues { it.value.toAny() } diff --git a/core/src/main/kotlin/util/Redis.kt b/core/src/main/kotlin/util/Redis.kt index 96711dd..2cad423 100644 --- a/core/src/main/kotlin/util/Redis.kt +++ b/core/src/main/kotlin/util/Redis.kt @@ -1,5 +1,8 @@ package com.amplitude.util +import com.amplitude.Metrics +import com.amplitude.RedisCommand +import com.amplitude.RedisCommandFailure import com.amplitude.cohort.CohortDescription import io.lettuce.core.RedisClient import io.lettuce.core.RedisFuture @@ -11,29 +14,33 @@ import kotlinx.coroutines.Deferred import kotlinx.coroutines.future.asDeferred import kotlin.time.Duration -private const val STORAGE_PROTOCOL_VERSION = "v1" +private const val STORAGE_PROTOCOL_VERSION = "v2" internal sealed class RedisKey(val value: String) { - object Projects : RedisKey("projects") + data class Projects(val prefix: String) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects") data class Deployments( + val prefix: String, val projectId: String - ) : RedisKey("projects:$projectId:deployments") + ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:deployments") data class FlagConfigs( + val prefix: String, val projectId: String, val deploymentKey: String - ) : RedisKey("projects:$projectId:deployments:$deploymentKey:flags") + ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:deployments:$deploymentKey:flags") data class CohortDescriptions( + val prefix: String, val projectId: String - ) : RedisKey("projects:$projectId:cohorts") + ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:cohorts") data class CohortMembers( + val prefix: String, val projectId: String, val cohortDescription: CohortDescription - ) : RedisKey("projects:$projectId:cohorts:${cohortDescription.id}:users:${cohortDescription.lastComputed}") + ) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:projects:$projectId:cohorts:${cohortDescription.id}:${cohortDescription.groupType}:${cohortDescription.lastComputed}") } internal interface Redis { @@ -52,8 +59,7 @@ internal interface Redis { } internal class RedisConnection( - redisUri: String, - private val redisPrefix: String + redisUri: String ) : Redis { private val connection: Deferred> @@ -65,83 +71,81 @@ internal class RedisConnection( override suspend fun get(key: RedisKey): String? { return connection.run { - get(key.getPrefixedKey()) + get(key.value) } } override suspend fun set(key: RedisKey, value: String) { connection.run { - set(key.getPrefixedKey(), value) + set(key.value, value) } } override suspend fun del(key: RedisKey) { connection.run { - del(key.getPrefixedKey()) + del(key.value) } } override suspend fun sadd(key: RedisKey, values: Set) { connection.run { - sadd(key.getPrefixedKey(), *values.toTypedArray()) + sadd(key.value, *values.toTypedArray()) } } override suspend fun srem(key: RedisKey, value: String) { connection.run { - srem(key.getPrefixedKey(), value) + srem(key.value, value) } } override suspend fun smembers(key: RedisKey): Set? { return connection.run { - smembers(key.getPrefixedKey()) + smembers(key.value) } } override suspend fun sismember(key: RedisKey, value: String): Boolean { return connection.run { - sismember(key.getPrefixedKey(), value) + sismember(key.value, value) } } override suspend fun hget(key: RedisKey, field: String): String? { return connection.run { - hget(key.getPrefixedKey(), field) + hget(key.value, field) } } override suspend fun hgetall(key: RedisKey): Map? { return connection.run { - hgetall(key.getPrefixedKey()) + hgetall(key.value) } } override suspend fun hset(key: RedisKey, values: Map) { connection.run { - hset(key.getPrefixedKey(), values) + hset(key.value, values) } } override suspend fun hdel(key: RedisKey, field: String) { connection.run { - hdel(key.getPrefixedKey(), field) + hdel(key.value, field) } } override suspend fun expire(key: RedisKey, ttl: Duration) { connection.run { - expire(key.getPrefixedKey(), ttl.inWholeSeconds) + expire(key.value, ttl.inWholeSeconds) } } private suspend inline fun Deferred>.run( - action: RedisAsyncCommands.() -> RedisFuture + crossinline action: RedisAsyncCommands.() -> RedisFuture ): R { - return await().async().action().asDeferred().await() - } - - private fun RedisKey.getPrefixedKey(): String { - return "$redisPrefix:$STORAGE_PROTOCOL_VERSION:${this.value}" + return Metrics.with({ RedisCommand }, { e -> RedisCommandFailure(e) }) { + await().async().action().asDeferred().await() + } } } diff --git a/core/src/main/kotlin/util/Yaml.kt b/core/src/main/kotlin/util/Yaml.kt new file mode 100644 index 0000000..c080851 --- /dev/null +++ b/core/src/main/kotlin/util/Yaml.kt @@ -0,0 +1,6 @@ +package com.amplitude.util + +import com.charleskorn.kaml.Yaml +import com.charleskorn.kaml.YamlConfiguration + +val yaml = Yaml(configuration = YamlConfiguration(strictMode = false)) diff --git a/core/src/test/kotlin/Utils.kt b/core/src/test/kotlin/Utils.kt new file mode 100644 index 0000000..e038dba --- /dev/null +++ b/core/src/test/kotlin/Utils.kt @@ -0,0 +1,15 @@ +fun user( + userId: String? = null, + deviceId: String? = null, + userProperties: Map? = null, + groups: Map>? = null, + groupProperties: Map>>? = null +): MutableMap { + return mutableMapOf( + "user_id" to userId, + "device_id" to deviceId, + "user_properties" to userProperties, + "groups" to groups, + "group_properties" to groupProperties + ) +} diff --git a/core/src/test/kotlin/assignment/AssignmentFilterTest.kt b/core/src/test/kotlin/assignment/AssignmentFilterTest.kt index 8d9901e..2aeb57d 100644 --- a/core/src/test/kotlin/assignment/AssignmentFilterTest.kt +++ b/core/src/test/kotlin/assignment/AssignmentFilterTest.kt @@ -1,8 +1,7 @@ -package com.amplitude.experiment.assignment - import com.amplitude.assignment.Assignment import com.amplitude.assignment.InMemoryAssignmentFilter -import com.amplitude.experiment.evaluation.SkylabUser +import com.amplitude.experiment.evaluation.EvaluationVariant +import com.amplitude.util.toEvaluationContext import kotlinx.coroutines.runBlocking import org.junit.Assert import org.junit.Test @@ -13,10 +12,10 @@ class AssignmentFilterTest { fun `test single assignment`() = runBlocking { val filter = InMemoryAssignmentFilter(100) val assignment = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) Assert.assertTrue(filter.shouldTrack(assignment)) @@ -26,18 +25,18 @@ class AssignmentFilterTest { fun `test duplicate assignments`() = runBlocking { val filter = InMemoryAssignmentFilter(100) val assignment1 = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) filter.shouldTrack(assignment1) val assignment2 = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) Assert.assertFalse(filter.shouldTrack(assignment2)) @@ -47,18 +46,18 @@ class AssignmentFilterTest { fun `test same user different results`() = runBlocking { val filter = InMemoryAssignmentFilter(100) val assignment1 = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) Assert.assertTrue(filter.shouldTrack(assignment1)) val assignment2 = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("control"), - "flag-key-2" to flagResult("on") + "flag-key-1" to EvaluationVariant(key = "control"), + "flag-key-2" to EvaluationVariant(key = "on") ) ) Assert.assertTrue(filter.shouldTrack(assignment2)) @@ -68,18 +67,18 @@ class AssignmentFilterTest { fun `test same results for different users`() = runBlocking { val filter = InMemoryAssignmentFilter(100) val assignment1 = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) Assert.assertTrue(filter.shouldTrack(assignment1)) val assignment2 = Assignment( - SkylabUser(userId = "different user"), + user(userId = "different user").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) Assert.assertTrue(filter.shouldTrack(assignment2)) @@ -89,17 +88,17 @@ class AssignmentFilterTest { fun `test empty results`() = runBlocking { val filter = InMemoryAssignmentFilter(100) val assignment1 = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), mapOf() ) Assert.assertTrue(filter.shouldTrack(assignment1)) val assignment2 = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), mapOf() ) Assert.assertFalse(filter.shouldTrack(assignment2)) val assignment3 = Assignment( - SkylabUser(userId = "different user"), + user(userId = "different user").toEvaluationContext(), mapOf() ) Assert.assertTrue(filter.shouldTrack(assignment3)) @@ -109,18 +108,18 @@ class AssignmentFilterTest { fun `test duplicate assignments with different result ordering`() = runBlocking { val filter = InMemoryAssignmentFilter(100) val assignment1 = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), linkedMapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) Assert.assertTrue(filter.shouldTrack(assignment1)) val assignment2 = Assignment( - SkylabUser(userId = "user"), + user(userId = "user").toEvaluationContext(), linkedMapOf( - "flag-key-2" to flagResult("control"), - "flag-key-1" to flagResult("on") + "flag-key-2" to EvaluationVariant(key = "control"), + "flag-key-1" to EvaluationVariant(key = "on") ) ) Assert.assertFalse(filter.shouldTrack(assignment2)) @@ -130,26 +129,26 @@ class AssignmentFilterTest { fun `test lru replacement`() = runBlocking { val filter = InMemoryAssignmentFilter(2) val assignment1 = Assignment( - SkylabUser(userId = "user1"), + user(userId = "user").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) Assert.assertTrue(filter.shouldTrack(assignment1)) val assignment2 = Assignment( - SkylabUser(userId = "user2"), + user(userId = "user2").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) Assert.assertTrue(filter.shouldTrack(assignment2)) val assignment3 = Assignment( - SkylabUser(userId = "user3"), + user(userId = "user3").toEvaluationContext(), mapOf( - "flag-key-1" to flagResult("on"), - "flag-key-2" to flagResult("control") + "flag-key-1" to EvaluationVariant(key = "on"), + "flag-key-2" to EvaluationVariant(key = "control") ) ) Assert.assertTrue(filter.shouldTrack(assignment3)) diff --git a/core/src/test/kotlin/assignment/AssignmentServiceTest.kt b/core/src/test/kotlin/assignment/AssignmentServiceTest.kt index 6dd9e15..422c309 100644 --- a/core/src/test/kotlin/assignment/AssignmentServiceTest.kt +++ b/core/src/test/kotlin/assignment/AssignmentServiceTest.kt @@ -1,9 +1,10 @@ -package com.amplitude.experiment.assignment - import com.amplitude.assignment.Assignment import com.amplitude.assignment.DAY_MILLIS import com.amplitude.assignment.toAmplitudeEvent -import com.amplitude.experiment.evaluation.SkylabUser +import com.amplitude.experiment.evaluation.EvaluationVariant +import com.amplitude.util.deviceId +import com.amplitude.util.toEvaluationContext +import com.amplitude.util.userId import kotlinx.coroutines.runBlocking import org.junit.Assert import org.junit.Test @@ -12,30 +13,35 @@ class AssignmentServiceTest { @Test fun `test assignment to amplitude event`() = runBlocking { - val user = SkylabUser(userId = "user", deviceId = "device") + val user = user(userId = "user", deviceId = "device").toEvaluationContext() val results = mapOf( - "flag-key-1" to flagResult( - variant = "on", - description = "description-1", - isDefaultVariant = false + "flag-key-1" to EvaluationVariant( + key = "on", + metadata = mapOf( + "version" to 1, + "segmentName" to "Segment 1" + ) ), - "flag-key-2" to flagResult( - variant = "off", - description = "description-2", - isDefaultVariant = true + "flag-key-2" to EvaluationVariant( + key = "off", + metadata = mapOf( + "default" to true, + "version" to 1, + "segmentName" to "All Other Users" + ) ) ) val assignment = Assignment(user, results) val event = assignment.toAmplitudeEvent() - Assert.assertEquals(user.userId, event.userId) - Assert.assertEquals(user.deviceId, event.deviceId) + Assert.assertEquals(user.userId(), event.userId) + Assert.assertEquals(user.deviceId(), event.deviceId) Assert.assertEquals("[Experiment] Assignment", event.eventType) val eventProperties = event.eventProperties Assert.assertEquals(4, eventProperties.length()) Assert.assertEquals("on", eventProperties.get("flag-key-1.variant")) - Assert.assertEquals("description-1", eventProperties.get("flag-key-1.details")) + Assert.assertEquals("v1 rule:Segment 1", eventProperties.get("flag-key-1.details")) Assert.assertEquals("off", eventProperties.get("flag-key-2.variant")) - Assert.assertEquals("description-2", eventProperties.get("flag-key-2.details")) + Assert.assertEquals("v1 rule:All Other Users", eventProperties.get("flag-key-2.details")) val userProperties = event.userProperties Assert.assertEquals(2, userProperties.length()) Assert.assertEquals(1, userProperties.getJSONObject("\$set").length()) diff --git a/core/src/test/kotlin/assignment/Utils.kt b/core/src/test/kotlin/assignment/Utils.kt deleted file mode 100644 index 008f4f6..0000000 --- a/core/src/test/kotlin/assignment/Utils.kt +++ /dev/null @@ -1,15 +0,0 @@ -package com.amplitude.experiment.assignment - -import com.amplitude.experiment.evaluation.FlagResult -import com.amplitude.experiment.evaluation.Variant - -internal fun flagResult( - variant: String, - description: String = "description", - isDefaultVariant: Boolean = false, - expKey: String? = null, - deployed: Boolean = true, - type: String = "release" -): FlagResult { - return FlagResult(Variant(variant), description, isDefaultVariant, expKey, deployed, type) -} diff --git a/core/src/test/kotlin/util/CacheTest.kt b/core/src/test/kotlin/util/CacheTest.kt index 7598a70..db7f5e1 100644 --- a/core/src/test/kotlin/util/CacheTest.kt +++ b/core/src/test/kotlin/util/CacheTest.kt @@ -1,5 +1,3 @@ -package com.amplituide.util - import com.amplitude.util.Cache import kotlinx.coroutines.Job import kotlinx.coroutines.joinAll diff --git a/gradle.properties b/gradle.properties index 747a8e4..1ee7abd 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,16 +1,16 @@ # kotlin kotlin.code.style=official -ktorVersion=2.2.4 -kotlinVersion=1.8.10 -serialziationVersion=1.8.0 +ktorVersion=2.3.5 +kotlinVersion=1.9.10 +serialziationVersion=1.9.0 # logging & metrics logbackVersion=1.4.6 prometheusVersion=1.10.5 # amplitude -experimentEvaluationVersion = 1.1.1 -amplitudeAnalytics = 1.10.3 +experimentEvaluationVersion = 2.0.0-beta.2 +amplitudeAnalytics = 1.12.0 amplitudeAnalyticsJson = 20230227 # redis diff --git a/service/build.gradle.kts b/service/build.gradle.kts index 14c5fda..60dc1b1 100644 --- a/service/build.gradle.kts +++ b/service/build.gradle.kts @@ -3,8 +3,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { application id("io.ktor.plugin") version "2.2.4" - kotlin("jvm") version "1.8.10" - kotlin("plugin.serialization") version "1.8.0" + kotlin("jvm") version "1.9.10" + kotlin("plugin.serialization") version "1.9.0" id("org.jlleitschuh.gradle.ktlint") version "11.3.1" } @@ -34,7 +34,6 @@ val serializationVersion: String by project dependencies { implementation(project(":core")) implementation("com.amplitude:evaluation-core:$experimentEvaluationVersion") - implementation("com.amplitude:evaluation-serialization:$experimentEvaluationVersion") implementation("io.ktor:ktor-server-call-logging-jvm:$ktorVersion") implementation("io.ktor:ktor-server-core-jvm:$ktorVersion") implementation("io.ktor:ktor-server-metrics-micrometer-jvm:$ktorVersion") diff --git a/service/src/main/kotlin/Server.kt b/service/src/main/kotlin/Server.kt index 008ea33..93ab5e0 100644 --- a/service/src/main/kotlin/Server.kt +++ b/service/src/main/kotlin/Server.kt @@ -1,14 +1,16 @@ +@file:UseSerializers(AnySerializer::class) + package com.amplitude -import com.amplitude.experiment.evaluation.SkylabUser -import com.amplitude.experiment.evaluation.serialization.SerialExperimentUser import com.amplitude.plugins.configureLogging import com.amplitude.plugins.configureMetrics +import com.amplitude.util.AnySerializer import com.amplitude.util.json import com.amplitude.util.logger import com.amplitude.util.stringEnv import io.ktor.http.HttpStatusCode import io.ktor.serialization.kotlinx.json.json +import io.ktor.server.application.Application import io.ktor.server.application.ApplicationCall import io.ktor.server.application.call import io.ktor.server.application.createApplicationPlugin @@ -25,10 +27,14 @@ import io.ktor.server.routing.post import io.ktor.server.routing.routing import io.ktor.util.toByteArray import kotlinx.coroutines.runBlocking -import kotlinx.serialization.decodeFromString +import kotlinx.serialization.UseSerializers +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import java.io.FileNotFoundException import java.util.Base64 +private lateinit var evaluationProxy: EvaluationProxy + fun main() { val log = logger("Service") @@ -59,91 +65,156 @@ fun main() { ConfigurationFile.fromEnv() } + /* + * Initialize and start the evaluation proxy. + */ + evaluationProxy = EvaluationProxy( + projectsFile.projects, + configFile.configuration + ) + /* * Start the server. */ - embeddedServer(Netty, port = configFile.configuration.port, host = "0.0.0.0") { - /* - * Initialize and start the evaluation proxy. - */ - val evaluationProxy = EvaluationProxy( - projectsFile.projects, - configFile.configuration - ) - runBlocking { - evaluationProxy.start() - } + embeddedServer( + factory = Netty, + port = configFile.configuration.port, + host = "0.0.0.0", + module = Application::proxyServer + ).start(wait = true) +} - /* - * Configure ktor plugins. - */ - configureLogging() - configureMetrics() - install(ContentNegotiation) { - json() - } - install( - createApplicationPlugin("shutdown") { - val plugin = ShutDownUrl("/shutdown") { 0 } - onCall { call -> - if (call.request.uri == plugin.url) { - evaluationProxy.shutdown() - plugin.doShutdown(call) - } - } - } - ) +fun Application.proxyServer() { + runBlocking { + evaluationProxy.start() + } - /* - * Configure endpoints. - */ - routing { - get("/sdk/v1/deployments/{deployment}/flags") { - val deployment = this.call.parameters["deployment"] - val result = try { - evaluationProxy.getSerializedFlagConfigs(deployment) - } catch (e: HttpErrorResponseException) { - call.respond(HttpStatusCode.fromValue(e.status), e.message) - return@get + /* + * Configure ktor plugins. + */ + configureLogging() + configureMetrics() + install(ContentNegotiation) { + json(json) + } + install( + createApplicationPlugin("shutdown") { + val plugin = ShutDownUrl("/shutdown") { 0 } + onCall { call -> + if (call.request.uri == plugin.url) { + evaluationProxy.shutdown() + plugin.doShutdown(call) } - call.respond(result) } + } + ) - get("/sdk/v1/deployments/{deployment}/users/{userId}/cohorts") { - val deployment = this.call.parameters["deployment"] - val userId = this.call.parameters["userId"] - val result = try { - evaluationProxy.getSerializedCohortMembershipsForUser(deployment, userId) - } catch (e: HttpErrorResponseException) { - call.respond(HttpStatusCode.fromValue(e.status), e.message) - return@get - } - call.respond(result) - } + /* + * Configure endpoints. + */ + routing { + // Local Evaluation - get("/sdk/vardata") { - call.evaluate(evaluationProxy, ApplicationRequest::getUserFromHeader) + get("/sdk/v2/flags") { + val deployment = this.call.request.getDeploymentKey() + val result = try { + evaluationProxy.getSerializedFlagConfigs(deployment) + } catch (e: HttpErrorResponseException) { + call.respond(HttpStatusCode.fromValue(e.status), e.message) + return@get } + call.respond(result) + } - post("/sdk/vardata") { - call.evaluate(evaluationProxy, ApplicationRequest::getUserFromBody) + get("/sdk/v2/cohorts/{cohortId}/description") { + val deployment = this.call.request.getDeploymentKey() + val cohortId = this.call.parameters["cohortId"] + val result = try { + evaluationProxy.getSerializedCohortDescription(deployment, cohortId) + } catch (e: HttpErrorResponseException) { + call.respond(HttpStatusCode.fromValue(e.status), e.message) + return@get } + call.respond(result) + } - get("/v1/vardata") { - call.evaluate(evaluationProxy, ApplicationRequest::getUserFromQuery) + get("/sdk/v2/cohorts/{cohortId}/members") { + val deployment = this.call.request.getDeploymentKey() + val cohortId = this.call.parameters["cohortId"] + val result = try { + evaluationProxy.getSerializedCohortMembers(deployment, cohortId) + } catch (e: HttpErrorResponseException) { + call.respond(HttpStatusCode.fromValue(e.status), e.message) + return@get } + call.respond(result) + } - post("/v1/vardata") { - call.evaluate(evaluationProxy, ApplicationRequest::getUserFromBody) + get("/sdk/v2/users/{userId}/cohorts") { + val deployment = this.call.request.getDeploymentKey() + val userId = this.call.parameters["userId"] + val result = try { + evaluationProxy.getSerializedCohortMembershipsForUser(deployment, userId) + } catch (e: HttpErrorResponseException) { + call.respond(HttpStatusCode.fromValue(e.status), e.message) + return@get } - get("/status") { - call.respond("OK") + call.respond(result) + } + + get("/sdk/v2/groups/{groupType}/{groupName}/cohorts") { + val deployment = this.call.request.getDeploymentKey() + val groupType = this.call.parameters["groupType"] + val groupName = this.call.parameters["groupName"] + val result = try { + evaluationProxy.getSerializedCohortMembershipsForGroup(deployment, groupType, groupName) + } catch (e: HttpErrorResponseException) { + call.respond(HttpStatusCode.fromValue(e.status), e.message) + return@get } + call.respond(result) + } + + // Remote Evaluation V2 Endpoints + + get("/sdk/v2/vardata") { + call.evaluate(evaluationProxy, ApplicationRequest::getUserFromHeader) + } + + post("/sdk/v2/vardata") { + call.evaluate(evaluationProxy, ApplicationRequest::getUserFromBody) + } + + // Remote Evaluation V1 endpoints + + get("/sdk/vardata") { + call.evaluateV1(evaluationProxy, ApplicationRequest::getUserFromHeader) + } + + post("/sdk/vardata") { + call.evaluateV1(evaluationProxy, ApplicationRequest::getUserFromBody) + } + + get("/v1/vardata") { + call.evaluateV1(evaluationProxy, ApplicationRequest::getUserFromQuery) } - }.start(wait = true) + + post("/v1/vardata") { + call.evaluateV1(evaluationProxy, ApplicationRequest::getUserFromBody) + } + + // Health Check + + get("/status") { + call.respond("OK") + } + } } -suspend fun ApplicationCall.evaluate(evaluationProxy: EvaluationProxy, userProvider: suspend ApplicationRequest.() -> SkylabUser) { +suspend fun ApplicationCall.evaluate( + evaluationProxy: EvaluationProxy, + userProvider: suspend ApplicationRequest.() -> Map +) { val result = try { // Deployment key is included in Authorization header with prefix "Api-Key " val deploymentKey = request.getDeploymentKey() @@ -157,6 +228,23 @@ suspend fun ApplicationCall.evaluate(evaluationProxy: EvaluationProxy, userProvi respond(result) } +suspend fun ApplicationCall.evaluateV1( + evaluationProxy: EvaluationProxy, + userProvider: suspend ApplicationRequest.() -> Map +) { + val result = try { + // Deployment key is included in Authorization header with prefix "Api-Key " + val deploymentKey = request.getDeploymentKey() + val user = request.userProvider() + val flagKeys = request.getFlagKeys() + evaluationProxy.serializedEvaluateV1(deploymentKey, user, flagKeys) + } catch (e: HttpErrorResponseException) { + respond(HttpStatusCode.fromValue(e.status), e.message) + return + } + respond(result) +} + /** * Get the deployment key from the request, included in Authorization header with prefix "Api-Key " */ @@ -188,37 +276,45 @@ private fun ApplicationRequest.getFlagKeys(): Set { /** * Get the user from the header. Used for SDK GET requests. */ -private fun ApplicationRequest.getUserFromHeader(): SkylabUser { +private fun ApplicationRequest.getUserFromHeader(): Map { val b64User = this.headers["X-Amp-Exp-User"] val userJson = Base64.getDecoder().decode(b64User).toString(Charsets.UTF_8) - return json.decodeFromString(userJson).convert() + return json.decodeFromString(userJson).toMap() } /** * Get the user from the body. Used for SDK/REST POST requests. */ -private suspend fun ApplicationRequest.getUserFromBody(): SkylabUser { +private suspend fun ApplicationRequest.getUserFromBody(): Map { val userJson = this.receiveChannel().toByteArray().toString(Charsets.UTF_8) - return json.decodeFromString(userJson).convert() + return json.decodeFromString(userJson).toMap() } /** * Get the user from the query. Used for REST GET requests. */ -private fun ApplicationRequest.getUserFromQuery(): SkylabUser { +private fun ApplicationRequest.getUserFromQuery(): JsonObject { val userId = this.queryParameters["user_id"] val deviceId = this.queryParameters["device_id"] val context = this.queryParameters["context"] - var user = if (context != null) { - json.decodeFromString(context).convert() + var user: JsonObject = if (context != null) { + json.decodeFromString(context) } else { - SkylabUser() + JsonObject(emptyMap()) } if (userId != null) { - user = user.copy(userId = userId) + user = JsonObject( + user.toMutableMap().apply { + put("user_id", JsonPrimitive(userId)) + } + ) } if (deviceId != null) { - user = user.copy(deviceId = deviceId) + user = JsonObject( + user.toMutableMap().apply { + put("device_id", JsonPrimitive(userId)) + } + ) } return user }