Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TracingInterceptor should take effect only one time #399

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ public class TracingInterceptor implements ExecutionInterceptor {
// TODO(anuraaga): Make private in next major version and rename.
public static final ExecutionAttribute<Subsegment> entityKey = new ExecutionAttribute("AWS X-Ray Entity");

// Make sure only one xray interceptor takes effect.
private static final ExecutionAttribute<TracingInterceptor> XRAY_INTERCEPTOR_KEY =
new ExecutionAttribute("AWS X-Ray Interceptor");

private static final Log logger = LogFactory.getLog(TracingInterceptor.class);

private static final ObjectMapper MAPPER = new ObjectMapper()
Expand Down Expand Up @@ -239,8 +243,26 @@ private HashMap<String, Object> extractResponseParameters(
return parameters;
}

private boolean isDuplicateInterceptor(ExecutionAttributes executionAttributes) {
if (executionAttributes.getAttribute(XRAY_INTERCEPTOR_KEY) == null) {
executionAttributes.putAttribute(XRAY_INTERCEPTOR_KEY, this);
return false;
}

if (executionAttributes.getAttribute(XRAY_INTERCEPTOR_KEY) != this) {
return true;
}

return false;
}

@Override
public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
if (isDuplicateInterceptor(executionAttributes)) {
logger.debug("X-Ray TracingInterceptor already exists.");
return;
}

AWSXRayRecorder recorder = getRecorder();
Entity origin = recorder.getTraceEntity();

Expand All @@ -265,19 +287,26 @@ public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes
@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
SdkHttpRequest httpRequest = context.httpRequest();
if (isDuplicateInterceptor(executionAttributes)) {
return httpRequest;
}

Subsegment subsegment = executionAttributes.getAttribute(entityKey);
if (!subsegment.shouldPropagate()) {
return httpRequest;
}

return httpRequest.toBuilder().appendHeader(
return httpRequest.toBuilder().putHeader(
TraceHeader.HEADER_KEY,
TraceHeader.fromEntity(subsegment).toString()).build();
}

@Override
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
if (isDuplicateInterceptor(executionAttributes)) {
return;
}

Subsegment subsegment = executionAttributes.getAttribute(entityKey);
if (subsegment == null) {
return;
Expand All @@ -293,6 +322,10 @@ public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttr

@Override
public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
if (isDuplicateInterceptor(executionAttributes)) {
return;
}

Subsegment subsegment = executionAttributes.getAttribute(entityKey);
if (subsegment == null) {
return;
Expand All @@ -307,6 +340,10 @@ public void afterExecution(Context.AfterExecution context, ExecutionAttributes e

@Override
public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) {
if (isDuplicateInterceptor(executionAttributes)) {
return;
}

Subsegment subsegment = executionAttributes.getAttribute(entityKey);
if (subsegment == null) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,44 @@ public void teardown() {
AWSXRay.endSegment();
}

@Test
public void testDuplicateInterceptor() throws Exception {
SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(400));
LambdaClient client = lambdaClientDupInterceptor(mockClient);

Segment segment = AWSXRay.getCurrentSegment();
try {
client.invoke(InvokeRequest.builder()
.functionName("testFunctionName")
.build()
);
} catch (Exception e) {
// ignore SDK errors
} finally {
Assert.assertEquals(1, segment.getSubsegments().size());
Subsegment subsegment = segment.getSubsegments().get(0);
Map<String, Object> awsStats = subsegment.getAws();
@SuppressWarnings("unchecked")
Map<String, Object> httpResponseStats = (Map<String, Object>) subsegment.getHttp().get("response");
Cause cause = subsegment.getCause();

Assert.assertEquals("Invoke", awsStats.get("operation"));
Assert.assertEquals("testFunctionName", awsStats.get("function_name"));
Assert.assertEquals("1111-2222-3333-4444", awsStats.get("request_id"));
Assert.assertEquals("extended", awsStats.get("id_2"));
Assert.assertEquals("us-west-42", awsStats.get("region"));
Assert.assertEquals(0, awsStats.get("retries"));
Assert.assertEquals(2L, httpResponseStats.get("content_length"));
Assert.assertEquals(400, httpResponseStats.get("status"));
Assert.assertEquals(false, subsegment.isInProgress());
Assert.assertEquals(true, subsegment.isError());
Assert.assertEquals(false, subsegment.isThrottle());
Assert.assertEquals(false, subsegment.isFault());
Assert.assertEquals(1, cause.getExceptions().size());
Assert.assertEquals(true, cause.getExceptions().get(0).isRemote());
}
}

@Test
public void testResponseDescriptors() throws Exception {
String responseBody = "{\"LastEvaluatedTableName\":\"baz\",\"TableNames\":[\"foo\",\"bar\",\"baz\"]}";
Expand Down Expand Up @@ -490,7 +528,7 @@ public void testNoHeaderAddedWhenPropagationOff() {

interceptor.modifyHttpRequest(context, attributes);

verify(mockRequest.toBuilder(), never()).appendHeader(anyString(), anyString());
verify(mockRequest.toBuilder(), never()).putHeader(anyString(), anyString());
}

@Test
Expand Down Expand Up @@ -661,6 +699,22 @@ private static LambdaClient lambdaClient(SdkHttpClient mockClient) {
.build();
}

private static LambdaClient lambdaClientDupInterceptor(SdkHttpClient mockClient) {
return LambdaClient.builder()
.httpClient(mockClient)
.endpointOverride(URI.create("http://example.com"))
.region(Region.of("us-west-42"))
.credentialsProvider(StaticCredentialsProvider.create(
AwsSessionCredentials.create("key", "secret", "session")
))
.overrideConfiguration(ClientOverrideConfiguration.builder()
.addExecutionInterceptor(new TracingInterceptor())
.addExecutionInterceptor(new TracingInterceptor())
.build()
)
.build();
}

private static LambdaAsyncClient lambdaAsyncClient(SdkAsyncHttpClient mockClient) {
return LambdaAsyncClient.builder()
.httpClient(mockClient)
Expand Down
Loading