Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor errors, introduce type parameter in Error #165

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ object AccessTokenProvider {
override def requestToken(scope: Scope): F[ClientCredentialsToken.AccessTokenResponse] =
ClientCredentials
.requestToken(tokenUrl, clientId, clientSecret, scope)(backend)
.map(_.leftMap(OAuth2Exception).toTry)
.map(_.leftMap(_.toException).toTry)
.flatMap(backend.responseMonad.fromTry)

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,12 @@ import com.ocadotechnology.sttp.oauth2.common._
import io.circe.Decoder
import io.circe.refined._
import sttp.client3.ResponseAs
import com.ocadotechnology.sttp.oauth2.common.Error.OAuth2Error

import scala.concurrent.duration.FiniteDuration

object ClientCredentialsToken {

type Response = Either[Error, ClientCredentialsToken.AccessTokenResponse]

private[oauth2] implicit val bearerTokenResponseDecoder: Decoder[Either[OAuth2Error, AccessTokenResponse]] =
circe.eitherOrFirstError[AccessTokenResponse, OAuth2Error](
Decoder[AccessTokenResponse],
Decoder[OAuth2Error]
)
type Response = Either[Error[Throwable], ClientCredentialsToken.AccessTokenResponse]

val response: ResponseAs[Response, Any] =
common.responseWithCommonError[ClientCredentialsToken.AccessTokenResponse]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import java.time.Instant

object Introspection {

type Response = Either[common.Error, Introspection.TokenIntrospectionResponse]
type Response = Either[common.Error[Throwable], Introspection.TokenIntrospectionResponse]

val response: ResponseAs[Response, Any] =
common.responseWithCommonError[TokenIntrospectionResponse]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import sttp.client3.ResponseAs
object OAuth2Token {

// TODO: should be changed to Response[A] and allow custom responses, like in AuthorizationCodeGrant
type Response = Either[Error, ExtendedOAuth2TokenResponse]
type Response = Either[Error[Throwable], ExtendedOAuth2TokenResponse]

val response: ResponseAs[Response, Any] =
common.responseWithCommonError[ExtendedOAuth2TokenResponse]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ object PasswordGrantProvider {
)(
backend: SttpBackend[F, Any]
): PasswordGrantProvider[F] = { (user: User, scope: Scope) =>
PasswordGrant.requestToken(tokenUrl, user, clientId, clientSecret, scope)(backend).map(_.leftMap(OAuth2Exception)).rethrow
PasswordGrant.requestToken(tokenUrl, user, clientId, clientSecret, scope)(backend).map(_.leftMap(_.toException)).rethrow
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ object TokenIntrospection {
override def introspect(token: Secret[String]): F[Introspection.TokenIntrospectionResponse] =
ClientCredentials
.introspectToken(tokenIntrospectionUrl, clientId, clientSecret, token)(backend)
.map(_.leftMap(OAuth2Exception).toTry)
.map(_.leftMap(_.toException).toTry)
.flatMap(backend.responseMonad.fromTry)

}
Expand Down
81 changes: 47 additions & 34 deletions oauth2/src/main/scala/com/ocadotechnology/sttp/oauth2/common.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.ocadotechnology.sttp.oauth2

import cats.syntax.all._
import com.ocadotechnology.sttp.oauth2.common.Error.OAuth2Error
import com.ocadotechnology.sttp.oauth2.common.Error.{OAuth2Error, errorDecoder}
import com.ocadotechnology.sttp.oauth2.common.Error.OAuth2ErrorResponse.InvalidClient
import com.ocadotechnology.sttp.oauth2.common.Error.OAuth2ErrorResponse.InvalidGrant
import com.ocadotechnology.sttp.oauth2.common.Error.OAuth2ErrorResponse.InvalidRequest
Expand Down Expand Up @@ -40,21 +40,30 @@ object common {
def refine: RefineMPartiallyApplied[Refined, ValidScope] = refineMV[ValidScope]
}

sealed trait Error extends Throwable with Product with Serializable
final case class OAuth2Exception[E <: Throwable](inner: Error[E]) extends Exception(s"${inner.message}: ${inner.error.getMessage}", inner.error)

sealed trait Error[E] extends Product with Serializable {
def error: E
def message: String
}

object Error {

final case class HttpClientError(statusCode: StatusCode, cause: Throwable)
extends Exception(s"Client call resulted in error ($statusCode): ${cause.getMessage}", cause)
with Error
final case class HttpClientError(statusCode: StatusCode, error: Throwable) extends Error[Throwable] {
val message: String = s"Client call resulted in error ($statusCode)"
}

sealed trait OAuth2Error extends Error
sealed trait OAuth2Error[E] extends Error[E]

/** Token errors as listed in documentation: https://tools.ietf.org/html/rfc6749#section-5.2
*/
final case class OAuth2ErrorResponse(errorType: OAuth2ErrorResponse.OAuth2ErrorResponseType, errorDescription: Option[String])
extends Exception(errorDescription.fold(s"$errorType")(description => s"$errorType: $description"))
with OAuth2Error
final case class OAuth2ErrorResponse[E](
errorType: OAuth2ErrorResponse.OAuth2ErrorResponseType,
errorDescription: Option[String],
error: E
) extends OAuth2Error[E] {
val message: String = errorDescription.fold(s"$errorType")(description => s"$errorType: $description")
}

object OAuth2ErrorResponse {

Expand All @@ -74,41 +83,45 @@ object common {

}

final case class UnknownOAuth2Error(error: String, errorDescription: Option[String])
extends Exception(
errorDescription.fold(s"Unknown OAuth2 error type: $error")(description =>
s"Unknown OAuth2 error type: $error, description: $description"
)
final case class UnknownOAuth2Error[E](errorString: String, errorDescription: Option[String], error: E)
extends OAuth2Error[E] {

val message: String = errorDescription.fold(s"Unknown OAuth2 error type: $errorString")(description =>
s"Unknown OAuth2 error type: $errorString, description: $description"
)
with OAuth2Error

implicit val errorDecoder: Decoder[OAuth2Error] =
Decoder.forProduct2[OAuth2Error, String, Option[String]]("error", "error_description") { (error, description) =>
}

def errorDecoder[E](cause: E): Decoder[OAuth2Error[E]] =
Decoder.forProduct2[OAuth2Error[E], String, Option[String]]("error", "error_description") { (error, description) =>
error match {
case "invalid_request" => OAuth2ErrorResponse(InvalidRequest, description)
case "invalid_client" => OAuth2ErrorResponse(InvalidClient, description)
case "invalid_grant" => OAuth2ErrorResponse(InvalidGrant, description)
case "unauthorized_client" => OAuth2ErrorResponse(UnauthorizedClient, description)
case "unsupported_grant_type" => OAuth2ErrorResponse(UnsupportedGrantType, description)
case "invalid_scope" => OAuth2ErrorResponse(InvalidScope, description)
case unknown => UnknownOAuth2Error(unknown, description)
case "invalid_request" => OAuth2ErrorResponse(InvalidRequest, description, cause)
case "invalid_client" => OAuth2ErrorResponse(InvalidClient, description, cause)
case "invalid_grant" => OAuth2ErrorResponse(InvalidGrant, description, cause)
case "unauthorized_client" => OAuth2ErrorResponse(UnauthorizedClient, description, cause)
case "unsupported_grant_type" => OAuth2ErrorResponse(UnsupportedGrantType, description, cause)
case "invalid_scope" => OAuth2ErrorResponse(InvalidScope, description, cause)
case unknown => UnknownOAuth2Error(unknown, description, cause)
}
}

}

private[oauth2] def responseWithCommonError[A](implicit decoder: Decoder[A]): ResponseAs[Either[Error, A], Any] =
asJson[A].mapWithMetadata { case (either, meta) =>
either match {
case Left(HttpError(response, statusCode)) if statusCode.isClientError =>
decode[OAuth2Error](response)
.fold(error => Error.HttpClientError(statusCode, DeserializationException(response, error)).asLeft[A], _.asLeft[A])
case Left(sttpError) => Left(Error.HttpClientError(meta.code, sttpError))
case Right(value) => value.asRight[Error]
}
private[oauth2] def responseWithCommonError[A](implicit decoder: Decoder[A]): ResponseAs[Either[Error[Throwable], A], Any] =
asJson[A].mapWithMetadata {
case (either, meta) =>
either match {
case Left(HttpError(response, statusCode)) if statusCode.isClientError =>
decode[OAuth2Error[Throwable]](response)(errorDecoder(HttpError(response, statusCode)))
.fold(error => Error.HttpClientError(statusCode, DeserializationException(response, error)).asLeft[A], _.asLeft[A])
case Left(sttpError) => Left(Error.HttpClientError(meta.code, sttpError))
case Right(value) => value.asRight[Error[Throwable]]
}
}

final case class OAuth2Exception(error: Error) extends Exception(error.getMessage, error)
final implicit class ErrorToException[E <: Throwable](error: Error[E]) {
def toException: Throwable = OAuth2Exception(error)
}

final case class ParsingException(msg: String) extends Exception(msg)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@ package com.ocadotechnology.sttp.oauth2

import com.ocadotechnology.sttp.oauth2.common.Scope
import com.ocadotechnology.sttp.oauth2.common.Error

import org.scalatest.wordspec.AnyWordSpec
import org.scalatest.matchers.should.Matchers
import sttp.model.Uri
import sttp.client3.testing._
import sttp.monad.TryMonad

import scala.util.Try
import eu.timepit.refined.types.string.NonEmptyString
import eu.timepit.refined.auto._
import org.scalatest.TryValues
import org.scalatest.EitherValues
import sttp.model.StatusCode
import sttp.model.Method
import sttp.client3.Request
import sttp.client3.{HttpError, Request}

import scala.concurrent.duration._

Expand Down Expand Up @@ -72,20 +72,19 @@ class ClientCredentialsSpec extends AnyWordSpec with Matchers with TryValues wit

oauth2Errors.foreach { case (errorKey, errorDescription, statusCode, error) =>
s"support $errorKey OAuth2 error" in {
val body = s"""
|{
|"error":"$errorKey",
|"error_description":"$errorDescription"
|}
|""".stripMargin

val testingBackend = SttpBackendStub(TryMonad)
.whenRequestMatches(validTokenRequest)
.thenRespond(
s"""
{
"error":"$errorKey",
"error_description":"$errorDescription"
}
""",
statusCode
)

requestToken(testingBackend).success.value.left.value shouldBe Error.OAuth2ErrorResponse(error, Some(errorDescription))
.thenRespond(body, statusCode)

requestToken(testingBackend).success.value.left.value shouldBe
Error.OAuth2ErrorResponse(error, Some(errorDescription), HttpError(body, statusCode))
}
}

Expand Down Expand Up @@ -133,20 +132,19 @@ class ClientCredentialsSpec extends AnyWordSpec with Matchers with TryValues wit

oauth2Errors.foreach { case (errorKey, errorDescription, statusCode, error) =>
s"support $errorKey OAuth2 error" in {
val body = s"""
|{
|"error":"$errorKey",
|"error_description":"$errorDescription"
|}
|""".stripMargin

val testingBackend = SttpBackendStub(TryMonad)
.whenRequestMatches(validIntrospectRequest)
.thenRespond(
s"""
{
"error":"$errorKey",
"error_description":"$errorDescription"
}
""",
statusCode
)

introspectToken(testingBackend).success.value.left.value shouldBe Error.OAuth2ErrorResponse(error, Some(errorDescription))
.thenRespond(body, statusCode)

introspectToken(testingBackend).success.value.left.value shouldBe
Error.OAuth2ErrorResponse(error, Some(errorDescription), HttpError(body, statusCode))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.ocadotechnology.sttp.oauth2
import com.ocadotechnology.sttp.oauth2.ClientCredentialsToken.AccessTokenResponse
import com.ocadotechnology.sttp.oauth2.common.Error.OAuth2ErrorResponse.InvalidClient
import com.ocadotechnology.sttp.oauth2.common._
import io.circe.Decoder
import io.circe.DecodingFailure
import io.circe.literal._
import org.scalatest.EitherValues
Expand All @@ -15,6 +16,13 @@ import scala.concurrent.duration._

class ClientCredentialsTokenDeserializationSpec extends AnyFlatSpec with Matchers with EitherValues {


private[oauth2] implicit val bearerTokenResponseDecoder: Decoder[Either[OAuth2Error[Unit], AccessTokenResponse]] =
circe.eitherOrFirstError[AccessTokenResponse, OAuth2Error[Unit]](
Decoder[AccessTokenResponse],
Decoder[OAuth2Error[Unit]](common.Error.errorDecoder(()))
)

"token response JSON" should "be deserialized to proper response" in {
val json =
// language=JSON
Expand All @@ -27,7 +35,7 @@ class ClientCredentialsTokenDeserializationSpec extends AnyFlatSpec with Matcher
"token_type": "Bearer"
}"""

val response = json.as[Either[OAuth2Error, AccessTokenResponse]]
val response = json.as[Either[OAuth2Error[Unit], AccessTokenResponse]]
response shouldBe Right(
Right(
ClientCredentialsToken.AccessTokenResponse(
Expand All @@ -52,7 +60,7 @@ class ClientCredentialsTokenDeserializationSpec extends AnyFlatSpec with Matcher
"token_type": "VeryBadType"
}"""

json.as[Either[OAuth2Error, AccessTokenResponse]].left.value shouldBe a[DecodingFailure]
json.as[Either[OAuth2Error[Unit], AccessTokenResponse]].left.value shouldBe a[DecodingFailure]
}

"JSON with error" should "be deserialized to proper type" in {
Expand All @@ -64,8 +72,8 @@ class ClientCredentialsTokenDeserializationSpec extends AnyFlatSpec with Matcher
"error_uri": "https://pandasso.pages.tech.lastmile.com/documentation/support/panda-errors/token/#invalid_client_client_invalid"
}"""

json.as[Either[OAuth2Error, AccessTokenResponse]] shouldBe Right(
Left(OAuth2ErrorResponse(InvalidClient, Some("Client is missing or invalid.")))
json.as[Either[OAuth2Error[Unit], AccessTokenResponse]] shouldBe Right(
Left(OAuth2ErrorResponse(InvalidClient, Some("Client is missing or invalid."), ()))
)
}

Expand All @@ -76,8 +84,8 @@ class ClientCredentialsTokenDeserializationSpec extends AnyFlatSpec with Matcher
"error": "invalid_client"
}"""

json.as[Either[OAuth2Error, AccessTokenResponse]] shouldBe Right(
Left(OAuth2ErrorResponse(InvalidClient, None))
json.as[Either[OAuth2Error[Unit], AccessTokenResponse]] shouldBe Right(
Left(OAuth2ErrorResponse(InvalidClient, None, ()))
)
}

Expand Down
Loading