Skip to content

Commit

Permalink
Merge pull request #2396 from constantine2nd/develop
Browse files Browse the repository at this point in the history
Add function getCurrentConsumerViaMtls
  • Loading branch information
simonredfern authored Jun 17, 2024
2 parents a92f9d2 + 35ce523 commit dc056fc
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 4 deletions.
4 changes: 2 additions & 2 deletions obp-api/src/main/scala/code/api/util/APIUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2976,8 +2976,8 @@ object APIUtil extends MdcLoggable with CustomJsonFormats{
Some(consent.jsonWebToken),
// Note: At this point we are getting the Consumer from the Consumer in the Consent.
// This may later be cross checked via the value in consumer_validation_method_for_consent.
// TODO: Get the source of truth for Consumer (e.g. CONSUMER_CERTIFICATE) as early as possible.
cc.copy(consumer = Consumers.consumers.vend.getConsumerByConsumerId(consent.consumerId))
// Get the source of truth for Consumer (e.g. CONSUMER_CERTIFICATE) as early as possible.
cc.copy(consumer = Consent.getCurrentConsumerViaMtls(callContext = cc))
)
case _ =>
JwtUtil.checkIfStringIsJWTValue(consentValue.getOrElse("")).isDefined match {
Expand Down
18 changes: 16 additions & 2 deletions obp-api/src/main/scala/code/api/util/ConsentUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ object Consent extends MdcLoggable {
case _ => None
}
}
/**
* Purpose of this helper function is to get the Consumer via MTLS info i.e. PEM certificate.
* @return the boxed Consumer
*/
def getCurrentConsumerViaMtls(callContext: CallContext): Box[Consumer] = {
val clientCert: String = APIUtil.`getPSD2-CERT`(callContext.requestHeaders)
.getOrElse(SecureRandomUtil.csprng.nextLong().toString)
def removeBreakLines(input: String) = input
.replace("\n", "")
.replace("\r", "")
Consumers.consumers.vend.getConsumerByPemCertificate(clientCert).or(
Consumers.consumers.vend.getConsumerByPemCertificate(removeBreakLines(clientCert))
)
}

private def verifyHmacSignedJwt(jwtToken: String, c: MappedConsent): Boolean = {
JwtUtil.verifyHmacSignedJwt(jwtToken, c.secret)
Expand Down Expand Up @@ -357,7 +371,7 @@ object Consent extends MdcLoggable {
try {
val consent = net.liftweb.json.parse(jsonAsString).extract[ConsentJWT]
// Set Consumer into Call Context
val consumer = Consumers.consumers.vend.getConsumerByConsumerId(consent.aud)
val consumer = getCurrentConsumerViaMtls(callContext)
val updatedCallContext = callContext.copy(consumer = consumer)
checkConsent(consent, consentAsJwt, updatedCallContext) match { // Check is it Consent-JWT expired
case (Full(true)) => // OK
Expand Down Expand Up @@ -466,7 +480,7 @@ object Consent extends MdcLoggable {
Consents.consentProvider.vend.getConsentByConsentId(consentId) match {
case Full(storedConsent) =>
// Set Consumer into Call Context
val consumer = Consumers.consumers.vend.getConsumerByConsumerId(storedConsent.consumerId)
val consumer = getCurrentConsumerViaMtls(callContext)
val updatedCallContext = callContext.copy(consumer = consumer)
// This function MUST be called only once per call. I.e. it's date dependent
val (canBeUsed, currentCounterState) = checkFrequencyPerDay(storedConsent)
Expand Down
2 changes: 2 additions & 0 deletions obp-api/src/main/scala/code/consumer/ConsumerProvider.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ trait ConsumersProvider {
def getConsumerByPrimaryId(id: Long): Box[Consumer]
def getConsumerByConsumerKey(consumerKey: String): Box[Consumer]
def getConsumerByConsumerKeyFuture(consumerKey: String): Future[Box[Consumer]]
def getConsumerByPemCertificate(pem: String): Box[Consumer]
def getConsumerByConsumerId(consumerId: String): Box[Consumer]
def getConsumerByConsumerIdFuture(consumerId: String): Future[Box[Consumer]]
def getConsumersByUserIdFuture(userId: String): Future[List[Consumer]]
Expand Down Expand Up @@ -61,6 +62,7 @@ class RemotedataConsumersCaseClasses {
case class getConsumerByPrimaryId(id: Long)
case class getConsumerByConsumerKey(consumerKey: String)
case class getConsumerByConsumerKeyFuture(consumerKey: String)
case class getConsumerByPemCertificate(pem: String)
case class getConsumerByConsumerId(consumerId: String)
case class getConsumerByConsumerIdFuture(consumerId: String)
case class getConsumersByUserIdFuture(userId: String)
Expand Down
4 changes: 4 additions & 0 deletions obp-api/src/main/scala/code/model/OAuth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ object MappedConsumersProvider extends ConsumersProvider with MdcLoggable {
}
}

def getConsumerByPemCertificate(pem: String): Box[Consumer] = {
Consumer.find(By(Consumer.clientCertificate, pem))
}

def getConsumerByConsumerId(consumerId: String): Box[Consumer] = {
Consumer.find(By(Consumer.consumerId, consumerId))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ object RemotedataConsumers extends ObpActorInit with ConsumersProvider {
def getConsumerByPrimaryId(id: Long): Box[Consumer] = getValueFromFuture(
(actor ? cc.getConsumerByPrimaryId(id)).mapTo[Box[Consumer]]
)
def getConsumerByPemCertificate(pem: String): Box[Consumer] = getValueFromFuture(
(actor ? cc.getConsumerByPemCertificate(pem)).mapTo[Box[Consumer]]
)
def getConsumerByConsumerId(consumerId: String): Box[Consumer] = getValueFromFuture(
(actor ? cc.getConsumerByConsumerId(consumerId)).mapTo[Box[Consumer]]
)
Expand Down

0 comments on commit dc056fc

Please sign in to comment.