Skip to content

Commit

Permalink
only one xray TracingInterceptor should be effective
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzlei committed Mar 5, 2024
1 parent 0cc3e14 commit 7e27e98
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ public class TracingInterceptor implements ExecutionInterceptor {
// TODO(anuraaga): Make private in next major version and rename.
public static final ExecutionAttribute<Subsegment> entityKey = new ExecutionAttribute("AWS X-Ray Entity");

// Make sure only one xray interceptor takes effect.
private static final ExecutionAttribute<TracingInterceptor> XRAY_INTERCEPTOR_KEY = new ExecutionAttribute("AWS X-Ray Interceptor");

private static final Log logger = LogFactory.getLog(TracingInterceptor.class);

private static final ObjectMapper MAPPER = new ObjectMapper()
Expand Down Expand Up @@ -239,8 +242,26 @@ private HashMap<String, Object> extractResponseParameters(
return parameters;
}

private boolean isDuplicateInterceptor(ExecutionAttributes executionAttributes) {
if (executionAttributes.getAttribute(XRAY_INTERCEPTOR_KEY) == null) {
executionAttributes.putAttribute(XRAY_INTERCEPTOR_KEY, this);
return false;
}

if (executionAttributes.getAttribute(XRAY_INTERCEPTOR_KEY) != this) {
return true;
}

return false;
}

@Override
public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
if (isDuplicateInterceptor(executionAttributes)) {
logger.debug("X-Ray TracingInterceptor already exists.");
return;
}

AWSXRayRecorder recorder = getRecorder();
Entity origin = recorder.getTraceEntity();

Expand All @@ -265,19 +286,26 @@ public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes
@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
SdkHttpRequest httpRequest = context.httpRequest();
if (isDuplicateInterceptor(executionAttributes)) {
return httpRequest;
}

Subsegment subsegment = executionAttributes.getAttribute(entityKey);
if (!subsegment.shouldPropagate()) {
return httpRequest;
}

return httpRequest.toBuilder().appendHeader(
return httpRequest.toBuilder().putHeader(
TraceHeader.HEADER_KEY,
TraceHeader.fromEntity(subsegment).toString()).build();
}

@Override
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
if (isDuplicateInterceptor(executionAttributes)) {
return;
}

Subsegment subsegment = executionAttributes.getAttribute(entityKey);
if (subsegment == null) {
return;
Expand All @@ -293,6 +321,10 @@ public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttr

@Override
public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
if (isDuplicateInterceptor(executionAttributes)) {
return;
}

Subsegment subsegment = executionAttributes.getAttribute(entityKey);
if (subsegment == null) {
return;
Expand All @@ -307,6 +339,10 @@ public void afterExecution(Context.AfterExecution context, ExecutionAttributes e

@Override
public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) {
if (isDuplicateInterceptor(executionAttributes)) {
return;
}

Subsegment subsegment = executionAttributes.getAttribute(entityKey);
if (subsegment == null) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -98,6 +99,44 @@ public void teardown() {
AWSXRay.endSegment();
}

@Test
public void testDuplicateInterceptor() throws Exception {
SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(400));
LambdaClient client = lambdaClientDupInterceptor(mockClient);

Segment segment = AWSXRay.getCurrentSegment();
try {
client.invoke(InvokeRequest.builder()
.functionName("testFunctionName")
.build()
);
} catch (Exception e) {
// ignore SDK errors
} finally {
Assert.assertEquals(1, segment.getSubsegments().size());
Subsegment subsegment = segment.getSubsegments().get(0);
Map<String, Object> awsStats = subsegment.getAws();
@SuppressWarnings("unchecked")
Map<String, Object> httpResponseStats = (Map<String, Object>) subsegment.getHttp().get("response");
Cause cause = subsegment.getCause();

Assert.assertEquals("Invoke", awsStats.get("operation"));
Assert.assertEquals("testFunctionName", awsStats.get("function_name"));
Assert.assertEquals("1111-2222-3333-4444", awsStats.get("request_id"));
Assert.assertEquals("extended", awsStats.get("id_2"));
Assert.assertEquals("us-west-42", awsStats.get("region"));
Assert.assertEquals(0, awsStats.get("retries"));
Assert.assertEquals(2L, httpResponseStats.get("content_length"));
Assert.assertEquals(400, httpResponseStats.get("status"));
Assert.assertEquals(false, subsegment.isInProgress());
Assert.assertEquals(true, subsegment.isError());
Assert.assertEquals(false, subsegment.isThrottle());
Assert.assertEquals(false, subsegment.isFault());
Assert.assertEquals(1, cause.getExceptions().size());
Assert.assertEquals(true, cause.getExceptions().get(0).isRemote());
}
}

@Test
public void testResponseDescriptors() throws Exception {
String responseBody = "{\"LastEvaluatedTableName\":\"baz\",\"TableNames\":[\"foo\",\"bar\",\"baz\"]}";
Expand Down Expand Up @@ -490,7 +529,7 @@ public void testNoHeaderAddedWhenPropagationOff() {

interceptor.modifyHttpRequest(context, attributes);

verify(mockRequest.toBuilder(), never()).appendHeader(anyString(), anyString());
verify(mockRequest.toBuilder(), never()).putHeader(anyString(), anyString());
}

@Test
Expand Down Expand Up @@ -661,6 +700,22 @@ private static LambdaClient lambdaClient(SdkHttpClient mockClient) {
.build();
}

private static LambdaClient lambdaClientDupInterceptor(SdkHttpClient mockClient) {
return LambdaClient.builder()
.httpClient(mockClient)
.endpointOverride(URI.create("http://example.com"))
.region(Region.of("us-west-42"))
.credentialsProvider(StaticCredentialsProvider.create(
AwsSessionCredentials.create("key", "secret", "session")
))
.overrideConfiguration(ClientOverrideConfiguration.builder()
.addExecutionInterceptor(new TracingInterceptor())
.addExecutionInterceptor(new TracingInterceptor())
.build()
)
.build();
}

private static LambdaAsyncClient lambdaAsyncClient(SdkAsyncHttpClient mockClient) {
return LambdaAsyncClient.builder()
.httpClient(mockClient)
Expand Down

0 comments on commit 7e27e98

Please sign in to comment.