Skip to content

Commit

Permalink
fix: server accessing api and secret key from cohorts request
Browse files Browse the repository at this point in the history
  • Loading branch information
bgiori committed Aug 7, 2024
1 parent 27d8cc2 commit 66962e3
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 23 deletions.
2 changes: 1 addition & 1 deletion core/src/main/kotlin/cohort/CohortLoader.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ internal class CohortLoader(
}
}
if (cohort != null) {
log.info("Cohort download complete. {}", cohort)
cohortStorage.putCohort(cohort)
}
log.info("Cohort download complete. {}", cohort ?: cohortId)
} catch (t: Throwable) {
// Don't throw if we fail to download the cohort. We
// prefer to continue to update flags.
Expand Down
22 changes: 12 additions & 10 deletions core/src/main/kotlin/cohort/CohortStorage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,18 @@ internal class RedisCohortStorage(
val jsonEncodedDescription = json.encodeToString(description)
val existingDescription = getCohortDescription(description.id)
if ((existingDescription?.lastModified ?: 0L) < description.lastModified) {
redis.sadd(
RedisKey.CohortMembers(
prefix,
projectId,
description.id,
description.groupType,
description.lastModified,
),
cohort.members,
)
if (cohort.members.isNotEmpty()) {
redis.sadd(
RedisKey.CohortMembers(
prefix,
projectId,
description.id,
description.groupType,
description.lastModified,
),
cohort.members,
)
}
redis.hset(RedisKey.CohortDescriptions(prefix, projectId), mapOf(description.id to jsonEncodedDescription))
if (existingDescription != null) {
redis.expire(
Expand Down
36 changes: 24 additions & 12 deletions service/src/main/kotlin/Server.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.amplitude.util.json
import com.amplitude.util.logger
import com.amplitude.util.stringEnv
import com.amplitude.util.toAnyMap
import io.ktor.http.Headers
import io.ktor.serialization.kotlinx.json.json
import io.ktor.server.application.Application
import io.ktor.server.application.ApplicationCall
Expand All @@ -22,10 +23,12 @@ import io.ktor.server.response.respond
import io.ktor.server.routing.get
import io.ktor.server.routing.post
import io.ktor.server.routing.routing
import io.ktor.util.decodeBase64String
import io.ktor.util.toByteArray
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import org.jetbrains.annotations.VisibleForTesting
import java.io.FileNotFoundException
import java.util.Base64

Expand Down Expand Up @@ -115,13 +118,13 @@ fun Application.proxyServer() {
// Local Evaluation

get("/sdk/v2/flags") {
val deployment = this.call.request.getDeploymentKey()
val deployment = this.call.request.headers.getDeploymentKey()
val result = evaluationProxy.getFlagConfigs(deployment)
call.respond(result.status, result.body)
}

get("/sdk/v1/cohort/{cohortId}") {
val (apiKey, secretKey) = this.call.request.getApiAndSecretKey()
val (apiKey, secretKey) = this.call.request.headers.getApiAndSecretKey()
val cohortId = this.call.parameters["cohortId"]
val maxCohortSize = this.call.request.queryParameters["maxCohortSize"]?.toIntOrNull()
val lastModified = this.call.request.queryParameters["lastModified"]?.toLongOrNull()
Expand All @@ -130,7 +133,7 @@ fun Application.proxyServer() {
}

get("/sdk/v2/memberships/{groupType}/{groupName}") {
val deployment = this.call.request.getDeploymentKey()
val deployment = this.call.request.headers.getDeploymentKey()
val groupType = this.call.parameters["groupType"]
val groupName = this.call.parameters["groupName"]
val result = evaluationProxy.getCohortMemberships(deployment, groupType, groupName)
Expand Down Expand Up @@ -178,7 +181,7 @@ suspend fun ApplicationCall.evaluate(
userProvider: suspend ApplicationRequest.() -> Map<String, Any?>,
) {
// Deployment key is included in Authorization header with prefix "Api-Key "
val deploymentKey = request.getDeploymentKey()
val deploymentKey = request.headers.getDeploymentKey()
val user = request.userProvider()
val flagKeys = request.getFlagKeys()
val result = evaluationProxy.evaluate(deploymentKey, user, flagKeys)
Expand All @@ -190,7 +193,7 @@ suspend fun ApplicationCall.evaluateV1(
userProvider: suspend ApplicationRequest.() -> Map<String, Any?>,
) {
// Deployment key is included in Authorization header with prefix "Api-Key "
val deploymentKey = request.getDeploymentKey()
val deploymentKey = request.headers.getDeploymentKey()
val user = request.userProvider()
val flagKeys = request.getFlagKeys()
val result = evaluationProxy.evaluateV1(deploymentKey, user, flagKeys)
Expand All @@ -201,8 +204,9 @@ suspend fun ApplicationCall.evaluateV1(
* Get the deployment key from the request, included in Authorization header
* with prefix "Api-Key "
*/
private fun ApplicationRequest.getDeploymentKey(): String? {
val deploymentKey = this.headers["Authorization"]
@VisibleForTesting
internal fun Headers.getDeploymentKey(): String? {
val deploymentKey = this["Authorization"]
if (deploymentKey == null || !deploymentKey.startsWith("Api-Key", ignoreCase = true)) {
return null
}
Expand All @@ -213,16 +217,24 @@ private fun ApplicationRequest.getDeploymentKey(): String? {
* Get the API and secret key from the request, included in Authorization
* header as Basic auth.
*/
private fun ApplicationRequest.getApiAndSecretKey(): Pair<String?, String?> {
val authHeaderValue = this.headers["Authorization"]
@VisibleForTesting
internal fun Headers.getApiAndSecretKey(): Pair<String?, String?> {
val authHeaderValue = this["Authorization"]
if (authHeaderValue == null || !authHeaderValue.startsWith("Basic", ignoreCase = true)) {
return null to null
}
val segmentedAuthValue = authHeaderValue.substring("Basic ".length).split(":")
if (segmentedAuthValue.size < 2) {
try {
val segmentedAuthValue = authHeaderValue
.substring("Basic ".length)
.decodeBase64String()
.split(":")
if (segmentedAuthValue.size < 2) {
return null to null
}
return segmentedAuthValue[0] to segmentedAuthValue[1]
} catch (e: Exception) {
return null to null
}
return segmentedAuthValue[0] to segmentedAuthValue[1]
}

/**
Expand Down
21 changes: 21 additions & 0 deletions service/src/test/kotlin/ServerTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import com.amplitude.getApiAndSecretKey
import io.ktor.http.Headers
import io.ktor.util.encodeBase64
import kotlin.test.Test
import kotlin.test.assertEquals

class ServerTest {

@Test
fun `test get api and secret key`() {
val apiKey = "api"
val secretKey = "secret"
val headers = Headers.build {
set("Authorization", "Basic ${"$apiKey:$secretKey".encodeBase64()}")
}

val result = headers.getApiAndSecretKey()
assertEquals(apiKey, result.first)
assertEquals(secretKey, result.second)
}
}

0 comments on commit 66962e3

Please sign in to comment.