Skip to content

Commit

Permalink
PI-2342 Propagate trace context over SNS for distributed tracing (#3995)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcus-bcl authored Jul 8, 2024
1 parent 02c42b3 commit 0a5f52e
Show file tree
Hide file tree
Showing 48 changed files with 170 additions and 109 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package uk.gov.justice.digital.hmpps.telemetry

import com.microsoft.applicationinsights.TelemetryClient
import com.microsoft.applicationinsights.telemetry.TelemetryContext
import org.slf4j.LoggerFactory
import org.springframework.scheduling.annotation.Async
import org.springframework.stereotype.Service
import java.lang.Exception

@Service
class TelemetryService(private val telemetryClient: TelemetryClient = TelemetryClient()) {
Expand All @@ -27,4 +27,6 @@ class TelemetryService(private val telemetryClient: TelemetryClient = TelemetryC
log.debug("{} {} {}", exception.message, properties, metrics)
telemetryClient.trackException(exception, properties, metrics)
}

fun getContext(): TelemetryContext = telemetryClient.context
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package uk.gov.justice.digital.hmpps.listener

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.jacksonTypeRef
import io.awspring.cloud.sqs.annotation.SqsListener
import io.awspring.cloud.sqs.listener.AsyncAdapterBlockingExecutionFailedException
import io.awspring.cloud.sqs.listener.ListenerExecutionFailedException
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.instrumentation.annotations.SpanAttribute
import io.opentelemetry.instrumentation.annotations.WithSpan
import io.sentry.Sentry
import io.sentry.spring.jakarta.tracing.SentryTransaction
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression
import org.springframework.context.annotation.Conditional
import org.springframework.dao.CannotAcquireLockException
Expand All @@ -17,36 +22,35 @@ import org.springframework.transaction.CannotCreateTransactionException
import org.springframework.transaction.UnexpectedRollbackException
import org.springframework.web.client.RestClientException
import uk.gov.justice.digital.hmpps.config.AwsCondition
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.messaging.NotificationHandler
import uk.gov.justice.digital.hmpps.retry.retry
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.extractSpanContext
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.startSpan
import java.util.concurrent.CompletionException

@Component
@Conditional(AwsCondition::class)
@ConditionalOnExpression("\${messaging.consumer.enabled:true} and '\${messaging.consumer.queue:}' != ''")
class AwsNotificationListener(
private val handler: NotificationHandler<*>
private val handler: NotificationHandler<*>,
private val objectMapper: ObjectMapper
) {
@SqsListener("\${messaging.consumer.queue}")
@SentryTransaction(operation = "messaging")
@WithSpan(kind = SpanKind.CONSUMER)
fun receive(message: String) {
try {
retry(
3,
listOf(
RestClientException::class,
CannotAcquireLockException::class,
ObjectOptimisticLockingFailureException::class,
CannotCreateTransactionException::class,
CannotGetJdbcConnectionException::class,
UnexpectedRollbackException::class
)
) { handler.handle(message) }
} catch (e: Throwable) {
Sentry.captureException(unwrapSqsExceptions(e))
throw e
@SentryTransaction(operation = "messaging")
@SqsListener("\${messaging.consumer.queue}")
fun receive(@SpanAttribute message: String) {
val attributes = objectMapper.readValue(message, jacksonTypeRef<Notification<String>>()).attributes
val span = attributes.extractSpanContext().startSpan(this::class.java.name, "receive", SpanKind.CONSUMER)
span.makeCurrent().use {
try {
retry(3, RETRYABLE_EXCEPTIONS) { handler.handle(message) }
} catch (e: Throwable) {
Sentry.captureException(unwrapSqsExceptions(e))
throw e
}
}
span.end()
}

fun unwrapSqsExceptions(e: Throwable): Throwable {
Expand All @@ -63,4 +67,16 @@ class AwsNotificationListener(
}
return cause
}

companion object {
private val log: Logger = LoggerFactory.getLogger(this::class.java)
val RETRYABLE_EXCEPTIONS = listOf(
RestClientException::class,
CannotAcquireLockException::class,
ObjectOptimisticLockingFailureException::class,
CannotCreateTransactionException::class,
CannotGetJdbcConnectionException::class,
UnexpectedRollbackException::class
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonAnyGetter
import com.fasterxml.jackson.annotation.JsonAnySetter
import com.fasterxml.jackson.annotation.JsonIgnore
import com.fasterxml.jackson.annotation.JsonProperty
import java.util.UUID
import java.util.*

data class Notification<T>(
@JsonProperty("Message") val message: T,
Expand All @@ -21,9 +21,14 @@ data class MessageAttributes(
constructor(eventType: String) : this(mutableMapOf("eventType" to MessageAttribute("String", eventType)))

override operator fun get(key: String): MessageAttribute? = attributes[key]

operator fun set(key: String, value: MessageAttribute) {
attributes[key] = value
}

operator fun set(key: String, value: String) {
set(key, MessageAttribute("String", value))
}
}

data class MessageAttribute(@JsonProperty("Type") val type: String, @JsonProperty("Value") val value: String)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package uk.gov.justice.digital.hmpps.publisher

import com.fasterxml.jackson.databind.ObjectMapper
import io.awspring.cloud.sqs.operations.SqsTemplate
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.instrumentation.annotations.WithSpan
import org.springframework.beans.factory.annotation.Value
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty
import org.springframework.context.annotation.Conditional
Expand All @@ -10,6 +12,7 @@ import org.springframework.messaging.support.MessageBuilder
import org.springframework.stereotype.Component
import uk.gov.justice.digital.hmpps.config.AwsCondition
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.withSpanContext
import java.util.concurrent.Semaphore

@Component
Expand All @@ -23,6 +26,8 @@ class QueuePublisher(
) : NotificationPublisher {

private val permit = Semaphore(limit, true)

@WithSpan(kind = SpanKind.PRODUCER)
override fun publish(notification: Notification<*>) {
notification.message?.also { _ ->
permit.acquire()
Expand All @@ -35,12 +40,7 @@ class QueuePublisher(
}

private fun Notification<*>.asMessage() = MessageBuilder.createMessage(
objectMapper.writeValueAsString(
Notification(
message = objectMapper.writeValueAsString(message),
attributes
)
),
MessageHeaders(attributes.map { it.key to it.value.value }.toMap())
objectMapper.writeValueAsString(Notification(objectMapper.writeValueAsString(message), attributes)),
MessageHeaders(attributes.map { it.key to it.value.value }.toMap()).withSpanContext()
)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package uk.gov.justice.digital.hmpps.publisher

import io.awspring.cloud.sns.core.SnsTemplate
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.instrumentation.annotations.WithSpan
import org.springframework.beans.factory.annotation.Value
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty
import org.springframework.context.annotation.Conditional
Expand All @@ -10,6 +12,7 @@ import org.springframework.messaging.support.MessageBuilder
import org.springframework.stereotype.Component
import uk.gov.justice.digital.hmpps.config.AwsCondition
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.withSpanContext

@Primary
@Component
Expand All @@ -19,12 +22,13 @@ class TopicPublisher(
private val notificationTemplate: SnsTemplate,
@Value("\${messaging.producer.topic}") private val topic: String
) : NotificationPublisher {
@WithSpan(kind = SpanKind.PRODUCER)
override fun publish(notification: Notification<*>) {
notification.message?.let { message ->
notificationTemplate.convertAndSend(topic, message) { msg ->
MessageBuilder.createMessage(
msg.payload,
MessageHeaders(notification.attributes.map { it.key to it.value.value }.toMap())
MessageHeaders(notification.attributes.map { it.key to it.value.value }.toMap()).withSpanContext(),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package uk.gov.justice.digital.hmpps.telemetry

import io.opentelemetry.api.GlobalOpenTelemetry
import io.opentelemetry.api.trace.Span
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.context.Context
import io.opentelemetry.context.propagation.TextMapGetter
import org.springframework.messaging.MessageHeaders
import uk.gov.justice.digital.hmpps.message.HmppsDomainEvent
import uk.gov.justice.digital.hmpps.message.MessageAttributes
import uk.gov.justice.digital.hmpps.message.Notification

object TelemetryMessagingExtensions {
fun MessageHeaders.withSpanContext(): MessageHeaders {
val map = this.toMutableMap()
val context = Context.current().with(Span.current())
GlobalOpenTelemetry.getPropagators().textMapPropagator
.inject(context, map) { carrier, key, value -> carrier!![key] = value }
return MessageHeaders(map)
}

fun MessageAttributes.extractSpanContext(): Context {
val getter = object : TextMapGetter<MessageAttributes> {
override fun keys(carrier: MessageAttributes) = carrier.keys
override fun get(carrier: MessageAttributes?, key: String) = carrier?.get(key)?.value
}
return GlobalOpenTelemetry.getPropagators().textMapPropagator.extract(Context.current(), this, getter)
}

fun Context.startSpan(scopeName: String, spanName: String, spanKind: SpanKind = SpanKind.INTERNAL): Span {
val tracer = GlobalOpenTelemetry.getTracer(scopeName)
return tracer.spanBuilder(spanName).setParent(this).setSpanKind(spanKind).startSpan()
}

fun TelemetryService.hmppsEventReceived(hmppsEvent: HmppsDomainEvent) {
trackEvent(
"NotificationReceived",
mapOf("eventType" to hmppsEvent.eventType) +
(hmppsEvent.detailUrl?.let { mapOf("detailUrl" to it) } ?: mapOf()) +
(hmppsEvent.personReference.identifiers.associate { Pair(it.type, it.value) })
)
}

fun <T> TelemetryService.notificationReceived(notification: Notification<T>) {
if (notification.message is HmppsDomainEvent) {
hmppsEventReceived(notification.message)
} else {
trackEvent("NotificationReceived", notification.eventType?.let { mapOf("eventType" to it) } ?: mapOf())
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
package uk.gov.justice.digital.hmpps.listener

import com.fasterxml.jackson.core.type.TypeReference
import com.fasterxml.jackson.databind.ObjectMapper
import io.awspring.cloud.sqs.listener.AsyncAdapterBlockingExecutionFailedException
import io.awspring.cloud.sqs.listener.ListenerExecutionFailedException
import io.sentry.Sentry
import org.hamcrest.CoreMatchers.equalTo
import org.hamcrest.MatcherAssert.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.extension.ExtendWith
import org.mockito.InjectMocks
import org.mockito.Mock
import org.mockito.Mockito.mockStatic
import org.mockito.junit.jupiter.MockitoExtension
import org.mockito.kotlin.any
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import org.springframework.messaging.support.GenericMessage
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.messaging.NotificationHandler
import java.util.concurrent.CompletionException

Expand All @@ -23,9 +28,18 @@ class AwsNotificationListenerTest {
@Mock
lateinit var handler: NotificationHandler<Any>

@Mock
lateinit var objectMapper: ObjectMapper

@InjectMocks
lateinit var listener: AwsNotificationListener

@BeforeEach
fun setUp() {
whenever(objectMapper.readValue(any<String>(), any<TypeReference<Notification<String>>>()))
.thenReturn(Notification("message"))
}

@Test
fun `messages are dispatched to handler`() {
listener.receive("message")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package uk.gov.justice.digital.hmpps.telemetry

import com.microsoft.applicationinsights.TelemetryClient
import org.hamcrest.MatcherAssert.assertThat
import org.hamcrest.Matchers.equalTo
import org.hamcrest.Matchers.hasProperty
import org.hamcrest.Matchers.not
import org.hamcrest.Matchers.*
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
Expand All @@ -14,15 +12,13 @@ import org.mockito.junit.jupiter.MockitoExtension
import org.mockito.kotlin.check
import org.mockito.kotlin.eq
import org.mockito.kotlin.verify
import uk.gov.justice.digital.hmpps.message.HmppsDomainEvent
import uk.gov.justice.digital.hmpps.message.MessageAttributes
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.message.PersonIdentifier
import uk.gov.justice.digital.hmpps.message.PersonReference
import uk.gov.justice.digital.hmpps.message.*
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.hmppsEventReceived
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.notificationReceived
import java.time.ZonedDateTime

@ExtendWith(MockitoExtension::class)
class TelemetryServiceTest {
class TelemetryMessagingExtensionsTest {

@Mock
private lateinit var telemetryClient: TelemetryClient
Expand Down Expand Up @@ -52,7 +48,7 @@ class TelemetryServiceTest {
)

verify(telemetryClient).trackEvent(
eq("SOME_SPECIAL_EVENT_RECEIVED"),
eq("NotificationReceived"),
check {
assertThat(it["eventType"], equalTo(eventType))
assertThat(it["detailUrl"], equalTo(detailUrl))
Expand All @@ -73,7 +69,7 @@ class TelemetryServiceTest {
)

verify(telemetryClient).trackEvent(
eq("SOME_SPECIAL_EVENT_RECEIVED"),
eq("NotificationReceived"),
check { assertThat(it, not(hasProperty("detailUrl"))) },
anyMap()
)
Expand All @@ -88,13 +84,13 @@ class TelemetryServiceTest {
)
)

verify(telemetryClient).trackEvent(eq("TEST_EVENT_RECEIVED"), anyMap(), anyMap())
verify(telemetryClient).trackEvent(eq("NotificationReceived"), anyMap(), anyMap())
}

@Test
fun `handles events with no event type`() {
telemetryService.notificationReceived(Notification(message = "this is a string"))

verify(telemetryClient).trackEvent(eq("UNKNOWN_EVENT_RECEIVED"), anyMap(), anyMap())
verify(telemetryClient).trackEvent(eq("NotificationReceived"), anyMap(), anyMap())
}
}
Loading

0 comments on commit 0a5f52e

Please sign in to comment.