Skip to content

Commit

Permalink
mdc-utils: Make LoggerStringWriter thread friendly (#2771)
Browse files Browse the repository at this point in the history
Motivation:

We have seen some flakyness in `HttpMessageDiscardWatchdogServiceFilterTest`
that may be related to thread safety for the test log appender.
Specifically,
- The sink, a `StringWriter`, is globally shared so it can be reset by
  concurrently running tests.
- The underlying `StringWriter` itself isn't thread safe and it's possible
  for it to be concurrent written to and read from potentially leading to
  state corruption.

Modifications:

- Make a `LoggerStringWriter` usable as an instance instead of as static
  functions. This gives each test suite it's own state that it easier to control.
- Make a thread-safe proxy for the StringWriter so it's safe to write
  and read from concurrently.
  • Loading branch information
bryce-anderson authored Dec 7, 2023
1 parent e3a10f8 commit 112fb07
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ final class HttpMessageDiscardWatchdogClientFilterTest {

private static final Logger LOGGER = LoggerFactory.getLogger(HttpMessageDiscardWatchdogClientFilterTest.class);

private final LoggerStringWriter loggerStringWriter = new LoggerStringWriter();

@RegisterExtension
static final ExecutionContextExtension SERVER_CTX =
ExecutionContextExtension.cached("server-io", "server-executor")
Expand All @@ -68,12 +70,12 @@ final class HttpMessageDiscardWatchdogClientFilterTest {

@BeforeEach
public void setup() {
LoggerStringWriter.reset();
loggerStringWriter.reset();
}

@AfterEach
public void tearDown() {
LoggerStringWriter.remove();
loggerStringWriter.remove();
}

/**
Expand Down Expand Up @@ -132,7 +134,7 @@ protected Single<StreamingHttpResponse> request(final StreamingHttpRequester del
}
}

String output = LoggerStringWriter.stableAccumulated(1000);
String output = loggerStringWriter.stableAccumulated(1000);
LOGGER.info("Logger output: {}", output);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,16 @@ final class HttpMessageDiscardWatchdogServiceFilterTest {
ExecutionContextExtension.cached("client-io", "client-executor")
.setClassLevel(true);

private final LoggerStringWriter loggerStringWriter = new LoggerStringWriter();

@BeforeEach
public void setup() {
LoggerStringWriter.reset();
loggerStringWriter.reset();
}

@AfterEach
public void tearDown() {
LoggerStringWriter.remove();
loggerStringWriter.remove();
}

@ParameterizedTest(name = "{displayName} [{index}] transformer={0}")
Expand All @@ -91,7 +93,7 @@ public Single<StreamingHttpResponse> handle(final HttpServiceContext ctx,
assertEquals(0, response.payloadBody().readableBytes());
}

String output = LoggerStringWriter.stableAccumulated(CI ? 5000 : 1000);
String output = loggerStringWriter.stableAccumulated(CI ? 5000 : 1000);
if (!output.contains("Discovered un-drained HTTP response message body which " +
"has been dropped by user code")) {
throw new AssertionError("Logs didn't contain the expected output:\n" + output);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.StringWriter;
import java.io.Writer;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeoutException;
import javax.annotation.Nullable;

import static java.lang.System.nanoTime;
import static java.lang.Thread.currentThread;
Expand All @@ -42,25 +44,24 @@

public final class LoggerStringWriter {
private static final Logger LOGGER = LoggerFactory.getLogger(LoggerStringWriter.class);
private static final String APPENDER_NAME = "writer";
@Nullable
private static StringWriter logStringWriter;

private LoggerStringWriter() {
// no instances.
}
// protected by synchronization on `this`.
private ConcurrentStringWriter writer;

/**
* Clear the content of the {@link #accumulated()}.
* <p>
* Note that the underlying logger may be initialized by this method and it must always be
* followed up with a {@link #remove()} call at the end of tests to clean up logger state.
*/
public static void reset() {
getStringWriter().getBuffer().setLength(0);
public void reset() {
getStringWriter().reset();
}

/**
* Remove the underlying in-memory log appender.
*/
public static void remove() {
public void remove() {
removeStringWriter();
}

Expand All @@ -69,7 +70,7 @@ public static void remove() {
*
* @return the accumulated content that has been logged.
*/
public static String accumulated() {
public String accumulated() {
return getStringWriter().toString();
}

Expand All @@ -83,7 +84,7 @@ public static String accumulated() {
* @throws TimeoutException If the {@code totalWaitTimeMillis} duration has been exceeded and the
* {@link #accumulated()} has not yet stabilize.
*/
public static String stableAccumulated(int totalWaitTimeMillis) throws InterruptedException, TimeoutException {
public String stableAccumulated(int totalWaitTimeMillis) throws InterruptedException, TimeoutException {
return stableAccumulated(totalWaitTimeMillis, 10);
}

Expand All @@ -98,7 +99,7 @@ public static String stableAccumulated(int totalWaitTimeMillis) throws Interrupt
* @throws TimeoutException If the {@code totalWaitTimeMillis} duration has been exceeded and the
* {@link #accumulated()} has not yet stabilize.
*/
public static String stableAccumulated(int totalWaitTimeMillis, final long sleepDurationMs)
public String stableAccumulated(int totalWaitTimeMillis, final long sleepDurationMs)
throws InterruptedException, TimeoutException {
// We force a unique log entry, and wait for it to ensure the content from the local thread has been flushed.
String forcedLogEntry = "forced log entry to help for flush on current thread " +
Expand Down Expand Up @@ -157,29 +158,28 @@ public static void assertContainsMdcPair(String value, String expectedLabel, Str
assertThat(value.substring(beginIndex, beginIndex + expectedValue.length()), is(expectedValue));
}

private static synchronized StringWriter getStringWriter() {
if (logStringWriter == null) {
private synchronized ConcurrentStringWriter getStringWriter() {
if (writer == null) {
final LoggerContext context = (LoggerContext) LogManager.getContext(false);
logStringWriter = addWriterAppender(context, DEBUG);
writer = addWriterAppender(context, DEBUG);
}
return logStringWriter;
return writer;
}

private static synchronized void removeStringWriter() {
if (logStringWriter == null) {
private synchronized void removeStringWriter() {
if (writer == null) {
return;
}
removeWriterAppender((LoggerContext) LogManager.getContext(false));
logStringWriter = null;
removeWriterAppender(writer, (LoggerContext) LogManager.getContext(false));
writer = null;
}

private static StringWriter addWriterAppender(final LoggerContext context, Level level) {
private static ConcurrentStringWriter addWriterAppender(final LoggerContext context, Level level) {
final Configuration config = context.getConfiguration();
final StringWriter writer = new StringWriter();

final ConcurrentStringWriter writer = new ConcurrentStringWriter();
final Map.Entry<String, Appender> existing = config.getAppenders().entrySet().iterator().next();
final WriterAppender writerAppender = WriterAppender.newBuilder()
.setName(APPENDER_NAME)
.setName(writer.name)
.setLayout(existing.getValue().getLayout())
.setTarget(writer)
.build();
Expand All @@ -190,16 +190,58 @@ private static StringWriter addWriterAppender(final LoggerContext context, Level
return writer;
}

private static void removeWriterAppender(final LoggerContext context) {
private static void removeWriterAppender(ConcurrentStringWriter writer, final LoggerContext context) {
final Configuration config = context.getConfiguration();
LoggerConfig rootConfig = config.getRootLogger();
// Stopping the logger is subject to race conditions where logging during cleanup on global executor
// may still try to log and raise an error.
WriterAppender writerAppender = (WriterAppender) rootConfig.getAppenders().get(APPENDER_NAME);
WriterAppender writerAppender = (WriterAppender) rootConfig.getAppenders().get(writer.name);
if (writerAppender != null) {
writerAppender.stop(0, NANOSECONDS);
}
// Don't remove directly from map, because the root logger also cleans up filters.
rootConfig.removeAppender(APPENDER_NAME);
rootConfig.removeAppender(writer.name);
}

// This is essentially just a thread safe `StringAppender` with a unique `String name` field to use
// as a map key.
private static final class ConcurrentStringWriter extends Writer {

private static final String APPENDER_NAME_PREFIX = "writer";

private final StringWriter stringWriter = new StringWriter();

// We use uuid as a way to give the appender a unique name. We could try and do it with the current
// thread name but it's hard to say if that will be unique but it is certain to be ugly.
final String name = APPENDER_NAME_PREFIX + '_' + UUID.randomUUID();
@Override
public void write(char[] cbuf, int off, int len) throws IOException {
synchronized (stringWriter) {
stringWriter.write(cbuf, off, len);
}
}

@Override
public void flush() {
// this is a no-op for `StringWriter`
}

@Override
public void close() {
// this is a no-op for `StringWriter`
}

@Override
public String toString() {
synchronized (stringWriter) {
return stringWriter.toString();
}
}

void reset() {
synchronized (stringWriter) {
stringWriter.getBuffer().setLength(0);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.http.netty.HttpClients.forSingleAddress;
import static io.servicetalk.log4j2.mdc.utils.LoggerStringWriter.assertContainsMdcPair;
import static io.servicetalk.log4j2.mdc.utils.LoggerStringWriter.stableAccumulated;
import static io.servicetalk.opentelemetry.http.TestUtils.SPAN_STATE_SERIALIZER;
import static io.servicetalk.opentelemetry.http.TestUtils.TRACING_TEST_LOG_LINE_PREFIX;
import static io.servicetalk.opentelemetry.http.TestUtils.TestTracingClientLoggerFilter;
Expand All @@ -62,17 +61,19 @@ class OpenTelemetryHttpRequestFilterTest {

private static final Logger LOGGER = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private final LoggerStringWriter loggerStringWriter = new LoggerStringWriter();

@RegisterExtension
static final OpenTelemetryExtension otelTesting = OpenTelemetryExtension.create();

@BeforeEach
public void setup() {
LoggerStringWriter.reset();
loggerStringWriter.reset();
}

@AfterEach
public void tearDown() {
LoggerStringWriter.remove();
loggerStringWriter.remove();
}

@Test
Expand All @@ -86,7 +87,7 @@ void testInjectWithNoParent() throws Exception {
HttpResponse response = client.request(client.get(requestUrl)).toFuture().get();
TestSpanState serverSpanState = response.payloadBody(SPAN_STATE_SERIALIZER);

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl,
verifyTraceIdPresentInLogs(loggerStringWriter.stableAccumulated(1000), requestUrl,
serverSpanState.getTraceId(), serverSpanState.getSpanId(),
TRACING_TEST_LOG_LINE_PREFIX);
assertThat(otelTesting.getSpans()).hasSize(1);
Expand Down Expand Up @@ -116,7 +117,7 @@ void testInjectWithAParent() throws Exception {
HttpResponse response = client.request(client.get(requestUrl)).toFuture().get();
TestSpanState serverSpanState = response.payloadBody(SPAN_STATE_SERIALIZER);

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl,
verifyTraceIdPresentInLogs(loggerStringWriter.stableAccumulated(1000), requestUrl,
serverSpanState.getTraceId(), serverSpanState.getSpanId(),
TRACING_TEST_LOG_LINE_PREFIX);
assertThat(otelTesting.getSpans()).hasSize(2);
Expand Down Expand Up @@ -173,7 +174,7 @@ void testInjectWithAParentCreated() throws Exception {
} finally {
span.end();
}
verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl,
verifyTraceIdPresentInLogs(loggerStringWriter.stableAccumulated(1000), requestUrl,
serverSpanState.getTraceId(), serverSpanState.getSpanId(),
TRACING_TEST_LOG_LINE_PREFIX);
assertThat(otelTesting.getSpans()).hasSize(3);
Expand Down Expand Up @@ -219,7 +220,7 @@ void testCaptureHeader() throws Exception {
.addHeader("some-request-header", "request-header-value")).toFuture().get();
TestSpanState serverSpanState = response.payloadBody(SPAN_STATE_SERIALIZER);

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl,
verifyTraceIdPresentInLogs(loggerStringWriter.stableAccumulated(1000), requestUrl,
serverSpanState.getTraceId(), serverSpanState.getSpanId(),
TRACING_TEST_LOG_LINE_PREFIX);
assertThat(otelTesting.getSpans()).hasSize(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@

import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.http.netty.HttpClients.forSingleAddress;
import static io.servicetalk.log4j2.mdc.utils.LoggerStringWriter.stableAccumulated;
import static io.servicetalk.opentelemetry.http.OpenTelemetryHttpRequestFilterTest.verifyTraceIdPresentInLogs;
import static io.servicetalk.opentelemetry.http.TestUtils.SPAN_STATE_SERIALIZER;
import static io.servicetalk.opentelemetry.http.TestUtils.TRACING_TEST_LOG_LINE_PREFIX;
Expand All @@ -60,14 +59,16 @@ class OpenTelemetryHttpServerFilterTest {
@RegisterExtension
static final OpenTelemetryExtension otelTesting = OpenTelemetryExtension.create();

private final LoggerStringWriter loggerStringWriter = new LoggerStringWriter();

@BeforeEach
public void setup() {
LoggerStringWriter.reset();
loggerStringWriter.reset();
}

@AfterEach
public void tearDown() {
LoggerStringWriter.remove();
loggerStringWriter.remove();
}

@Test
Expand All @@ -78,7 +79,7 @@ void testInjectWithNoParent() throws Exception {
HttpResponse response = client.request(client.get(requestUrl)).toFuture().get();
TestSpanState serverSpanState = response.payloadBody(SPAN_STATE_SERIALIZER);

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl,
verifyTraceIdPresentInLogs(loggerStringWriter.stableAccumulated(1000), requestUrl,
serverSpanState.getTraceId(), serverSpanState.getSpanId(),
TRACING_TEST_LOG_LINE_PREFIX);
assertThat(otelTesting.getSpans()).hasSize(1);
Expand Down Expand Up @@ -125,7 +126,7 @@ void testInjectWithAParent() throws Exception {
HttpResponse response = client.request(client.get(requestUrl)).toFuture().get();
TestSpanState serverSpanState = response.payloadBody(SPAN_STATE_SERIALIZER);

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl,
verifyTraceIdPresentInLogs(loggerStringWriter.stableAccumulated(1000), requestUrl,
serverSpanState.getTraceId(), serverSpanState.getSpanId(),
TRACING_TEST_LOG_LINE_PREFIX);
assertThat(otelTesting.getSpans()).hasSize(2);
Expand Down Expand Up @@ -176,7 +177,7 @@ void testInjectWithNewTrace() throws Exception {
} finally {
span.end();
}
verifyTraceIdPresentInLogs(stableAccumulated(1000), "/",
verifyTraceIdPresentInLogs(loggerStringWriter.stableAccumulated(1000), "/",
serverSpanState.getTraceId(), serverSpanState.getSpanId(),
TRACING_TEST_LOG_LINE_PREFIX);
assertThat(otelTesting.getSpans()).hasSize(2);
Expand Down Expand Up @@ -209,7 +210,7 @@ void testCaptureHeaders() throws Exception {
.toFuture().get();
TestSpanState serverSpanState = response.payloadBody(SPAN_STATE_SERIALIZER);

verifyTraceIdPresentInLogs(stableAccumulated(1000), requestUrl,
verifyTraceIdPresentInLogs(loggerStringWriter.stableAccumulated(1000), requestUrl,
serverSpanState.getTraceId(), serverSpanState.getSpanId(),
TRACING_TEST_LOG_LINE_PREFIX);
assertThat(otelTesting.getSpans()).hasSize(1);
Expand Down
Loading

0 comments on commit 112fb07

Please sign in to comment.