diff --git a/docs/index.asciidoc b/docs/index.asciidoc index d1721ed5..b6647751 100644 --- a/docs/index.asciidoc +++ b/docs/index.asciidoc @@ -221,6 +221,7 @@ This plugin supports the following configuration options plus the <> |<>|No | <> |<>|__Deprecated__ | <> |<>|Yes +| <> |<>|No | <> |<>|__Deprecated__ | <> |a valid filesystem path|No | <> |<>|No @@ -384,6 +385,17 @@ deprecated[6.5.0, Replaced by <>] The port to listen on. +[id="plugins-{type}s-{plugin}-protect_direct_memory"] +===== `protect_direct_memory` + + * Value type is <> + * Default value is `true` + +If enabled, actively check native memory used by network part to do parsing and avoid +out of memory conditions. When the consumption of native memory used is close to +the maximum limit, connections are being closed in undetermined order until the safe +memory condition is reestablished. + [id="plugins-{type}s-{plugin}-ssl"] ===== `ssl` deprecated[6.6.0, Replaced by <>] diff --git a/lib/logstash/inputs/beats.rb b/lib/logstash/inputs/beats.rb index 3a302bfd..d1f8a6c5 100644 --- a/lib/logstash/inputs/beats.rb +++ b/lib/logstash/inputs/beats.rb @@ -74,6 +74,10 @@ class LogStash::Inputs::Beats < LogStash::Inputs::Base # The port to listen on. config :port, :validate => :number, :required => true + # Proactive checks that keep the beats input active when the memory used by protocol parser and network + # related operations is going to terminate. + config :protect_direct_memory, :validate => :boolean, :default => true + # Events are by default sent in plain text. You can # enable encryption by setting `ssl` to true and configuring # the `ssl_certificate` and `ssl_key` options. @@ -243,9 +247,11 @@ def register end # def register def create_server - server = org.logstash.beats.Server.new(@host, @port, @client_inactivity_timeout, @executor_threads) + server = org.logstash.beats.Server.new(@host, @port, @client_inactivity_timeout, @executor_threads, @protect_direct_memory) server.setSslHandlerProvider(new_ssl_handshake_provider(new_ssl_context_builder)) if @ssl_enabled server + rescue java.lang.IllegalArgumentException => e + configuration_error e.message end def run(output_queue) diff --git a/spec/inputs/beats_spec.rb b/spec/inputs/beats_spec.rb index f68a33da..7c0694da 100644 --- a/spec/inputs/beats_spec.rb +++ b/spec/inputs/beats_spec.rb @@ -14,6 +14,7 @@ let(:port) { BeatsInputTest.random_port } let(:client_inactivity_timeout) { 400 } let(:threads) { 1 + rand(9) } + let(:protect_direct_memory) { true } let(:queue) { Queue.new } let(:config) do { @@ -36,7 +37,7 @@ let(:port) { 9001 } it "sends the required options to the server" do - expect(org.logstash.beats.Server).to receive(:new).with(host, port, client_inactivity_timeout, threads) + expect(org.logstash.beats.Server).to receive(:new).with(host, port, client_inactivity_timeout, threads, protect_direct_memory) subject.register end end @@ -529,8 +530,8 @@ subject(:plugin) { LogStash::Inputs::Beats.new(config) } before do - @server = org.logstash.beats.Server.new(host, port, client_inactivity_timeout, threads) - expect( org.logstash.beats.Server ).to receive(:new).with(host, port, client_inactivity_timeout, threads).and_return @server + @server = org.logstash.beats.Server.new(host, port, client_inactivity_timeout, threads, protect_direct_memory) + expect( org.logstash.beats.Server ).to receive(:new).with(host, port, client_inactivity_timeout, threads, protect_direct_memory).and_return @server expect( @server ).to receive(:listen) subject.register diff --git a/src/main/java/org/logstash/beats/BeatsHandler.java b/src/main/java/org/logstash/beats/BeatsHandler.java index 15dfb7e9..caa5054a 100644 --- a/src/main/java/org/logstash/beats/BeatsHandler.java +++ b/src/main/java/org/logstash/beats/BeatsHandler.java @@ -1,5 +1,7 @@ package org.logstash.beats; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import org.apache.logging.log4j.LogManager; @@ -92,6 +94,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E logger.info(format("closing (" + cause.getMessage() + ")")); } } else { + PooledByteBufAllocator allocator = (PooledByteBufAllocator) ByteBufAllocator.DEFAULT; + OOMConnectionCloser.DirectMemoryUsage usageSnapshot = OOMConnectionCloser.DirectMemoryUsage.capture(allocator); + logger.info("Connection {}, memory status used: {}, pinned: {}, ratio {}", ctx.channel(), usageSnapshot.used, usageSnapshot.pinned, usageSnapshot.ratio); final Throwable realCause = extractCause(cause, 0); if (logger.isDebugEnabled()){ logger.info(format("Handling exception: " + cause + " (caused by: " + realCause + ")"), cause); diff --git a/src/main/java/org/logstash/beats/BeatsParser.java b/src/main/java/org/logstash/beats/BeatsParser.java index 812150b1..e8adef60 100644 --- a/src/main/java/org/logstash/beats/BeatsParser.java +++ b/src/main/java/org/logstash/beats/BeatsParser.java @@ -2,9 +2,12 @@ import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -14,6 +17,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.zip.Inflater; import java.util.zip.InflaterOutputStream; @@ -48,8 +52,8 @@ private enum States { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws InvalidFrameProtocolException, IOException { - if(!hasEnoughBytes(in)) { - if (decodingCompressedBuffer){ + if (!hasEnoughBytes(in)) { + if (decodingCompressedBuffer) { throw new InvalidFrameProtocolException("Insufficient bytes in compressed content to decode: " + currentState); } return; @@ -182,6 +186,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t case READ_COMPRESSED_FRAME: { logger.trace("Running: READ_COMPRESSED_FRAME"); + inflateCompressedFrame(ctx, in, (buffer) -> { transition(States.READ_HEADER); @@ -199,9 +204,18 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } case READ_JSON: { logger.trace("Running: READ_JSON"); - ((V2Batch)batch).addMessage(sequence, in, requiredBytes); - if(batch.isComplete()) { - if(logger.isTraceEnabled()) { + try { + ((V2Batch) batch).addMessage(sequence, in, requiredBytes); + } catch (Throwable th) { + // batch has to release its internal buffer before bubbling up the exception + batch.release(); + + // re throw the same error after released the internal buffer + throw th; + } + + if (batch.isComplete()) { + if (logger.isTraceEnabled()) { logger.trace("Sending batch size: " + this.batch.size() + ", windowSize: " + batch.getBatchSize() + " , seq: " + sequence); } out.add(batch); diff --git a/src/main/java/org/logstash/beats/FlowLimiterHandler.java b/src/main/java/org/logstash/beats/FlowLimiterHandler.java new file mode 100644 index 00000000..6a0da517 --- /dev/null +++ b/src/main/java/org/logstash/beats/FlowLimiterHandler.java @@ -0,0 +1,56 @@ +package org.logstash.beats; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Configure the channel where it's installed to operate the reads in pull mode, + * disabling the autoread and explicitly invoking the read operation. + * The flow control to keep the outgoing buffer under control is done + * avoiding to read in new bytes if the outgoing direction became not writable, this + * excert back pressure to the TCP layer and ultimately to the upstream system. + * */ +@Sharable +public final class FlowLimiterHandler extends ChannelInboundHandlerAdapter { + + private final static Logger logger = LogManager.getLogger(FlowLimiterHandler.class); + + @Override + public void channelRegistered(final ChannelHandlerContext ctx) throws Exception { + ctx.channel().config().setAutoRead(false); + super.channelRegistered(ctx); + } + + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + if (isAutoreadDisabled(ctx.channel()) && ctx.channel().isWritable()) { + ctx.channel().read(); + } + } + + @Override + public void channelReadComplete(final ChannelHandlerContext ctx) throws Exception { + super.channelReadComplete(ctx); + if (isAutoreadDisabled(ctx.channel()) && ctx.channel().isWritable()) { + ctx.channel().read(); + } + } + + private boolean isAutoreadDisabled(Channel channel) { + return !channel.config().isAutoRead(); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + ctx.channel().read(); + super.channelWritabilityChanged(ctx); + + logger.debug("Writability on channel {} changed to {}", ctx.channel(), ctx.channel().isWritable()); + } + +} diff --git a/src/main/java/org/logstash/beats/OOMConnectionCloser.java b/src/main/java/org/logstash/beats/OOMConnectionCloser.java new file mode 100644 index 00000000..751aba53 --- /dev/null +++ b/src/main/java/org/logstash/beats/OOMConnectionCloser.java @@ -0,0 +1,88 @@ +package org.logstash.beats; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.PlatformDependent; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class OOMConnectionCloser extends ChannelInboundHandlerAdapter { + + private final PooledByteBufAllocator allocator; + + static class DirectMemoryUsage { + final long used; + final long pinned; + private final PooledByteBufAllocator allocator; + final short ratio; + + private DirectMemoryUsage(long used, long pinned, PooledByteBufAllocator allocator) { + this.used = used; + this.pinned = pinned; + this.allocator = allocator; + this.ratio = (short) Math.round(((double) pinned / used) * 100); + } + + static DirectMemoryUsage capture(PooledByteBufAllocator allocator) { + long usedDirectMemory = allocator.metric().usedDirectMemory(); + long pinnedDirectMemory = allocator.pinnedDirectMemory(); + return new DirectMemoryUsage(usedDirectMemory, pinnedDirectMemory, allocator); + } + + boolean isCloseToOOM() { + long maxDirectMemory = PlatformDependent.maxDirectMemory(); + int chunkSize = allocator.metric().chunkSize(); + return ((maxDirectMemory - used) <= chunkSize) && ratio > 75; + } + } + + private final static Logger logger = LogManager.getLogger(OOMConnectionCloser.class); + + public static final Pattern DIRECT_MEMORY_ERROR = Pattern.compile("^Cannot reserve \\d* bytes of direct buffer memory.*$"); + + OOMConnectionCloser() { + allocator = (PooledByteBufAllocator) ByteBufAllocator.DEFAULT; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + DirectMemoryUsage direct = DirectMemoryUsage.capture(allocator); + logger.info("Direct memory status, used: {}, pinned: {}, ratio: {}", direct.used, direct.pinned, direct.ratio); + if (direct.isCloseToOOM()) { + logger.warn("Closing connection {} because running out of memory, used: {}, pinned: {}, ratio {}", ctx.channel(), direct.used, direct.pinned, direct.ratio); + ReferenceCountUtil.release(msg); // to free the memory used by the buffer + ctx.flush(); + ctx.close(); + } else { + super.channelRead(ctx, msg); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (isDirectMemoryOOM(cause)) { + DirectMemoryUsage direct = DirectMemoryUsage.capture(allocator); + logger.info("Direct memory status, used: {}, pinned: {}, ratio: {}", direct.used, direct.pinned, direct.ratio); + logger.warn("Dropping connection {} due to lack of available Direct Memory. Please lower the number of concurrent connections or reduce the batch size. " + + "Alternatively, raise -XX:MaxDirectMemorySize option in the JVM running Logstash", ctx.channel()); + ctx.flush(); + ctx.close(); + } else { + super.exceptionCaught(ctx, cause); + } + } + + private boolean isDirectMemoryOOM(Throwable th) { + if (!(th instanceof OutOfMemoryError)) { + return false; + } + Matcher m = DIRECT_MEMORY_ERROR.matcher(th.getMessage()); + return m.matches(); + } +} \ No newline at end of file diff --git a/src/main/java/org/logstash/beats/Runner.java b/src/main/java/org/logstash/beats/Runner.java index 0cb623e4..548f6ef1 100644 --- a/src/main/java/org/logstash/beats/Runner.java +++ b/src/main/java/org/logstash/beats/Runner.java @@ -17,7 +17,7 @@ static public void main(String[] args) throws Exception { // Check for leaks. // ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID); - Server server = new Server("0.0.0.0", DEFAULT_PORT, 15, Runtime.getRuntime().availableProcessors()); + Server server = new Server("0.0.0.0", DEFAULT_PORT, 15, Runtime.getRuntime().availableProcessors(), true); if(args.length > 0 && args[0].equals("ssl")) { logger.debug("Using SSL"); diff --git a/src/main/java/org/logstash/beats/Server.java b/src/main/java/org/logstash/beats/Server.java index c343aaf6..b15741b6 100644 --- a/src/main/java/org/logstash/beats/Server.java +++ b/src/main/java/org/logstash/beats/Server.java @@ -1,13 +1,18 @@ package org.logstash.beats; import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.*; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.timeout.IdleStateHandler; import io.netty.util.concurrent.DefaultEventExecutorGroup; import io.netty.util.concurrent.EventExecutorGroup; +import io.netty.util.internal.PlatformDependent; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.logstash.netty.SslHandlerProvider; @@ -18,6 +23,7 @@ public class Server { private final int port; private final String host; private final int beatsHeandlerThreadCount; + private final boolean protectDirectMemory; private NioEventLoopGroup workGroup; private IMessageListener messageListener = new MessageListener(); private SslHandlerProvider sslHandlerProvider; @@ -25,11 +31,26 @@ public class Server { private final int clientInactivityTimeoutSeconds; - public Server(String host, int p, int clientInactivityTimeoutSeconds, int threadCount) { + public Server(String host, int p, int clientInactivityTimeoutSeconds, int threadCount, boolean protectDirectMemory) { this.host = host; port = p; this.clientInactivityTimeoutSeconds = clientInactivityTimeoutSeconds; beatsHeandlerThreadCount = threadCount; + this.protectDirectMemory = protectDirectMemory; + + validateMinimumDirectMemory(); + } + + /** + * Validate if the configured available direct memory is enough for safe processing, else throws a ConfigurationException + * */ + private void validateMinimumDirectMemory() { + long maxDirectMemoryAllocatable = PlatformDependent.maxDirectMemory(); + if (maxDirectMemoryAllocatable < 256 * 1024 * 1024) { + long roundedMegabytes = Math.round((double) maxDirectMemoryAllocatable / 1024 / 1024); + throw new IllegalArgumentException("Max direct memory should be at least 256MB but was " + roundedMegabytes + "MB, " + + "please check your MaxDirectMemorySize and io.netty.maxDirectMemory settings"); + } } public void setSslHandlerProvider(SslHandlerProvider sslHandlerProvider){ @@ -126,6 +147,9 @@ private class BeatsInitializer extends ChannelInitializer { public void initChannel(SocketChannel socket){ ChannelPipeline pipeline = socket.pipeline(); + if (protectDirectMemory) { + pipeline.addLast(new OOMConnectionCloser()); + } if (isSslEnabled()) { pipeline.addLast(SSL_HANDLER, sslHandlerProvider.sslHandlerForChannel(socket)); @@ -134,7 +158,12 @@ public void initChannel(SocketChannel socket){ new IdleStateHandler(localClientInactivityTimeoutSeconds, IDLESTATE_WRITER_IDLE_TIME_SECONDS, localClientInactivityTimeoutSeconds)); pipeline.addLast(BEATS_ACKER, new AckEncoder()); pipeline.addLast(CONNECTION_HANDLER, new ConnectionHandler()); - pipeline.addLast(beatsHandlerExecutorGroup, new BeatsParser(), new BeatsHandler(localMessageListener)); + if (protectDirectMemory) { + pipeline.addLast(new FlowLimiterHandler()); + pipeline.addLast(new ThunderingGuardHandler()); + } + pipeline.addLast(beatsHandlerExecutorGroup, new BeatsParser()); + pipeline.addLast(beatsHandlerExecutorGroup, new BeatsHandler(localMessageListener)); } diff --git a/src/main/java/org/logstash/beats/ThunderingGuardHandler.java b/src/main/java/org/logstash/beats/ThunderingGuardHandler.java new file mode 100644 index 00000000..a880ed35 --- /dev/null +++ b/src/main/java/org/logstash/beats/ThunderingGuardHandler.java @@ -0,0 +1,40 @@ +package org.logstash.beats; + +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * This handler is responsible to avoid accepting new connections when the direct memory + * consumption is close to the MaxDirectMemorySize. + *

+ * If the total allocated direct memory is close to the max memory size and also the pinned + * bytes from the direct memory allocator is close to the direct memory used, then it drops the new + * incoming connections. + * */ +@Sharable +public final class ThunderingGuardHandler extends ChannelInboundHandlerAdapter { + + private final static long MAX_DIRECT_MEMORY = io.netty.util.internal.PlatformDependent.maxDirectMemory(); + + private final static Logger logger = LogManager.getLogger(ThunderingGuardHandler.class); + + @Override + public void channelRegistered(final ChannelHandlerContext ctx) throws Exception { + PooledByteBufAllocator pooledAllocator = (PooledByteBufAllocator) ctx.alloc(); + long usedDirectMemory = pooledAllocator.metric().usedDirectMemory(); + if (usedDirectMemory > MAX_DIRECT_MEMORY * 0.90) { + long pinnedDirectMemory = pooledAllocator.pinnedDirectMemory(); + if (pinnedDirectMemory >= usedDirectMemory * 0.80) { + ctx.close(); + logger.warn("Dropping connection {} due to high resource consumption", ctx.channel()); + return; + } + } + + super.channelRegistered(ctx); + } +} diff --git a/src/test/java/org/logstash/beats/FlowLimiterHandlerTest.java b/src/test/java/org/logstash/beats/FlowLimiterHandlerTest.java new file mode 100644 index 00000000..f268621e --- /dev/null +++ b/src/test/java/org/logstash/beats/FlowLimiterHandlerTest.java @@ -0,0 +1,195 @@ +package org.logstash.beats; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import org.junit.Test; + +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import static org.junit.Assert.*; + +public class FlowLimiterHandlerTest { + + private ReadMessagesCollector readMessagesCollector; + + private static ByteBuf prepareSample(int numBytes) { + return prepareSample(numBytes, 'A'); + } + + private static ByteBuf prepareSample(int numBytes, char c) { + ByteBuf payload = PooledByteBufAllocator.DEFAULT.directBuffer(numBytes); + for (int i = 0; i < numBytes; i++) { + payload.writeByte(c); + } + return payload; + } + + private ChannelInboundHandlerAdapter onClientConnected(Consumer action) { + return new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + action.accept(ctx); + } + }; + } + + private static class ReadMessagesCollector extends SimpleChannelInboundHandler { + private Channel clientChannel; + private final NioEventLoopGroup group; + boolean firstChunkRead = false; + + ReadMessagesCollector(NioEventLoopGroup group) { + this.group = group; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + if (!firstChunkRead) { + assertEquals("Expect to read a first chunk and no others", 32, msg.readableBytes()); + firstChunkRead = true; + + // client write other data that MUSTN'T be read by the server, because + // is rate limited. + clientChannel.writeAndFlush(prepareSample(16)).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + // on successful flush schedule a shutdown + ctx.channel().eventLoop().schedule(new Runnable() { + @Override + public void run() { + group.shutdownGracefully(); + } + }, 2, TimeUnit.SECONDS); + } else { + ctx.fireExceptionCaught(future.cause()); + } + } + }); + + } else { + // the first read happened, no other reads are commanded by the server + // should never pass here + fail("Shouldn't never be notified other data while in rate limiting"); + } + } + + public void updateClient(Channel clientChannel) { + assertNotNull(clientChannel); + this.clientChannel = clientChannel; + } + } + + + private static class AssertionsHandler extends ChannelInboundHandlerAdapter { + + private final NioEventLoopGroup group; + + private Throwable lastError; + + public AssertionsHandler(NioEventLoopGroup group) { + this.group = group; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + lastError = cause; + group.shutdownGracefully(); + } + + public void assertNoErrors() { + if (lastError != null) { + if (lastError instanceof AssertionError) { + throw (AssertionError) lastError; + } else { + fail("Failed with error" + lastError); + } + } + } + } + + @Test + public void givenAChannelInNotWriteableStateWhenNewBuffersAreSentByClientThenNoDecodeTakePartOnServerSide() throws Exception { + final int highWaterMark = 32 * 1024; + FlowLimiterHandler sut = new FlowLimiterHandler(); + + NioEventLoopGroup group = new NioEventLoopGroup(); + ServerBootstrap b = new ServerBootstrap(); + + readMessagesCollector = new ReadMessagesCollector(group); + AssertionsHandler assertionsHandler = new AssertionsHandler(group); + try { + b.group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.config().setWriteBufferHighWaterMark(highWaterMark); + ch.pipeline() + .addLast(onClientConnected(ctx -> { + // write as much to move the channel in not writable state + fillOutboundWatermark(ctx, highWaterMark); + // ask the client to send some data present on the channel + clientChannel.writeAndFlush(prepareSample(32)); + })) + .addLast(sut) + .addLast(readMessagesCollector) + .addLast(assertionsHandler); + } + }); + ChannelFuture future = b.bind("0.0.0.0", 1234).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + startAClient(group); + } + } + }).sync(); + future.channel().closeFuture().sync(); + } finally { + group.shutdownGracefully().sync(); + } + + assertionsHandler.assertNoErrors(); + } + + private static void fillOutboundWatermark(ChannelHandlerContext ctx, int highWaterMark) { + final ByteBuf payload = prepareSample(highWaterMark, 'C'); + while (ctx.channel().isWritable()) { + ctx.pipeline().writeAndFlush(payload.copy()); + } + } + + Channel clientChannel; + + private void startAClient(NioEventLoopGroup group) { + Bootstrap b = new Bootstrap(); + b.group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.config().setAutoRead(false); + clientChannel = ch; + readMessagesCollector.updateClient(clientChannel); + } + }); + b.connect("localhost", 1234); + } + +} \ No newline at end of file diff --git a/src/test/java/org/logstash/beats/ServerTest.java b/src/test/java/org/logstash/beats/ServerTest.java index 37512cdc..b067e55d 100644 --- a/src/test/java/org/logstash/beats/ServerTest.java +++ b/src/test/java/org/logstash/beats/ServerTest.java @@ -50,7 +50,7 @@ public void testServerShouldTerminateConnectionWhenExceptionHappen() throws Inte final CountDownLatch latch = new CountDownLatch(concurrentConnections); - final Server server = new Server(host, randomPort, inactivityTime, threadCount); + final Server server = new Server(host, randomPort, inactivityTime, threadCount, true); final AtomicBoolean otherCause = new AtomicBoolean(false); server.setMessageListener(new MessageListener() { public void onNewConnection(ChannelHandlerContext ctx) { @@ -114,7 +114,7 @@ public void testServerShouldTerminateConnectionIdleForTooLong() throws Interrupt final CountDownLatch latch = new CountDownLatch(concurrentConnections); final AtomicBoolean exceptionClose = new AtomicBoolean(false); - final Server server = new Server(host, randomPort, inactivityTime, threadCount); + final Server server = new Server(host, randomPort, inactivityTime, threadCount, true); server.setMessageListener(new MessageListener() { @Override public void onNewConnection(ChannelHandlerContext ctx) { @@ -170,7 +170,7 @@ public void run() { @Test public void testServerShouldAcceptConcurrentConnection() throws InterruptedException { - final Server server = new Server(host, randomPort, 30, threadCount); + final Server server = new Server(host, randomPort, 30, threadCount, true); SpyListener listener = new SpyListener(); server.setMessageListener(listener); Runnable serverTask = new Runnable() { diff --git a/src/test/java/org/logstash/beats/ThunderingGuardHandlerTest.java b/src/test/java/org/logstash/beats/ThunderingGuardHandlerTest.java new file mode 100644 index 00000000..d93f6582 --- /dev/null +++ b/src/test/java/org/logstash/beats/ThunderingGuardHandlerTest.java @@ -0,0 +1,83 @@ +package org.logstash.beats; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.PlatformDependent; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.*; + +public class ThunderingGuardHandlerTest { + + public static final int MB = 1024 * 1024; + public static final long MAX_DIRECT_MEMORY_BYTES = PlatformDependent.maxDirectMemory(); + + @Test + public void testVerifyDirectMemoryCouldGoBeyondThe90Percent() { + // allocate 90% of direct memory + List allocatedBuffers = allocateDirectMemory(MAX_DIRECT_MEMORY_BYTES, 0.9); + + // allocate one more + ByteBuf payload = PooledByteBufAllocator.DEFAULT.directBuffer(1 * MB); + long usedDirectMemory = PooledByteBufAllocator.DEFAULT.metric().usedDirectMemory(); + long pinnedDirectMemory = PooledByteBufAllocator.DEFAULT.pinnedDirectMemory(); + + // verify + assertTrue("Direct memory allocation should be > 90% of the max available", usedDirectMemory > 0.9 * MAX_DIRECT_MEMORY_BYTES); + assertTrue("Direct memory usage should be > 80% of the max available", pinnedDirectMemory > 0.8 * MAX_DIRECT_MEMORY_BYTES); + + allocatedBuffers.forEach(ReferenceCounted::release); + payload.release(); + } + + private static List allocateDirectMemory(long maxDirectMemoryBytes, double percentage) { + List allocatedBuffers = new ArrayList<>(); + final long numBuffersToAllocate = (long) (maxDirectMemoryBytes / MB * percentage); + for (int i = 0; i < numBuffersToAllocate; i++) { + allocatedBuffers.add(PooledByteBufAllocator.DEFAULT.directBuffer(1 * MB)); + } + return allocatedBuffers; + } + + @Test + public void givenUsedDirectMemoryAndPinnedMemoryAreCloseToTheMaxDirectAvailableWhenNewConnectionIsCreatedThenItIsReject() { + EmbeddedChannel channel = new EmbeddedChannel(new ThunderingGuardHandler()); + + // consume > 90% of the direct memory + List allocatedBuffers = allocateDirectMemory(MAX_DIRECT_MEMORY_BYTES, 0.9); + // allocate one more + ByteBuf payload = PooledByteBufAllocator.DEFAULT.directBuffer(1 * MB); + + channel.pipeline().fireChannelRegistered(); + + // verify + assertFalse("Under constrained memory new channels has to be forcibly closed", channel.isOpen()); + + allocatedBuffers.forEach(ReferenceCounted::release); + payload.release(); + } + + @Test + public void givenUsedDirectMemoryAndNotPinnedWhenNewConnectionIsCreatedThenItIsAccepted() { + EmbeddedChannel channel = new EmbeddedChannel(new ThunderingGuardHandler()); + + // consume > 90% of the direct memory + List allocatedBuffers = allocateDirectMemory(MAX_DIRECT_MEMORY_BYTES, 0.9); + allocatedBuffers.forEach(ReferenceCounted::release); + // allocate one more + ByteBuf payload = PooledByteBufAllocator.DEFAULT.directBuffer(1 * MB); + payload.release(); + + channel.pipeline().fireChannelRegistered(); + + // verify + assertTrue("Despite memory is allocated but not pinned, new connections MUST be accepted", channel.isOpen()); + + } + +} \ No newline at end of file