diff --git a/aws-xray-recorder-sdk-aws-sdk-v2/src/main/java/com/amazonaws/xray/interceptors/TracingInterceptor.java b/aws-xray-recorder-sdk-aws-sdk-v2/src/main/java/com/amazonaws/xray/interceptors/TracingInterceptor.java index de5a9ca5..271f9731 100644 --- a/aws-xray-recorder-sdk-aws-sdk-v2/src/main/java/com/amazonaws/xray/interceptors/TracingInterceptor.java +++ b/aws-xray-recorder-sdk-aws-sdk-v2/src/main/java/com/amazonaws/xray/interceptors/TracingInterceptor.java @@ -64,6 +64,9 @@ public class TracingInterceptor implements ExecutionInterceptor { // TODO(anuraaga): Make private in next major version and rename. public static final ExecutionAttribute entityKey = new ExecutionAttribute("AWS X-Ray Entity"); + // Make sure only one xray interceptor takes effect. + private static final ExecutionAttribute XRAY_INTERCEPTOR_KEY = new ExecutionAttribute("AWS X-Ray Interceptor"); + private static final Log logger = LogFactory.getLog(TracingInterceptor.class); private static final ObjectMapper MAPPER = new ObjectMapper() @@ -239,8 +242,26 @@ private HashMap extractResponseParameters( return parameters; } + private boolean isDuplicateInterceptor(ExecutionAttributes executionAttributes) { + if (executionAttributes.getAttribute(XRAY_INTERCEPTOR_KEY) == null) { + executionAttributes.putAttribute(XRAY_INTERCEPTOR_KEY, this); + return false; + } + + if (executionAttributes.getAttribute(XRAY_INTERCEPTOR_KEY) != this) { + return true; + } + + return false; + } + @Override public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) { + if (isDuplicateInterceptor(executionAttributes)) { + logger.debug("X-Ray TracingInterceptor already exists."); + return; + } + AWSXRayRecorder recorder = getRecorder(); Entity origin = recorder.getTraceEntity(); @@ -265,19 +286,26 @@ public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes @Override public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { SdkHttpRequest httpRequest = context.httpRequest(); + if (isDuplicateInterceptor(executionAttributes)) { + return httpRequest; + } Subsegment subsegment = executionAttributes.getAttribute(entityKey); if (!subsegment.shouldPropagate()) { return httpRequest; } - return httpRequest.toBuilder().appendHeader( + return httpRequest.toBuilder().putHeader( TraceHeader.HEADER_KEY, TraceHeader.fromEntity(subsegment).toString()).build(); } @Override public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + if (isDuplicateInterceptor(executionAttributes)) { + return; + } + Subsegment subsegment = executionAttributes.getAttribute(entityKey); if (subsegment == null) { return; @@ -293,6 +321,10 @@ public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttr @Override public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) { + if (isDuplicateInterceptor(executionAttributes)) { + return; + } + Subsegment subsegment = executionAttributes.getAttribute(entityKey); if (subsegment == null) { return; @@ -307,6 +339,10 @@ public void afterExecution(Context.AfterExecution context, ExecutionAttributes e @Override public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) { + if (isDuplicateInterceptor(executionAttributes)) { + return; + } + Subsegment subsegment = executionAttributes.getAttribute(entityKey); if (subsegment == null) { return; diff --git a/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java b/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java index e2bdf8ba..18c69cca 100644 --- a/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java +++ b/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java @@ -17,6 +17,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -98,6 +99,44 @@ public void teardown() { AWSXRay.endSegment(); } + @Test + public void testDuplicateInterceptor() throws Exception { + SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(400)); + LambdaClient client = lambdaClientDupInterceptor(mockClient); + + Segment segment = AWSXRay.getCurrentSegment(); + try { + client.invoke(InvokeRequest.builder() + .functionName("testFunctionName") + .build() + ); + } catch (Exception e) { + // ignore SDK errors + } finally { + Assert.assertEquals(1, segment.getSubsegments().size()); + Subsegment subsegment = segment.getSubsegments().get(0); + Map awsStats = subsegment.getAws(); + @SuppressWarnings("unchecked") + Map httpResponseStats = (Map) subsegment.getHttp().get("response"); + Cause cause = subsegment.getCause(); + + Assert.assertEquals("Invoke", awsStats.get("operation")); + Assert.assertEquals("testFunctionName", awsStats.get("function_name")); + Assert.assertEquals("1111-2222-3333-4444", awsStats.get("request_id")); + Assert.assertEquals("extended", awsStats.get("id_2")); + Assert.assertEquals("us-west-42", awsStats.get("region")); + Assert.assertEquals(0, awsStats.get("retries")); + Assert.assertEquals(2L, httpResponseStats.get("content_length")); + Assert.assertEquals(400, httpResponseStats.get("status")); + Assert.assertEquals(false, subsegment.isInProgress()); + Assert.assertEquals(true, subsegment.isError()); + Assert.assertEquals(false, subsegment.isThrottle()); + Assert.assertEquals(false, subsegment.isFault()); + Assert.assertEquals(1, cause.getExceptions().size()); + Assert.assertEquals(true, cause.getExceptions().get(0).isRemote()); + } + } + @Test public void testResponseDescriptors() throws Exception { String responseBody = "{\"LastEvaluatedTableName\":\"baz\",\"TableNames\":[\"foo\",\"bar\",\"baz\"]}"; @@ -490,7 +529,7 @@ public void testNoHeaderAddedWhenPropagationOff() { interceptor.modifyHttpRequest(context, attributes); - verify(mockRequest.toBuilder(), never()).appendHeader(anyString(), anyString()); + verify(mockRequest.toBuilder(), never()).putHeader(anyString(), anyString()); } @Test @@ -661,6 +700,22 @@ private static LambdaClient lambdaClient(SdkHttpClient mockClient) { .build(); } + private static LambdaClient lambdaClientDupInterceptor(SdkHttpClient mockClient) { + return LambdaClient.builder() + .httpClient(mockClient) + .endpointOverride(URI.create("http://example.com")) + .region(Region.of("us-west-42")) + .credentialsProvider(StaticCredentialsProvider.create( + AwsSessionCredentials.create("key", "secret", "session") + )) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .addExecutionInterceptor(new TracingInterceptor()) + .addExecutionInterceptor(new TracingInterceptor()) + .build() + ) + .build(); + } + private static LambdaAsyncClient lambdaAsyncClient(SdkAsyncHttpClient mockClient) { return LambdaAsyncClient.builder() .httpClient(mockClient)