diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 77e39fa..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.gitignore b/.gitignore index ce181e5..5b0d958 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea .gradle -build \ No newline at end of file +build +.DS_Store diff --git a/build.gradle.kts b/build.gradle.kts index cb142cb..7b4d0db 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -6,24 +6,25 @@ plugins { } group = "io.github.omkar-tenkale" -version = "0.1.0" +version = "0.2.0" repositories { mavenCentral() } dependencies { - - val ktorVersion = "2.1.3" - implementation("io.ktor:ktor-server-core-jvm:$ktorVersion") - implementation("io.ktor:ktor-server-webjars:$ktorVersion") - implementation("io.ktor:ktor-server-auth:$ktorVersion") - implementation("io.ktor:ktor-server-auth-jwt-jvm:$ktorVersion") - testImplementation("io.ktor:ktor-server-netty-jvm:$ktorVersion") - testImplementation("io.ktor:ktor-server-content-negotiation:$ktorVersion") - testImplementation("io.ktor:ktor-serialization-jackson:$ktorVersion") - testImplementation("io.ktor:ktor-server-auth:$ktorVersion") - testImplementation("io.ktor:ktor-server-call-logging:$ktorVersion") + val ktor_version = "2.2.1" + implementation("io.ktor:ktor-server-core-jvm:$ktor_version") + implementation("io.ktor:ktor-server-webjars:$ktor_version") + implementation("io.ktor:ktor-server-auth:$ktor_version") + testImplementation("io.ktor:ktor-server-netty-jvm:$ktor_version") + testImplementation("io.ktor:ktor-server-content-negotiation:$ktor_version") + implementation("io.ktor:ktor-server-status-pages:$ktor_version") + testImplementation("io.ktor:ktor-serialization-jackson:$ktor_version") + testImplementation("io.ktor:ktor-server-auth:$ktor_version") + testImplementation("io.ktor:ktor-server-call-logging:$ktor_version") + testImplementation("io.ktor:ktor-server-test-host:$ktor_version") + testImplementation("org.junit.jupiter:junit-jupiter:5.8.1") } tasks.test { diff --git a/src/main/kotlin/io/github/omkartenkale/ktor_role_based_auth/RoleBasedAuthPlugin.kt b/src/main/kotlin/io/github/omkartenkale/ktor_role_based_auth/RoleBasedAuthPlugin.kt index 82ab80c..59db3f6 100644 --- a/src/main/kotlin/io/github/omkartenkale/ktor_role_based_auth/RoleBasedAuthPlugin.kt +++ b/src/main/kotlin/io/github/omkartenkale/ktor_role_based_auth/RoleBasedAuthPlugin.kt @@ -6,92 +6,114 @@ import io.ktor.server.auth.* import io.ktor.server.request.* import io.ktor.server.response.* import io.ktor.server.routing.* -import io.ktor.util.pipeline.* typealias Role = String +class RoleBasedAuthConfiguration { + var requiredRoles: Set = emptySet() + lateinit var authType: AuthType +} +enum class AuthType { + ALL, + ANY, + NONE, +} +class AuthorizedRouteSelector(private val description: String) : RouteSelector() { + override fun evaluate(context: RoutingResolveContext, segmentIndex: Int) = RouteSelectorEvaluation.Constant -class RoleBasedAuthConfiguration( - var any: Set? = null, - var all: Set? = null, - var none: Set? = null, -) + override fun toString(): String = "(authorize ${description})" +} -fun Route.withRole(role: Role, build: suspend PipelineContext.() -> Unit) = - withAnyRole(setOf(role), build) +class RoleBasedAuthPluginConfiguration { + var roleExtractor: ((Principal) -> Set) = { emptySet() } + private set -fun Route.withAllRoles(roles: Set, build: suspend PipelineContext.() -> Unit) { - install(RoleBasedAuthPlugin) { - all = roles.toSet() + fun extractRoles(extractor: (Principal) -> Set) { + roleExtractor = extractor } - handle { build() } + var throwErrorOnUnauthorizedResponse = false } -fun Route.withoutRoles(roles: Set, build: suspend PipelineContext.() -> Unit) { - install(RoleBasedAuthPlugin) { - none = roles.toSet() - } - handle { build() } +private lateinit var pluginGlobalConfig: RoleBasedAuthPluginConfiguration +fun AuthenticationConfig.roleBased(config:RoleBasedAuthPluginConfiguration.()->Unit){ + pluginGlobalConfig = RoleBasedAuthPluginConfiguration().apply(config) } - -fun Route.withAnyRole(roles: Set, build: suspend PipelineContext.() -> Unit) { - install(RoleBasedAuthPlugin) { - any = roles.toSet() +private fun Route.buildAuthorizedRoute( + requiredRoles: Set, + authType: AuthType, + build: Route.() -> Unit +): Route { + val authorizedRoute = createChild(AuthorizedRouteSelector(requiredRoles.joinToString(","))) + authorizedRoute.install(RoleBasedAuthPlugin) { + this.requiredRoles = requiredRoles + this.authType = authType } - handle { build() } + authorizedRoute.build() + return authorizedRoute } +fun Route.withRole(role: Role, build: Route.() -> Unit) = + buildAuthorizedRoute(requiredRoles = setOf(role),authType= AuthType.ALL, build = build) + +fun Route.withRoles(vararg roles: Role, build: Route.() -> Unit) = + buildAuthorizedRoute(requiredRoles = roles.toSet(),authType= AuthType.ALL, build = build) + +fun Route.withAnyRole(vararg roles: Role, build: Route.() -> Unit) = + buildAuthorizedRoute(requiredRoles = roles.toSet(),authType = AuthType.ANY, build = build) + +fun Route.withoutRoles(vararg roles: Role, build: Route.() -> Unit) = + buildAuthorizedRoute(requiredRoles = roles.toSet(), authType = AuthType.NONE,build = build) + val RoleBasedAuthPlugin = createRouteScopedPlugin(name = "RoleBasedAuthorization", createConfiguration = ::RoleBasedAuthConfiguration) { + if(::pluginGlobalConfig.isInitialized.not()){ + error("RoleBasedAuthPlugin not initialized. Setup plugin by calling AuthenticationConfig#roleBased in authenticate block") + } with(pluginConfig) { on(AuthenticationChecked) { call -> - val principal = call.principal() ?: error("Missing principal") - val roles = roleBasedAuthPluginConfiguration?.roleExtractor?.invoke(principal) - ?: error("RoleBasedAuthPlugin is not initialized,You can initialize it by calling 'installRoleBasedAuthPlugin()'") + val principal = call.principal() ?: return@on + val userRoles = pluginGlobalConfig.roleExtractor.invoke(principal) val denyReasons = mutableListOf() - all?.let { - val missing = it - roles - if (missing.isNotEmpty()) { - denyReasons += "Principal $principal lacks required role(s) ${missing.joinToString(" and ")}" + + when(authType) { + AuthType.ALL -> { + val missing = requiredRoles - userRoles + if (missing.isNotEmpty()) { + denyReasons += "Principal lacks required role(s) ${missing.joinToString(" and ")}" + } } - } - any?.let { - if (it.none { it in roles }) { - denyReasons += "Principal $principal has none of the sufficient role(s) ${ - it.joinToString( - " or " - ) - }" + AuthType.ANY -> { + if (userRoles.none { it in requiredRoles }) { + denyReasons += "Principal has none of the sufficient role(s) ${ + requiredRoles.joinToString( + " or " + ) + }" + } } - } - none?.let { - if (it.any { it in roles }) { - denyReasons += "Principal $principal has forbidden role(s) ${ - (it.intersect(roles)).joinToString( - " and " - ) - }" + AuthType.NONE -> { + if (userRoles.any{ it in requiredRoles}) { + denyReasons += "Principal has forbidden role(s) ${ + (requiredRoles.intersect(userRoles)).joinToString( + " and " + ) + }" + + } } } if (denyReasons.isNotEmpty()) { - val message = denyReasons.joinToString(". ") - println("Authorization failed for ${call.request.path()}. $message") - call.respond(HttpStatusCode.Forbidden) + if(pluginGlobalConfig.throwErrorOnUnauthorizedResponse){ + throw UnauthorizedAccessException(denyReasons) + }else{ + val message = denyReasons.joinToString(". ") + if(application.developmentMode){ + application.log.warn("Authorization failed for ${call.request.path()} $message") + } + call.respond(HttpStatusCode.Forbidden) + } } } } } -private var roleBasedAuthPluginConfiguration: RoleBasedAuthPluginConfiguration? = null -fun Application.installRoleBasedAuthPlugin(configuration: RoleBasedAuthPluginConfiguration.() -> Unit) { - roleBasedAuthPluginConfiguration = RoleBasedAuthPluginConfiguration().apply { configuration() } -} - - -class RoleBasedAuthPluginConfiguration { - var roleExtractor: ((Principal) -> Set)? = null - private set - - fun extractRoles(extractor: (Principal) -> Set) { - roleExtractor = extractor - } -} \ No newline at end of file +class UnauthorizedAccessException(val denyReasons: MutableList) : Exception() diff --git a/src/test/kotlin/io/github/omkartenkale/ktor_role_based_auth/RoleBasedAuthPluginTest.kt b/src/test/kotlin/io/github/omkartenkale/ktor_role_based_auth/RoleBasedAuthPluginTest.kt new file mode 100644 index 0000000..ebc0a6f --- /dev/null +++ b/src/test/kotlin/io/github/omkartenkale/ktor_role_based_auth/RoleBasedAuthPluginTest.kt @@ -0,0 +1,178 @@ +package io.github.omkartenkale.ktor_role_based_auth + +import io.ktor.client.request.basicAuth +import io.ktor.client.request.get +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.call +import io.ktor.server.application.install +import io.ktor.server.auth.UserIdPrincipal +import io.ktor.server.auth.authenticate +import io.ktor.server.auth.authentication +import io.ktor.server.auth.basic +import io.ktor.server.plugins.statuspages.StatusPages +import io.ktor.server.response.respond +import io.ktor.server.response.respondText +import io.ktor.server.routing.get +import io.ktor.server.routing.route +import io.ktor.server.testing.ApplicationTestBuilder +import io.ktor.server.testing.testApplication +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions.* + + +internal class RoleBasedAuthPluginTest { + + object Roles { + const val ADMIN = "ADMIN" + const val SUPER_ADMIN = "SUPER_ADMIN" + } + + private fun withServer( + throwErrorOnUnauthorizedResponse: Boolean = false, + block: suspend ApplicationTestBuilder.() -> Unit + ) { + val usersWithRoles = mapOf("Leon" to setOf(), "Amy" to setOf(Roles.ADMIN), "Jay" to setOf(Roles.SUPER_ADMIN)) + testApplication { + application { + if (throwErrorOnUnauthorizedResponse) { + install(StatusPages) { + exception { call, cause -> + call.respond(HttpStatusCode.InternalServerError) + } + } + } + authentication { + basic { + validate { credentials -> + if (usersWithRoles.containsKey(credentials.name) && credentials.password == "1234") { + UserIdPrincipal(credentials.name) + } else { + null + } + } + } + roleBased { + extractRoles { principal -> usersWithRoles[(principal as UserIdPrincipal).name]!! } + this.throwErrorOnUnauthorizedResponse = throwErrorOnUnauthorizedResponse + } + } + } + + routing { + route("/") { + + get { + call.respondText("Welcome!") + } + + authenticate { + + route("/profile") { + get { + call.respondText("Joined: 2 years ago") + } + } + + route("/dashboard") { + withAnyRole(Roles.ADMIN, Roles.SUPER_ADMIN) { + get { + call.respondText("Total users: 2443") + } + } + } + + route("/system-stats") { + withRole(Roles.SUPER_ADMIN) { + get { + call.respondText("CPU: 34%") + } + } + } + } + } + } + block() + } + } + + @Test + fun `No auth required for root route`() { + withServer { + with(client.get("/")) { + assertEquals(HttpStatusCode.OK, status) + } + } + } + + @Test + fun `Allow only authenticated users`() { + withServer { + with(client.get("/profile")) { + assertEquals(HttpStatusCode.Unauthorized, status) + } + with(client.get("/profile") { + basicAuth("Leon", "0000") + }) { + assertEquals(HttpStatusCode.Unauthorized, status) + } + with(client.get("/profile") { + basicAuth("Leon", "1234") + }) { + assertEquals(HttpStatusCode.OK, status) + } + } + } + + @Test + fun `Allow only users with admin role`() { + withServer { + with(client.get("/dashboard")) { + assertEquals(status, HttpStatusCode.Unauthorized) + } + with(client.get("/dashboard") { + basicAuth("Leon", "1234") + }) { + assertEquals(HttpStatusCode.Forbidden, status) + } + + with(client.get("/dashboard") { + basicAuth("Amy", "1234") + }) { + assertEquals(HttpStatusCode.OK, status) + } + } + } + + @Test + fun `Allow only users with superadmin role`() { + withServer { + with(client.get("/system-stats")) { + assertEquals(HttpStatusCode.Unauthorized, status) + } + with(client.get("/system-stats") { + basicAuth("Amy", "1234") + }) { + assertEquals(HttpStatusCode.Forbidden, status) + } + with(client.get("/system-stats") { + basicAuth("Jay", "1234") + }) { + assertEquals(HttpStatusCode.OK, status) + } + } + } + + @Test + fun `Test throwErrorOnUnauthorizedResponse`() { + withServer(true) { + with(client.get("/system-stats")) { + assertEquals(HttpStatusCode.Unauthorized, status) + } + with(client.get("/system-stats") { + basicAuth("Amy", "1234") + }) { + assertEquals(HttpStatusCode.InternalServerError, status) + } + } + } +}