diff --git a/.gitignore b/.gitignore index a3c75313a21c..ed43883f5ba7 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,9 @@ hs_err_pid*.log dependency-reduced-pom.xml + +*/.unison.* + +# exclude mainframer files +mainframer +.mainframer diff --git a/.mvn/settings.xml b/.mvn/settings.xml new file mode 100644 index 000000000000..1f7f6fafec76 --- /dev/null +++ b/.mvn/settings.xml @@ -0,0 +1,9 @@ + + + + sonatype-nexus-snapshots + ${env.SANOTYPE_USER} + ${env.SANOTYPE_PASSWORD} + + + diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index b35e822464d3..000000000000 --- a/.travis.yml +++ /dev/null @@ -1,12 +0,0 @@ -language: java -jdk: - - oraclejdk7 - - openjdk7 -branches: - only: - - master - - 3 - - 3.5 -before_install: 'mvn -version' -install: 'mvn clean install -Pfull -DskipTests' - diff --git a/all/pom.xml b/all/pom.xml index b0b519419597..e0ca51fc0bcb 100644 --- a/all/pom.xml +++ b/all/pom.xml @@ -17,11 +17,11 @@ 4.0.0 - + io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-all @@ -32,6 +32,7 @@ ${project.build.directory}/src ${project.build.directory}/versions + true diff --git a/bom/pom.xml b/bom/pom.xml index b8542d9fca65..b78da2b70dc1 100644 --- a/bom/pom.xml +++ b/bom/pom.xml @@ -16,16 +16,16 @@ --> 4.0.0 - + io.netty netty-bom - 4.1.25.dse + 4.1.34.3.dse pom Netty/BOM @@ -49,7 +49,6 @@ https://github.com/netty/netty scm:git:git://github.com/netty/netty.git scm:git:ssh://git@github.com/netty/netty.git - netty-4.1.25.dse @@ -69,165 +68,165 @@ io.netty netty-buffer - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-dns - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-haproxy - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-http - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-http2 - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-memcache - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-mqtt - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-redis - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-smtp - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-socks - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-stomp - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-codec-xml - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-common - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-dev-tools - 4.1.25.dse + 4.1.34.3.dse io.netty netty-handler - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-handler-proxy - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-resolver - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-resolver-dns - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-transport - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-transport-rxtx - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-transport-sctp - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-transport-udt - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-example - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-all - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-transport-native-unix-common - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-transport-native-unix-common - 4.1.25.5.dse + 4.1.34.3.dse linux-x86_64 io.netty netty-transport-native-unix-common - 4.1.25.5.dse + 4.1.34.3.dse osx-x86_64 io.netty netty-transport-native-epoll - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-transport-native-epoll - 4.1.25.5.dse + 4.1.34.3.dse linux-x86_64 io.netty netty-transport-native-kqueue - 4.1.25.5.dse + 4.1.34.3.dse io.netty netty-transport-native-kqueue - 4.1.25.5.dse + 4.1.34.3.dse osx-x86_64 diff --git a/buffer/pom.xml b/buffer/pom.xml index 6ee9d013b0eb..4e5586497c95 100644 --- a/buffer/pom.xml +++ b/buffer/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-buffer diff --git a/buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java b/buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java index 4a13d821b933..ad2f793b1abf 100644 --- a/buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java @@ -15,6 +15,7 @@ */ package io.netty.buffer; +import io.netty.util.AsciiString; import io.netty.util.ByteProcessor; import io.netty.util.CharsetUtil; import io.netty.util.IllegalReferenceCountException; @@ -37,19 +38,29 @@ import java.nio.charset.Charset; import static io.netty.util.internal.MathUtil.isOutOfBounds; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; /** * A skeletal implementation of a buffer. */ public abstract class AbstractByteBuf extends ByteBuf { private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractByteBuf.class); - private static final String PROP_MODE = "io.netty.buffer.bytebuf.checkAccessible"; - private static final boolean checkAccessible; + private static final String LEGACY_PROP_CHECK_ACCESSIBLE = "io.netty.buffer.bytebuf.checkAccessible"; + private static final String PROP_CHECK_ACCESSIBLE = "io.netty.buffer.checkAccessible"; + static final boolean checkAccessible; // accessed from CompositeByteBuf + private static final String PROP_CHECK_BOUNDS = "io.netty.buffer.checkBounds"; + private static final boolean checkBounds; static { - checkAccessible = SystemPropertyUtil.getBoolean(PROP_MODE, true); + if (SystemPropertyUtil.contains(PROP_CHECK_ACCESSIBLE)) { + checkAccessible = SystemPropertyUtil.getBoolean(PROP_CHECK_ACCESSIBLE, true); + } else { + checkAccessible = SystemPropertyUtil.getBoolean(LEGACY_PROP_CHECK_ACCESSIBLE, true); + } + checkBounds = SystemPropertyUtil.getBoolean(PROP_CHECK_BOUNDS, true); if (logger.isDebugEnabled()) { - logger.debug("-D{}: {}", PROP_MODE, checkAccessible); + logger.debug("-D{}: {}", PROP_CHECK_ACCESSIBLE, checkAccessible); + logger.debug("-D{}: {}", PROP_CHECK_BOUNDS, checkBounds); } } @@ -63,9 +74,7 @@ public abstract class AbstractByteBuf extends ByteBuf { private int maxCapacity; protected AbstractByteBuf(int maxCapacity) { - if (maxCapacity < 0) { - throw new IllegalArgumentException("maxCapacity: " + maxCapacity + " (expected: >= 0)"); - } + checkPositiveOrZero(maxCapacity, "maxCapacity"); this.maxCapacity = maxCapacity; } @@ -97,11 +106,18 @@ public int readerIndex() { return readerIndex; } + private static void checkIndexBounds(final int readerIndex, final int writerIndex, final int capacity) { + if (readerIndex < 0 || readerIndex > writerIndex || writerIndex > capacity) { + throw new IndexOutOfBoundsException(String.format( + "readerIndex: %d, writerIndex: %d (expected: 0 <= readerIndex <= writerIndex <= capacity(%d))", + readerIndex, writerIndex, capacity)); + } + } + @Override public ByteBuf readerIndex(int readerIndex) { - if (readerIndex < 0 || readerIndex > writerIndex) { - throw new IndexOutOfBoundsException(String.format( - "readerIndex: %d (expected: 0 <= readerIndex <= writerIndex(%d))", readerIndex, writerIndex)); + if (checkBounds) { + checkIndexBounds(readerIndex, writerIndex, capacity()); } this.readerIndex = readerIndex; return this; @@ -114,10 +130,8 @@ public int writerIndex() { @Override public ByteBuf writerIndex(int writerIndex) { - if (writerIndex < readerIndex || writerIndex > capacity()) { - throw new IndexOutOfBoundsException(String.format( - "writerIndex: %d (expected: readerIndex(%d) <= writerIndex <= capacity(%d))", - writerIndex, readerIndex, capacity())); + if (checkBounds) { + checkIndexBounds(readerIndex, writerIndex, capacity()); } this.writerIndex = writerIndex; return this; @@ -125,10 +139,8 @@ public ByteBuf writerIndex(int writerIndex) { @Override public ByteBuf setIndex(int readerIndex, int writerIndex) { - if (readerIndex < 0 || readerIndex > writerIndex || writerIndex > capacity()) { - throw new IndexOutOfBoundsException(String.format( - "readerIndex: %d, writerIndex: %d (expected: 0 <= readerIndex <= writerIndex <= capacity(%d))", - readerIndex, writerIndex, capacity())); + if (checkBounds) { + checkIndexBounds(readerIndex, writerIndex, capacity()); } setIndex0(readerIndex, writerIndex); return this; @@ -258,10 +270,7 @@ protected final void adjustMarkers(int decrement) { @Override public ByteBuf ensureWritable(int minWritableBytes) { - if (minWritableBytes < 0) { - throw new IllegalArgumentException(String.format( - "minWritableBytes: %d (expected: >= 0)", minWritableBytes)); - } + checkPositiveOrZero(minWritableBytes, "minWritableBytes"); ensureWritable0(minWritableBytes); return this; } @@ -271,11 +280,12 @@ final void ensureWritable0(int minWritableBytes) { if (minWritableBytes <= writableBytes()) { return; } - - if (minWritableBytes > maxCapacity - writerIndex) { - throw new IndexOutOfBoundsException(String.format( - "writerIndex(%d) + minWritableBytes(%d) exceeds maxCapacity(%d): %s", - writerIndex, minWritableBytes, maxCapacity, this)); + if (checkBounds) { + if (minWritableBytes > maxCapacity - writerIndex) { + throw new IndexOutOfBoundsException(String.format( + "writerIndex(%d) + minWritableBytes(%d) exceeds maxCapacity(%d): %s", + writerIndex, minWritableBytes, maxCapacity, this)); + } } // Normalize the current capacity to the power of 2. @@ -288,10 +298,7 @@ final void ensureWritable0(int minWritableBytes) { @Override public int ensureWritable(int minWritableBytes, boolean force) { ensureAccessible(); - if (minWritableBytes < 0) { - throw new IllegalArgumentException(String.format( - "minWritableBytes: %d (expected: >= 0)", minWritableBytes)); - } + checkPositiveOrZero(minWritableBytes, "minWritableBytes"); if (minWritableBytes <= writableBytes()) { return 0; @@ -318,12 +325,12 @@ public int ensureWritable(int minWritableBytes, boolean force) { @Override public ByteBuf order(ByteOrder endianness) { - if (endianness == null) { - throw new NullPointerException("endianness"); - } if (endianness == order()) { return this; } + if (endianness == null) { + throw new NullPointerException("endianness"); + } return newSwappedByteBuf(); } @@ -490,7 +497,10 @@ public ByteBuf getBytes(int index, ByteBuf dst, int length) { @Override public CharSequence getCharSequence(int index, int length, Charset charset) { - // TODO: We could optimize this for UTF8 and US_ASCII + if (CharsetUtil.US_ASCII.equals(charset) || CharsetUtil.ISO_8859_1.equals(charset)) { + // ByteBufUtil.getBytes(...) will return a new copy which the AsciiString uses directly + return new AsciiString(ByteBufUtil.getBytes(this, index, length, true), false); + } return toString(index, length, charset); } @@ -618,15 +628,21 @@ public ByteBuf setBytes(int index, ByteBuf src) { return this; } + private static void checkReadableBounds(final ByteBuf src, final int length) { + if (length > src.readableBytes()) { + throw new IndexOutOfBoundsException(String.format( + "length(%d) exceeds src.readableBytes(%d) where src is: %s", length, src.readableBytes(), src)); + } + } + @Override public ByteBuf setBytes(int index, ByteBuf src, int length) { checkIndex(index, length); if (src == null) { throw new NullPointerException("src"); } - if (length > src.readableBytes()) { - throw new IndexOutOfBoundsException(String.format( - "length(%d) exceeds src.readableBytes(%d) where src is: %s", length, src.readableBytes(), src)); + if (checkBounds) { + checkReadableBounds(src, length); } setBytes(index, src, src.readerIndex(), length); @@ -889,9 +905,11 @@ public ByteBuf readBytes(ByteBuf dst) { @Override public ByteBuf readBytes(ByteBuf dst, int length) { - if (length > dst.writableBytes()) { - throw new IndexOutOfBoundsException(String.format( - "length(%d) exceeds dst.writableBytes(%d) where dst is: %s", length, dst.writableBytes(), dst)); + if (checkBounds) { + if (length > dst.writableBytes()) { + throw new IndexOutOfBoundsException(String.format( + "length(%d) exceeds dst.writableBytes(%d) where dst is: %s", length, dst.writableBytes(), dst)); + } } readBytes(dst, dst.writerIndex(), length); dst.writerIndex(dst.writerIndex() + length); @@ -1065,9 +1083,8 @@ public ByteBuf writeBytes(ByteBuf src) { @Override public ByteBuf writeBytes(ByteBuf src, int length) { - if (length > src.readableBytes()) { - throw new IndexOutOfBoundsException(String.format( - "length(%d) exceeds src.readableBytes(%d) where src is: %s", length, src.readableBytes(), src)); + if (checkBounds) { + checkReadableBounds(src, length); } writeBytes(src, src.readerIndex(), length); src.readerIndex(src.readerIndex() + length); @@ -1269,7 +1286,7 @@ public int forEachByte(int index, int length, ByteProcessor processor) { } } - private int forEachByteAsc0(int start, int end, ByteProcessor processor) throws Exception { + int forEachByteAsc0(int start, int end, ByteProcessor processor) throws Exception { for (; start < end; ++start) { if (!processor.process(_getByte(start))) { return start; @@ -1301,7 +1318,7 @@ public int forEachByteDesc(int index, int length, ByteProcessor processor) { } } - private int forEachByteDesc0(int rStart, final int rEnd, ByteProcessor processor) throws Exception { + int forEachByteDesc0(int rStart, final int rEnd, ByteProcessor processor) throws Exception { for (; rStart >= rEnd; --rStart) { if (!processor.process(_getByte(rStart))) { return rStart; @@ -1357,26 +1374,30 @@ protected final void checkIndex(int index, int fieldLength) { checkIndex0(index, fieldLength); } - final void checkIndex0(int index, int fieldLength) { - if (isOutOfBounds(index, fieldLength, capacity())) { + private static void checkRangeBounds(final int index, final int fieldLength, final int capacity) { + if (isOutOfBounds(index, fieldLength, capacity)) { throw new IndexOutOfBoundsException(String.format( - "index: %d, length: %d (expected: range(0, %d))", index, fieldLength, capacity())); + "index: %d, length: %d (expected: range(0, %d))", index, fieldLength, capacity)); + } + } + + final void checkIndex0(int index, int fieldLength) { + if (checkBounds) { + checkRangeBounds(index, fieldLength, capacity()); } } protected final void checkSrcIndex(int index, int length, int srcIndex, int srcCapacity) { checkIndex(index, length); - if (isOutOfBounds(srcIndex, length, srcCapacity)) { - throw new IndexOutOfBoundsException(String.format( - "srcIndex: %d, length: %d (expected: range(0, %d))", srcIndex, length, srcCapacity)); + if (checkBounds) { + checkRangeBounds(srcIndex, length, srcCapacity); } } protected final void checkDstIndex(int index, int length, int dstIndex, int dstCapacity) { checkIndex(index, length); - if (isOutOfBounds(dstIndex, length, dstCapacity)) { - throw new IndexOutOfBoundsException(String.format( - "dstIndex: %d, length: %d (expected: range(0, %d))", dstIndex, length, dstCapacity)); + if (checkBounds) { + checkRangeBounds(dstIndex, length, dstCapacity); } } @@ -1386,25 +1407,28 @@ protected final void checkDstIndex(int index, int length, int dstIndex, int dstC * than the specified value. */ protected final void checkReadableBytes(int minimumReadableBytes) { - if (minimumReadableBytes < 0) { - throw new IllegalArgumentException("minimumReadableBytes: " + minimumReadableBytes + " (expected: >= 0)"); - } + checkPositiveOrZero(minimumReadableBytes, "minimumReadableBytes"); checkReadableBytes0(minimumReadableBytes); } protected final void checkNewCapacity(int newCapacity) { ensureAccessible(); - if (newCapacity < 0 || newCapacity > maxCapacity()) { - throw new IllegalArgumentException("newCapacity: " + newCapacity + " (expected: 0-" + maxCapacity() + ')'); + if (checkBounds) { + if (newCapacity < 0 || newCapacity > maxCapacity()) { + throw new IllegalArgumentException("newCapacity: " + newCapacity + + " (expected: 0-" + maxCapacity() + ')'); + } } } private void checkReadableBytes0(int minimumReadableBytes) { ensureAccessible(); - if (readerIndex > writerIndex - minimumReadableBytes) { - throw new IndexOutOfBoundsException(String.format( - "readerIndex(%d) + length(%d) exceeds writerIndex(%d): %s", - readerIndex, minimumReadableBytes, writerIndex, this)); + if (checkBounds) { + if (readerIndex > writerIndex - minimumReadableBytes) { + throw new IndexOutOfBoundsException(String.format( + "readerIndex(%d) + length(%d) exceeds writerIndex(%d): %s", + readerIndex, minimumReadableBytes, writerIndex, this)); + } } } @@ -1413,7 +1437,7 @@ private void checkReadableBytes0(int minimumReadableBytes) { * if the buffer was released before. */ protected final void ensureAccessible() { - if (checkAccessible && refCnt() == 0) { + if (checkAccessible && !isAccessible()) { throw new IllegalReferenceCountException(0); } } diff --git a/buffer/src/main/java/io/netty/buffer/AbstractByteBufAllocator.java b/buffer/src/main/java/io/netty/buffer/AbstractByteBufAllocator.java index 40525144e3f3..920b3fb46d04 100644 --- a/buffer/src/main/java/io/netty/buffer/AbstractByteBufAllocator.java +++ b/buffer/src/main/java/io/netty/buffer/AbstractByteBufAllocator.java @@ -16,6 +16,8 @@ package io.netty.buffer; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.util.ResourceLeakDetector; import io.netty.util.ResourceLeakTracker; import io.netty.util.internal.PlatformDependent; @@ -222,9 +224,7 @@ public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { } private static void validate(int initialCapacity, int maxCapacity) { - if (initialCapacity < 0) { - throw new IllegalArgumentException("initialCapacity: " + initialCapacity + " (expected: 0+)"); - } + checkPositiveOrZero(initialCapacity, "initialCapacity"); if (initialCapacity > maxCapacity) { throw new IllegalArgumentException(String.format( "initialCapacity: %d (expected: not greater than maxCapacity(%d)", @@ -249,9 +249,7 @@ public String toString() { @Override public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { - if (minNewCapacity < 0) { - throw new IllegalArgumentException("minNewCapacity: " + minNewCapacity + " (expected: 0+)"); - } + checkPositiveOrZero(minNewCapacity, "minNewCapacity"); if (minNewCapacity > maxCapacity) { throw new IllegalArgumentException(String.format( "minNewCapacity: %d (expected: not greater than maxCapacity(%d)", diff --git a/buffer/src/main/java/io/netty/buffer/AbstractDerivedByteBuf.java b/buffer/src/main/java/io/netty/buffer/AbstractDerivedByteBuf.java index 58f1d907a156..40844020dd21 100644 --- a/buffer/src/main/java/io/netty/buffer/AbstractDerivedByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/AbstractDerivedByteBuf.java @@ -31,6 +31,11 @@ protected AbstractDerivedByteBuf(int maxCapacity) { super(maxCapacity); } + @Override + final boolean isAccessible() { + return unwrap().isAccessible(); + } + @Override public final int refCnt() { return refCnt0(); diff --git a/buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java b/buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java index d624d855f4da..548bdba5e493 100644 --- a/buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java @@ -17,6 +17,7 @@ package io.netty.buffer; import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.PlatformDependent; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; @@ -26,27 +27,68 @@ * Abstract base class for {@link ByteBuf} implementations that count references. */ public abstract class AbstractReferenceCountedByteBuf extends AbstractByteBuf { - + private static final long REFCNT_FIELD_OFFSET; private static final AtomicIntegerFieldUpdater refCntUpdater = AtomicIntegerFieldUpdater.newUpdater(AbstractReferenceCountedByteBuf.class, "refCnt"); - private volatile int refCnt; + // even => "real" refcount is (refCnt >>> 1); odd => "real" refcount is 0 + @SuppressWarnings("unused") + private volatile int refCnt = 2; + + static { + long refCntFieldOffset = -1; + try { + if (PlatformDependent.hasUnsafe()) { + refCntFieldOffset = PlatformDependent.objectFieldOffset( + AbstractReferenceCountedByteBuf.class.getDeclaredField("refCnt")); + } + } catch (Throwable ignore) { + refCntFieldOffset = -1; + } + + REFCNT_FIELD_OFFSET = refCntFieldOffset; + } + + private static int realRefCnt(int rawCnt) { + return (rawCnt & 1) != 0 ? 0 : rawCnt >>> 1; + } protected AbstractReferenceCountedByteBuf(int maxCapacity) { super(maxCapacity); - refCntUpdater.set(this, 1); + } + + private int nonVolatileRawCnt() { + // TODO: Once we compile against later versions of Java we can replace the Unsafe usage here by varhandles. + return REFCNT_FIELD_OFFSET != -1 ? PlatformDependent.getInt(this, REFCNT_FIELD_OFFSET) + : refCntUpdater.get(this); + } + + @Override + boolean isAccessible() { + // Try to do non-volatile read for performance as the ensureAccessible() is racy anyway and only provide + // a best-effort guard. + + // This is copied explicitly from the nonVolatileRawCnt() method above to reduce call stack depth, + // to avoid hitting the default limit for inlining (9) + final int rawCnt = REFCNT_FIELD_OFFSET != -1 ? PlatformDependent.getInt(this, REFCNT_FIELD_OFFSET) + : refCntUpdater.get(this); + + // The "real" ref count is > 0 if the rawCnt is even. + // (x & y) appears to be surprisingly expensive relative to (x == y). Thus the expression below provides + // a fast path for most common cases where the ref count is 1, 2, 3 or 4. + return rawCnt == 2 || rawCnt == 4 || rawCnt == 6 || rawCnt == 8 || (rawCnt & 1) == 0; } @Override public int refCnt() { - return refCnt; + return realRefCnt(refCntUpdater.get(this)); } /** * An unsafe operation intended for use by a subclass that sets the reference count of the buffer directly */ - protected final void setRefCnt(int refCnt) { - refCntUpdater.set(this, refCnt); + protected final void setRefCnt(int newRefCnt) { + refCntUpdater.set(this, newRefCnt << 1); // overflow OK here } @Override @@ -60,11 +102,18 @@ public ByteBuf retain(int increment) { } private ByteBuf retain0(final int increment) { - int oldRef = refCntUpdater.getAndAdd(this, increment); - if (oldRef <= 0 || oldRef + increment < oldRef) { - // Ensure we don't resurrect (which means the refCnt was 0) and also that we encountered an overflow. - refCntUpdater.getAndAdd(this, -increment); - throw new IllegalReferenceCountException(oldRef, increment); + // all changes to the raw count are 2x the "real" change + int adjustedIncrement = increment << 1; // overflow OK here + int oldRef = refCntUpdater.getAndAdd(this, adjustedIncrement); + if ((oldRef & 1) != 0) { + throw new IllegalReferenceCountException(0, increment); + } + // don't pass 0! + if ((oldRef <= 0 && oldRef + adjustedIncrement >= 0) + || (oldRef >= 0 && oldRef + adjustedIncrement < oldRef)) { + // overflow case + refCntUpdater.getAndAdd(this, -adjustedIncrement); + throw new IllegalReferenceCountException(realRefCnt(oldRef), increment); } return this; } @@ -90,17 +139,57 @@ public boolean release(int decrement) { } private boolean release0(int decrement) { - int oldRef = refCntUpdater.getAndAdd(this, -decrement); - if (oldRef == decrement) { - deallocate(); - return true; - } else if (oldRef < decrement || oldRef - decrement > oldRef) { - // Ensure we don't over-release, and avoid underflow. - refCntUpdater.getAndAdd(this, decrement); - throw new IllegalReferenceCountException(oldRef, -decrement); + int rawCnt = nonVolatileRawCnt(), realCnt = toLiveRealCnt(rawCnt, decrement); + if (decrement == realCnt) { + if (refCntUpdater.compareAndSet(this, rawCnt, 1)) { + deallocate(); + return true; + } + return retryRelease0(decrement); } - return false; + return releaseNonFinal0(decrement, rawCnt, realCnt); } + + private boolean releaseNonFinal0(int decrement, int rawCnt, int realCnt) { + if (decrement < realCnt + // all changes to the raw count are 2x the "real" change + && refCntUpdater.compareAndSet(this, rawCnt, rawCnt - (decrement << 1))) { + return false; + } + return retryRelease0(decrement); + } + + private boolean retryRelease0(int decrement) { + for (;;) { + int rawCnt = refCntUpdater.get(this), realCnt = toLiveRealCnt(rawCnt, decrement); + if (decrement == realCnt) { + if (refCntUpdater.compareAndSet(this, rawCnt, 1)) { + deallocate(); + return true; + } + } else if (decrement < realCnt) { + // all changes to the raw count are 2x the "real" change + if (refCntUpdater.compareAndSet(this, rawCnt, rawCnt - (decrement << 1))) { + return false; + } + } else { + throw new IllegalReferenceCountException(realCnt, -decrement); + } + Thread.yield(); // this benefits throughput under high contention + } + } + + /** + * Like {@link #realRefCnt(int)} but throws if refCnt == 0 + */ + private static int toLiveRealCnt(int rawCnt, int decrement) { + if ((rawCnt & 1) == 0) { + return rawCnt >>> 1; + } + // odd rawCnt => already deallocated + throw new IllegalReferenceCountException(0, -decrement); + } + /** * Called once {@link #refCnt()} is equals 0. */ diff --git a/buffer/src/main/java/io/netty/buffer/ByteBuf.java b/buffer/src/main/java/io/netty/buffer/ByteBuf.java index c04340a0d15e..83ecc6685a5b 100644 --- a/buffer/src/main/java/io/netty/buffer/ByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/ByteBuf.java @@ -258,14 +258,14 @@ public abstract class ByteBuf implements ReferenceCounted, Comparable { * capacity, the content of this buffer is truncated. If the {@code newCapacity} is greater * than the current capacity, the buffer is appended with unspecified data whose length is * {@code (newCapacity - currentCapacity)}. + * + * @throws IllegalArgumentException if the {@code newCapacity} is greater than {@link #maxCapacity()} */ public abstract ByteBuf capacity(int newCapacity); /** - * Returns the maximum allowed capacity of this buffer. If a user attempts to increase the - * capacity of this buffer beyond the maximum capacity using {@link #capacity(int)} or - * {@link #ensureWritable(int)}, those methods will raise an - * {@link IllegalArgumentException}. + * Returns the maximum allowed capacity of this buffer. This value provides an upper + * bound on {@link #capacity()}. */ public abstract int maxCapacity(); @@ -513,22 +513,23 @@ public abstract class ByteBuf implements ReferenceCounted, Comparable { public abstract ByteBuf discardSomeReadBytes(); /** - * Makes sure the number of {@linkplain #writableBytes() the writable bytes} - * is equal to or greater than the specified value. If there is enough - * writable bytes in this buffer, this method returns with no side effect. - * Otherwise, it raises an {@link IllegalArgumentException}. + * Expands the buffer {@link #capacity()} to make sure the number of + * {@linkplain #writableBytes() writable bytes} is equal to or greater than the + * specified value. If there are enough writable bytes in this buffer, this method + * returns with no side effect. * * @param minWritableBytes * the expected minimum number of writable bytes * @throws IndexOutOfBoundsException - * if {@link #writerIndex()} + {@code minWritableBytes} > {@link #maxCapacity()} + * if {@link #writerIndex()} + {@code minWritableBytes} > {@link #maxCapacity()}. + * @see #capacity(int) */ public abstract ByteBuf ensureWritable(int minWritableBytes); /** - * Tries to make sure the number of {@linkplain #writableBytes() the writable bytes} - * is equal to or greater than the specified value. Unlike {@link #ensureWritable(int)}, - * this method does not raise an exception but returns a code. + * Expands the buffer {@link #capacity()} to make sure the number of + * {@linkplain #writableBytes() writable bytes} is equal to or greater than the + * specified value. Unlike {@link #ensureWritable(int)}, this method returns a status code. * * @param minWritableBytes * the expected minimum number of writable bytes @@ -1756,9 +1757,8 @@ public double readDoubleLE() { /** * Sets the specified boolean at the current {@code writerIndex} * and increases the {@code writerIndex} by {@code 1} in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 1} + * If {@code this.writableBytes} is less than {@code 1}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeBoolean(boolean value); @@ -1766,9 +1766,8 @@ public double readDoubleLE() { * Sets the specified byte at the current {@code writerIndex} * and increases the {@code writerIndex} by {@code 1} in this buffer. * The 24 high-order bits of the specified value are ignored. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 1} + * If {@code this.writableBytes} is less than {@code 1}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeByte(int value); @@ -1776,9 +1775,8 @@ public double readDoubleLE() { * Sets the specified 16-bit short integer at the current * {@code writerIndex} and increases the {@code writerIndex} by {@code 2} * in this buffer. The 16 high-order bits of the specified value are ignored. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 2} + * If {@code this.writableBytes} is less than {@code 2}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeShort(int value); @@ -1787,9 +1785,8 @@ public double readDoubleLE() { * Order at the current {@code writerIndex} and increases the * {@code writerIndex} by {@code 2} in this buffer. * The 16 high-order bits of the specified value are ignored. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 2} + * If {@code this.writableBytes} is less than {@code 2}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeShortLE(int value); @@ -1797,9 +1794,8 @@ public double readDoubleLE() { * Sets the specified 24-bit medium integer at the current * {@code writerIndex} and increases the {@code writerIndex} by {@code 3} * in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 3} + * If {@code this.writableBytes} is less than {@code 3}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeMedium(int value); @@ -1808,18 +1804,16 @@ public double readDoubleLE() { * {@code writerIndex} in the Little Endian Byte Order and * increases the {@code writerIndex} by {@code 3} in this * buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 3} + * If {@code this.writableBytes} is less than {@code 3}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeMediumLE(int value); /** * Sets the specified 32-bit integer at the current {@code writerIndex} * and increases the {@code writerIndex} by {@code 4} in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 4} + * If {@code this.writableBytes} is less than {@code 4}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeInt(int value); @@ -1827,9 +1821,8 @@ public double readDoubleLE() { * Sets the specified 32-bit integer at the current {@code writerIndex} * in the Little Endian Byte Order and increases the {@code writerIndex} * by {@code 4} in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 4} + * If {@code this.writableBytes} is less than {@code 4}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeIntLE(int value); @@ -1837,9 +1830,8 @@ public double readDoubleLE() { * Sets the specified 64-bit long integer at the current * {@code writerIndex} and increases the {@code writerIndex} by {@code 8} * in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 8} + * If {@code this.writableBytes} is less than {@code 8}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeLong(long value); @@ -1848,9 +1840,8 @@ public double readDoubleLE() { * {@code writerIndex} in the Little Endian Byte Order and * increases the {@code writerIndex} by {@code 8} * in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 8} + * If {@code this.writableBytes} is less than {@code 8}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeLongLE(long value); @@ -1858,9 +1849,8 @@ public double readDoubleLE() { * Sets the specified 2-byte UTF-16 character at the current * {@code writerIndex} and increases the {@code writerIndex} by {@code 2} * in this buffer. The 16 high-order bits of the specified value are ignored. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 2} + * If {@code this.writableBytes} is less than {@code 2}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeChar(int value); @@ -1868,9 +1858,8 @@ public double readDoubleLE() { * Sets the specified 32-bit floating point number at the current * {@code writerIndex} and increases the {@code writerIndex} by {@code 4} * in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 4} + * If {@code this.writableBytes} is less than {@code 4}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeFloat(float value); @@ -1878,9 +1867,8 @@ public double readDoubleLE() { * Sets the specified 32-bit floating point number at the current * {@code writerIndex} in Little Endian Byte Order and increases * the {@code writerIndex} by {@code 4} in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 4} + * If {@code this.writableBytes} is less than {@code 4}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public ByteBuf writeFloatLE(float value) { return writeIntLE(Float.floatToRawIntBits(value)); @@ -1890,9 +1878,8 @@ public ByteBuf writeFloatLE(float value) { * Sets the specified 64-bit floating point number at the current * {@code writerIndex} and increases the {@code writerIndex} by {@code 8} * in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 8} + * If {@code this.writableBytes} is less than {@code 8}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeDouble(double value); @@ -1900,9 +1887,8 @@ public ByteBuf writeFloatLE(float value) { * Sets the specified 64-bit floating point number at the current * {@code writerIndex} in Little Endian Byte Order and increases * the {@code writerIndex} by {@code 8} in this buffer. - * - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is less than {@code 8} + * If {@code this.writableBytes} is less than {@code 8}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public ByteBuf writeDoubleLE(double value) { return writeLongLE(Double.doubleToRawLongBits(value)); @@ -1917,10 +1903,9 @@ public ByteBuf writeDoubleLE(double value) { * increases the {@code readerIndex} of the source buffer by the number of * the transferred bytes while {@link #writeBytes(ByteBuf, int, int)} * does not. - * - * @throws IndexOutOfBoundsException - * if {@code src.readableBytes} is greater than - * {@code this.writableBytes} + * If {@code this.writableBytes} is less than {@code src.readableBytes}, + * {@link #ensureWritable(int)} will be called in an attempt to expand + * capacity to accommodate. */ public abstract ByteBuf writeBytes(ByteBuf src); @@ -1932,12 +1917,11 @@ public ByteBuf writeDoubleLE(double value) { * except that this method increases the {@code readerIndex} of the source * buffer by the number of the transferred bytes (= {@code length}) while * {@link #writeBytes(ByteBuf, int, int)} does not. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. * * @param length the number of bytes to transfer - * - * @throws IndexOutOfBoundsException - * if {@code length} is greater than {@code this.writableBytes} or - * if {@code length} is greater then {@code src.readableBytes} + * @throws IndexOutOfBoundsException if {@code length} is greater then {@code src.readableBytes} */ public abstract ByteBuf writeBytes(ByteBuf src, int length); @@ -1945,15 +1929,15 @@ public ByteBuf writeDoubleLE(double value) { * Transfers the specified source buffer's data to this buffer starting at * the current {@code writerIndex} and increases the {@code writerIndex} * by the number of the transferred bytes (= {@code length}). + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. * * @param srcIndex the first index of the source * @param length the number of bytes to transfer * * @throws IndexOutOfBoundsException - * if the specified {@code srcIndex} is less than {@code 0}, - * if {@code srcIndex + length} is greater than - * {@code src.capacity}, or - * if {@code length} is greater than {@code this.writableBytes} + * if the specified {@code srcIndex} is less than {@code 0}, or + * if {@code srcIndex + length} is greater than {@code src.capacity} */ public abstract ByteBuf writeBytes(ByteBuf src, int srcIndex, int length); @@ -1961,9 +1945,8 @@ public ByteBuf writeDoubleLE(double value) { * Transfers the specified source array's data to this buffer starting at * the current {@code writerIndex} and increases the {@code writerIndex} * by the number of the transferred bytes (= {@code src.length}). - * - * @throws IndexOutOfBoundsException - * if {@code src.length} is greater than {@code this.writableBytes} + * If {@code this.writableBytes} is less than {@code src.length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. */ public abstract ByteBuf writeBytes(byte[] src); @@ -1971,15 +1954,15 @@ public ByteBuf writeDoubleLE(double value) { * Transfers the specified source array's data to this buffer starting at * the current {@code writerIndex} and increases the {@code writerIndex} * by the number of the transferred bytes (= {@code length}). + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. * * @param srcIndex the first index of the source * @param length the number of bytes to transfer * * @throws IndexOutOfBoundsException - * if the specified {@code srcIndex} is less than {@code 0}, - * if {@code srcIndex + length} is greater than - * {@code src.length}, or - * if {@code length} is greater than {@code this.writableBytes} + * if the specified {@code srcIndex} is less than {@code 0}, or + * if {@code srcIndex + length} is greater than {@code src.length} */ public abstract ByteBuf writeBytes(byte[] src, int srcIndex, int length); @@ -1988,10 +1971,9 @@ public ByteBuf writeDoubleLE(double value) { * the current {@code writerIndex} until the source buffer's position * reaches its limit, and increases the {@code writerIndex} by the * number of the transferred bytes. - * - * @throws IndexOutOfBoundsException - * if {@code src.remaining()} is greater than - * {@code this.writableBytes} + * If {@code this.writableBytes} is less than {@code src.remaining()}, + * {@link #ensureWritable(int)} will be called in an attempt to expand + * capacity to accommodate. */ public abstract ByteBuf writeBytes(ByteBuffer src); @@ -1999,29 +1981,28 @@ public ByteBuf writeDoubleLE(double value) { * Transfers the content of the specified stream to this buffer * starting at the current {@code writerIndex} and increases the * {@code writerIndex} by the number of the transferred bytes. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. * * @param length the number of bytes to transfer * * @return the actual number of bytes read in from the specified stream * - * @throws IndexOutOfBoundsException - * if {@code length} is greater than {@code this.writableBytes} - * @throws IOException - * if the specified stream threw an exception during I/O + * @throws IOException if the specified stream threw an exception during I/O */ - public abstract int writeBytes(InputStream in, int length) throws IOException; + public abstract int writeBytes(InputStream in, int length) throws IOException; /** * Transfers the content of the specified channel to this buffer * starting at the current {@code writerIndex} and increases the * {@code writerIndex} by the number of the transferred bytes. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. * * @param length the maximum number of bytes to transfer * * @return the actual number of bytes read in from the specified channel * - * @throws IndexOutOfBoundsException - * if {@code length} is greater than {@code this.writableBytes} * @throws IOException * if the specified channel threw an exception during I/O */ @@ -2032,14 +2013,14 @@ public ByteBuf writeDoubleLE(double value) { * to this buffer starting at the current {@code writerIndex} and increases the * {@code writerIndex} by the number of the transferred bytes. * This method does not modify the channel's position. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. * * @param position the file position at which the transfer is to begin * @param length the maximum number of bytes to transfer * * @return the actual number of bytes read in from the specified channel * - * @throws IndexOutOfBoundsException - * if {@code length} is greater than {@code this.writableBytes} * @throws IOException * if the specified channel threw an exception during I/O */ @@ -2049,11 +2030,10 @@ public ByteBuf writeDoubleLE(double value) { * Fills this buffer with NUL (0x00) starting at the current * {@code writerIndex} and increases the {@code writerIndex} by the * specified {@code length}. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. * * @param length the number of NULs to write to the buffer - * - * @throws IndexOutOfBoundsException - * if {@code length} is greater than {@code this.writableBytes} */ public abstract ByteBuf writeZero(int length); @@ -2061,12 +2041,12 @@ public ByteBuf writeDoubleLE(double value) { * Writes the specified {@link CharSequence} at the current {@code writerIndex} and increases * the {@code writerIndex} by the written bytes. * in this buffer. + * If {@code this.writableBytes} is not large enough to write the whole sequence, + * {@link #ensureWritable(int)} will be called in an attempt to expand capacity to accommodate. * * @param sequence to write * @param charset that should be used * @return the written number of bytes - * @throws IndexOutOfBoundsException - * if {@code this.writableBytes} is not large enough to write the whole sequence */ public abstract int writeCharSequence(CharSequence sequence, Charset charset); @@ -2465,4 +2445,12 @@ public ByteBuf writeDoubleLE(double value) { @Override public abstract ByteBuf touch(Object hint); + + /** + * Used internally by {@link AbstractByteBuf#ensureAccessible()} to try to guard + * against using the buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } } diff --git a/buffer/src/main/java/io/netty/buffer/ByteBufInputStream.java b/buffer/src/main/java/io/netty/buffer/ByteBufInputStream.java index 2d8d34d32317..038cd8db459b 100644 --- a/buffer/src/main/java/io/netty/buffer/ByteBufInputStream.java +++ b/buffer/src/main/java/io/netty/buffer/ByteBufInputStream.java @@ -16,6 +16,7 @@ package io.netty.buffer; import io.netty.util.ReferenceCounted; +import io.netty.util.internal.StringUtil; import java.io.DataInput; import java.io.DataInputStream; @@ -240,17 +241,18 @@ public int readInt() throws IOException { return buffer.readInt(); } - private final StringBuilder lineBuf = new StringBuilder(); + private StringBuilder lineBuf; @Override public String readLine() throws IOException { - lineBuf.setLength(0); - - loop: while (true) { - if (!buffer.isReadable()) { - return lineBuf.length() > 0 ? lineBuf.toString() : null; - } + if (!buffer.isReadable()) { + return null; + } + if (lineBuf != null) { + lineBuf.setLength(0); + } + loop: do { int c = buffer.readUnsignedByte(); switch (c) { case '\n': @@ -263,11 +265,14 @@ public String readLine() throws IOException { break loop; default: + if (lineBuf == null) { + lineBuf = new StringBuilder(); + } lineBuf.append((char) c); } - } + } while (buffer.isReadable()); - return lineBuf.toString(); + return lineBuf != null && lineBuf.length() > 0 ? lineBuf.toString() : StringUtil.EMPTY_STRING; } @Override diff --git a/buffer/src/main/java/io/netty/buffer/ByteBufUtil.java b/buffer/src/main/java/io/netty/buffer/ByteBufUtil.java index e7afab36728f..54a9f0c9fa02 100644 --- a/buffer/src/main/java/io/netty/buffer/ByteBufUtil.java +++ b/buffer/src/main/java/io/netty/buffer/ByteBufUtil.java @@ -43,6 +43,7 @@ import static io.netty.util.internal.MathUtil.isOutOfBounds; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static io.netty.util.internal.StringUtil.NEWLINE; import static io.netty.util.internal.StringUtil.isSurrogate; @@ -53,10 +54,10 @@ public final class ByteBufUtil { private static final InternalLogger logger = InternalLoggerFactory.getInstance(ByteBufUtil.class); - private static final FastThreadLocal CHAR_BUFFERS = new FastThreadLocal() { + private static final FastThreadLocal BYTE_ARRAYS = new FastThreadLocal() { @Override - protected CharBuffer initialValue() throws Exception { - return CharBuffer.allocate(1024); + protected byte[] initialValue() throws Exception { + return PlatformDependent.allocateUninitializedArray(MAX_TL_ARRAY_LEN); } }; @@ -95,6 +96,16 @@ protected CharBuffer initialValue() throws Exception { logger.debug("-Dio.netty.maxThreadLocalCharBufferSize: {}", MAX_CHAR_BUFFER_SIZE); } + static final int MAX_TL_ARRAY_LEN = 1024; + + /** + * Allocates a new array if minLength > {@link ByteBufUtil#MAX_TL_ARRAY_LEN} + */ + static byte[] threadLocalTempArray(int minLength) { + return minLength <= MAX_TL_ARRAY_LEN ? BYTE_ARRAYS.get() + : PlatformDependent.allocateUninitializedArray(minLength); + } + /** * Returns a hex dump * of the specified buffer's readable bytes. @@ -452,8 +463,9 @@ private static int firstIndexOf(ByteBuf buffer, int fromIndex, int toIndex, byte } private static int lastIndexOf(ByteBuf buffer, int fromIndex, int toIndex, byte value) { - fromIndex = Math.min(fromIndex, buffer.capacity()); - if (fromIndex < 0 || buffer.capacity() == 0) { + int capacity = buffer.capacity(); + fromIndex = Math.min(fromIndex, capacity); + if (fromIndex < 0 || capacity == 0) { return -1; } @@ -546,17 +558,8 @@ static int writeUtf8(AbstractByteBuf buffer, int writerIndex, CharSequence seq, buffer._setByte(writerIndex++, WRITE_UTF_UNKNOWN); break; } - if (!Character.isLowSurrogate(c2)) { - buffer._setByte(writerIndex++, WRITE_UTF_UNKNOWN); - buffer._setByte(writerIndex++, Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2); - continue; - } - int codePoint = Character.toCodePoint(c, c2); - // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. - buffer._setByte(writerIndex++, (byte) (0xf0 | (codePoint >> 18))); - buffer._setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); - buffer._setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); - buffer._setByte(writerIndex++, (byte) (0x80 | (codePoint & 0x3f))); + // Extra method to allow inlining the rest of writeUtf8 which is the most likely code path. + writerIndex = writeUtf8Surrogate(buffer, writerIndex, c, c2); } else { buffer._setByte(writerIndex++, (byte) (0xe0 | (c >> 12))); buffer._setByte(writerIndex++, (byte) (0x80 | ((c >> 6) & 0x3f))); @@ -566,6 +569,21 @@ static int writeUtf8(AbstractByteBuf buffer, int writerIndex, CharSequence seq, return writerIndex - oldWriterIndex; } + private static int writeUtf8Surrogate(AbstractByteBuf buffer, int writerIndex, char c, char c2) { + if (!Character.isLowSurrogate(c2)) { + buffer._setByte(writerIndex++, WRITE_UTF_UNKNOWN); + buffer._setByte(writerIndex++, Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2); + return writerIndex; + } + int codePoint = Character.toCodePoint(c, c2); + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer._setByte(writerIndex++, (byte) (0xf0 | (codePoint >> 18))); + buffer._setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); + buffer._setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); + buffer._setByte(writerIndex++, (byte) (0x80 | (codePoint & 0x3f))); + return writerIndex; + } + /** * Returns max bytes length of UTF8 character sequence of the given length. */ @@ -756,52 +774,27 @@ static ByteBuf encodeString0(ByteBufAllocator alloc, boolean enforceHeap, CharBu } } + @SuppressWarnings("deprecation") static String decodeString(ByteBuf src, int readerIndex, int len, Charset charset) { if (len == 0) { return StringUtil.EMPTY_STRING; } - final CharsetDecoder decoder = CharsetUtil.decoder(charset); - final int maxLength = (int) ((double) len * decoder.maxCharsPerByte()); - CharBuffer dst = CHAR_BUFFERS.get(); - if (dst.length() < maxLength) { - dst = CharBuffer.allocate(maxLength); - if (maxLength <= MAX_CHAR_BUFFER_SIZE) { - CHAR_BUFFERS.set(dst); - } + final byte[] array; + final int offset; + + if (src.hasArray()) { + array = src.array(); + offset = src.arrayOffset() + readerIndex; } else { - dst.clear(); + array = threadLocalTempArray(len); + offset = 0; + src.getBytes(readerIndex, array, 0, len); } - if (src.nioBufferCount() == 1) { - decodeString(decoder, src.nioBuffer(readerIndex, len), dst); - } else { - // We use a heap buffer as CharsetDecoder is most likely able to use a fast-path if src and dst buffers - // are both backed by a byte array. - ByteBuf buffer = src.alloc().heapBuffer(len); - try { - buffer.writeBytes(src, readerIndex, len); - // Use internalNioBuffer(...) to reduce object creation. - decodeString(decoder, buffer.internalNioBuffer(buffer.readerIndex(), len), dst); - } finally { - // Release the temporary buffer again. - buffer.release(); - } - } - return dst.flip().toString(); - } - - private static void decodeString(CharsetDecoder decoder, ByteBuffer src, CharBuffer dst) { - try { - CoderResult cr = decoder.decode(src, dst, true); - if (!cr.isUnderflow()) { - cr.throwException(); - } - cr = decoder.flush(dst); - if (!cr.isUnderflow()) { - cr.throwException(); - } - } catch (CharacterCodingException x) { - throw new IllegalStateException(x); + if (CharsetUtil.US_ASCII.equals(charset)) { + // Fast-path for US-ASCII which is used frequently. + return new String(array, 0, offset, len); } + return new String(array, offset, len, charset); } /** @@ -844,13 +837,14 @@ public static byte[] getBytes(ByteBuf buf, int start, int length) { * If {@code copy} is false the underlying storage will be shared, if possible. */ public static byte[] getBytes(ByteBuf buf, int start, int length, boolean copy) { - if (isOutOfBounds(start, length, buf.capacity())) { + int capacity = buf.capacity(); + if (isOutOfBounds(start, length, capacity)) { throw new IndexOutOfBoundsException("expected: " + "0 <= start(" + start + ") <= start + length(" + length - + ") <= " + "buf.capacity(" + buf.capacity() + ')'); + + ") <= " + "buf.capacity(" + capacity + ')'); } if (buf.hasArray()) { - if (copy || start != 0 || length != buf.capacity()) { + if (copy || start != 0 || length != capacity) { int baseOffset = buf.arrayOffset() + start; return Arrays.copyOfRange(buf.array(), baseOffset, baseOffset + length); } else { @@ -858,7 +852,7 @@ public static byte[] getBytes(ByteBuf buf, int start, int length, boolean copy) } } - byte[] v = new byte[length]; + byte[] v = PlatformDependent.allocateUninitializedArray(length); buf.getBytes(start, v); return v; } @@ -1007,9 +1001,7 @@ private static final class HexUtil { } private static String hexDump(ByteBuf buffer, int fromIndex, int length) { - if (length < 0) { - throw new IllegalArgumentException("length: " + length); - } + checkPositiveOrZero(length, "length"); if (length == 0) { return ""; } @@ -1029,9 +1021,7 @@ private static String hexDump(ByteBuf buffer, int fromIndex, int length) { } private static String hexDump(byte[] array, int fromIndex, int length) { - if (length < 0) { - throw new IllegalArgumentException("length: " + length); - } + checkPositiveOrZero(length, "length"); if (length == 0) { return ""; } @@ -1413,7 +1403,9 @@ static void readBytes(ByteBufAllocator allocator, ByteBuffer buffer, int positio int chunkLen = Math.min(length, WRITE_CHUNK_SIZE); buffer.clear().position(position); - if (allocator.isDirectBufferPooled()) { + if (length <= MAX_TL_ARRAY_LEN || !allocator.isDirectBufferPooled()) { + getBytes(buffer, threadLocalTempArray(chunkLen), 0, chunkLen, out, length); + } else { // if direct buffers are pooled chances are good that heap buffers are pooled as well. ByteBuf tmpBuf = allocator.heapBuffer(chunkLen); try { @@ -1423,8 +1415,6 @@ static void readBytes(ByteBufAllocator allocator, ByteBuffer buffer, int positio } finally { tmpBuf.release(); } - } else { - getBytes(buffer, new byte[chunkLen], 0, chunkLen, out, length); } } } diff --git a/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java b/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java index 7afff7e7f13e..99127cbd2e06 100644 --- a/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java @@ -15,7 +15,11 @@ */ package io.netty.buffer; +import io.netty.util.ByteProcessor; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.RecyclableArrayList; import java.io.IOException; import java.io.InputStream; @@ -26,12 +30,12 @@ import java.nio.channels.GatheringByteChannel; import java.nio.channels.ScatteringByteChannel; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.ConcurrentModificationException; import java.util.Iterator; import java.util.List; -import java.util.ListIterator; import java.util.NoSuchElementException; import static io.netty.util.internal.ObjectUtil.checkNotNull; @@ -48,70 +52,94 @@ public class CompositeByteBuf extends AbstractReferenceCountedByteBuf implements private final ByteBufAllocator alloc; private final boolean direct; - private final ComponentList components; private final int maxNumComponents; + private int componentCount; + private Component[] components; // resized when needed + private boolean freed; - public CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents) { + private CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents, int initSize) { super(AbstractByteBufAllocator.DEFAULT_MAX_CAPACITY); if (alloc == null) { throw new NullPointerException("alloc"); } + if (maxNumComponents < 1) { + throw new IllegalArgumentException( + "maxNumComponents: " + maxNumComponents + " (expected: >= 1)"); + } this.alloc = alloc; this.direct = direct; this.maxNumComponents = maxNumComponents; - components = newList(maxNumComponents); + components = newCompArray(initSize, maxNumComponents); } - public CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents, ByteBuf... buffers) { - this(alloc, direct, maxNumComponents, buffers, 0, buffers.length); + public CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents) { + this(alloc, direct, maxNumComponents, 0); } - CompositeByteBuf( - ByteBufAllocator alloc, boolean direct, int maxNumComponents, ByteBuf[] buffers, int offset, int len) { - super(AbstractByteBufAllocator.DEFAULT_MAX_CAPACITY); - if (alloc == null) { - throw new NullPointerException("alloc"); - } - if (maxNumComponents < 2) { - throw new IllegalArgumentException( - "maxNumComponents: " + maxNumComponents + " (expected: >= 2)"); - } + public CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents, ByteBuf... buffers) { + this(alloc, direct, maxNumComponents, buffers, 0); + } - this.alloc = alloc; - this.direct = direct; - this.maxNumComponents = maxNumComponents; - components = newList(maxNumComponents); + CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents, + ByteBuf[] buffers, int offset) { + this(alloc, direct, maxNumComponents, buffers.length - offset); - addComponents0(false, 0, buffers, offset, len); + addComponents0(false, 0, buffers, offset); consolidateIfNeeded(); - setIndex(0, capacity()); + setIndex0(0, capacity()); } public CompositeByteBuf( ByteBufAllocator alloc, boolean direct, int maxNumComponents, Iterable buffers) { - super(AbstractByteBufAllocator.DEFAULT_MAX_CAPACITY); - if (alloc == null) { - throw new NullPointerException("alloc"); + this(alloc, direct, maxNumComponents, + buffers instanceof Collection ? ((Collection) buffers).size() : 0); + + addComponents(false, 0, buffers); + setIndex(0, capacity()); + } + + // support passing arrays of other types instead of having to copy to a ByteBuf[] first + interface ByteWrapper { + ByteBuf wrap(T bytes); + boolean isEmpty(T bytes); + } + + static final ByteWrapper BYTE_ARRAY_WRAPPER = new ByteWrapper() { + @Override + public ByteBuf wrap(byte[] bytes) { + return Unpooled.wrappedBuffer(bytes); } - if (maxNumComponents < 2) { - throw new IllegalArgumentException( - "maxNumComponents: " + maxNumComponents + " (expected: >= 2)"); + @Override + public boolean isEmpty(byte[] bytes) { + return bytes.length == 0; } + }; - this.alloc = alloc; - this.direct = direct; - this.maxNumComponents = maxNumComponents; - components = newList(maxNumComponents); + static final ByteWrapper BYTE_BUFFER_WRAPPER = new ByteWrapper() { + @Override + public ByteBuf wrap(ByteBuffer bytes) { + return Unpooled.wrappedBuffer(bytes); + } + @Override + public boolean isEmpty(ByteBuffer bytes) { + return !bytes.hasRemaining(); + } + }; - addComponents0(false, 0, buffers); + CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents, + ByteWrapper wrapper, T[] buffers, int offset) { + this(alloc, direct, maxNumComponents, buffers.length - offset); + + addComponents0(false, 0, wrapper, buffers, offset); consolidateIfNeeded(); setIndex(0, capacity()); } - private static ComponentList newList(int maxNumComponents) { - return new ComponentList(Math.min(AbstractByteBufAllocator.DEFAULT_MAX_COMPONENTS, maxNumComponents)); + private static Component[] newCompArray(int initComponents, int maxNumComponents) { + int capacityGuess = Math.min(AbstractByteBufAllocator.DEFAULT_MAX_COMPONENTS, maxNumComponents); + return new Component[Math.max(initComponents, capacityGuess)]; } // Special constructor used by WrappedCompositeByteBuf @@ -129,8 +157,8 @@ private static ComponentList newList(int maxNumComponents) { * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. * If you need to have it increased use {@link #addComponent(boolean, ByteBuf)}. *

- * {@link ByteBuf#release()} ownership of {@code buffer} is transfered to this {@link CompositeByteBuf}. - * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transfered to this + * {@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link CompositeByteBuf}. + * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transferred to this * {@link CompositeByteBuf}. */ public CompositeByteBuf addComponent(ByteBuf buffer) { @@ -143,10 +171,10 @@ public CompositeByteBuf addComponent(ByteBuf buffer) { * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. * If you need to have it increased use {@link #addComponents(boolean, ByteBuf[])}. *

- * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transfered to this + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this * {@link CompositeByteBuf}. * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} - * ownership of all {@link ByteBuf} objects is transfered to this {@link CompositeByteBuf}. + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. */ public CompositeByteBuf addComponents(ByteBuf... buffers) { return addComponents(false, buffers); @@ -158,10 +186,10 @@ public CompositeByteBuf addComponents(ByteBuf... buffers) { * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. * If you need to have it increased use {@link #addComponents(boolean, Iterable)}. *

- * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transfered to this + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this * {@link CompositeByteBuf}. * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} - * ownership of all {@link ByteBuf} objects is transfered to this {@link CompositeByteBuf}. + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. */ public CompositeByteBuf addComponents(Iterable buffers) { return addComponents(false, buffers); @@ -173,9 +201,9 @@ public CompositeByteBuf addComponents(Iterable buffers) { * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. * If you need to have it increased use {@link #addComponent(boolean, int, ByteBuf)}. *

- * {@link ByteBuf#release()} ownership of {@code buffer} is transfered to this {@link CompositeByteBuf}. + * {@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link CompositeByteBuf}. * @param cIndex the index on which the {@link ByteBuf} will be added. - * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transfered to this + * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transferred to this * {@link CompositeByteBuf}. */ public CompositeByteBuf addComponent(int cIndex, ByteBuf buffer) { @@ -186,28 +214,26 @@ public CompositeByteBuf addComponent(int cIndex, ByteBuf buffer) { * Add the given {@link ByteBuf} and increase the {@code writerIndex} if {@code increaseWriterIndex} is * {@code true}. * - * {@link ByteBuf#release()} ownership of {@code buffer} is transfered to this {@link CompositeByteBuf}. - * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transfered to this + * {@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link CompositeByteBuf}. + * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transferred to this * {@link CompositeByteBuf}. */ public CompositeByteBuf addComponent(boolean increaseWriterIndex, ByteBuf buffer) { - checkNotNull(buffer, "buffer"); - addComponent0(increaseWriterIndex, components.size(), buffer); - consolidateIfNeeded(); - return this; + return addComponent(increaseWriterIndex, componentCount, buffer); } /** * Add the given {@link ByteBuf}s and increase the {@code writerIndex} if {@code increaseWriterIndex} is * {@code true}. * - * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transfered to this + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this * {@link CompositeByteBuf}. * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} - * ownership of all {@link ByteBuf} objects is transfered to this {@link CompositeByteBuf}. + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. */ public CompositeByteBuf addComponents(boolean increaseWriterIndex, ByteBuf... buffers) { - addComponents0(increaseWriterIndex, components.size(), buffers, 0, buffers.length); + checkNotNull(buffers, "buffers"); + addComponents0(increaseWriterIndex, componentCount, buffers, 0); consolidateIfNeeded(); return this; } @@ -216,24 +242,22 @@ public CompositeByteBuf addComponents(boolean increaseWriterIndex, ByteBuf... bu * Add the given {@link ByteBuf}s and increase the {@code writerIndex} if {@code increaseWriterIndex} is * {@code true}. * - * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transfered to this + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this * {@link CompositeByteBuf}. * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} - * ownership of all {@link ByteBuf} objects is transfered to this {@link CompositeByteBuf}. + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. */ public CompositeByteBuf addComponents(boolean increaseWriterIndex, Iterable buffers) { - addComponents0(increaseWriterIndex, components.size(), buffers); - consolidateIfNeeded(); - return this; + return addComponents(increaseWriterIndex, componentCount, buffers); } /** * Add the given {@link ByteBuf} on the specific index and increase the {@code writerIndex} * if {@code increaseWriterIndex} is {@code true}. * - * {@link ByteBuf#release()} ownership of {@code buffer} is transfered to this {@link CompositeByteBuf}. + * {@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link CompositeByteBuf}. * @param cIndex the index on which the {@link ByteBuf} will be added. - * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transfered to this + * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transferred to this * {@link CompositeByteBuf}. */ public CompositeByteBuf addComponent(boolean increaseWriterIndex, int cIndex, ByteBuf buffer) { @@ -252,29 +276,19 @@ private int addComponent0(boolean increaseWriterIndex, int cIndex, ByteBuf buffe try { checkComponentIndex(cIndex); - int readableBytes = buffer.readableBytes(); - // No need to consolidate - just add a component to the list. - @SuppressWarnings("deprecation") - Component c = new Component(buffer.order(ByteOrder.BIG_ENDIAN).slice()); - if (cIndex == components.size()) { - wasAdded = components.add(c); - if (cIndex == 0) { - c.endOffset = readableBytes; - } else { - Component prev = components.get(cIndex - 1); - c.offset = prev.endOffset; - c.endOffset = c.offset + readableBytes; - } - } else { - components.add(cIndex, c); - wasAdded = true; - if (readableBytes != 0) { - updateComponentOffsets(cIndex); - } + Component c = newComponent(buffer, 0); + int readableBytes = c.length(); + + addComp(cIndex, c); + wasAdded = true; + if (readableBytes > 0 && cIndex < componentCount - 1) { + updateComponentOffsets(cIndex); + } else if (cIndex > 0) { + c.reposition(components[cIndex - 1].endOffset); } if (increaseWriterIndex) { - writerIndex(writerIndex() + buffer.readableBytes()); + writerIndex(writerIndex() + readableBytes); } return cIndex; } finally { @@ -284,59 +298,103 @@ private int addComponent0(boolean increaseWriterIndex, int cIndex, ByteBuf buffe } } + @SuppressWarnings("deprecation") + private Component newComponent(ByteBuf buf, int offset) { + if (checkAccessible && !buf.isAccessible()) { + throw new IllegalReferenceCountException(0); + } + int srcIndex = buf.readerIndex(), len = buf.readableBytes(); + ByteBuf slice = null; + // unwrap if already sliced + if (buf instanceof AbstractUnpooledSlicedByteBuf) { + srcIndex += ((AbstractUnpooledSlicedByteBuf) buf).idx(0); + slice = buf; + buf = buf.unwrap(); + } else if (buf instanceof PooledSlicedByteBuf) { + srcIndex += ((PooledSlicedByteBuf) buf).adjustment; + slice = buf; + buf = buf.unwrap(); + } + return new Component(buf.order(ByteOrder.BIG_ENDIAN), srcIndex, offset, len, slice); + } + /** * Add the given {@link ByteBuf}s on the specific index *

* Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. * If you need to have it increased you need to handle it by your own. *

- * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transfered to this + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this * {@link CompositeByteBuf}. * @param cIndex the index on which the {@link ByteBuf} will be added. {@link ByteBuf#release()} ownership of all - * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects is transfered to this + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects is transferred to this * {@link CompositeByteBuf}. * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} - * ownership of all {@link ByteBuf} objects is transfered to this {@link CompositeByteBuf}. + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. */ public CompositeByteBuf addComponents(int cIndex, ByteBuf... buffers) { - addComponents0(false, cIndex, buffers, 0, buffers.length); + checkNotNull(buffers, "buffers"); + addComponents0(false, cIndex, buffers, 0); consolidateIfNeeded(); return this; } - private int addComponents0(boolean increaseWriterIndex, int cIndex, ByteBuf[] buffers, int offset, int len) { - checkNotNull(buffers, "buffers"); - int i = offset; + private CompositeByteBuf addComponents0(boolean increaseWriterIndex, + final int cIndex, ByteBuf[] buffers, int arrOffset) { + final int len = buffers.length, count = len - arrOffset; + // only set ci after we've shifted so that finally block logic is always correct + int ci = Integer.MAX_VALUE; try { checkComponentIndex(cIndex); - - // No need for consolidation - while (i < len) { - // Increment i now to prepare for the next iteration and prevent a duplicate release (addComponent0 - // will release if an exception occurs, and we also release in the finally block here). - ByteBuf b = buffers[i++]; + shiftComps(cIndex, count); // will increase componentCount + int nextOffset = cIndex > 0 ? components[cIndex - 1].endOffset : 0; + for (ci = cIndex; arrOffset < len; arrOffset++, ci++) { + ByteBuf b = buffers[arrOffset]; if (b == null) { break; } - cIndex = addComponent0(increaseWriterIndex, cIndex, b) + 1; - int size = components.size(); - if (cIndex > size) { - cIndex = size; - } + Component c = newComponent(b, nextOffset); + components[ci] = c; + nextOffset = c.endOffset; } - return cIndex; + return this; } finally { - for (; i < len; ++i) { - ByteBuf b = buffers[i]; - if (b != null) { - try { - b.release(); - } catch (Throwable ignored) { - // ignore + // ci is now the index following the last successfully added component + if (ci < componentCount) { + if (ci < cIndex + count) { + // we bailed early + removeCompRange(ci, cIndex + count); + for (; arrOffset < len; ++arrOffset) { + ReferenceCountUtil.safeRelease(buffers[arrOffset]); } } + updateComponentOffsets(ci); // only need to do this here for components after the added ones + } + if (increaseWriterIndex && ci > cIndex && ci <= componentCount) { + writerIndex(writerIndex() + components[ci - 1].endOffset - components[cIndex].offset); + } + } + } + + private int addComponents0(boolean increaseWriterIndex, int cIndex, + ByteWrapper wrapper, T[] buffers, int offset) { + checkComponentIndex(cIndex); + + // No need for consolidation + for (int i = offset, len = buffers.length; i < len; i++) { + T b = buffers[i]; + if (b == null) { + break; + } + if (!wrapper.isEmpty(b)) { + cIndex = addComponent0(increaseWriterIndex, cIndex, wrapper.wrap(b)) + 1; + int size = componentCount; + if (cIndex > size) { + cIndex = size; + } } } + return cIndex; } /** @@ -345,50 +403,46 @@ private int addComponents0(boolean increaseWriterIndex, int cIndex, ByteBuf[] bu * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. * If you need to have it increased you need to handle it by your own. *

- * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transfered to this + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this * {@link CompositeByteBuf}. * @param cIndex the index on which the {@link ByteBuf} will be added. * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all - * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects is transfered to this + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects is transferred to this * {@link CompositeByteBuf}. */ public CompositeByteBuf addComponents(int cIndex, Iterable buffers) { - addComponents0(false, cIndex, buffers); - consolidateIfNeeded(); - return this; + return addComponents(false, cIndex, buffers); } - private int addComponents0(boolean increaseIndex, int cIndex, Iterable buffers) { + // TODO optimize further, similar to ByteBuf[] version + // (difference here is that we don't know *always* know precise size increase in advance, + // but we do in the most common case that the Iterable is a Collection) + private CompositeByteBuf addComponents(boolean increaseIndex, int cIndex, Iterable buffers) { if (buffers instanceof ByteBuf) { // If buffers also implements ByteBuf (e.g. CompositeByteBuf), it has to go to addComponent(ByteBuf). - return addComponent0(increaseIndex, cIndex, (ByteBuf) buffers); + return addComponent(increaseIndex, cIndex, (ByteBuf) buffers); } checkNotNull(buffers, "buffers"); + Iterator it = buffers.iterator(); + try { + checkComponentIndex(cIndex); - if (!(buffers instanceof Collection)) { - List list = new ArrayList(); - try { - for (ByteBuf b: buffers) { - list.add(b); - } - buffers = list; - } finally { - if (buffers != list) { - for (ByteBuf b: buffers) { - if (b != null) { - try { - b.release(); - } catch (Throwable ignored) { - // ignore - } - } - } + // No need for consolidation + while (it.hasNext()) { + ByteBuf b = it.next(); + if (b == null) { + break; } + cIndex = addComponent0(increaseIndex, cIndex, b) + 1; + cIndex = Math.min(cIndex, componentCount); + } + } finally { + while (it.hasNext()) { + ReferenceCountUtil.safeRelease(it.next()); } } - - Collection col = (Collection) buffers; - return addComponents0(increaseIndex, cIndex, col.toArray(new ByteBuf[col.size()]), 0 , col.size()); + consolidateIfNeeded(); + return this; } /** @@ -398,63 +452,53 @@ private int addComponents0(boolean increaseIndex, int cIndex, Iterable private void consolidateIfNeeded() { // Consolidate if the number of components will exceed the allowed maximum by the current // operation. - final int numComponents = components.size(); - if (numComponents > maxNumComponents) { - final int capacity = components.get(numComponents - 1).endOffset; + int size = componentCount; + if (size > maxNumComponents) { + final int capacity = components[size - 1].endOffset; ByteBuf consolidated = allocBuffer(capacity); + lastAccessed = null; // We're not using foreach to avoid creating an iterator. - for (int i = 0; i < numComponents; i ++) { - Component c = components.get(i); - ByteBuf b = c.buf; - consolidated.writeBytes(b); - c.freeIfNecessary(); + for (int i = 0; i < size; i ++) { + components[i].transferTo(consolidated); } - Component c = new Component(consolidated); - c.endOffset = c.length; - components.clear(); - components.add(c); + + components[0] = new Component(consolidated, 0, 0, capacity, consolidated); + removeCompRange(1, size); } } private void checkComponentIndex(int cIndex) { ensureAccessible(); - if (cIndex < 0 || cIndex > components.size()) { + if (cIndex < 0 || cIndex > componentCount) { throw new IndexOutOfBoundsException(String.format( "cIndex: %d (expected: >= 0 && <= numComponents(%d))", - cIndex, components.size())); + cIndex, componentCount)); } } private void checkComponentIndex(int cIndex, int numComponents) { ensureAccessible(); - if (cIndex < 0 || cIndex + numComponents > components.size()) { + if (cIndex < 0 || cIndex + numComponents > componentCount) { throw new IndexOutOfBoundsException(String.format( "cIndex: %d, numComponents: %d " + "(expected: cIndex >= 0 && cIndex + numComponents <= totalNumComponents(%d))", - cIndex, numComponents, components.size())); + cIndex, numComponents, componentCount)); } } private void updateComponentOffsets(int cIndex) { - int size = components.size(); + int size = componentCount; if (size <= cIndex) { return; } - Component c = components.get(cIndex); - if (cIndex == 0) { - c.offset = 0; - c.endOffset = c.length; - cIndex ++; - } - - for (int i = cIndex; i < size; i ++) { - Component prev = components.get(i - 1); - Component cur = components.get(i); - cur.offset = prev.endOffset; - cur.endOffset = cur.offset + cur.length; + int nextIndex = cIndex > 0 ? components[cIndex - 1].endOffset : 0; + for (; cIndex < size; cIndex++) { + Component c = components[cIndex]; + c.reposition(nextIndex); + nextIndex = c.endOffset; } } @@ -465,9 +509,13 @@ private void updateComponentOffsets(int cIndex) { */ public CompositeByteBuf removeComponent(int cIndex) { checkComponentIndex(cIndex); - Component comp = components.remove(cIndex); - comp.freeIfNecessary(); - if (comp.length > 0) { + Component comp = components[cIndex]; + if (lastAccessed == comp) { + lastAccessed = null; + } + comp.free(); + removeComp(cIndex); + if (comp.length() > 0) { // Only need to call updateComponentOffsets if the length was > 0 updateComponentOffsets(cIndex); } @@ -489,13 +537,16 @@ public CompositeByteBuf removeComponents(int cIndex, int numComponents) { int endIndex = cIndex + numComponents; boolean needsUpdate = false; for (int i = cIndex; i < endIndex; ++i) { - Component c = components.get(i); - if (c.length > 0) { + Component c = components[i]; + if (c.length() > 0) { needsUpdate = true; } - c.freeIfNecessary(); + if (lastAccessed == c) { + lastAccessed = null; + } + c.free(); } - components.removeRange(cIndex, endIndex); + removeCompRange(cIndex, endIndex); if (needsUpdate) { // Only need to call updateComponentOffsets if the length was > 0 @@ -507,10 +558,59 @@ public CompositeByteBuf removeComponents(int cIndex, int numComponents) { @Override public Iterator iterator() { ensureAccessible(); - if (components.isEmpty()) { - return EMPTY_ITERATOR; + return componentCount == 0 ? EMPTY_ITERATOR : new CompositeByteBufIterator(); + } + + @Override + protected int forEachByteAsc0(int start, int end, ByteProcessor processor) throws Exception { + if (end <= start) { + return -1; + } + for (int i = toComponentIndex0(start), length = end - start; length > 0; i++) { + Component c = components[i]; + if (c.offset == c.endOffset) { + continue; // empty + } + ByteBuf s = c.buf; + int localStart = c.idx(start); + int localLength = Math.min(length, c.endOffset - start); + // avoid additional checks in AbstractByteBuf case + int result = s instanceof AbstractByteBuf + ? ((AbstractByteBuf) s).forEachByteAsc0(localStart, localStart + localLength, processor) + : s.forEachByte(localStart, localLength, processor); + if (result != -1) { + return result - c.adjustment; + } + start += localLength; + length -= localLength; } - return new CompositeByteBufIterator(); + return -1; + } + + @Override + protected int forEachByteDesc0(int rStart, int rEnd, ByteProcessor processor) throws Exception { + if (rEnd > rStart) { // rStart *and* rEnd are inclusive + return -1; + } + for (int i = toComponentIndex0(rStart), length = 1 + rStart - rEnd; length > 0; i--) { + Component c = components[i]; + if (c.offset == c.endOffset) { + continue; // empty + } + ByteBuf s = c.buf; + int localRStart = c.idx(length + rEnd); + int localLength = Math.min(length, localRStart), localIndex = localRStart - localLength; + // avoid additional checks in AbstractByteBuf case + int result = s instanceof AbstractByteBuf + ? ((AbstractByteBuf) s).forEachByteDesc0(localRStart - 1, localIndex, processor) + : s.forEachByteDesc(localIndex, localLength, processor); + + if (result != -1) { + return result - c.adjustment; + } + length -= localLength; + } + return -1; } /** @@ -522,50 +622,40 @@ public List decompose(int offset, int length) { return Collections.emptyList(); } - int componentId = toComponentIndex(offset); - List slice = new ArrayList(components.size()); - - // The first component - Component firstC = components.get(componentId); - ByteBuf first = firstC.buf.duplicate(); - first.readerIndex(offset - firstC.offset); - - ByteBuf buf = first; + int componentId = toComponentIndex0(offset); int bytesToSlice = length; - do { - int readableBytes = buf.readableBytes(); - if (bytesToSlice <= readableBytes) { - // Last component - buf.writerIndex(buf.readerIndex() + bytesToSlice); - slice.add(buf); - break; - } else { - // Not the last component - slice.add(buf); - bytesToSlice -= readableBytes; - componentId ++; + // The first component + Component firstC = components[componentId]; - // Fetch the next component. - buf = components.get(componentId).buf.duplicate(); - } - } while (bytesToSlice > 0); + ByteBuf slice = firstC.buf.slice(firstC.idx(offset), Math.min(firstC.endOffset - offset, bytesToSlice)); + bytesToSlice -= slice.readableBytes(); - // Slice all components because only readable bytes are interesting. - for (int i = 0; i < slice.size(); i ++) { - slice.set(i, slice.get(i).slice()); + if (bytesToSlice == 0) { + return Collections.singletonList(slice); } - return slice; + List sliceList = new ArrayList(componentCount - componentId); + sliceList.add(slice); + + // Add all the slices until there is nothing more left and then return the List. + do { + Component component = components[++componentId]; + slice = component.buf.slice(component.idx(component.offset), Math.min(component.length(), bytesToSlice)); + bytesToSlice -= slice.readableBytes(); + sliceList.add(slice); + } while (bytesToSlice > 0); + + return sliceList; } @Override public boolean isDirect() { - int size = components.size(); + int size = componentCount; if (size == 0) { return false; } for (int i = 0; i < size; i++) { - if (!components.get(i).buf.isDirect()) { + if (!components[i].buf.isDirect()) { return false; } } @@ -574,11 +664,11 @@ public boolean isDirect() { @Override public boolean hasArray() { - switch (components.size()) { + switch (componentCount) { case 0: return true; case 1: - return components.get(0).buf.hasArray(); + return components[0].buf.hasArray(); default: return false; } @@ -586,11 +676,11 @@ public boolean hasArray() { @Override public byte[] array() { - switch (components.size()) { + switch (componentCount) { case 0: return EmptyArrays.EMPTY_BYTES; case 1: - return components.get(0).buf.array(); + return components[0].buf.array(); default: throw new UnsupportedOperationException(); } @@ -598,11 +688,12 @@ public byte[] array() { @Override public int arrayOffset() { - switch (components.size()) { + switch (componentCount) { case 0: return 0; case 1: - return components.get(0).buf.arrayOffset(); + Component c = components[0]; + return c.idx(c.buf.arrayOffset()); default: throw new UnsupportedOperationException(); } @@ -610,11 +701,11 @@ public int arrayOffset() { @Override public boolean hasMemoryAddress() { - switch (components.size()) { + switch (componentCount) { case 0: return Unpooled.EMPTY_BUFFER.hasMemoryAddress(); case 1: - return components.get(0).buf.hasMemoryAddress(); + return components[0].buf.hasMemoryAddress(); default: return false; } @@ -622,11 +713,12 @@ public boolean hasMemoryAddress() { @Override public long memoryAddress() { - switch (components.size()) { + switch (componentCount) { case 0: return Unpooled.EMPTY_BUFFER.memoryAddress(); case 1: - return components.get(0).buf.memoryAddress(); + Component c = components[0]; + return c.buf.memoryAddress() + c.adjustment; default: throw new UnsupportedOperationException(); } @@ -634,51 +726,45 @@ public long memoryAddress() { @Override public int capacity() { - final int numComponents = components.size(); - if (numComponents == 0) { - return 0; - } - return components.get(numComponents - 1).endOffset; + int size = componentCount; + return size > 0 ? components[size - 1].endOffset : 0; } @Override public CompositeByteBuf capacity(int newCapacity) { checkNewCapacity(newCapacity); - int oldCapacity = capacity(); + final int size = componentCount, oldCapacity = capacity(); if (newCapacity > oldCapacity) { final int paddingLength = newCapacity - oldCapacity; - ByteBuf padding; - int nComponents = components.size(); - if (nComponents < maxNumComponents) { - padding = allocBuffer(paddingLength); - padding.setIndex(0, paddingLength); - addComponent0(false, components.size(), padding); - } else { - padding = allocBuffer(paddingLength); - padding.setIndex(0, paddingLength); + ByteBuf padding = allocBuffer(paddingLength).setIndex(0, paddingLength); + addComponent0(false, size, padding); + if (componentCount >= maxNumComponents) { // FIXME: No need to create a padding buffer and consolidate. // Just create a big single buffer and put the current content there. - addComponent0(false, components.size(), padding); consolidateIfNeeded(); } } else if (newCapacity < oldCapacity) { - int bytesToTrim = oldCapacity - newCapacity; - for (ListIterator i = components.listIterator(components.size()); i.hasPrevious();) { - Component c = i.previous(); - if (bytesToTrim >= c.length) { - bytesToTrim -= c.length; - i.remove(); - continue; + lastAccessed = null; + int i = size - 1; + for (int bytesToTrim = oldCapacity - newCapacity; i >= 0; i--) { + Component c = components[i]; + final int cLength = c.length(); + if (bytesToTrim < cLength) { + // Trim the last component + c.endOffset -= bytesToTrim; + ByteBuf slice = c.slice; + if (slice != null) { + // We must replace the cached slice with a derived one to ensure that + // it can later be released properly in the case of PooledSlicedByteBuf. + c.slice = slice.slice(0, c.length()); + } + break; } - - // Replace the last component with the trimmed slice. - Component newC = new Component(c.buf.slice(0, c.length - bytesToTrim)); - newC.offset = c.offset; - newC.endOffset = newC.offset + newC.length; - i.set(newC); - break; + c.free(); + bytesToTrim -= cLength; } + removeCompRange(i + 1, size); if (readerIndex() > newCapacity) { setIndex(newCapacity, newCapacity); @@ -703,7 +789,7 @@ public ByteOrder order() { * Return the current number of {@link ByteBuf}'s that are composed in this instance */ public int numComponents() { - return components.size(); + return componentCount; } /** @@ -718,10 +804,21 @@ public int maxNumComponents() { */ public int toComponentIndex(int offset) { checkIndex(offset); + return toComponentIndex0(offset); + } - for (int low = 0, high = components.size(); low <= high;) { + private int toComponentIndex0(int offset) { + int size = componentCount; + if (offset == 0) { // fast-path zero offset + for (int i = 0; i < size; i++) { + if (components[i].endOffset > 0) { + return i; + } + } + } + for (int low = 0, high = size; low <= high;) { int mid = low + high >>> 1; - Component c = components.get(mid); + Component c = components[mid]; if (offset >= c.endOffset) { low = mid + 1; } else if (offset < c.offset) { @@ -736,25 +833,26 @@ public int toComponentIndex(int offset) { public int toByteIndex(int cIndex) { checkComponentIndex(cIndex); - return components.get(cIndex).offset; + return components[cIndex].offset; } @Override public byte getByte(int index) { - return _getByte(index); + Component c = findComponent(index); + return c.buf.getByte(c.idx(index)); } @Override protected byte _getByte(int index) { - Component c = findComponent(index); - return c.buf.getByte(index - c.offset); + Component c = findComponent0(index); + return c.buf.getByte(c.idx(index)); } @Override protected short _getShort(int index) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 2 <= c.endOffset) { - return c.buf.getShort(index - c.offset); + return c.buf.getShort(c.idx(index)); } else if (order() == ByteOrder.BIG_ENDIAN) { return (short) ((_getByte(index) & 0xff) << 8 | _getByte(index + 1) & 0xff); } else { @@ -764,9 +862,9 @@ protected short _getShort(int index) { @Override protected short _getShortLE(int index) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 2 <= c.endOffset) { - return c.buf.getShortLE(index - c.offset); + return c.buf.getShortLE(c.idx(index)); } else if (order() == ByteOrder.BIG_ENDIAN) { return (short) (_getByte(index) & 0xff | (_getByte(index + 1) & 0xff) << 8); } else { @@ -776,9 +874,9 @@ protected short _getShortLE(int index) { @Override protected int _getUnsignedMedium(int index) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 3 <= c.endOffset) { - return c.buf.getUnsignedMedium(index - c.offset); + return c.buf.getUnsignedMedium(c.idx(index)); } else if (order() == ByteOrder.BIG_ENDIAN) { return (_getShort(index) & 0xffff) << 8 | _getByte(index + 2) & 0xff; } else { @@ -788,9 +886,9 @@ protected int _getUnsignedMedium(int index) { @Override protected int _getUnsignedMediumLE(int index) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 3 <= c.endOffset) { - return c.buf.getUnsignedMediumLE(index - c.offset); + return c.buf.getUnsignedMediumLE(c.idx(index)); } else if (order() == ByteOrder.BIG_ENDIAN) { return _getShortLE(index) & 0xffff | (_getByte(index + 2) & 0xff) << 16; } else { @@ -800,9 +898,9 @@ protected int _getUnsignedMediumLE(int index) { @Override protected int _getInt(int index) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 4 <= c.endOffset) { - return c.buf.getInt(index - c.offset); + return c.buf.getInt(c.idx(index)); } else if (order() == ByteOrder.BIG_ENDIAN) { return (_getShort(index) & 0xffff) << 16 | _getShort(index + 2) & 0xffff; } else { @@ -812,9 +910,9 @@ protected int _getInt(int index) { @Override protected int _getIntLE(int index) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 4 <= c.endOffset) { - return c.buf.getIntLE(index - c.offset); + return c.buf.getIntLE(c.idx(index)); } else if (order() == ByteOrder.BIG_ENDIAN) { return _getShortLE(index) & 0xffff | (_getShortLE(index + 2) & 0xffff) << 16; } else { @@ -824,9 +922,9 @@ protected int _getIntLE(int index) { @Override protected long _getLong(int index) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 8 <= c.endOffset) { - return c.buf.getLong(index - c.offset); + return c.buf.getLong(c.idx(index)); } else if (order() == ByteOrder.BIG_ENDIAN) { return (_getInt(index) & 0xffffffffL) << 32 | _getInt(index + 4) & 0xffffffffL; } else { @@ -836,9 +934,9 @@ protected long _getLong(int index) { @Override protected long _getLongLE(int index) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 8 <= c.endOffset) { - return c.buf.getLongLE(index - c.offset); + return c.buf.getLongLE(c.idx(index)); } else if (order() == ByteOrder.BIG_ENDIAN) { return _getIntLE(index) & 0xffffffffL | (_getIntLE(index + 4) & 0xffffffffL) << 32; } else { @@ -853,13 +951,11 @@ public CompositeByteBuf getBytes(int index, byte[] dst, int dstIndex, int length return this; } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); while (length > 0) { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); - s.getBytes(index - adjustment, dst, dstIndex, localLength); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.getBytes(c.idx(index), dst, dstIndex, localLength); index += localLength; dstIndex += localLength; length -= localLength; @@ -878,15 +974,13 @@ public CompositeByteBuf getBytes(int index, ByteBuffer dst) { return this; } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); try { while (length > 0) { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); dst.limit(dst.position() + localLength); - s.getBytes(index - adjustment, dst); + c.buf.getBytes(c.idx(index), dst); index += localLength; length -= localLength; i ++; @@ -904,13 +998,11 @@ public CompositeByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int lengt return this; } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); while (length > 0) { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); - s.getBytes(index - adjustment, dst, dstIndex, localLength); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.getBytes(c.idx(index), dst, dstIndex, localLength); index += localLength; dstIndex += localLength; length -= localLength; @@ -960,13 +1052,11 @@ public CompositeByteBuf getBytes(int index, OutputStream out, int length) throws return this; } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); while (length > 0) { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); - s.getBytes(index - adjustment, out, localLength); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.getBytes(c.idx(index), out, localLength); index += localLength; length -= localLength; i ++; @@ -977,25 +1067,28 @@ public CompositeByteBuf getBytes(int index, OutputStream out, int length) throws @Override public CompositeByteBuf setByte(int index, int value) { Component c = findComponent(index); - c.buf.setByte(index - c.offset, value); + c.buf.setByte(c.idx(index), value); return this; } @Override protected void _setByte(int index, int value) { - setByte(index, value); + Component c = findComponent0(index); + c.buf.setByte(c.idx(index), value); } @Override public CompositeByteBuf setShort(int index, int value) { - return (CompositeByteBuf) super.setShort(index, value); + checkIndex(index, 2); + _setShort(index, value); + return this; } @Override protected void _setShort(int index, int value) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 2 <= c.endOffset) { - c.buf.setShort(index - c.offset, value); + c.buf.setShort(c.idx(index), value); } else if (order() == ByteOrder.BIG_ENDIAN) { _setByte(index, (byte) (value >>> 8)); _setByte(index + 1, (byte) value); @@ -1007,9 +1100,9 @@ protected void _setShort(int index, int value) { @Override protected void _setShortLE(int index, int value) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 2 <= c.endOffset) { - c.buf.setShortLE(index - c.offset, value); + c.buf.setShortLE(c.idx(index), value); } else if (order() == ByteOrder.BIG_ENDIAN) { _setByte(index, (byte) value); _setByte(index + 1, (byte) (value >>> 8)); @@ -1021,14 +1114,16 @@ protected void _setShortLE(int index, int value) { @Override public CompositeByteBuf setMedium(int index, int value) { - return (CompositeByteBuf) super.setMedium(index, value); + checkIndex(index, 3); + _setMedium(index, value); + return this; } @Override protected void _setMedium(int index, int value) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 3 <= c.endOffset) { - c.buf.setMedium(index - c.offset, value); + c.buf.setMedium(c.idx(index), value); } else if (order() == ByteOrder.BIG_ENDIAN) { _setShort(index, (short) (value >> 8)); _setByte(index + 2, (byte) value); @@ -1040,9 +1135,9 @@ protected void _setMedium(int index, int value) { @Override protected void _setMediumLE(int index, int value) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 3 <= c.endOffset) { - c.buf.setMediumLE(index - c.offset, value); + c.buf.setMediumLE(c.idx(index), value); } else if (order() == ByteOrder.BIG_ENDIAN) { _setShortLE(index, (short) value); _setByte(index + 2, (byte) (value >>> 16)); @@ -1054,14 +1149,16 @@ protected void _setMediumLE(int index, int value) { @Override public CompositeByteBuf setInt(int index, int value) { - return (CompositeByteBuf) super.setInt(index, value); + checkIndex(index, 4); + _setInt(index, value); + return this; } @Override protected void _setInt(int index, int value) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 4 <= c.endOffset) { - c.buf.setInt(index - c.offset, value); + c.buf.setInt(c.idx(index), value); } else if (order() == ByteOrder.BIG_ENDIAN) { _setShort(index, (short) (value >>> 16)); _setShort(index + 2, (short) value); @@ -1073,9 +1170,9 @@ protected void _setInt(int index, int value) { @Override protected void _setIntLE(int index, int value) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 4 <= c.endOffset) { - c.buf.setIntLE(index - c.offset, value); + c.buf.setIntLE(c.idx(index), value); } else if (order() == ByteOrder.BIG_ENDIAN) { _setShortLE(index, (short) value); _setShortLE(index + 2, (short) (value >>> 16)); @@ -1087,14 +1184,16 @@ protected void _setIntLE(int index, int value) { @Override public CompositeByteBuf setLong(int index, long value) { - return (CompositeByteBuf) super.setLong(index, value); + checkIndex(index, 8); + _setLong(index, value); + return this; } @Override protected void _setLong(int index, long value) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 8 <= c.endOffset) { - c.buf.setLong(index - c.offset, value); + c.buf.setLong(c.idx(index), value); } else if (order() == ByteOrder.BIG_ENDIAN) { _setInt(index, (int) (value >>> 32)); _setInt(index + 4, (int) value); @@ -1106,9 +1205,9 @@ protected void _setLong(int index, long value) { @Override protected void _setLongLE(int index, long value) { - Component c = findComponent(index); + Component c = findComponent0(index); if (index + 8 <= c.endOffset) { - c.buf.setLongLE(index - c.offset, value); + c.buf.setLongLE(c.idx(index), value); } else if (order() == ByteOrder.BIG_ENDIAN) { _setIntLE(index, (int) value); _setIntLE(index + 4, (int) (value >>> 32)); @@ -1125,13 +1224,11 @@ public CompositeByteBuf setBytes(int index, byte[] src, int srcIndex, int length return this; } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); while (length > 0) { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); - s.setBytes(index - adjustment, src, srcIndex, localLength); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.setBytes(c.idx(index), src, srcIndex, localLength); index += localLength; srcIndex += localLength; length -= localLength; @@ -1150,15 +1247,13 @@ public CompositeByteBuf setBytes(int index, ByteBuffer src) { return this; } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); try { while (length > 0) { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); src.limit(src.position() + localLength); - s.setBytes(index - adjustment, src); + c.buf.setBytes(c.idx(index), src); index += localLength; length -= localLength; i ++; @@ -1176,13 +1271,11 @@ public CompositeByteBuf setBytes(int index, ByteBuf src, int srcIndex, int lengt return this; } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); while (length > 0) { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); - s.setBytes(index - adjustment, src, srcIndex, localLength); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.setBytes(c.idx(index), src, srcIndex, localLength); index += localLength; srcIndex += localLength; length -= localLength; @@ -1198,20 +1291,17 @@ public int setBytes(int index, InputStream in, int length) throws IOException { return in.read(EmptyArrays.EMPTY_BYTES); } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); int readBytes = 0; - do { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); if (localLength == 0) { // Skip empty buffer i++; continue; } - int localReadBytes = s.setBytes(index - adjustment, in, localLength); + int localReadBytes = c.buf.setBytes(c.idx(index), in, localLength); if (localReadBytes < 0) { if (readBytes == 0) { return -1; @@ -1220,15 +1310,11 @@ public int setBytes(int index, InputStream in, int length) throws IOException { } } + index += localReadBytes; + length -= localReadBytes; + readBytes += localReadBytes; if (localReadBytes == localLength) { - index += localLength; - length -= localLength; - readBytes += localLength; i ++; - } else { - index += localReadBytes; - length -= localReadBytes; - readBytes += localReadBytes; } } while (length > 0); @@ -1242,19 +1328,17 @@ public int setBytes(int index, ScatteringByteChannel in, int length) throws IOEx return in.read(EMPTY_NIO_BUFFER); } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); int readBytes = 0; do { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); if (localLength == 0) { // Skip empty buffer i++; continue; } - int localReadBytes = s.setBytes(index - adjustment, in, localLength); + int localReadBytes = c.buf.setBytes(c.idx(index), in, localLength); if (localReadBytes == 0) { break; @@ -1268,15 +1352,11 @@ public int setBytes(int index, ScatteringByteChannel in, int length) throws IOEx } } + index += localReadBytes; + length -= localReadBytes; + readBytes += localReadBytes; if (localReadBytes == localLength) { - index += localLength; - length -= localLength; - readBytes += localLength; i ++; - } else { - index += localReadBytes; - length -= localReadBytes; - readBytes += localReadBytes; } } while (length > 0); @@ -1290,19 +1370,17 @@ public int setBytes(int index, FileChannel in, long position, int length) throws return in.read(EMPTY_NIO_BUFFER, position); } - int i = toComponentIndex(index); + int i = toComponentIndex0(index); int readBytes = 0; do { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); if (localLength == 0) { // Skip empty buffer i++; continue; } - int localReadBytes = s.setBytes(index - adjustment, in, position + readBytes, localLength); + int localReadBytes = c.buf.setBytes(c.idx(index), in, position + readBytes, localLength); if (localReadBytes == 0) { break; @@ -1316,15 +1394,11 @@ public int setBytes(int index, FileChannel in, long position, int length) throws } } + index += localReadBytes; + length -= localReadBytes; + readBytes += localReadBytes; if (localReadBytes == localLength) { - index += localLength; - length -= localLength; - readBytes += localLength; i ++; - } else { - index += localReadBytes; - length -= localReadBytes; - readBytes += localReadBytes; } } while (length > 0); @@ -1336,7 +1410,7 @@ public ByteBuf copy(int index, int length) { checkIndex(index, length); ByteBuf dst = allocBuffer(length); if (length != 0) { - copyTo(index, length, toComponentIndex(index), dst); + copyTo(index, length, toComponentIndex0(index), dst); } return dst; } @@ -1346,11 +1420,9 @@ private void copyTo(int index, int length, int componentId, ByteBuf dst) { int i = componentId; while (length > 0) { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); - s.getBytes(index - adjustment, dst, dstIndex, localLength); + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.getBytes(c.idx(index), dst, dstIndex, localLength); index += localLength; dstIndex += localLength; length -= localLength; @@ -1367,7 +1439,8 @@ private void copyTo(int index, int length, int componentId, ByteBuf dst) { * @return buf the {@link ByteBuf} on the specified index */ public ByteBuf component(int cIndex) { - return internalComponent(cIndex).duplicate(); + checkComponentIndex(cIndex); + return components[cIndex].duplicate(); } /** @@ -1377,7 +1450,7 @@ public ByteBuf component(int cIndex) { * @return the {@link ByteBuf} on the specified index */ public ByteBuf componentAtOffset(int offset) { - return internalComponentAtOffset(offset).duplicate(); + return findComponent(offset).duplicate(); } /** @@ -1388,7 +1461,7 @@ public ByteBuf componentAtOffset(int offset) { */ public ByteBuf internalComponent(int cIndex) { checkComponentIndex(cIndex); - return components.get(cIndex).buf; + return components[cIndex].slice(); } /** @@ -1398,21 +1471,40 @@ public ByteBuf internalComponent(int cIndex) { * @param offset the offset for which the {@link ByteBuf} should be returned */ public ByteBuf internalComponentAtOffset(int offset) { - return findComponent(offset).buf; + return findComponent(offset).slice(); } + // weak cache - check it first when looking for component + private Component lastAccessed; + private Component findComponent(int offset) { + Component la = lastAccessed; + if (la != null && offset >= la.offset && offset < la.endOffset) { + ensureAccessible(); + return la; + } checkIndex(offset); + return findIt(offset); + } + + private Component findComponent0(int offset) { + Component la = lastAccessed; + if (la != null && offset >= la.offset && offset < la.endOffset) { + return la; + } + return findIt(offset); + } - for (int low = 0, high = components.size(); low <= high;) { + private Component findIt(int offset) { + for (int low = 0, high = componentCount; low <= high;) { int mid = low + high >>> 1; - Component c = components.get(mid); + Component c = components[mid]; if (offset >= c.endOffset) { low = mid + 1; } else if (offset < c.offset) { high = mid - 1; } else { - assert c.length != 0; + lastAccessed = c; return c; } } @@ -1422,17 +1514,16 @@ private Component findComponent(int offset) { @Override public int nioBufferCount() { - switch (components.size()) { + int size = componentCount; + switch (size) { case 0: return 1; case 1: - return components.get(0).buf.nioBufferCount(); + return components[0].buf.nioBufferCount(); default: int count = 0; - int componentsCount = components.size(); - for (int i = 0; i < componentsCount; i++) { - Component c = components.get(i); - count += c.buf.nioBufferCount(); + for (int i = 0; i < size; i++) { + count += components[i].buf.nioBufferCount(); } return count; } @@ -1440,11 +1531,12 @@ public int nioBufferCount() { @Override public ByteBuffer internalNioBuffer(int index, int length) { - switch (components.size()) { + switch (componentCount) { case 0: return EMPTY_NIO_BUFFER; case 1: - return components.get(0).buf.internalNioBuffer(index, length); + Component c = components[0]; + return c.buf.internalNioBuffer(c.idx(index), length); default: throw new UnsupportedOperationException(); } @@ -1454,19 +1546,24 @@ public ByteBuffer internalNioBuffer(int index, int length) { public ByteBuffer nioBuffer(int index, int length) { checkIndex(index, length); - switch (components.size()) { + switch (componentCount) { case 0: return EMPTY_NIO_BUFFER; case 1: - ByteBuf buf = components.get(0).buf; + Component c = components[0]; + ByteBuf buf = c.buf; if (buf.nioBufferCount() == 1) { - return components.get(0).buf.nioBuffer(index, length); + return buf.nioBuffer(c.idx(index), length); } } - ByteBuffer merged = ByteBuffer.allocate(length).order(order()); ByteBuffer[] buffers = nioBuffers(index, length); + if (buffers.length == 1) { + return buffers[0].duplicate(); + } + + ByteBuffer merged = ByteBuffer.allocate(length).order(order()); for (ByteBuffer buf: buffers) { merged.put(buf); } @@ -1482,29 +1579,32 @@ public ByteBuffer[] nioBuffers(int index, int length) { return new ByteBuffer[] { EMPTY_NIO_BUFFER }; } - List buffers = new ArrayList(components.size()); - int i = toComponentIndex(index); - while (length > 0) { - Component c = components.get(i); - ByteBuf s = c.buf; - int adjustment = c.offset; - int localLength = Math.min(length, s.capacity() - (index - adjustment)); - switch (s.nioBufferCount()) { + RecyclableArrayList buffers = RecyclableArrayList.newInstance(componentCount); + try { + int i = toComponentIndex0(index); + while (length > 0) { + Component c = components[i]; + ByteBuf s = c.buf; + int localLength = Math.min(length, c.endOffset - index); + switch (s.nioBufferCount()) { case 0: throw new UnsupportedOperationException(); case 1: - buffers.add(s.nioBuffer(index - adjustment, localLength)); + buffers.add(s.nioBuffer(c.idx(index), localLength)); break; default: - Collections.addAll(buffers, s.nioBuffers(index - adjustment, localLength)); + Collections.addAll(buffers, s.nioBuffers(c.idx(index), localLength)); + } + + index += localLength; + length -= localLength; + i ++; } - index += localLength; - length -= localLength; - i ++; + return buffers.toArray(new ByteBuffer[0]); + } finally { + buffers.recycle(); } - - return buffers.toArray(new ByteBuffer[buffers.size()]); } /** @@ -1512,25 +1612,20 @@ public ByteBuffer[] nioBuffers(int index, int length) { */ public CompositeByteBuf consolidate() { ensureAccessible(); - final int numComponents = numComponents(); + final int numComponents = componentCount; if (numComponents <= 1) { return this; } - final Component last = components.get(numComponents - 1); - final int capacity = last.endOffset; + final int capacity = components[numComponents - 1].endOffset; final ByteBuf consolidated = allocBuffer(capacity); for (int i = 0; i < numComponents; i ++) { - Component c = components.get(i); - ByteBuf b = c.buf; - consolidated.writeBytes(b); - c.freeIfNecessary(); + components[i].transferTo(consolidated); } - - components.clear(); - components.add(new Component(consolidated)); - updateComponentOffsets(0); + lastAccessed = null; + components[0] = new Component(consolidated, 0, 0, capacity, consolidated); + removeCompRange(1, numComponents); return this; } @@ -1547,19 +1642,16 @@ public CompositeByteBuf consolidate(int cIndex, int numComponents) { } final int endCIndex = cIndex + numComponents; - final Component last = components.get(endCIndex - 1); - final int capacity = last.endOffset - components.get(cIndex).offset; + final Component last = components[endCIndex - 1]; + final int capacity = last.endOffset - components[cIndex].offset; final ByteBuf consolidated = allocBuffer(capacity); for (int i = cIndex; i < endCIndex; i ++) { - Component c = components.get(i); - ByteBuf b = c.buf; - consolidated.writeBytes(b); - c.freeIfNecessary(); + components[i].transferTo(consolidated); } - - components.removeRange(cIndex + 1, endCIndex); - components.set(cIndex, new Component(consolidated)); + lastAccessed = null; + removeCompRange(cIndex + 1, endCIndex); + components[cIndex] = new Component(consolidated, 0, 0, capacity, consolidated); updateComponentOffsets(cIndex); return this; } @@ -1577,25 +1669,26 @@ public CompositeByteBuf discardReadComponents() { // Discard everything if (readerIndex = writerIndex = capacity). int writerIndex = writerIndex(); if (readerIndex == writerIndex && writerIndex == capacity()) { - int size = components.size(); - for (int i = 0; i < size; i++) { - components.get(i).freeIfNecessary(); + for (int i = 0, size = componentCount; i < size; i++) { + components[i].free(); } - components.clear(); + lastAccessed = null; + clearComps(); setIndex(0, 0); adjustMarkers(readerIndex); return this; } // Remove read components. - int firstComponentId = toComponentIndex(readerIndex); + int firstComponentId = toComponentIndex0(readerIndex); for (int i = 0; i < firstComponentId; i ++) { - components.get(i).freeIfNecessary(); + components[i].free(); } - components.removeRange(0, firstComponentId); + lastAccessed = null; + removeCompRange(0, firstComponentId); // Update indexes and markers. - Component first = components.get(0); + Component first = components[0]; int offset = first.offset; updateComponentOffsets(0); setIndex(readerIndex - offset, writerIndex - offset); @@ -1614,34 +1707,49 @@ public CompositeByteBuf discardReadBytes() { // Discard everything if (readerIndex = writerIndex = capacity). int writerIndex = writerIndex(); if (readerIndex == writerIndex && writerIndex == capacity()) { - int size = components.size(); - for (int i = 0; i < size; i++) { - components.get(i).freeIfNecessary(); + for (int i = 0, size = componentCount; i < size; i++) { + components[i].free(); } - components.clear(); + lastAccessed = null; + clearComps(); setIndex(0, 0); adjustMarkers(readerIndex); return this; } // Remove read components. - int firstComponentId = toComponentIndex(readerIndex); + int firstComponentId = toComponentIndex0(readerIndex); for (int i = 0; i < firstComponentId; i ++) { - components.get(i).freeIfNecessary(); + Component c = components[i]; + c.free(); + if (lastAccessed == c) { + lastAccessed = null; + } } // Remove or replace the first readable component with a new slice. - Component c = components.get(firstComponentId); - int adjustment = readerIndex - c.offset; - if (adjustment == c.length) { + Component c = components[firstComponentId]; + if (readerIndex == c.endOffset) { // new slice would be empty, so remove instead + c.free(); + if (lastAccessed == c) { + lastAccessed = null; + } firstComponentId++; } else { - Component newC = new Component(c.buf.slice(adjustment, c.length - adjustment)); - components.set(firstComponentId, newC); + int trimmedBytes = readerIndex - c.offset; + c.offset = 0; + c.endOffset -= readerIndex; + c.adjustment += readerIndex; + ByteBuf slice = c.slice; + if (slice != null) { + // We must replace the cached slice with a derived one to ensure that + // it can later be released properly in the case of PooledSlicedByteBuf. + c.slice = slice.slice(trimmedBytes, c.length()); + } } - components.removeRange(0, firstComponentId); + removeCompRange(0, firstComponentId); // Update indexes and markers. updateComponentOffsets(0); @@ -1658,253 +1766,338 @@ private ByteBuf allocBuffer(int capacity) { public String toString() { String result = super.toString(); result = result.substring(0, result.length() - 1); - return result + ", components=" + components.size() + ')'; + return result + ", components=" + componentCount + ')'; } private static final class Component { final ByteBuf buf; - final int length; + int adjustment; int offset; int endOffset; - Component(ByteBuf buf) { + private ByteBuf slice; // cached slice, may be null + + Component(ByteBuf buf, int srcOffset, int offset, int len, ByteBuf slice) { this.buf = buf; - length = buf.readableBytes(); + this.offset = offset; + this.endOffset = offset + len; + this.adjustment = srcOffset - offset; + this.slice = slice; } - void freeIfNecessary() { - buf.release(); // We should not get a NPE here. If so, it must be a bug. + int idx(int index) { + return index + adjustment; + } + + int length() { + return endOffset - offset; + } + + void reposition(int newOffset) { + int move = newOffset - offset; + endOffset += move; + adjustment -= move; + offset = newOffset; + } + + // copy then release + void transferTo(ByteBuf dst) { + dst.writeBytes(buf, idx(offset), length()); + free(); + } + + ByteBuf slice() { + return slice != null ? slice : (slice = buf.slice(idx(offset), length())); + } + + ByteBuf duplicate() { + return buf.duplicate().setIndex(idx(offset), idx(endOffset)); + } + + void free() { + // Release the slice if present since it may have a different + // refcount to the unwrapped buf if it is a PooledSlicedByteBuf + ByteBuf buffer = slice; + if (buffer != null) { + buffer.release(); + } else { + buf.release(); + } + // null out in either case since it could be racy if set lazily (but not + // in the case we care about, where it will have been set in the ctor) + slice = null; } } @Override public CompositeByteBuf readerIndex(int readerIndex) { - return (CompositeByteBuf) super.readerIndex(readerIndex); + super.readerIndex(readerIndex); + return this; } @Override public CompositeByteBuf writerIndex(int writerIndex) { - return (CompositeByteBuf) super.writerIndex(writerIndex); + super.writerIndex(writerIndex); + return this; } @Override public CompositeByteBuf setIndex(int readerIndex, int writerIndex) { - return (CompositeByteBuf) super.setIndex(readerIndex, writerIndex); + super.setIndex(readerIndex, writerIndex); + return this; } @Override public CompositeByteBuf clear() { - return (CompositeByteBuf) super.clear(); + super.clear(); + return this; } @Override public CompositeByteBuf markReaderIndex() { - return (CompositeByteBuf) super.markReaderIndex(); + super.markReaderIndex(); + return this; } @Override public CompositeByteBuf resetReaderIndex() { - return (CompositeByteBuf) super.resetReaderIndex(); + super.resetReaderIndex(); + return this; } @Override public CompositeByteBuf markWriterIndex() { - return (CompositeByteBuf) super.markWriterIndex(); + super.markWriterIndex(); + return this; } @Override public CompositeByteBuf resetWriterIndex() { - return (CompositeByteBuf) super.resetWriterIndex(); + super.resetWriterIndex(); + return this; } @Override public CompositeByteBuf ensureWritable(int minWritableBytes) { - return (CompositeByteBuf) super.ensureWritable(minWritableBytes); + super.ensureWritable(minWritableBytes); + return this; } @Override public CompositeByteBuf getBytes(int index, ByteBuf dst) { - return (CompositeByteBuf) super.getBytes(index, dst); + return getBytes(index, dst, dst.writableBytes()); } @Override public CompositeByteBuf getBytes(int index, ByteBuf dst, int length) { - return (CompositeByteBuf) super.getBytes(index, dst, length); + getBytes(index, dst, dst.writerIndex(), length); + dst.writerIndex(dst.writerIndex() + length); + return this; } @Override public CompositeByteBuf getBytes(int index, byte[] dst) { - return (CompositeByteBuf) super.getBytes(index, dst); + return getBytes(index, dst, 0, dst.length); } @Override public CompositeByteBuf setBoolean(int index, boolean value) { - return (CompositeByteBuf) super.setBoolean(index, value); + return setByte(index, value? 1 : 0); } @Override public CompositeByteBuf setChar(int index, int value) { - return (CompositeByteBuf) super.setChar(index, value); + return setShort(index, value); } @Override public CompositeByteBuf setFloat(int index, float value) { - return (CompositeByteBuf) super.setFloat(index, value); + return setInt(index, Float.floatToRawIntBits(value)); } @Override public CompositeByteBuf setDouble(int index, double value) { - return (CompositeByteBuf) super.setDouble(index, value); + return setLong(index, Double.doubleToRawLongBits(value)); } @Override public CompositeByteBuf setBytes(int index, ByteBuf src) { - return (CompositeByteBuf) super.setBytes(index, src); + super.setBytes(index, src, src.readableBytes()); + return this; } @Override public CompositeByteBuf setBytes(int index, ByteBuf src, int length) { - return (CompositeByteBuf) super.setBytes(index, src, length); + super.setBytes(index, src, length); + return this; } @Override public CompositeByteBuf setBytes(int index, byte[] src) { - return (CompositeByteBuf) super.setBytes(index, src); + return setBytes(index, src, 0, src.length); } @Override public CompositeByteBuf setZero(int index, int length) { - return (CompositeByteBuf) super.setZero(index, length); + super.setZero(index, length); + return this; } @Override public CompositeByteBuf readBytes(ByteBuf dst) { - return (CompositeByteBuf) super.readBytes(dst); + super.readBytes(dst, dst.writableBytes()); + return this; } @Override public CompositeByteBuf readBytes(ByteBuf dst, int length) { - return (CompositeByteBuf) super.readBytes(dst, length); + super.readBytes(dst, length); + return this; } @Override public CompositeByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { - return (CompositeByteBuf) super.readBytes(dst, dstIndex, length); + super.readBytes(dst, dstIndex, length); + return this; } @Override public CompositeByteBuf readBytes(byte[] dst) { - return (CompositeByteBuf) super.readBytes(dst); + super.readBytes(dst, 0, dst.length); + return this; } @Override public CompositeByteBuf readBytes(byte[] dst, int dstIndex, int length) { - return (CompositeByteBuf) super.readBytes(dst, dstIndex, length); + super.readBytes(dst, dstIndex, length); + return this; } @Override public CompositeByteBuf readBytes(ByteBuffer dst) { - return (CompositeByteBuf) super.readBytes(dst); + super.readBytes(dst); + return this; } @Override public CompositeByteBuf readBytes(OutputStream out, int length) throws IOException { - return (CompositeByteBuf) super.readBytes(out, length); + super.readBytes(out, length); + return this; } @Override public CompositeByteBuf skipBytes(int length) { - return (CompositeByteBuf) super.skipBytes(length); + super.skipBytes(length); + return this; } @Override public CompositeByteBuf writeBoolean(boolean value) { - return (CompositeByteBuf) super.writeBoolean(value); + writeByte(value ? 1 : 0); + return this; } @Override public CompositeByteBuf writeByte(int value) { - return (CompositeByteBuf) super.writeByte(value); + ensureWritable0(1); + _setByte(writerIndex++, value); + return this; } @Override public CompositeByteBuf writeShort(int value) { - return (CompositeByteBuf) super.writeShort(value); + super.writeShort(value); + return this; } @Override public CompositeByteBuf writeMedium(int value) { - return (CompositeByteBuf) super.writeMedium(value); + super.writeMedium(value); + return this; } @Override public CompositeByteBuf writeInt(int value) { - return (CompositeByteBuf) super.writeInt(value); + super.writeInt(value); + return this; } @Override public CompositeByteBuf writeLong(long value) { - return (CompositeByteBuf) super.writeLong(value); + super.writeLong(value); + return this; } @Override public CompositeByteBuf writeChar(int value) { - return (CompositeByteBuf) super.writeChar(value); + super.writeShort(value); + return this; } @Override public CompositeByteBuf writeFloat(float value) { - return (CompositeByteBuf) super.writeFloat(value); + super.writeInt(Float.floatToRawIntBits(value)); + return this; } @Override public CompositeByteBuf writeDouble(double value) { - return (CompositeByteBuf) super.writeDouble(value); + super.writeLong(Double.doubleToRawLongBits(value)); + return this; } @Override public CompositeByteBuf writeBytes(ByteBuf src) { - return (CompositeByteBuf) super.writeBytes(src); + super.writeBytes(src, src.readableBytes()); + return this; } @Override public CompositeByteBuf writeBytes(ByteBuf src, int length) { - return (CompositeByteBuf) super.writeBytes(src, length); + super.writeBytes(src, length); + return this; } @Override public CompositeByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { - return (CompositeByteBuf) super.writeBytes(src, srcIndex, length); + super.writeBytes(src, srcIndex, length); + return this; } @Override public CompositeByteBuf writeBytes(byte[] src) { - return (CompositeByteBuf) super.writeBytes(src); + super.writeBytes(src, 0, src.length); + return this; } @Override public CompositeByteBuf writeBytes(byte[] src, int srcIndex, int length) { - return (CompositeByteBuf) super.writeBytes(src, srcIndex, length); + super.writeBytes(src, srcIndex, length); + return this; } @Override public CompositeByteBuf writeBytes(ByteBuffer src) { - return (CompositeByteBuf) super.writeBytes(src); + super.writeBytes(src); + return this; } @Override public CompositeByteBuf writeZero(int length) { - return (CompositeByteBuf) super.writeZero(length); + super.writeZero(length); + return this; } @Override public CompositeByteBuf retain(int increment) { - return (CompositeByteBuf) super.retain(increment); + super.retain(increment); + return this; } @Override public CompositeByteBuf retain() { - return (CompositeByteBuf) super.retain(); + super.retain(); + return this; } @Override @@ -1934,21 +2127,25 @@ protected void deallocate() { } freed = true; - int size = components.size(); // We're not using foreach to avoid creating an iterator. // see https://github.com/netty/netty/issues/2642 - for (int i = 0; i < size; i++) { - components.get(i).freeIfNecessary(); + for (int i = 0, size = componentCount; i < size; i++) { + components[i].free(); } } + @Override + boolean isAccessible() { + return !freed; + } + @Override public ByteBuf unwrap() { return null; } private final class CompositeByteBufIterator implements Iterator { - private final int size = components.size(); + private final int size = numComponents(); private int index; @Override @@ -1958,14 +2155,14 @@ public boolean hasNext() { @Override public ByteBuf next() { - if (size != components.size()) { + if (size != numComponents()) { throw new ConcurrentModificationException(); } if (!hasNext()) { throw new NoSuchElementException(); } try { - return components.get(index++).buf; + return components[index++].slice(); } catch (IndexOutOfBoundsException e) { throw new ConcurrentModificationException(); } @@ -1977,16 +2174,59 @@ public void remove() { } } - private static final class ComponentList extends ArrayList { + // Component array manipulation - range checking omitted + + private void clearComps() { + removeCompRange(0, componentCount); + } - ComponentList(int initialCapacity) { - super(initialCapacity); + private void removeComp(int i) { + removeCompRange(i, i + 1); + } + + private void removeCompRange(int from, int to) { + if (from >= to) { + return; + } + final int size = componentCount; + assert from >= 0 && to <= size; + if (to < size) { + System.arraycopy(components, to, components, from, size - to); } + int newSize = size - to + from; + for (int i = newSize; i < size; i++) { + components[i] = null; + } + componentCount = newSize; + } - // Expose this methods so we not need to create a new subList just to remove a range of elements. - @Override - public void removeRange(int fromIndex, int toIndex) { - super.removeRange(fromIndex, toIndex); + private void addComp(int i, Component c) { + shiftComps(i, 1); + components[i] = c; + } + + private void shiftComps(int i, int count) { + final int size = componentCount, newSize = size + count; + assert i >= 0 && i <= size && count > 0; + if (newSize > components.length) { + // grow the array + int newArrSize = Math.max(size + (size >> 1), newSize); + Component[] newArr; + if (i == size) { + newArr = Arrays.copyOf(components, newArrSize, Component[].class); + } else { + newArr = new Component[newArrSize]; + if (i > 0) { + System.arraycopy(components, 0, newArr, 0, i); + } + if (i < size) { + System.arraycopy(components, i, newArr, i + count, size - i); + } + } + components = newArr; + } else if (i < size) { + System.arraycopy(components, i, components, i + count, size - i); } + componentCount = newSize; } } diff --git a/buffer/src/main/java/io/netty/buffer/EmptyByteBuf.java b/buffer/src/main/java/io/netty/buffer/EmptyByteBuf.java index b954318a7f81..cbc8a1acf8c3 100644 --- a/buffer/src/main/java/io/netty/buffer/EmptyByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/EmptyByteBuf.java @@ -16,6 +16,8 @@ package io.netty.buffer; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.util.ByteProcessor; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PlatformDependent; @@ -223,9 +225,7 @@ public ByteBuf discardSomeReadBytes() { @Override public ByteBuf ensureWritable(int minWritableBytes) { - if (minWritableBytes < 0) { - throw new IllegalArgumentException("minWritableBytes: " + minWritableBytes + " (expected: >= 0)"); - } + checkPositiveOrZero(minWritableBytes, "minWritableBytes"); if (minWritableBytes != 0) { throw new IndexOutOfBoundsException(); } @@ -234,9 +234,7 @@ public ByteBuf ensureWritable(int minWritableBytes) { @Override public int ensureWritable(int minWritableBytes, boolean force) { - if (minWritableBytes < 0) { - throw new IllegalArgumentException("minWritableBytes: " + minWritableBytes + " (expected: >= 0)"); - } + checkPositiveOrZero(minWritableBytes, "minWritableBytes"); if (minWritableBytes == 0) { return 0; @@ -1048,9 +1046,7 @@ private ByteBuf checkIndex(int index) { } private ByteBuf checkIndex(int index, int length) { - if (length < 0) { - throw new IllegalArgumentException("length: " + length); - } + checkPositiveOrZero(length, "length"); if (index != 0 || length != 0) { throw new IndexOutOfBoundsException(); } @@ -1058,9 +1054,7 @@ private ByteBuf checkIndex(int index, int length) { } private ByteBuf checkLength(int length) { - if (length < 0) { - throw new IllegalArgumentException("length: " + length + " (expected: >= 0)"); - } + checkPositiveOrZero(length, "length"); if (length != 0) { throw new IndexOutOfBoundsException(); } diff --git a/buffer/src/main/java/io/netty/buffer/FixedCompositeByteBuf.java b/buffer/src/main/java/io/netty/buffer/FixedCompositeByteBuf.java index b8d650663f30..08c8d7fc6a5c 100644 --- a/buffer/src/main/java/io/netty/buffer/FixedCompositeByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/FixedCompositeByteBuf.java @@ -39,7 +39,7 @@ final class FixedCompositeByteBuf extends AbstractReferenceCountedByteBuf { private final int capacity; private final ByteBufAllocator allocator; private final ByteOrder order; - private final Object[] buffers; + private final ByteBuf[] buffers; private final boolean direct; FixedCompositeByteBuf(ByteBufAllocator allocator, ByteBuf... buffers) { @@ -52,8 +52,7 @@ final class FixedCompositeByteBuf extends AbstractReferenceCountedByteBuf { direct = false; } else { ByteBuf b = buffers[0]; - this.buffers = new Object[buffers.length]; - this.buffers[0] = b; + this.buffers = buffers; boolean direct = true; int nioBufferCount = b.nioBufferCount(); int capacity = b.readableBytes(); @@ -68,7 +67,6 @@ final class FixedCompositeByteBuf extends AbstractReferenceCountedByteBuf { if (!b.isDirect()) { direct = false; } - this.buffers[i] = b; } this.nioBufferCount = nioBufferCount; this.capacity = capacity; @@ -232,20 +230,14 @@ private Component findComponent(int index) { int readable = 0; for (int i = 0 ; i < buffers.length; i++) { Component comp = null; - ByteBuf b; - Object obj = buffers[i]; - boolean isBuffer; - if (obj instanceof ByteBuf) { - b = (ByteBuf) obj; - isBuffer = true; - } else { - comp = (Component) obj; + ByteBuf b = buffers[i]; + if (b instanceof Component) { + comp = (Component) b; b = comp.buf; - isBuffer = false; } readable += b.readableBytes(); if (index < readable) { - if (isBuffer) { + if (comp == null) { // Create a new component and store it in the array so it not create a new object // on the next access. comp = new Component(i, readable - b.readableBytes(), b); @@ -261,11 +253,8 @@ private Component findComponent(int index) { * Return the {@link ByteBuf} stored at the given index of the array. */ private ByteBuf buffer(int i) { - Object obj = buffers[i]; - if (obj instanceof ByteBuf) { - return (ByteBuf) obj; - } - return ((Component) obj).buf; + ByteBuf b = buffers[i]; + return b instanceof Component ? ((Component) b).buf : b; } @Override @@ -604,7 +593,7 @@ public ByteBuffer[] nioBuffers(int index, int length) { s = buffer(++i); } - return array.toArray(new ByteBuffer[array.size()]); + return array.toArray(new ByteBuffer[0]); } finally { array.recycle(); } @@ -684,17 +673,16 @@ public String toString() { return result + ", components=" + buffers.length + ')'; } - private static final class Component { + private static final class Component extends WrappedByteBuf { private final int index; private final int offset; - private final ByteBuf buf; private final int endOffset; Component(int index, int offset, ByteBuf buf) { + super(buf); this.index = index; this.offset = offset; endOffset = offset + buf.readableBytes(); - this.buf = buf; } } } diff --git a/buffer/src/main/java/io/netty/buffer/PoolArena.java b/buffer/src/main/java/io/netty/buffer/PoolArena.java index 48593ef34d21..7d75c9b4bb29 100644 --- a/buffer/src/main/java/io/netty/buffer/PoolArena.java +++ b/buffer/src/main/java/io/netty/buffer/PoolArena.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static java.lang.Math.max; abstract class PoolArena implements PoolArenaMetric { @@ -205,7 +206,7 @@ private void allocate(PoolThreadCache cache, PooledByteBuf buf, final int req assert s.doNotDestroy && s.elemSize == normCapacity; long handle = s.allocate(); assert handle >= 0; - s.chunk.initBufWithSubpage(buf, handle, reqCapacity); + s.chunk.initBufWithSubpage(buf, null, handle, reqCapacity); incTinySmallAllocation(tiny); return; } @@ -242,9 +243,8 @@ private void allocateNormal(PooledByteBuf buf, int reqCapacity, int normCapac // Add a new chunk. PoolChunk c = newChunk(pageSize, maxOrder, pageShifts, chunkSize); - long handle = c.allocate(normCapacity); - assert handle > 0; - c.initBuf(buf, handle, reqCapacity); + boolean success = c.allocate(buf, reqCapacity, normCapacity); + assert success; qInit.add(c); } @@ -263,7 +263,7 @@ private void allocateHuge(PooledByteBuf buf, int reqCapacity) { allocationsHuge.increment(); } - void free(PoolChunk chunk, long handle, int normCapacity, PoolThreadCache cache) { + void free(PoolChunk chunk, ByteBuffer nioBuffer, long handle, int normCapacity, PoolThreadCache cache) { if (chunk.unpooled) { int size = chunk.chunkSize(); destroyChunk(chunk); @@ -271,12 +271,12 @@ void free(PoolChunk chunk, long handle, int normCapacity, PoolThreadCache cac deallocationsHuge.increment(); } else { SizeClass sizeClass = sizeClass(normCapacity); - if (cache != null && cache.add(this, chunk, handle, normCapacity, sizeClass)) { + if (cache != null && cache.add(this, chunk, nioBuffer, handle, normCapacity, sizeClass)) { // cached so not free it. return; } - freeChunk(chunk, handle, sizeClass); + freeChunk(chunk, handle, sizeClass, nioBuffer); } } @@ -287,7 +287,7 @@ private SizeClass sizeClass(int normCapacity) { return isTiny(normCapacity) ? SizeClass.Tiny : SizeClass.Small; } - void freeChunk(PoolChunk chunk, long handle, SizeClass sizeClass) { + void freeChunk(PoolChunk chunk, long handle, SizeClass sizeClass, ByteBuffer nioBuffer) { final boolean destroyChunk; synchronized (this) { switch (sizeClass) { @@ -303,7 +303,7 @@ void freeChunk(PoolChunk chunk, long handle, SizeClass sizeClass) { default: throw new Error(); } - destroyChunk = !chunk.parent.free(chunk, handle); + destroyChunk = !chunk.parent.free(chunk, handle, nioBuffer); } if (destroyChunk) { // destroyChunk not need to be called while holding the synchronized lock. @@ -331,9 +331,7 @@ PoolSubpage findSubpagePoolHead(int elemSize) { } int normalizeCapacity(int reqCapacity) { - if (reqCapacity < 0) { - throw new IllegalArgumentException("capacity: " + reqCapacity + " (expected: 0+)"); - } + checkPositiveOrZero(reqCapacity, "reqCapacity"); if (reqCapacity >= chunkSize) { return directMemoryCacheAlignment == 0 ? reqCapacity : alignCapacity(reqCapacity); @@ -387,6 +385,7 @@ void reallocate(PooledByteBuf buf, int newCapacity, boolean freeOldMemory) { } PoolChunk oldChunk = buf.chunk; + ByteBuffer oldNioBuffer = buf.tmpNioBuf; long oldHandle = buf.handle; T oldMemory = buf.memory; int oldOffset = buf.offset; @@ -415,7 +414,7 @@ void reallocate(PooledByteBuf buf, int newCapacity, boolean freeOldMemory) { buf.setIndex(readerIndex, writerIndex); if (freeOldMemory) { - free(oldChunk, oldHandle, oldMaxLength, buf.cache); + free(oldChunk, oldNioBuffer, oldHandle, oldMaxLength, buf.cache); } } @@ -725,11 +724,16 @@ boolean isDirect() { return true; } - private int offsetCacheLine(ByteBuffer memory) { + // mark as package-private, only for unit test + int offsetCacheLine(ByteBuffer memory) { // We can only calculate the offset if Unsafe is present as otherwise directBufferAddress(...) will // throw an NPE. - return HAS_UNSAFE ? - (int) (PlatformDependent.directBufferAddress(memory) & directMemoryCacheAlignmentMask) : 0; + int remainder = HAS_UNSAFE + ? (int) (PlatformDependent.directBufferAddress(memory) & directMemoryCacheAlignmentMask) + : 0; + + // offset = alignment - address & (alignment - 1) + return directMemoryCacheAlignment - remainder; } @Override diff --git a/buffer/src/main/java/io/netty/buffer/PoolChunk.java b/buffer/src/main/java/io/netty/buffer/PoolChunk.java index b3ca160223a7..0a19a5a1d492 100644 --- a/buffer/src/main/java/io/netty/buffer/PoolChunk.java +++ b/buffer/src/main/java/io/netty/buffer/PoolChunk.java @@ -16,6 +16,10 @@ package io.netty.buffer; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Deque; + /** * Description of algorithm for PageRun/PoolSubpage allocation from PoolChunk * @@ -94,11 +98,10 @@ * Note: * ----- * In the implementation for improving cache coherence, - * we store 2 pieces of information (i.e, 2 byte vals) as a short value in memoryMap + * we store 2 pieces of information depth_of_id and x as two byte values in memoryMap and depthMap respectively * - * memoryMap[id]= (depth_of_id, x) - * where as per convention defined above - * the second value (i.e, x) indicates that the first node which is free to be allocated is at depth x (from root) + * memoryMap[id]= depth_of_id is defined above + * depthMap[id]= x indicates that the first node which is free to be allocated is at depth x (from root) */ final class PoolChunk implements PoolChunkMetric { @@ -108,7 +111,6 @@ final class PoolChunk implements PoolChunkMetric { final T memory; final boolean unpooled; final int offset; - private final byte[] memoryMap; private final byte[] depthMap; private final PoolSubpage[] subpages; @@ -123,6 +125,13 @@ final class PoolChunk implements PoolChunkMetric { /** Used to mark memory as unusable */ private final byte unusable; + // Use as cache for ByteBuffer created from the memory. These are just duplicates and so are only a container + // around the memory itself. These are often needed for operations within the Pooled*ByteBuf and so + // may produce extra GC, which can be greatly reduced by caching the duplicates. + // + // This may be null if the PoolChunk is unpooled as pooling the ByteBuffer instances does not make any sense here. + private final Deque cachedNioBuffers; + private int freeBytes; PoolChunkList parent; @@ -164,6 +173,7 @@ final class PoolChunk implements PoolChunkMetric { } subpages = newSubpageArray(maxSubpageAllocs); + cachedNioBuffers = new ArrayDeque(8); } /** Creates a special chunk that is not pooled. */ @@ -183,6 +193,7 @@ final class PoolChunk implements PoolChunkMetric { chunkSize = size; log2ChunkSize = log2(chunkSize); maxSubpageAllocs = 0; + cachedNioBuffers = null; } @SuppressWarnings("unchecked") @@ -211,12 +222,20 @@ private int usage(int freeBytes) { return 100 - freePercentage; } - long allocate(int normCapacity) { + boolean allocate(PooledByteBuf buf, int reqCapacity, int normCapacity) { + final long handle; if ((normCapacity & subpageOverflowMask) != 0) { // >= pageSize - return allocateRun(normCapacity); + handle = allocateRun(normCapacity); } else { - return allocateSubpage(normCapacity); + handle = allocateSubpage(normCapacity); } + + if (handle < 0) { + return false; + } + ByteBuffer nioBuffer = cachedNioBuffers != null ? cachedNioBuffers.pollLast() : null; + initBuf(buf, nioBuffer, handle, reqCapacity); + return true; } /** @@ -311,8 +330,8 @@ private long allocateRun(int normCapacity) { } /** - * Create/ initialize a new PoolSubpage of normCapacity - * Any PoolSubpage created/ initialized here is added to subpage pool in the PoolArena that owns this PoolChunk + * Create / initialize a new PoolSubpage of normCapacity + * Any PoolSubpage created / initialized here is added to subpage pool in the PoolArena that owns this PoolChunk * * @param normCapacity normalized capacity * @return index in memoryMap @@ -321,8 +340,8 @@ private long allocateSubpage(int normCapacity) { // Obtain the head of the PoolSubPage pool that is owned by the PoolArena and synchronize on it. // This is need as we may add it back and so alter the linked-list structure. PoolSubpage head = arena.findSubpagePoolHead(normCapacity); + int d = maxOrder; // subpages are only be allocated from pages i.e., leaves synchronized (head) { - int d = maxOrder; // subpages are only be allocated from pages i.e., leaves int id = allocateNode(d); if (id < 0) { return id; @@ -353,7 +372,7 @@ private long allocateSubpage(int normCapacity) { * * @param handle handle to free */ - void free(long handle) { + void free(long handle, ByteBuffer nioBuffer) { int memoryMapIdx = memoryMapIdx(handle); int bitmapIdx = bitmapIdx(handle); @@ -373,26 +392,32 @@ void free(long handle) { freeBytes += runLength(memoryMapIdx); setValue(memoryMapIdx, depth(memoryMapIdx)); updateParentsFree(memoryMapIdx); + + if (nioBuffer != null && cachedNioBuffers != null && + cachedNioBuffers.size() < PooledByteBufAllocator.DEFAULT_MAX_CACHED_BYTEBUFFERS_PER_CHUNK) { + cachedNioBuffers.offer(nioBuffer); + } } - void initBuf(PooledByteBuf buf, long handle, int reqCapacity) { + void initBuf(PooledByteBuf buf, ByteBuffer nioBuffer, long handle, int reqCapacity) { int memoryMapIdx = memoryMapIdx(handle); int bitmapIdx = bitmapIdx(handle); if (bitmapIdx == 0) { byte val = value(memoryMapIdx); assert val == unusable : String.valueOf(val); - buf.init(this, handle, runOffset(memoryMapIdx) + offset, reqCapacity, runLength(memoryMapIdx), - arena.parent.threadCache()); + buf.init(this, nioBuffer, handle, runOffset(memoryMapIdx) + offset, + reqCapacity, runLength(memoryMapIdx), arena.parent.threadCache()); } else { - initBufWithSubpage(buf, handle, bitmapIdx, reqCapacity); + initBufWithSubpage(buf, nioBuffer, handle, bitmapIdx, reqCapacity); } } - void initBufWithSubpage(PooledByteBuf buf, long handle, int reqCapacity) { - initBufWithSubpage(buf, handle, bitmapIdx(handle), reqCapacity); + void initBufWithSubpage(PooledByteBuf buf, ByteBuffer nioBuffer, long handle, int reqCapacity) { + initBufWithSubpage(buf, nioBuffer, handle, bitmapIdx(handle), reqCapacity); } - private void initBufWithSubpage(PooledByteBuf buf, long handle, int bitmapIdx, int reqCapacity) { + private void initBufWithSubpage(PooledByteBuf buf, ByteBuffer nioBuffer, + long handle, int bitmapIdx, int reqCapacity) { assert bitmapIdx != 0; int memoryMapIdx = memoryMapIdx(handle); @@ -402,7 +427,7 @@ private void initBufWithSubpage(PooledByteBuf buf, long handle, int bitmapIdx assert reqCapacity <= subpage.elemSize; buf.init( - this, handle, + this, nioBuffer, handle, runOffset(memoryMapIdx) + (bitmapIdx & 0x3FFFFFFF) * subpage.elemSize + offset, reqCapacity, subpage.elemSize, arena.parent.threadCache()); } diff --git a/buffer/src/main/java/io/netty/buffer/PoolChunkList.java b/buffer/src/main/java/io/netty/buffer/PoolChunkList.java index f92834d85c4f..e610be8e9b1a 100644 --- a/buffer/src/main/java/io/netty/buffer/PoolChunkList.java +++ b/buffer/src/main/java/io/netty/buffer/PoolChunkList.java @@ -25,6 +25,8 @@ import static java.lang.Math.*; +import java.nio.ByteBuffer; + final class PoolChunkList implements PoolChunkListMetric { private static final Iterator EMPTY_METRICS = Collections.emptyList().iterator(); private final PoolArena arena; @@ -75,21 +77,14 @@ void prevList(PoolChunkList prevList) { } boolean allocate(PooledByteBuf buf, int reqCapacity, int normCapacity) { - if (head == null || normCapacity > maxCapacity) { + if (normCapacity > maxCapacity) { // Either this PoolChunkList is empty or the requested capacity is larger then the capacity which can // be handled by the PoolChunks that are contained in this PoolChunkList. return false; } - for (PoolChunk cur = head;;) { - long handle = cur.allocate(normCapacity); - if (handle < 0) { - cur = cur.next; - if (cur == null) { - return false; - } - } else { - cur.initBuf(buf, handle, reqCapacity); + for (PoolChunk cur = head; cur != null; cur = cur.next) { + if (cur.allocate(buf, reqCapacity, normCapacity)) { if (cur.usage() >= maxUsage) { remove(cur); nextList.add(cur); @@ -97,10 +92,11 @@ boolean allocate(PooledByteBuf buf, int reqCapacity, int normCapacity) { return true; } } + return false; } - boolean free(PoolChunk chunk, long handle) { - chunk.free(handle); + boolean free(PoolChunk chunk, long handle, ByteBuffer nioBuffer) { + chunk.free(handle, nioBuffer); if (chunk.usage() < minUsage) { remove(chunk); // Move the PoolChunk down the PoolChunkList linked-list. diff --git a/buffer/src/main/java/io/netty/buffer/PoolThreadCache.java b/buffer/src/main/java/io/netty/buffer/PoolThreadCache.java index 3503748c0d9b..01a69d570064 100644 --- a/buffer/src/main/java/io/netty/buffer/PoolThreadCache.java +++ b/buffer/src/main/java/io/netty/buffer/PoolThreadCache.java @@ -17,6 +17,8 @@ package io.netty.buffer; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.buffer.PoolArena.SizeClass; import io.netty.util.Recycler; import io.netty.util.Recycler.Handle; @@ -27,6 +29,7 @@ import java.nio.ByteBuffer; import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; /** * Acts a Thread cache for allocations. This implementation is moduled after @@ -54,6 +57,7 @@ final class PoolThreadCache { private final int numShiftsNormalDirect; private final int numShiftsNormalHeap; private final int freeSweepAllocationThreshold; + private final AtomicBoolean freed = new AtomicBoolean(); private int allocations; @@ -63,10 +67,7 @@ final class PoolThreadCache { PoolThreadCache(PoolArena heapArena, PoolArena directArena, int tinyCacheSize, int smallCacheSize, int normalCacheSize, int maxCachedBufferCapacity, int freeSweepAllocationThreshold) { - if (maxCachedBufferCapacity < 0) { - throw new IllegalArgumentException("maxCachedBufferCapacity: " - + maxCachedBufferCapacity + " (expected: >= 0)"); - } + checkPositiveOrZero(maxCachedBufferCapacity, "maxCachedBufferCapacity"); this.freeSweepAllocationThreshold = freeSweepAllocationThreshold; this.heapArena = heapArena; this.directArena = directArena; @@ -198,12 +199,13 @@ private boolean allocate(MemoryRegionCache cache, PooledByteBuf buf, int reqC * Returns {@code true} if it fit into the cache {@code false} otherwise. */ @SuppressWarnings({ "unchecked", "rawtypes" }) - boolean add(PoolArena area, PoolChunk chunk, long handle, int normCapacity, SizeClass sizeClass) { + boolean add(PoolArena area, PoolChunk chunk, ByteBuffer nioBuffer, + long handle, int normCapacity, SizeClass sizeClass) { MemoryRegionCache cache = cache(area, normCapacity, sizeClass); if (cache == null) { return false; } - return cache.add(chunk, handle); + return cache.add(chunk, nioBuffer, handle); } private MemoryRegionCache cache(PoolArena area, int normCapacity, SizeClass sizeClass) { @@ -219,27 +221,42 @@ private MemoryRegionCache cache(PoolArena area, int normCapacity, SizeClas } } + /// TODO: In the future when we move to Java9+ we should use java.lang.ref.Cleaner. + @Override + protected void finalize() throws Throwable { + try { + super.finalize(); + } finally { + free(); + } + } + /** * Should be called if the Thread that uses this cache is about to exist to release resources out of the cache */ void free() { - int numFreed = free(tinySubPageDirectCaches) + - free(smallSubPageDirectCaches) + - free(normalDirectCaches) + - free(tinySubPageHeapCaches) + - free(smallSubPageHeapCaches) + - free(normalHeapCaches); - - if (numFreed > 0 && logger.isDebugEnabled()) { - logger.debug("Freed {} thread-local buffer(s) from thread: {}", numFreed, Thread.currentThread().getName()); - } + // As free() may be called either by the finalizer or by FastThreadLocal.onRemoval(...) we need to ensure + // we only call this one time. + if (freed.compareAndSet(false, true)) { + int numFreed = free(tinySubPageDirectCaches) + + free(smallSubPageDirectCaches) + + free(normalDirectCaches) + + free(tinySubPageHeapCaches) + + free(smallSubPageHeapCaches) + + free(normalHeapCaches); + + if (numFreed > 0 && logger.isDebugEnabled()) { + logger.debug("Freed {} thread-local buffer(s) from thread: {}", numFreed, + Thread.currentThread().getName()); + } - if (directArena != null) { - directArena.numThreadCaches.getAndDecrement(); - } + if (directArena != null) { + directArena.numThreadCaches.getAndDecrement(); + } - if (heapArena != null) { - heapArena.numThreadCaches.getAndDecrement(); + if (heapArena != null) { + heapArena.numThreadCaches.getAndDecrement(); + } } } @@ -329,8 +346,8 @@ private static final class SubPageMemoryRegionCache extends MemoryRegionCache @Override protected void initBuf( - PoolChunk chunk, long handle, PooledByteBuf buf, int reqCapacity) { - chunk.initBufWithSubpage(buf, handle, reqCapacity); + PoolChunk chunk, ByteBuffer nioBuffer, long handle, PooledByteBuf buf, int reqCapacity) { + chunk.initBufWithSubpage(buf, nioBuffer, handle, reqCapacity); } } @@ -344,8 +361,8 @@ private static final class NormalMemoryRegionCache extends MemoryRegionCache< @Override protected void initBuf( - PoolChunk chunk, long handle, PooledByteBuf buf, int reqCapacity) { - chunk.initBuf(buf, handle, reqCapacity); + PoolChunk chunk, ByteBuffer nioBuffer, long handle, PooledByteBuf buf, int reqCapacity) { + chunk.initBuf(buf, nioBuffer, handle, reqCapacity); } } @@ -364,15 +381,15 @@ private abstract static class MemoryRegionCache { /** * Init the {@link PooledByteBuf} using the provided chunk and handle with the capacity restrictions. */ - protected abstract void initBuf(PoolChunk chunk, long handle, + protected abstract void initBuf(PoolChunk chunk, ByteBuffer nioBuffer, long handle, PooledByteBuf buf, int reqCapacity); /** * Add to cache if not already full. */ @SuppressWarnings("unchecked") - public final boolean add(PoolChunk chunk, long handle) { - Entry entry = newEntry(chunk, handle); + public final boolean add(PoolChunk chunk, ByteBuffer nioBuffer, long handle) { + Entry entry = newEntry(chunk, nioBuffer, handle); boolean queued = queue.offer(entry); if (!queued) { // If it was not possible to cache the chunk, immediately recycle the entry @@ -390,7 +407,7 @@ public final boolean allocate(PooledByteBuf buf, int reqCapacity) { if (entry == null) { return false; } - initBuf(entry.chunk, entry.handle, buf, reqCapacity); + initBuf(entry.chunk, entry.nioBuffer, entry.handle, buf, reqCapacity); entry.recycle(); // allocations is not thread-safe which is fine as this is only called from the same thread all time. @@ -436,16 +453,18 @@ public final void trim() { private void freeEntry(Entry entry) { PoolChunk chunk = entry.chunk; long handle = entry.handle; + ByteBuffer nioBuffer = entry.nioBuffer; // recycle now so PoolChunk can be GC'ed. entry.recycle(); - chunk.arena.freeChunk(chunk, handle, sizeClass); + chunk.arena.freeChunk(chunk, handle, sizeClass, nioBuffer); } static final class Entry { final Handle> recyclerHandle; PoolChunk chunk; + ByteBuffer nioBuffer; long handle = -1; Entry(Handle> recyclerHandle) { @@ -454,15 +473,17 @@ static final class Entry { void recycle() { chunk = null; + nioBuffer = null; handle = -1; recyclerHandle.recycle(this); } } @SuppressWarnings("rawtypes") - private static Entry newEntry(PoolChunk chunk, long handle) { + private static Entry newEntry(PoolChunk chunk, ByteBuffer nioBuffer, long handle) { Entry entry = RECYCLER.get(); entry.chunk = chunk; + entry.nioBuffer = nioBuffer; entry.handle = handle; return entry; } diff --git a/buffer/src/main/java/io/netty/buffer/PooledByteBuf.java b/buffer/src/main/java/io/netty/buffer/PooledByteBuf.java index 56a4be387232..beffbb07ee5a 100644 --- a/buffer/src/main/java/io/netty/buffer/PooledByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/PooledByteBuf.java @@ -33,7 +33,7 @@ abstract class PooledByteBuf extends AbstractReferenceCountedByteBuf { protected int length; int maxLength; PoolThreadCache cache; - private ByteBuffer tmpNioBuf; + ByteBuffer tmpNioBuf; private ByteBufAllocator allocator; @SuppressWarnings("unchecked") @@ -42,27 +42,29 @@ protected PooledByteBuf(Recycler.Handle> recyclerHand this.recyclerHandle = (Handle>) recyclerHandle; } - void init(PoolChunk chunk, long handle, int offset, int length, int maxLength, PoolThreadCache cache) { - init0(chunk, handle, offset, length, maxLength, cache); + void init(PoolChunk chunk, ByteBuffer nioBuffer, + long handle, int offset, int length, int maxLength, PoolThreadCache cache) { + init0(chunk, nioBuffer, handle, offset, length, maxLength, cache); } void initUnpooled(PoolChunk chunk, int length) { - init0(chunk, 0, chunk.offset, length, length, null); + init0(chunk, null, 0, chunk.offset, length, length, null); } - private void init0(PoolChunk chunk, long handle, int offset, int length, int maxLength, PoolThreadCache cache) { + private void init0(PoolChunk chunk, ByteBuffer nioBuffer, + long handle, int offset, int length, int maxLength, PoolThreadCache cache) { assert handle >= 0; assert chunk != null; this.chunk = chunk; memory = chunk.memory; + tmpNioBuf = nioBuffer; allocator = chunk.arena.parent; this.cache = cache; this.handle = handle; this.offset = offset; this.length = length; this.maxLength = maxLength; - tmpNioBuf = null; } /** @@ -166,8 +168,8 @@ protected final void deallocate() { final long handle = this.handle; this.handle = -1; memory = null; + chunk.arena.free(chunk, tmpNioBuf, handle, maxLength, cache); tmpNioBuf = null; - chunk.arena.free(chunk, handle, maxLength, cache); chunk = null; recycle(); } diff --git a/buffer/src/main/java/io/netty/buffer/PooledByteBufAllocator.java b/buffer/src/main/java/io/netty/buffer/PooledByteBufAllocator.java index b613fea9e1cb..bcfc9ceb1a13 100644 --- a/buffer/src/main/java/io/netty/buffer/PooledByteBufAllocator.java +++ b/buffer/src/main/java/io/netty/buffer/PooledByteBufAllocator.java @@ -16,6 +16,8 @@ package io.netty.buffer; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.util.NettyRuntime; import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.FastThreadLocalThread; @@ -45,6 +47,7 @@ public class PooledByteBufAllocator extends AbstractByteBufAllocator implements private static final int DEFAULT_CACHE_TRIM_INTERVAL; private static final boolean DEFAULT_USE_CACHE_FOR_ALL_THREADS; private static final int DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT; + static final int DEFAULT_MAX_CACHED_BYTEBUFFERS_PER_CHUNK; private static final int MIN_PAGE_SIZE = 4096; private static final int MAX_CHUNK_SIZE = (int) (((long) Integer.MAX_VALUE + 1) / 2); @@ -116,6 +119,11 @@ public class PooledByteBufAllocator extends AbstractByteBufAllocator implements DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT = SystemPropertyUtil.getInt( "io.netty.allocator.directMemoryCacheAlignment", 0); + // Use 1023 by default as we use an ArrayDeque as backing storage which will then allocate an internal array + // of 1024 elements. Otherwise we would allocate 2048 and only use 1024 which is wasteful. + DEFAULT_MAX_CACHED_BYTEBUFFERS_PER_CHUNK = SystemPropertyUtil.getInt( + "io.netty.allocator.maxCachedByteBuffersPerChunk", 1023); + if (logger.isDebugEnabled()) { logger.debug("-Dio.netty.allocator.numHeapArenas: {}", DEFAULT_NUM_HEAP_ARENA); logger.debug("-Dio.netty.allocator.numDirectArenas: {}", DEFAULT_NUM_DIRECT_ARENA); @@ -136,6 +144,8 @@ public class PooledByteBufAllocator extends AbstractByteBufAllocator implements logger.debug("-Dio.netty.allocator.maxCachedBufferCapacity: {}", DEFAULT_MAX_CACHED_BUFFER_CAPACITY); logger.debug("-Dio.netty.allocator.cacheTrimInterval: {}", DEFAULT_CACHE_TRIM_INTERVAL); logger.debug("-Dio.netty.allocator.useCacheForAllThreads: {}", DEFAULT_USE_CACHE_FOR_ALL_THREADS); + logger.debug("-Dio.netty.allocator.maxCachedByteBuffersPerChunk: {}", + DEFAULT_MAX_CACHED_BYTEBUFFERS_PER_CHUNK); } } @@ -207,17 +217,10 @@ public PooledByteBufAllocator(boolean preferDirect, int nHeapArena, int nDirectA this.normalCacheSize = normalCacheSize; chunkSize = validateAndCalculateChunkSize(pageSize, maxOrder); - if (nHeapArena < 0) { - throw new IllegalArgumentException("nHeapArena: " + nHeapArena + " (expected: >= 0)"); - } - if (nDirectArena < 0) { - throw new IllegalArgumentException("nDirectArea: " + nDirectArena + " (expected: >= 0)"); - } + checkPositiveOrZero(nHeapArena, "nHeapArena"); + checkPositiveOrZero(nDirectArena, "nDirectArena"); - if (directMemoryCacheAlignment < 0) { - throw new IllegalArgumentException("directMemoryCacheAlignment: " - + directMemoryCacheAlignment + " (expected: >= 0)"); - } + checkPositiveOrZero(directMemoryCacheAlignment, "directMemoryCacheAlignment"); if (directMemoryCacheAlignment > 0 && !isDirectMemoryCacheAlignmentSupported()) { throw new IllegalArgumentException("directMemoryCacheAlignment is not supported"); } @@ -580,7 +583,7 @@ final long usedDirectMemory() { return usedMemory(directArenas); } - private static long usedMemory(PoolArena... arenas) { + private static long usedMemory(PoolArena[] arenas) { if (arenas == null) { return -1; } diff --git a/buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java b/buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java index 3c43509c0545..9601150b319d 100644 --- a/buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java @@ -351,8 +351,8 @@ public ByteBuf setBytes(int index, ByteBuffer src) { @Override public int setBytes(int index, InputStream in, int length) throws IOException { checkIndex(index, length); - byte[] tmp = new byte[length]; - int readBytes = in.read(tmp); + byte[] tmp = ByteBufUtil.threadLocalTempArray(length); + int readBytes = in.read(tmp, 0, length); if (readBytes <= 0) { return readBytes; } diff --git a/buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java b/buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java index 1dcc3702c4b7..e2dc22cb07db 100644 --- a/buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java @@ -49,9 +49,9 @@ private PooledUnsafeDirectByteBuf(Recycler.Handle rec } @Override - void init(PoolChunk chunk, long handle, int offset, int length, int maxLength, - PoolThreadCache cache) { - super.init(chunk, handle, offset, length, maxLength, cache); + void init(PoolChunk chunk, ByteBuffer nioBuffer, + long handle, int offset, int length, int maxLength, PoolThreadCache cache) { + super.init(chunk, nioBuffer, handle, offset, length, maxLength, cache); initMemoryAddress(); } diff --git a/buffer/src/main/java/io/netty/buffer/ReadOnlyByteBufferBuf.java b/buffer/src/main/java/io/netty/buffer/ReadOnlyByteBufferBuf.java index 406514f384b0..c7cda05fd979 100644 --- a/buffer/src/main/java/io/netty/buffer/ReadOnlyByteBufferBuf.java +++ b/buffer/src/main/java/io/netty/buffer/ReadOnlyByteBufferBuf.java @@ -355,11 +355,11 @@ public ByteBuf getBytes(int index, OutputStream out, int length) throws IOExcept if (buffer.hasArray()) { out.write(buffer.array(), index + buffer.arrayOffset(), length); } else { - byte[] tmp = new byte[length]; + byte[] tmp = ByteBufUtil.threadLocalTempArray(length); ByteBuffer tmpBuf = internalNioBuffer(); tmpBuf.clear().position(index); - tmpBuf.get(tmp); - out.write(tmp); + tmpBuf.get(tmp, 0, length); + out.write(tmp, 0, length); } return this; } diff --git a/buffer/src/main/java/io/netty/buffer/SwappedByteBuf.java b/buffer/src/main/java/io/netty/buffer/SwappedByteBuf.java index 5d54a1f9f14c..abf27663f7a9 100644 --- a/buffer/src/main/java/io/netty/buffer/SwappedByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/SwappedByteBuf.java @@ -997,6 +997,11 @@ public int refCnt() { return buf.refCnt(); } + @Override + final boolean isAccessible() { + return buf.isAccessible(); + } + @Override public ByteBuf retain() { buf.retain(); diff --git a/buffer/src/main/java/io/netty/buffer/Unpooled.java b/buffer/src/main/java/io/netty/buffer/Unpooled.java index 3639a1776848..d7df1928857e 100644 --- a/buffer/src/main/java/io/netty/buffer/Unpooled.java +++ b/buffer/src/main/java/io/netty/buffer/Unpooled.java @@ -15,14 +15,14 @@ */ package io.netty.buffer; +import io.netty.buffer.CompositeByteBuf.ByteWrapper; import io.netty.util.internal.PlatformDependent; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.CharBuffer; import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.List; +import java.util.Arrays; /** @@ -219,7 +219,7 @@ public static ByteBuf wrappedBuffer(long memoryAddress, int size, boolean doFree * Creates a new buffer which wraps the specified buffer's readable bytes. * A modification on the specified buffer's content will be visible to the * returned buffer. - * @param buffer The buffer to wrap. Reference count ownership of this variable is transfered to this method. + * @param buffer The buffer to wrap. Reference count ownership of this variable is transferred to this method. * @return The readable portion of the {@code buffer}, or an empty buffer if there is no readable portion. * The caller is responsible for releasing this buffer. */ @@ -238,18 +238,18 @@ public static ByteBuf wrappedBuffer(ByteBuf buffer) { * content will be visible to the returned buffer. */ public static ByteBuf wrappedBuffer(byte[]... arrays) { - return wrappedBuffer(AbstractByteBufAllocator.DEFAULT_MAX_COMPONENTS, arrays); + return wrappedBuffer(arrays.length, arrays); } /** * Creates a new big-endian composite buffer which wraps the readable bytes of the * specified buffers without copying them. A modification on the content * of the specified buffers will be visible to the returned buffer. - * @param buffers The buffers to wrap. Reference count ownership of all variables is transfered to this method. + * @param buffers The buffers to wrap. Reference count ownership of all variables is transferred to this method. * @return The readable portion of the {@code buffers}. The caller is responsible for releasing this buffer. */ public static ByteBuf wrappedBuffer(ByteBuf... buffers) { - return wrappedBuffer(AbstractByteBufAllocator.DEFAULT_MAX_COMPONENTS, buffers); + return wrappedBuffer(buffers.length, buffers); } /** @@ -258,50 +258,49 @@ public static ByteBuf wrappedBuffer(ByteBuf... buffers) { * specified buffers will be visible to the returned buffer. */ public static ByteBuf wrappedBuffer(ByteBuffer... buffers) { - return wrappedBuffer(AbstractByteBufAllocator.DEFAULT_MAX_COMPONENTS, buffers); + return wrappedBuffer(buffers.length, buffers); } - /** - * Creates a new big-endian composite buffer which wraps the specified - * arrays without copying them. A modification on the specified arrays' - * content will be visible to the returned buffer. - */ - public static ByteBuf wrappedBuffer(int maxNumComponents, byte[]... arrays) { - switch (arrays.length) { + static ByteBuf wrappedBuffer(int maxNumComponents, ByteWrapper wrapper, T[] array) { + switch (array.length) { case 0: break; case 1: - if (arrays[0].length != 0) { - return wrappedBuffer(arrays[0]); + if (!wrapper.isEmpty(array[0])) { + return wrapper.wrap(array[0]); } break; default: - // Get the list of the component, while guessing the byte order. - final List components = new ArrayList(arrays.length); - for (byte[] a: arrays) { - if (a == null) { - break; + for (int i = 0, len = array.length; i < len; i++) { + T bytes = array[i]; + if (bytes == null) { + return EMPTY_BUFFER; } - if (a.length > 0) { - components.add(wrappedBuffer(a)); + if (!wrapper.isEmpty(bytes)) { + return new CompositeByteBuf(ALLOC, false, maxNumComponents, wrapper, array, i); } } - - if (!components.isEmpty()) { - return new CompositeByteBuf(ALLOC, false, maxNumComponents, components); - } } return EMPTY_BUFFER; } + /** + * Creates a new big-endian composite buffer which wraps the specified + * arrays without copying them. A modification on the specified arrays' + * content will be visible to the returned buffer. + */ + public static ByteBuf wrappedBuffer(int maxNumComponents, byte[]... arrays) { + return wrappedBuffer(maxNumComponents, CompositeByteBuf.BYTE_ARRAY_WRAPPER, arrays); + } + /** * Creates a new big-endian composite buffer which wraps the readable bytes of the * specified buffers without copying them. A modification on the content * of the specified buffers will be visible to the returned buffer. * @param maxNumComponents Advisement as to how many independent buffers are allowed to exist before * consolidation occurs. - * @param buffers The buffers to wrap. Reference count ownership of all variables is transfered to this method. + * @param buffers The buffers to wrap. Reference count ownership of all variables is transferred to this method. * @return The readable portion of the {@code buffers}. The caller is responsible for releasing this buffer. */ public static ByteBuf wrappedBuffer(int maxNumComponents, ByteBuf... buffers) { @@ -320,7 +319,7 @@ public static ByteBuf wrappedBuffer(int maxNumComponents, ByteBuf... buffers) { for (int i = 0; i < buffers.length; i++) { ByteBuf buf = buffers[i]; if (buf.isReadable()) { - return new CompositeByteBuf(ALLOC, false, maxNumComponents, buffers, i, buffers.length); + return new CompositeByteBuf(ALLOC, false, maxNumComponents, buffers, i); } buf.release(); } @@ -335,32 +334,7 @@ public static ByteBuf wrappedBuffer(int maxNumComponents, ByteBuf... buffers) { * specified buffers will be visible to the returned buffer. */ public static ByteBuf wrappedBuffer(int maxNumComponents, ByteBuffer... buffers) { - switch (buffers.length) { - case 0: - break; - case 1: - if (buffers[0].hasRemaining()) { - return wrappedBuffer(buffers[0].order(BIG_ENDIAN)); - } - break; - default: - // Get the list of the component, while guessing the byte order. - final List components = new ArrayList(buffers.length); - for (ByteBuffer b: buffers) { - if (b == null) { - break; - } - if (b.remaining() > 0) { - components.add(wrappedBuffer(b.order(BIG_ENDIAN))); - } - } - - if (!components.isEmpty()) { - return new CompositeByteBuf(ALLOC, false, maxNumComponents, components); - } - } - - return EMPTY_BUFFER; + return wrappedBuffer(maxNumComponents, CompositeByteBuf.BYTE_BUFFER_WRAPPER, buffers); } /** @@ -399,7 +373,7 @@ public static ByteBuf copiedBuffer(byte[] array, int offset, int length) { if (length == 0) { return EMPTY_BUFFER; } - byte[] copy = new byte[length]; + byte[] copy = PlatformDependent.allocateUninitializedArray(length); System.arraycopy(array, offset, copy, 0, length); return wrappedBuffer(copy); } @@ -415,7 +389,7 @@ public static ByteBuf copiedBuffer(ByteBuffer buffer) { if (length == 0) { return EMPTY_BUFFER; } - byte[] copy = new byte[length]; + byte[] copy = PlatformDependent.allocateUninitializedArray(length); // Duplicate the buffer so we not adjust the position during our get operation. // See https://github.com/netty/netty/issues/3896 ByteBuffer duplicate = buffer.duplicate(); @@ -472,7 +446,7 @@ public static ByteBuf copiedBuffer(byte[]... arrays) { return EMPTY_BUFFER; } - byte[] mergedArray = new byte[length]; + byte[] mergedArray = PlatformDependent.allocateUninitializedArray(length); for (int i = 0, j = 0; i < arrays.length; i ++) { byte[] a = arrays[i]; System.arraycopy(a, 0, mergedArray, j, a.length); @@ -526,7 +500,7 @@ public static ByteBuf copiedBuffer(ByteBuf... buffers) { return EMPTY_BUFFER; } - byte[] mergedArray = new byte[length]; + byte[] mergedArray = PlatformDependent.allocateUninitializedArray(length); for (int i = 0, j = 0; i < buffers.length; i ++) { ByteBuf b = buffers[i]; int bLen = b.readableBytes(); @@ -581,7 +555,7 @@ public static ByteBuf copiedBuffer(ByteBuffer... buffers) { return EMPTY_BUFFER; } - byte[] mergedArray = new byte[length]; + byte[] mergedArray = PlatformDependent.allocateUninitializedArray(length); for (int i = 0, j = 0; i < buffers.length; i ++) { // Duplicate the buffer so we not adjust the position during our get operation. // See https://github.com/netty/netty/issues/3896 @@ -881,9 +855,36 @@ public static ByteBuf unreleasableBuffer(ByteBuf buf) { /** * Wrap the given {@link ByteBuf}s in an unmodifiable {@link ByteBuf}. Be aware the returned {@link ByteBuf} will * not try to slice the given {@link ByteBuf}s to reduce GC-Pressure. + * + * @deprecated Use {@link #wrappedUnmodifiableBuffer(ByteBuf...)}. */ + @Deprecated public static ByteBuf unmodifiableBuffer(ByteBuf... buffers) { - return new FixedCompositeByteBuf(ALLOC, buffers); + return wrappedUnmodifiableBuffer(true, buffers); + } + + /** + * Wrap the given {@link ByteBuf}s in an unmodifiable {@link ByteBuf}. Be aware the returned {@link ByteBuf} will + * not try to slice the given {@link ByteBuf}s to reduce GC-Pressure. + * + * The returned {@link ByteBuf} may wrap the provided array directly, and so should not be subsequently modified. + */ + public static ByteBuf wrappedUnmodifiableBuffer(ByteBuf... buffers) { + return wrappedUnmodifiableBuffer(false, buffers); + } + + private static ByteBuf wrappedUnmodifiableBuffer(boolean copy, ByteBuf... buffers) { + switch (buffers.length) { + case 0: + return EMPTY_BUFFER; + case 1: + return buffers[0].asReadOnly(); + default: + if (copy) { + buffers = Arrays.copyOf(buffers, buffers.length, ByteBuf[].class); + } + return new FixedCompositeByteBuf(ALLOC, buffers); + } } private Unpooled() { diff --git a/buffer/src/main/java/io/netty/buffer/UnpooledByteBufAllocator.java b/buffer/src/main/java/io/netty/buffer/UnpooledByteBufAllocator.java index 4edf0dcd3d93..6fe188ed8bdb 100644 --- a/buffer/src/main/java/io/netty/buffer/UnpooledByteBufAllocator.java +++ b/buffer/src/main/java/io/netty/buffer/UnpooledByteBufAllocator.java @@ -140,14 +140,14 @@ private static final class InstrumentedUnpooledUnsafeHeapByteBuf extends Unpoole } @Override - byte[] allocateArray(int initialCapacity) { + protected byte[] allocateArray(int initialCapacity) { byte[] bytes = super.allocateArray(initialCapacity); ((UnpooledByteBufAllocator) alloc()).incrementHeap(bytes.length); return bytes; } @Override - void freeArray(byte[] array) { + protected void freeArray(byte[] array) { int length = array.length; super.freeArray(array); ((UnpooledByteBufAllocator) alloc()).decrementHeap(length); @@ -160,14 +160,14 @@ private static final class InstrumentedUnpooledHeapByteBuf extends UnpooledHeapB } @Override - byte[] allocateArray(int initialCapacity) { + protected byte[] allocateArray(int initialCapacity) { byte[] bytes = super.allocateArray(initialCapacity); ((UnpooledByteBufAllocator) alloc()).incrementHeap(bytes.length); return bytes; } @Override - void freeArray(byte[] array) { + protected void freeArray(byte[] array) { int length = array.length; super.freeArray(array); ((UnpooledByteBufAllocator) alloc()).decrementHeap(length); diff --git a/buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java b/buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java index f93e256588cd..bcf816c6d2e7 100644 --- a/buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java @@ -15,6 +15,8 @@ */ package io.netty.buffer; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.util.internal.PlatformDependent; import java.io.IOException; @@ -52,12 +54,8 @@ public UnpooledDirectByteBuf(ByteBufAllocator alloc, int initialCapacity, int ma if (alloc == null) { throw new NullPointerException("alloc"); } - if (initialCapacity < 0) { - throw new IllegalArgumentException("initialCapacity: " + initialCapacity); - } - if (maxCapacity < 0) { - throw new IllegalArgumentException("maxCapacity: " + maxCapacity); - } + checkPositiveOrZero(initialCapacity, "initialCapacity"); + checkPositiveOrZero(maxCapacity, "maxCapacity"); if (initialCapacity > maxCapacity) { throw new IllegalArgumentException(String.format( "initialCapacity(%d) > maxCapacity(%d)", initialCapacity, maxCapacity)); @@ -561,8 +559,8 @@ public int setBytes(int index, InputStream in, int length) throws IOException { if (buffer.hasArray()) { return in.read(buffer.array(), buffer.arrayOffset() + index, length); } else { - byte[] tmp = new byte[length]; - int readBytes = in.read(tmp); + byte[] tmp = ByteBufUtil.threadLocalTempArray(length); + int readBytes = in.read(tmp, 0, length); if (readBytes <= 0) { return readBytes; } diff --git a/buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java b/buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java index 26b1c5122b5d..f37ceb0559aa 100644 --- a/buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java @@ -84,11 +84,11 @@ protected UnpooledHeapByteBuf(ByteBufAllocator alloc, byte[] initialArray, int m setIndex(0, initialArray.length); } - byte[] allocateArray(int initialCapacity) { + protected byte[] allocateArray(int initialCapacity) { return new byte[initialCapacity]; } - void freeArray(byte[] array) { + protected void freeArray(byte[] array) { // NOOP } @@ -536,7 +536,7 @@ protected void _setLongLE(int index, long value) { @Override public ByteBuf copy(int index, int length) { checkIndex(index, length); - byte[] copiedArray = new byte[length]; + byte[] copiedArray = PlatformDependent.allocateUninitializedArray(length); System.arraycopy(array, index, copiedArray, 0, length); return new UnpooledHeapByteBuf(alloc(), copiedArray, maxCapacity()); } diff --git a/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java b/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java index 5ff0c222d27b..0658139d7501 100644 --- a/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java @@ -15,6 +15,8 @@ */ package io.netty.buffer; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.util.internal.PlatformDependent; import java.io.IOException; @@ -53,12 +55,8 @@ public UnpooledUnsafeDirectByteBuf(ByteBufAllocator alloc, int initialCapacity, if (alloc == null) { throw new NullPointerException("alloc"); } - if (initialCapacity < 0) { - throw new IllegalArgumentException("initialCapacity: " + initialCapacity); - } - if (maxCapacity < 0) { - throw new IllegalArgumentException("maxCapacity: " + maxCapacity); - } + checkPositiveOrZero(initialCapacity, "initialCapacity"); + checkPositiveOrZero(maxCapacity, "maxCapacity"); if (initialCapacity > maxCapacity) { throw new IllegalArgumentException(String.format( "initialCapacity(%d) > maxCapacity(%d)", initialCapacity, maxCapacity)); diff --git a/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeHeapByteBuf.java b/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeHeapByteBuf.java index 51786f1567c4..0fbe856466f1 100644 --- a/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeHeapByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeHeapByteBuf.java @@ -30,7 +30,7 @@ class UnpooledUnsafeHeapByteBuf extends UnpooledHeapByteBuf { } @Override - byte[] allocateArray(int initialCapacity) { + protected byte[] allocateArray(int initialCapacity) { return PlatformDependent.allocateUninitializedArray(initialCapacity); } diff --git a/buffer/src/main/java/io/netty/buffer/UnsafeByteBufUtil.java b/buffer/src/main/java/io/netty/buffer/UnsafeByteBufUtil.java index 7d866d5b433a..05918a8ebb18 100644 --- a/buffer/src/main/java/io/netty/buffer/UnsafeByteBufUtil.java +++ b/buffer/src/main/java/io/netty/buffer/UnsafeByteBufUtil.java @@ -584,7 +584,9 @@ static void getBytes(AbstractByteBuf buf, long addr, int index, OutputStream out buf.checkIndex(index, length); if (length != 0) { int len = Math.min(length, ByteBufUtil.WRITE_CHUNK_SIZE); - if (buf.alloc().isDirectBufferPooled()) { + if (len <= ByteBufUtil.MAX_TL_ARRAY_LEN || !buf.alloc().isDirectBufferPooled()) { + getBytes(addr, ByteBufUtil.threadLocalTempArray(len), 0, len, out, length); + } else { // if direct buffers are pooled chances are good that heap buffers are pooled as well. ByteBuf tmpBuf = buf.alloc().heapBuffer(len); try { @@ -594,8 +596,6 @@ static void getBytes(AbstractByteBuf buf, long addr, int index, OutputStream out } finally { tmpBuf.release(); } - } else { - getBytes(addr, new byte[len], 0, len, out, length); } } } diff --git a/buffer/src/main/java/io/netty/buffer/WrappedByteBuf.java b/buffer/src/main/java/io/netty/buffer/WrappedByteBuf.java index 45aa60ce889b..33570e200447 100644 --- a/buffer/src/main/java/io/netty/buffer/WrappedByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/WrappedByteBuf.java @@ -1033,4 +1033,9 @@ public boolean release() { public boolean release(int decrement) { return buf.release(decrement); } + + @Override + final boolean isAccessible() { + return buf.isAccessible(); + } } diff --git a/buffer/src/main/java/io/netty/buffer/WrappedCompositeByteBuf.java b/buffer/src/main/java/io/netty/buffer/WrappedCompositeByteBuf.java index 8f5161620f98..c300ea087e25 100644 --- a/buffer/src/main/java/io/netty/buffer/WrappedCompositeByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/WrappedCompositeByteBuf.java @@ -423,6 +423,11 @@ public final int refCnt() { return wrapped.refCnt(); } + @Override + final boolean isAccessible() { + return wrapped.isAccessible(); + } + @Override public ByteBuf duplicate() { return wrapped.duplicate(); diff --git a/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java b/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java index 1af64a7392d6..59194ab374d4 100644 --- a/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java @@ -30,6 +30,7 @@ import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.CharBuffer; import java.nio.ReadOnlyBufferException; import java.nio.channels.Channels; import java.nio.channels.FileChannel; @@ -3648,11 +3649,23 @@ public void testSetUtf16CharSequence() { testSetGetCharSequence(CharsetUtil.UTF_16); } + private static final CharBuffer EXTENDED_ASCII_CHARS, ASCII_CHARS; + + static { + char[] chars = new char[256]; + for (char c = 0; c < chars.length; c++) { + chars[c] = c; + } + EXTENDED_ASCII_CHARS = CharBuffer.wrap(chars); + ASCII_CHARS = CharBuffer.wrap(chars, 0, 128); + } + private void testSetGetCharSequence(Charset charset) { - ByteBuf buf = newBuffer(16); - String sequence = "AB"; + ByteBuf buf = newBuffer(1024); + CharBuffer sequence = CharsetUtil.US_ASCII.equals(charset) + ? ASCII_CHARS : EXTENDED_ASCII_CHARS; int bytes = buf.setCharSequence(1, sequence, charset); - assertEquals(sequence, buf.getCharSequence(1, bytes, charset)); + assertEquals(sequence, CharBuffer.wrap(buf.getCharSequence(1, bytes, charset))); buf.release(); } @@ -3677,12 +3690,13 @@ public void testWriteReadUtf16CharSequence() { } private void testWriteReadCharSequence(Charset charset) { - ByteBuf buf = newBuffer(16); - String sequence = "AB"; + ByteBuf buf = newBuffer(1024); + CharBuffer sequence = CharsetUtil.US_ASCII.equals(charset) + ? ASCII_CHARS : EXTENDED_ASCII_CHARS; buf.writerIndex(1); int bytes = buf.writeCharSequence(sequence, charset); buf.readerIndex(1); - assertEquals(sequence, buf.readCharSequence(bytes, charset)); + assertEquals(sequence, CharBuffer.wrap(buf.readCharSequence(bytes, charset))); buf.release(); } diff --git a/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java b/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java index f51475bf9ae6..ffeebaac9c46 100644 --- a/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java @@ -16,6 +16,7 @@ package io.netty.buffer; import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.PlatformDependent; import org.junit.Assume; import org.junit.Test; @@ -51,6 +52,8 @@ */ public abstract class AbstractCompositeByteBufTest extends AbstractByteBufTest { + private static final ByteBufAllocator ALLOC = UnpooledByteBufAllocator.DEFAULT; + private final ByteOrder order; protected AbstractCompositeByteBufTest(ByteOrder order) { @@ -87,7 +90,7 @@ protected ByteBuf newBuffer(int length, int maxCapacity) { buffers.add(EMPTY_BUFFER); } - ByteBuf buffer = wrappedBuffer(Integer.MAX_VALUE, buffers.toArray(new ByteBuf[buffers.size()])).order(order); + ByteBuf buffer = wrappedBuffer(Integer.MAX_VALUE, buffers.toArray(new ByteBuf[0])).order(order); // Truncate to the requested capacity. buffer.capacity(length); @@ -132,6 +135,41 @@ public void testComponentAtOffset() { buf.release(); } + @Test + public void testToComponentIndex() { + CompositeByteBuf buf = (CompositeByteBuf) wrappedBuffer(new byte[]{1, 2, 3, 4, 5}, + new byte[]{4, 5, 6, 7, 8, 9, 26}, new byte[]{10, 9, 8, 7, 6, 5, 33}); + + // spot checks + assertEquals(0, buf.toComponentIndex(4)); + assertEquals(1, buf.toComponentIndex(5)); + assertEquals(2, buf.toComponentIndex(15)); + + //Loop through each byte + + byte index = 0; + + while (index < buf.capacity()) { + int cindex = buf.toComponentIndex(index++); + assertTrue(cindex >= 0 && cindex < buf.numComponents()); + } + + buf.release(); + } + + @Test + public void testToByteIndex() { + CompositeByteBuf buf = (CompositeByteBuf) wrappedBuffer(new byte[]{1, 2, 3, 4, 5}, + new byte[]{4, 5, 6, 7, 8, 9, 26}, new byte[]{10, 9, 8, 7, 6, 5, 33}); + + // spot checks + assertEquals(0, buf.toByteIndex(0)); + assertEquals(5, buf.toByteIndex(1)); + assertEquals(12, buf.toByteIndex(2)); + + buf.release(); + } + @Test public void testDiscardReadBytes3() { ByteBuf a, b; @@ -744,6 +782,20 @@ public void testRemoveLastComponentWithOthersLeft() { buf.release(); } + @Test + public void testRemoveComponents() { + CompositeByteBuf buf = compositeBuffer(); + for (int i = 0; i < 10; i++) { + buf.addComponent(wrappedBuffer(new byte[]{1, 2})); + } + assertEquals(10, buf.numComponents()); + assertEquals(20, buf.capacity()); + buf.removeComponents(4, 3); + assertEquals(7, buf.numComponents()); + assertEquals(14, buf.capacity()); + buf.release(); + } + @Test public void testGatheringWritesHeap() throws Exception { testGatheringWrites(buffer().order(order), buffer().order(order)); @@ -1017,6 +1069,33 @@ public void testAddEmptyBufferInMiddle() { cbuf.release(); } + @Test + public void testInsertEmptyBufferInMiddle() { + CompositeByteBuf cbuf = compositeBuffer(); + ByteBuf buf1 = buffer().writeByte((byte) 1); + cbuf.addComponent(true, buf1); + ByteBuf buf2 = buffer().writeByte((byte) 2); + cbuf.addComponent(true, buf2); + + // insert empty one between the first two + cbuf.addComponent(true, 1, EMPTY_BUFFER); + + assertEquals(2, cbuf.readableBytes()); + assertEquals((byte) 1, cbuf.readByte()); + assertEquals((byte) 2, cbuf.readByte()); + + assertEquals(2, cbuf.capacity()); + assertEquals(3, cbuf.numComponents()); + + byte[] dest = new byte[2]; + // should skip over the empty one, not throw a java.lang.Error :) + cbuf.getBytes(0, dest); + + assertArrayEquals(new byte[] {1, 2}, dest); + + cbuf.release(); + } + @Test public void testIterator() { CompositeByteBuf cbuf = compositeBuffer(); @@ -1116,6 +1195,97 @@ public void testReleasesItsComponents() { assertEquals(0, buffer.refCnt()); } + @Test + public void testReleasesItsComponents2() { + // It is important to use a pooled allocator here to ensure + // the slices returned by readRetainedSlice are of type + // PooledSlicedByteBuf, which maintains an independent refcount + // (so that we can be sure to cover this case) + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer(); // 1 + + buffer.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + // use readRetainedSlice this time - produces different kind of slices + ByteBuf s1 = buffer.readRetainedSlice(2); // 2 + ByteBuf s2 = s1.readRetainedSlice(2); // 3 + ByteBuf s3 = s2.readRetainedSlice(2); // 4 + ByteBuf s4 = s3.readRetainedSlice(2); // 5 + + ByteBuf composite = Unpooled.compositeBuffer() + .addComponent(s1) + .addComponents(s2, s3, s4) + .order(ByteOrder.LITTLE_ENDIAN); + + assertEquals(1, composite.refCnt()); + assertEquals(2, buffer.refCnt()); + + // releasing composite should release the 4 components + composite.release(); + assertEquals(0, composite.refCnt()); + assertEquals(1, buffer.refCnt()); + + // last remaining ref to buffer + buffer.release(); + assertEquals(0, buffer.refCnt()); + } + + @Test + public void testReleasesOnShrink() { + + ByteBuf b1 = Unpooled.buffer(2).writeShort(1); + ByteBuf b2 = Unpooled.buffer(2).writeShort(2); + + // composite takes ownership of s1 and s2 + ByteBuf composite = Unpooled.compositeBuffer() + .addComponents(b1, b2); + + assertEquals(4, composite.capacity()); + + // reduce capacity down to two, will drop the second component + composite.capacity(2); + assertEquals(2, composite.capacity()); + + // releasing composite should release the components + composite.release(); + assertEquals(0, composite.refCnt()); + assertEquals(0, b1.refCnt()); + assertEquals(0, b2.refCnt()); + } + + @Test + public void testReleasesOnShrink2() { + // It is important to use a pooled allocator here to ensure + // the slices returned by readRetainedSlice are of type + // PooledSlicedByteBuf, which maintains an independent refcount + // (so that we can be sure to cover this case) + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer(); + + buffer.writeShort(1).writeShort(2); + + ByteBuf b1 = buffer.readRetainedSlice(2); + ByteBuf b2 = b1.retainedSlice(b1.readerIndex(), 2); + + // composite takes ownership of b1 and b2 + ByteBuf composite = Unpooled.compositeBuffer() + .addComponents(b1, b2); + + assertEquals(4, composite.capacity()); + + // reduce capacity down to two, will drop the second component + composite.capacity(2); + assertEquals(2, composite.capacity()); + + // releasing composite should release the components + composite.release(); + assertEquals(0, composite.refCnt()); + assertEquals(0, b1.refCnt()); + assertEquals(0, b2.refCnt()); + + // release last remaining ref to buffer + buffer.release(); + assertEquals(0, buffer.refCnt()); + } + @Test public void testAllocatorIsSameWhenCopy() { testAllocatorIsSameWhenCopy(false); @@ -1136,4 +1306,76 @@ private void testAllocatorIsSameWhenCopy(boolean withIndexAndLength) { buffer.release(); copy.release(); } + + @Test + public void testDecomposeMultiple() { + testDecompose(150, 500, 3); + } + + @Test + public void testDecomposeOne() { + testDecompose(310, 50, 1); + } + + @Test + public void testDecomposeNone() { + testDecompose(310, 0, 0); + } + + private static void testDecompose(int offset, int length, int expectedListSize) { + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + ByteBuf buf = wrappedBuffer(bytes); + + CompositeByteBuf composite = compositeBuffer(); + composite.addComponents(true, + buf.retainedSlice(100, 200), + buf.retainedSlice(300, 400), + buf.retainedSlice(700, 100)); + + ByteBuf slice = composite.slice(offset, length); + List bufferList = composite.decompose(offset, length); + assertEquals(expectedListSize, bufferList.size()); + ByteBuf wrapped = wrappedBuffer(bufferList.toArray(new ByteBuf[0])); + + assertEquals(slice, wrapped); + composite.release(); + buf.release(); + + for (ByteBuf buffer: bufferList) { + assertEquals(0, buffer.refCnt()); + } + } + + @Test + public void testComponentsLessThanLowerBound() { + try { + new CompositeByteBuf(ALLOC, true, 0); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("maxNumComponents: 0 (expected: >= 1)", e.getMessage()); + } + } + + @Test + public void testComponentsEqualToLowerBound() { + assertCompositeBufCreated(1); + } + + @Test + public void testComponentsGreaterThanLowerBound() { + assertCompositeBufCreated(5); + } + + /** + * Assert that a new {@linkplain CompositeByteBuf} was created successfully with the desired number of max + * components. + */ + private static void assertCompositeBufCreated(int expectedMaxComponents) { + CompositeByteBuf buf = new CompositeByteBuf(ALLOC, true, expectedMaxComponents); + + assertEquals(expectedMaxComponents, buf.maxNumComponents()); + assertTrue(buf.release()); + } + } diff --git a/buffer/src/test/java/io/netty/buffer/ByteBufStreamTest.java b/buffer/src/test/java/io/netty/buffer/ByteBufStreamTest.java index f05594d1c12e..222ddcf7ee73 100644 --- a/buffer/src/test/java/io/netty/buffer/ByteBufStreamTest.java +++ b/buffer/src/test/java/io/netty/buffer/ByteBufStreamTest.java @@ -188,11 +188,13 @@ public void testReadLine() throws Exception { String s = in.readLine(); assertNull(s); - int charCount = 5; //total chars in the string below without new line characters - byte[] abc = "a\nb\r\nc\nd\ne".getBytes(utf8); + int charCount = 7; //total chars in the string below without new line characters + byte[] abc = "\na\n\nb\r\nc\nd\ne".getBytes(utf8); buf.writeBytes(abc); in.mark(charCount); + assertEquals("", in.readLine()); assertEquals("a", in.readLine()); + assertEquals("", in.readLine()); assertEquals("b", in.readLine()); assertEquals("c", in.readLine()); assertEquals("d", in.readLine()); diff --git a/buffer/src/test/java/io/netty/buffer/FixedCompositeByteBufTest.java b/buffer/src/test/java/io/netty/buffer/FixedCompositeByteBufTest.java index 481c70fc8430..b6260f7848bf 100644 --- a/buffer/src/test/java/io/netty/buffer/FixedCompositeByteBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/FixedCompositeByteBufTest.java @@ -401,7 +401,7 @@ public void testHasMemoryAddressWhenEmpty() { buf.release(); } - @Test(expected = UnsupportedOperationException.class) + @Test public void testHasNoMemoryAddressWhenMultipleBuffers() { ByteBuf buf1 = directBuffer(10); if (!buf1.hasMemoryAddress()) { @@ -415,6 +415,8 @@ public void testHasNoMemoryAddressWhenMultipleBuffers() { try { buf.memoryAddress(); fail(); + } catch (UnsupportedOperationException expected) { + // expected } finally { buf.release(); } diff --git a/buffer/src/test/java/io/netty/buffer/PoolArenaTest.java b/buffer/src/test/java/io/netty/buffer/PoolArenaTest.java index 3fafef9c04ef..0e18eaf3f049 100644 --- a/buffer/src/test/java/io/netty/buffer/PoolArenaTest.java +++ b/buffer/src/test/java/io/netty/buffer/PoolArenaTest.java @@ -16,6 +16,7 @@ package io.netty.buffer; +import io.netty.util.internal.PlatformDependent; import org.junit.Assert; import org.junit.Test; @@ -43,6 +44,25 @@ public void testNormalizeAlignedCapacity() throws Exception { } } + @Test + public void testDirectArenaOffsetCacheLine() throws Exception { + int capacity = 5; + int alignment = 128; + + for (int i = 0; i < 1000; i++) { + ByteBuffer bb = PlatformDependent.useDirectBufferNoCleaner() + ? PlatformDependent.allocateDirectNoCleaner(capacity + alignment) + : ByteBuffer.allocateDirect(capacity + alignment); + + PoolArena.DirectArena arena = new PoolArena.DirectArena(null, 0, 0, 9, 9, alignment); + int offset = arena.offsetCacheLine(bb); + long address = PlatformDependent.directBufferAddress(bb); + + Assert.assertEquals(0, (offset + address) & (alignment - 1)); + PlatformDependent.freeDirectWithCleaner(bb); + } + } + @Test public final void testAllocationCounter() { final PooledByteBufAllocator allocator = new PooledByteBufAllocator( diff --git a/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java b/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java index 495bb765406a..39a613249387 100644 --- a/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java +++ b/buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java @@ -430,8 +430,14 @@ public void testConcurrentUsage() throws Throwable { Thread.sleep(100); } } finally { + // First mark all AllocationThreads to complete their work and then wait until these are complete + // and rethrow if there was any error. for (AllocationThread t : threads) { - t.finish(); + t.markAsFinished(); + } + + for (AllocationThread t: threads) { + t.joinAndCheckForError(); } } } @@ -461,7 +467,7 @@ private static final class AllocationThread extends Thread { private final ByteBufAllocator allocator; private final AtomicReference finish = new AtomicReference(); - public AllocationThread(ByteBufAllocator allocator) { + AllocationThread(ByteBufAllocator allocator) { this.allocator = allocator; } @@ -494,14 +500,17 @@ private void releaseBuffers() { } } - public boolean isFinished() { + boolean isFinished() { return finish.get() != null; } - public void finish() throws Throwable { + void markAsFinished() { + finish.compareAndSet(null, Boolean.TRUE); + } + + void joinAndCheckForError() throws Throwable { try { // Mark as finish if not already done but ensure we not override the previous set error. - finish.compareAndSet(null, Boolean.TRUE); join(); } finally { releaseBuffers(); @@ -509,7 +518,7 @@ public void finish() throws Throwable { checkForError(); } - public void checkForError() throws Throwable { + void checkForError() throws Throwable { Object obj = finish.get(); if (obj instanceof Throwable) { throw (Throwable) obj; diff --git a/build.yaml b/build.yaml index f07a35b5c801..81ab0c48282a 100644 --- a/build.yaml +++ b/build.yaml @@ -3,20 +3,13 @@ schedules: schedule: adhoc os: - osx/high-sierra - - ubuntu/trusty64 -java: - - oraclejdk8 + - ubuntu/xenial64 build: - script: | echo "OS VERSION ===== $OS_VERSION" if [ "$OS_VERSION" = "osx/high-sierra" ]; then mvn -B clean -DskipTests - mvn -B -pl transport-native-unix-common,transport-native-kqueue -Partifactory deploy -DskipTests -DaltDeploymentRepository="artifactory::default::https://repo.datastax.com/datastax-releases-local" + mvn -B -U -pl transport-native-unix-common,transport-native-kqueue -Partifactory deploy -DskipTests -DaltDeploymentRepository="artifactory::default::https://repo.sjc.dsinternal.org/artifactory/datastax-releases-local" else - export DEBIAN_FRONTEND=noninteractive - export MAVEN_HOME=/home/jenkins/.mvn/apache-maven-3.2.5 - export PATH=$MAVEN_HOME/bin:$PATH - sudo apt-get update - sudo apt-get install -y autoconf automake libtool make tar gcc-multilib libaio-dev - mvn -B clean deploy -Partifactory -DskipTests -DaltDeploymentRepository="artifactory::default::https://repo.datastax.com/datastax-releases-local" + ./docker-datastax-release.sh fi diff --git a/codec-dns/pom.xml b/codec-dns/pom.xml index fe34a550a86a..94de3b97e695 100644 --- a/codec-dns/pom.xml +++ b/codec-dns/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-dns @@ -33,6 +33,21 @@ + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/AbstractDnsRecord.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/AbstractDnsRecord.java index 28b92c27f928..2ba6e573a7fd 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/AbstractDnsRecord.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/AbstractDnsRecord.java @@ -21,6 +21,7 @@ import java.net.IDN; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; /** * A skeletal implementation of {@link DnsRecord}. @@ -62,9 +63,7 @@ protected AbstractDnsRecord(String name, DnsRecordType type, long timeToLive) { * @param timeToLive the TTL value of the record */ protected AbstractDnsRecord(String name, DnsRecordType type, int dnsClass, long timeToLive) { - if (timeToLive < 0) { - throw new IllegalArgumentException("timeToLive: " + timeToLive + " (expected: >= 0)"); - } + checkPositiveOrZero(timeToLive, "timeToLive"); // Convert to ASCII which will also check that the length is not too big. // See: // - https://github.com/netty/netty/issues/4937 diff --git a/codec-haproxy/pom.xml b/codec-haproxy/pom.xml index 5f06d4b867d8..5f2ee5c3615f 100644 --- a/codec-haproxy/pom.xml +++ b/codec-haproxy/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-haproxy @@ -33,6 +33,16 @@ + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec diff --git a/codec-http/pom.xml b/codec-http/pom.xml index db4407b0697f..d3953462213f 100644 --- a/codec-http/pom.xml +++ b/codec-http/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-http @@ -33,6 +33,21 @@ + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec @@ -42,7 +57,6 @@ ${project.groupId} netty-handler ${project.version} - true com.jcraft diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/CombinedHttpHeaders.java b/codec-http/src/main/java/io/netty/handler/codec/http/CombinedHttpHeaders.java index 2d43b7ad04a3..b8a2d0b8915c 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/CombinedHttpHeaders.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/CombinedHttpHeaders.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; +import static io.netty.handler.codec.http.HttpHeaderNames.SET_COOKIE; import static io.netty.util.AsciiString.CASE_INSENSITIVE_HASHER; import static io.netty.util.internal.StringUtil.COMMA; import static io.netty.util.internal.StringUtil.unescapeCsvFields; @@ -78,7 +79,7 @@ public CharSequence escape(CharSequence value) { return charSequenceEscaper; } - public CombinedHttpHeadersImpl(HashingStrategy nameHashingStrategy, + CombinedHttpHeadersImpl(HashingStrategy nameHashingStrategy, ValueConverter valueConverter, io.netty.handler.codec.DefaultHeaders.NameValidator nameValidator) { super(nameHashingStrategy, valueConverter, nameValidator); @@ -87,7 +88,7 @@ public CombinedHttpHeadersImpl(HashingStrategy nameHashingStrategy @Override public Iterator valueIterator(CharSequence name) { Iterator itr = super.valueIterator(name); - if (!itr.hasNext()) { + if (!itr.hasNext() || cannotBeCombined(name)) { return itr; } Iterator unescapedItr = unescapeCsvFields(itr.next()).iterator(); @@ -100,7 +101,7 @@ public Iterator valueIterator(CharSequence name) { @Override public List getAll(CharSequence name) { List values = super.getAll(name); - if (values.isEmpty()) { + if (values.isEmpty() || cannotBeCombined(name)) { return values; } if (values.size() != 1) { @@ -213,9 +214,13 @@ public CombinedHttpHeadersImpl setObject(CharSequence name, Iterable values) return this; } + private static boolean cannotBeCombined(CharSequence name) { + return SET_COOKIE.contentEqualsIgnoreCase(name); + } + private CombinedHttpHeadersImpl addEscapedValue(CharSequence name, CharSequence escapedValue) { CharSequence currentValue = super.get(name); - if (currentValue == null) { + if (currentValue == null || cannotBeCombined(name)) { super.add(name, escapedValue); } else { super.set(name, commaSeparateEscapedValues(currentValue, escapedValue)); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/ComposedLastHttpContent.java b/codec-http/src/main/java/io/netty/handler/codec/http/ComposedLastHttpContent.java index b8b7dc967fac..1ef39e56fff7 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/ComposedLastHttpContent.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/ComposedLastHttpContent.java @@ -28,6 +28,11 @@ final class ComposedLastHttpContent implements LastHttpContent { this.trailingHeaders = trailingHeaders; } + ComposedLastHttpContent(HttpHeaders trailingHeaders, DecoderResult result) { + this(trailingHeaders); + this.result = result; + } + @Override public HttpHeaders trailingHeaders() { return trailingHeaders; diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeaders.java b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeaders.java index 88af27f738e6..ef24c7558342 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeaders.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeaders.java @@ -372,8 +372,7 @@ private static void validateHeaderNameElement(byte value) { default: // Check to see if the character is not an ASCII character, or invalid if (value < 0) { - throw new IllegalArgumentException("a header name cannot contain non-ASCII character: " + - value); + throw new IllegalArgumentException("a header name cannot contain non-ASCII character: " + value); } } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpResponse.java b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpResponse.java index 86858108a277..d5b7cf0b7b9a 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpResponse.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpResponse.java @@ -105,4 +105,23 @@ public HttpResponse setProtocolVersion(HttpVersion version) { public String toString() { return HttpMessageUtil.appendResponse(new StringBuilder(256), this).toString(); } + + @Override + public int hashCode() { + int result = 1; + result = 31 * result + status.hashCode(); + result = 31 * result + super.hashCode(); + return result; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttpResponse)) { + return false; + } + + DefaultHttpResponse other = (DefaultHttpResponse) o; + + return status.equals(other.status()) && super.equals(o); + } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java index e85adaaa373c..eb7b7c0a6c23 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java @@ -19,6 +19,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.util.ReferenceCountUtil; @@ -50,102 +51,107 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder out) throws Exception { - if (msg instanceof HttpResponse && ((HttpResponse) msg).status().code() == 100) { + try { + if (msg instanceof HttpResponse && ((HttpResponse) msg).status().code() == 100) { - if (!(msg instanceof LastHttpContent)) { - continueResponse = true; + if (!(msg instanceof LastHttpContent)) { + continueResponse = true; + } + // 100-continue response must be passed through. + out.add(ReferenceCountUtil.retain(msg)); + return; } - // 100-continue response must be passed through. - out.add(ReferenceCountUtil.retain(msg)); - return; - } - if (continueResponse) { - if (msg instanceof LastHttpContent) { - continueResponse = false; + if (continueResponse) { + if (msg instanceof LastHttpContent) { + continueResponse = false; + } + // 100-continue response must be passed through. + out.add(ReferenceCountUtil.retain(msg)); + return; } - // 100-continue response must be passed through. - out.add(ReferenceCountUtil.retain(msg)); - return; - } - if (msg instanceof HttpMessage) { - cleanup(); - final HttpMessage message = (HttpMessage) msg; - final HttpHeaders headers = message.headers(); + if (msg instanceof HttpMessage) { + cleanup(); + final HttpMessage message = (HttpMessage) msg; + final HttpHeaders headers = message.headers(); - // Determine the content encoding. - String contentEncoding = headers.get(HttpHeaderNames.CONTENT_ENCODING); - if (contentEncoding != null) { - contentEncoding = contentEncoding.trim(); - } else { - contentEncoding = IDENTITY; - } - decoder = newContentDecoder(contentEncoding); + // Determine the content encoding. + String contentEncoding = headers.get(HttpHeaderNames.CONTENT_ENCODING); + if (contentEncoding != null) { + contentEncoding = contentEncoding.trim(); + } else { + contentEncoding = IDENTITY; + } + decoder = newContentDecoder(contentEncoding); - if (decoder == null) { - if (message instanceof HttpContent) { - ((HttpContent) message).retain(); + if (decoder == null) { + if (message instanceof HttpContent) { + ((HttpContent) message).retain(); + } + out.add(message); + return; } - out.add(message); - return; - } - // Remove content-length header: - // the correct value can be set only after all chunks are processed/decoded. - // If buffering is not an issue, add HttpObjectAggregator down the chain, it will set the header. - // Otherwise, rely on LastHttpContent message. - if (headers.contains(HttpHeaderNames.CONTENT_LENGTH)) { - headers.remove(HttpHeaderNames.CONTENT_LENGTH); - headers.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); - } - // Either it is already chunked or EOF terminated. - // See https://github.com/netty/netty/issues/5892 + // Remove content-length header: + // the correct value can be set only after all chunks are processed/decoded. + // If buffering is not an issue, add HttpObjectAggregator down the chain, it will set the header. + // Otherwise, rely on LastHttpContent message. + if (headers.contains(HttpHeaderNames.CONTENT_LENGTH)) { + headers.remove(HttpHeaderNames.CONTENT_LENGTH); + headers.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + } + // Either it is already chunked or EOF terminated. + // See https://github.com/netty/netty/issues/5892 - // set new content encoding, - CharSequence targetContentEncoding = getTargetContentEncoding(contentEncoding); - if (HttpHeaderValues.IDENTITY.contentEquals(targetContentEncoding)) { - // Do NOT set the 'Content-Encoding' header if the target encoding is 'identity' - // as per: http://tools.ietf.org/html/rfc2616#section-14.11 - headers.remove(HttpHeaderNames.CONTENT_ENCODING); - } else { - headers.set(HttpHeaderNames.CONTENT_ENCODING, targetContentEncoding); - } + // set new content encoding, + CharSequence targetContentEncoding = getTargetContentEncoding(contentEncoding); + if (HttpHeaderValues.IDENTITY.contentEquals(targetContentEncoding)) { + // Do NOT set the 'Content-Encoding' header if the target encoding is 'identity' + // as per: http://tools.ietf.org/html/rfc2616#section-14.11 + headers.remove(HttpHeaderNames.CONTENT_ENCODING); + } else { + headers.set(HttpHeaderNames.CONTENT_ENCODING, targetContentEncoding); + } - if (message instanceof HttpContent) { - // If message is a full request or response object (headers + data), don't copy data part into out. - // Output headers only; data part will be decoded below. - // Note: "copy" object must not be an instance of LastHttpContent class, - // as this would (erroneously) indicate the end of the HttpMessage to other handlers. - HttpMessage copy; - if (message instanceof HttpRequest) { - HttpRequest r = (HttpRequest) message; // HttpRequest or FullHttpRequest - copy = new DefaultHttpRequest(r.protocolVersion(), r.method(), r.uri()); - } else if (message instanceof HttpResponse) { - HttpResponse r = (HttpResponse) message; // HttpResponse or FullHttpResponse - copy = new DefaultHttpResponse(r.protocolVersion(), r.status()); + if (message instanceof HttpContent) { + // If message is a full request or response object (headers + data), don't copy data part into out. + // Output headers only; data part will be decoded below. + // Note: "copy" object must not be an instance of LastHttpContent class, + // as this would (erroneously) indicate the end of the HttpMessage to other handlers. + HttpMessage copy; + if (message instanceof HttpRequest) { + HttpRequest r = (HttpRequest) message; // HttpRequest or FullHttpRequest + copy = new DefaultHttpRequest(r.protocolVersion(), r.method(), r.uri()); + } else if (message instanceof HttpResponse) { + HttpResponse r = (HttpResponse) message; // HttpResponse or FullHttpResponse + copy = new DefaultHttpResponse(r.protocolVersion(), r.status()); + } else { + throw new CodecException("Object of class " + message.getClass().getName() + + " is not a HttpRequest or HttpResponse"); + } + copy.headers().set(message.headers()); + copy.setDecoderResult(message.decoderResult()); + out.add(copy); } else { - throw new CodecException("Object of class " + message.getClass().getName() + - " is not a HttpRequest or HttpResponse"); + out.add(message); } - copy.headers().set(message.headers()); - copy.setDecoderResult(message.decoderResult()); - out.add(copy); - } else { - out.add(message); } - } - if (msg instanceof HttpContent) { - final HttpContent c = (HttpContent) msg; - if (decoder == null) { - out.add(c.retain()); - } else { - decodeContent(c, out); + if (msg instanceof HttpContent) { + final HttpContent c = (HttpContent) msg; + if (decoder == null) { + out.add(c.retain()); + } else { + decodeContent(c, out); + } } + } finally { + needRead = out.isEmpty(); } } @@ -164,7 +170,21 @@ private void decodeContent(HttpContent c, List out) { if (headers.isEmpty()) { out.add(LastHttpContent.EMPTY_LAST_CONTENT); } else { - out.add(new ComposedLastHttpContent(headers)); + out.add(new ComposedLastHttpContent(headers, DecoderResult.SUCCESS)); + } + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + boolean needRead = this.needRead; + this.needRead = true; + + try { + ctx.fireChannelReadComplete(); + } finally { + if (needRead && !ctx.channel().config().isAutoRead()) { + ctx.read(); } } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java index 0078edc42134..7be3b8b92e92 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBufHolder; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.MessageToMessageCodec; import io.netty.util.ReferenceCountUtil; @@ -77,10 +78,10 @@ protected void decode(ChannelHandlerContext ctx, HttpRequest msg, List o acceptedEncoding = HttpContentDecoder.IDENTITY; } - HttpMethod meth = msg.method(); - if (meth == HttpMethod.HEAD) { + HttpMethod method = msg.method(); + if (HttpMethod.HEAD.equals(method)) { acceptedEncoding = ZERO_LENGTH_HEAD; - } else if (meth == HttpMethod.CONNECT) { + } else if (HttpMethod.CONNECT.equals(method)) { acceptedEncoding = ZERO_LENGTH_CONNECT; } @@ -264,7 +265,7 @@ private boolean encodeContent(HttpContent c, List out) { if (headers.isEmpty()) { out.add(LastHttpContent.EMPTY_LAST_CONTENT); } else { - out.add(new ComposedLastHttpContent(headers)); + out.add(new ComposedLastHttpContent(headers, DecoderResult.SUCCESS)); } return true; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaders.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaders.java index 35cfc5c5f181..694bc4d56e44 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaders.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaders.java @@ -1695,7 +1695,7 @@ public String toString() { } /** - * Returns a deap copy of the passed in {@link HttpHeaders}. + * Returns a deep copy of the passed in {@link HttpHeaders}. */ public HttpHeaders copy() { return new DefaultHttpHeaders().set(this); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpMethod.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpMethod.java index 3d975516d5fc..a634bd0016f3 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpMethod.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpMethod.java @@ -156,6 +156,9 @@ public int hashCode() { @Override public boolean equals(Object o) { + if (this == o) { + return true; + } if (!(o instanceof HttpMethod)) { return false; } @@ -171,6 +174,9 @@ public String toString() { @Override public int compareTo(HttpMethod o) { + if (o == this) { + return 0; + } return name().compareTo(o.name()); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java index af1d642a0397..8d4fc1d42957 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec.http; +import static io.netty.util.internal.ObjectUtil.checkPositive; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; @@ -168,21 +170,10 @@ protected HttpObjectDecoder( protected HttpObjectDecoder( int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean chunkedSupported, boolean validateHeaders, int initialBufferSize) { - if (maxInitialLineLength <= 0) { - throw new IllegalArgumentException( - "maxInitialLineLength must be a positive integer: " + - maxInitialLineLength); - } - if (maxHeaderSize <= 0) { - throw new IllegalArgumentException( - "maxHeaderSize must be a positive integer: " + - maxHeaderSize); - } - if (maxChunkSize <= 0) { - throw new IllegalArgumentException( - "maxChunkSize must be a positive integer: " + - maxChunkSize); - } + checkPositive(maxInitialLineLength, "maxInitialLineLength"); + checkPositive(maxHeaderSize, "maxHeaderSize"); + checkPositive(maxChunkSize, "maxChunkSize"); + AppendableCharSequence seq = new AppendableCharSequence(initialBufferSize); lineParser = new LineParser(seq, maxInitialLineLength); headerParser = new HeaderParser(seq, maxHeaderSize); @@ -584,7 +575,7 @@ private State readHeaders(ByteBuf buffer) { } if (line.length() > 0) { do { - char firstChar = line.charAt(0); + char firstChar = line.charAtUnsafe(0); if (name != null && (firstChar == ' ' || firstChar == '\t')) { //please do not make one line from below code //as it breaks +XX:OptimizeStringConcat optimization @@ -609,23 +600,61 @@ private State readHeaders(ByteBuf buffer) { if (name != null) { headers.add(name, value); } + // reset name and value fields name = null; value = null; - State nextState; + List values = headers.getAll(HttpHeaderNames.CONTENT_LENGTH); + int contentLengthValuesCount = values.size(); + + if (contentLengthValuesCount > 0) { + // Guard against multiple Content-Length headers as stated in + // https://tools.ietf.org/html/rfc7230#section-3.3.2: + // + // If a message is received that has multiple Content-Length header + // fields with field-values consisting of the same decimal value, or a + // single Content-Length header field with a field value containing a + // list of identical decimal values (e.g., "Content-Length: 42, 42"), + // indicating that duplicate Content-Length header fields have been + // generated or combined by an upstream message processor, then the + // recipient MUST either reject the message as invalid or replace the + // duplicated field-values with a single valid Content-Length field + // containing that decimal value prior to determining the message body + // length or forwarding the message. + if (contentLengthValuesCount > 1 && message.protocolVersion() == HttpVersion.HTTP_1_1) { + throw new IllegalArgumentException("Multiple Content-Length headers found"); + } + contentLength = Long.parseLong(values.get(0)); + } if (isContentAlwaysEmpty(message)) { HttpUtil.setTransferEncodingChunked(message, false); - nextState = State.SKIP_CONTROL_CHARS; + return State.SKIP_CONTROL_CHARS; } else if (HttpUtil.isTransferEncodingChunked(message)) { - nextState = State.READ_CHUNK_SIZE; + // See https://tools.ietf.org/html/rfc7230#section-3.3.3 + // + // If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. Such a message might indicate an attempt to + // perform request smuggling (Section 9.5) or response splitting + // (Section 9.4) and ought to be handled as an error. A sender MUST + // remove the received Content-Length field prior to forwarding such + // a message downstream. + // + // This is also what http_parser does: + // https://github.com/nodejs/http-parser/blob/v2.9.2/http_parser.c#L1769 + if (contentLengthValuesCount > 0 && message.protocolVersion() == HttpVersion.HTTP_1_1) { + throw new IllegalArgumentException( + "Both 'Content-Length: " + contentLength + "' and 'Transfer-Encoding: chunked' found"); + } + + return State.READ_CHUNK_SIZE; } else if (contentLength() >= 0) { - nextState = State.READ_FIXED_LENGTH_CONTENT; + return State.READ_FIXED_LENGTH_CONTENT; } else { - nextState = State.READ_VARIABLE_LENGTH_CONTENT; + return State.READ_VARIABLE_LENGTH_CONTENT; } - return nextState; } private long contentLength() { @@ -640,49 +669,50 @@ private LastHttpContent readTrailingHeaders(ByteBuf buffer) { if (line == null) { return null; } + LastHttpContent trailer = this.trailer; + if (line.length() == 0 && trailer == null) { + // We have received the empty line which signals the trailer is complete and did not parse any trailers + // before. Just return an empty last content to reduce allocations. + return LastHttpContent.EMPTY_LAST_CONTENT; + } + CharSequence lastHeader = null; - if (line.length() > 0) { - LastHttpContent trailer = this.trailer; - if (trailer == null) { - trailer = this.trailer = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, validateHeaders); - } - do { - char firstChar = line.charAt(0); - if (lastHeader != null && (firstChar == ' ' || firstChar == '\t')) { - List current = trailer.trailingHeaders().getAll(lastHeader); - if (!current.isEmpty()) { - int lastPos = current.size() - 1; - //please do not make one line from below code - //as it breaks +XX:OptimizeStringConcat optimization - String lineTrimmed = line.toString().trim(); - String currentLastPos = current.get(lastPos); - current.set(lastPos, currentLastPos + lineTrimmed); - } - } else { - splitHeader(line); - CharSequence headerName = name; - if (!HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(headerName) && + if (trailer == null) { + trailer = this.trailer = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, validateHeaders); + } + while (line.length() > 0) { + char firstChar = line.charAtUnsafe(0); + if (lastHeader != null && (firstChar == ' ' || firstChar == '\t')) { + List current = trailer.trailingHeaders().getAll(lastHeader); + if (!current.isEmpty()) { + int lastPos = current.size() - 1; + //please do not make one line from below code + //as it breaks +XX:OptimizeStringConcat optimization + String lineTrimmed = line.toString().trim(); + String currentLastPos = current.get(lastPos); + current.set(lastPos, currentLastPos + lineTrimmed); + } + } else { + splitHeader(line); + CharSequence headerName = name; + if (!HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(headerName) && !HttpHeaderNames.TRANSFER_ENCODING.contentEqualsIgnoreCase(headerName) && !HttpHeaderNames.TRAILER.contentEqualsIgnoreCase(headerName)) { - trailer.trailingHeaders().add(headerName, value); - } - lastHeader = name; - // reset name and value fields - name = null; - value = null; - } - - line = headerParser.parse(buffer); - if (line == null) { - return null; + trailer.trailingHeaders().add(headerName, value); } - } while (line.length() > 0); - - this.trailer = null; - return trailer; + lastHeader = name; + // reset name and value fields + name = null; + value = null; + } + line = headerParser.parse(buffer); + if (line == null) { + return null; + } } - return LastHttpContent.EMPTY_LAST_CONTENT; + this.trailer = null; + return trailer; } protected abstract boolean isDecodingRequest(); @@ -735,14 +765,33 @@ private void splitHeader(AppendableCharSequence sb) { nameStart = findNonWhitespace(sb, 0); for (nameEnd = nameStart; nameEnd < length; nameEnd ++) { - char ch = sb.charAt(nameEnd); - if (ch == ':' || Character.isWhitespace(ch)) { + char ch = sb.charAtUnsafe(nameEnd); + // https://tools.ietf.org/html/rfc7230#section-3.2.4 + // + // No whitespace is allowed between the header field-name and colon. In + // the past, differences in the handling of such whitespace have led to + // security vulnerabilities in request routing and response handling. A + // server MUST reject any received request message that contains + // whitespace between a header field-name and colon with a response code + // of 400 (Bad Request). A proxy MUST remove any such whitespace from a + // response message before forwarding the message downstream. + if (ch == ':' || + // In case of decoding a request we will just continue processing and header validation + // is done in the DefaultHttpHeaders implementation. + // + // In the case of decoding a response we will "skip" the whitespace. + (!isDecodingRequest() && Character.isWhitespace(ch))) { break; } } + if (nameEnd == length) { + // There was no colon present at all. + throw new IllegalArgumentException("No colon found"); + } + for (colonEnd = nameEnd; colonEnd < length; colonEnd ++) { - if (sb.charAt(colonEnd) == ':') { + if (sb.charAtUnsafe(colonEnd) == ':') { colonEnd ++; break; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectEncoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectEncoder.java index fe03378bba6e..5841dc1e1528 100755 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectEncoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectEncoder.java @@ -83,7 +83,8 @@ protected void encode(ChannelHandlerContext ctx, Object msg, List out) t ByteBuf buf = null; if (msg instanceof HttpMessage) { if (state != ST_INIT) { - throw new IllegalStateException("unexpected message type: " + StringUtil.simpleClassName(msg)); + throw new IllegalStateException("unexpected message type: " + StringUtil.simpleClassName(msg) + + ", state: " + state); } @SuppressWarnings({ "unchecked", "CastConflictsWithInstanceof" }) diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseStatus.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseStatus.java index b7e2c10d456c..ef60a41d4f3b 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseStatus.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseStatus.java @@ -22,6 +22,7 @@ import static io.netty.handler.codec.http.HttpConstants.SP; import static io.netty.util.ByteProcessor.FIND_ASCII_SPACE; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static java.lang.Integer.parseInt; /** @@ -538,10 +539,7 @@ public HttpResponseStatus(int code, String reasonPhrase) { } private HttpResponseStatus(int code, String reasonPhrase, boolean bytes) { - if (code < 0) { - throw new IllegalArgumentException( - "code: " + code + " (expected: 0+)"); - } + checkPositiveOrZero(code, "code"); if (reasonPhrase == null) { throw new NullPointerException("reasonPhrase"); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java index 4e8d61361b29..6e36128fa23c 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java @@ -81,16 +81,18 @@ public void upgradeFrom(ChannelHandlerContext ctx) { } private final class HttpServerRequestDecoder extends HttpRequestDecoder { - public HttpServerRequestDecoder(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) { + + HttpServerRequestDecoder(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) { super(maxInitialLineLength, maxHeaderSize, maxChunkSize); } - public HttpServerRequestDecoder(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, + HttpServerRequestDecoder(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders) { super(maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders); } - public HttpServerRequestDecoder(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, + HttpServerRequestDecoder(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, + boolean validateHeaders, int initialBufferSize) { super(maxInitialLineLength, maxHeaderSize, maxChunkSize, validateHeaders, initialBufferSize); } @@ -115,7 +117,8 @@ private final class HttpServerResponseEncoder extends HttpResponseEncoder { @Override protected void sanitizeHeadersBeforeEncode(HttpResponse msg, boolean isAlwaysEmpty) { - if (!isAlwaysEmpty && method == HttpMethod.CONNECT && msg.status().codeClass() == HttpStatusClass.SUCCESS) { + if (!isAlwaysEmpty && HttpMethod.CONNECT.equals(method) + && msg.status().codeClass() == HttpStatusClass.SUCCESS) { // Stripping Transfer-Encoding: // See https://tools.ietf.org/html/rfc7230#section-3.3.1 msg.headers().remove(HttpHeaderNames.TRANSFER_ENCODING); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java index f1f3efcb1de7..2b54b0e4b211 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java @@ -14,9 +14,6 @@ */ package io.netty.handler.codec.http; -import static io.netty.util.AsciiString.containsContentEqualsIgnoreCase; -import static io.netty.util.AsciiString.containsAllContentEqualsIgnoreCase; - import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -30,7 +27,10 @@ import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static io.netty.util.AsciiString.containsAllContentEqualsIgnoreCase; +import static io.netty.util.AsciiString.containsContentEqualsIgnoreCase; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.StringUtil.COMMA; /** * A server-side handler that receives HTTP requests and optionally performs a protocol switch if @@ -284,16 +284,23 @@ private boolean upgrade(final ChannelHandlerContext ctx, final FullHttpRequest r } // Make sure the CONNECTION header is present. - CharSequence connectionHeader = request.headers().get(HttpHeaderNames.CONNECTION); - if (connectionHeader == null) { + List connectionHeaderValues = request.headers().getAll(HttpHeaderNames.CONNECTION); + + if (connectionHeaderValues == null) { return false; } + final StringBuilder concatenatedConnectionValue = new StringBuilder(connectionHeaderValues.size() * 10); + for (CharSequence connectionHeaderValue : connectionHeaderValues) { + concatenatedConnectionValue.append(connectionHeaderValue).append(COMMA); + } + concatenatedConnectionValue.setLength(concatenatedConnectionValue.length() - 1); + // Make sure the CONNECTION header contains UPGRADE as well as all protocol-specific headers. Collection requiredHeaders = upgradeCodec.requiredUpgradeHeaders(); - List values = splitHeader(connectionHeader); + List values = splitHeader(concatenatedConnectionValue); if (!containsContentEqualsIgnoreCase(values, HttpHeaderNames.UPGRADE) || - !containsAllContentEqualsIgnoreCase(values, requiredHeaders)) { + !containsAllContentEqualsIgnoreCase(values, requiredHeaders)) { return false; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java index 5b6e90e9abd3..8e04b10db62a 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java @@ -15,18 +15,18 @@ */ package io.netty.handler.codec.http; -import io.netty.util.AsciiString; -import io.netty.util.CharsetUtil; -import io.netty.util.NetUtil; - import java.net.InetSocketAddress; import java.net.URI; -import java.util.ArrayList; import java.nio.charset.Charset; import java.nio.charset.UnsupportedCharsetException; +import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; + /** * Utility methods useful in the HTTP context. */ @@ -60,12 +60,13 @@ public static boolean isAsteriskForm(URI uri) { /** * Returns {@code true} if and only if the connection can remain open and * thus 'kept alive'. This methods respects the value of the. + * * {@code "Connection"} header first and then the return value of * {@link HttpVersion#isKeepAliveDefault()}. */ public static boolean isKeepAlive(HttpMessage message) { CharSequence connection = message.headers().get(HttpHeaderNames.CONNECTION); - if (connection != null && HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(connection)) { + if (HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(connection)) { return false; } @@ -193,6 +194,7 @@ public static long getContentLength(HttpMessage message, long defaultValue) { /** * Get an {@code int} representation of {@link #getContentLength(HttpMessage, long)}. + * * @return the content length or {@code defaultValue} if this message does * not have the {@code "Content-Length"} header or its value is not * a number. Not to exceed the boundaries of integer. @@ -249,13 +251,9 @@ public static boolean isContentLengthSet(HttpMessage m) { * present */ public static boolean is100ContinueExpected(HttpMessage message) { - if (!isExpectHeaderValid(message)) { - return false; - } - - final String expectValue = message.headers().get(HttpHeaderNames.EXPECT); - // unquoted tokens in the expect header are case-insensitive, thus 100-continue is case insensitive - return HttpHeaderValues.CONTINUE.toString().equalsIgnoreCase(expectValue); + return isExpectHeaderValid(message) + // unquoted tokens in the expect header are case-insensitive, thus 100-continue is case insensitive + && message.headers().contains(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE, true); } /** @@ -313,6 +311,7 @@ public static boolean isTransferEncodingChunked(HttpMessage message) { /** * Set the {@link HttpHeaderNames#TRANSFER_ENCODING} to either include {@link HttpHeaderValues#CHUNKED} if * {@code chunked} is {@code true}, or remove {@link HttpHeaderValues#CHUNKED} if {@code chunked} is {@code false}. + * * @param m The message which contains the headers to modify. * @param chunked if {@code true} then include {@link HttpHeaderValues#CHUNKED} in the headers. otherwise remove * {@link HttpHeaderValues#CHUNKED} from the headers. @@ -371,7 +370,7 @@ public static Charset getCharset(CharSequence contentTypeValue) { /** * Fetch charset from message's Content-Type header. * - * @param message entity to fetch Content-Type header from + * @param message entity to fetch Content-Type header from * @param defaultCharset result to use in case of empty, incorrect or doesn't contain required part header value * @return the charset from message's Content-Type header or {@code defaultCharset} * if charset is not presented or unparsable @@ -389,7 +388,7 @@ public static Charset getCharset(HttpMessage message, Charset defaultCharset) { * Fetch charset from Content-Type header value. * * @param contentTypeValue Content-Type header value to parse - * @param defaultCharset result to use in case of empty, incorrect or doesn't contain required part header value + * @param defaultCharset result to use in case of empty, incorrect or doesn't contain required part header value * @return the charset from message's Content-Type header or {@code defaultCharset} * if charset is not presented or unparsable */ @@ -459,13 +458,23 @@ public static CharSequence getCharsetAsSequence(CharSequence contentTypeValue) { if (contentTypeValue == null) { throw new NullPointerException("contentTypeValue"); } + int indexOfCharset = AsciiString.indexOfIgnoreCaseAscii(contentTypeValue, CHARSET_EQUALS, 0); - if (indexOfCharset != AsciiString.INDEX_NOT_FOUND) { - int indexOfEncoding = indexOfCharset + CHARSET_EQUALS.length(); - if (indexOfEncoding < contentTypeValue.length()) { - return contentTypeValue.subSequence(indexOfEncoding, contentTypeValue.length()); + if (indexOfCharset == AsciiString.INDEX_NOT_FOUND) { + return null; + } + + int indexOfEncoding = indexOfCharset + CHARSET_EQUALS.length(); + if (indexOfEncoding < contentTypeValue.length()) { + CharSequence charsetCandidate = contentTypeValue.subSequence(indexOfEncoding, contentTypeValue.length()); + int indexOfSemicolon = AsciiString.indexOfIgnoreCaseAscii(charsetCandidate, SEMICOLON, 0); + if (indexOfSemicolon == AsciiString.INDEX_NOT_FOUND) { + return charsetCandidate; } + + return charsetCandidate.subSequence(0, indexOfSemicolon); } + return null; } @@ -517,6 +526,7 @@ public static CharSequence getMimeType(CharSequence contentTypeValue) { /** * Formats the host string of an address so it can be used for computing an HTTP component * such as an URL or a Host header + * * @param addr the address * @return the formatted String */ @@ -526,7 +536,7 @@ public static String formatHostnameForHttp(InetSocketAddress addr) { if (!addr.isUnresolved()) { hostString = NetUtil.toAddressString(addr.getAddress()); } - return "[" + hostString + "]"; + return '[' + hostString + ']'; } return hostString; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java index a643f42458d9..7ba40eed90c3 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec.http; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.buffer.ByteBuf; import io.netty.util.CharsetUtil; @@ -165,12 +167,8 @@ private HttpVersion( } } - if (majorVersion < 0) { - throw new IllegalArgumentException("negative majorVersion"); - } - if (minorVersion < 0) { - throw new IllegalArgumentException("negative minorVersion"); - } + checkPositiveOrZero(majorVersion, "majorVersion"); + checkPositiveOrZero(minorVersion, "minorVersion"); this.protocolName = protocolName; this.majorVersion = majorVersion; diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieEncoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieEncoder.java index 9554b64ed462..0a66aaad7a2f 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieEncoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieEncoder.java @@ -161,7 +161,7 @@ public String encode(Collection cookies) { if (cookies.size() == 1) { encode(buf, cookies.iterator().next()); } else { - Cookie[] cookiesSorted = cookies.toArray(new Cookie[cookies.size()]); + Cookie[] cookiesSorted = cookies.toArray(new Cookie[0]); Arrays.sort(cookiesSorted, COOKIE_COMPARATOR); for (Cookie c : cookiesSorted) { encode(buf, c); @@ -198,7 +198,7 @@ public String encode(Iterable cookies) { while (cookiesIt.hasNext()) { cookiesList.add(cookiesIt.next()); } - Cookie[] cookiesSorted = cookiesList.toArray(new Cookie[cookiesList.size()]); + Cookie[] cookiesSorted = cookiesList.toArray(new Cookie[0]); Arrays.sort(cookiesSorted, COOKIE_COMPARATOR); for (Cookie c : cookiesSorted) { encode(buf, c); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieEncoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieEncoder.java index d64876815295..39d1d07ccba2 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieEncoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieEncoder.java @@ -44,7 +44,8 @@ protected void validateCookie(String name, String value) { } if ((pos = firstInvalidCookieValueOctet(unwrappedValue)) >= 0) { - throw new IllegalArgumentException("Cookie value contains an invalid char: " + value.charAt(pos)); + throw new IllegalArgumentException("Cookie value contains an invalid char: " + + unwrappedValue.charAt(pos)); } } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieUtil.java b/codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieUtil.java index 1e9d9c8f87e2..2e818b92eb39 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieUtil.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieUtil.java @@ -97,24 +97,24 @@ static String stripTrailingSeparator(StringBuilder buf) { static void add(StringBuilder sb, String name, long val) { sb.append(name); - sb.append((char) HttpConstants.EQUALS); + sb.append('='); sb.append(val); - sb.append((char) HttpConstants.SEMICOLON); - sb.append((char) HttpConstants.SP); + sb.append(';'); + sb.append(HttpConstants.SP_CHAR); } static void add(StringBuilder sb, String name, String val) { sb.append(name); - sb.append((char) HttpConstants.EQUALS); + sb.append('='); sb.append(val); - sb.append((char) HttpConstants.SEMICOLON); - sb.append((char) HttpConstants.SP); + sb.append(';'); + sb.append(HttpConstants.SP_CHAR); } static void add(StringBuilder sb, String name) { sb.append(name); - sb.append((char) HttpConstants.SEMICOLON); - sb.append((char) HttpConstants.SP); + sb.append(';'); + sb.append(HttpConstants.SP_CHAR); } static void addQuoted(StringBuilder sb, String name, String val) { @@ -123,12 +123,12 @@ static void addQuoted(StringBuilder sb, String name, String val) { } sb.append(name); - sb.append((char) HttpConstants.EQUALS); - sb.append((char) HttpConstants.DOUBLE_QUOTE); + sb.append('='); + sb.append('"'); sb.append(val); - sb.append((char) HttpConstants.DOUBLE_QUOTE); - sb.append((char) HttpConstants.SEMICOLON); - sb.append((char) HttpConstants.SP); + sb.append('"'); + sb.append(';'); + sb.append(HttpConstants.SP_CHAR); } static int firstInvalidCookieNameOctet(CharSequence cs) { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieEncoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieEncoder.java index b707dc33d180..c5a1d7d0e7d1 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieEncoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieEncoder.java @@ -105,10 +105,10 @@ public String encode(Cookie cookie) { add(buf, CookieHeaderNames.MAX_AGE, cookie.maxAge()); Date expires = new Date(cookie.maxAge() * 1000 + System.currentTimeMillis()); buf.append(CookieHeaderNames.EXPIRES); - buf.append((char) HttpConstants.EQUALS); + buf.append('='); DateFormatter.append(expires, buf); - buf.append((char) HttpConstants.SEMICOLON); - buf.append((char) HttpConstants.SP); + buf.append(';'); + buf.append(HttpConstants.SP_CHAR); } if (cookie.path() != null) { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java index 39f79a4cd1d8..f1e162f1acc2 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java @@ -191,7 +191,7 @@ private void setAllowCredentials(final HttpResponse response) { private static boolean isPreflightRequest(final HttpRequest request) { final HttpHeaders headers = request.headers(); - return request.method().equals(OPTIONS) && + return OPTIONS.equals(request.method()) && headers.contains(HttpHeaderNames.ORIGIN) && headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractHttpData.java b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractHttpData.java index ff05753bae53..6657f5f52743 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractHttpData.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractHttpData.java @@ -59,7 +59,9 @@ protected AbstractHttpData(String name, Charset charset, long size) { } @Override - public long getMaxSize() { return maxSize; } + public long getMaxSize() { + return maxSize; + } @Override public void setMaxSize(long maxSize) { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpData.java b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpData.java index 31aa9ce64b5f..4cb7e567b252 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpData.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpData.java @@ -128,8 +128,7 @@ public void setContent(File file) throws IOException { } long newsize = file.length(); if (newsize > Integer.MAX_VALUE) { - throw new IllegalArgumentException( - "File too big to be loaded in memory"); + throw new IllegalArgumentException("File too big to be loaded in memory"); } checkSize(newsize); FileInputStream inputStream = new FileInputStream(file); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java index 8e3a90c009ba..4fefc1139822 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java @@ -1515,6 +1515,6 @@ private static String[] splitMultipartHeaderValues(String svalue) { } } values.add(svalue.substring(start)); - return values.toArray(new String[values.size()]); + return values.toArray(new String[0]); } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8FrameValidator.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8FrameValidator.java index 55f07057ec08..5ce5ec369cb8 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8FrameValidator.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8FrameValidator.java @@ -47,7 +47,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception if ((frame instanceof TextWebSocketFrame) || (utf8Validator != null && utf8Validator.isChecking())) { // Check UTF-8 correctness for this payload - checkUTF8String(ctx, frame.content()); + checkUTF8String(frame.content()); // This does a second check to make sure UTF-8 // correctness for entire text message @@ -60,12 +60,12 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception if (fragmentedFramesCount == 0) { // First text or binary frame for a fragmented set if (frame instanceof TextWebSocketFrame) { - checkUTF8String(ctx, frame.content()); + checkUTF8String(frame.content()); } } else { // Subsequent frames - only check if init frame is text if (utf8Validator != null && utf8Validator.isChecking()) { - checkUTF8String(ctx, frame.content()); + checkUTF8String(frame.content()); } } @@ -77,17 +77,18 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception super.channelRead(ctx, msg); } - private void checkUTF8String(ChannelHandlerContext ctx, ByteBuf buffer) { - try { - if (utf8Validator == null) { - utf8Validator = new Utf8Validator(); - } - utf8Validator.check(buffer); - } catch (CorruptedFrameException ex) { - if (ctx.channel().isActive()) { - ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); - } + private void checkUTF8String(ByteBuf buffer) { + if (utf8Validator == null) { + utf8Validator = new Utf8Validator(); } + utf8Validator.check(buffer); } + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof CorruptedFrameException && ctx.channel().isOpen()) { + ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); + } + super.exceptionCaught(ctx, cause); + } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java index 604746c5ae2f..f5a6dc193600 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java @@ -451,7 +451,7 @@ protected void checkCloseFrameBody( // Must have 2 byte integer within the valid range int statusCode = buffer.readShort(); if (statusCode >= 0 && statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006 - || statusCode >= 1012 && statusCode <= 2999) { + || statusCode >= 1015 && statusCode <= 2999) { protocolViolation(ctx, "Invalid close frame getStatus code: " + statusCode); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java index 05070f74b764..f02626319f32 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java @@ -131,22 +131,23 @@ protected FullHttpRequest newHandshakeRequest() { // Format request FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); HttpHeaders headers = request.headers(); - headers.add(HttpHeaderNames.UPGRADE, WEBSOCKET) - .add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) - .add(HttpHeaderNames.HOST, websocketHostValue(wsURL)) - .add(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL)) - .add(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1) - .add(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2); - - String expectedSubprotocol = expectedSubprotocol(); - if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { - headers.add(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); - } if (customHeaders != null) { headers.add(customHeaders); } + headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) + .set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL)) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2); + + String expectedSubprotocol = expectedSubprotocol(); + if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); + } + // Set Content-Length to workaround some known defect. // See also: http://www.ietf.org/mail-archive/web/hybi/current/msg02149.html headers.set(HttpHeaderNames.CONTENT_LENGTH, key3.length); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java index f85d086b2f49..4632a4aecb69 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java @@ -145,22 +145,22 @@ protected FullHttpRequest newHandshakeRequest() { FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); HttpHeaders headers = request.headers(); - headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) - .add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) - .add(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) - .add(HttpHeaderNames.HOST, websocketHostValue(wsURL)) - .add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); + if (customHeaders != null) { + headers.add(customHeaders); + } + + headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) + .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) + .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); String expectedSubprotocol = expectedSubprotocol(); if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { - headers.add(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); + headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); } - headers.add(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "7"); - - if (customHeaders != null) { - headers.add(customHeaders); - } + headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "7"); return request; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java index 5bfef6449c11..1a11aa6358e7 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java @@ -146,22 +146,22 @@ protected FullHttpRequest newHandshakeRequest() { FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); HttpHeaders headers = request.headers(); - headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) - .add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) - .add(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) - .add(HttpHeaderNames.HOST, websocketHostValue(wsURL)) - .add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); + if (customHeaders != null) { + headers.add(customHeaders); + } + + headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) + .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) + .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); String expectedSubprotocol = expectedSubprotocol(); if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { - headers.add(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); + headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); } - headers.add(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "8"); - - if (customHeaders != null) { - headers.add(customHeaders); - } + headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "8"); return request; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java index 9490e3bcef51..808f7fc49ad2 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java @@ -146,22 +146,22 @@ protected FullHttpRequest newHandshakeRequest() { FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); HttpHeaders headers = request.headers(); - headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) - .add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) - .add(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) - .add(HttpHeaderNames.HOST, websocketHostValue(wsURL)) - .add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); + if (customHeaders != null) { + headers.add(customHeaders); + } + + headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) + .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) + .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); String expectedSubprotocol = expectedSubprotocol(); if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { - headers.add(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); + headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); } - headers.add(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13"); - - if (customHeaders != null) { - headers.add(customHeaders); - } + headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13"); return request; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java index 2ace7e8aa09e..a40e2d74e54e 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java @@ -45,7 +45,9 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { /** * Returns the used handshaker */ - public WebSocketClientHandshaker handshaker() { return handshaker; } + public WebSocketClientHandshaker handshaker() { + return handshaker; + } /** * Events that are fired to notify about handshake status @@ -151,6 +153,23 @@ public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version * {@code true} if close frames should not be forwarded and just close the channel */ public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames) { + this(handshaker, handleCloseFrames, true); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param dropPongFrames + * {@code true} if pong frames should not be forwarded + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, + boolean dropPongFrames) { + super(dropPongFrames); this.handshaker = handshaker; this.handleCloseFrames = handleCloseFrames; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandler.java index 7037ae346867..53532cafe98b 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandler.java @@ -22,6 +22,27 @@ import java.util.List; abstract class WebSocketProtocolHandler extends MessageToMessageDecoder { + + private final boolean dropPongFrames; + + /** + * Creates a new {@link WebSocketProtocolHandler} that will drop {@link PongWebSocketFrame}s. + */ + WebSocketProtocolHandler() { + this(true); + } + + /** + * Creates a new {@link WebSocketProtocolHandler}, given a parameter that determines whether or not to drop {@link + * PongWebSocketFrame}s. + * + * @param dropPongFrames + * {@code true} if {@link PongWebSocketFrame}s should be dropped + */ + WebSocketProtocolHandler(boolean dropPongFrames) { + this.dropPongFrames = dropPongFrames; + } + @Override protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List out) throws Exception { if (frame instanceof PingWebSocketFrame) { @@ -29,8 +50,7 @@ protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List methodMap = new HashMap(); @@ -130,7 +130,7 @@ public static HttpMethod valueOf(String name) { if (result != null) { return result; } else { - return new HttpMethod(name); + return HttpMethod.valueOf(name); } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyGoAwayFrame.java b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyGoAwayFrame.java index 4d88875a6e8c..79c21f2404d2 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyGoAwayFrame.java +++ b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyGoAwayFrame.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec.spdy; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.util.internal.StringUtil; /** @@ -62,10 +64,7 @@ public int lastGoodStreamId() { @Override public SpdyGoAwayFrame setLastGoodStreamId(int lastGoodStreamId) { - if (lastGoodStreamId < 0) { - throw new IllegalArgumentException("Last-good-stream-ID" - + " cannot be negative: " + lastGoodStreamId); - } + checkPositiveOrZero(lastGoodStreamId, "lastGoodStreamId"); this.lastGoodStreamId = lastGoodStreamId; return this; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyStreamFrame.java b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyStreamFrame.java index 4618d4d4a95c..487844ecd913 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyStreamFrame.java +++ b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyStreamFrame.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec.spdy; +import static io.netty.util.internal.ObjectUtil.checkPositive; + /** * The default {@link SpdyStreamFrame} implementation. */ @@ -39,10 +41,7 @@ public int streamId() { @Override public SpdyStreamFrame setStreamId(int streamId) { - if (streamId <= 0) { - throw new IllegalArgumentException( - "Stream-ID must be positive: " + streamId); - } + checkPositive(streamId, "streamId"); this.streamId = streamId; return this; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynReplyFrame.java b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynReplyFrame.java index 7efc905641e3..f757d1dbd663 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynReplyFrame.java +++ b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynReplyFrame.java @@ -20,8 +20,7 @@ /** * The default {@link SpdySynReplyFrame} implementation. */ -public class DefaultSpdySynReplyFrame extends DefaultSpdyHeadersFrame - implements SpdySynReplyFrame { +public class DefaultSpdySynReplyFrame extends DefaultSpdyHeadersFrame implements SpdySynReplyFrame { /** * Creates a new instance. diff --git a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynStreamFrame.java b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynStreamFrame.java index f8adc1c5f1ca..46fe30163634 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynStreamFrame.java +++ b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynStreamFrame.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec.spdy; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.util.internal.StringUtil; /** @@ -77,11 +79,7 @@ public int associatedStreamId() { @Override public SpdySynStreamFrame setAssociatedStreamId(int associatedStreamId) { - if (associatedStreamId < 0) { - throw new IllegalArgumentException( - "Associated-To-Stream-ID cannot be negative: " + - associatedStreamId); - } + checkPositiveOrZero(associatedStreamId, "associatedStreamId"); this.associatedStreamId = associatedStreamId; return this; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyWindowUpdateFrame.java b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyWindowUpdateFrame.java index f14611bac614..22b0406c80c0 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyWindowUpdateFrame.java +++ b/codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyWindowUpdateFrame.java @@ -15,6 +15,9 @@ */ package io.netty.handler.codec.spdy; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.util.internal.StringUtil; /** @@ -43,10 +46,7 @@ public int streamId() { @Override public SpdyWindowUpdateFrame setStreamId(int streamId) { - if (streamId < 0) { - throw new IllegalArgumentException( - "Stream-ID cannot be negative: " + streamId); - } + checkPositiveOrZero(streamId, "streamId"); this.streamId = streamId; return this; } @@ -58,11 +58,7 @@ public int deltaWindowSize() { @Override public SpdyWindowUpdateFrame setDeltaWindowSize(int deltaWindowSize) { - if (deltaWindowSize <= 0) { - throw new IllegalArgumentException( - "Delta-Window-Size must be positive: " + - deltaWindowSize); - } + checkPositive(deltaWindowSize, "deltaWindowSize"); this.deltaWindowSize = deltaWindowSize; return this; } diff --git a/codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoder.java index e0d1112813b7..fc432b683096 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoder.java @@ -38,6 +38,8 @@ import static io.netty.handler.codec.spdy.SpdyCodecUtil.getUnsignedInt; import static io.netty.handler.codec.spdy.SpdyCodecUtil.getUnsignedMedium; import static io.netty.handler.codec.spdy.SpdyCodecUtil.getUnsignedShort; +import static io.netty.util.internal.ObjectUtil.checkPositive; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -95,10 +97,7 @@ public SpdyFrameDecoder(SpdyVersion spdyVersion, SpdyFrameDecoderDelegate delega if (delegate == null) { throw new NullPointerException("delegate"); } - if (maxChunkSize <= 0) { - throw new IllegalArgumentException( - "maxChunkSize must be a positive integer: " + maxChunkSize); - } + checkPositive(maxChunkSize, "maxChunkSize"); this.spdyVersion = spdyVersion.getVersion(); this.delegate = delegate; this.maxChunkSize = maxChunkSize; diff --git a/codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpDecoder.java index 366ad15b662f..5e16a6f4f2a7 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpDecoder.java @@ -38,6 +38,7 @@ import java.util.Map; import static io.netty.handler.codec.spdy.SpdyHeaders.HttpNames.*; +import static io.netty.util.internal.ObjectUtil.checkPositive; /** * Decodes {@link SpdySynStreamFrame}s, {@link SpdySynReplyFrame}s, @@ -103,10 +104,7 @@ protected SpdyHttpDecoder(SpdyVersion version, int maxContentLength, Map strItr = headers.valueStringIterator(SET_COOKIE); + assertTrue(strItr.hasNext()); + assertEquals("a", strItr.next()); + assertTrue(strItr.hasNext()); + assertEquals("b", strItr.next()); + assertTrue(strItr.hasNext()); + assertEquals("c", strItr.next()); + } + private static void assertValueIterator(Iterator strItr) { assertTrue(strItr.hasNext()); assertEquals("a", strItr.next()); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpHeadersTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpHeadersTest.java index 3f2a678ab525..d0f6c414f463 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpHeadersTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpHeadersTest.java @@ -232,7 +232,7 @@ public void providesHeaderNamesAsArray() throws Exception { .add(HttpHeaderNames.CONTENT_LENGTH, 10) .names(); - String[] namesArray = nettyHeaders.toArray(new String[nettyHeaders.size()]); + String[] namesArray = nettyHeaders.toArray(new String[0]); assertArrayEquals(namesArray, new String[] { HttpHeaderNames.CONTENT_LENGTH.toString() }); } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpResponseTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpResponseTest.java new file mode 100644 index 000000000000..0a466c6bff97 --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpResponseTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +public class DefaultHttpResponseTest { + + @Test + public void testNotEquals() { + HttpResponse ok = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpResponse notFound = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND); + assertNotEquals(ok, notFound); + assertNotEquals(ok.hashCode(), notFound.hashCode()); + } + + @Test + public void testEquals() { + HttpResponse ok = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpResponse ok2 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + assertEquals(ok, ok2); + assertEquals(ok.hashCode(), ok2.hashCode()); + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorTest.java index 6ab2ab88c629..508eb5ae54c6 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorTest.java @@ -18,6 +18,7 @@ import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.EncoderException; import io.netty.handler.codec.compression.ZlibWrapper; import io.netty.util.CharsetUtil; @@ -188,6 +189,7 @@ public void testChunkedContentWithTrailingHeader() throws Exception { assertThat(chunk.content().isReadable(), is(false)); assertThat(chunk, is(instanceOf(LastHttpContent.class))); assertEquals("Netty", ((LastHttpContent) chunk).trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, chunk.decoderResult()); chunk.release(); assertThat(ch.readOutbound(), is(nullValue())); @@ -331,6 +333,7 @@ public void testEmptyFullContentWithTrailer() throws Exception { assertThat(res.content().readableBytes(), is(0)); assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); assertEquals("Netty", res.trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); assertThat(ch.readOutbound(), is(nullValue())); } @@ -370,6 +373,7 @@ public void test100Continue() throws Exception { assertThat(res.content().readableBytes(), is(0)); assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); assertEquals("Netty", res.trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); assertThat(ch.readOutbound(), is(nullValue())); } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java index 7a27a4c08acf..ff0628527a3c 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java @@ -89,6 +89,48 @@ public void testRequestDecompression() { assertFalse(channel.finish()); // assert that no messages are left in channel } + @Test + public void testChunkedRequestDecompression() { + HttpResponseDecoder decoder = new HttpResponseDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, null); + + String headers = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "Trailer: My-Trailer\r\n" + + "Content-Encoding: gzip\r\n\r\n"; + + channel.writeInbound(Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII))); + + String chunkLength = Integer.toHexString(GZ_HELLO_WORLD.length); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(chunkLength + "\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("\r\n".getBytes(CharsetUtil.US_ASCII)))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("0\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("My-Trailer: 42\r\n\r\n\r\n", CharsetUtil.US_ASCII))); + + Object ob1 = channel.readInbound(); + assertThat(ob1, is(instanceOf(DefaultHttpResponse.class))); + + Object ob2 = channel.readInbound(); + assertThat(ob1, is(instanceOf(DefaultHttpResponse.class))); + HttpContent content = (HttpContent) ob2; + assertEquals(HELLO_WORLD, content.content().toString(CharsetUtil.US_ASCII)); + content.release(); + + Object ob3 = channel.readInbound(); + assertThat(ob1, is(instanceOf(DefaultHttpResponse.class))); + LastHttpContent lastContent = (LastHttpContent) ob3; + assertNotNull(lastContent.decoderResult()); + assertTrue(lastContent.decoderResult().isSuccess()); + assertFalse(lastContent.trailingHeaders().isEmpty()); + assertEquals("42", lastContent.trailingHeaders().get("My-Trailer")); + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + @Test public void testResponseDecompression() { // baseline test: response decoder, content decompressor && request aggregator work as expected diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecompressorTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecompressorTest.java new file mode 100644 index 000000000000..4a659fad5ed6 --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecompressorTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Assert; +import org.junit.Test; + +import java.util.concurrent.atomic.AtomicInteger; + +public class HttpContentDecompressorTest { + + // See https://github.com/netty/netty/issues/8915. + @Test + public void testInvokeReadWhenNotProduceMessage() { + final AtomicInteger readCalled = new AtomicInteger(); + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void read(ChannelHandlerContext ctx) { + readCalled.incrementAndGet(); + ctx.read(); + } + }, new HttpContentDecompressor(), new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.fireChannelRead(msg); + ctx.read(); + } + }); + + channel.config().setAutoRead(false); + + readCalled.set(0); + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + response.headers().set(HttpHeaderNames.CONTENT_ENCODING, "gzip"); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json;charset=UTF-8"); + response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + + Assert.assertTrue(channel.writeInbound(response)); + + // we triggered read explicitly + Assert.assertEquals(1, readCalled.get()); + + Assert.assertTrue(channel.readInbound() instanceof HttpResponse); + + Assert.assertFalse(channel.writeInbound(new DefaultHttpContent(Unpooled.EMPTY_BUFFER))); + + // read was triggered by the HttpContentDecompressor itself as it did not produce any message to the next + // inbound handler. + Assert.assertEquals(2, readCalled.get()); + Assert.assertFalse(channel.finishAndReleaseAll()); + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java index 9242cf872833..6301ee8c0b56 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java @@ -22,6 +22,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.EncoderException; import io.netty.handler.codec.MessageToByteEncoder; import io.netty.util.CharsetUtil; @@ -156,6 +157,7 @@ public void testChunkedContentWithTrailingHeader() throws Exception { assertThat(chunk.content().isReadable(), is(false)); assertThat(chunk, is(instanceOf(LastHttpContent.class))); assertEquals("Netty", ((LastHttpContent) chunk).trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); chunk.release(); assertThat(ch.readOutbound(), is(nullValue())); @@ -285,6 +287,7 @@ public void testEmptyFullContentWithTrailer() throws Exception { assertThat(res.content().readableBytes(), is(0)); assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); assertEquals("Netty", res.trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); assertThat(ch.readOutbound(), is(nullValue())); } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java index 7dc0ac5a5195..503c93339f01 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java @@ -23,7 +23,10 @@ import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.DecoderResultProvider; import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.AsciiString; import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; + import org.junit.Test; import org.mockito.Mockito; @@ -40,6 +43,7 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.Assert.assertSame; public class HttpObjectAggregatorTest { @@ -517,4 +521,123 @@ public void testReplaceAggregatedResponse() { aggregatedRep.release(); replacedRep.release(); } + + @Test + public void testSelectiveRequestAggregation() { + HttpObjectAggregator myPostAggregator = new HttpObjectAggregator(1024 * 1024) { + @Override + protected boolean isStartMessage(HttpObject msg) throws Exception { + if (msg instanceof HttpRequest) { + HttpRequest request = (HttpRequest) msg; + HttpMethod method = request.method(); + + if (method.equals(HttpMethod.POST)) { + return true; + } + } + + return false; + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(myPostAggregator); + + try { + // Aggregate: POST + HttpRequest request1 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + HttpContent content1 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8)); + request1.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); + + assertTrue(channel.writeInbound(request1, content1, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting an aggregated response out + Object msg1 = channel.readInbound(); + try { + assertTrue(msg1 instanceof FullHttpRequest); + } finally { + ReferenceCountUtil.release(msg1); + } + + // Don't aggregate: non-POST + HttpRequest request2 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/"); + HttpContent content2 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8)); + request2.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); + + try { + assertTrue(channel.writeInbound(request2, content2, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting the same response objects out + assertSame(request2, channel.readInbound()); + assertSame(content2, channel.readInbound()); + assertSame(LastHttpContent.EMPTY_LAST_CONTENT, channel.readInbound()); + } finally { + ReferenceCountUtil.release(request2); + ReferenceCountUtil.release(content2); + } + + assertFalse(channel.finish()); + } finally { + channel.close(); + } + } + + @Test + public void testSelectiveResponseAggregation() { + HttpObjectAggregator myTextAggregator = new HttpObjectAggregator(1024 * 1024) { + @Override + protected boolean isStartMessage(HttpObject msg) throws Exception { + if (msg instanceof HttpResponse) { + HttpResponse response = (HttpResponse) msg; + HttpHeaders headers = response.headers(); + + String contentType = headers.get(HttpHeaderNames.CONTENT_TYPE); + if (AsciiString.contentEqualsIgnoreCase(contentType, HttpHeaderValues.TEXT_PLAIN)) { + return true; + } + } + + return false; + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(myTextAggregator); + + try { + // Aggregate: text/plain + HttpResponse response1 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpContent content1 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8)); + response1.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); + + assertTrue(channel.writeInbound(response1, content1, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting an aggregated response out + Object msg1 = channel.readInbound(); + try { + assertTrue(msg1 instanceof FullHttpResponse); + } finally { + ReferenceCountUtil.release(msg1); + } + + // Don't aggregate: application/json + HttpResponse response2 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpContent content2 = new DefaultHttpContent(Unpooled.copiedBuffer("{key: 'value'}", CharsetUtil.UTF_8)); + response2.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + + try { + assertTrue(channel.writeInbound(response2, content2, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting the same response objects out + assertSame(response2, channel.readInbound()); + assertSame(content2, channel.readInbound()); + assertSame(LastHttpContent.EMPTY_LAST_CONTENT, channel.readInbound()); + } finally { + ReferenceCountUtil.release(response2); + ReferenceCountUtil.release(content2); + } + + assertFalse(channel.finish()); + } finally { + channel.close(); + } + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java index 45720631c40a..717b58090158 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java @@ -320,4 +320,80 @@ public void testTooLargeHeaders() { assertTrue(request.decoderResult().cause() instanceof TooLongFrameException); assertFalse(channel.finish()); } + + @Test + public void testWhitespace() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Transfer-Encoding : chunked\r\n" + + "Host: netty.io\n\r\n"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testHeaderWithNoValueAndMissingColon() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Content-Length: 0\r\n" + + "Host:\r\n" + + "netty.io\r\n\r\n"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testMultipleContentLengthHeaders() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Content-Length: 1\r\n" + + "Content-Length: 0\r\n\r\n" + + "b"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testMultipleContentLengthHeaders2() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Content-Length: 1\r\n" + + "Connection: close\r\n" + + "Content-Length: 0\r\n\r\n" + + "b"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testContentLengthHeaderWithCommaValue() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Content-Length: 1,1\r\n\r\n" + + "b"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testMultipleContentLengthHeadersWithFolding() { + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n" + + "Content-Length: 5\r\n" + + "Content-Length:\r\n" + + "\t6\r\n\r\n" + + "123456"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testContentLengthHeaderAndChunked() { + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + testInvalidHeaders0(requestStr); + } + + private static void testInvalidHeaders0(String requestStr) { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isFailure()); + assertTrue(request.decoderResult().cause() instanceof IllegalArgumentException); + assertFalse(channel.finish()); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java index 017dbd5ff943..6378a376b11c 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java @@ -683,4 +683,48 @@ public void testConnectionClosedBeforeHeadersReceived() { assertThat(message.decoderResult().cause(), instanceOf(PrematureChannelClosureException.class)); assertNull(channel.readInbound()); } + + @Test + public void testTrailerWithEmptyLineInSeparateBuffer() { + HttpResponseDecoder decoder = new HttpResponseDecoder(); + EmbeddedChannel channel = new EmbeddedChannel(decoder); + + String headers = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "Trailer: My-Trailer\r\n"; + assertFalse(channel.writeInbound(Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII)))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("\r\n".getBytes(CharsetUtil.US_ASCII)))); + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("0\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("My-Trailer: 42\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); + + HttpResponse response = channel.readInbound(); + assertEquals(2, response.headers().size()); + assertEquals("chunked", response.headers().get(HttpHeaderNames.TRANSFER_ENCODING)); + assertEquals("My-Trailer", response.headers().get(HttpHeaderNames.TRAILER)); + + LastHttpContent lastContent = channel.readInbound(); + assertEquals(1, lastContent.trailingHeaders().size()); + assertEquals("42", lastContent.trailingHeaders().get("My-Trailer")); + assertEquals(0, lastContent.content().readableBytes()); + lastContent.release(); + + assertFalse(channel.finish()); + } + + @Test + public void testWhitespace() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + String requestStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding : chunked\r\n" + + "Host: netty.io\n\r\n"; + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); + assertEquals(HttpHeaderValues.CHUNKED.toString(), response.headers().get(HttpHeaderNames.TRANSFER_ENCODING)); + assertEquals("netty.io", response.headers().get(HttpHeaderNames.HOST)); + assertFalse(channel.finish()); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java index 6ad08e6b4c88..31596067aefe 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java @@ -15,10 +15,6 @@ */ package io.netty.handler.codec.http; -import io.netty.util.CharsetUtil; -import io.netty.util.ReferenceCountUtil; -import org.junit.Test; - import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; @@ -26,12 +22,14 @@ import java.util.Collections; import java.util.List; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.Test; + import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; -import static org.hamcrest.Matchers.hasToString; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -91,6 +89,22 @@ public void testGetCharset() { assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(UPPER_CASE_NORMAL_CONTENT_TYPE)); } + @Test + public void testGetCharsetIfNotLastParameter() { + String NORMAL_CONTENT_TYPE_WITH_PARAMETERS = "application/soap-xml; charset=utf-8; " + + "action=\"http://www.soap-service.by/foo/add\""; + + HttpMessage message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost:7788/foo"); + message.headers().set(HttpHeaderNames.CONTENT_TYPE, NORMAL_CONTENT_TYPE_WITH_PARAMETERS); + + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(message)); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(NORMAL_CONTENT_TYPE_WITH_PARAMETERS)); + + assertEquals("utf-8", HttpUtil.getCharsetAsSequence(message)); + assertEquals("utf-8", HttpUtil.getCharsetAsSequence(NORMAL_CONTENT_TYPE_WITH_PARAMETERS)); + } + @Test public void testGetCharset_defaultValue() { final String SIMPLE_CONTENT_TYPE = "text/html"; @@ -292,4 +306,15 @@ public void testIpv4Unresolved() { InetSocketAddress socketAddress = InetSocketAddress.createUnresolved("10.0.0.1", 8080); assertEquals("10.0.0.1", HttpUtil.formatHostnameForHttp(socketAddress)); } + + @Test + public void testKeepAliveIfConnectionHeaderAbsent() { + HttpMessage http11Message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "http:localhost/http_1_1"); + assertTrue(HttpUtil.isKeepAlive(http11Message)); + + HttpMessage http10Message = new DefaultHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, + "http:localhost/http_1_0"); + assertFalse(HttpUtil.isKeepAlive(http10Message)); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cookie/ServerCookieEncoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cookie/ServerCookieEncoderTest.java index 723a7a17a9f6..82f813ff4779 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/cookie/ServerCookieEncoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/cookie/ServerCookieEncoderTest.java @@ -18,7 +18,9 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import static org.junit.matchers.JUnitMatchers.containsString; import io.netty.handler.codec.DateFormatter; import java.text.ParseException; @@ -131,6 +133,15 @@ public void illegalCharInCookieValueMakesStrictEncoderThrowsException() { assertEquals(illegalChars.size(), exceptions); } + @Test + public void illegalCharInWrappedValueAppearsInException() { + try { + ServerCookieEncoder.STRICT.encode(new DefaultCookie("name", "\"value,\"")); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage().toLowerCase(), containsString("cookie value contains an invalid char: ,")); + } + } + @Test public void testEncodingMultipleCookiesLax() { List result = new ArrayList(); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoderTest.java index 547eed6174e8..cb6c187efcf2 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoderTest.java @@ -12,18 +12,61 @@ */ package io.netty.handler.codec.http.websocketx; +import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; - +import io.netty.channel.embedded.EmbeddedChannel; import org.junit.Test; -import org.mockito.Mockito; + +import java.util.HashSet; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; public class WebSocket08FrameDecoderTest { @Test public void channelInactive() throws Exception { final WebSocket08FrameDecoder decoder = new WebSocket08FrameDecoder(true, true, 65535, false); - final ChannelHandlerContext ctx = Mockito.mock(ChannelHandlerContext.class); + final ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); decoder.channelInactive(ctx); - Mockito.verify(ctx).fireChannelInactive(); + verify(ctx).fireChannelInactive(); + } + + @Test + public void supportIanaStatusCodes() throws Exception { + Set forbiddenIanaCodes = new HashSet(); + forbiddenIanaCodes.add(1004); + forbiddenIanaCodes.add(1005); + forbiddenIanaCodes.add(1006); + Set validIanaCodes = new HashSet(); + for (int i = 1000; i < 1015; i++) { + validIanaCodes.add(i); + } + validIanaCodes.removeAll(forbiddenIanaCodes); + + for (int statusCode: validIanaCodes) { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(true)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(true, true, 65535, false)); + + assertTrue(encoderChannel.writeOutbound(new CloseWebSocketFrame(statusCode, "Bye"))); + assertTrue(encoderChannel.finish()); + ByteBuf serializedCloseFrame = encoderChannel.readOutbound(); + assertNull(encoderChannel.readOutbound()); + + assertTrue(decoderChannel.writeInbound(serializedCloseFrame)); + assertTrue(decoderChannel.finish()); + + CloseWebSocketFrame outputFrame = decoderChannel.readInbound(); + assertNull(decoderChannel.readOutbound()); + try { + assertEquals(statusCode, outputFrame.statusCode()); + } finally { + outputFrame.release(); + } + } } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java index bda0734ad969..33c6ce6847e4 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java @@ -16,17 +16,35 @@ package io.netty.handler.codec.http.websocketx; import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; import java.net.URI; public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTest { @Override - protected WebSocketClientHandshaker newHandshaker(URI uri) { - return new WebSocketClientHandshaker00(uri, WebSocketVersion.V00, null, null, 1024); + protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers) { + return new WebSocketClientHandshaker00(uri, WebSocketVersion.V00, subprotocol, headers, 1024); } @Override protected CharSequence getOriginHeaderName() { return HttpHeaderNames.ORIGIN; } + + @Override + protected CharSequence getProtocolHeaderName() { + return HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL; + } + + @Override + protected CharSequence[] getHandshakeHeaderNames() { + return new CharSequence[] { + HttpHeaderNames.CONNECTION, + HttpHeaderNames.UPGRADE, + HttpHeaderNames.HOST, + HttpHeaderNames.ORIGIN, + HttpHeaderNames.SEC_WEBSOCKET_KEY1, + HttpHeaderNames.SEC_WEBSOCKET_KEY2, + }; + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java index bce6c73a78ff..9ff3e8485b90 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java @@ -16,17 +16,35 @@ package io.netty.handler.codec.http.websocketx; import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; import java.net.URI; public class WebSocketClientHandshaker07Test extends WebSocketClientHandshakerTest { @Override - protected WebSocketClientHandshaker newHandshaker(URI uri) { - return new WebSocketClientHandshaker07(uri, WebSocketVersion.V07, null, false, null, 1024); + protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers) { + return new WebSocketClientHandshaker07(uri, WebSocketVersion.V07, subprotocol, false, headers, 1024); } @Override protected CharSequence getOriginHeaderName() { return HttpHeaderNames.SEC_WEBSOCKET_ORIGIN; } + + @Override + protected CharSequence getProtocolHeaderName() { + return HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL; + } + + @Override + protected CharSequence[] getHandshakeHeaderNames() { + return new CharSequence[] { + HttpHeaderNames.UPGRADE, + HttpHeaderNames.CONNECTION, + HttpHeaderNames.SEC_WEBSOCKET_KEY, + HttpHeaderNames.HOST, + HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, + HttpHeaderNames.SEC_WEBSOCKET_VERSION, + }; + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08Test.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08Test.java index 4ce8016adda3..1efb6821b9b6 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08Test.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08Test.java @@ -15,11 +15,13 @@ */ package io.netty.handler.codec.http.websocketx; +import io.netty.handler.codec.http.HttpHeaders; + import java.net.URI; public class WebSocketClientHandshaker08Test extends WebSocketClientHandshaker07Test { @Override - protected WebSocketClientHandshaker newHandshaker(URI uri) { - return new WebSocketClientHandshaker07(uri, WebSocketVersion.V08, null, false, null, 1024); + protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers) { + return new WebSocketClientHandshaker08(uri, WebSocketVersion.V08, subprotocol, false, headers, 1024); } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java index ad89fde6bc12..1727178831df 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java @@ -15,11 +15,13 @@ */ package io.netty.handler.codec.http.websocketx; +import io.netty.handler.codec.http.HttpHeaders; + import java.net.URI; public class WebSocketClientHandshaker13Test extends WebSocketClientHandshaker07Test { @Override - protected WebSocketClientHandshaker newHandshaker(URI uri) { - return new WebSocketClientHandshaker13(uri, WebSocketVersion.V13, null, false, null, 1024); + protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers) { + return new WebSocketClientHandshaker13(uri, WebSocketVersion.V13, subprotocol, false, headers, 1024); } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java index eeb0d69d1dc2..2054af513f90 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java @@ -21,11 +21,13 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.EmptyHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestEncoder; import io.netty.handler.codec.http.HttpResponseDecoder; @@ -35,14 +37,21 @@ import java.net.URI; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public abstract class WebSocketClientHandshakerTest { - protected abstract WebSocketClientHandshaker newHandshaker(URI uri); + protected abstract WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers); + + protected WebSocketClientHandshaker newHandshaker(URI uri) { + return newHandshaker(uri, null, null); + } protected abstract CharSequence getOriginHeaderName(); + protected abstract CharSequence getProtocolHeaderName(); + + protected abstract CharSequence[] getHandshakeHeaderNames(); + @Test public void hostHeaderWs() { for (String scheme : new String[]{"ws://", "http://"}) { @@ -231,8 +240,8 @@ protected WebSocketFrameEncoder newWebSocketEncoder() { } }; - byte[] data = new byte[24]; - PlatformDependent.threadLocalRandom().nextBytes(data); + // use randomBytes helper from utils to check that it functions properly + byte[] data = WebSocketUtil.randomBytes(24); // Create a EmbeddedChannel which we will use to encode a BinaryWebsocketFrame to bytes and so use these // to test the actual handshaker. @@ -292,4 +301,36 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) thr frame.release(); } } + + @Test + public void testDuplicateWebsocketHandshakeHeaders() { + URI uri = URI.create("ws://localhost:9999/foo"); + + HttpHeaders inputHeaders = new DefaultHttpHeaders(); + String bogusSubProtocol = "bogusSubProtocol"; + String bogusHeaderValue = "bogusHeaderValue"; + + // add values for the headers that are reserved for use in the websockets handshake + for (CharSequence header : getHandshakeHeaderNames()) { + inputHeaders.add(header, bogusHeaderValue); + } + inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol); + + String realSubProtocol = "realSubProtocol"; + WebSocketClientHandshaker handshaker = newHandshaker(uri, realSubProtocol, inputHeaders); + FullHttpRequest request = handshaker.newHandshakeRequest(); + HttpHeaders outputHeaders = request.headers(); + + // the header values passed in originally have been replaced with values generated by the Handshaker + for (CharSequence header : getHandshakeHeaderNames()) { + assertEquals(1, outputHeaders.getAll(header).size()); + assertNotEquals(bogusHeaderValue, outputHeaders.get(header)); + } + + // the subprotocol header value is that of the subprotocol string passed into the Handshaker + assertEquals(1, outputHeaders.getAll(getProtocolHeaderName()).size()); + assertEquals(realSubProtocol, outputHeaders.get(getProtocolHeaderName())); + + request.release(); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandlerTest.java new file mode 100644 index 000000000000..af7498222ced --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandlerTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Tests common, abstract class functionality in {@link WebSocketClientProtocolHandler}. + */ +public class WebSocketProtocolHandlerTest { + + @Test + public void testPingFrame() { + ByteBuf pingData = Unpooled.copiedBuffer("Hello, world", CharsetUtil.UTF_8); + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketProtocolHandler() { }); + + PingWebSocketFrame inputMessage = new PingWebSocketFrame(pingData); + assertFalse(channel.writeInbound(inputMessage)); // the message was not propagated inbound + + // a Pong frame was written to the channel + PongWebSocketFrame response = channel.readOutbound(); + assertEquals(pingData, response.content()); + + pingData.release(); + assertFalse(channel.finish()); + } + + @Test + public void testPongFrameDropFrameFalse() { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketProtocolHandler(false) { }); + + PongWebSocketFrame pingResponse = new PongWebSocketFrame(); + assertTrue(channel.writeInbound(pingResponse)); + + assertPropagatedInbound(pingResponse, channel); + + pingResponse.release(); + assertFalse(channel.finish()); + } + + @Test + public void testPongFrameDropFrameTrue() { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketProtocolHandler(true) { }); + + PongWebSocketFrame pingResponse = new PongWebSocketFrame(); + assertFalse(channel.writeInbound(pingResponse)); // message was not propagated inbound + } + + @Test + public void testTextFrame() { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketProtocolHandler() { }); + + TextWebSocketFrame textFrame = new TextWebSocketFrame(); + assertTrue(channel.writeInbound(textFrame)); + + assertPropagatedInbound(textFrame, channel); + + textFrame.release(); + assertFalse(channel.finish()); + } + + /** + * Asserts that a message was propagated inbound through the channel. + */ + private static void assertPropagatedInbound(T message, EmbeddedChannel channel) { + T propagatedResponse = channel.readInbound(); + assertEquals(message, propagatedResponse); + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00Test.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00Test.java index 76826aba8b06..8783e0b5bbc5 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00Test.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00Test.java @@ -32,7 +32,9 @@ import org.junit.Assert; import org.junit.Test; -import static io.netty.handler.codec.http.HttpVersion.*; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; public class WebSocketServerHandshaker00Test { @@ -46,6 +48,34 @@ public void testPerformOpeningHandshakeSubProtocolNotSupported() { testPerformOpeningHandshake0(false); } + @Test + public void testPerformHandshakeWithoutOriginHeader() { + EmbeddedChannel ch = new EmbeddedChannel( + new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder()); + + FullHttpRequest req = new DefaultFullHttpRequest( + HTTP_1_1, HttpMethod.GET, "/chat", Unpooled.copiedBuffer("^n:ds[4U", CharsetUtil.US_ASCII)); + + req.headers().set(HttpHeaderNames.HOST, "server.example.com"); + req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); + req.headers().set(HttpHeaderNames.CONNECTION, "Upgrade"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, "4 @1 46546xW%0l 1 5"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat"); + + WebSocketServerHandshaker00 handshaker00 = new WebSocketServerHandshaker00( + "ws://example.com/chat", "chat", Integer.MAX_VALUE); + try { + handshaker00.handshake(ch, req); + fail("Expecting WebSocketHandshakeException"); + } catch (WebSocketHandshakeException e) { + assertEquals("Missing origin header, got only " + + "[host, upgrade, connection, sec-websocket-key1, sec-websocket-protocol]", + e.getMessage()); + } finally { + req.release(); + } + } + private static void testPerformOpeningHandshake0(boolean subProtocol) { EmbeddedChannel ch = new EmbeddedChannel( new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder()); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtf8FrameValidatorTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtf8FrameValidatorTest.java new file mode 100644 index 000000000000..c3bb0ed7dd3d --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtf8FrameValidatorTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CorruptedFrameException; +import org.junit.Assert; +import org.junit.Test; + +public class WebSocketUtf8FrameValidatorTest { + + @Test + public void testCorruptedFrameExceptionInFinish() { + assertCorruptedFrameExceptionHandling(new byte[]{-50}); + } + + @Test + public void testCorruptedFrameExceptionInCheck() { + assertCorruptedFrameExceptionHandling(new byte[]{-8, -120, -128, -128, -128}); + } + + private void assertCorruptedFrameExceptionHandling(byte[] data) { + EmbeddedChannel channel = new EmbeddedChannel(new Utf8FrameValidator()); + try { + channel.writeInbound(new TextWebSocketFrame(Unpooled.copiedBuffer(data))); + Assert.fail(); + } catch (CorruptedFrameException e) { + // expected exception + } + Assert.assertTrue(channel.finish()); + ByteBuf buf = channel.readOutbound(); + Assert.assertNotNull(buf); + try { + Assert.assertFalse(buf.isReadable()); + } finally { + buf.release(); + } + Assert.assertNull(channel.readOutbound()); + } +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtilTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtilTest.java new file mode 100644 index 000000000000..e40501951ab6 --- /dev/null +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtilTest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +public class WebSocketUtilTest { + + // how many times do we want to run each random variable checker + private static final int NUM_ITERATIONS = 1000; + + private static void assertRandomWithinBoundaries(int min, int max) { + int r = WebSocketUtil.randomNumber(min, max); + assertTrue(min <= r && r <= max); + } + + @Test + public void testRandomNumberGenerator() { + int iteration = 0; + while (++iteration < NUM_ITERATIONS) { + assertRandomWithinBoundaries(0, 1); + assertRandomWithinBoundaries(0, 1); + assertRandomWithinBoundaries(-1, 1); + assertRandomWithinBoundaries(-1, 0); + } + } + +} diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionTestUtil.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionTestUtil.java index 411b167f8f3c..38867febd100 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionTestUtil.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionTestUtil.java @@ -69,7 +69,7 @@ static final class WebSocketExtensionDataMatcher implements ArgumentMatcheremptyMap())); assertNotNull(extension); - assertEquals(WebSocketClientExtension.RSV1, extension.rsv()); + assertEquals(RSV1, extension.rsv()); assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); } @@ -92,7 +96,7 @@ public void testCustomHandshake() { // test assertNotNull(extension); - assertEquals(WebSocketClientExtension.RSV1, extension.rsv()); + assertEquals(RSV1, extension.rsv()); assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); @@ -107,7 +111,7 @@ public void testCustomHandshake() { // test assertNotNull(extension); - assertEquals(WebSocketClientExtension.RSV1, extension.rsv()); + assertEquals(RSV1, extension.rsv()); assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); @@ -121,4 +125,61 @@ public void testCustomHandshake() { // test assertNull(extension); } + + @Test + public void testDecoderNoClientContext() { + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, MAX_WINDOW_SIZE, true, false); + + byte[] firstPayload = new byte[] { + 76, -50, -53, 10, -62, 48, 20, 4, -48, 95, 41, 89, -37, 36, 77, 90, 31, -39, 41, -72, 112, 33, -120, 20, + 20, 119, -79, 70, 123, -95, 121, -48, 92, -116, 80, -6, -17, -58, -99, -37, -31, 12, 51, 19, 1, -9, -12, + 68, -111, -117, 25, 58, 111, 77, -127, -66, -64, -34, 20, 59, -64, -29, -2, 90, -100, -115, 30, 16, 114, + -68, 61, 29, 40, 89, -112, -73, 25, 35, 120, -105, -67, -32, -43, -70, -84, 120, -55, 69, 43, -124, 106, + -92, 18, -110, 114, -50, 111, 25, -3, 10, 17, -75, 13, 127, -84, 106, 90, -66, 84, -75, 84, 53, -89, + -75, 92, -3, -40, -61, 119, 49, -117, 30, 49, 68, -59, 88, 74, -119, -34, 1, -83, -7, -48, 124, -124, + -23, 16, 88, -118, 121, 54, -53, 1, 44, 32, 81, 19, 25, -115, -43, -32, -64, -67, -120, -110, -101, 121, + -2, 2 + }; + + byte[] secondPayload = new byte[] { + -86, 86, 42, 46, 77, 78, 78, 45, 6, 26, 83, 82, 84, -102, -86, 3, -28, 38, 21, 39, 23, 101, 38, -91, 2, + -51, -51, 47, 74, 73, 45, 114, -54, -49, -49, -10, 49, -78, -118, 112, 10, 9, 13, 118, 1, -102, 84, + -108, 90, 88, 10, 116, 27, -56, -84, 124, -112, -13, 16, 26, 116, -108, 18, -117, -46, -127, 6, 69, 99, + -45, 24, 91, 91, 11, 0 + }; + + Map parameters = Collections.singletonMap(CLIENT_NO_CONTEXT, null); + + WebSocketClientExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + assertNotNull(extension); + + EmbeddedChannel decoderChannel = new EmbeddedChannel(extension.newExtensionDecoder()); + assertTrue( + decoderChannel.writeInbound(new TextWebSocketFrame(true, RSV1, Unpooled.copiedBuffer(firstPayload)))); + TextWebSocketFrame firstFrameDecompressed = decoderChannel.readInbound(); + assertTrue( + decoderChannel.writeInbound(new TextWebSocketFrame(true, RSV1, Unpooled.copiedBuffer(secondPayload)))); + TextWebSocketFrame secondFrameDecompressed = decoderChannel.readInbound(); + + assertNotNull(firstFrameDecompressed); + assertNotNull(firstFrameDecompressed.content()); + assertTrue(firstFrameDecompressed instanceof TextWebSocketFrame); + assertEquals(firstFrameDecompressed.text(), + "{\"info\":\"Welcome to the BitMEX Realtime API.\",\"version\"" + + ":\"2018-10-02T22:53:23.000Z\",\"timestamp\":\"2018-10-15T06:43:40.437Z\"," + + "\"docs\":\"https://www.bitmex.com/app/wsAPI\",\"limit\":{\"remaining\":39}}"); + assertTrue(firstFrameDecompressed.release()); + + assertNotNull(secondFrameDecompressed); + assertNotNull(secondFrameDecompressed.content()); + assertTrue(secondFrameDecompressed instanceof TextWebSocketFrame); + assertEquals(secondFrameDecompressed.text(), + "{\"success\":true,\"subscribe\":\"orderBookL2:XBTUSD\"," + + "\"request\":{\"op\":\"subscribe\",\"args\":[\"orderBookL2:XBTUSD\"]}}"); + assertTrue(secondFrameDecompressed.release()); + + assertFalse(decoderChannel.finish()); + } } diff --git a/codec-http2/pom.xml b/codec-http2/pom.xml index d0961ff6f663..7973628d221b 100644 --- a/codec-http2/pom.xml +++ b/codec-http2/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-http2 @@ -35,7 +35,22 @@ ${project.groupId} - netty-codec-http + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + + + ${project.groupId} + netty-codec ${project.version} @@ -43,6 +58,11 @@ netty-handler ${project.version} + + ${project.groupId} + netty-codec-http + ${project.version} + com.jcraft jzlib diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java index 3137da21ed25..d0ba8646c734 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java @@ -108,7 +108,7 @@ public ChannelFuture writeData(final ChannelHandlerContext ctx, final int stream return promise; } - PromiseCombiner combiner = new PromiseCombiner(); + PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); for (;;) { ByteBuf nextBuf = nextReadableBuf(channel); boolean compressedEndOfStream = nextBuf == null && endOfStream; diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java index 425131745726..4d866eb41af4 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java @@ -25,7 +25,6 @@ import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.UnaryPromiseNotifier; import io.netty.util.internal.EmptyArrays; -import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -225,7 +224,12 @@ public boolean goAwayReceived() { } @Override - public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf debugData) { + public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf debugData) throws Http2Exception { + if (localEndpoint.lastStreamKnownByPeer() >= 0 && localEndpoint.lastStreamKnownByPeer() < lastKnownStream) { + throw connectionError(PROTOCOL_ERROR, "lastStreamId MUST NOT increase. Current value: %d new value: %d", + localEndpoint.lastStreamKnownByPeer(), lastKnownStream); + } + localEndpoint.lastStreamKnownByPeer(lastKnownStream); for (int i = 0; i < listeners.size(); ++i) { try { @@ -235,19 +239,7 @@ public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf de } } - try { - forEachActiveStream(new Http2StreamVisitor() { - @Override - public boolean visit(Http2Stream stream) { - if (stream.id() > lastKnownStream && localEndpoint.isValidStreamId(stream.id())) { - stream.close(); - } - return true; - } - }); - } catch (Http2Exception e) { - PlatformDependent.throwException(e); - } + closeStreamsGreaterThanLastKnownStreamId(lastKnownStream, localEndpoint); } @Override @@ -256,7 +248,20 @@ public boolean goAwaySent() { } @Override - public void goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugData) { + public boolean goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugData) throws Http2Exception { + if (remoteEndpoint.lastStreamKnownByPeer() >= 0) { + // Protect against re-entrancy. Could happen if writing the frame fails, and error handling + // treating this is a connection handler and doing a graceful shutdown... + if (lastKnownStream == remoteEndpoint.lastStreamKnownByPeer()) { + return false; + } + if (lastKnownStream > remoteEndpoint.lastStreamKnownByPeer()) { + throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " + + "sending multiple GOAWAY frames (was '%d', is '%d').", + remoteEndpoint.lastStreamKnownByPeer(), lastKnownStream); + } + } + remoteEndpoint.lastStreamKnownByPeer(lastKnownStream); for (int i = 0; i < listeners.size(); ++i) { try { @@ -266,19 +271,21 @@ public void goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugD } } - try { - forEachActiveStream(new Http2StreamVisitor() { - @Override - public boolean visit(Http2Stream stream) { - if (stream.id() > lastKnownStream && remoteEndpoint.isValidStreamId(stream.id())) { - stream.close(); - } - return true; + closeStreamsGreaterThanLastKnownStreamId(lastKnownStream, remoteEndpoint); + return true; + } + + private void closeStreamsGreaterThanLastKnownStreamId(final int lastKnownStream, + final DefaultEndpoint endpoint) throws Http2Exception { + forEachActiveStream(new Http2StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) { + if (stream.id() > lastKnownStream && endpoint.isValidStreamId(stream.id())) { + stream.close(); } - }); - } catch (Http2Exception e) { - PlatformDependent.throwException(e); - } + return true; + } + }); } /** @@ -863,10 +870,10 @@ private void updateMaxStreams() { private void checkNewStreamAllowed(int streamId, State state) throws Http2Exception { assert state != IDLE; - if (goAwayReceived() && streamId > localEndpoint.lastStreamKnownByPeer()) { - throw connectionError(PROTOCOL_ERROR, "Cannot create stream %d since this endpoint has received a " + - "GOAWAY frame with last stream id %d.", streamId, - localEndpoint.lastStreamKnownByPeer()); + if (lastStreamKnownByPeer >= 0 && streamId > lastStreamKnownByPeer) { + throw streamError(streamId, REFUSED_STREAM, + "Cannot create stream %d greater than Last-Stream-ID %d from GOAWAY.", + streamId, lastStreamKnownByPeer); } if (!isValidStreamId(streamId)) { if (streamId < 0) { @@ -923,7 +930,7 @@ private final class ActiveStreams { private final Set streams = new LinkedHashSet(); private int pendingIterations; - public ActiveStreams(List listeners) { + ActiveStreams(List listeners) { this.listeners = listeners; } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java index 49335e95e737..2d78fc9ca917 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java @@ -158,12 +158,8 @@ private int unconsumedBytes(Http2Stream stream) { void onGoAwayRead0(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) throws Http2Exception { - if (connection.goAwayReceived() && connection.local().lastStreamKnownByPeer() < lastStreamId) { - throw connectionError(PROTOCOL_ERROR, "lastStreamId MUST NOT increase. Current value: %d new value: %d", - connection.local().lastStreamKnownByPeer(), lastStreamId); - } - listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); connection.goAwayReceived(lastStreamId, errorCode, debugData); + listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); } void onUnknownFrame0(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, @@ -535,12 +531,18 @@ private boolean shouldIgnoreHeadersOrDataFrame(ChannelHandlerContext ctx, int st throw streamError(streamId, STREAM_CLOSED, "Received %s frame for an unknown stream %d", frameName, streamId); } else if (stream.isResetSent() || streamCreatedAfterGoAwaySent(streamId)) { + // If we have sent a reset stream it is assumed the stream will be closed after the write completes. + // If we have not sent a reset, but the stream was created after a GoAway this is not supported by + // DefaultHttp2Connection and if a custom Http2Connection is used it is assumed the lifetime is managed + // elsewhere so we don't close the stream or otherwise modify the stream's state. + if (logger.isInfoEnabled()) { - logger.info("{} ignoring {} frame for stream {} {}", ctx.channel(), frameName, + logger.info("{} ignoring {} frame for stream {}", ctx.channel(), frameName, stream.isResetSent() ? "RST_STREAM sent." : ("Stream created after GOAWAY sent. Last known stream by peer " + connection.remote().lastStreamKnownByPeer())); } + return true; } return false; diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java index d0c5944cb9f7..ff6a30eae623 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java @@ -30,6 +30,7 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static java.lang.Integer.MAX_VALUE; import static java.lang.Math.min; @@ -163,7 +164,12 @@ public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int str Http2Stream stream = connection.stream(streamId); if (stream == null) { try { - stream = connection.local().createStream(streamId, endOfStream); + // We don't create the stream in a `halfClosed` state because if this is an initial + // HEADERS frame we don't want the connection state to signify that the HEADERS have + // been sent until after they have been encoded and placed in the outbound buffer. + // Therefore, we let the `LifeCycleManager` will take care of transitioning the state + // as appropriate. + stream = connection.local().createStream(streamId, /*endOfStream*/ false); } catch (Http2Exception cause) { if (connection.remote().mayHaveCreatedStream(streamId)) { promise.tryFailure(new IllegalStateException("Stream no longer exists: " + streamId, cause)); @@ -190,17 +196,10 @@ public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int str // for this stream. Http2RemoteFlowController flowController = flowController(); if (!endOfStream || !flowController.hasFlowControlled(stream)) { + // The behavior here should mirror that in FlowControlledHeaders + + promise = promise.unvoid(); boolean isInformational = validateHeadersSentState(stream, headers, connection.isServer(), endOfStream); - if (endOfStream) { - final Http2Stream finalStream = stream; - final ChannelFutureListener closeStreamLocalListener = new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - lifecycleManager.closeStreamLocal(finalStream, future); - } - }; - promise = promise.unvoid().addListener(closeStreamLocalListener); - } ChannelFuture future = frameWriter.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise); @@ -222,6 +221,13 @@ public void operationComplete(ChannelFuture future) throws Exception { lifecycleManager.onError(ctx, true, failureCause); } + if (endOfStream) { + // Must handle calling onError before calling closeStreamLocal, otherwise the error handler will + // incorrectly think the stream no longer exists and so may not send RST_STREAM or perform similar + // appropriate action. + lifecycleManager.closeStreamLocal(stream, future); + } + return future; } else { // Pass headers to the flow-controller so it can maintain their sequence relative to DATA frames. @@ -288,6 +294,7 @@ public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, i // Reserve the promised stream. connection.local().reservePushStream(promisedStreamId, stream); + promise = promise.unvoid(); ChannelFuture future = frameWriter.writePushPromise(ctx, streamId, promisedStreamId, headers, padding, promise); // Writing headers may fail during the encode state if they violate HPACK limits. @@ -391,6 +398,9 @@ public void error(ChannelHandlerContext ctx, Throwable cause) { queue.releaseAndFailAll(cause); // Don't update dataSize because we need to ensure the size() method returns a consistent size even after // error so we don't invalidate flow control when returning bytes to flow control. + // + // That said we will set dataSize and padding to 0 in the write(...) method if we cleared the queue + // because of an error. lifecycleManager.onError(ctx, true, cause); } @@ -399,11 +409,21 @@ public void write(ChannelHandlerContext ctx, int allowedBytes) { int queuedData = queue.readableBytes(); if (!endOfStream) { if (queuedData == 0) { - // There's no need to write any data frames because there are only empty data frames in the queue - // and it is not end of stream yet. Just complete their promises by getting the buffer corresponding - // to 0 bytes and writing it to the channel (to preserve notification order). - ChannelPromise writePromise = ctx.newPromise().addListener(this); - ctx.write(queue.remove(0, writePromise), writePromise); + if (queue.isEmpty()) { + // When the queue is empty it means we did clear it because of an error(...) call + // (as otherwise we will have at least 1 entry in there), which will happen either when called + // explicit or when the write itself fails. In this case just set dataSize and padding to 0 + // which will signal back that the whole frame was consumed. + // + // See https://github.com/netty/netty/issues/8707. + padding = dataSize = 0; + } else { + // There's no need to write any data frames because there are only empty data frames in the + // queue and it is not end of stream yet. Just complete their promises by getting the buffer + // corresponding to 0 bytes and writing it to the channel (to preserve notification order). + ChannelPromise writePromise = ctx.newPromise().addListener(this); + ctx.write(queue.remove(0, writePromise), writePromise); + } return; } @@ -468,7 +488,7 @@ private final class FlowControlledHeaders extends FlowControlledBase { FlowControlledHeaders(Http2Stream stream, Http2Headers headers, int streamDependency, short weight, boolean exclusive, int padding, boolean endOfStream, ChannelPromise promise) { - super(stream, padding, endOfStream, promise); + super(stream, padding, endOfStream, promise.unvoid()); this.headers = headers; this.streamDependency = streamDependency; this.weight = weight; @@ -491,9 +511,8 @@ public void error(ChannelHandlerContext ctx, Throwable cause) { @Override public void write(ChannelHandlerContext ctx, int allowedBytes) { boolean isInformational = validateHeadersSentState(stream, headers, connection.isServer(), endOfStream); - if (promise.isVoid()) { - promise = ctx.newPromise(); - } + // The code is currently requiring adding this listener before writing, in order to call onError() before + // closeStreamLocal(). promise.addListener(this); ChannelFuture f = frameWriter.writeHeaders(ctx, stream.id(), headers, streamDependency, weight, exclusive, @@ -525,9 +544,7 @@ public abstract class FlowControlledBase implements Http2RemoteFlowController.Fl FlowControlledBase(final Http2Stream stream, int padding, boolean endOfStream, final ChannelPromise promise) { - if (padding < 0) { - throw new IllegalArgumentException("padding must be >= 0"); - } + checkPositiveOrZero(padding, "padding"); this.padding = padding; this.endOfStream = endOfStream; this.stream = stream; diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java index 63e184bc7b4b..cc63b2847461 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java @@ -239,8 +239,8 @@ private void processPayloadState(ChannelHandlerContext ctx, ByteBuf in, Http2Fra return; } - // Get a view of the buffer for the size of the payload. - ByteBuf payload = in.readSlice(payloadLength); + // Only process up to payloadLength bytes. + int payloadEndIndex = in.readerIndex() + payloadLength; // We have consumed the data, next time we read we will be expecting to read a frame header. readingHeaders = true; @@ -248,39 +248,40 @@ private void processPayloadState(ChannelHandlerContext ctx, ByteBuf in, Http2Fra // Read the payload and fire the frame event to the listener. switch (frameType) { case DATA: - readDataFrame(ctx, payload, listener); + readDataFrame(ctx, in, payloadEndIndex, listener); break; case HEADERS: - readHeadersFrame(ctx, payload, listener); + readHeadersFrame(ctx, in, payloadEndIndex, listener); break; case PRIORITY: - readPriorityFrame(ctx, payload, listener); + readPriorityFrame(ctx, in, listener); break; case RST_STREAM: - readRstStreamFrame(ctx, payload, listener); + readRstStreamFrame(ctx, in, listener); break; case SETTINGS: - readSettingsFrame(ctx, payload, listener); + readSettingsFrame(ctx, in, listener); break; case PUSH_PROMISE: - readPushPromiseFrame(ctx, payload, listener); + readPushPromiseFrame(ctx, in, payloadEndIndex, listener); break; case PING: - readPingFrame(ctx, payload.readLong(), listener); + readPingFrame(ctx, in.readLong(), listener); break; case GO_AWAY: - readGoAwayFrame(ctx, payload, listener); + readGoAwayFrame(ctx, in, payloadEndIndex, listener); break; case WINDOW_UPDATE: - readWindowUpdateFrame(ctx, payload, listener); + readWindowUpdateFrame(ctx, in, listener); break; case CONTINUATION: - readContinuationFrame(payload, listener); + readContinuationFrame(in, payloadEndIndex, listener); break; default: - readUnknownFrame(ctx, payload, listener); + readUnknownFrame(ctx, in, payloadEndIndex, listener); break; } + in.readerIndex(payloadEndIndex); } private void verifyDataFrame() throws Http2Exception { @@ -408,21 +409,20 @@ private void verifyUnknownFrame() throws Http2Exception { verifyNotProcessingHeaders(); } - private void readDataFrame(ChannelHandlerContext ctx, ByteBuf payload, + private void readDataFrame(ChannelHandlerContext ctx, ByteBuf payload, int payloadEndIndex, Http2FrameListener listener) throws Http2Exception { int padding = readPadding(payload); verifyPadding(padding); // Determine how much data there is to read by removing the trailing // padding. - int dataLength = lengthWithoutTrailingPadding(payload.readableBytes(), padding); + int dataLength = lengthWithoutTrailingPadding(payloadEndIndex - payload.readerIndex(), padding); ByteBuf data = payload.readSlice(dataLength); listener.onDataRead(ctx, streamId, data, padding, flags.endOfStream()); - payload.skipBytes(payload.readableBytes()); } - private void readHeadersFrame(final ChannelHandlerContext ctx, ByteBuf payload, + private void readHeadersFrame(final ChannelHandlerContext ctx, ByteBuf payload, int payloadEndIndex, Http2FrameListener listener) throws Http2Exception { final int headersStreamId = streamId; final Http2Flags headersFlags = flags; @@ -439,7 +439,7 @@ private void readHeadersFrame(final ChannelHandlerContext ctx, ByteBuf payload, throw streamError(streamId, PROTOCOL_ERROR, "A stream cannot depend on itself."); } final short weight = (short) (payload.readUnsignedByte() + 1); - final ByteBuf fragment = payload.readSlice(lengthWithoutTrailingPadding(payload.readableBytes(), padding)); + final int lenToRead = lengthWithoutTrailingPadding(payloadEndIndex - payload.readerIndex(), padding); // Create a handler that invokes the listener when the header block is complete. headersContinuation = new HeadersContinuation() { @@ -449,10 +449,10 @@ public int getStreamId() { } @Override - public void processFragment(boolean endOfHeaders, ByteBuf fragment, + public void processFragment(boolean endOfHeaders, ByteBuf fragment, int len, Http2FrameListener listener) throws Http2Exception { final HeadersBlockBuilder hdrBlockBuilder = headersBlockBuilder(); - hdrBlockBuilder.addFragment(fragment, ctx.alloc(), endOfHeaders); + hdrBlockBuilder.addFragment(fragment, len, ctx.alloc(), endOfHeaders); if (endOfHeaders) { listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), streamDependency, weight, exclusive, padding, headersFlags.endOfStream()); @@ -461,7 +461,7 @@ public void processFragment(boolean endOfHeaders, ByteBuf fragment, }; // Process the initial fragment, invoking the listener's callback if end of headers. - headersContinuation.processFragment(flags.endOfHeaders(), fragment, listener); + headersContinuation.processFragment(flags.endOfHeaders(), payload, lenToRead, listener); resetHeadersContinuationIfEnd(flags.endOfHeaders()); return; } @@ -475,10 +475,10 @@ public int getStreamId() { } @Override - public void processFragment(boolean endOfHeaders, ByteBuf fragment, + public void processFragment(boolean endOfHeaders, ByteBuf fragment, int len, Http2FrameListener listener) throws Http2Exception { final HeadersBlockBuilder hdrBlockBuilder = headersBlockBuilder(); - hdrBlockBuilder.addFragment(fragment, ctx.alloc(), endOfHeaders); + hdrBlockBuilder.addFragment(fragment, len, ctx.alloc(), endOfHeaders); if (endOfHeaders) { listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), padding, headersFlags.endOfStream()); @@ -487,8 +487,8 @@ public void processFragment(boolean endOfHeaders, ByteBuf fragment, }; // Process the initial fragment, invoking the listener's callback if end of headers. - final ByteBuf fragment = payload.readSlice(lengthWithoutTrailingPadding(payload.readableBytes(), padding)); - headersContinuation.processFragment(flags.endOfHeaders(), fragment, listener); + int len = lengthWithoutTrailingPadding(payloadEndIndex - payload.readerIndex(), padding); + headersContinuation.processFragment(flags.endOfHeaders(), payload, len, listener); resetHeadersContinuationIfEnd(flags.endOfHeaders()); } @@ -543,7 +543,7 @@ private void readSettingsFrame(ChannelHandlerContext ctx, ByteBuf payload, } } - private void readPushPromiseFrame(final ChannelHandlerContext ctx, ByteBuf payload, + private void readPushPromiseFrame(final ChannelHandlerContext ctx, ByteBuf payload, int payloadEndIndex, Http2FrameListener listener) throws Http2Exception { final int pushPromiseStreamId = streamId; final int padding = readPadding(payload); @@ -558,9 +558,9 @@ public int getStreamId() { } @Override - public void processFragment(boolean endOfHeaders, ByteBuf fragment, + public void processFragment(boolean endOfHeaders, ByteBuf fragment, int len, Http2FrameListener listener) throws Http2Exception { - headersBlockBuilder().addFragment(fragment, ctx.alloc(), endOfHeaders); + headersBlockBuilder().addFragment(fragment, len, ctx.alloc(), endOfHeaders); if (endOfHeaders) { listener.onPushPromiseRead(ctx, pushPromiseStreamId, promisedStreamId, headersBlockBuilder().headers(), padding); @@ -569,8 +569,8 @@ public void processFragment(boolean endOfHeaders, ByteBuf fragment, }; // Process the initial fragment, invoking the listener's callback if end of headers. - final ByteBuf fragment = payload.readSlice(lengthWithoutTrailingPadding(payload.readableBytes(), padding)); - headersContinuation.processFragment(flags.endOfHeaders(), fragment, listener); + int len = lengthWithoutTrailingPadding(payloadEndIndex - payload.readerIndex(), padding); + headersContinuation.processFragment(flags.endOfHeaders(), payload, len, listener); resetHeadersContinuationIfEnd(flags.endOfHeaders()); } @@ -583,11 +583,11 @@ private void readPingFrame(ChannelHandlerContext ctx, long data, } } - private static void readGoAwayFrame(ChannelHandlerContext ctx, ByteBuf payload, + private static void readGoAwayFrame(ChannelHandlerContext ctx, ByteBuf payload, int payloadEndIndex, Http2FrameListener listener) throws Http2Exception { int lastStreamId = readUnsignedInt(payload); long errorCode = payload.readUnsignedInt(); - ByteBuf debugData = payload.readSlice(payload.readableBytes()); + ByteBuf debugData = payload.readSlice(payloadEndIndex - payload.readerIndex()); listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); } @@ -601,18 +601,17 @@ private void readWindowUpdateFrame(ChannelHandlerContext ctx, ByteBuf payload, listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement); } - private void readContinuationFrame(ByteBuf payload, Http2FrameListener listener) + private void readContinuationFrame(ByteBuf payload, int payloadEndIndex, Http2FrameListener listener) throws Http2Exception { // Process the initial fragment, invoking the listener's callback if end of headers. - final ByteBuf continuationFragment = payload.readSlice(payload.readableBytes()); - headersContinuation.processFragment(flags.endOfHeaders(), continuationFragment, - listener); + headersContinuation.processFragment(flags.endOfHeaders(), payload, + payloadEndIndex - payload.readerIndex(), listener); resetHeadersContinuationIfEnd(flags.endOfHeaders()); } - private void readUnknownFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener) - throws Http2Exception { - payload = payload.readSlice(payload.readableBytes()); + private void readUnknownFrame(ChannelHandlerContext ctx, ByteBuf payload, + int payloadEndIndex, Http2FrameListener listener) throws Http2Exception { + payload = payload.readSlice(payloadEndIndex - payload.readerIndex()); listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); } @@ -664,7 +663,7 @@ private abstract class HeadersContinuation { * @param fragment the fragment of the header block to be added. * @param listener the listener to be notified if the header block is completed. */ - abstract void processFragment(boolean endOfHeaders, ByteBuf fragment, + abstract void processFragment(boolean endOfHeaders, ByteBuf fragment, int len, Http2FrameListener listener) throws Http2Exception; final HeadersBlockBuilder headersBlockBuilder() { @@ -704,33 +703,32 @@ private void headerSizeExceeded() throws Http2Exception { * This is used for an optimization for when the first fragment is the full * block. In that case, the buffer is used directly without copying. */ - final void addFragment(ByteBuf fragment, ByteBufAllocator alloc, boolean endOfHeaders) throws Http2Exception { + final void addFragment(ByteBuf fragment, int len, ByteBufAllocator alloc, + boolean endOfHeaders) throws Http2Exception { if (headerBlock == null) { - if (fragment.readableBytes() > headersDecoder.configuration().maxHeaderListSizeGoAway()) { + if (len > headersDecoder.configuration().maxHeaderListSizeGoAway()) { headerSizeExceeded(); } if (endOfHeaders) { // Optimization - don't bother copying, just use the buffer as-is. Need // to retain since we release when the header block is built. - headerBlock = fragment.retain(); + headerBlock = fragment.readRetainedSlice(len); } else { - headerBlock = alloc.buffer(fragment.readableBytes()); - headerBlock.writeBytes(fragment); + headerBlock = alloc.buffer(len).writeBytes(fragment, len); } return; } - if (headersDecoder.configuration().maxHeaderListSizeGoAway() - fragment.readableBytes() < + if (headersDecoder.configuration().maxHeaderListSizeGoAway() - len < headerBlock.readableBytes()) { headerSizeExceeded(); } - if (headerBlock.isWritable(fragment.readableBytes())) { + if (headerBlock.isWritable(len)) { // The buffer can hold the requested bytes, just write it directly. - headerBlock.writeBytes(fragment); + headerBlock.writeBytes(fragment, len); } else { // Allocate a new buffer that is big enough to hold the entire header block so far. - ByteBuf buf = alloc.buffer(headerBlock.readableBytes() + fragment.readableBytes()); - buf.writeBytes(headerBlock); - buf.writeBytes(fragment); + ByteBuf buf = alloc.buffer(headerBlock.readableBytes() + len); + buf.writeBytes(headerBlock).writeBytes(fragment, len); headerBlock.release(); headerBlock = buf; } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java index 3fa84137042d..e2b23963378a 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java @@ -61,6 +61,8 @@ import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS; import static io.netty.handler.codec.http2.Http2FrameTypes.WINDOW_UPDATE; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static java.lang.Math.max; import static java.lang.Math.min; @@ -195,7 +197,7 @@ public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf ctx.write(lastFrame, promiseAggregator.newPromise()); // Write the payload. - lastFrame = data.readSlice(maxFrameSize); + lastFrame = data.readableBytes() != maxFrameSize ? data.readSlice(maxFrameSize) : data; data = null; ctx.write(lastFrame, promiseAggregator.newPromise()); } @@ -614,15 +616,11 @@ private static void writePaddingLength(ByteBuf buf, int padding) { } private static void verifyStreamId(int streamId, String argumentName) { - if (streamId <= 0) { - throw new IllegalArgumentException(argumentName + " must be > 0"); - } + checkPositive(streamId, "streamId"); } private static void verifyStreamOrConnectionId(int streamId, String argumentName) { - if (streamId < 0) { - throw new IllegalArgumentException(argumentName + " must be >= 0"); - } + checkPositiveOrZero(streamId, "streamId"); } private static void verifyWeight(short weight) { @@ -638,9 +636,7 @@ private static void verifyErrorCode(long errorCode) { } private static void verifyWindowSizeIncrement(int windowSizeIncrement) { - if (windowSizeIncrement < 0) { - throw new IllegalArgumentException("WindowSizeIncrement must be >= 0"); - } + checkPositiveOrZero(windowSizeIncrement, "windowSizeIncrement"); } private static void verifyPingPayload(ByteBuf data) { diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2GoAwayFrame.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2GoAwayFrame.java index 77207673303f..2dbd738d04bd 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2GoAwayFrame.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2GoAwayFrame.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec.http2; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.buffer.ByteBuf; import io.netty.buffer.DefaultByteBufHolder; import io.netty.buffer.Unpooled; @@ -98,9 +100,7 @@ public int extraStreamIds() { @Override public Http2GoAwayFrame setExtraStreamIds(int extraStreamIds) { - if (extraStreamIds < 0) { - throw new IllegalArgumentException("extraStreamIds must be non-negative"); - } + checkPositiveOrZero(extraStreamIds, "extraStreamIds"); this.extraStreamIds = extraStreamIds; return this; } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoder.java index 684fef03528c..5d6320950ca3 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoder.java @@ -22,6 +22,7 @@ import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_LIST_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_INITIAL_HUFFMAN_DECODE_CAPACITY; import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; @UnstableApi @@ -31,6 +32,7 @@ public class DefaultHttp2HeadersDecoder implements Http2HeadersDecoder, Http2Hea private final HpackDecoder hpackDecoder; private final boolean validateHeaders; + private long maxHeaderListSizeGoAway; /** * Used to calculate an exponential moving average of header sizes to get an estimate of how large the data @@ -79,6 +81,8 @@ public DefaultHttp2HeadersDecoder(boolean validateHeaders, long maxHeaderListSiz DefaultHttp2HeadersDecoder(boolean validateHeaders, HpackDecoder hpackDecoder) { this.hpackDecoder = ObjectUtil.checkNotNull(hpackDecoder, "hpackDecoder"); this.validateHeaders = validateHeaders; + this.maxHeaderListSizeGoAway = + Http2CodecUtil.calculateMaxHeaderListSizeGoAway(hpackDecoder.getMaxHeaderListSize()); } @Override @@ -93,7 +97,12 @@ public long maxHeaderTableSize() { @Override public void maxHeaderListSize(long max, long goAwayMax) throws Http2Exception { - hpackDecoder.setMaxHeaderListSize(max, goAwayMax); + if (goAwayMax < max || goAwayMax < 0) { + throw connectionError(INTERNAL_ERROR, "Header List Size GO_AWAY %d must be non-negative and >= %d", + goAwayMax, max); + } + hpackDecoder.setMaxHeaderListSize(max); + this.maxHeaderListSizeGoAway = goAwayMax; } @Override @@ -103,7 +112,7 @@ public long maxHeaderListSize() { @Override public long maxHeaderListSizeGoAway() { - return hpackDecoder.getMaxHeaderListSizeGoAway(); + return maxHeaderListSizeGoAway; } @Override diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java index 74dc3ae31c44..3dccea5056e4 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java @@ -24,6 +24,7 @@ import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2Exception.streamError; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static java.lang.Math.max; import static java.lang.Math.min; import io.netty.buffer.ByteBuf; @@ -173,9 +174,7 @@ public void incrementWindowSize(Http2Stream stream, int delta) throws Http2Excep @Override public boolean consumeBytes(Http2Stream stream, int numBytes) throws Http2Exception { assert ctx != null && ctx.executor().inEventLoop(); - if (numBytes < 0) { - throw new IllegalArgumentException("numBytes must not be negative"); - } + checkPositiveOrZero(numBytes, "numBytes"); if (numBytes == 0) { return false; } @@ -296,7 +295,7 @@ private static boolean isClosed(Http2Stream stream) { * received. */ private final class AutoRefillState extends DefaultState { - public AutoRefillState(Http2Stream stream, int initialWindowSize) { + AutoRefillState(Http2Stream stream, int initialWindowSize) { super(stream, initialWindowSize); } @@ -349,7 +348,7 @@ private class DefaultState implements FlowState { private int lowerBound; private boolean endOfStream; - public DefaultState(Http2Stream stream, int initialWindowSize) { + DefaultState(Http2Stream stream, int initialWindowSize) { this.stream = stream; window(initialWindowSize); streamWindowUpdateRatio = windowUpdateRatio; @@ -613,7 +612,7 @@ private final class WindowUpdateVisitor implements Http2StreamVisitor { private CompositeStreamException compositeException; private final int delta; - public WindowUpdateVisitor(int delta) { + WindowUpdateVisitor(int delta) { this.delta = delta; } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java index 217cf8dc251d..ef6ec986ee8f 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java @@ -31,6 +31,7 @@ import static io.netty.handler.codec.http2.Http2Exception.streamError; import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static java.lang.Math.max; import static java.lang.Math.min; @@ -635,9 +636,7 @@ final void writePendingBytes() throws Http2Exception { } void initialWindowSize(int newWindowSize) throws Http2Exception { - if (newWindowSize < 0) { - throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize); - } + checkPositiveOrZero(newWindowSize, "newWindowSize"); final int delta = newWindowSize - initialWindowSize; initialWindowSize = newWindowSize; diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java index 78ef230c62ff..3e73bd68dd86 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java @@ -33,6 +33,7 @@ import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.streamError; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; /** * A HTTP2 frame listener that will decompress data frames according to the {@code content-encoding} header for each @@ -398,9 +399,7 @@ void incrementDecompressedBytes(int delta) { * @return The number of pre-decompressed bytes that have been consumed. */ int consumeBytes(int streamId, int decompressedBytes) throws Http2Exception { - if (decompressedBytes < 0) { - throw new IllegalArgumentException("decompressedBytes must not be negative: " + decompressedBytes); - } + checkPositiveOrZero(decompressedBytes, "decompressedBytes"); if (decompressed - decompressedBytes < 0) { throw streamError(streamId, INTERNAL_ERROR, "Attempting to return too many bytes for stream %d. decompressed: %d " + diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java index 67d6aa9944bf..389fc0483629 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java @@ -42,9 +42,9 @@ import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_TABLE_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.headerListSizeExceeded; import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; -import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.streamError; import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.getPseudoHeader; import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat; import static io.netty.util.AsciiString.EMPTY_STRING; @@ -84,7 +84,6 @@ final class HpackDecoder { private final HpackDynamicTable hpackDynamicTable; private final HpackHuffmanDecoder hpackHuffmanDecoder; - private long maxHeaderListSizeGoAway; private long maxHeaderListSize; private long maxDynamicTableSize; private long encoderMaxDynamicTableSize; @@ -108,7 +107,6 @@ final class HpackDecoder { */ HpackDecoder(long maxHeaderListSize, int initialHuffmanDecodeCapacity, int maxHeaderTableSize) { this.maxHeaderListSize = checkPositive(maxHeaderListSize, "maxHeaderListSize"); - this.maxHeaderListSizeGoAway = Http2CodecUtil.calculateMaxHeaderListSizeGoAway(maxHeaderListSize); maxDynamicTableSize = encoderMaxDynamicTableSize = maxHeaderTableSize; maxDynamicTableSizeChangeRequired = false; @@ -122,14 +120,21 @@ final class HpackDecoder { * This method assumes the entire header block is contained in {@code in}. */ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean validateHeaders) throws Http2Exception { + Http2HeadersSink sink = new Http2HeadersSink(streamId, headers, maxHeaderListSize, validateHeaders); + decode(in, sink); + + // Now that we've read all of our headers we can perform the validation steps. We must + // delay throwing until this point to prevent dynamic table corruption. + sink.finish(); + } + + private void decode(ByteBuf in, Sink sink) throws Http2Exception { int index = 0; - long headersLength = 0; int nameLength = 0; int valueLength = 0; byte state = READ_HEADER_REPRESENTATION; boolean huffmanEncoded = false; CharSequence name = null; - HeaderType headerType = null; IndexType indexType = IndexType.NONE; while (in.isReadable()) { switch (state) { @@ -150,9 +155,7 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid break; default: HpackHeaderField indexedHeader = getIndexedHeader(index); - headerType = validate(indexedHeader.name, headerType, validateHeaders); - headersLength = addHeader(headers, indexedHeader.name, indexedHeader.value, - headersLength); + sink.appendToHeaderList(indexedHeader.name, indexedHeader.value); } } else if ((b & 0x40) == 0x40) { // Literal Header Field with Incremental Indexing @@ -168,7 +171,6 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid default: // Index was stored as the prefix name = readName(index); - headerType = validate(name, headerType, validateHeaders); nameLength = name.length(); state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; } @@ -193,11 +195,10 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid state = READ_INDEXED_HEADER_NAME; break; default: - // Index was stored as the prefix - name = readName(index); - headerType = validate(name, headerType, validateHeaders); - nameLength = name.length(); - state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; + // Index was stored as the prefix + name = readName(index); + nameLength = name.length(); + state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; } } break; @@ -209,15 +210,13 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid case READ_INDEXED_HEADER: HpackHeaderField indexedHeader = getIndexedHeader(decodeULE128(in, index)); - headerType = validate(indexedHeader.name, headerType, validateHeaders); - headersLength = addHeader(headers, indexedHeader.name, indexedHeader.value, headersLength); + sink.appendToHeaderList(indexedHeader.name, indexedHeader.value); state = READ_HEADER_REPRESENTATION; break; case READ_INDEXED_HEADER_NAME: // Header Name matches an entry in the Header Table name = readName(decodeULE128(in, index)); - headerType = validate(name, headerType, validateHeaders); nameLength = name.length(); state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; break; @@ -229,9 +228,6 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid if (index == 0x7f) { state = READ_LITERAL_HEADER_NAME_LENGTH; } else { - if (index > maxHeaderListSizeGoAway - headersLength) { - headerListSizeExceeded(maxHeaderListSizeGoAway); - } nameLength = index; state = READ_LITERAL_HEADER_NAME; } @@ -241,9 +237,6 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid // Header Name is a Literal String nameLength = decodeULE128(in, index); - if (nameLength > maxHeaderListSizeGoAway - headersLength) { - headerListSizeExceeded(maxHeaderListSizeGoAway); - } state = READ_LITERAL_HEADER_NAME; break; @@ -254,7 +247,6 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid } name = readStringLiteral(in, nameLength, huffmanEncoded); - headerType = validate(name, headerType, validateHeaders); state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; break; @@ -268,15 +260,10 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid state = READ_LITERAL_HEADER_VALUE_LENGTH; break; case 0: - headerType = validate(name, headerType, validateHeaders); - headersLength = insertHeader(headers, name, EMPTY_STRING, indexType, headersLength); + insertHeader(sink, name, EMPTY_STRING, indexType); state = READ_HEADER_REPRESENTATION; break; default: - // Check new header size against max header size - if ((long) index + nameLength > maxHeaderListSizeGoAway - headersLength) { - headerListSizeExceeded(maxHeaderListSizeGoAway); - } valueLength = index; state = READ_LITERAL_HEADER_VALUE; } @@ -287,10 +274,6 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid // Header Value is a Literal String valueLength = decodeULE128(in, index); - // Check new header size against max header size - if ((long) valueLength + nameLength > maxHeaderListSizeGoAway - headersLength) { - headerListSizeExceeded(maxHeaderListSizeGoAway); - } state = READ_LITERAL_HEADER_VALUE; break; @@ -301,8 +284,7 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid } CharSequence value = readStringLiteral(in, valueLength, huffmanEncoded); - headerType = validate(name, headerType, validateHeaders); - headersLength = insertHeader(headers, name, value, indexType, headersLength); + insertHeader(sink, name, value, indexType); state = READ_HEADER_REPRESENTATION; break; @@ -311,13 +293,6 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean valid } } - // we have read all of our headers, and not exceeded maxHeaderListSizeGoAway see if we have - // exceeded our actual maxHeaderListSize. This must be done here to prevent dynamic table - // corruption - if (headersLength > maxHeaderListSize) { - headerListSizeExceeded(streamId, maxHeaderListSize, true); - } - if (state != READ_HEADER_REPRESENTATION) { throw connectionError(COMPRESSION_ERROR, "Incomplete header block fragment."); } @@ -341,27 +316,27 @@ public void setMaxHeaderTableSize(long maxHeaderTableSize) throws Http2Exception } } + /** + * @deprecated use {@link #setMaxHeaderListSize(long)}; {@code maxHeaderListSizeGoAway} is + * ignored + */ + @Deprecated public void setMaxHeaderListSize(long maxHeaderListSize, long maxHeaderListSizeGoAway) throws Http2Exception { - if (maxHeaderListSizeGoAway < maxHeaderListSize || maxHeaderListSizeGoAway < 0) { - throw connectionError(INTERNAL_ERROR, "Header List Size GO_AWAY %d must be positive and >= %d", - maxHeaderListSizeGoAway, maxHeaderListSize); - } + setMaxHeaderListSize(maxHeaderListSize); + } + + public void setMaxHeaderListSize(long maxHeaderListSize) throws Http2Exception { if (maxHeaderListSize < MIN_HEADER_LIST_SIZE || maxHeaderListSize > MAX_HEADER_LIST_SIZE) { throw connectionError(PROTOCOL_ERROR, "Header List Size must be >= %d and <= %d but was %d", MIN_HEADER_TABLE_SIZE, MAX_HEADER_TABLE_SIZE, maxHeaderListSize); } this.maxHeaderListSize = maxHeaderListSize; - this.maxHeaderListSizeGoAway = maxHeaderListSizeGoAway; } public long getMaxHeaderListSize() { return maxHeaderListSize; } - public long getMaxHeaderListSizeGoAway() { - return maxHeaderListSizeGoAway; - } - /** * Return the maximum table size. This is the maximum size allowed by both the encoder and the * decoder. @@ -400,26 +375,23 @@ private void setDynamicTableSize(long dynamicTableSize) throws Http2Exception { hpackDynamicTable.setCapacity(dynamicTableSize); } - private HeaderType validate(CharSequence name, HeaderType previousHeaderType, - final boolean validateHeaders) throws Http2Exception { - if (!validateHeaders) { - return null; - } - + private static HeaderType validate(int streamId, CharSequence name, + HeaderType previousHeaderType) throws Http2Exception { if (hasPseudoHeaderFormat(name)) { if (previousHeaderType == HeaderType.REGULAR_HEADER) { - throw connectionError(PROTOCOL_ERROR, "Pseudo-header field '%s' found after regular header.", name); + throw streamError(streamId, PROTOCOL_ERROR, + "Pseudo-header field '%s' found after regular header.", name); } final Http2Headers.PseudoHeaderName pseudoHeader = getPseudoHeader(name); if (pseudoHeader == null) { - throw connectionError(PROTOCOL_ERROR, "Invalid HTTP/2 pseudo-header '%s' encountered.", name); + throw streamError(streamId, PROTOCOL_ERROR, "Invalid HTTP/2 pseudo-header '%s' encountered.", name); } final HeaderType currentHeaderType = pseudoHeader.isRequestOnly() ? HeaderType.REQUEST_PSEUDO_HEADER : HeaderType.RESPONSE_PSEUDO_HEADER; if (previousHeaderType != null && currentHeaderType != previousHeaderType) { - throw connectionError(PROTOCOL_ERROR, "Mix of request and response pseudo-headers."); + throw streamError(streamId, PROTOCOL_ERROR, "Mix of request and response pseudo-headers."); } return currentHeaderType; @@ -450,9 +422,8 @@ private HpackHeaderField getIndexedHeader(int index) throws Http2Exception { throw INDEX_HEADER_ILLEGAL_INDEX_VALUE; } - private long insertHeader(Http2Headers headers, CharSequence name, CharSequence value, - IndexType indexType, long headerSize) throws Http2Exception { - headerSize = addHeader(headers, name, value, headerSize); + private void insertHeader(Sink sink, CharSequence name, CharSequence value, IndexType indexType) { + sink.appendToHeaderList(name, value); switch (indexType) { case NONE: @@ -466,18 +437,6 @@ private long insertHeader(Http2Headers headers, CharSequence name, CharSequence default: throw new Error("should not reach here"); } - - return headerSize; - } - - private long addHeader(Http2Headers headers, CharSequence name, CharSequence value, long headersLength) - throws Http2Exception { - headersLength += HpackHeaderField.sizeOf(name, value); - if (headersLength > maxHeaderListSizeGoAway) { - headerListSizeExceeded(maxHeaderListSizeGoAway); - } - headers.add(name, value); - return headersLength; } private CharSequence readStringLiteral(ByteBuf in, int length, boolean huffmanEncoded) throws Http2Exception { @@ -553,4 +512,58 @@ private enum HeaderType { REQUEST_PSEUDO_HEADER, RESPONSE_PSEUDO_HEADER } + + private interface Sink { + void appendToHeaderList(CharSequence name, CharSequence value); + void finish() throws Http2Exception; + } + + private static final class Http2HeadersSink implements Sink { + private final Http2Headers headers; + private final long maxHeaderListSize; + private final int streamId; + private final boolean validate; + private long headersLength; + private boolean exceededMaxLength; + private HeaderType previousType; + private Http2Exception validationException; + + Http2HeadersSink(int streamId, Http2Headers headers, long maxHeaderListSize, boolean validate) { + this.headers = headers; + this.maxHeaderListSize = maxHeaderListSize; + this.streamId = streamId; + this.validate = validate; + } + + @Override + public void finish() throws Http2Exception { + if (exceededMaxLength) { + headerListSizeExceeded(streamId, maxHeaderListSize, true); + } else if (validationException != null) { + throw validationException; + } + } + + @Override + public void appendToHeaderList(CharSequence name, CharSequence value) { + headersLength += HpackHeaderField.sizeOf(name, value); + exceededMaxLength |= headersLength > maxHeaderListSize; + + if (exceededMaxLength || validationException != null) { + // We don't store the header since we've already failed validation requirements. + return; + } + + if (validate) { + try { + previousType = HpackDecoder.validate(streamId, name, previousType); + } catch (Http2Exception ex) { + validationException = ex; + return; + } + } + + headers.add(name, value); + } + } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackEncoder.java index 7719072a499c..301a2c51cfb8 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackEncoder.java @@ -75,14 +75,14 @@ final class HpackEncoder { /** * Creates a new encoder. */ - public HpackEncoder(boolean ignoreMaxHeaderListSize) { + HpackEncoder(boolean ignoreMaxHeaderListSize) { this(ignoreMaxHeaderListSize, 16); } /** * Creates a new encoder. */ - public HpackEncoder(boolean ignoreMaxHeaderListSize, int arraySizeHint) { + HpackEncoder(boolean ignoreMaxHeaderListSize, int arraySizeHint) { this.ignoreMaxHeaderListSize = ignoreMaxHeaderListSize; maxHeaderTableSize = DEFAULT_HEADER_TABLE_SIZE; maxHeaderListSize = MAX_HEADER_LIST_SIZE; diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java index 317ee48063d8..7ebc8fd48888 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java @@ -67,9 +67,6 @@ public final class Http2CodecUtil { private static final ByteBuf CONNECTION_PREFACE = unreleasableBuffer(directBuffer(24).writeBytes("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".getBytes(UTF_8))) .asReadOnly(); - private static final ByteBuf EMPTY_PING = - unreleasableBuffer(directBuffer(PING_FRAME_PAYLOAD_LENGTH).writeZero(PING_FRAME_PAYLOAD_LENGTH)) - .asReadOnly(); private static final int MAX_PADDING_LENGTH_LENGTH = 1; public static final int DATA_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH; @@ -169,14 +166,6 @@ public static ByteBuf connectionPrefaceBuf() { return CONNECTION_PREFACE.retainedDuplicate(); } - /** - * Returns a buffer filled with all zeros that is the appropriate length for a PING frame. - */ - public static ByteBuf emptyPingBuf() { - // Return a duplicate so that modifications to the reader index will not affect the original buffer. - return EMPTY_PING.retainedDuplicate(); - } - /** * Iteratively looks through the causality chain for the given exception and returns the first * {@link Http2Exception} or {@code null} if none. diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java index f43820bf2392..7639e0840634 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java @@ -326,8 +326,15 @@ interface PropertyKey { /** * Indicates that a {@code GOAWAY} was received from the remote endpoint and sets the last known stream. + * @param lastKnownStream The Last-Stream-ID in the + * GOAWAY frame. + * @param errorCode the Error Code in the + * GOAWAY frame. + * @param message The Additional Debug Data in the + * GOAWAY frame. Note that reference count ownership + * belongs to the caller (ownership is not transferred to this method). */ - void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf message); + void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf message) throws Http2Exception; /** * Indicates whether or not a {@code GOAWAY} was sent to the remote endpoint. @@ -335,7 +342,15 @@ interface PropertyKey { boolean goAwaySent(); /** - * Indicates that a {@code GOAWAY} was sent to the remote endpoint and sets the last known stream. + * Updates the local state of this {@link Http2Connection} as a result of a {@code GOAWAY} to send to the remote + * endpoint. + * @param lastKnownStream The Last-Stream-ID in the + * GOAWAY frame. + * @param errorCode the Error Code in the + * GOAWAY frame. + * GOAWAY frame. Note that reference count ownership + * belongs to the caller (ownership is not transferred to this method). + * @return {@code true} if the corresponding {@code GOAWAY} frame should be sent to the remote endpoint. */ - void goAwaySent(int lastKnownStream, long errorCode, ByteBuf message); + boolean goAwaySent(int lastKnownStream, long errorCode, ByteBuf message) throws Http2Exception; } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java index 29f9def87e2c..618a4a6771bd 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java @@ -233,7 +233,7 @@ private final class PrefaceDecoder extends BaseDecoder { private ByteBuf clientPrefaceString; private boolean prefaceSent; - public PrefaceDecoder(ChannelHandlerContext ctx) throws Exception { + PrefaceDecoder(ChannelHandlerContext ctx) throws Exception { clientPrefaceString = clientPrefaceString(encoder.connection()); // This handler was just added to the context. In case it was handled after // the connection became active, send the connection preface now. @@ -527,8 +527,18 @@ public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { } } - void channelReadComplete0(ChannelHandlerContext ctx) throws Exception { - super.channelReadComplete(ctx); + final void channelReadComplete0(ChannelHandlerContext ctx) { + // Discard bytes of the cumulation buffer if needed. + discardSomeReadBytes(); + + // Ensure we never stale the HTTP/2 Channel. Flow-control is enforced by HTTP/2. + // + // See https://tools.ietf.org/html/rfc7540#section-5.2.2 + if (!ctx.channel().config().isAutoRead()) { + ctx.read(); + } + + ctx.fireChannelReadComplete(); } /** @@ -701,7 +711,9 @@ protected void onStreamError(ChannelHandlerContext ctx, boolean outbound, } if (stream == null) { - resetUnknownStream(ctx, streamId, http2Ex.error().code(), ctx.newPromise()); + if (!outbound || connection().local().mayHaveCreatedStream(streamId)) { + resetUnknownStream(ctx, streamId, http2Ex.error().code(), ctx.newPromise()); + } } else { resetStream(ctx, stream, http2Ex.error().code(), ctx.newPromise()); } @@ -792,47 +804,37 @@ public void operationComplete(ChannelFuture future) throws Exception { @Override public ChannelFuture goAway(final ChannelHandlerContext ctx, final int lastStreamId, final long errorCode, final ByteBuf debugData, ChannelPromise promise) { + promise = promise.unvoid(); + final Http2Connection connection = connection(); try { - promise = promise.unvoid(); - final Http2Connection connection = connection(); - if (connection().goAwaySent()) { - // Protect against re-entrancy. Could happen if writing the frame fails, and error handling - // treating this is a connection handler and doing a graceful shutdown... - if (lastStreamId == connection().remote().lastStreamKnownByPeer()) { - // Release the data and notify the promise - debugData.release(); - return promise.setSuccess(); - } - if (lastStreamId > connection.remote().lastStreamKnownByPeer()) { - throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " + - "sending multiple GOAWAY frames (was '%d', is '%d').", - connection.remote().lastStreamKnownByPeer(), lastStreamId); - } + if (!connection.goAwaySent(lastStreamId, errorCode, debugData)) { + debugData.release(); + promise.trySuccess(); + return promise; } + } catch (Throwable cause) { + debugData.release(); + promise.tryFailure(cause); + return promise; + } - connection.goAwaySent(lastStreamId, errorCode, debugData); - - // Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and - // result in an IllegalRefCountException. - debugData.retain(); - ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise); + // Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and + // result in an IllegalRefCountException. + debugData.retain(); + ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise); - if (future.isDone()) { - processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future); - } else { - future.addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future); - } - }); - } - - return future; - } catch (Throwable cause) { // Make sure to catch Throwable because we are doing a retain() in this method. - debugData.release(); - return promise.setFailure(cause); + if (future.isDone()) { + processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future); + } else { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future); + } + }); } + + return future; } /** diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataFrame.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataFrame.java index a0d55c487984..2c99e2c287a4 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataFrame.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataFrame.java @@ -37,7 +37,7 @@ public interface Http2DataFrame extends Http2StreamFrame, ByteBufHolder { ByteBuf content(); /** - * Returns the number of bytes that are flow-controlled initialy, so even if the {@link #content()} is consumed + * Returns the number of bytes that are flow-controlled initially, so even if the {@link #content()} is consumed * this will not change. */ int initialFlowControlledBytes(); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Exception.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Exception.java index 258f871eb840..c41600109322 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Exception.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Exception.java @@ -194,7 +194,7 @@ public static int streamId(Http2Exception e) { /** * Provides a hint as to if shutdown is justified, what type of shutdown should be executed. */ - public static enum ShutdownHint { + public enum ShutdownHint { /** * Do not shutdown the underlying channel. */ diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodec.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodec.java index da2b27a1d7bd..cf756cab25c7 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodec.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodec.java @@ -29,6 +29,8 @@ import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2GoAwayException; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -143,16 +145,17 @@ public class Http2FrameCodec extends Http2ConnectionHandler { private static final InternalLogger LOG = InternalLoggerFactory.getInstance(Http2FrameCodec.class); - private final PropertyKey streamKey; + protected final PropertyKey streamKey; private final PropertyKey upgradeKey; private final Integer initialFlowControlWindowSize; - private ChannelHandlerContext ctx; + ChannelHandlerContext ctx; /** Number of buffered streams if the {@link StreamBufferingEncoder} is used. **/ private int numBufferedStreams; - private DefaultHttp2FrameStream frameStreamToInitialize; + private final IntObjectMap frameStreamToInitializeMap = + new IntObjectHashMap(8); Http2FrameCodec(Http2ConnectionEncoder encoder, Http2ConnectionDecoder decoder, Http2Settings initialSettings) { super(decoder, encoder, initialSettings); @@ -358,12 +361,17 @@ private void writeHeadersFrame( } stream.id = streamId; - // TODO: This depends on the fact that the connection based API will create Http2Stream objects - // synchronously. We should investigate how to refactor this later on when we consolidate some layers. - assert frameStreamToInitialize == null; - frameStreamToInitialize = stream; + // Use a Map to store all pending streams as we may have multiple. This is needed as if we would store the + // stream in a field directly we may override the stored field before onStreamAdded(...) was called + // and so not correctly set the property for the buffered stream. + // + // See https://github.com/netty/netty/issues/8692 + Object old = frameStreamToInitializeMap.put(streamId, stream); + + // We should not re-use ids. + assert old == null; - // TODO(buchgr): Once Http2Stream2 and Http2Stream are merged this is no longer necessary. + // TODO(buchgr): Once Http2FrameStream and Http2Stream are merged this is no longer necessary. final ChannelPromise writePromise = ctx.newPromise(); encoder().writeHeaders(ctx, streamId, headersFrame.headers(), headersFrame.padding(), @@ -399,7 +407,7 @@ private void onStreamActive0(Http2Stream stream) { return; } - DefaultHttp2FrameStream stream2 = newStream().setStreamAndProperty(streamKey, stream); + Http2FrameStream stream2 = newStream().setStreamAndProperty(streamKey, stream); onHttp2StreamStateChanged(ctx, stream2); } @@ -407,9 +415,10 @@ private final class ConnectionListener extends Http2ConnectionAdapter { @Override public void onStreamAdded(Http2Stream stream) { - if (frameStreamToInitialize != null && stream.id() == frameStreamToInitialize.id()) { - frameStreamToInitialize.setStreamAndProperty(streamKey, stream); - frameStreamToInitialize = null; + DefaultHttp2FrameStream frameStream = frameStreamToInitializeMap.remove(stream.id()); + + if (frameStream != null) { + frameStream.setStreamAndProperty(streamKey, stream); } } @@ -420,7 +429,7 @@ public void onStreamActive(Http2Stream stream) { @Override public void onStreamClosed(Http2Stream stream) { - DefaultHttp2FrameStream stream2 = stream.getProperty(streamKey); + Http2FrameStream stream2 = stream.getProperty(streamKey); if (stream2 != null) { onHttp2StreamStateChanged(ctx, stream2); } @@ -428,7 +437,7 @@ public void onStreamClosed(Http2Stream stream) { @Override public void onStreamHalfClosed(Http2Stream stream) { - DefaultHttp2FrameStream stream2 = stream.getProperty(streamKey); + Http2FrameStream stream2 = stream.getProperty(streamKey); if (stream2 != null) { onHttp2StreamStateChanged(ctx, stream2); } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameLogger.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameLogger.java index 8a67e5f87ad0..791e991cb841 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameLogger.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameLogger.java @@ -20,7 +20,6 @@ import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.logging.LogLevel; -import io.netty.util.internal.StringUtil; import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogLevel; import io.netty.util.internal.logging.InternalLogger; @@ -60,34 +59,48 @@ private Http2FrameLogger(InternalLogLevel level, InternalLogger logger) { this.logger = checkNotNull(logger, "logger"); } + public boolean isEnabled() { + return logger.isEnabled(level); + } + public void logData(Direction direction, ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endStream) { - logger.log(level, "{} {} DATA: streamId={} padding={} endStream={} length={} bytes={}", ctx.channel(), - direction.name(), streamId, padding, endStream, data.readableBytes(), toString(data)); + if (isEnabled()) { + logger.log(level, "{} {} DATA: streamId={} padding={} endStream={} length={} bytes={}", ctx.channel(), + direction.name(), streamId, padding, endStream, data.readableBytes(), toString(data)); + } } public void logHeaders(Direction direction, ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, boolean endStream) { - logger.log(level, "{} {} HEADERS: streamId={} headers={} padding={} endStream={}", ctx.channel(), - direction.name(), streamId, headers, padding, endStream); + if (isEnabled()) { + logger.log(level, "{} {} HEADERS: streamId={} headers={} padding={} endStream={}", ctx.channel(), + direction.name(), streamId, headers, padding, endStream); + } } public void logHeaders(Direction direction, ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, short weight, boolean exclusive, int padding, boolean endStream) { - logger.log(level, "{} {} HEADERS: streamId={} headers={} streamDependency={} weight={} exclusive={} " + - "padding={} endStream={}", ctx.channel(), - direction.name(), streamId, headers, streamDependency, weight, exclusive, padding, endStream); + if (isEnabled()) { + logger.log(level, "{} {} HEADERS: streamId={} headers={} streamDependency={} weight={} exclusive={} " + + "padding={} endStream={}", ctx.channel(), + direction.name(), streamId, headers, streamDependency, weight, exclusive, padding, endStream); + } } public void logPriority(Direction direction, ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, boolean exclusive) { - logger.log(level, "{} {} PRIORITY: streamId={} streamDependency={} weight={} exclusive={}", ctx.channel(), - direction.name(), streamId, streamDependency, weight, exclusive); + if (isEnabled()) { + logger.log(level, "{} {} PRIORITY: streamId={} streamDependency={} weight={} exclusive={}", ctx.channel(), + direction.name(), streamId, streamDependency, weight, exclusive); + } } public void logRstStream(Direction direction, ChannelHandlerContext ctx, int streamId, long errorCode) { - logger.log(level, "{} {} RST_STREAM: streamId={} errorCode={}", ctx.channel(), - direction.name(), streamId, errorCode); + if (isEnabled()) { + logger.log(level, "{} {} RST_STREAM: streamId={} errorCode={}", ctx.channel(), + direction.name(), streamId, errorCode); + } } public void logSettingsAck(Direction direction, ChannelHandlerContext ctx) { @@ -95,48 +108,58 @@ public void logSettingsAck(Direction direction, ChannelHandlerContext ctx) { } public void logSettings(Direction direction, ChannelHandlerContext ctx, Http2Settings settings) { - logger.log(level, "{} {} SETTINGS: ack=false settings={}", ctx.channel(), direction.name(), settings); + if (isEnabled()) { + logger.log(level, "{} {} SETTINGS: ack=false settings={}", ctx.channel(), direction.name(), settings); + } } public void logPing(Direction direction, ChannelHandlerContext ctx, long data) { - logger.log(level, "{} {} PING: ack=false bytes={}", ctx.channel(), - direction.name(), data); + if (isEnabled()) { + logger.log(level, "{} {} PING: ack=false bytes={}", ctx.channel(), + direction.name(), data); + } } public void logPingAck(Direction direction, ChannelHandlerContext ctx, long data) { - logger.log(level, "{} {} PING: ack=true bytes={}", ctx.channel(), - direction.name(), data); + if (isEnabled()) { + logger.log(level, "{} {} PING: ack=true bytes={}", ctx.channel(), + direction.name(), data); + } } public void logPushPromise(Direction direction, ChannelHandlerContext ctx, int streamId, int promisedStreamId, Http2Headers headers, int padding) { - logger.log(level, "{} {} PUSH_PROMISE: streamId={} promisedStreamId={} headers={} padding={}", ctx.channel(), - direction.name(), streamId, promisedStreamId, headers, padding); + if (isEnabled()) { + logger.log(level, "{} {} PUSH_PROMISE: streamId={} promisedStreamId={} headers={} padding={}", + ctx.channel(), direction.name(), streamId, promisedStreamId, headers, padding); + } } public void logGoAway(Direction direction, ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) { - logger.log(level, "{} {} GO_AWAY: lastStreamId={} errorCode={} length={} bytes={}", ctx.channel(), - direction.name(), lastStreamId, errorCode, debugData.readableBytes(), toString(debugData)); + if (isEnabled()) { + logger.log(level, "{} {} GO_AWAY: lastStreamId={} errorCode={} length={} bytes={}", ctx.channel(), + direction.name(), lastStreamId, errorCode, debugData.readableBytes(), toString(debugData)); + } } public void logWindowsUpdate(Direction direction, ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) { - logger.log(level, "{} {} WINDOW_UPDATE: streamId={} windowSizeIncrement={}", ctx.channel(), - direction.name(), streamId, windowSizeIncrement); + if (isEnabled()) { + logger.log(level, "{} {} WINDOW_UPDATE: streamId={} windowSizeIncrement={}", ctx.channel(), + direction.name(), streamId, windowSizeIncrement); + } } public void logUnknownFrame(Direction direction, ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf data) { - logger.log(level, "{} {} UNKNOWN: frameType={} streamId={} flags={} length={} bytes={}", ctx.channel(), - direction.name(), frameType & 0xFF, streamId, flags.value(), data.readableBytes(), toString(data)); + if (isEnabled()) { + logger.log(level, "{} {} UNKNOWN: frameType={} streamId={} flags={} length={} bytes={}", ctx.channel(), + direction.name(), frameType & 0xFF, streamId, flags.value(), data.readableBytes(), toString(data)); + } } private String toString(ByteBuf buf) { - if (!logger.isEnabled(level)) { - return StringUtil.EMPTY_STRING; - } - if (level == InternalLogLevel.TRACE || buf.readableBytes() <= BUFFER_LENGTH_THRESHOLD) { // Log the entire buffer. return ByteBufUtil.hexDump(buf); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodec.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodec.java index dd5b1c5d1b9c..cec2dd94a848 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodec.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodec.java @@ -31,10 +31,10 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelConfig; import io.netty.channel.DefaultChannelPipeline; -import io.netty.channel.DefaultMaxMessagesRecvByteBufAllocator; import io.netty.channel.EventLoop; import io.netty.channel.MessageSizeEstimator; import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.RecvByteBufAllocator.Handle; import io.netty.channel.VoidChannelPromise; import io.netty.channel.WriteBufferWaterMark; import io.netty.util.DefaultAttributeMap; @@ -43,13 +43,19 @@ import io.netty.util.internal.StringUtil; import io.netty.util.internal.ThrowableUtil; import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; import java.net.SocketAddress; import java.nio.channels.ClosedChannelException; import java.util.ArrayDeque; import java.util.Queue; +import java.util.concurrent.RejectedExecutionException; +import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_STREAM_ID; import static io.netty.handler.codec.http2.Http2CodecUtil.isStreamIdValid; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static java.lang.Math.min; /** @@ -100,9 +106,11 @@ @UnstableApi public class Http2MultiplexCodec extends Http2FrameCodec { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultHttp2StreamChannel.class); + private static final ChannelFutureListener CHILD_CHANNEL_REGISTRATION_LISTENER = new ChannelFutureListener() { @Override - public void operationComplete(ChannelFuture future) throws Exception { + public void operationComplete(ChannelFuture future) { registerDone(future); } }; @@ -139,20 +147,8 @@ public Handle newHandle() { } } - private static final class Http2StreamChannelRecvByteBufAllocator extends DefaultMaxMessagesRecvByteBufAllocator { - - @Override - public MaxMessageHandle newHandle() { - return new MaxMessageHandle() { - @Override - public int guess() { - return 1024; - } - }; - } - } - private final ChannelHandler inboundStreamHandler; + private final ChannelHandler upgradeStreamHandler; private int initialOutboundStreamWindow = Http2CodecUtil.DEFAULT_WINDOW_SIZE; private boolean parentReadInProgress; @@ -168,9 +164,25 @@ public int guess() { Http2MultiplexCodec(Http2ConnectionEncoder encoder, Http2ConnectionDecoder decoder, Http2Settings initialSettings, - ChannelHandler inboundStreamHandler) { + ChannelHandler inboundStreamHandler, + ChannelHandler upgradeStreamHandler) { super(encoder, decoder, initialSettings); this.inboundStreamHandler = inboundStreamHandler; + this.upgradeStreamHandler = upgradeStreamHandler; + } + + @Override + public void onHttpClientUpgrade() throws Http2Exception { + // We must have an upgrade handler or else we can't handle the stream + if (upgradeStreamHandler == null) { + throw connectionError(INTERNAL_ERROR, "Client is misconfigured for upgrade requests"); + } + // Creates the Http2Stream in the Connection. + super.onHttpClientUpgrade(); + // Now make a new FrameStream, set it's underlying Http2Stream, and initialize it. + Http2MultiplexCodecStream codecStream = newStream(); + codecStream.setStreamAndProperty(streamKey, connection().stream(HTTP_UPGRADE_STREAM_ID)); + onHttp2UpgradeStreamInitialized(ctx, codecStream); } private static void registerDone(ChannelFuture future) { @@ -204,7 +216,7 @@ public final void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { while (ch != null) { DefaultHttp2StreamChannel curr = ch; ch = curr.next; - curr.next = null; + curr.next = curr.previous = null; } head = tail = null; } @@ -218,7 +230,7 @@ Http2MultiplexCodecStream newStream() { final void onHttp2Frame(ChannelHandlerContext ctx, Http2Frame frame) { if (frame instanceof Http2StreamFrame) { Http2StreamFrame streamFrame = (Http2StreamFrame) frame; - onHttp2StreamFrame(((Http2MultiplexCodecStream) streamFrame.stream()).channel, streamFrame); + ((Http2MultiplexCodecStream) streamFrame.stream()).channel.fireChildRead(streamFrame); } else if (frame instanceof Http2GoAwayFrame) { onHttp2GoAwayFrame(ctx, (Http2GoAwayFrame) frame); // Allow other handlers to act on GOAWAY frame @@ -236,6 +248,22 @@ final void onHttp2Frame(ChannelHandlerContext ctx, Http2Frame frame) { } } + private void onHttp2UpgradeStreamInitialized(ChannelHandlerContext ctx, Http2MultiplexCodecStream stream) { + assert stream.state() == Http2Stream.State.HALF_CLOSED_LOCAL; + DefaultHttp2StreamChannel ch = new DefaultHttp2StreamChannel(stream, true); + ch.outboundClosed = true; + + // Add our upgrade handler to the channel and then register the channel. + // The register call fires the channelActive, etc. + ch.pipeline().addLast(upgradeStreamHandler); + ChannelFuture future = ctx.channel().eventLoop().register(ch); + if (future.isDone()) { + registerDone(future); + } else { + future.addListener(CHILD_CHANNEL_REGISTRATION_LISTENER); + } + } + @Override final void onHttp2StreamStateChanged(ChannelHandlerContext ctx, Http2FrameStream stream) { Http2MultiplexCodecStream s = (Http2MultiplexCodecStream) stream; @@ -289,38 +317,48 @@ final void onHttp2FrameStreamException(ChannelHandlerContext ctx, Http2FrameStre } } - private void onHttp2StreamFrame(DefaultHttp2StreamChannel childChannel, Http2StreamFrame frame) { - switch (childChannel.fireChildRead(frame)) { - case READ_PROCESSED_BUT_STOP_READING: - childChannel.fireChildReadComplete(); - break; - case READ_PROCESSED_OK_TO_PROCESS_MORE: - addChildChannelToReadPendingQueue(childChannel); - break; - case READ_IGNORED_CHANNEL_INACTIVE: - case READ_QUEUED: - // nothing to do: - break; - default: - throw new Error(); + private boolean isChildChannelInReadPendingQueue(DefaultHttp2StreamChannel childChannel) { + return childChannel.previous != null || childChannel.next != null || head == childChannel; + } + + final void tryAddChildChannelToReadPendingQueue(DefaultHttp2StreamChannel childChannel) { + if (!isChildChannelInReadPendingQueue(childChannel)) { + addChildChannelToReadPendingQueue(childChannel); } } final void addChildChannelToReadPendingQueue(DefaultHttp2StreamChannel childChannel) { - if (!childChannel.fireChannelReadPending) { - assert childChannel.next == null; + if (tail == null) { + assert head == null; + tail = head = childChannel; + } else { + childChannel.previous = tail; + tail.next = childChannel; + tail = childChannel; + } + } - if (tail == null) { - assert head == null; - tail = head = childChannel; - } else { - tail.next = childChannel; - tail = childChannel; - } - childChannel.fireChannelReadPending = true; + private void tryRemoveChildChannelFromReadPendingQueue(DefaultHttp2StreamChannel childChannel) { + if (isChildChannelInReadPendingQueue(childChannel)) { + removeChildChannelFromReadPendingQueue(childChannel); } } + private void removeChildChannelFromReadPendingQueue(DefaultHttp2StreamChannel childChannel) { + DefaultHttp2StreamChannel previous = childChannel.previous; + if (childChannel.next != null) { + childChannel.next.previous = previous; + } else { + tail = tail.previous; // If there is no next, this childChannel is the tail, so move the tail back. + } + if (previous != null) { + previous.next = childChannel.next; + } else { + head = head.next; // If there is no previous, this childChannel is the head, so move the tail forward. + } + childChannel.next = childChannel.previous = null; + } + private void onHttp2GoAwayFrame(ChannelHandlerContext ctx, final Http2GoAwayFrame goAwayFrame) { try { forEachActiveStream(new Http2FrameStreamVisitor() { @@ -345,8 +383,14 @@ public boolean visit(Http2FrameStream stream) { */ @Override public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception { - parentReadInProgress = false; - onChannelReadComplete(ctx); + try { + onChannelReadComplete(ctx); + } finally { + parentReadInProgress = false; + tail = head = null; + // We always flush as this is what Http2ConnectionHandler does for now. + flush0(ctx); + } channelReadComplete0(ctx); } @@ -360,58 +404,24 @@ final void onChannelReadComplete(ChannelHandlerContext ctx) { // If we have many child channel we can optimize for the case when multiple call flush() in // channelReadComplete(...) callbacks and only do it once as otherwise we will end-up with multiple // write calls on the socket which is expensive. - try { - DefaultHttp2StreamChannel current = head; - while (current != null) { - DefaultHttp2StreamChannel childChannel = current; - if (childChannel.fireChannelReadPending) { - // Clear early in case fireChildReadComplete() causes it to need to be re-processed - childChannel.fireChannelReadPending = false; - childChannel.fireChildReadComplete(); - } - childChannel.next = null; - current = current.next; - } - } finally { - tail = head = null; - - // We always flush as this is what Http2ConnectionHandler does for now. - flush0(ctx); + DefaultHttp2StreamChannel current = head; + while (current != null) { + DefaultHttp2StreamChannel childChannel = current; + // Clear early in case fireChildReadComplete() causes it to need to be re-processed + current = current.next; + childChannel.next = childChannel.previous = null; + childChannel.fireChildReadComplete(); } } - // Allow to override for testing - void flush0(ChannelHandlerContext ctx) { + final void flush0(ChannelHandlerContext ctx) { flush(ctx); } - /** - * Return bytes to flow control. - *

- * Package private to allow to override for testing - * @param ctx The {@link ChannelHandlerContext} associated with the parent channel. - * @param stream The object representing the HTTP/2 stream. - * @param bytes The number of bytes to return to flow control. - * @return {@code true} if a frame has been written as a result of this method call. - * @throws Http2Exception If this operation violates the flow control limits. - */ - boolean onBytesConsumed(@SuppressWarnings("unused") ChannelHandlerContext ctx, - Http2FrameStream stream, int bytes) throws Http2Exception { - return consumeBytes(stream.id(), bytes); - } - - // Allow to extend for testing - static class Http2MultiplexCodecStream extends DefaultHttp2FrameStream { + static final class Http2MultiplexCodecStream extends DefaultHttp2FrameStream { DefaultHttp2StreamChannel channel; } - private enum ReadState { - READ_QUEUED, - READ_IGNORED_CHANNEL_INACTIVE, - READ_PROCESSED_BUT_STOP_READING, - READ_PROCESSED_OK_TO_PROCESS_MORE - } - private boolean initialWritability(DefaultHttp2FrameStream stream) { // If the stream id is not valid yet we will just mark the channel as writable as we will be notified // about non-writability state as soon as the first Http2HeaderFrame is written (if needed). @@ -419,6 +429,26 @@ private boolean initialWritability(DefaultHttp2FrameStream stream) { return !isStreamIdValid(stream.id()) || isWritable(stream); } + /** + * The current status of the read-processing for a {@link Http2StreamChannel}. + */ + private enum ReadStatus { + /** + * No read in progress and no read was requested (yet) + */ + IDLE, + + /** + * Reading in progress + */ + IN_PROGRESS, + + /** + * A read operation was requested. + */ + REQUESTED + } + // TODO: Handle writability changes due writing from outside the eventloop. private final class DefaultHttp2StreamChannel extends DefaultAttributeMap implements Http2StreamChannel { private final Http2StreamChannelConfig config = new Http2StreamChannelConfig(this); @@ -434,24 +464,26 @@ private final class DefaultHttp2StreamChannel extends DefaultAttributeMap implem private volatile boolean writable; private boolean outboundClosed; - private boolean closePending; - private boolean readInProgress; + + /** + * This variable represents if a read is in progress for the current channel or was requested. + * Note that depending upon the {@link RecvByteBufAllocator} behavior a read may extend beyond the + * {@link Http2ChannelUnsafe#beginRead()} method scope. The {@link Http2ChannelUnsafe#beginRead()} loop may + * drain all pending data, and then if the parent channel is reading this channel may still accept frames. + */ + private ReadStatus readStatus = ReadStatus.IDLE; + private Queue inboundBuffer; /** {@code true} after the first HEADERS frame has been written **/ private boolean firstFrameWritten; - /** {@code true} if a close without an error was initiated **/ - private boolean streamClosedWithoutError; - - // Keeps track of flush calls in channelReadComplete(...) and aggregate these. - private boolean inFireChannelReadComplete; - - boolean fireChannelReadPending; - - // Holds the reference to the next DefaultHttp2StreamChannel that should be processed in - // channelReadComplete(...) + // Currently the child channel and parent channel are always on the same EventLoop thread. This allows us to + // extend the read loop of a child channel if the child channel drains its queued data during read, and the + // parent channel is still in its read loop. The next/previous links build a doubly linked list that the parent + // channel will iterate in its channelReadComplete to end the read cycle for each child channel in the list. DefaultHttp2StreamChannel next; + DefaultHttp2StreamChannel previous; DefaultHttp2StreamChannel(DefaultHttp2FrameStream stream, boolean outbound) { this.stream = stream; @@ -479,13 +511,10 @@ public Http2FrameStream stream() { } void streamClosed() { - streamClosedWithoutError = true; - if (readInProgress) { - // Just call closeForcibly() as this will take care of fireChannelInactive(). - unsafe().closeForcibly(); - } else { - closePending = true; - } + unsafe.readEOS(); + // Attempt to drain any queued data from the queue and deliver it to the application before closing this + // channel. + unsafe.doBeginRead(); } @Override @@ -729,49 +758,48 @@ void writabilityChanged(boolean writable) { * Receive a read message. This does not notify handlers unless a read is in progress on the * channel. */ - ReadState fireChildRead(Http2Frame frame) { + void fireChildRead(Http2Frame frame) { assert eventLoop().inEventLoop(); if (!isActive()) { ReferenceCountUtil.release(frame); - return ReadState.READ_IGNORED_CHANNEL_INACTIVE; - } - if (readInProgress && (inboundBuffer == null || inboundBuffer.isEmpty())) { - // Check for null because inboundBuffer doesn't support null; we want to be consistent - // for what values are supported. - RecvByteBufAllocator.ExtendedHandle allocHandle = unsafe.recvBufAllocHandle(); + } else if (readStatus != ReadStatus.IDLE) { + // If a read is in progress or has been requested, there cannot be anything in the queue, + // otherwise we would have drained it from the queue and processed it during the read cycle. + assert inboundBuffer == null || inboundBuffer.isEmpty(); + final Handle allocHandle = unsafe.recvBufAllocHandle(); unsafe.doRead0(frame, allocHandle); - return allocHandle.continueReading() ? - ReadState.READ_PROCESSED_OK_TO_PROCESS_MORE : ReadState.READ_PROCESSED_BUT_STOP_READING; + // We currently don't need to check for readEOS because the parent channel and child channel are limited + // to the same EventLoop thread. There are a limited number of frame types that may come after EOS is + // read (unknown, reset) and the trade off is less conditionals for the hot path (headers/data) at the + // cost of additional readComplete notifications on the rare path. + if (allocHandle.continueReading()) { + tryAddChildChannelToReadPendingQueue(this); + } else { + tryRemoveChildChannelFromReadPendingQueue(this); + unsafe.notifyReadComplete(allocHandle); + } } else { if (inboundBuffer == null) { inboundBuffer = new ArrayDeque(4); } inboundBuffer.add(frame); - return ReadState.READ_QUEUED; } } void fireChildReadComplete() { assert eventLoop().inEventLoop(); - try { - if (readInProgress) { - inFireChannelReadComplete = true; - readInProgress = false; - unsafe().recvBufAllocHandle().readComplete(); - pipeline().fireChannelReadComplete(); - } - } finally { - inFireChannelReadComplete = false; - } + assert readStatus != ReadStatus.IDLE; + unsafe.notifyReadComplete(unsafe.recvBufAllocHandle()); } private final class Http2ChannelUnsafe implements Unsafe { private final VoidChannelPromise unsafeVoidPromise = new VoidChannelPromise(DefaultHttp2StreamChannel.this, false); @SuppressWarnings("deprecation") - private RecvByteBufAllocator.ExtendedHandle recvHandle; + private Handle recvHandle; private boolean writeDoneAndNoFlush; private boolean closeInitiated; + private boolean readEOS; @Override public void connect(final SocketAddress remoteAddress, @@ -783,9 +811,10 @@ public void connect(final SocketAddress remoteAddress, } @Override - public RecvByteBufAllocator.ExtendedHandle recvBufAllocHandle() { + public Handle recvBufAllocHandle() { if (recvHandle == null) { - recvHandle = (RecvByteBufAllocator.ExtendedHandle) config().getRecvByteBufAllocator().newHandle(); + recvHandle = config().getRecvByteBufAllocator().newHandle(); + recvHandle.reset(config()); } return recvHandle; } @@ -850,7 +879,7 @@ public void close(final ChannelPromise promise) { // This means close() was called before so we just register a listener and return closePromise.addListener(new ChannelFutureListener() { @Override - public void operationComplete(ChannelFuture future) throws Exception { + public void operationComplete(ChannelFuture future) { promise.setSuccess(); } }); @@ -859,12 +888,13 @@ public void operationComplete(ChannelFuture future) throws Exception { } closeInitiated = true; - closePending = false; - fireChannelReadPending = false; + tryRemoveChildChannelFromReadPendingQueue(DefaultHttp2StreamChannel.this); + + final boolean wasActive = isActive(); - // Only ever send a reset frame if the connection is still alive as otherwise it makes no sense at - // all anyway. - if (parent().isActive() && !streamClosedWithoutError && isStreamIdValid(stream().id())) { + // Only ever send a reset frame if the connection is still alive and if the stream may have existed + // as otherwise we may send a RST on a stream in an invalid state and cause a connection error. + if (parent().isActive() && !readEOS && connection().streamMayHaveExisted(stream().id())) { Http2StreamFrame resetFrame = new DefaultHttp2ResetFrame(Http2Error.CANCEL).stream(stream()); write(resetFrame, unsafe().voidPromise()); flush(); @@ -885,10 +915,7 @@ public void operationComplete(ChannelFuture future) throws Exception { closePromise.setSuccess(); promise.setSuccess(); - pipeline().fireChannelInactive(); - if (isRegistered()) { - deregister(unsafe().voidPromise()); - } + fireChannelInactiveAndDeregister(voidPromise(), wasActive); if (recvHandle != null) { recvHandle.channelClosed(); } @@ -901,83 +928,157 @@ public void closeForcibly() { @Override public void deregister(ChannelPromise promise) { + fireChannelInactiveAndDeregister(promise, false); + } + + private void fireChannelInactiveAndDeregister(final ChannelPromise promise, + final boolean fireChannelInactive) { if (!promise.setUncancellable()) { return; } - if (registered) { - registered = true; + + if (!registered) { promise.setSuccess(); - pipeline().fireChannelUnregistered(); - } else { - promise.setFailure(new IllegalStateException("Not registered")); + return; + } + + // As a user may call deregister() from within any method while doing processing in the ChannelPipeline, + // we need to ensure we do the actual deregister operation later. This is necessary to preserve the + // behavior of the AbstractChannel, which always invokes channelUnregistered and channelInactive + // events 'later' to ensure the current events in the handler are completed before these events. + // + // See: + // https://github.com/netty/netty/issues/4435 + invokeLater(new Runnable() { + @Override + public void run() { + if (fireChannelInactive) { + pipeline.fireChannelInactive(); + } + // The user can fire `deregister` events multiple times but we only want to fire the pipeline + // event if the channel was actually registered. + if (registered) { + registered = false; + pipeline.fireChannelUnregistered(); + } + safeSetSuccess(promise); + } + }); + } + + private void safeSetSuccess(ChannelPromise promise) { + if (!(promise instanceof VoidChannelPromise) && !promise.trySuccess()) { + logger.warn("Failed to mark a promise as success because it is done already: {}", promise); + } + } + + private void invokeLater(Runnable task) { + try { + // This method is used by outbound operation implementations to trigger an inbound event later. + // They do not trigger an inbound event immediately because an outbound operation might have been + // triggered by another inbound event handler method. If fired immediately, the call stack + // will look like this for example: + // + // handlerA.inboundBufferUpdated() - (1) an inbound handler method closes a connection. + // -> handlerA.ctx.close() + // -> channel.unsafe.close() + // -> handlerA.channelInactive() - (2) another inbound handler method called while in (1) yet + // + // which means the execution of two inbound handler methods of the same handler overlap undesirably. + eventLoop().execute(task); + } catch (RejectedExecutionException e) { + logger.warn("Can't invoke task later as EventLoop rejected it", e); } } @Override public void beginRead() { - if (readInProgress || !isActive()) { + if (!isActive()) { return; } - readInProgress = true; + switch (readStatus) { + case IDLE: + readStatus = ReadStatus.IN_PROGRESS; + doBeginRead(); + break; + case IN_PROGRESS: + readStatus = ReadStatus.REQUESTED; + break; + default: + break; + } + } - final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); - allocHandle.reset(config()); - if (inboundBuffer == null || inboundBuffer.isEmpty()) { - if (closePending) { + void doBeginRead() { + Object message; + if (inboundBuffer == null || (message = inboundBuffer.poll()) == null) { + if (readEOS) { unsafe.closeForcibly(); } - return; + } else { + final Handle allocHandle = recvBufAllocHandle(); + allocHandle.reset(config()); + boolean continueReading = false; + do { + doRead0((Http2Frame) message, allocHandle); + } while ((readEOS || (continueReading = allocHandle.continueReading())) && + (message = inboundBuffer.poll()) != null); + + if (continueReading && parentReadInProgress && !readEOS) { + // Currently the parent and child channel are on the same EventLoop thread. If the parent is + // currently reading it is possile that more frames will be delivered to this child channel. In + // the case that this child channel still wants to read we delay the channelReadComplete on this + // child channel until the parent is done reading. + assert !isChildChannelInReadPendingQueue(DefaultHttp2StreamChannel.this); + addChildChannelToReadPendingQueue(DefaultHttp2StreamChannel.this); + } else { + notifyReadComplete(allocHandle); + } } + } - // We have already checked that the queue is not empty, so before this value is used it will always be - // set by allocHandle.continueReading(). - boolean continueReading; - do { - Object m = inboundBuffer.poll(); - if (m == null) { - continueReading = false; - break; - } - doRead0((Http2Frame) m, allocHandle); - } while (continueReading = allocHandle.continueReading()); + void readEOS() { + readEOS = true; + } - if (continueReading && parentReadInProgress) { - // We don't know if more frames will be delivered in the parent channel's read loop, so add this - // channel to the channelReadComplete queue to be notified later. - addChildChannelToReadPendingQueue(DefaultHttp2StreamChannel.this); + void notifyReadComplete(Handle allocHandle) { + assert next == null && previous == null; + if (readStatus == ReadStatus.REQUESTED) { + readStatus = ReadStatus.IN_PROGRESS; } else { - // Reading data may result in frames being written (e.g. WINDOW_UPDATE, RST, etc..). If the parent - // channel is not currently reading we need to force a flush at the child channel, because we cannot - // rely upon flush occurring in channelReadComplete on the parent channel. - readInProgress = false; - allocHandle.readComplete(); - pipeline().fireChannelReadComplete(); - flush(); - if (closePending) { - unsafe.closeForcibly(); - } + readStatus = ReadStatus.IDLE; + } + allocHandle.readComplete(); + pipeline().fireChannelReadComplete(); + // Reading data may result in frames being written (e.g. WINDOW_UPDATE, RST, etc..). If the parent + // channel is not currently reading we need to force a flush at the child channel, because we cannot + // rely upon flush occurring in channelReadComplete on the parent channel. + flush(); + if (readEOS) { + unsafe.closeForcibly(); } } @SuppressWarnings("deprecation") - void doRead0(Http2Frame frame, RecvByteBufAllocator.Handle allocHandle) { - int numBytesToBeConsumed = 0; + void doRead0(Http2Frame frame, Handle allocHandle) { + pipeline().fireChannelRead(frame); + allocHandle.incMessagesRead(1); + if (frame instanceof Http2DataFrame) { - numBytesToBeConsumed = ((Http2DataFrame) frame).initialFlowControlledBytes(); + final int numBytesToBeConsumed = ((Http2DataFrame) frame).initialFlowControlledBytes(); + allocHandle.attemptedBytesRead(numBytesToBeConsumed); allocHandle.lastBytesRead(numBytesToBeConsumed); + if (numBytesToBeConsumed != 0) { + try { + writeDoneAndNoFlush |= consumeBytes(stream.id(), numBytesToBeConsumed); + } catch (Http2Exception e) { + pipeline().fireExceptionCaught(e); + } + } } else { + allocHandle.attemptedBytesRead(MIN_HTTP2_FRAME_SIZE); allocHandle.lastBytesRead(MIN_HTTP2_FRAME_SIZE); } - allocHandle.incMessagesRead(1); - pipeline().fireChannelRead(frame); - - if (numBytesToBeConsumed != 0) { - try { - writeDoneAndNoFlush |= onBytesConsumed(ctx, stream, numBytesToBeConsumed); - } catch (Http2Exception e) { - pipeline().fireExceptionCaught(e); - } - } } @Override @@ -1014,7 +1115,7 @@ public void write(Object msg, final ChannelPromise promise) { } else { future.addListener(new ChannelFutureListener() { @Override - public void operationComplete(ChannelFuture future) throws Exception { + public void operationComplete(ChannelFuture future) { firstWriteComplete(future, promise); } }); @@ -1036,7 +1137,7 @@ public void operationComplete(ChannelFuture future) throws Exception { } else { future.addListener(new ChannelFutureListener() { @Override - public void operationComplete(ChannelFuture future) throws Exception { + public void operationComplete(ChannelFuture future) { writeComplete(future, promise); } }); @@ -1056,9 +1157,9 @@ private void firstWriteComplete(ChannelFuture future, ChannelPromise promise) { writabilityChanged(Http2MultiplexCodec.this.isWritable(stream)); promise.setSuccess(); } else { - promise.setFailure(wrapStreamClosedError(cause)); // If the first write fails there is not much we can do, just close closeForcibly(); + promise.setFailure(wrapStreamClosedError(cause)); } } @@ -1068,8 +1169,6 @@ private void writeComplete(ChannelFuture future, ChannelPromise promise) { promise.setSuccess(); } else { Throwable error = wrapStreamClosedError(cause); - promise.setFailure(error); - if (error instanceof ClosedChannelException) { if (config.isAutoClose()) { // Close channel if needed. @@ -1078,6 +1177,7 @@ private void writeComplete(ChannelFuture future, ChannelPromise promise) { outboundClosed = true; } } + promise.setFailure(error); } } @@ -1108,18 +1208,16 @@ private ChannelFuture write0(Object msg) { @Override public void flush() { - if (!writeDoneAndNoFlush) { + // If we are currently in the parent channel's read loop we should just ignore the flush. + // We will ensure we trigger ctx.flush() after we processed all Channels later on and + // so aggregate the flushes. This is done as ctx.flush() is expensive when as it may trigger an + // write(...) or writev(...) operation on the socket. + if (!writeDoneAndNoFlush || parentReadInProgress) { // There is nothing to flush so this is a NOOP. return; } try { - // If we are currently in the channelReadComplete(...) call we should just ignore the flush. - // We will ensure we trigger ctx.flush() after we processed all Channels later on and - // so aggregate the flushes. This is done as ctx.flush() is expensive when as it may trigger an - // write(...) or writev(...) operation on the socket. - if (!inFireChannelReadComplete) { - flush0(ctx); - } + flush0(ctx); } finally { writeDoneAndNoFlush = false; } @@ -1143,10 +1241,8 @@ public ChannelOutboundBuffer outboundBuffer() { * changes. */ private final class Http2StreamChannelConfig extends DefaultChannelConfig { - Http2StreamChannelConfig(Channel channel) { super(channel); - setRecvByteBufAllocator(new Http2StreamChannelRecvByteBufAllocator()); } @Override diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java index 8e0929094a6b..c5732ec68748 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java @@ -27,8 +27,10 @@ @UnstableApi public class Http2MultiplexCodecBuilder extends AbstractHttp2ConnectionHandlerBuilder { + private Http2FrameWriter frameWriter; final ChannelHandler childHandler; + private ChannelHandler upgradeStreamHandler; Http2MultiplexCodecBuilder(boolean server, ChannelHandler childHandler) { server(server); @@ -43,6 +45,12 @@ private static ChannelHandler checkSharable(ChannelHandler handler) { return handler; } + // For testing only. + Http2MultiplexCodecBuilder frameWriter(Http2FrameWriter frameWriter) { + this.frameWriter = checkNotNull(frameWriter, "frameWriter"); + return this; + } + /** * Creates a builder for a HTTP/2 client. * @@ -83,6 +91,14 @@ public Http2MultiplexCodecBuilder gracefulShutdownTimeoutMillis(long gracefulShu return super.gracefulShutdownTimeoutMillis(gracefulShutdownTimeoutMillis); } + public Http2MultiplexCodecBuilder withUpgradeStreamHandler(ChannelHandler upgradeStreamHandler) { + if (this.isServer()) { + throw new IllegalArgumentException("Server codecs don't use an extra handler for the upgrade stream"); + } + this.upgradeStreamHandler = upgradeStreamHandler; + return this; + } + @Override public boolean isServer() { return super.isServer(); @@ -151,12 +167,34 @@ public Http2MultiplexCodecBuilder initialHuffmanDecodeCapacity(int initialHuffma @Override public Http2MultiplexCodec build() { + Http2FrameWriter frameWriter = this.frameWriter; + if (frameWriter != null) { + // This is to support our tests and will never be executed by the user as frameWriter(...) + // is package-private. + DefaultHttp2Connection connection = new DefaultHttp2Connection(isServer(), maxReservedStreams()); + Long maxHeaderListSize = initialSettings().maxHeaderListSize(); + Http2FrameReader frameReader = new DefaultHttp2FrameReader(maxHeaderListSize == null ? + new DefaultHttp2HeadersDecoder(true) : + new DefaultHttp2HeadersDecoder(true, maxHeaderListSize)); + + if (frameLogger() != null) { + frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger()); + frameReader = new Http2InboundFrameLogger(frameReader, frameLogger()); + } + Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); + if (encoderEnforceMaxConcurrentStreams()) { + encoder = new StreamBufferingEncoder(encoder); + } + Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader); + + return build(decoder, encoder, initialSettings()); + } return super.build(); } @Override protected Http2MultiplexCodec build( Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, Http2Settings initialSettings) { - return new Http2MultiplexCodec(encoder, decoder, initialSettings, childHandler); + return new Http2MultiplexCodec(encoder, decoder, initialSettings, childHandler, upgradeStreamHandler); } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodec.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodec.java index 36a7fba1c0c6..e13a45fe0152 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodec.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodec.java @@ -41,6 +41,8 @@ import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.ssl.SslHandler; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; import io.netty.util.internal.UnstableApi; import java.util.List; @@ -57,16 +59,17 @@ @UnstableApi @Sharable public class Http2StreamFrameToHttpObjectCodec extends MessageToMessageCodec { + + private static final AttributeKey SCHEME_ATTR_KEY = + AttributeKey.valueOf(HttpScheme.class, "STREAMFRAMECODEC_SCHEME"); + private final boolean isServer; private final boolean validateHeaders; - private HttpScheme scheme; - public Http2StreamFrameToHttpObjectCodec(final boolean isServer, final boolean validateHeaders) { this.isServer = isServer; this.validateHeaders = validateHeaders; - scheme = HttpScheme.HTTP; } public Http2StreamFrameToHttpObjectCodec(final boolean isServer) { @@ -154,7 +157,7 @@ protected void encode(ChannelHandlerContext ctx, HttpObject obj, List ou final HttpResponse res = (HttpResponse) obj; if (res.status().equals(HttpResponseStatus.CONTINUE)) { if (res instanceof FullHttpResponse) { - final Http2Headers headers = toHttp2Headers(res); + final Http2Headers headers = toHttp2Headers(ctx, res); out.add(new DefaultHttp2HeadersFrame(headers, false)); return; } else { @@ -165,7 +168,7 @@ protected void encode(ChannelHandlerContext ctx, HttpObject obj, List ou } if (obj instanceof HttpMessage) { - Http2Headers headers = toHttp2Headers((HttpMessage) obj); + Http2Headers headers = toHttp2Headers(ctx, (HttpMessage) obj); boolean noMoreFrames = false; if (obj instanceof FullHttpMessage) { FullHttpMessage full = (FullHttpMessage) obj; @@ -184,11 +187,11 @@ protected void encode(ChannelHandlerContext ctx, HttpObject obj, List ou } } - private Http2Headers toHttp2Headers(final HttpMessage msg) { + private Http2Headers toHttp2Headers(final ChannelHandlerContext ctx, final HttpMessage msg) { if (msg instanceof HttpRequest) { msg.headers().set( HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), - scheme.name()); + connectionScheme(ctx)); } return HttpConversionUtil.toHttp2Headers(msg, validateHeaders); @@ -213,17 +216,35 @@ private FullHttpMessage newFullMessage(final int id, public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { super.handlerAdded(ctx); - // this handler is typically used on an Http2StreamChannel. at this + // this handler is typically used on an Http2StreamChannel. At this // stage, ssl handshake should've been established. checking for the // presence of SslHandler in the parent's channel pipeline to // determine the HTTP scheme should suffice, even for the case where // SniHandler is used. - scheme = isSsl(ctx) ? HttpScheme.HTTPS : HttpScheme.HTTP; + final Attribute schemeAttribute = connectionSchemeAttribute(ctx); + if (schemeAttribute.get() == null) { + final HttpScheme scheme = isSsl(ctx) ? HttpScheme.HTTPS : HttpScheme.HTTP; + schemeAttribute.set(scheme); + } } protected boolean isSsl(final ChannelHandlerContext ctx) { - final Channel ch = ctx.channel(); - final Channel connChannel = (ch instanceof Http2StreamChannel) ? ch.parent() : ch; + final Channel connChannel = connectionChannel(ctx); return null != connChannel.pipeline().get(SslHandler.class); } + + private static HttpScheme connectionScheme(ChannelHandlerContext ctx) { + final HttpScheme scheme = connectionSchemeAttribute(ctx).get(); + return scheme == null ? HttpScheme.HTTP : scheme; + } + + private static Attribute connectionSchemeAttribute(ChannelHandlerContext ctx) { + final Channel ch = connectionChannel(ctx); + return ch.attr(SCHEME_ATTR_KEY); + } + + private static Channel connectionChannel(ChannelHandlerContext ctx) { + final Channel ch = ctx.channel(); + return ch instanceof Http2StreamChannel ? ch.parent() : ch; + } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2UnknownFrame.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2UnknownFrame.java index 51eeafaf0d95..41e578703a5b 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2UnknownFrame.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2UnknownFrame.java @@ -20,13 +20,12 @@ import io.netty.util.internal.UnstableApi; @UnstableApi -public interface Http2UnknownFrame extends Http2Frame, ByteBufHolder { +public interface Http2UnknownFrame extends Http2StreamFrame, ByteBufHolder { + @Override Http2FrameStream stream(); - /** - * Set the {@link Http2FrameStream} object for this frame. - */ + @Override Http2UnknownFrame stream(Http2FrameStream stream); byte frameType(); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/UniformStreamByteDistributor.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/UniformStreamByteDistributor.java index 421d3488b20c..c7e57911b8f2 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/UniformStreamByteDistributor.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/UniformStreamByteDistributor.java @@ -24,6 +24,7 @@ import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; import static java.lang.Math.max; import static java.lang.Math.min; @@ -72,9 +73,7 @@ public void onStreamClosed(Http2Stream stream) { * Must be > 0. */ public void minAllocationChunk(int minAllocationChunk) { - if (minAllocationChunk <= 0) { - throw new IllegalArgumentException("minAllocationChunk must be > 0"); - } + checkPositive(minAllocationChunk, "minAllocationChunk"); this.minAllocationChunk = minAllocationChunk; } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributor.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributor.java index b7c2cddf52cd..b0f3a25dc4a4 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributor.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributor.java @@ -37,6 +37,8 @@ import static io.netty.handler.codec.http2.Http2CodecUtil.streamableBytes; import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static java.lang.Integer.MAX_VALUE; import static java.lang.Math.max; import static java.lang.Math.min; @@ -96,9 +98,8 @@ public WeightedFairQueueByteDistributor(Http2Connection connection) { } public WeightedFairQueueByteDistributor(Http2Connection connection, int maxStateOnlySize) { - if (maxStateOnlySize < 0) { - throw new IllegalArgumentException("maxStateOnlySize: " + maxStateOnlySize + " (expected: >0)"); - } else if (maxStateOnlySize == 0) { + checkPositiveOrZero(maxStateOnlySize, "maxStateOnlySize"); + if (maxStateOnlySize == 0) { stateOnlyMap = IntCollections.emptyMap(); stateOnlyRemovalQueue = EmptyPriorityQueue.instance(); } else { @@ -281,9 +282,7 @@ public boolean distribute(int maxBytes, Writer writer) throws Http2Exception { * @param allocationQuantum the amount of bytes that will be allocated to each stream. Must be > 0. */ public void allocationQuantum(int allocationQuantum) { - if (allocationQuantum <= 0) { - throw new IllegalArgumentException("allocationQuantum must be > 0"); - } + checkPositive(allocationQuantum, "allocationQuantum"); this.allocationQuantum = allocationQuantum; } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandlerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandlerTest.java index a544054b6685..129f62d7574d 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandlerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandlerTest.java @@ -42,8 +42,15 @@ import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link CleartextHttp2ServerUpgradeHandler} @@ -112,47 +119,35 @@ public void priorKnowledge() throws Exception { @Test public void upgrade() throws Exception { - setUpServerChannel(); - String upgradeString = "GET / HTTP/1.1\r\n" + - "Host: example.com\r\n" + - "Connection: Upgrade, HTTP2-Settings\r\n" + - "Upgrade: h2c\r\n" + - "HTTP2-Settings: AAMAAABkAAQAAP__\r\n\r\n"; - ByteBuf upgrade = Unpooled.copiedBuffer(upgradeString, CharsetUtil.US_ASCII); - - assertFalse(channel.writeInbound(upgrade)); - - assertEquals(1, userEvents.size()); - - Object userEvent = userEvents.get(0); - assertTrue(userEvent instanceof UpgradeEvent); - assertEquals("h2c", ((UpgradeEvent) userEvent).protocol()); - ReferenceCountUtil.release(userEvent); - - assertEquals(100, http2ConnectionHandler.connection().local().maxActiveStreams()); - assertEquals(65535, http2ConnectionHandler.connection().local().flowController().initialWindowSize()); - - assertEquals(1, http2ConnectionHandler.connection().numActiveStreams()); - assertNotNull(http2ConnectionHandler.connection().stream(1)); - - Http2Stream stream = http2ConnectionHandler.connection().stream(1); - assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); - assertFalse(stream.isHeadersSent()); - - String expectedHttpResponse = "HTTP/1.1 101 Switching Protocols\r\n" + - "connection: upgrade\r\n" + - "upgrade: h2c\r\n\r\n"; - ByteBuf responseBuffer = channel.readOutbound(); - assertEquals(expectedHttpResponse, responseBuffer.toString(CharsetUtil.UTF_8)); - responseBuffer.release(); + "Host: example.com\r\n" + + "Connection: Upgrade, HTTP2-Settings\r\n" + + "Upgrade: h2c\r\n" + + "HTTP2-Settings: AAMAAABkAAQAAP__\r\n\r\n"; + validateClearTextUpgrade(upgradeString); + } - // Check that the preface was send (a.k.a the settings frame) - ByteBuf settingsBuffer = channel.readOutbound(); - assertNotNull(settingsBuffer); - settingsBuffer.release(); + @Test + public void upgradeWithMultipleConnectionHeaders() { + String upgradeString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: keep-alive\r\n" + + "Connection: Upgrade, HTTP2-Settings\r\n" + + "Upgrade: h2c\r\n" + + "HTTP2-Settings: AAMAAABkAAQAAP__\r\n\r\n"; + validateClearTextUpgrade(upgradeString); + } - assertNull(channel.readOutbound()); + @Test + public void requiredHeadersInSeparateConnectionHeaders() { + String upgradeString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: keep-alive\r\n" + + "Connection: HTTP2-Settings\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: h2c\r\n" + + "HTTP2-Settings: AAMAAABkAAQAAP__\r\n\r\n"; + validateClearTextUpgrade(upgradeString); } @Test @@ -254,4 +249,43 @@ private static ByteBuf settingsFrameBuf() { private static Http2Settings expectedSettings() { return new Http2Settings().maxConcurrentStreams(100).initialWindowSize(65535); } + + private void validateClearTextUpgrade(String upgradeString) { + setUpServerChannel(); + + ByteBuf upgrade = Unpooled.copiedBuffer(upgradeString, CharsetUtil.US_ASCII); + + assertFalse(channel.writeInbound(upgrade)); + + assertEquals(1, userEvents.size()); + + Object userEvent = userEvents.get(0); + assertTrue(userEvent instanceof UpgradeEvent); + assertEquals("h2c", ((UpgradeEvent) userEvent).protocol()); + ReferenceCountUtil.release(userEvent); + + assertEquals(100, http2ConnectionHandler.connection().local().maxActiveStreams()); + assertEquals(65535, http2ConnectionHandler.connection().local().flowController().initialWindowSize()); + + assertEquals(1, http2ConnectionHandler.connection().numActiveStreams()); + assertNotNull(http2ConnectionHandler.connection().stream(1)); + + Http2Stream stream = http2ConnectionHandler.connection().stream(1); + assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); + assertFalse(stream.isHeadersSent()); + + String expectedHttpResponse = "HTTP/1.1 101 Switching Protocols\r\n" + + "connection: upgrade\r\n" + + "upgrade: h2c\r\n\r\n"; + ByteBuf responseBuffer = channel.readOutbound(); + assertEquals(expectedHttpResponse, responseBuffer.toString(CharsetUtil.UTF_8)); + responseBuffer.release(); + + // Check that the preface was send (a.k.a the settings frame) + ByteBuf settingsBuffer = channel.readOutbound(); + assertNotNull(settingsBuffer); + settingsBuffer.release(); + + assertNull(channel.readOutbound()); + } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java index 8d71ddeaffe4..7e87d52893c2 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java @@ -674,13 +674,6 @@ public void rstStreamReadAfterGoAwayShouldSucceed() throws Exception { verify(listener).onRstStreamRead(eq(ctx), anyInt(), anyLong()); } - @Test(expected = Http2Exception.class) - public void goawayIncreasedLastStreamIdShouldThrow() throws Exception { - when(local.lastStreamKnownByPeer()).thenReturn(1); - when(connection.goAwayReceived()).thenReturn(true); - decode().onGoAwayRead(ctx, 3, 2L, EMPTY_BUFFER); - } - @Test(expected = Http2Exception.class) public void rstStreamReadForUnknownStreamShouldThrow() throws Exception { when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(false); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java index 5f9a6b18719c..9ca7b1fee1b0 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java @@ -34,6 +34,7 @@ import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; @@ -54,6 +55,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -65,6 +67,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -124,6 +127,13 @@ public void setup() throws Exception { when(channel.unsafe()).thenReturn(unsafe); ChannelConfig config = new DefaultChannelConfig(channel); when(channel.config()).thenReturn(config); + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock in) { + return newPromise().setFailure((Throwable) in.getArgument(0)); + } + }).when(channel).newFailedFuture(any(Throwable.class)); + when(writer.configuration()).thenReturn(writerConfig); when(writerConfig.frameSizePolicy()).thenReturn(frameSizePolicy); when(frameSizePolicy.maxFrameSize()).thenReturn(64); @@ -204,6 +214,36 @@ public ChannelFuture answer(InvocationOnMock in) throws Throwable { encoder.lifecycleManager(lifecycleManager); } + @Test + public void dataWithEndOfStreamWriteShouldSignalThatFrameWasConsumedOnError() throws Exception { + dataWriteShouldSignalThatFrameWasConsumedOnError0(true); + } + + @Test + public void dataWriteShouldSignalThatFrameWasConsumedOnError() throws Exception { + dataWriteShouldSignalThatFrameWasConsumedOnError0(false); + } + + private void dataWriteShouldSignalThatFrameWasConsumedOnError0(boolean endOfStream) throws Exception { + createStream(STREAM_ID, false); + final ByteBuf data = dummyData(); + ChannelPromise p = newPromise(); + encoder.writeData(ctx, STREAM_ID, data, 0, endOfStream, p); + + FlowControlled controlled = payloadCaptor.getValue(); + assertEquals(8, controlled.size()); + payloadCaptor.getValue().write(ctx, 4); + assertEquals(4, controlled.size()); + + Throwable error = new IllegalStateException(); + payloadCaptor.getValue().error(ctx, error); + payloadCaptor.getValue().write(ctx, 8); + assertEquals(0, controlled.size()); + assertEquals("abcd", writtenData.get(0)); + assertEquals(0, data.refCnt()); + assertSame(error, p.cause()); + } + @Test public void dataWriteShouldSucceed() throws Exception { createStream(STREAM_ID, false); @@ -707,6 +747,59 @@ public void headersWriteShouldHalfClosePushStream() throws Exception { verify(lifecycleManager).closeStreamLocal(eq(stream), eq(promise)); } + @Test + public void headersWriteShouldHalfCloseAfterOnErrorForPreCreatedStream() throws Exception { + final ChannelPromise promise = newPromise(); + final Throwable ex = new RuntimeException(); + // Fake an encoding error, like HPACK's HeaderListSizeException + when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) { + promise.setFailure(ex); + return promise; + } + }); + + writeAllFlowControlledFrames(); + Http2Stream stream = createStream(STREAM_ID, false); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise); + + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + assertFalse(stream.isHeadersSent()); + InOrder inOrder = inOrder(lifecycleManager); + inOrder.verify(lifecycleManager).onError(eq(ctx), eq(true), eq(ex)); + inOrder.verify(lifecycleManager).closeStreamLocal(eq(stream(STREAM_ID)), eq(promise)); + } + + @Test + public void headersWriteShouldHalfCloseAfterOnErrorForImplicitlyCreatedStream() throws Exception { + final ChannelPromise promise = newPromise(); + final Throwable ex = new RuntimeException(); + // Fake an encoding error, like HPACK's HeaderListSizeException + when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) { + promise.setFailure(ex); + return promise; + } + }); + + writeAllFlowControlledFrames(); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise); + + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + assertFalse(stream(STREAM_ID).isHeadersSent()); + InOrder inOrder = inOrder(lifecycleManager); + inOrder.verify(lifecycleManager).onError(eq(ctx), eq(true), eq(ex)); + inOrder.verify(lifecycleManager).closeStreamLocal(eq(stream(STREAM_ID)), eq(promise)); + } + @Test public void encoderDelegatesGoAwayToLifeCycleManager() { ChannelPromise promise = newPromise(); @@ -770,7 +863,7 @@ public void canWriteDataFrameAfterGoAwayReceived() throws Exception { } @Test - public void canWriteHeaderFrameAfterGoAwayReceived() { + public void canWriteHeaderFrameAfterGoAwayReceived() throws Http2Exception { writeAllFlowControlledFrames(); goAwayReceived(STREAM_ID); ChannelPromise promise = newPromise(); @@ -803,11 +896,11 @@ private Http2Stream stream(int streamId) { return connection.stream(streamId); } - private void goAwayReceived(int lastStreamId) { + private void goAwayReceived(int lastStreamId) throws Http2Exception { connection.goAwayReceived(lastStreamId, 0, EMPTY_BUFFER); } - private void goAwaySent(int lastStreamId) { + private void goAwaySent(int lastStreamId) throws Http2Exception { connection.goAwaySent(lastStreamId, 0, EMPTY_BUFFER); } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java index ef385beed646..b43c00020670 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java @@ -47,8 +47,8 @@ import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyLong; -import static org.mockito.Mockito.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -423,11 +423,29 @@ public void reserveWithPushDisallowedShouldThrow() throws Http2Exception { } @Test(expected = Http2Exception.class) - public void goAwayReceivedShouldDisallowCreation() throws Http2Exception { + public void goAwayReceivedShouldDisallowLocalCreation() throws Http2Exception { + server.goAwayReceived(0, 1L, Unpooled.EMPTY_BUFFER); + server.local().createStream(3, true); + } + + @Test + public void goAwayReceivedShouldAllowRemoteCreation() throws Http2Exception { server.goAwayReceived(0, 1L, Unpooled.EMPTY_BUFFER); server.remote().createStream(3, true); } + @Test(expected = Http2Exception.class) + public void goAwaySentShouldDisallowRemoteCreation() throws Http2Exception { + server.goAwaySent(0, 1L, Unpooled.EMPTY_BUFFER); + server.remote().createStream(2, true); + } + + @Test + public void goAwaySentShouldAllowLocalCreation() throws Http2Exception { + server.goAwaySent(0, 1L, Unpooled.EMPTY_BUFFER); + server.local().createStream(2, true); + } + @Test public void closeShouldSucceed() throws Http2Exception { Http2Stream stream = server.remote().createStream(3, true); @@ -606,7 +624,7 @@ private static final class ListenerExceptionThrower implements Answer { private final boolean[] array; private final int index; - public ListenerExceptionThrower(boolean[] array, int index) { + ListenerExceptionThrower(boolean[] array, int index) { this.array = array; this.index = index; } @@ -622,7 +640,7 @@ private static final class ListenerVerifyCallAnswer implements Answer { private final boolean[] array; private final int index; - public ListenerVerifyCallAnswer(boolean[] array, int index) { + ListenerVerifyCallAnswer(boolean[] array, int index) { this.array = array; this.index = index; } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java index 4652a1fceb66..994fef6f37af 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java @@ -431,15 +431,13 @@ public void testLiteralNeverIndexedWithLargeValue() throws Http2Exception { } @Test - public void testDecodeLargerThanMaxHeaderListSizeButSmallerThanMaxHeaderListSizeUpdatesDynamicTable() - throws Http2Exception { + public void testDecodeLargerThanMaxHeaderListSizeUpdatesDynamicTable() throws Http2Exception { ByteBuf in = Unpooled.buffer(300); try { - hpackDecoder.setMaxHeaderListSize(200, 300); + hpackDecoder.setMaxHeaderListSize(200); HpackEncoder hpackEncoder = new HpackEncoder(true); // encode headers that are slightly larger than maxHeaderListSize - // but smaller than maxHeaderListSizeGoAway Http2Headers toEncode = new DefaultHttp2Headers(); toEncode.add("test_1", "1"); toEncode.add("test_2", "2"); @@ -447,8 +445,7 @@ public void testDecodeLargerThanMaxHeaderListSizeButSmallerThanMaxHeaderListSize toEncode.add("test_3", "3"); hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); - // decode the headers, we should get an exception, but - // the decoded headers object should contain all of the headers + // decode the headers, we should get an exception Http2Headers decoded = new DefaultHttp2Headers(); try { hpackDecoder.decode(1, in, decoded, true); @@ -457,8 +454,18 @@ public void testDecodeLargerThanMaxHeaderListSizeButSmallerThanMaxHeaderListSize assertTrue(e instanceof Http2Exception.HeaderListSizeException); } - assertEquals(4, decoded.size()); - assertTrue(decoded.contains("test_3")); + // but the dynamic table should have been updated, so that later blocks + // can refer to earlier headers + in.clear(); + // 0x80, "indexed header field representation" + // index 62, the first (most recent) dynamic table entry + in.writeByte(0x80 | 62); + Http2Headers decoded2 = new DefaultHttp2Headers(); + hpackDecoder.decode(1, in, decoded2, true); + + Http2Headers golden = new DefaultHttp2Headers(); + golden.add("test_3", "3"); + assertEquals(golden, decoded2); } finally { in.release(); } @@ -468,11 +475,10 @@ public void testDecodeLargerThanMaxHeaderListSizeButSmallerThanMaxHeaderListSize public void testDecodeCountsNamesOnlyOnce() throws Http2Exception { ByteBuf in = Unpooled.buffer(200); try { - hpackDecoder.setMaxHeaderListSize(3500, 4000); + hpackDecoder.setMaxHeaderListSize(3500); HpackEncoder hpackEncoder = new HpackEncoder(true); // encode headers that are slightly larger than maxHeaderListSize - // but smaller than maxHeaderListSizeGoAway Http2Headers toEncode = new DefaultHttp2Headers(); toEncode.add(String.format("%03000d", 0).replace('0', 'f'), "value"); toEncode.add("accept", "value"); @@ -493,7 +499,7 @@ public void testAccountForHeaderOverhead() throws Exception { String headerName = "12345"; String headerValue = "56789"; long headerSize = headerName.length() + headerValue.length(); - hpackDecoder.setMaxHeaderListSize(headerSize, 100); + hpackDecoder.setMaxHeaderListSize(headerSize); HpackEncoder hpackEncoder = new HpackEncoder(true); Http2Headers toEncode = new DefaultHttp2Headers(); @@ -538,7 +544,7 @@ public void unknownPseudoHeader() throws Exception { Http2Headers decoded = new DefaultHttp2Headers(); - expectedException.expect(Http2Exception.class); + expectedException.expect(Http2Exception.StreamException.class); hpackDecoder.decode(1, in, decoded, true); } finally { in.release(); @@ -582,7 +588,7 @@ public void requestPseudoHeaderInResponse() throws Exception { Http2Headers decoded = new DefaultHttp2Headers(); - expectedException.expect(Http2Exception.class); + expectedException.expect(Http2Exception.StreamException.class); hpackDecoder.decode(1, in, decoded, true); } finally { in.release(); @@ -602,7 +608,7 @@ public void responsePseudoHeaderInRequest() throws Exception { Http2Headers decoded = new DefaultHttp2Headers(); - expectedException.expect(Http2Exception.class); + expectedException.expect(Http2Exception.StreamException.class); hpackDecoder.decode(1, in, decoded, true); } finally { in.release(); @@ -622,10 +628,47 @@ public void pseudoHeaderAfterRegularHeader() throws Exception { Http2Headers decoded = new DefaultHttp2Headers(); - expectedException.expect(Http2Exception.class); + expectedException.expect(Http2Exception.StreamException.class); hpackDecoder.decode(1, in, decoded, true); } finally { in.release(); } } + + @Test + public void failedValidationDoesntCorruptHpack() throws Exception { + ByteBuf in1 = Unpooled.buffer(200); + ByteBuf in2 = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new DefaultHttp2Headers(); + toEncode.add(":method", "GET"); + toEncode.add(":status", "200"); + toEncode.add("foo", "bar"); + hpackEncoder.encodeHeaders(1, in1, toEncode, NEVER_SENSITIVE); + + Http2Headers decoded = new DefaultHttp2Headers(); + + try { + hpackDecoder.decode(1, in1, decoded, true); + fail("Should have thrown a StreamException"); + } catch (Http2Exception.StreamException expected) { + assertEquals(1, expected.streamId()); + } + + // Do it again, this time without validation, to make sure the HPACK state is still sane. + decoded.clear(); + hpackEncoder.encodeHeaders(1, in2, toEncode, NEVER_SENSITIVE); + hpackDecoder.decode(1, in2, decoded, false); + + assertEquals(3, decoded.size()); + assertEquals("GET", decoded.method().toString()); + assertEquals("200", decoded.status().toString()); + assertEquals("bar", decoded.get("foo").toString()); + } finally { + in1.release(); + in2.release(); + } + } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackEncoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackEncoderTest.java index b1dce05c38b3..de049a6e5066 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackEncoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackEncoderTest.java @@ -74,7 +74,7 @@ public void testWillEncode16MBHeaderByDefault() throws Http2Exception { try { hpackEncoder.encodeHeaders(0, buf, headersIn, Http2HeadersEncoder.NEVER_SENSITIVE); - hpackDecoder.setMaxHeaderListSize(bigHeaderSize + 1024, bigHeaderSize + 1024); + hpackDecoder.setMaxHeaderListSize(bigHeaderSize + 1024); hpackDecoder.decode(0, buf, headersOut, false); } finally { buf.release(); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTest.java index 77da8665c0a5..fe9fa3260138 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTest.java @@ -31,6 +31,7 @@ */ package io.netty.handler.codec.http2; +import io.netty.util.internal.ResourcesUtil; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -38,7 +39,6 @@ import java.io.File; import java.io.InputStream; -import java.net.URL; import java.util.ArrayList; import java.util.Collection; @@ -56,8 +56,7 @@ public HpackTest(String fileName) { @Parameters(name = "{0}") public static Collection data() { - URL url = HpackTest.class.getResource(TEST_DIR); - File[] files = new File(url.getFile()).listFiles(); + File[] files = ResourcesUtil.getFile(HpackTest.class, TEST_DIR).listFiles(); if (files == null) { throw new NullPointerException("files"); } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodecTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodecTest.java index cb774a94ca94..fcfdb4b75e54 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodecTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodecTest.java @@ -14,9 +14,11 @@ */ package io.netty.handler.codec.http2; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest; @@ -43,7 +45,8 @@ public void testUpgradeToHttp2FrameCodec() throws Exception { @Test public void testUpgradeToHttp2MultiplexCodec() throws Exception { - testUpgrade(Http2MultiplexCodecBuilder.forClient(new HttpInboundHandler()).build()); + testUpgrade(Http2MultiplexCodecBuilder.forClient(new HttpInboundHandler()) + .withUpgradeStreamHandler(new ChannelInboundHandlerAdapter()).build()); } private static void testUpgrade(Http2ConnectionHandler handler) throws Exception { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java index be5c41b0e55d..9b8c62402723 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java @@ -22,8 +22,10 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelMetadata; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; import io.netty.channel.DefaultChannelPromise; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator; @@ -142,6 +144,11 @@ public void setup() throws Exception { promise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); voidPromise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + + when(channel.metadata()).thenReturn(new ChannelMetadata(false)); + DefaultChannelConfig config = new DefaultChannelConfig(channel); + when(channel.config()).thenReturn(config); + Throwable fakeException = new RuntimeException("Fake exception"); when(encoder.connection()).thenReturn(connection); when(decoder.connection()).thenReturn(connection); @@ -189,6 +196,7 @@ public Http2Stream answer(InvocationOnMock in) throws Throwable { when(connection.stream(NON_EXISTANT_STREAM_ID)).thenReturn(null); when(connection.numActiveStreams()).thenReturn(1); when(connection.stream(STREAM_ID)).thenReturn(stream); + when(connection.goAwaySent(anyInt(), anyLong(), any(ByteBuf.class))).thenReturn(true); when(stream.open(anyBoolean())).thenReturn(stream); when(encoder.writeSettings(eq(ctx), any(Http2Settings.class), eq(promise))).thenReturn(future); when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); @@ -638,6 +646,12 @@ public void cannotSendGoAwayFrameWithIncreasingLastStreamIds() throws Exception when(connection.goAwaySent()).thenReturn(true); when(remote.lastStreamKnownByPeer()).thenReturn(STREAM_ID); + doAnswer(new Answer() { + @Override + public Boolean answer(InvocationOnMock invocationOnMock) { + throw new IllegalStateException(); + } + }).when(connection).goAwaySent(anyInt(), anyLong(), any(ByteBuf.class)); handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise); assertTrue(promise.isDone()); assertFalse(promise.isSuccess()); @@ -677,6 +691,14 @@ public void channelReadCompleteTriggersFlush() throws Exception { verify(ctx, times(1)).flush(); } + @Test + public void channelReadCompleteCallsReadWhenAutoReadFalse() throws Exception { + channel.config().setAutoRead(false); + handler = newHandler(); + handler.channelReadComplete(ctx); + verify(ctx, times(1)).read(); + } + @Test public void channelClosedDoesNotThrowPrefaceException() throws Exception { when(connection.isServer()).thenReturn(true); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java index 9860d28fee90..32febac05f41 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java @@ -27,6 +27,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultEventLoopGroup; @@ -37,6 +38,7 @@ import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable; import io.netty.util.AsciiString; import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Future; import org.junit.After; import org.junit.Before; @@ -52,11 +54,14 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicReference; +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2Error.NO_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2TestUtil.randomString; import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel; +import static java.lang.Integer.MAX_VALUE; import static java.util.concurrent.TimeUnit.SECONDS; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.not; @@ -220,7 +225,7 @@ public void run() throws Http2Exception { anyLong()); // The server will not respond, and so don't wait for graceful shutdown - http2Client.gracefulShutdownTimeoutMillis(0); + setClientGracefulShutdownTime(0); } @Test @@ -679,7 +684,7 @@ public void writeOfEmptyReleasedBufferMultipleBuffersTrailersQueuedInFlowControl writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(WriteEmptyBufferMode.SECOND_WITH_TRAILERS); } - public void writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(final WriteEmptyBufferMode mode) + private void writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(final WriteEmptyBufferMode mode) throws Exception { bootstrapEnv(1, 1, 2, 1); @@ -725,6 +730,59 @@ public void run() throws Http2Exception { } } + @Test + public void writeFailureFlowControllerRemoveFrame() + throws Exception { + bootstrapEnv(1, 1, 2, 1); + + final ChannelPromise dataPromise = newPromise(); + final ChannelPromise assertPromise = newPromise(); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, EmptyHttp2Headers.INSTANCE, 0, (short) 16, false, 0, false, + newPromise()); + clientChannel.pipeline().addFirst(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ReferenceCountUtil.release(msg); + + // Ensure we update the window size so we will try to write the rest of the frame while + // processing the flush. + http2Client.encoder().flowController().initialWindowSize(8); + promise.setFailure(new IllegalStateException()); + } + }); + + http2Client.encoder().flowController().initialWindowSize(4); + http2Client.encoder().writeData(ctx(), 3, randomBytes(8), 0, false, dataPromise); + assertTrue(http2Client.encoder().flowController() + .hasFlowControlled(http2Client.connection().stream(3))); + + http2Client.flush(ctx()); + + try { + // The Frame should have been removed after the write failed. + assertFalse(http2Client.encoder().flowController() + .hasFlowControlled(http2Client.connection().stream(3))); + assertPromise.setSuccess(); + } catch (Throwable error) { + assertPromise.setFailure(error); + } + } + }); + + try { + dataPromise.get(); + fail(); + } catch (ExecutionException e) { + assertThat(e.getCause(), is(instanceOf(IllegalStateException.class))); + } + + assertPromise.sync(); + } + @Test public void nonHttp2ExceptionInPipelineShouldNotCloseConnection() throws Exception { bootstrapEnv(1, 1, 2, 1); @@ -766,7 +824,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { assertTrue(clientChannel.isOpen()); // Set the timeout very low because we know graceful shutdown won't complete - http2Client.gracefulShutdownTimeoutMillis(0); + setClientGracefulShutdownTime(0); } @Test @@ -774,7 +832,7 @@ public void noMoreStreamIdsShouldSendGoAway() throws Exception { bootstrapEnv(1, 1, 3, 1, 1); // Don't wait for the server to close streams - http2Client.gracefulShutdownTimeoutMillis(0); + setClientGracefulShutdownTime(0); // Create a single stream by sending a HEADERS frame to the server. final Http2Headers headers = dummyHeaders(); @@ -792,7 +850,7 @@ public void run() throws Http2Exception { runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - http2Client.encoder().writeHeaders(ctx(), Integer.MAX_VALUE + 1, headers, 0, (short) 16, false, 0, + http2Client.encoder().writeHeaders(ctx(), MAX_VALUE + 1, headers, 0, (short) 16, false, 0, true, newPromise()); http2Client.flush(ctx()); } @@ -803,6 +861,167 @@ public void run() throws Http2Exception { eq(PROTOCOL_ERROR.code()), any(ByteBuf.class)); } + @Test + public void createStreamAfterReceiveGoAwayShouldNotSendGoAway() throws Exception { + bootstrapEnv(1, 1, 2, 1, 1); + + // We want both sides to do graceful shutdown during the test. + setClientGracefulShutdownTime(10000); + setServerGracefulShutdownTime(10000); + + final CountDownLatch clientGoAwayLatch = new CountDownLatch(1); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + clientGoAwayLatch.countDown(); + return null; + } + }).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class)); + + // Create a single stream by sending a HEADERS frame to the server. + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, + false, newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Server has received the headers, so the stream is open + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + runInChannel(serverChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeGoAway(serverCtx(), 3, NO_ERROR.code(), EMPTY_BUFFER, serverNewPromise()); + http2Server.flush(serverCtx()); + } + }); + + // wait for the client to receive the GO_AWAY. + assertTrue(clientGoAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + verify(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(3), eq(NO_ERROR.code()), + any(ByteBuf.class)); + + final AtomicReference clientWriteAfterGoAwayFutureRef = new AtomicReference(); + final CountDownLatch clientWriteAfterGoAwayLatch = new CountDownLatch(1); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + ChannelFuture f = http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, (short) 16, false, 0, + true, newPromise()); + clientWriteAfterGoAwayFutureRef.set(f); + http2Client.flush(ctx()); + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + clientWriteAfterGoAwayLatch.countDown(); + } + }); + } + }); + + // Wait for the client's write operation to complete. + assertTrue(clientWriteAfterGoAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + ChannelFuture clientWriteAfterGoAwayFuture = clientWriteAfterGoAwayFutureRef.get(); + assertNotNull(clientWriteAfterGoAwayFuture); + Throwable clientCause = clientWriteAfterGoAwayFuture.cause(); + assertThat(clientCause, is(instanceOf(Http2Exception.StreamException.class))); + assertEquals(Http2Error.REFUSED_STREAM.code(), ((Http2Exception.StreamException) clientCause).error().code()); + + // Wait for the server to receive a GO_AWAY, but this is expected to timeout! + assertFalse(goAwayLatch.await(1, SECONDS)); + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + + // Shutdown shouldn't wait for the server to close streams + setClientGracefulShutdownTime(0); + setServerGracefulShutdownTime(0); + } + + @Test + public void createStreamSynchronouslyAfterGoAwayReceivedShouldFailLocally() throws Exception { + bootstrapEnv(1, 1, 2, 1, 1); + + final CountDownLatch clientGoAwayLatch = new CountDownLatch(1); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + clientGoAwayLatch.countDown(); + return null; + } + }).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class)); + + // We want both sides to do graceful shutdown during the test. + setClientGracefulShutdownTime(10000); + setServerGracefulShutdownTime(10000); + + final Http2Headers headers = dummyHeaders(); + final AtomicReference clientWriteAfterGoAwayFutureRef = new AtomicReference(); + final CountDownLatch clientWriteAfterGoAwayLatch = new CountDownLatch(1); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + ChannelFuture f = http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, (short) 16, false, 0, + true, newPromise()); + clientWriteAfterGoAwayFutureRef.set(f); + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + clientWriteAfterGoAwayLatch.countDown(); + } + }); + http2Client.flush(ctx()); + return null; + } + }).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class)); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, + true, newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Server has received the headers, so the stream is open + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + runInChannel(serverChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeGoAway(serverCtx(), 3, NO_ERROR.code(), EMPTY_BUFFER, serverNewPromise()); + http2Server.flush(serverCtx()); + } + }); + + // Wait for the client's write operation to complete. + assertTrue(clientWriteAfterGoAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + ChannelFuture clientWriteAfterGoAwayFuture = clientWriteAfterGoAwayFutureRef.get(); + assertNotNull(clientWriteAfterGoAwayFuture); + Throwable clientCause = clientWriteAfterGoAwayFuture.cause(); + assertThat(clientCause, is(instanceOf(Http2Exception.StreamException.class))); + assertEquals(Http2Error.REFUSED_STREAM.code(), ((Http2Exception.StreamException) clientCause).error().code()); + + // Wait for the server to receive a GO_AWAY, but this is expected to timeout! + assertFalse(goAwayLatch.await(1, SECONDS)); + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + + // Shutdown shouldn't wait for the server to close streams + setClientGracefulShutdownTime(0); + setServerGracefulShutdownTime(0); + } + @Test public void flowControlProperlyChunksLargeMessage() throws Exception { final Http2Headers headers = dummyHeaders(); @@ -861,7 +1080,7 @@ public void run() throws Http2Exception { assertArrayEquals(data.array(), received); } finally { // Don't wait for server to close streams - http2Client.gracefulShutdownTimeoutMillis(0); + setClientGracefulShutdownTime(0); data.release(); out.close(); } @@ -949,7 +1168,7 @@ public void run() throws Http2Exception { } } finally { // Don't wait for server to close streams - http2Client.gracefulShutdownTimeoutMillis(0); + setClientGracefulShutdownTime(0); data.release(); } } @@ -1063,6 +1282,28 @@ public Integer answer(InvocationOnMock invocation) throws Throwable { any(ByteBuf.class), anyInt(), anyBoolean()); } + private void setClientGracefulShutdownTime(final long millis) throws InterruptedException { + setGracefulShutdownTime(clientChannel, http2Client, millis); + } + + private void setServerGracefulShutdownTime(final long millis) throws InterruptedException { + setGracefulShutdownTime(serverChannel, http2Server, millis); + } + + private static void setGracefulShutdownTime(Channel channel, final Http2ConnectionHandler handler, + final long millis) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + runInChannel(channel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + handler.gracefulShutdownTimeoutMillis(millis); + latch.countDown(); + } + }); + + assertTrue(latch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + } + /** * Creates a {@link ByteBuf} of the given length, filled with random bytes. */ diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java index 212398604723..27d13cf50b32 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java @@ -15,9 +15,7 @@ package io.netty.handler.codec.http2; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; -import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; @@ -57,6 +55,11 @@ import java.util.concurrent.atomic.AtomicBoolean; import static io.netty.handler.codec.http2.Http2CodecUtil.isStreamIdValid; +import static io.netty.handler.codec.http2.Http2TestUtil.anyChannelPromise; +import static io.netty.handler.codec.http2.Http2TestUtil.anyHttp2Settings; +import static io.netty.handler.codec.http2.Http2TestUtil.assertEqualsAndRelease; +import static io.netty.handler.codec.http2.Http2TestUtil.bb; + import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -73,7 +76,6 @@ import static org.mockito.Mockito.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.same; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -86,9 +88,10 @@ public class Http2FrameCodecTest { private Http2FrameWriter frameWriter; private Http2FrameCodec frameCodec; private EmbeddedChannel channel; + // For injecting inbound frames - private Http2FrameListener frameListener; - private ChannelHandlerContext http2HandlerCtx; + private Http2FrameInboundWriter frameInboundWriter; + private LastInboundHandler inboundHandler; private final Http2Headers request = new DefaultHttp2Headers() @@ -123,29 +126,29 @@ private void setUp(Http2FrameCodecBuilder frameCodecBuilder, Http2Settings initi */ tearDown(); - frameWriter = spy(new VerifiableHttp2FrameWriter()); + frameWriter = Http2TestUtil.mockedFrameWriter(); + frameCodec = frameCodecBuilder.frameWriter(frameWriter).frameLogger(new Http2FrameLogger(LogLevel.TRACE)) .initialSettings(initialRemoteSettings).build(); - frameListener = ((DefaultHttp2ConnectionDecoder) frameCodec.decoder()) - .internalFrameListener(); inboundHandler = new LastInboundHandler(); channel = new EmbeddedChannel(); + frameInboundWriter = new Http2FrameInboundWriter(channel); channel.connect(new InetSocketAddress(0)); channel.pipeline().addLast(frameCodec); channel.pipeline().addLast(inboundHandler); channel.pipeline().fireChannelActive(); - http2HandlerCtx = channel.pipeline().context(frameCodec); - // Handshake - verify(frameWriter).writeSettings(eq(http2HandlerCtx), - anyHttp2Settings(), anyChannelPromise()); + verify(frameWriter).writeSettings(eqFrameCodecCtx(), anyHttp2Settings(), anyChannelPromise()); verifyNoMoreInteractions(frameWriter); channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf()); - frameListener.onSettingsRead(http2HandlerCtx, initialRemoteSettings); - verify(frameWriter).writeSettingsAck(eq(http2HandlerCtx), anyChannelPromise()); - frameListener.onSettingsAckRead(http2HandlerCtx); + + frameInboundWriter.writeInboundSettings(initialRemoteSettings); + + verify(frameWriter).writeSettingsAck(eqFrameCodecCtx(), anyChannelPromise()); + + frameInboundWriter.writeInboundSettingsAck(); Http2SettingsFrame settingsFrame = inboundHandler.readInbound(); assertNotNull(settingsFrame); @@ -153,7 +156,7 @@ private void setUp(Http2FrameCodecBuilder frameCodecBuilder, Http2Settings initi @Test public void stateChanges() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 1, request, 31, true); + frameInboundWriter.writeInboundHeaders(1, request, 31, true); Http2Stream stream = frameCodec.connection().stream(1); assertNotNull(stream); @@ -169,12 +172,12 @@ public void stateChanges() throws Exception { assertEquals(inboundFrame, new DefaultHttp2HeadersFrame(request, true, 31).stream(stream2)); assertNull(inboundHandler.readInbound()); - inboundHandler.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2)); + channel.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2)); verify(frameWriter).writeHeaders( - eq(http2HandlerCtx), eq(1), eq(response), anyInt(), anyShort(), anyBoolean(), + eqFrameCodecCtx(), eq(1), eq(response), anyInt(), anyShort(), anyBoolean(), eq(27), eq(true), anyChannelPromise()); verify(frameWriter, never()).writeRstStream( - any(ChannelHandlerContext.class), anyInt(), anyLong(), anyChannelPromise()); + eqFrameCodecCtx(), anyInt(), anyLong(), anyChannelPromise()); assertEquals(State.CLOSED, stream.state()); event = inboundHandler.readInboundMessageOrUserEvent(); @@ -185,7 +188,7 @@ public void stateChanges() throws Exception { @Test public void headerRequestHeaderResponse() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 1, request, 31, true); + frameInboundWriter.writeInboundHeaders(1, request, 31, true); Http2Stream stream = frameCodec.connection().stream(1); assertNotNull(stream); @@ -198,34 +201,34 @@ public void headerRequestHeaderResponse() throws Exception { assertEquals(inboundFrame, new DefaultHttp2HeadersFrame(request, true, 31).stream(stream2)); assertNull(inboundHandler.readInbound()); - inboundHandler.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2)); + channel.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2)); verify(frameWriter).writeHeaders( - eq(http2HandlerCtx), eq(1), eq(response), anyInt(), anyShort(), anyBoolean(), + eqFrameCodecCtx(), eq(1), eq(response), anyInt(), anyShort(), anyBoolean(), eq(27), eq(true), anyChannelPromise()); verify(frameWriter, never()).writeRstStream( - any(ChannelHandlerContext.class), anyInt(), anyLong(), anyChannelPromise()); + eqFrameCodecCtx(), anyInt(), anyLong(), anyChannelPromise()); assertEquals(State.CLOSED, stream.state()); assertTrue(channel.isActive()); } - @Test - public void flowControlShouldBeResilientToMissingStreams() throws Http2Exception { - Http2Connection conn = new DefaultHttp2Connection(true); - Http2ConnectionEncoder enc = new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter()); - Http2ConnectionDecoder dec = new DefaultHttp2ConnectionDecoder(conn, enc, new DefaultHttp2FrameReader()); - Http2FrameCodec codec = new Http2FrameCodec(enc, dec, new Http2Settings()); - EmbeddedChannel em = new EmbeddedChannel(codec); + @Test + public void flowControlShouldBeResilientToMissingStreams() throws Http2Exception { + Http2Connection conn = new DefaultHttp2Connection(true); + Http2ConnectionEncoder enc = new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter()); + Http2ConnectionDecoder dec = new DefaultHttp2ConnectionDecoder(conn, enc, new DefaultHttp2FrameReader()); + Http2FrameCodec codec = new Http2FrameCodec(enc, dec, new Http2Settings()); + EmbeddedChannel em = new EmbeddedChannel(codec); - // We call #consumeBytes on a stream id which has not been seen yet to emulate the case - // where a stream is deregistered which in reality can happen in response to a RST. - assertFalse(codec.consumeBytes(1, 1)); - assertTrue(em.finishAndReleaseAll()); - } + // We call #consumeBytes on a stream id which has not been seen yet to emulate the case + // where a stream is deregistered which in reality can happen in response to a RST. + assertFalse(codec.consumeBytes(1, 1)); + assertTrue(em.finishAndReleaseAll()); + } @Test public void entityRequestEntityResponse() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 1, request, 0, false); + frameInboundWriter.writeInboundHeaders(1, request, 0, false); Http2Stream stream = frameCodec.connection().stream(1); assertNotNull(stream); @@ -239,39 +242,35 @@ public void entityRequestEntityResponse() throws Exception { assertNull(inboundHandler.readInbound()); ByteBuf hello = bb("hello"); - frameListener.onDataRead(http2HandlerCtx, 1, hello, 31, true); - // Release hello to emulate ByteToMessageDecoder - hello.release(); + frameInboundWriter.writeInboundData(1, hello, 31, true); Http2DataFrame inboundData = inboundHandler.readInbound(); Http2DataFrame expected = new DefaultHttp2DataFrame(bb("hello"), true, 31).stream(stream2); - assertEquals(expected, inboundData); + assertEqualsAndRelease(expected, inboundData); - assertEquals(1, inboundData.refCnt()); - expected.release(); - inboundData.release(); assertNull(inboundHandler.readInbound()); - inboundHandler.writeOutbound(new DefaultHttp2HeadersFrame(response, false).stream(stream2)); - verify(frameWriter).writeHeaders(eq(http2HandlerCtx), eq(1), eq(response), anyInt(), + channel.writeOutbound(new DefaultHttp2HeadersFrame(response, false).stream(stream2)); + verify(frameWriter).writeHeaders(eqFrameCodecCtx(), eq(1), eq(response), anyInt(), anyShort(), anyBoolean(), eq(0), eq(false), anyChannelPromise()); - inboundHandler.writeOutbound(new DefaultHttp2DataFrame(bb("world"), true, 27).stream(stream2)); + channel.writeOutbound(new DefaultHttp2DataFrame(bb("world"), true, 27).stream(stream2)); ArgumentCaptor outboundData = ArgumentCaptor.forClass(ByteBuf.class); - verify(frameWriter).writeData(eq(http2HandlerCtx), eq(1), outboundData.capture(), eq(27), + verify(frameWriter).writeData(eqFrameCodecCtx(), eq(1), outboundData.capture(), eq(27), eq(true), anyChannelPromise()); ByteBuf bb = bb("world"); assertEquals(bb, outboundData.getValue()); assertEquals(1, outboundData.getValue().refCnt()); bb.release(); - verify(frameWriter, never()).writeRstStream( - any(ChannelHandlerContext.class), anyInt(), anyLong(), anyChannelPromise()); + outboundData.getValue().release(); + + verify(frameWriter, never()).writeRstStream(eqFrameCodecCtx(), anyInt(), anyLong(), anyChannelPromise()); assertTrue(channel.isActive()); } @Test public void sendRstStream() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, true); + frameInboundWriter.writeInboundHeaders(3, request, 31, true); Http2Stream stream = frameCodec.connection().stream(3); assertNotNull(stream); @@ -285,16 +284,15 @@ public void sendRstStream() throws Exception { assertNotNull(stream2); assertEquals(3, stream2.id()); - inboundHandler.writeOutbound(new DefaultHttp2ResetFrame(314 /* non-standard error */).stream(stream2)); - verify(frameWriter).writeRstStream( - eq(http2HandlerCtx), eq(3), eq(314L), anyChannelPromise()); + channel.writeOutbound(new DefaultHttp2ResetFrame(314 /* non-standard error */).stream(stream2)); + verify(frameWriter).writeRstStream(eqFrameCodecCtx(), eq(3), eq(314L), anyChannelPromise()); assertEquals(State.CLOSED, stream.state()); assertTrue(channel.isActive()); } @Test public void receiveRstStream() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); + frameInboundWriter.writeInboundHeaders(3, request, 31, false); Http2Stream stream = frameCodec.connection().stream(3); assertNotNull(stream); @@ -304,7 +302,7 @@ public void receiveRstStream() throws Exception { Http2HeadersFrame actualHeaders = inboundHandler.readInbound(); assertEquals(expectedHeaders.stream(actualHeaders.stream()), actualHeaders); - frameListener.onRstStreamRead(http2HandlerCtx, 3, Http2Error.NO_ERROR.code()); + frameInboundWriter.writeInboundRstStream(3, Http2Error.NO_ERROR.code()); Http2ResetFrame expectedRst = new DefaultHttp2ResetFrame(Http2Error.NO_ERROR).stream(actualHeaders.stream()); Http2ResetFrame actualRst = inboundHandler.readInbound(); @@ -315,8 +313,7 @@ public void receiveRstStream() throws Exception { @Test public void sendGoAway() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); - + frameInboundWriter.writeInboundHeaders(3, request, 31, false); Http2Stream stream = frameCodec.connection().stream(3); assertNotNull(stream); assertEquals(State.OPEN, stream.state()); @@ -324,32 +321,29 @@ public void sendGoAway() throws Exception { ByteBuf debugData = bb("debug"); ByteBuf expected = debugData.copy(); - Http2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(Http2Error.NO_ERROR.code(), debugData.slice()); + Http2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(Http2Error.NO_ERROR.code(), debugData); goAwayFrame.setExtraStreamIds(2); - inboundHandler.writeOutbound(goAwayFrame); - verify(frameWriter).writeGoAway( - eq(http2HandlerCtx), eq(7), eq(Http2Error.NO_ERROR.code()), eq(expected), anyChannelPromise()); + channel.writeOutbound(goAwayFrame); + verify(frameWriter).writeGoAway(eqFrameCodecCtx(), eq(7), + eq(Http2Error.NO_ERROR.code()), eq(expected), anyChannelPromise()); assertEquals(1, debugData.refCnt()); assertEquals(State.OPEN, stream.state()); assertTrue(channel.isActive()); expected.release(); + debugData.release(); } @Test public void receiveGoaway() throws Exception { ByteBuf debugData = bb("foo"); - frameListener.onGoAwayRead(http2HandlerCtx, 2, Http2Error.NO_ERROR.code(), debugData); - // Release debugData to emulate ByteToMessageDecoder - debugData.release(); + frameInboundWriter.writeInboundGoAway(2, Http2Error.NO_ERROR.code(), debugData); Http2GoAwayFrame expectedFrame = new DefaultHttp2GoAwayFrame(2, Http2Error.NO_ERROR.code(), bb("foo")); Http2GoAwayFrame actualFrame = inboundHandler.readInbound(); - assertEquals(expectedFrame, actualFrame); - assertNull(inboundHandler.readInbound()); + assertEqualsAndRelease(expectedFrame, actualFrame); - expectedFrame.release(); - actualFrame.release(); + assertNull(inboundHandler.readInbound()); } @Test @@ -383,7 +377,7 @@ public ReferenceCounted touch(Object hint) { @Test public void goAwayLastStreamIdOverflowed() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 5, request, 31, false); + frameInboundWriter.writeInboundHeaders(5, request, 31, false); Http2Stream stream = frameCodec.connection().stream(5); assertNotNull(stream); @@ -393,10 +387,10 @@ public void goAwayLastStreamIdOverflowed() throws Exception { Http2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(Http2Error.NO_ERROR.code(), debugData.slice()); goAwayFrame.setExtraStreamIds(Integer.MAX_VALUE); - inboundHandler.writeOutbound(goAwayFrame); + channel.writeOutbound(goAwayFrame); // When the last stream id computation overflows, the last stream id should just be set to 2^31 - 1. - verify(frameWriter).writeGoAway(eq(http2HandlerCtx), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()), - eq(debugData), anyChannelPromise()); + verify(frameWriter).writeGoAway(eqFrameCodecCtx(), eq(Integer.MAX_VALUE), + eq(Http2Error.NO_ERROR.code()), eq(debugData), anyChannelPromise()); assertEquals(1, debugData.refCnt()); assertEquals(State.OPEN, stream.state()); assertTrue(channel.isActive()); @@ -404,13 +398,13 @@ public void goAwayLastStreamIdOverflowed() throws Exception { @Test public void streamErrorShouldFireExceptionForInbound() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); + frameInboundWriter.writeInboundHeaders(3, request, 31, false); Http2Stream stream = frameCodec.connection().stream(3); assertNotNull(stream); StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo"); - frameCodec.onError(http2HandlerCtx, false, streamEx); + channel.pipeline().fireExceptionCaught(streamEx); Http2FrameStreamEvent event = inboundHandler.readInboundMessageOrUserEvent(); assertEquals(Http2FrameStreamEvent.Type.State, event.type()); @@ -430,13 +424,13 @@ public void streamErrorShouldFireExceptionForInbound() throws Exception { @Test public void streamErrorShouldNotFireExceptionForOutbound() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); + frameInboundWriter.writeInboundHeaders(3, request, 31, false); Http2Stream stream = frameCodec.connection().stream(3); assertNotNull(stream); StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo"); - frameCodec.onError(http2HandlerCtx, true, streamEx); + frameCodec.onError(frameCodec.ctx, true, streamEx); Http2FrameStreamEvent event = inboundHandler.readInboundMessageOrUserEvent(); assertEquals(Http2FrameStreamEvent.Type.State, event.type()); @@ -452,14 +446,14 @@ public void streamErrorShouldNotFireExceptionForOutbound() throws Exception { @Test public void windowUpdateFrameDecrementsConsumedBytes() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); + frameInboundWriter.writeInboundHeaders(3, request, 31, false); Http2Connection connection = frameCodec.connection(); Http2Stream stream = connection.stream(3); assertNotNull(stream); ByteBuf data = Unpooled.buffer(100).writeZero(100); - frameListener.onDataRead(http2HandlerCtx, 3, data, 0, true); + frameInboundWriter.writeInboundData(3, data, 0, false); Http2HeadersFrame inboundHeaders = inboundHandler.readInbound(); assertNotNull(inboundHeaders); @@ -472,12 +466,11 @@ public void windowUpdateFrameDecrementsConsumedBytes() throws Exception { int after = connection.local().flowController().unconsumedBytes(stream); assertEquals(100, before - after); assertTrue(f.isSuccess()); - data.release(); } @Test public void windowUpdateMayFail() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); + frameInboundWriter.writeInboundHeaders(3, request, 31, false); Http2Connection connection = frameCodec.connection(); Http2Stream stream = connection.stream(3); assertNotNull(stream); @@ -496,10 +489,10 @@ public void windowUpdateMayFail() throws Exception { @Test public void inboundWindowUpdateShouldBeForwarded() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); - frameListener.onWindowUpdateRead(http2HandlerCtx, 3, 100); + frameInboundWriter.writeInboundHeaders(3, request, 31, false); + frameInboundWriter.writeInboundWindowUpdate(3, 100); // Connection-level window update - frameListener.onWindowUpdateRead(http2HandlerCtx, 0, 100); + frameInboundWriter.writeInboundWindowUpdate(0, 100); Http2HeadersFrame headersFrame = inboundHandler.readInbound(); assertNotNull(headersFrame); @@ -558,7 +551,7 @@ public void writeUnknownFrame() { unknownFrame.stream(stream); channel.write(unknownFrame); - verify(frameWriter).writeFrame(eq(http2HandlerCtx), eq(unknownFrame.frameType()), + verify(frameWriter).writeFrame(eqFrameCodecCtx(), eq(unknownFrame.frameType()), eq(unknownFrame.stream().id()), eq(unknownFrame.flags()), eq(buffer), any(ChannelPromise.class)); } @@ -567,7 +560,7 @@ public void sendSettingsFrame() { Http2Settings settings = new Http2Settings(); channel.write(new DefaultHttp2SettingsFrame(settings)); - verify(frameWriter).writeSettings(eq(http2HandlerCtx), same(settings), any(ChannelPromise.class)); + verify(frameWriter).writeSettings(eqFrameCodecCtx(), same(settings), any(ChannelPromise.class)); } @Test(timeout = 5000) @@ -619,11 +612,54 @@ public void newOutboundStreamsShouldBeBuffered() throws Exception { assertFalse(promise2.isDone()); // Increase concurrent streams limit to 2 - frameListener.onSettingsRead(http2HandlerCtx, new Http2Settings().maxConcurrentStreams(2)); + frameInboundWriter.writeInboundSettings(new Http2Settings().maxConcurrentStreams(2)); + + channel.flush(); + + assertTrue(promise2.syncUninterruptibly().isSuccess()); + } + + @Test + public void multipleNewOutboundStreamsShouldBeBuffered() throws Exception { + // We use a limit of 1 and then increase it step by step. + setUp(Http2FrameCodecBuilder.forServer().encoderEnforceMaxConcurrentStreams(true), + new Http2Settings().maxConcurrentStreams(1)); + + Http2FrameStream stream1 = frameCodec.newStream(); + Http2FrameStream stream2 = frameCodec.newStream(); + Http2FrameStream stream3 = frameCodec.newStream(); + + ChannelPromise promise1 = channel.newPromise(); + ChannelPromise promise2 = channel.newPromise(); + ChannelPromise promise3 = channel.newPromise(); + + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream1), promise1); + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream2), promise2); + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream3), promise3); + + assertTrue(isStreamIdValid(stream1.id())); + channel.runPendingTasks(); + assertTrue(isStreamIdValid(stream2.id())); + + assertTrue(promise1.syncUninterruptibly().isSuccess()); + assertFalse(promise2.isDone()); + assertFalse(promise3.isDone()); + // Increase concurrent streams limit to 2 + frameInboundWriter.writeInboundSettings(new Http2Settings().maxConcurrentStreams(2)); channel.flush(); + // As we increased the limit to 2 we should have also succeed the second frame. assertTrue(promise2.syncUninterruptibly().isSuccess()); + assertFalse(promise3.isDone()); + + frameInboundWriter.writeInboundSettings(new Http2Settings().maxConcurrentStreams(3)); + channel.flush(); + + // With the max streams of 3 all streams should be succeed now. + assertTrue(promise3.syncUninterruptibly().isSuccess()); + + assertFalse(channel.finishAndReleaseAll()); } @Test @@ -643,7 +679,7 @@ public void streamIdentifiersExhausted() throws Http2Exception { @Test public void receivePing() throws Http2Exception { - frameListener.onPingRead(http2HandlerCtx, 12345L); + frameInboundWriter.writeInboundPing(false, 12345L); Http2PingFrame pingFrame = inboundHandler.readInbound(); assertNotNull(pingFrame); @@ -656,13 +692,13 @@ public void receivePing() throws Http2Exception { public void sendPing() { channel.writeAndFlush(new DefaultHttp2PingFrame(12345)); - verify(frameWriter).writePing(eq(http2HandlerCtx), eq(false), eq(12345L), anyChannelPromise()); + verify(frameWriter).writePing(eqFrameCodecCtx(), eq(false), eq(12345L), anyChannelPromise()); } @Test public void receiveSettings() throws Http2Exception { Http2Settings settings = new Http2Settings().maxConcurrentStreams(1); - frameListener.onSettingsRead(http2HandlerCtx, settings); + frameInboundWriter.writeInboundSettings(settings); Http2SettingsFrame settingsFrame = inboundHandler.readInbound(); assertNotNull(settingsFrame); @@ -674,7 +710,7 @@ public void sendSettings() { Http2Settings settings = new Http2Settings().maxConcurrentStreams(1); channel.writeAndFlush(new DefaultHttp2SettingsFrame(settings)); - verify(frameWriter).writeSettings(eq(http2HandlerCtx), eq(settings), anyChannelPromise()); + verify(frameWriter).writeSettings(eqFrameCodecCtx(), eq(settings), anyChannelPromise()); } @Test @@ -682,7 +718,7 @@ public void iterateActiveStreams() throws Exception { setUp(Http2FrameCodecBuilder.forServer().encoderEnforceMaxConcurrentStreams(true), new Http2Settings().maxConcurrentStreams(1)); - frameListener.onHeadersRead(http2HandlerCtx, 3, request, 0, false); + frameInboundWriter.writeInboundHeaders(3, request, 0, false); Http2HeadersFrame headersFrame = inboundHandler.readInbound(); assertNotNull(headersFrame); @@ -736,8 +772,7 @@ public void operationComplete(ChannelFuture future) throws Exception { @Test public void upgradeEventNoRefCntError() throws Exception { - frameListener.onHeadersRead(http2HandlerCtx, Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false); - + frameInboundWriter.writeInboundHeaders(Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false); // Using reflect as the constructor is package-private and the class is final. Constructor constructor = UpgradeEvent.class.getDeclaredConstructor(CharSequence.class, FullHttpRequest.class); @@ -753,7 +788,7 @@ public void upgradeEventNoRefCntError() throws Exception { @Test public void upgradeWithoutFlowControlling() throws Exception { - channel.pipeline().addAfter(http2HandlerCtx.name(), null, new ChannelInboundHandlerAdapter() { + channel.pipeline().addAfter(frameCodec.ctx.name(), null, new ChannelInboundHandlerAdapter() { @Override public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof Http2DataFrame) { @@ -774,7 +809,7 @@ public void operationComplete(ChannelFuture future) throws Exception { } }); - frameListener.onHeadersRead(http2HandlerCtx, Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false); + frameInboundWriter.writeInboundHeaders(Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false); // Using reflect as the constructor is package-private and the class is final. Constructor constructor = @@ -792,24 +827,7 @@ public void operationComplete(ChannelFuture future) throws Exception { channel.pipeline().fireUserEventTriggered(upgradeEvent); } - private static ChannelPromise anyChannelPromise() { - return any(ChannelPromise.class); - } - - private static Http2Settings anyHttp2Settings() { - return any(Http2Settings.class); - } - - private static ByteBuf bb(String s) { - return ByteBufUtil.writeUtf8(UnpooledByteBufAllocator.DEFAULT, s); - } - - private static class VerifiableHttp2FrameWriter extends DefaultHttp2FrameWriter { - @Override - public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, - int padding, boolean endStream, ChannelPromise promise) { - // duplicate 'data' to prevent readerIndex from being changed, to ease verification - return super.writeData(ctx, streamId, data.duplicate(), padding, endStream, promise); - } + private ChannelHandlerContext eqFrameCodecCtx() { + return eq(frameCodec.ctx); } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameInboundWriter.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameInboundWriter.java new file mode 100644 index 000000000000..ace50aadbea6 --- /dev/null +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameInboundWriter.java @@ -0,0 +1,340 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelProgressivePromise; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.EventExecutor; + +import java.net.SocketAddress; + +/** + * Utility class which allows easy writing of HTTP2 frames via {@link EmbeddedChannel#writeInbound(Object...)}. + */ +final class Http2FrameInboundWriter { + + private final ChannelHandlerContext ctx; + private final Http2FrameWriter writer; + + Http2FrameInboundWriter(EmbeddedChannel channel) { + this(channel, new DefaultHttp2FrameWriter()); + } + + Http2FrameInboundWriter(EmbeddedChannel channel, Http2FrameWriter writer) { + this.ctx = new WriteInboundChannelHandlerContext(channel); + this.writer = writer; + } + + void writeInboundData(int streamId, ByteBuf data, int padding, boolean endStream) { + writer.writeData(ctx, streamId, data, padding, endStream, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundHeaders(int streamId, Http2Headers headers, + int padding, boolean endStream) { + writer.writeHeaders(ctx, streamId, headers, padding, endStream, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundHeaders(int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, int padding, boolean endStream) { + writer.writeHeaders(ctx, streamId, headers, streamDependency, + weight, exclusive, padding, endStream, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundPriority(int streamId, int streamDependency, + short weight, boolean exclusive) { + writer.writePriority(ctx, streamId, streamDependency, weight, + exclusive, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundRstStream(int streamId, long errorCode) { + writer.writeRstStream(ctx, streamId, errorCode, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundSettings(Http2Settings settings) { + writer.writeSettings(ctx, settings, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundSettingsAck() { + writer.writeSettingsAck(ctx, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundPing(boolean ack, long data) { + writer.writePing(ctx, ack, data, ctx.newPromise()).syncUninterruptibly(); + } + + void writePushPromise(int streamId, int promisedStreamId, + Http2Headers headers, int padding) { + writer.writePushPromise(ctx, streamId, promisedStreamId, + headers, padding, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundGoAway(int lastStreamId, long errorCode, ByteBuf debugData) { + writer.writeGoAway(ctx, lastStreamId, errorCode, debugData, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundWindowUpdate(int streamId, int windowSizeIncrement) { + writer.writeWindowUpdate(ctx, streamId, windowSizeIncrement, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundFrame(byte frameType, int streamId, + Http2Flags flags, ByteBuf payload) { + writer.writeFrame(ctx, frameType, streamId, flags, payload, ctx.newPromise()).syncUninterruptibly(); + } + + private static final class WriteInboundChannelHandlerContext extends ChannelOutboundHandlerAdapter + implements ChannelHandlerContext { + private final EmbeddedChannel channel; + + WriteInboundChannelHandlerContext(EmbeddedChannel channel) { + this.channel = channel; + } + + @Override + public Channel channel() { + return channel; + } + + @Override + public EventExecutor executor() { + return channel.eventLoop(); + } + + @Override + public String name() { + return "WriteInbound"; + } + + @Override + public ChannelHandler handler() { + return this; + } + + @Override + public boolean isRemoved() { + return false; + } + + @Override + public ChannelHandlerContext fireChannelRegistered() { + channel.pipeline().fireChannelRegistered(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelUnregistered() { + channel.pipeline().fireChannelUnregistered(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelActive() { + channel.pipeline().fireChannelActive(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelInactive() { + channel.pipeline().fireChannelInactive(); + return this; + } + + @Override + public ChannelHandlerContext fireExceptionCaught(Throwable cause) { + channel.pipeline().fireExceptionCaught(cause); + return this; + } + + @Override + public ChannelHandlerContext fireUserEventTriggered(Object evt) { + channel.pipeline().fireUserEventTriggered(evt); + return this; + } + + @Override + public ChannelHandlerContext fireChannelRead(Object msg) { + channel.pipeline().fireChannelRead(msg); + return this; + } + + @Override + public ChannelHandlerContext fireChannelReadComplete() { + channel.pipeline().fireChannelReadComplete(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelWritabilityChanged() { + channel.pipeline().fireChannelWritabilityChanged(); + return this; + } + + @Override + public ChannelHandlerContext read() { + channel.read(); + return this; + } + + @Override + public ChannelHandlerContext flush() { + channel.pipeline().fireChannelReadComplete(); + return this; + } + + @Override + public ChannelPipeline pipeline() { + return channel.pipeline(); + } + + @Override + public ByteBufAllocator alloc() { + return channel.alloc(); + } + + @Override + public Attribute attr(AttributeKey key) { + return channel.attr(key); + } + + @Override + public boolean hasAttr(AttributeKey key) { + return channel.hasAttr(key); + } + + @Override + public ChannelFuture bind(SocketAddress localAddress) { + return channel.bind(localAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress) { + return channel.connect(remoteAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { + return channel.connect(remoteAddress, localAddress); + } + + @Override + public ChannelFuture disconnect() { + return channel.disconnect(); + } + + @Override + public ChannelFuture close() { + return channel.close(); + } + + @Override + public ChannelFuture deregister() { + return channel.deregister(); + } + + @Override + public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { + return channel.bind(localAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { + return channel.connect(remoteAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + return channel.connect(remoteAddress, localAddress, promise); + } + + @Override + public ChannelFuture disconnect(ChannelPromise promise) { + return channel.disconnect(promise); + } + + @Override + public ChannelFuture close(ChannelPromise promise) { + return channel.close(promise); + } + + @Override + public ChannelFuture deregister(ChannelPromise promise) { + return channel.deregister(promise); + } + + @Override + public ChannelFuture write(Object msg) { + return write(msg, newPromise()); + } + + @Override + public ChannelFuture write(Object msg, ChannelPromise promise) { + return writeAndFlush(msg, promise); + } + + @Override + public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { + try { + channel.writeInbound(msg); + channel.runPendingTasks(); + promise.setSuccess(); + } catch (Throwable cause) { + promise.setFailure(cause); + } + return promise; + } + + @Override + public ChannelFuture writeAndFlush(Object msg) { + return writeAndFlush(msg, newPromise()); + } + + @Override + public ChannelPromise newPromise() { + return channel.newPromise(); + } + + @Override + public ChannelProgressivePromise newProgressivePromise() { + return channel.newProgressivePromise(); + } + + @Override + public ChannelFuture newSucceededFuture() { + return channel.newSucceededFuture(); + } + + @Override + public ChannelFuture newFailedFuture(Throwable cause) { + return channel.newFailedFuture(cause); + } + + @Override + public ChannelPromise voidPromise() { + return channel.voidPromise(); + } + } +} diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecClientUpgradeTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecClientUpgradeTest.java new file mode 100644 index 000000000000..26b63ed7f9c3 --- /dev/null +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecClientUpgradeTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import org.junit.Test; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class Http2MultiplexCodecClientUpgradeTest { + + @ChannelHandler.Sharable + private final class NoopHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.channel().close(); + } + } + + private final class UpgradeHandler extends ChannelInboundHandlerAdapter { + Http2Stream.State stateOnActive; + int streamId; + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + Http2StreamChannel ch = (Http2StreamChannel) ctx.channel(); + stateOnActive = ch.stream().state(); + streamId = ch.stream().id(); + super.channelActive(ctx); + } + } + + private Http2MultiplexCodec newCodec(ChannelHandler upgradeHandler) { + Http2MultiplexCodecBuilder builder = Http2MultiplexCodecBuilder.forClient(new NoopHandler()); + builder.withUpgradeStreamHandler(upgradeHandler); + return builder.build(); + } + + @Test + public void upgradeHandlerGetsActivated() throws Exception { + UpgradeHandler upgradeHandler = new UpgradeHandler(); + Http2MultiplexCodec codec = newCodec(upgradeHandler); + EmbeddedChannel ch = new EmbeddedChannel(codec); + + codec.onHttpClientUpgrade(); + + assertFalse(upgradeHandler.stateOnActive.localSideOpen()); + assertTrue(upgradeHandler.stateOnActive.remoteSideOpen()); + assertEquals(1, upgradeHandler.streamId); + assertTrue(ch.finishAndReleaseAll()); + } + + @Test(expected = Http2Exception.class) + public void clientUpgradeWithoutUpgradeHandlerThrowsHttp2Exception() throws Http2Exception { + Http2MultiplexCodec codec = Http2MultiplexCodecBuilder.forClient(new NoopHandler()).build(); + EmbeddedChannel ch = new EmbeddedChannel(codec); + try { + codec.onHttpClientUpgrade(); + } finally { + assertTrue(ch.finishAndReleaseAll()); + } + } +} diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecTest.java index 7a0e8c6a3a06..7788e6dd64c7 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecTest.java @@ -15,8 +15,7 @@ package io.netty.handler.codec.http2; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -28,45 +27,60 @@ import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpScheme; import io.netty.handler.codec.http2.Http2Exception.StreamException; +import io.netty.handler.codec.http2.LastInboundHandler.Consumer; import io.netty.util.AsciiString; import io.netty.util.AttributeKey; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Queue; import java.util.concurrent.atomic.AtomicBoolean; - -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static io.netty.util.ReferenceCountUtil.release; -import static org.hamcrest.Matchers.instanceOf; +import static io.netty.handler.codec.http2.Http2TestUtil.anyChannelPromise; +import static io.netty.handler.codec.http2.Http2TestUtil.anyHttp2Settings; +import static io.netty.handler.codec.http2.Http2TestUtil.assertEqualsAndRelease; +import static io.netty.handler.codec.http2.Http2TestUtil.bb; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyShort; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * Unit tests for {@link Http2MultiplexCodec}. */ public class Http2MultiplexCodecTest { - - private EmbeddedChannel parentChannel; - private Writer writer; - - private TestChannelInitializer childChannelInitializer; - - private static final Http2Headers request = new DefaultHttp2Headers() + private final Http2Headers request = new DefaultHttp2Headers() .method(HttpMethod.GET.asciiName()).scheme(HttpScheme.HTTPS.name()) .authority(new AsciiString("example.org")).path(new AsciiString("/foo")); - private TestableHttp2MultiplexCodec codec; - private TestableHttp2MultiplexCodec.Stream inboundStream; - private TestableHttp2MultiplexCodec.Stream outboundStream; + private EmbeddedChannel parentChannel; + private Http2FrameWriter frameWriter; + private Http2FrameInboundWriter frameInboundWriter; + private TestChannelInitializer childChannelInitializer; + private Http2MultiplexCodec codec; private static final int initialRemoteStreamWindow = 1024; @@ -74,25 +88,38 @@ public class Http2MultiplexCodecTest { public void setUp() { childChannelInitializer = new TestChannelInitializer(); parentChannel = new EmbeddedChannel(); - writer = new Writer(); - + frameInboundWriter = new Http2FrameInboundWriter(parentChannel); parentChannel.connect(new InetSocketAddress(0)); - codec = new TestableHttp2MultiplexCodecBuilder(true, childChannelInitializer).build(); + frameWriter = Http2TestUtil.mockedFrameWriter(); + codec = new Http2MultiplexCodecBuilder(true, childChannelInitializer).frameWriter(frameWriter).build(); parentChannel.pipeline().addLast(codec); parentChannel.runPendingTasks(); + parentChannel.pipeline().fireChannelActive(); + + parentChannel.writeInbound(Http2CodecUtil.connectionPrefaceBuf()); Http2Settings settings = new Http2Settings().initialWindowSize(initialRemoteStreamWindow); - codec.onHttp2Frame(new DefaultHttp2SettingsFrame(settings)); + frameInboundWriter.writeInboundSettings(settings); + + verify(frameWriter).writeSettingsAck(eqMultiplexCodecCtx(), anyChannelPromise()); - inboundStream = codec.newStream(); - inboundStream.id = 3; - outboundStream = codec.newStream(); - outboundStream.id = 2; + frameInboundWriter.writeInboundSettingsAck(); + + Http2SettingsFrame settingsFrame = parentChannel.readInbound(); + assertNotNull(settingsFrame); + + // Handshake + verify(frameWriter).writeSettings(eqMultiplexCodecCtx(), + anyHttp2Settings(), anyChannelPromise()); + } + + private ChannelHandlerContext eqMultiplexCodecCtx() { + return eq(codec.ctx); } @After public void tearDown() throws Exception { - if (childChannelInitializer.handler != null) { + if (childChannelInitializer.handler instanceof LastInboundHandler) { ((LastInboundHandler) childChannelInitializer.handler).finishAndReleaseAll(); } parentChannel.finishAndReleaseAll(); @@ -104,123 +131,220 @@ public void tearDown() throws Exception { // TODO(buchgr): GOAWAY Logic // TODO(buchgr): Test ChannelConfig.setMaxMessagesPerRead + @Test + public void writeUnknownFrame() { + Http2StreamChannel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + ctx.writeAndFlush(new DefaultHttp2UnknownFrame((byte) 99, new Http2Flags())); + ctx.fireChannelActive(); + } + }); + assertTrue(childChannel.isActive()); + + parentChannel.runPendingTasks(); + + verify(frameWriter).writeFrame(eq(codec.ctx), eq((byte) 99), eqStreamId(childChannel), any(Http2Flags.class), + any(ByteBuf.class), any(ChannelPromise.class)); + } + + private Http2StreamChannel newInboundStream(int streamId, boolean endStream, final ChannelHandler childHandler) { + return newInboundStream(streamId, endStream, null, childHandler); + } + + private Http2StreamChannel newInboundStream(int streamId, boolean endStream, + AtomicInteger maxReads, final ChannelHandler childHandler) { + final AtomicReference streamChannelRef = new AtomicReference(); + childChannelInitializer.maxReads = maxReads; + childChannelInitializer.handler = new ChannelInboundHandlerAdapter() { + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + assertNull(streamChannelRef.get()); + streamChannelRef.set((Http2StreamChannel) ctx.channel()); + ctx.pipeline().addLast(childHandler); + ctx.fireChannelRegistered(); + } + }; + + frameInboundWriter.writeInboundHeaders(streamId, request, 0, endStream); + parentChannel.runPendingTasks(); + Http2StreamChannel channel = streamChannelRef.get(); + assertEquals(streamId, channel.stream().id()); + return channel; + } + + @Test + public void readUnkownFrame() { + LastInboundHandler handler = new LastInboundHandler(); + + Http2StreamChannel channel = newInboundStream(3, true, handler); + frameInboundWriter.writeInboundFrame((byte) 99, channel.stream().id(), new Http2Flags(), Unpooled.EMPTY_BUFFER); + + // header frame and unknown frame + verifyFramesMultiplexedToCorrectChannel(channel, handler, 2); + + Channel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter()); + assertTrue(childChannel.isActive()); + } + @Test public void headerAndDataFramesShouldBeDelivered() { LastInboundHandler inboundHandler = new LastInboundHandler(); - childChannelInitializer.handler = inboundHandler; - Http2HeadersFrame headersFrame = new DefaultHttp2HeadersFrame(request).stream(inboundStream); - Http2DataFrame dataFrame1 = new DefaultHttp2DataFrame(bb("hello")).stream(inboundStream); - Http2DataFrame dataFrame2 = new DefaultHttp2DataFrame(bb("world")).stream(inboundStream); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + Http2HeadersFrame headersFrame = new DefaultHttp2HeadersFrame(request).stream(channel.stream()); + Http2DataFrame dataFrame1 = new DefaultHttp2DataFrame(bb("hello")).stream(channel.stream()); + Http2DataFrame dataFrame2 = new DefaultHttp2DataFrame(bb("world")).stream(channel.stream()); - assertFalse(inboundHandler.isChannelActive()); - inboundStream.state = Http2Stream.State.OPEN; - codec.onHttp2StreamStateChanged(inboundStream); - codec.onHttp2Frame(headersFrame); assertTrue(inboundHandler.isChannelActive()); - codec.onHttp2Frame(dataFrame1); - codec.onHttp2Frame(dataFrame2); + frameInboundWriter.writeInboundData(channel.stream().id(), bb("hello"), 0, false); + frameInboundWriter.writeInboundData(channel.stream().id(), bb("world"), 0, false); assertEquals(headersFrame, inboundHandler.readInbound()); - assertEquals(dataFrame1, inboundHandler.readInbound()); - assertEquals(dataFrame2, inboundHandler.readInbound()); - assertNull(inboundHandler.readInbound()); - dataFrame1.release(); - dataFrame2.release(); + assertEqualsAndRelease(dataFrame1, inboundHandler.readInbound()); + assertEqualsAndRelease(dataFrame2, inboundHandler.readInbound()); + + assertNull(inboundHandler.readInbound()); } @Test public void framesShouldBeMultiplexed() { + LastInboundHandler handler1 = new LastInboundHandler(); + Http2StreamChannel channel1 = newInboundStream(3, false, handler1); + LastInboundHandler handler2 = new LastInboundHandler(); + Http2StreamChannel channel2 = newInboundStream(5, false, handler2); + LastInboundHandler handler3 = new LastInboundHandler(); + Http2StreamChannel channel3 = newInboundStream(11, false, handler3); - TestableHttp2MultiplexCodec.Stream stream3 = codec.newStream(); - stream3.id = 3; - TestableHttp2MultiplexCodec.Stream stream5 = codec.newStream(); - stream5.id = 5; - - TestableHttp2MultiplexCodec.Stream stream11 = codec.newStream(); - stream11.id = 11; - - LastInboundHandler inboundHandler3 = streamActiveAndWriteHeaders(stream3); - LastInboundHandler inboundHandler5 = streamActiveAndWriteHeaders(stream5); - LastInboundHandler inboundHandler11 = streamActiveAndWriteHeaders(stream11); + verifyFramesMultiplexedToCorrectChannel(channel1, handler1, 1); + verifyFramesMultiplexedToCorrectChannel(channel2, handler2, 1); + verifyFramesMultiplexedToCorrectChannel(channel3, handler3, 1); - verifyFramesMultiplexedToCorrectChannel(stream3, inboundHandler3, 1); - verifyFramesMultiplexedToCorrectChannel(stream5, inboundHandler5, 1); - verifyFramesMultiplexedToCorrectChannel(stream11, inboundHandler11, 1); + frameInboundWriter.writeInboundData(channel2.stream().id(), bb("hello"), 0, false); + frameInboundWriter.writeInboundData(channel1.stream().id(), bb("foo"), 0, true); + frameInboundWriter.writeInboundData(channel2.stream().id(), bb("world"), 0, true); + frameInboundWriter.writeInboundData(channel3.stream().id(), bb("bar"), 0, true); - codec.onHttp2Frame(new DefaultHttp2DataFrame(bb("hello"), false).stream(stream5)); - codec.onHttp2Frame(new DefaultHttp2DataFrame(bb("foo"), true).stream(stream3)); - codec.onHttp2Frame(new DefaultHttp2DataFrame(bb("world"), true).stream(stream5)); - codec.onHttp2Frame(new DefaultHttp2DataFrame(bb("bar"), true).stream(stream11)); - verifyFramesMultiplexedToCorrectChannel(stream5, inboundHandler5, 2); - verifyFramesMultiplexedToCorrectChannel(stream3, inboundHandler3, 1); - verifyFramesMultiplexedToCorrectChannel(stream11, inboundHandler11, 1); + verifyFramesMultiplexedToCorrectChannel(channel1, handler1, 1); + verifyFramesMultiplexedToCorrectChannel(channel2, handler2, 2); + verifyFramesMultiplexedToCorrectChannel(channel3, handler3, 1); } @Test - public void inboundDataFrameShouldEmitWindowUpdateFrame() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); + public void inboundDataFrameShouldUpdateLocalFlowController() throws Http2Exception { + Http2LocalFlowController flowController = Mockito.mock(Http2LocalFlowController.class); + codec.connection().local().flowController(flowController); + + LastInboundHandler handler = new LastInboundHandler(); + final Http2StreamChannel channel = newInboundStream(3, false, handler); + ByteBuf tenBytes = bb("0123456789"); - codec.onHttp2Frame(new DefaultHttp2DataFrame(tenBytes, true).stream(inboundStream)); - codec.onChannelReadComplete(); - Http2WindowUpdateFrame windowUpdate = parentChannel.readOutbound(); - assertNotNull(windowUpdate); + frameInboundWriter.writeInboundData(channel.stream().id(), tenBytes, 0, true); - assertEquals(inboundStream, windowUpdate.stream()); - assertEquals(10, windowUpdate.windowSizeIncrement()); + // Verify we marked the bytes as consumed + verify(flowController).consumeBytes(argThat(new ArgumentMatcher() { + @Override + public boolean matches(Http2Stream http2Stream) { + return http2Stream.id() == channel.stream().id(); + } + }), eq(10)); // headers and data frame - verifyFramesMultiplexedToCorrectChannel(inboundStream, inboundHandler, 2); + verifyFramesMultiplexedToCorrectChannel(channel, handler, 2); } @Test public void unhandledHttp2FramesShouldBePropagated() { - assertThat(parentChannel.readInbound(), instanceOf(Http2SettingsFrame.class)); - Http2PingFrame pingFrame = new DefaultHttp2PingFrame(0); - codec.onHttp2Frame(pingFrame); - assertSame(parentChannel.readInbound(), pingFrame); + frameInboundWriter.writeInboundPing(false, 0); + assertEquals(parentChannel.readInbound(), pingFrame); - DefaultHttp2GoAwayFrame goAwayFrame = - new DefaultHttp2GoAwayFrame(1, parentChannel.alloc().buffer().writeLong(8)); - codec.onHttp2Frame(goAwayFrame); + DefaultHttp2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(1, + parentChannel.alloc().buffer().writeLong(8)); + frameInboundWriter.writeInboundGoAway(0, goAwayFrame.errorCode(), goAwayFrame.content().retainedDuplicate()); Http2GoAwayFrame frame = parentChannel.readInbound(); - assertSame(frame, goAwayFrame); - assertTrue(frame.release()); + assertEqualsAndRelease(frame, goAwayFrame); } @Test public void channelReadShouldRespectAutoRead() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - Channel childChannel = inboundHandler.channel(); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); assertTrue(childChannel.config().isAutoRead()); Http2HeadersFrame headersFrame = inboundHandler.readInbound(); assertNotNull(headersFrame); childChannel.config().setAutoRead(false); - codec.onHttp2Frame( - new DefaultHttp2DataFrame(bb("hello world"), false).stream(inboundStream)); - codec.onChannelReadComplete(); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("hello world"), 0, false); Http2DataFrame dataFrame0 = inboundHandler.readInbound(); assertNotNull(dataFrame0); release(dataFrame0); - codec.onHttp2Frame(new DefaultHttp2DataFrame(bb("foo"), false).stream(inboundStream)); - codec.onHttp2Frame(new DefaultHttp2DataFrame(bb("bar"), true).stream(inboundStream)); - codec.onChannelReadComplete(); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("foo"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar"), 0, false); - dataFrame0 = inboundHandler.readInbound(); - assertNull(dataFrame0); + assertNull(inboundHandler.readInbound()); childChannel.config().setAutoRead(true); - verifyFramesMultiplexedToCorrectChannel(inboundStream, inboundHandler, 2); + verifyFramesMultiplexedToCorrectChannel(childChannel, inboundHandler, 2); + } + + @Test + public void readInChannelReadWithoutAutoRead() { + useReadWithoutAutoRead(false); } - private Http2StreamChannel newOutboundStream() { - return new Http2StreamChannelBootstrap(parentChannel).handler(childChannelInitializer) + @Test + public void readInChannelReadCompleteWithoutAutoRead() { + useReadWithoutAutoRead(true); + } + + private void useReadWithoutAutoRead(final boolean readComplete) { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + assertTrue(childChannel.config().isAutoRead()); + childChannel.config().setAutoRead(false); + assertFalse(childChannel.config().isAutoRead()); + + Http2HeadersFrame headersFrame = inboundHandler.readInbound(); + assertNotNull(headersFrame); + + // Add a handler which will request reads. + childChannel.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.fireChannelRead(msg); + if (!readComplete) { + ctx.read(); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.fireChannelReadComplete(); + if (readComplete) { + ctx.read(); + } + } + }); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("hello world"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("foo"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("hello world"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("foo"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar"), 0, true); + + verifyFramesMultiplexedToCorrectChannel(childChannel, inboundHandler, 6); + } + + private Http2StreamChannel newOutboundStream(ChannelHandler handler) { + return new Http2StreamChannelBootstrap(parentChannel).handler(handler) .open().syncUninterruptibly().getNow(); } @@ -228,12 +352,11 @@ private Http2StreamChannel newOutboundStream() { * A child channel for a HTTP/2 stream in IDLE state (that is no headers sent or received), * should not emit a RST_STREAM frame on close, as this is a connection error of type protocol error. */ - @Test public void idleOutboundStreamShouldNotWriteResetFrameOnClose() { - childChannelInitializer.handler = new LastInboundHandler(); + LastInboundHandler handler = new LastInboundHandler(); - Channel childChannel = newOutboundStream(); + Channel childChannel = newOutboundStream(handler); assertTrue(childChannel.isActive()); childChannel.close(); @@ -246,87 +369,105 @@ public void idleOutboundStreamShouldNotWriteResetFrameOnClose() { @Test public void outboundStreamShouldWriteResetFrameOnClose_headersSent() { - childChannelInitializer.handler = new ChannelInboundHandlerAdapter() { + ChannelHandler handler = new ChannelInboundHandlerAdapter() { @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { + public void channelActive(ChannelHandlerContext ctx) { ctx.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); ctx.fireChannelActive(); } }; - Channel childChannel = newOutboundStream(); + Http2StreamChannel childChannel = newOutboundStream(handler); assertTrue(childChannel.isActive()); - Http2FrameStream stream2 = readOutboundHeadersAndAssignId(); + childChannel.close(); + verify(frameWriter).writeRstStream(eqMultiplexCodecCtx(), + eqStreamId(childChannel), eq(Http2Error.CANCEL.code()), anyChannelPromise()); + } + + @Test + public void outboundStreamShouldNotWriteResetFrameOnClose_IfStreamDidntExist() { + when(frameWriter.writeHeaders(eqMultiplexCodecCtx(), anyInt(), + any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { + + private boolean headersWritten; + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + // We want to fail to write the first headers frame. This is what happens if the connection + // refuses to allocate a new stream due to having received a GOAWAY. + if (!headersWritten) { + headersWritten = true; + return ((ChannelPromise) invocationOnMock.getArgument(8)).setFailure(new Exception("boom")); + } + return ((ChannelPromise) invocationOnMock.getArgument(8)).setSuccess(); + } + }); + + Http2StreamChannel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + ctx.fireChannelActive(); + } + }); + + assertFalse(childChannel.isActive()); childChannel.close(); parentChannel.runPendingTasks(); + // The channel was never active so we should not generate a RST frame. + verify(frameWriter, never()).writeRstStream(eqMultiplexCodecCtx(), eqStreamId(childChannel), anyLong(), + anyChannelPromise()); - Http2ResetFrame reset = parentChannel.readOutbound(); - assertEquals(stream2, reset.stream()); - assertEquals(Http2Error.CANCEL.code(), reset.errorCode()); + assertTrue(parentChannel.outboundMessages().isEmpty()); } @Test public void inboundRstStreamFireChannelInactive() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); assertTrue(inboundHandler.isChannelActive()); - codec.onHttp2Frame(new DefaultHttp2ResetFrame(Http2Error.INTERNAL_ERROR) - .stream(inboundStream)); - codec.onChannelReadComplete(); - - // This will be called by the frame codec. - inboundStream.state = Http2Stream.State.CLOSED; - codec.onHttp2StreamStateChanged(inboundStream); - parentChannel.runPendingTasks(); + frameInboundWriter.writeInboundRstStream(channel.stream().id(), Http2Error.INTERNAL_ERROR.code()); assertFalse(inboundHandler.isChannelActive()); + // A RST_STREAM frame should NOT be emitted, as we received a RST_STREAM. - assertNull(parentChannel.readOutbound()); + verify(frameWriter, Mockito.never()).writeRstStream(eqMultiplexCodecCtx(), eqStreamId(channel), + anyLong(), anyChannelPromise()); } @Test(expected = StreamException.class) public void streamExceptionTriggersChildChannelExceptionAndClose() throws Exception { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - - StreamException cause = new StreamException(inboundStream.id(), Http2Error.PROTOCOL_ERROR, "baaam!"); - Http2FrameStreamException http2Ex = new Http2FrameStreamException( - inboundStream, Http2Error.PROTOCOL_ERROR, cause); - codec.onHttp2FrameStreamException(http2Ex); - - inboundHandler.checkException(); - } - - @Test(expected = StreamException.class) - public void streamExceptionClosesChildChannel() throws Exception { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - - assertTrue(inboundHandler.isChannelActive()); - StreamException cause = new StreamException(inboundStream.id(), Http2Error.PROTOCOL_ERROR, "baaam!"); - Http2FrameStreamException http2Ex = new Http2FrameStreamException( - inboundStream, Http2Error.PROTOCOL_ERROR, cause); - codec.onHttp2FrameStreamException(http2Ex); - parentChannel.runPendingTasks(); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + assertTrue(channel.isActive()); + StreamException cause = new StreamException(channel.stream().id(), Http2Error.PROTOCOL_ERROR, "baaam!"); + parentChannel.pipeline().fireExceptionCaught(cause); - assertFalse(inboundHandler.isChannelActive()); + assertFalse(channel.isActive()); inboundHandler.checkException(); } @Test(expected = ClosedChannelException.class) public void streamClosedErrorTranslatedToClosedChannelExceptionOnWrites() throws Exception { - writer = new Writer() { - @Override - void write(Object msg, ChannelPromise promise) { - promise.tryFailure(new StreamException(inboundStream.id(), Http2Error.STREAM_CLOSED, "Stream Closed")); - } - }; LastInboundHandler inboundHandler = new LastInboundHandler(); - childChannelInitializer.handler = inboundHandler; - Channel childChannel = newOutboundStream(); + final Http2StreamChannel childChannel = newOutboundStream(inboundHandler); assertTrue(childChannel.isActive()); + Http2Headers headers = new DefaultHttp2Headers(); + when(frameWriter.writeHeaders(eqMultiplexCodecCtx(), anyInt(), + eq(headers), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(8)).setFailure( + new StreamException(childChannel.stream().id(), Http2Error.STREAM_CLOSED, "Stream Closed")); + } + }); ChannelFuture future = childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + parentChannel.flush(); assertFalse(childChannel.isActive()); @@ -340,9 +481,7 @@ void write(Object msg, ChannelPromise promise) { @Test public void creatingWritingReadingAndClosingOutboundStreamShouldWork() { LastInboundHandler inboundHandler = new LastInboundHandler(); - childChannelInitializer.handler = inboundHandler; - - Http2StreamChannel childChannel = newOutboundStream(); + Http2StreamChannel childChannel = newOutboundStream(inboundHandler); assertTrue(childChannel.isActive()); assertTrue(inboundHandler.isChannelActive()); @@ -350,25 +489,21 @@ public void creatingWritingReadingAndClosingOutboundStreamShouldWork() { Http2Headers headers = new DefaultHttp2Headers().scheme("https").method("GET").path("/foo.txt"); childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers)); - readOutboundHeadersAndAssignId(); - // Read from the child channel - headers = new DefaultHttp2Headers().scheme("https").status("200"); - codec.onHttp2Frame(new DefaultHttp2HeadersFrame(headers).stream(childChannel.stream())); - codec.onChannelReadComplete(); + frameInboundWriter.writeInboundHeaders(childChannel.stream().id(), headers, 0, false); Http2HeadersFrame headersFrame = inboundHandler.readInbound(); assertNotNull(headersFrame); - assertSame(headers, headersFrame.headers()); + assertEquals(headers, headersFrame.headers()); // Close the child channel. childChannel.close(); parentChannel.runPendingTasks(); // An active outbound stream should emit a RST_STREAM frame. - Http2ResetFrame rstFrame = parentChannel.readOutbound(); - assertNotNull(rstFrame); - assertEquals(childChannel.stream(), rstFrame.stream()); + verify(frameWriter).writeRstStream(eqMultiplexCodecCtx(), eqStreamId(childChannel), + anyLong(), anyChannelPromise()); + assertFalse(childChannel.isOpen()); assertFalse(childChannel.isActive()); assertFalse(inboundHandler.isChannelActive()); @@ -379,33 +514,36 @@ public void creatingWritingReadingAndClosingOutboundStreamShouldWork() { // @Test(expected = Http2NoMoreStreamIdsException.class) public void failedOutboundStreamCreationThrowsAndClosesChannel() throws Exception { - writer = new Writer() { - @Override - void write(Object msg, ChannelPromise promise) { - promise.tryFailure(new Http2NoMoreStreamIdsException()); - } - }; - LastInboundHandler inboundHandler = new LastInboundHandler(); - childChannelInitializer.handler = inboundHandler; - - Channel childChannel = newOutboundStream(); + LastInboundHandler handler = new LastInboundHandler(); + Http2StreamChannel childChannel = newOutboundStream(handler); assertTrue(childChannel.isActive()); - ChannelFuture future = childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + Http2Headers headers = new DefaultHttp2Headers(); + when(frameWriter.writeHeaders(eqMultiplexCodecCtx(), anyInt(), + eq(headers), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(8)).setFailure( + new Http2NoMoreStreamIdsException()); + } + }); + + ChannelFuture future = childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers)); parentChannel.flush(); assertFalse(childChannel.isActive()); assertFalse(childChannel.isOpen()); - inboundHandler.checkException(); + handler.checkException(); future.syncUninterruptibly(); } @Test public void channelClosedWhenCloseListenerCompletes() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - Http2StreamChannel childChannel = (Http2StreamChannel) inboundHandler.channel(); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); assertTrue(childChannel.isOpen()); assertTrue(childChannel.isActive()); @@ -426,13 +564,41 @@ public void operationComplete(ChannelFuture future) { childChannel.close(p).syncUninterruptibly(); assertFalse(channelOpen.get()); + assertFalse(channelActive.get()); assertFalse(childChannel.isActive()); } @Test public void channelClosedWhenChannelClosePromiseCompletes() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - Http2StreamChannel childChannel = (Http2StreamChannel) inboundHandler.channel(); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + + assertTrue(childChannel.isOpen()); + assertTrue(childChannel.isActive()); + + final AtomicBoolean channelOpen = new AtomicBoolean(true); + final AtomicBoolean channelActive = new AtomicBoolean(true); + + childChannel.closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + channelOpen.set(future.channel().isOpen()); + channelActive.set(future.channel().isActive()); + } + }); + childChannel.close().syncUninterruptibly(); + + assertFalse(channelOpen.get()); + assertFalse(channelActive.get()); + assertFalse(childChannel.isActive()); + } + + @Test + public void channelClosedWhenWriteFutureFails() { + final Queue writePromises = new ArrayDeque(); + + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); assertTrue(childChannel.isOpen()); assertTrue(childChannel.isActive()); @@ -440,23 +606,41 @@ public void channelClosedWhenChannelClosePromiseCompletes() { final AtomicBoolean channelOpen = new AtomicBoolean(true); final AtomicBoolean channelActive = new AtomicBoolean(true); - childChannel.closeFuture().addListener(new ChannelFutureListener() { + Http2Headers headers = new DefaultHttp2Headers(); + when(frameWriter.writeHeaders(eqMultiplexCodecCtx(), anyInt(), + eq(headers), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { @Override - public void operationComplete(ChannelFuture future) { + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + ChannelPromise promise = invocationOnMock.getArgument(8); + writePromises.offer(promise); + return promise; + } + }); + + ChannelFuture f = childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers)); + assertFalse(f.isDone()); + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { channelOpen.set(future.channel().isOpen()); channelActive.set(future.channel().isActive()); } }); - childChannel.close().syncUninterruptibly(); + + ChannelPromise first = writePromises.poll(); + first.setFailure(new ClosedChannelException()); + f.awaitUninterruptibly(); assertFalse(channelOpen.get()); + assertFalse(channelActive.get()); assertFalse(childChannel.isActive()); } @Test public void channelClosedTwiceMarksPromiseAsSuccessful() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - Http2StreamChannel childChannel = (Http2StreamChannel) inboundHandler.channel(); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); assertTrue(childChannel.isOpen()); assertTrue(childChannel.isActive()); @@ -471,7 +655,7 @@ public void channelClosedTwiceMarksPromiseAsSuccessful() { public void settingChannelOptsAndAttrs() { AttributeKey key = AttributeKey.newInstance("foo"); - Channel childChannel = newOutboundStream(); + Channel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter()); childChannel.config().setAutoRead(false).setWriteSpinCount(1000); childChannel.attr(key).set("bar"); assertFalse(childChannel.config().isAutoRead()); @@ -481,50 +665,52 @@ public void settingChannelOptsAndAttrs() { @Test public void outboundFlowControlWritability() { - Http2StreamChannel childChannel = newOutboundStream(); + Http2StreamChannel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter()); assertTrue(childChannel.isActive()); assertTrue(childChannel.isWritable()); childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); parentChannel.flush(); - Http2FrameStream stream = readOutboundHeadersAndAssignId(); - // Test for initial window size assertEquals(initialRemoteStreamWindow, childChannel.config().getWriteBufferHighWaterMark()); - codec.onHttp2StreamWritabilityChanged(stream, true); assertTrue(childChannel.isWritable()); - codec.onHttp2StreamWritabilityChanged(stream, false); + childChannel.write(new DefaultHttp2DataFrame(Unpooled.buffer().writeZero(16 * 1024 * 1024))); assertFalse(childChannel.isWritable()); } @Test public void writabilityAndFlowControl() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - Http2StreamChannel childChannel = (Http2StreamChannel) inboundHandler.channel(); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); assertEquals("", inboundHandler.writabilityStates()); + assertTrue(childChannel.isWritable()); // HEADERS frames are not flow controlled, so they should not affect the flow control window. childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); - codec.onHttp2StreamWritabilityChanged(childChannel.stream(), true); + codec.onHttp2StreamWritabilityChanged(codec.ctx, childChannel.stream(), true); - assertEquals("true", inboundHandler.writabilityStates()); + assertTrue(childChannel.isWritable()); + assertEquals("", inboundHandler.writabilityStates()); - codec.onHttp2StreamWritabilityChanged(childChannel.stream(), true); - assertEquals("true", inboundHandler.writabilityStates()); + codec.onHttp2StreamWritabilityChanged(codec.ctx, childChannel.stream(), true); + assertTrue(childChannel.isWritable()); + assertEquals("", inboundHandler.writabilityStates()); - codec.onHttp2StreamWritabilityChanged(childChannel.stream(), false); - assertEquals("true,false", inboundHandler.writabilityStates()); + codec.onHttp2StreamWritabilityChanged(codec.ctx, childChannel.stream(), false); + assertFalse(childChannel.isWritable()); + assertEquals("false", inboundHandler.writabilityStates()); - codec.onHttp2StreamWritabilityChanged(childChannel.stream(), false); - assertEquals("true,false", inboundHandler.writabilityStates()); + codec.onHttp2StreamWritabilityChanged(codec.ctx, childChannel.stream(), false); + assertFalse(childChannel.isWritable()); + assertEquals("false", inboundHandler.writabilityStates()); } @Test public void channelClosedWhenInactiveFired() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - Http2StreamChannel childChannel = (Http2StreamChannel) inboundHandler.channel(); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); final AtomicBoolean channelOpen = new AtomicBoolean(false); final AtomicBoolean channelActive = new AtomicBoolean(false); @@ -546,28 +732,56 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { assertFalse(channelActive.get()); } - @Ignore("not supported anymore atm") @Test - public void cancellingWritesBeforeFlush() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - Channel childChannel = inboundHandler.channel(); + public void channelInactiveHappensAfterExceptionCaughtEvents() throws Exception { + final AtomicInteger count = new AtomicInteger(0); + final AtomicInteger exceptionCaught = new AtomicInteger(-1); + final AtomicInteger channelInactive = new AtomicInteger(-1); + final AtomicInteger channelUnregistered = new AtomicInteger(-1); + Http2StreamChannel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter() { + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + ctx.close(); + throw new Exception("exception"); + } + }); + + childChannel.pipeline().addLast(new ChannelInboundHandlerAdapter() { - Http2HeadersFrame headers1 = new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()); - Http2HeadersFrame headers2 = new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()); - ChannelPromise writePromise = childChannel.newPromise(); - childChannel.write(headers1, writePromise); - childChannel.write(headers2); - assertTrue(writePromise.cancel(false)); - childChannel.flush(); + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + channelInactive.set(count.getAndIncrement()); + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + exceptionCaught.set(count.getAndIncrement()); + super.exceptionCaught(ctx, cause); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + channelUnregistered.set(count.getAndIncrement()); + super.channelUnregistered(ctx); + } + }); - Http2HeadersFrame headers = parentChannel.readOutbound(); - assertSame(headers, headers2); + childChannel.pipeline().fireUserEventTriggered(new Object()); + parentChannel.runPendingTasks(); + + // The events should have happened in this order because the inactive and deregistration events + // get deferred as they do in the AbstractChannel. + assertEquals(0, exceptionCaught.get()); + assertEquals(1, channelInactive.get()); + assertEquals(2, channelUnregistered.get()); } @Test public void callUnsafeCloseMultipleTimes() { - LastInboundHandler inboundHandler = streamActiveAndWriteHeaders(inboundStream); - Http2StreamChannel childChannel = (Http2StreamChannel) inboundHandler.channel(); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); childChannel.unsafe().close(childChannel.voidPromise()); ChannelPromise promise = childChannel.newPromise(); @@ -576,147 +790,218 @@ public void callUnsafeCloseMultipleTimes() { childChannel.closeFuture().syncUninterruptibly(); } - private LastInboundHandler streamActiveAndWriteHeaders(Http2FrameStream stream) { - LastInboundHandler inboundHandler = new LastInboundHandler(); - childChannelInitializer.handler = inboundHandler; - assertFalse(inboundHandler.isChannelActive()); - ((TestableHttp2MultiplexCodec.Stream) stream).state = Http2Stream.State.OPEN; - codec.onHttp2StreamStateChanged(stream); - codec.onHttp2Frame(new DefaultHttp2HeadersFrame(request).stream(stream)); - codec.onChannelReadComplete(); - assertTrue(inboundHandler.isChannelActive()); + @Test + public void endOfStreamDoesNotDiscardData() { + AtomicInteger numReads = new AtomicInteger(1); + final AtomicBoolean shouldDisableAutoRead = new AtomicBoolean(); + Consumer ctxConsumer = new Consumer() { + @Override + public void accept(ChannelHandlerContext obj) { + if (shouldDisableAutoRead.get()) { + obj.channel().config().setAutoRead(false); + } + } + }; + LastInboundHandler inboundHandler = new LastInboundHandler(ctxConsumer); + Http2StreamChannel childChannel = newInboundStream(3, false, numReads, inboundHandler); + childChannel.config().setAutoRead(false); - return inboundHandler; - } + Http2DataFrame dataFrame1 = new DefaultHttp2DataFrame(bb("1")).stream(childChannel.stream()); + Http2DataFrame dataFrame2 = new DefaultHttp2DataFrame(bb("2")).stream(childChannel.stream()); + Http2DataFrame dataFrame3 = new DefaultHttp2DataFrame(bb("3")).stream(childChannel.stream()); + Http2DataFrame dataFrame4 = new DefaultHttp2DataFrame(bb("4")).stream(childChannel.stream()); - private static void verifyFramesMultiplexedToCorrectChannel(Http2FrameStream stream, - LastInboundHandler inboundHandler, - int numFrames) { - for (int i = 0; i < numFrames; i++) { - Http2StreamFrame frame = inboundHandler.readInbound(); - assertNotNull(frame); - assertEquals(stream, frame.stream()); - release(frame); - } - assertNull(inboundHandler.readInbound()); - } + assertEquals(new DefaultHttp2HeadersFrame(request).stream(childChannel.stream()), inboundHandler.readInbound()); - private static ByteBuf bb(String s) { - return ByteBufUtil.writeUtf8(UnpooledByteBufAllocator.DEFAULT, s); - } + ChannelHandler readCompleteSupressHandler = new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + // We want to simulate the parent channel calling channelRead and delay calling channelReadComplete. + } + }; - /** - * Simulates the frame codec, in first assigning an identifier and the completing the write promise. - */ - private Http2FrameStream readOutboundHeadersAndAssignId() { - // Only peek at the frame, so to not complete the promise of the write. We need to first - // assign a stream identifier, as the frame codec would do. - Http2HeadersFrame headersFrame = (Http2HeadersFrame) parentChannel.outboundMessages().peek(); - assertNotNull(headersFrame); - assertNotNull(headersFrame.stream()); - assertFalse(Http2CodecUtil.isStreamIdValid(headersFrame.stream().id())); - ((TestableHttp2MultiplexCodec.Stream) headersFrame.stream()).id = outboundStream.id(); + parentChannel.pipeline().addFirst(readCompleteSupressHandler); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("1"), 0, false); + + assertEqualsAndRelease(dataFrame1, inboundHandler.readInbound()); + + // Deliver frames, and then a stream closed while read is inactive. + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("2"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("3"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("4"), 0, false); + + shouldDisableAutoRead.set(true); + childChannel.config().setAutoRead(true); + numReads.set(1); + + frameInboundWriter.writeInboundRstStream(childChannel.stream().id(), Http2Error.NO_ERROR.code()); + + // Detecting EOS should flush all pending data regardless of read calls. + assertEqualsAndRelease(dataFrame2, inboundHandler.readInbound()); + assertEqualsAndRelease(dataFrame3, inboundHandler.readInbound()); + assertEqualsAndRelease(dataFrame4, inboundHandler.readInbound()); - // Now read it and complete the write promise. - assertSame(headersFrame, parentChannel.readOutbound()); + Http2ResetFrame resetFrame = inboundHandler.readInbound(); + assertEquals(childChannel.stream(), resetFrame.stream()); + assertEquals(Http2Error.NO_ERROR.code(), resetFrame.errorCode()); - return headersFrame.stream(); + assertNull(inboundHandler.readInbound()); + + // Now we want to call channelReadComplete and simulate the end of the read loop. + parentChannel.pipeline().remove(readCompleteSupressHandler); + parentChannel.flushInbound(); + + childChannel.closeFuture().syncUninterruptibly(); } - /** - * This class removes the bits that would transform the frames to bytes and so make it easier to test the actual - * special handling of the codec. - */ - private final class TestableHttp2MultiplexCodec extends Http2MultiplexCodec { + @Test + public void childQueueIsDrainedAndNewDataIsDispatchedInParentReadLoopAutoRead() { + AtomicInteger numReads = new AtomicInteger(1); + final AtomicInteger channelReadCompleteCount = new AtomicInteger(0); + final AtomicBoolean shouldDisableAutoRead = new AtomicBoolean(); + Consumer ctxConsumer = new Consumer() { + @Override + public void accept(ChannelHandlerContext obj) { + channelReadCompleteCount.incrementAndGet(); + if (shouldDisableAutoRead.get()) { + obj.channel().config().setAutoRead(false); + } + } + }; + LastInboundHandler inboundHandler = new LastInboundHandler(ctxConsumer); + Http2StreamChannel childChannel = newInboundStream(3, false, numReads, inboundHandler); + childChannel.config().setAutoRead(false); - public TestableHttp2MultiplexCodec(Http2ConnectionEncoder encoder, - Http2ConnectionDecoder decoder, - Http2Settings initialSettings, - ChannelHandler inboundStreamHandler) { - super(encoder, decoder, initialSettings, inboundStreamHandler); - } + Http2DataFrame dataFrame1 = new DefaultHttp2DataFrame(bb("1")).stream(childChannel.stream()); + Http2DataFrame dataFrame2 = new DefaultHttp2DataFrame(bb("2")).stream(childChannel.stream()); + Http2DataFrame dataFrame3 = new DefaultHttp2DataFrame(bb("3")).stream(childChannel.stream()); + Http2DataFrame dataFrame4 = new DefaultHttp2DataFrame(bb("4")).stream(childChannel.stream()); - void onHttp2Frame(Http2Frame frame) { - onHttp2Frame(ctx, frame); - } + assertEquals(new DefaultHttp2HeadersFrame(request).stream(childChannel.stream()), inboundHandler.readInbound()); - void onChannelReadComplete() { - onChannelReadComplete(ctx); - } + ChannelHandler readCompleteSupressHandler = new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + // We want to simulate the parent channel calling channelRead and delay calling channelReadComplete. + } + }; + parentChannel.pipeline().addFirst(readCompleteSupressHandler); - void onHttp2StreamStateChanged(Http2FrameStream stream) { - onHttp2StreamStateChanged(ctx, stream); - } + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("1"), 0, false); - void onHttp2FrameStreamException(Http2FrameStreamException cause) { - onHttp2FrameStreamException(ctx, cause); - } + assertEqualsAndRelease(dataFrame1, inboundHandler.readInbound()); - void onHttp2StreamWritabilityChanged(Http2FrameStream stream, boolean writable) { - onHttp2StreamWritabilityChanged(ctx, stream, writable); - } + // We want one item to be in the queue, and allow the numReads to be larger than 1. This will ensure that + // when beginRead() is called the child channel is added to the readPending queue of the parent channel. + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("2"), 0, false); - @Override - boolean onBytesConsumed(ChannelHandlerContext ctx, Http2FrameStream stream, int bytes) { - writer.write(new DefaultHttp2WindowUpdateFrame(bytes).stream(stream), ctx.newPromise()); - return true; - } + numReads.set(10); + shouldDisableAutoRead.set(true); + childChannel.config().setAutoRead(true); - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { - writer.write(msg, promise); - } + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("3"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("4"), 0, false); - @Override - void flush0(ChannelHandlerContext ctx) { - // Do nothing - } + // Detecting EOS should flush all pending data regardless of read calls. + assertEqualsAndRelease(dataFrame2, inboundHandler.readInbound()); + assertEqualsAndRelease(dataFrame3, inboundHandler.readInbound()); + assertEqualsAndRelease(dataFrame4, inboundHandler.readInbound()); - @Override - Stream newStream() { - return new Stream(); - } + assertNull(inboundHandler.readInbound()); - final class Stream extends Http2MultiplexCodecStream { - Http2Stream.State state = Http2Stream.State.IDLE; - int id = -1; + // Now we want to call channelReadComplete and simulate the end of the read loop. + parentChannel.pipeline().remove(readCompleteSupressHandler); + parentChannel.flushInbound(); + // 3 = 1 for initialization + 1 for read when auto read was off + 1 for when auto read was back on + assertEquals(3, channelReadCompleteCount.get()); + } + + @Test + public void childQueueIsDrainedAndNewDataIsDispatchedInParentReadLoopNoAutoRead() { + final AtomicInteger numReads = new AtomicInteger(1); + final AtomicInteger channelReadCompleteCount = new AtomicInteger(0); + final AtomicBoolean shouldDisableAutoRead = new AtomicBoolean(); + Consumer ctxConsumer = new Consumer() { @Override - public int id() { - return id; + public void accept(ChannelHandlerContext obj) { + channelReadCompleteCount.incrementAndGet(); + if (shouldDisableAutoRead.get()) { + obj.channel().config().setAutoRead(false); + } } + }; + final LastInboundHandler inboundHandler = new LastInboundHandler(ctxConsumer); + Http2StreamChannel childChannel = newInboundStream(3, false, numReads, inboundHandler); + childChannel.config().setAutoRead(false); + + Http2DataFrame dataFrame1 = new DefaultHttp2DataFrame(bb("1")).stream(childChannel.stream()); + Http2DataFrame dataFrame2 = new DefaultHttp2DataFrame(bb("2")).stream(childChannel.stream()); + Http2DataFrame dataFrame3 = new DefaultHttp2DataFrame(bb("3")).stream(childChannel.stream()); + Http2DataFrame dataFrame4 = new DefaultHttp2DataFrame(bb("4")).stream(childChannel.stream()); + assertEquals(new DefaultHttp2HeadersFrame(request).stream(childChannel.stream()), inboundHandler.readInbound()); + + ChannelHandler readCompleteSupressHandler = new ChannelInboundHandlerAdapter() { @Override - public Http2Stream.State state() { - return state; + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + // We want to simulate the parent channel calling channelRead and delay calling channelReadComplete. } - } - } + }; + parentChannel.pipeline().addFirst(readCompleteSupressHandler); - private final class TestableHttp2MultiplexCodecBuilder extends Http2MultiplexCodecBuilder { + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("1"), 0, false); - TestableHttp2MultiplexCodecBuilder(boolean server, ChannelHandler childHandler) { - super(server, childHandler); - } + assertEqualsAndRelease(dataFrame1, inboundHandler.readInbound()); - @Override - public TestableHttp2MultiplexCodec build() { - return (TestableHttp2MultiplexCodec) super.build(); - } + // We want one item to be in the queue, and allow the numReads to be larger than 1. This will ensure that + // when beginRead() is called the child channel is added to the readPending queue of the parent channel. + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("2"), 0, false); - @Override - protected Http2MultiplexCodec build( - Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, Http2Settings initialSettings) { - return new TestableHttp2MultiplexCodec( - encoder, decoder, initialSettings, childHandler); - } - } + numReads.set(2); + childChannel.read(); - class Writer { + assertEqualsAndRelease(dataFrame2, inboundHandler.readInbound()); - void write(Object msg, ChannelPromise promise) { - parentChannel.outboundMessages().add(msg); - promise.setSuccess(); + assertNull(inboundHandler.readInbound()); + + // This is the second item that was read, this should be the last until we call read() again. This should also + // notify of readComplete(). + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("3"), 0, false); + + assertEqualsAndRelease(dataFrame3, inboundHandler.readInbound()); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("4"), 0, false); + assertNull(inboundHandler.readInbound()); + + childChannel.read(); + + assertEqualsAndRelease(dataFrame4, inboundHandler.readInbound()); + + assertNull(inboundHandler.readInbound()); + + // Now we want to call channelReadComplete and simulate the end of the read loop. + parentChannel.pipeline().remove(readCompleteSupressHandler); + parentChannel.flushInbound(); + + // 3 = 1 for initialization + 1 for first read of 2 items + 1 for second read of 2 items + + // 1 for parent channel readComplete + assertEquals(4, channelReadCompleteCount.get()); + } + + private static void verifyFramesMultiplexedToCorrectChannel(Http2StreamChannel streamChannel, + LastInboundHandler inboundHandler, + int numFrames) { + for (int i = 0; i < numFrames; i++) { + Http2StreamFrame frame = inboundHandler.readInbound(); + assertNotNull(frame); + assertEquals(streamChannel.stream(), frame.stream()); + release(frame); } + assertNull(inboundHandler.readInbound()); + } + + private static int eqStreamId(Http2StreamChannel channel) { + return eq(channel.stream().id()); } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodecTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodecTest.java index 45e781c86c47..393a4060ef4f 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodecTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodecTest.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; @@ -871,4 +872,65 @@ public void testPassThroughOtherAsClient() throws Exception { frame.release(); } } + + @Test + public void testIsSharableBetweenChannels() throws Exception { + final Queue frames = new ConcurrentLinkedQueue(); + final ChannelHandler sharedHandler = new Http2StreamFrameToHttpObjectCodec(false); + + final SslContext ctx = SslContextBuilder.forClient().sslProvider(SslProvider.JDK).build(); + EmbeddedChannel tlsCh = new EmbeddedChannel(ctx.newHandler(ByteBufAllocator.DEFAULT), + new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg instanceof Http2StreamFrame) { + frames.add((Http2StreamFrame) msg); + promise.setSuccess(); + } else { + ctx.write(msg, promise); + } + } + }, sharedHandler); + + EmbeddedChannel plaintextCh = new EmbeddedChannel( + new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg instanceof Http2StreamFrame) { + frames.add((Http2StreamFrame) msg); + promise.setSuccess(); + } else { + ctx.write(msg, promise); + } + } + }, sharedHandler); + + FullHttpRequest req = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "/hello/world"); + assertTrue(tlsCh.writeOutbound(req)); + assertTrue(tlsCh.finishAndReleaseAll()); + + Http2HeadersFrame headersFrame = (Http2HeadersFrame) frames.poll(); + Http2Headers headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("https")); + assertThat(headers.method().toString(), is("GET")); + assertThat(headers.path().toString(), is("/hello/world")); + assertTrue(headersFrame.isEndStream()); + assertNull(frames.poll()); + + // Run the plaintext channel + req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/hello/world"); + assertFalse(plaintextCh.writeOutbound(req)); + assertFalse(plaintextCh.finishAndReleaseAll()); + + headersFrame = (Http2HeadersFrame) frames.poll(); + headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("http")); + assertThat(headers.method().toString(), is("GET")); + assertThat(headers.path().toString(), is("/hello/world")); + assertTrue(headersFrame.isEndStream()); + assertNull(frames.poll()); + } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java index e7130ee5d902..f6603812374c 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java @@ -15,7 +15,9 @@ package io.netty.handler.codec.http2; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -24,18 +26,34 @@ import io.netty.channel.DefaultChannelPromise; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.util.AsciiString; +import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.ImmediateEventExecutor; import junit.framework.AssertionFailedError; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.util.List; import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; +import static io.netty.util.ReferenceCountUtil.release; import static java.lang.Math.min; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyShort; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; /** * Utilities for the integration tests. @@ -506,4 +524,179 @@ public int windowSize() { return isWriteAllowed ? (int) min(pendingBytes, Integer.MAX_VALUE) : -1; } } + + static Http2FrameWriter mockedFrameWriter() { + Http2FrameWriter.Configuration configuration = new Http2FrameWriter.Configuration() { + private final Http2HeadersEncoder.Configuration headerConfiguration = + new Http2HeadersEncoder.Configuration() { + @Override + public void maxHeaderTableSize(long max) { + // NOOP + } + + @Override + public long maxHeaderTableSize() { + return 0; + } + + @Override + public void maxHeaderListSize(long max) { + // NOOP + } + + @Override + public long maxHeaderListSize() { + return 0; + } + }; + + private final Http2FrameSizePolicy policy = new Http2FrameSizePolicy() { + @Override + public void maxFrameSize(int max) { + // NOOP + } + + @Override + public int maxFrameSize() { + return 0; + } + }; + @Override + public Http2HeadersEncoder.Configuration headersConfiguration() { + return headerConfiguration; + } + + @Override + public Http2FrameSizePolicy frameSizePolicy() { + return policy; + } + }; + + final ConcurrentLinkedQueue buffers = new ConcurrentLinkedQueue(); + + Http2FrameWriter frameWriter = Mockito.mock(Http2FrameWriter.class); + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) { + for (;;) { + ByteBuf buf = buffers.poll(); + if (buf == null) { + break; + } + buf.release(); + } + return null; + } + }).when(frameWriter).close(); + + when(frameWriter.configuration()).thenReturn(configuration); + when(frameWriter.writeSettings(any(ChannelHandlerContext.class), any(Http2Settings.class), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(2)).setSuccess(); + } + }); + + when(frameWriter.writeSettingsAck(any(ChannelHandlerContext.class), any(ChannelPromise.class))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(1)).setSuccess(); + } + }); + + when(frameWriter.writeGoAway(any(ChannelHandlerContext.class), anyInt(), + anyLong(), any(ByteBuf.class), any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + buffers.offer((ByteBuf) invocationOnMock.getArgument(3)); + return ((ChannelPromise) invocationOnMock.getArgument(4)).setSuccess(); + } + }); + when(frameWriter.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(), + anyBoolean(), any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess(); + } + }); + + when(frameWriter.writeHeaders(any(ChannelHandlerContext.class), anyInt(), + any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(8)).setSuccess(); + } + }); + + when(frameWriter.writeData(any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class), anyInt(), + anyBoolean(), any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + buffers.offer((ByteBuf) invocationOnMock.getArgument(2)); + return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess(); + } + }); + + when(frameWriter.writeRstStream(any(ChannelHandlerContext.class), anyInt(), + anyLong(), any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(3)).setSuccess(); + } + }); + + when(frameWriter.writeWindowUpdate(any(ChannelHandlerContext.class), anyInt(), anyInt(), + any(ChannelPromise.class))).then(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(3)).setSuccess(); + } + }); + + when(frameWriter.writePushPromise(any(ChannelHandlerContext.class), anyInt(), anyInt(), any(Http2Headers.class), + anyInt(), anyChannelPromise())).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess(); + } + }); + + when(frameWriter.writeFrame(any(ChannelHandlerContext.class), anyByte(), anyInt(), any(Http2Flags.class), + any(ByteBuf.class), anyChannelPromise())).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + buffers.offer((ByteBuf) invocationOnMock.getArgument(4)); + return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess(); + } + }); + return frameWriter; + } + + static ChannelPromise anyChannelPromise() { + return any(ChannelPromise.class); + } + + static Http2Settings anyHttp2Settings() { + return any(Http2Settings.class); + } + + static ByteBuf bb(String s) { + return ByteBufUtil.writeUtf8(UnpooledByteBufAllocator.DEFAULT, s); + } + + static void assertEqualsAndRelease(Http2Frame expected, Http2Frame actual) { + try { + assertEquals(expected, actual); + } finally { + release(expected); + release(actual); + // Will return -1 when not implements ReferenceCounted. + assertTrue(ReferenceCountUtil.refCnt(expected) <= 0); + assertTrue(ReferenceCountUtil.refCnt(actual) <= 0); + } + } + } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java index 55c4df0a984a..38f400af14ee 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.concurrent.locks.LockSupport; +import static io.netty.util.internal.ObjectUtil.checkNotNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; /** @@ -34,11 +35,36 @@ */ public class LastInboundHandler extends ChannelDuplexHandler { private final List queue = new ArrayList(); + private final Consumer channelReadCompleteConsumer; private Throwable lastException; private ChannelHandlerContext ctx; private boolean channelActive; private String writabilityStates = ""; + // TODO(scott): use JDK 8's Consumer + public interface Consumer { + void accept(T obj); + } + + private static final Consumer NOOP_CONSUMER = new Consumer() { + @Override + public void accept(Object obj) { + } + }; + + @SuppressWarnings("unchecked") + public static Consumer noopConsumer() { + return (Consumer) NOOP_CONSUMER; + } + + public LastInboundHandler() { + this(LastInboundHandler.noopConsumer()); + } + + public LastInboundHandler(Consumer channelReadCompleteConsumer) { + this.channelReadCompleteConsumer = checkNotNull(channelReadCompleteConsumer, "channelReadCompleteConsumer"); + } + @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { super.handlerAdded(ctx); @@ -86,6 +112,11 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception queue.add(msg); } + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadCompleteConsumer.accept(ctx); + } + @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { queue.add(new UserEvent(evt)); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java index a1768cb85769..15dc168b4da8 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java @@ -222,7 +222,7 @@ public void alternatingWritesToActiveAndBufferedStreams() { } @Test - public void bufferingNewStreamFailsAfterGoAwayReceived() { + public void bufferingNewStreamFailsAfterGoAwayReceived() throws Http2Exception { encoder.writeSettingsAck(ctx, newPromise()); setMaxConcurrentStreams(0); connection.goAwayReceived(1, 8, EMPTY_BUFFER); @@ -235,7 +235,7 @@ public void bufferingNewStreamFailsAfterGoAwayReceived() { } @Test - public void receivingGoAwayFailsBufferedStreams() { + public void receivingGoAwayFailsBufferedStreams() throws Http2Exception { encoder.writeSettingsAck(ctx, newPromise()); setMaxConcurrentStreams(5); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/TestChannelInitializer.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/TestChannelInitializer.java index 5e13d056bbd6..e8bbab678ef0 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/TestChannelInitializer.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/TestChannelInitializer.java @@ -16,10 +16,17 @@ package io.netty.handler.codec.http2; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelInitializer; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.util.UncheckedBooleanSupplier; + +import java.util.concurrent.atomic.AtomicInteger; /** * Channel initializer useful in tests. @@ -27,6 +34,7 @@ @Sharable public class TestChannelInitializer extends ChannelInitializer { ChannelHandler handler; + AtomicInteger maxReads; @Override public void initChannel(Channel channel) { @@ -34,5 +42,86 @@ public void initChannel(Channel channel) { channel.pipeline().addLast(handler); handler = null; } + if (maxReads != null) { + channel.config().setRecvByteBufAllocator(new TestNumReadsRecvByteBufAllocator(maxReads)); + } + } + + /** + * Designed to read a single byte at a time to control the number of reads done at a fine granularity. + */ + static final class TestNumReadsRecvByteBufAllocator implements RecvByteBufAllocator { + private final AtomicInteger numReads; + private TestNumReadsRecvByteBufAllocator(AtomicInteger numReads) { + this.numReads = numReads; + } + + @Override + public ExtendedHandle newHandle() { + return new ExtendedHandle() { + private int attemptedBytesRead; + private int lastBytesRead; + private int numMessagesRead; + @Override + public ByteBuf allocate(ByteBufAllocator alloc) { + return alloc.ioBuffer(guess(), guess()); + } + + @Override + public int guess() { + return 1; // only ever allocate buffers of size 1 to ensure the number of reads is controlled. + } + + @Override + public void reset(ChannelConfig config) { + numMessagesRead = 0; + } + + @Override + public void incMessagesRead(int numMessages) { + numMessagesRead += numMessages; + } + + @Override + public void lastBytesRead(int bytes) { + lastBytesRead = bytes; + } + + @Override + public int lastBytesRead() { + return lastBytesRead; + } + + @Override + public void attemptedBytesRead(int bytes) { + attemptedBytesRead = bytes; + } + + @Override + public int attemptedBytesRead() { + return attemptedBytesRead; + } + + @Override + public boolean continueReading() { + return numMessagesRead < numReads.get(); + } + + @Override + public boolean continueReading(UncheckedBooleanSupplier maybeMoreDataSupplier) { + return continueReading(); + } + + @Override + public void readComplete() { + // Nothing needs to be done or adjusted after each read cycle is completed. + } + + @Override + public void channelClosed() { + // noop + } + }; + } } } diff --git a/codec-memcache/pom.xml b/codec-memcache/pom.xml index 21dc05549e9b..f54cfca998f2 100644 --- a/codec-memcache/pom.xml +++ b/codec-memcache/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-memcache @@ -33,6 +33,21 @@ + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec diff --git a/codec-memcache/src/main/java/io/netty/handler/codec/memcache/binary/AbstractBinaryMemcacheDecoder.java b/codec-memcache/src/main/java/io/netty/handler/codec/memcache/binary/AbstractBinaryMemcacheDecoder.java index 2c9038282991..bec754afbda2 100644 --- a/codec-memcache/src/main/java/io/netty/handler/codec/memcache/binary/AbstractBinaryMemcacheDecoder.java +++ b/codec-memcache/src/main/java/io/netty/handler/codec/memcache/binary/AbstractBinaryMemcacheDecoder.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec.memcache.binary; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; @@ -59,9 +61,7 @@ protected AbstractBinaryMemcacheDecoder() { * @param chunkSize the maximum chunk size of the payload. */ protected AbstractBinaryMemcacheDecoder(int chunkSize) { - if (chunkSize < 0) { - throw new IllegalArgumentException("chunkSize must be a positive integer: " + chunkSize); - } + checkPositiveOrZero(chunkSize, "chunkSize"); this.chunkSize = chunkSize; } diff --git a/codec-mqtt/pom.xml b/codec-mqtt/pom.xml index 0bd0b7d8c2b5..08c050239111 100644 --- a/codec-mqtt/pom.xml +++ b/codec-mqtt/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-mqtt @@ -33,6 +33,21 @@ + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java index e014f37ac2bd..e93d34bde155 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java @@ -82,11 +82,11 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List ou } case READ_VARIABLE_HEADER: try { + final Result decodedVariableHeader = decodeVariableHeader(buffer, mqttFixedHeader); + variableHeader = decodedVariableHeader.value; if (bytesRemainingInVariablePart > maxBytesInMessage) { throw new DecoderException("too large message: " + bytesRemainingInVariablePart + " bytes"); } - final Result decodedVariableHeader = decodeVariableHeader(buffer, mqttFixedHeader); - variableHeader = decodedVariableHeader.value; bytesRemainingInVariablePart -= decodedVariableHeader.numberOfBytesConsumed; checkpoint(DecoderState.READ_PAYLOAD); // fall through @@ -133,7 +133,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List ou private MqttMessage invalidMessage(Throwable cause) { checkpoint(DecoderState.BAD_MESSAGE); - return MqttMessageFactory.newInvalidMessage(cause); + return MqttMessageFactory.newInvalidMessage(mqttFixedHeader, variableHeader, cause); } /** diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttMessageFactory.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttMessageFactory.java index 69f07c07273c..09f42698d781 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttMessageFactory.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttMessageFactory.java @@ -85,5 +85,10 @@ public static MqttMessage newInvalidMessage(Throwable cause) { return new MqttMessage(null, null, null, DecoderResult.failure(cause)); } + public static MqttMessage newInvalidMessage(MqttFixedHeader mqttFixedHeader, Object variableHeader, + Throwable cause) { + return new MqttMessage(mqttFixedHeader, variableHeader, null, DecoderResult.failure(cause)); + } + private MqttMessageFactory() { } } diff --git a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java index f69387ef2166..927334cb1d86 100644 --- a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java +++ b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java @@ -58,6 +58,11 @@ public class MqttCodecTest { private final MqttDecoder mqttDecoder = new MqttDecoder(); + /** + * MqttDecoder with an unrealistic max payload size of 1 byte. + */ + private final MqttDecoder mqttDecoderLimitedMessageSize = new MqttDecoder(1); + @Before public void setup() { MockitoAnnotations.initMocks(this); @@ -297,6 +302,154 @@ public void testUnknownMessageType() throws Exception { } } + @Test + public void testConnectMessageForMqtt31TooLarge() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); + + try { + final List out = new LinkedList(); + mqttDecoderLimitedMessageSize.decode(ctx, byteBuf, out); + + assertEquals("Expected one object but got " + out.size(), 1, out.size()); + + final MqttMessage decodedMessage = (MqttMessage) out.get(0); + + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validateConnectVariableHeader(message.variableHeader(), + (MqttConnectVariableHeader) decodedMessage.variableHeader()); + validateDecoderExceptionTooLargeMessage(decodedMessage); + } finally { + byteBuf.release(); + } + } + + @Test + public void testConnectMessageForMqtt311TooLarge() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); + + try { + final List out = new LinkedList(); + mqttDecoderLimitedMessageSize.decode(ctx, byteBuf, out); + + assertEquals("Expected one object but got " + out.size(), 1, out.size()); + + final MqttMessage decodedMessage = (MqttMessage) out.get(0); + + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validateConnectVariableHeader(message.variableHeader(), + (MqttConnectVariableHeader) decodedMessage.variableHeader()); + validateDecoderExceptionTooLargeMessage(decodedMessage); + } finally { + byteBuf.release(); + } + } + + @Test + public void testConnAckMessageTooLarge() throws Exception { + final MqttConnAckMessage message = createConnAckMessage(); + ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); + + try { + final List out = new LinkedList(); + mqttDecoderLimitedMessageSize.decode(ctx, byteBuf, out); + + assertEquals("Expected one object but got " + out.size(), 1, out.size()); + + final MqttMessage decodedMessage = (MqttMessage) out.get(0); + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validateDecoderExceptionTooLargeMessage(decodedMessage); + } finally { + byteBuf.release(); + } + } + + @Test + public void testPublishMessageTooLarge() throws Exception { + final MqttPublishMessage message = createPublishMessage(); + ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); + + try { + final List out = new LinkedList(); + mqttDecoderLimitedMessageSize.decode(ctx, byteBuf, out); + + assertEquals("Expected one object but got " + out.size(), 1, out.size()); + + final MqttMessage decodedMessage = (MqttMessage) out.get(0); + + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validatePublishVariableHeader(message.variableHeader(), + (MqttPublishVariableHeader) decodedMessage.variableHeader()); + validateDecoderExceptionTooLargeMessage(decodedMessage); + } finally { + byteBuf.release(); + } + } + + @Test + public void testSubscribeMessageTooLarge() throws Exception { + final MqttSubscribeMessage message = createSubscribeMessage(); + ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); + + try { + final List out = new LinkedList(); + mqttDecoderLimitedMessageSize.decode(ctx, byteBuf, out); + + assertEquals("Expected one object but got " + out.size(), 1, out.size()); + + final MqttMessage decodedMessage = (MqttMessage) out.get(0); + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validateMessageIdVariableHeader(message.variableHeader(), + (MqttMessageIdVariableHeader) decodedMessage.variableHeader()); + validateDecoderExceptionTooLargeMessage(decodedMessage); + } finally { + byteBuf.release(); + } + } + + @Test + public void testSubAckMessageTooLarge() throws Exception { + final MqttSubAckMessage message = createSubAckMessage(); + ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); + + try { + final List out = new LinkedList(); + mqttDecoderLimitedMessageSize.decode(ctx, byteBuf, out); + + assertEquals("Expected one object but got " + out.size(), 1, out.size()); + + final MqttMessage decodedMessage = (MqttMessage) out.get(0); + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validateMessageIdVariableHeader(message.variableHeader(), + (MqttMessageIdVariableHeader) decodedMessage.variableHeader()); + validateDecoderExceptionTooLargeMessage(decodedMessage); + } finally { + byteBuf.release(); + } + } + + @Test + public void testUnSubscribeMessageTooLarge() throws Exception { + final MqttUnsubscribeMessage message = createUnsubscribeMessage(); + ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); + + try { + final List out = new LinkedList(); + mqttDecoderLimitedMessageSize.decode(ctx, byteBuf, out); + + assertEquals("Expected one object but got " + out.size(), 1, out.size()); + + final MqttMessage decodedMessage = (MqttMessage) out.get(0); + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + validateMessageIdVariableHeader(message.variableHeader(), + (MqttMessageIdVariableHeader) decodedMessage.variableHeader()); + validateDecoderExceptionTooLargeMessage(decodedMessage); + } finally { + byteBuf.release(); + } + } + private void testMessageWithOnlyFixedHeader(MqttMessageType messageType) throws Exception { MqttMessage message = createMessageWithFixedHeader(messageType); ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); @@ -340,7 +493,7 @@ private static MqttMessage createMessageWithFixedHeaderAndMessageIdVariableHeade new MqttFixedHeader( messageType, false, - messageType == MqttMessageType.PUBREL ? MqttQoS.AT_LEAST_ONCE : MqttQoS.AT_MOST_ONCE, + messageType == MqttMessageType.PUBREL ? MqttQoS.AT_LEAST_ONCE : MqttQoS.AT_MOST_ONCE, false, 0); MqttMessageIdVariableHeader mqttMessageIdVariableHeader = MqttMessageIdVariableHeader.from(12345); @@ -378,7 +531,7 @@ private static MqttPublishMessage createPublishMessage() { MqttFixedHeader mqttFixedHeader = new MqttFixedHeader(MqttMessageType.PUBLISH, false, MqttQoS.AT_LEAST_ONCE, true, 0); MqttPublishVariableHeader mqttPublishVariableHeader = new MqttPublishVariableHeader("/abc", 1234); - ByteBuf payload = ALLOCATOR.buffer(); + ByteBuf payload = ALLOCATOR.buffer(); payload.writeBytes("whatever".getBytes(CharsetUtil.UTF_8)); return new MqttPublishMessage(mqttFixedHeader, mqttPublishVariableHeader, payload); } @@ -530,4 +683,14 @@ private static void validateUnsubscribePayload(MqttUnsubscribePayload expected, expected.topics().toArray(), actual.topics().toArray()); } + + private static void validateDecoderExceptionTooLargeMessage(MqttMessage message) { + assertNull("MqttMessage payload expected null ", message.payload()); + assertTrue(message.decoderResult().isFailure()); + Throwable cause = message.decoderResult().cause(); + assertTrue("MqttMessage DecoderResult cause expected instance of DecoderException ", + cause instanceof DecoderException); + assertTrue("MqttMessage DecoderResult cause reason expect to contain 'too large message' ", + cause.getMessage().contains("too large message:")); + } } diff --git a/codec-redis/pom.xml b/codec-redis/pom.xml index 0944db44483a..b92ca2f0c714 100644 --- a/codec-redis/pom.xml +++ b/codec-redis/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-redis @@ -33,6 +33,21 @@ + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec diff --git a/codec-smtp/pom.xml b/codec-smtp/pom.xml index 217873d8e580..d8570d8b204f 100644 --- a/codec-smtp/pom.xml +++ b/codec-smtp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-smtp @@ -33,6 +33,21 @@ + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec diff --git a/codec-socks/pom.xml b/codec-socks/pom.xml index a1069cea05cb..67c68c3c4dc6 100644 --- a/codec-socks/pom.xml +++ b/codec-socks/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-socks @@ -33,6 +33,21 @@ + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec diff --git a/codec-socks/src/main/java/io/netty/handler/codec/socksx/v5/Socks5AddressEncoder.java b/codec-socks/src/main/java/io/netty/handler/codec/socksx/v5/Socks5AddressEncoder.java index fcd29cf3f3fa..390622caba77 100644 --- a/codec-socks/src/main/java/io/netty/handler/codec/socksx/v5/Socks5AddressEncoder.java +++ b/codec-socks/src/main/java/io/netty/handler/codec/socksx/v5/Socks5AddressEncoder.java @@ -44,7 +44,6 @@ public void encodeAddress(Socks5AddressType addrType, String addrValue, ByteBuf out.writeByte(addrValue.length()); out.writeCharSequence(addrValue, CharsetUtil.US_ASCII); } else { - out.writeByte(1); out.writeByte(0); } } else if (typeVal == Socks5AddressType.IPv6.byteValue()) { diff --git a/codec-socks/src/test/java/io/netty/handler/codec/socks/SocksCmdRequestTest.java b/codec-socks/src/test/java/io/netty/handler/codec/socks/SocksCmdRequestTest.java index 37439a7e6af2..4013c8b929a5 100644 --- a/codec-socks/src/test/java/io/netty/handler/codec/socks/SocksCmdRequestTest.java +++ b/codec-socks/src/test/java/io/netty/handler/codec/socks/SocksCmdRequestTest.java @@ -21,6 +21,7 @@ import org.junit.Test; import java.net.IDN; +import java.nio.CharBuffer; import static org.junit.Assert.*; @@ -101,7 +102,7 @@ public void testHostNotEncodedForUnknown() { @Test public void testIDNEncodeToAsciiForDomain() { String host = "тест.рф"; - String asciiHost = IDN.toASCII(host); + CharBuffer asciiHost = CharBuffer.wrap(IDN.toASCII(host)); short port = 10000; SocksCmdRequest rq = new SocksCmdRequest(SocksCmdType.BIND, SocksAddressType.DOMAIN, host, port); @@ -116,7 +117,8 @@ public void testIDNEncodeToAsciiForDomain() { assertEquals((byte) 0x00, buffer.readByte()); assertEquals(SocksAddressType.DOMAIN.byteValue(), buffer.readByte()); assertEquals((byte) asciiHost.length(), buffer.readUnsignedByte()); - assertEquals(asciiHost, buffer.readCharSequence(asciiHost.length(), CharsetUtil.US_ASCII)); + assertEquals(asciiHost, + CharBuffer.wrap(buffer.readCharSequence(asciiHost.length(), CharsetUtil.US_ASCII))); assertEquals(port, buffer.readUnsignedShort()); buffer.release(); diff --git a/codec-socks/src/test/java/io/netty/handler/codec/socks/SocksCmdResponseTest.java b/codec-socks/src/test/java/io/netty/handler/codec/socks/SocksCmdResponseTest.java index 6861f761de5b..3ca64252d5c5 100644 --- a/codec-socks/src/test/java/io/netty/handler/codec/socks/SocksCmdResponseTest.java +++ b/codec-socks/src/test/java/io/netty/handler/codec/socks/SocksCmdResponseTest.java @@ -21,6 +21,7 @@ import org.junit.Test; import java.net.IDN; +import java.nio.CharBuffer; import static org.junit.Assert.*; @@ -135,7 +136,7 @@ public void testHostNotEncodedForUnknown() { @Test public void testIDNEncodeToAsciiForDomain() { String host = "тест.рф"; - String asciiHost = IDN.toASCII(host); + CharBuffer asciiHost = CharBuffer.wrap(IDN.toASCII(host)); short port = 10000; SocksCmdResponse rs = new SocksCmdResponse(SocksCmdStatus.SUCCESS, SocksAddressType.DOMAIN, host, port); @@ -150,7 +151,8 @@ public void testIDNEncodeToAsciiForDomain() { assertEquals((byte) 0x00, buffer.readByte()); assertEquals(SocksAddressType.DOMAIN.byteValue(), buffer.readByte()); assertEquals((byte) asciiHost.length(), buffer.readUnsignedByte()); - assertEquals(asciiHost, buffer.readCharSequence(asciiHost.length(), CharsetUtil.US_ASCII)); + assertEquals(asciiHost, + CharBuffer.wrap(buffer.readCharSequence(asciiHost.length(), CharsetUtil.US_ASCII))); assertEquals(port, buffer.readUnsignedShort()); buffer.release(); diff --git a/codec-socks/src/test/java/io/netty/handler/codec/socksx/v5/DefaultSocks5CommandResponseTest.java b/codec-socks/src/test/java/io/netty/handler/codec/socksx/v5/DefaultSocks5CommandResponseTest.java index 36fea750275f..9da123602e7b 100755 --- a/codec-socks/src/test/java/io/netty/handler/codec/socksx/v5/DefaultSocks5CommandResponseTest.java +++ b/codec-socks/src/test/java/io/netty/handler/codec/socksx/v5/DefaultSocks5CommandResponseTest.java @@ -51,8 +51,7 @@ public void testEmptyDomain() { 0x00, // success reply 0x00, // reserved 0x03, // address type domain - 0x01, // length of domain - 0x00, // domain value + 0x00, // length of domain 0x00, // port value 0x00 }; diff --git a/codec-stomp/pom.xml b/codec-stomp/pom.xml index 15d755138da2..58ac05cfb581 100644 --- a/codec-stomp/pom.xml +++ b/codec-stomp/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-stomp @@ -33,6 +33,21 @@ + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec diff --git a/codec-stomp/src/main/java/io/netty/handler/codec/stomp/StompSubframeDecoder.java b/codec-stomp/src/main/java/io/netty/handler/codec/stomp/StompSubframeDecoder.java index b4d6a1495823..0139426172f4 100644 --- a/codec-stomp/src/main/java/io/netty/handler/codec/stomp/StompSubframeDecoder.java +++ b/codec-stomp/src/main/java/io/netty/handler/codec/stomp/StompSubframeDecoder.java @@ -30,6 +30,7 @@ import static io.netty.buffer.ByteBufUtil.indexOf; import static io.netty.buffer.ByteBufUtil.readBytes; +import static io.netty.util.internal.ObjectUtil.checkPositive; /** * Decodes {@link ByteBuf}s into {@link StompHeadersSubframe}s and @@ -90,16 +91,8 @@ public StompSubframeDecoder(int maxLineLength, int maxChunkSize) { public StompSubframeDecoder(int maxLineLength, int maxChunkSize, boolean validateHeaders) { super(State.SKIP_CONTROL_CHARACTERS); - if (maxLineLength <= 0) { - throw new IllegalArgumentException( - "maxLineLength must be a positive integer: " + - maxLineLength); - } - if (maxChunkSize <= 0) { - throw new IllegalArgumentException( - "maxChunkSize must be a positive integer: " + - maxChunkSize); - } + checkPositive(maxLineLength, "maxLineLength"); + checkPositive(maxChunkSize, "maxChunkSize"); this.maxChunkSize = maxChunkSize; this.maxLineLength = maxLineLength; this.validateHeaders = validateHeaders; diff --git a/codec-xml/pom.xml b/codec-xml/pom.xml index 1f4539d85761..ff0e5598a927 100644 --- a/codec-xml/pom.xml +++ b/codec-xml/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec-xml @@ -35,12 +35,17 @@ ${project.groupId} - netty-codec + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport ${project.version} ${project.groupId} - netty-handler + netty-codec ${project.version} diff --git a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlAttribute.java b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlAttribute.java index d40ce138c339..968d306c3069 100644 --- a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlAttribute.java +++ b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlAttribute.java @@ -57,16 +57,30 @@ public String value() { @Override public boolean equals(Object o) { - if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { return false; } + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } XmlAttribute that = (XmlAttribute) o; - if (!name.equals(that.name)) { return false; } - if (namespace != null ? !namespace.equals(that.namespace) : that.namespace != null) { return false; } - if (prefix != null ? !prefix.equals(that.prefix) : that.prefix != null) { return false; } - if (type != null ? !type.equals(that.type) : that.type != null) { return false; } - if (value != null ? !value.equals(that.value) : that.value != null) { return false; } + if (!name.equals(that.name)) { + return false; + } + if (namespace != null ? !namespace.equals(that.namespace) : that.namespace != null) { + return false; + } + if (prefix != null ? !prefix.equals(that.prefix) : that.prefix != null) { + return false; + } + if (type != null ? !type.equals(that.type) : that.type != null) { + return false; + } + if (value != null ? !value.equals(that.value) : that.value != null) { + return false; + } return true; } diff --git a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlContent.java b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlContent.java index a47df3312c86..275297c64147 100644 --- a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlContent.java +++ b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlContent.java @@ -32,12 +32,18 @@ public String data() { @Override public boolean equals(Object o) { - if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { return false; } + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } XmlContent that = (XmlContent) o; - if (data != null ? !data.equals(that.data) : that.data != null) { return false; } + if (data != null ? !data.equals(that.data) : that.data != null) { + return false; + } return true; } diff --git a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlDTD.java b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlDTD.java index 754539b243ca..e36648f55eac 100644 --- a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlDTD.java +++ b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlDTD.java @@ -32,12 +32,18 @@ public String text() { @Override public boolean equals(Object o) { - if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { return false; } + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } XmlDTD xmlDTD = (XmlDTD) o; - if (text != null ? !text.equals(xmlDTD.text) : xmlDTD.text != null) { return false; } + if (text != null ? !text.equals(xmlDTD.text) : xmlDTD.text != null) { + return false; + } return true; } diff --git a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlDocumentStart.java b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlDocumentStart.java index 311a1f5702c3..98ce875cf164 100644 --- a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlDocumentStart.java +++ b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlDocumentStart.java @@ -54,17 +54,27 @@ public String encodingScheme() { @Override public boolean equals(Object o) { - if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { return false; } + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } XmlDocumentStart that = (XmlDocumentStart) o; - if (standalone != that.standalone) { return false; } - if (encoding != null ? !encoding.equals(that.encoding) : that.encoding != null) { return false; } + if (standalone != that.standalone) { + return false; + } + if (encoding != null ? !encoding.equals(that.encoding) : that.encoding != null) { + return false; + } if (encodingScheme != null ? !encodingScheme.equals(that.encodingScheme) : that.encodingScheme != null) { return false; } - if (version != null ? !version.equals(that.version) : that.version != null) { return false; } + if (version != null ? !version.equals(that.version) : that.version != null) { + return false; + } return true; } diff --git a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlElement.java b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlElement.java index 885e814a098c..8391bd09ee64 100644 --- a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlElement.java +++ b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlElement.java @@ -54,15 +54,27 @@ public List namespaces() { @Override public boolean equals(Object o) { - if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { return false; } + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } XmlElement that = (XmlElement) o; - if (!name.equals(that.name)) { return false; } - if (namespace != null ? !namespace.equals(that.namespace) : that.namespace != null) { return false; } - if (namespaces != null ? !namespaces.equals(that.namespaces) : that.namespaces != null) { return false; } - if (prefix != null ? !prefix.equals(that.prefix) : that.prefix != null) { return false; } + if (!name.equals(that.name)) { + return false; + } + if (namespace != null ? !namespace.equals(that.namespace) : that.namespace != null) { + return false; + } + if (namespaces != null ? !namespaces.equals(that.namespaces) : that.namespaces != null) { + return false; + } + if (prefix != null ? !prefix.equals(that.prefix) : that.prefix != null) { + return false; + } return true; } diff --git a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlElementStart.java b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlElementStart.java index 17d603eb4caf..19024230c896 100644 --- a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlElementStart.java +++ b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlElementStart.java @@ -35,13 +35,21 @@ public List attributes() { @Override public boolean equals(Object o) { - if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { return false; } - if (!super.equals(o)) { return false; } + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } XmlElementStart that = (XmlElementStart) o; - if (attributes != null ? !attributes.equals(that.attributes) : that.attributes != null) { return false; } + if (attributes != null ? !attributes.equals(that.attributes) : that.attributes != null) { + return false; + } return true; } diff --git a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlEntityReference.java b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlEntityReference.java index ba3cba9adad7..78ed9e709e02 100644 --- a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlEntityReference.java +++ b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlEntityReference.java @@ -38,13 +38,21 @@ public String text() { @Override public boolean equals(Object o) { - if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { return false; } + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } XmlEntityReference that = (XmlEntityReference) o; - if (name != null ? !name.equals(that.name) : that.name != null) { return false; } - if (text != null ? !text.equals(that.text) : that.text != null) { return false; } + if (name != null ? !name.equals(that.name) : that.name != null) { + return false; + } + if (text != null ? !text.equals(that.text) : that.text != null) { + return false; + } return true; } diff --git a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlNamespace.java b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlNamespace.java index 9cbb86fb42c5..2d0ae5696ba0 100644 --- a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlNamespace.java +++ b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlNamespace.java @@ -38,13 +38,21 @@ public String uri() { @Override public boolean equals(Object o) { - if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { return false; } + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } XmlNamespace that = (XmlNamespace) o; - if (prefix != null ? !prefix.equals(that.prefix) : that.prefix != null) { return false; } - if (uri != null ? !uri.equals(that.uri) : that.uri != null) { return false; } + if (prefix != null ? !prefix.equals(that.prefix) : that.prefix != null) { + return false; + } + if (uri != null ? !uri.equals(that.uri) : that.uri != null) { + return false; + } return true; } diff --git a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlProcessingInstruction.java b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlProcessingInstruction.java index 27dc4deee0d2..6f7588062cde 100644 --- a/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlProcessingInstruction.java +++ b/codec-xml/src/main/java/io/netty/handler/codec/xml/XmlProcessingInstruction.java @@ -38,13 +38,21 @@ public String target() { @Override public boolean equals(Object o) { - if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { return false; } + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } XmlProcessingInstruction that = (XmlProcessingInstruction) o; - if (data != null ? !data.equals(that.data) : that.data != null) { return false; } - if (target != null ? !target.equals(that.target) : that.target != null) { return false; } + if (data != null ? !data.equals(that.data) : that.data != null) { + return false; + } + if (target != null ? !target.equals(that.target) : that.target != null) { + return false; + } return true; } diff --git a/codec/pom.xml b/codec/pom.xml index 28c8429cc9be..b36a60b8196a 100644 --- a/codec/pom.xml +++ b/codec/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-codec @@ -33,6 +33,16 @@ + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + ${project.groupId} netty-transport diff --git a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java index 9d52b5506e13..bed1efc211cf 100644 --- a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec; +import static io.netty.util.internal.ObjectUtil.checkPositive; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; @@ -75,23 +77,28 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter public static final Cumulator MERGE_CUMULATOR = new Cumulator() { @Override public ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in) { - final ByteBuf buffer; - if (cumulation.writerIndex() > cumulation.maxCapacity() - in.readableBytes() + try { + final ByteBuf buffer; + if (cumulation.writerIndex() > cumulation.maxCapacity() - in.readableBytes() || cumulation.refCnt() > 1 || cumulation.isReadOnly()) { - // Expand cumulation (by replace it) when either there is not more room in the buffer - // or if the refCnt is greater then 1 which may happen when the user use slice().retain() or - // duplicate().retain() or if its read-only. - // - // See: - // - https://github.com/netty/netty/issues/2327 - // - https://github.com/netty/netty/issues/1764 - buffer = expandCumulation(alloc, cumulation, in.readableBytes()); - } else { - buffer = cumulation; + // Expand cumulation (by replace it) when either there is not more room in the buffer + // or if the refCnt is greater then 1 which may happen when the user use slice().retain() or + // duplicate().retain() or if its read-only. + // + // See: + // - https://github.com/netty/netty/issues/2327 + // - https://github.com/netty/netty/issues/1764 + buffer = expandCumulation(alloc, cumulation, in.readableBytes()); + } else { + buffer = cumulation; + } + buffer.writeBytes(in); + return buffer; + } finally { + // We must release in in all cases as otherwise it may produce a leak if writeBytes(...) throw + // for whatever release (for example because of OutOfMemoryError) + in.release(); } - buffer.writeBytes(in); - in.release(); - return buffer; } }; @@ -104,28 +111,36 @@ public ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in) @Override public ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in) { ByteBuf buffer; - if (cumulation.refCnt() > 1) { - // Expand cumulation (by replace it) when the refCnt is greater then 1 which may happen when the user - // use slice().retain() or duplicate().retain(). - // - // See: - // - https://github.com/netty/netty/issues/2327 - // - https://github.com/netty/netty/issues/1764 - buffer = expandCumulation(alloc, cumulation, in.readableBytes()); - buffer.writeBytes(in); - in.release(); - } else { - CompositeByteBuf composite; - if (cumulation instanceof CompositeByteBuf) { - composite = (CompositeByteBuf) cumulation; + try { + if (cumulation.refCnt() > 1) { + // Expand cumulation (by replace it) when the refCnt is greater then 1 which may happen when the + // user use slice().retain() or duplicate().retain(). + // + // See: + // - https://github.com/netty/netty/issues/2327 + // - https://github.com/netty/netty/issues/1764 + buffer = expandCumulation(alloc, cumulation, in.readableBytes()); + buffer.writeBytes(in); } else { - composite = alloc.compositeBuffer(Integer.MAX_VALUE); - composite.addComponent(true, cumulation); + CompositeByteBuf composite; + if (cumulation instanceof CompositeByteBuf) { + composite = (CompositeByteBuf) cumulation; + } else { + composite = alloc.compositeBuffer(Integer.MAX_VALUE); + composite.addComponent(true, cumulation); + } + composite.addComponent(true, in); + in = null; + buffer = composite; + } + return buffer; + } finally { + if (in != null) { + // We must release if the ownership was not transferred as otherwise it may produce a leak if + // writeBytes(...) throw for whatever release (for example because of OutOfMemoryError). + in.release(); } - composite.addComponent(true, in); - buffer = composite; } - return buffer; } }; @@ -189,9 +204,7 @@ public void setCumulator(Cumulator cumulator) { * The default is {@code 16}. */ public void setDiscardAfterReads(int discardAfterReads) { - if (discardAfterReads <= 0) { - throw new IllegalArgumentException("discardAfterReads must be > 0"); - } + checkPositive(discardAfterReads, "discardAfterReads"); this.discardAfterReads = discardAfterReads; } diff --git a/codec/src/main/java/io/netty/handler/codec/CharSequenceValueConverter.java b/codec/src/main/java/io/netty/handler/codec/CharSequenceValueConverter.java index b6510d1ff7e0..1157af0623e5 100644 --- a/codec/src/main/java/io/netty/handler/codec/CharSequenceValueConverter.java +++ b/codec/src/main/java/io/netty/handler/codec/CharSequenceValueConverter.java @@ -77,7 +77,7 @@ public CharSequence convertByte(byte value) { @Override public byte convertToByte(CharSequence value) { - if (value instanceof AsciiString) { + if (value instanceof AsciiString && value.length() == 1) { return ((AsciiString) value).byteAt(0); } return Byte.parseByte(value.toString()); diff --git a/codec/src/main/java/io/netty/handler/codec/CodecOutputList.java b/codec/src/main/java/io/netty/handler/codec/CodecOutputList.java index bd707ec5d31d..5cf161fc694c 100644 --- a/codec/src/main/java/io/netty/handler/codec/CodecOutputList.java +++ b/codec/src/main/java/io/netty/handler/codec/CodecOutputList.java @@ -148,7 +148,7 @@ public void add(int index, Object element) { expandArray(); } - if (index != size - 1) { + if (index != size) { System.arraycopy(array, index, array, index + 1, size - index); } diff --git a/codec/src/main/java/io/netty/handler/codec/DateFormatter.java b/codec/src/main/java/io/netty/handler/codec/DateFormatter.java index 86df148500fb..b07912abaa2e 100644 --- a/codec/src/main/java/io/netty/handler/codec/DateFormatter.java +++ b/codec/src/main/java/io/netty/handler/codec/DateFormatter.java @@ -261,10 +261,6 @@ private boolean tryParseDayOfMonth(CharSequence txt, int tokenStart, int tokenEn return false; } - private static boolean matchMonth(String month, CharSequence txt, int tokenStart) { - return AsciiString.regionMatchesAscii(month, true, 0, txt, tokenStart, 3); - } - private boolean tryParseMonth(CharSequence txt, int tokenStart, int tokenEnd) { int len = tokenEnd - tokenStart; @@ -272,29 +268,33 @@ private boolean tryParseMonth(CharSequence txt, int tokenStart, int tokenEnd) { return false; } - if (matchMonth("Jan", txt, tokenStart)) { + char monthChar1 = AsciiString.toLowerCase(txt.charAt(tokenStart)); + char monthChar2 = AsciiString.toLowerCase(txt.charAt(tokenStart + 1)); + char monthChar3 = AsciiString.toLowerCase(txt.charAt(tokenStart + 2)); + + if (monthChar1 == 'j' && monthChar2 == 'a' && monthChar3 == 'n') { month = Calendar.JANUARY; - } else if (matchMonth("Feb", txt, tokenStart)) { + } else if (monthChar1 == 'f' && monthChar2 == 'e' && monthChar3 == 'b') { month = Calendar.FEBRUARY; - } else if (matchMonth("Mar", txt, tokenStart)) { + } else if (monthChar1 == 'm' && monthChar2 == 'a' && monthChar3 == 'r') { month = Calendar.MARCH; - } else if (matchMonth("Apr", txt, tokenStart)) { + } else if (monthChar1 == 'a' && monthChar2 == 'p' && monthChar3 == 'r') { month = Calendar.APRIL; - } else if (matchMonth("May", txt, tokenStart)) { + } else if (monthChar1 == 'm' && monthChar2 == 'a' && monthChar3 == 'y') { month = Calendar.MAY; - } else if (matchMonth("Jun", txt, tokenStart)) { + } else if (monthChar1 == 'j' && monthChar2 == 'u' && monthChar3 == 'n') { month = Calendar.JUNE; - } else if (matchMonth("Jul", txt, tokenStart)) { + } else if (monthChar1 == 'j' && monthChar2 == 'u' && monthChar3 == 'l') { month = Calendar.JULY; - } else if (matchMonth("Aug", txt, tokenStart)) { + } else if (monthChar1 == 'a' && monthChar2 == 'u' && monthChar3 == 'g') { month = Calendar.AUGUST; - } else if (matchMonth("Sep", txt, tokenStart)) { + } else if (monthChar1 == 's' && monthChar2 == 'e' && monthChar3 == 'p') { month = Calendar.SEPTEMBER; - } else if (matchMonth("Oct", txt, tokenStart)) { + } else if (monthChar1 == 'o' && monthChar2 == 'c' && monthChar3 == 't') { month = Calendar.OCTOBER; - } else if (matchMonth("Nov", txt, tokenStart)) { + } else if (monthChar1 == 'n' && monthChar2 == 'o' && monthChar3 == 'v') { month = Calendar.NOVEMBER; - } else if (matchMonth("Dec", txt, tokenStart)) { + } else if (monthChar1 == 'd' && monthChar2 == 'e' && monthChar3 == 'c') { month = Calendar.DECEMBER; } else { return false; diff --git a/codec/src/main/java/io/netty/handler/codec/DelimiterBasedFrameDecoder.java b/codec/src/main/java/io/netty/handler/codec/DelimiterBasedFrameDecoder.java index 27e8c201576e..2fb6fec729b5 100644 --- a/codec/src/main/java/io/netty/handler/codec/DelimiterBasedFrameDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/DelimiterBasedFrameDecoder.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec; +import static io.netty.util.internal.ObjectUtil.checkPositive; + import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; @@ -346,10 +348,6 @@ private static void validateDelimiter(ByteBuf delimiter) { } private static void validateMaxFrameLength(int maxFrameLength) { - if (maxFrameLength <= 0) { - throw new IllegalArgumentException( - "maxFrameLength must be a positive integer: " + - maxFrameLength); - } + checkPositive(maxFrameLength, "maxFrameLength"); } } diff --git a/codec/src/main/java/io/netty/handler/codec/FixedLengthFrameDecoder.java b/codec/src/main/java/io/netty/handler/codec/FixedLengthFrameDecoder.java index 5b4bb7187c50..9475e5b65118 100644 --- a/codec/src/main/java/io/netty/handler/codec/FixedLengthFrameDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/FixedLengthFrameDecoder.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec; +import static io.netty.util.internal.ObjectUtil.checkPositive; + import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; @@ -46,10 +48,7 @@ public class FixedLengthFrameDecoder extends ByteToMessageDecoder { * @param frameLength the length of the frame */ public FixedLengthFrameDecoder(int frameLength) { - if (frameLength <= 0) { - throw new IllegalArgumentException( - "frameLength must be a positive integer: " + frameLength); - } + checkPositive(frameLength, "frameLength"); this.frameLength = frameLength; } diff --git a/codec/src/main/java/io/netty/handler/codec/LengthFieldBasedFrameDecoder.java b/codec/src/main/java/io/netty/handler/codec/LengthFieldBasedFrameDecoder.java index 4d94bdf881a5..b2a5cac49bf7 100644 --- a/codec/src/main/java/io/netty/handler/codec/LengthFieldBasedFrameDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/LengthFieldBasedFrameDecoder.java @@ -15,6 +15,9 @@ */ package io.netty.handler.codec; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import java.nio.ByteOrder; import java.util.List; @@ -302,23 +305,11 @@ public LengthFieldBasedFrameDecoder( throw new NullPointerException("byteOrder"); } - if (maxFrameLength <= 0) { - throw new IllegalArgumentException( - "maxFrameLength must be a positive integer: " + - maxFrameLength); - } + checkPositive(maxFrameLength, "maxFrameLength"); - if (lengthFieldOffset < 0) { - throw new IllegalArgumentException( - "lengthFieldOffset must be a non-negative integer: " + - lengthFieldOffset); - } + checkPositiveOrZero(lengthFieldOffset, "lengthFieldOffset"); - if (initialBytesToStrip < 0) { - throw new IllegalArgumentException( - "initialBytesToStrip must be a non-negative integer: " + - initialBytesToStrip); - } + checkPositiveOrZero(initialBytesToStrip, "initialBytesToStrip"); if (lengthFieldOffset > maxFrameLength - lengthFieldLength) { throw new IllegalArgumentException( diff --git a/codec/src/main/java/io/netty/handler/codec/LengthFieldPrepender.java b/codec/src/main/java/io/netty/handler/codec/LengthFieldPrepender.java index 4076e07deed9..b5c787f531fe 100644 --- a/codec/src/main/java/io/netty/handler/codec/LengthFieldPrepender.java +++ b/codec/src/main/java/io/netty/handler/codec/LengthFieldPrepender.java @@ -15,6 +15,8 @@ */ package io.netty.handler.codec; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandlerContext; @@ -163,10 +165,7 @@ protected void encode(ChannelHandlerContext ctx, ByteBuf msg, List out) length += lengthFieldLength; } - if (length < 0) { - throw new IllegalArgumentException( - "Adjusted frame length (" + length + ") is less than zero"); - } + checkPositiveOrZero(length, "length"); switch (lengthFieldLength) { case 1: diff --git a/codec/src/main/java/io/netty/handler/codec/LineBasedFrameDecoder.java b/codec/src/main/java/io/netty/handler/codec/LineBasedFrameDecoder.java index 305981b7622d..fda45cc401ec 100644 --- a/codec/src/main/java/io/netty/handler/codec/LineBasedFrameDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/LineBasedFrameDecoder.java @@ -25,6 +25,12 @@ * A decoder that splits the received {@link ByteBuf}s on line endings. *

* Both {@code "\n"} and {@code "\r\n"} are handled. + *

+ * The byte stream is expected to be in UTF-8 character encoding or ASCII. The current implementation + * uses direct {@code byte} to {@code char} cast and then compares that {@code char} to a few low range + * ASCII characters like {@code '\n'} or {@code '\r'}. UTF-8 is not using low range [0..0x7F] + * byte values for multibyte codepoint representations therefore fully supported by this implementation. + *

* For a more general delimiter-based decoder, see {@link DelimiterBasedFrameDecoder}. */ public class LineBasedFrameDecoder extends ByteToMessageDecoder { @@ -137,6 +143,8 @@ protected Object decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Except } else { discardedBytes += buffer.readableBytes(); buffer.readerIndex(buffer.writerIndex()); + // We skip everything in the buffer, we need to set the offset to 0 again. + offset = 0; } return null; } diff --git a/codec/src/main/java/io/netty/handler/codec/MessageAggregator.java b/codec/src/main/java/io/netty/handler/codec/MessageAggregator.java index 2cdb880c9932..11a75f2dd6e1 100644 --- a/codec/src/main/java/io/netty/handler/codec/MessageAggregator.java +++ b/codec/src/main/java/io/netty/handler/codec/MessageAggregator.java @@ -28,6 +28,7 @@ import java.util.List; import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; /** * An abstract {@link ChannelHandler} that aggregates a series of message objects into a single aggregated message. @@ -61,6 +62,8 @@ public abstract class MessageAggregator inboundMess } private static void validateMaxContentLength(int maxContentLength) { - if (maxContentLength < 0) { - throw new IllegalArgumentException("maxContentLength: " + maxContentLength + " (expected: >= 0)"); - } + checkPositiveOrZero(maxContentLength, "maxContentLength"); } @Override @@ -96,7 +97,20 @@ public boolean acceptInboundMessage(Object msg) throws Exception { @SuppressWarnings("unchecked") I in = (I) msg; - return (isContentMessage(in) || isStartMessage(in)) && !isAggregated(in); + if (isAggregated(in)) { + return false; + } + + // NOTE: It's tempting to make this check only if aggregating is false. There are however + // side conditions in decode(...) in respect to large messages. + if (isStartMessage(in)) { + aggregating = true; + return true; + } else if (aggregating && isContentMessage(in)) { + return true; + } + + return false; } /** @@ -192,6 +206,8 @@ protected final ChannelHandlerContext ctx() { @Override protected void decode(final ChannelHandlerContext ctx, I msg, List out) throws Exception { + assert aggregating; + if (isStartMessage(msg)) { handlingOversizedMessage = false; if (currentMessage != null) { @@ -246,7 +262,7 @@ public void operationComplete(ChannelFuture future) throws Exception { } else { aggregated = beginAggregation(m, EMPTY_BUFFER); } - finishAggregation(aggregated); + finishAggregation0(aggregated); out.add(aggregated); return; } @@ -301,7 +317,7 @@ public void operationComplete(ChannelFuture future) throws Exception { } if (last) { - finishAggregation(currentMessage); + finishAggregation0(currentMessage); // All done out.add(currentMessage); @@ -371,6 +387,11 @@ protected abstract Object newContinueResponse(S start, int maxContentLength, Cha */ protected void aggregate(O aggregated, C content) throws Exception { } + private void finishAggregation0(O aggregated) throws Exception { + aggregating = false; + finishAggregation(aggregated); + } + /** * Invoked when the specified {@code aggregated} message is about to be passed to the next handler in the pipeline. */ @@ -378,6 +399,7 @@ protected void finishAggregation(O aggregated) throws Exception { } private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception { handlingOversizedMessage = true; + aggregating = false; currentMessage = null; try { handleOversizedMessage(ctx, oversized); @@ -441,6 +463,7 @@ private void releaseCurrentMessage() { currentMessage.release(); currentMessage = null; handlingOversizedMessage = false; + aggregating = false; } } } diff --git a/codec/src/main/java/io/netty/handler/codec/MessageToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/MessageToMessageDecoder.java index 0337e768a0c7..ec828e2e3281 100644 --- a/codec/src/main/java/io/netty/handler/codec/MessageToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/MessageToMessageDecoder.java @@ -107,7 +107,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception /** * Decode from one message to an other. This method will be called for each written message that can be handled - * by this encoder. + * by this decoder. * * @param ctx the {@link ChannelHandlerContext} which this {@link MessageToMessageDecoder} belongs to * @param msg the message to decode to an other one diff --git a/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java b/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java index 16eed6890196..439dc8cb193e 100644 --- a/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java +++ b/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java @@ -22,6 +22,7 @@ import io.netty.channel.ChannelPromise; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.PromiseCombiner; import io.netty.util.internal.StringUtil; import io.netty.util.internal.TypeParameterMatcher; @@ -108,28 +109,36 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) if (out != null) { final int sizeMinusOne = out.size() - 1; if (sizeMinusOne == 0) { - ctx.write(out.get(0), promise); + ctx.write(out.getUnsafe(0), promise); } else if (sizeMinusOne > 0) { // Check if we can use a voidPromise for our extra writes to reduce GC-Pressure // See https://github.com/netty/netty/issues/2525 - ChannelPromise voidPromise = ctx.voidPromise(); - boolean isVoidPromise = promise == voidPromise; - for (int i = 0; i < sizeMinusOne; i ++) { - ChannelPromise p; - if (isVoidPromise) { - p = voidPromise; - } else { - p = ctx.newPromise(); - } - ctx.write(out.getUnsafe(i), p); + if (promise == ctx.voidPromise()) { + writeVoidPromise(ctx, out); + } else { + writePromiseCombiner(ctx, out, promise); } - ctx.write(out.getUnsafe(sizeMinusOne), promise); } out.recycle(); } } } + private static void writeVoidPromise(ChannelHandlerContext ctx, CodecOutputList out) { + final ChannelPromise voidPromise = ctx.voidPromise(); + for (int i = 0; i < out.size(); i++) { + ctx.write(out.getUnsafe(i), voidPromise); + } + } + + private static void writePromiseCombiner(ChannelHandlerContext ctx, CodecOutputList out, ChannelPromise promise) { + final PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); + for (int i = 0; i < out.size(); i++) { + combiner.add(ctx.write(out.getUnsafe(i))); + } + combiner.finish(promise); + } + /** * Encode from one message to an other. This method will be called for each written message that can be handled * by this encoder. diff --git a/codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayDecoder.java b/codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayDecoder.java index 1fe78c505125..21429d1867b1 100644 --- a/codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayDecoder.java @@ -16,6 +16,7 @@ package io.netty.handler.codec.bytes; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; @@ -52,9 +53,6 @@ public class ByteArrayDecoder extends MessageToMessageDecoder { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) throws Exception { // copy the ByteBuf content to a byte array - byte[] array = new byte[msg.readableBytes()]; - msg.getBytes(0, array); - - out.add(array); + out.add(ByteBufUtil.getBytes(msg)); } } diff --git a/codec/src/main/java/io/netty/handler/codec/compression/Bzip2DivSufSort.java b/codec/src/main/java/io/netty/handler/codec/compression/Bzip2DivSufSort.java index cdf92a698313..813872938c2c 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/Bzip2DivSufSort.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/Bzip2DivSufSort.java @@ -568,7 +568,9 @@ private void ssMergeForward(final int pa, int[] buf, final int bufoffset, SA[i++] = SA[k]; SA[k++] = SA[i]; if (last <= k) { - while (j < bufend) { SA[i++] = buf[j]; buf[j++] = SA[i]; } + while (j < bufend) { + SA[i++] = buf[j]; buf[j++] = SA[i]; + } SA[i] = buf[j]; buf[j] = t; return; } diff --git a/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibEncoder.java b/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibEncoder.java index f039fa66e849..276d7f86b0cd 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibEncoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/JdkZlibEncoder.java @@ -225,8 +225,18 @@ protected void encode(ChannelHandlerContext ctx, ByteBuf uncompressed, ByteBuf o } deflater.setInput(inAry, offset, len); - while (!deflater.needsInput()) { + for (;;) { deflate(out); + if (deflater.needsInput()) { + // Consumed everything + break; + } else { + if (!out.isWritable()) { + // We did not consume everything but the buffer is not writable anymore. Increase the capacity to + // make more room. + out.ensureWritable(out.writerIndex()); + } + } } } diff --git a/codec/src/main/java/io/netty/handler/codec/json/JsonObjectDecoder.java b/codec/src/main/java/io/netty/handler/codec/json/JsonObjectDecoder.java index f63e62066ee0..2508ff6c09c1 100644 --- a/codec/src/main/java/io/netty/handler/codec/json/JsonObjectDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/json/JsonObjectDecoder.java @@ -30,7 +30,12 @@ /** * Splits a byte stream of JSON objects and arrays into individual objects/arrays and passes them up the * {@link ChannelPipeline}. - * + *

+ * The byte stream is expected to be in UTF-8 character encoding or ASCII. The current implementation + * uses direct {@code byte} to {@code char} cast and then compares that {@code char} to a few low range + * ASCII characters like {@code '{'}, {@code '['} or {@code '"'}. UTF-8 is not using low range [0..0x7F] + * byte values for multibyte codepoint representations therefore fully supported by this implementation. + *

* This class does not do any real parsing or validation. A sequence of bytes is considered a JSON object/array * if it contains a matching number of opening and closing braces/brackets. It's up to a subsequent * {@link ChannelHandler} to parse the JSON text into a more usable form i.e. a POJO. diff --git a/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoder.java b/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoder.java index d48bfbf7a905..9ef56f11d4f4 100644 --- a/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoder.java @@ -20,6 +20,7 @@ import com.google.protobuf.Message; import com.google.protobuf.MessageLite; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; @@ -111,8 +112,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) array = msg.array(); offset = msg.arrayOffset() + msg.readerIndex(); } else { - array = new byte[length]; - msg.getBytes(msg.readerIndex(), array, 0, length); + array = ByteBufUtil.getBytes(msg, msg.readerIndex(), length, false); offset = 0; } diff --git a/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoderNano.java b/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoderNano.java index 5144b5fa3dd8..0d6685c979bd 100644 --- a/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoderNano.java +++ b/codec/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoderNano.java @@ -20,6 +20,7 @@ import java.util.List; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; @@ -78,8 +79,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) array = msg.array(); offset = msg.arrayOffset() + msg.readerIndex(); } else { - array = new byte[length]; - msg.getBytes(msg.readerIndex(), array, 0, length); + array = ByteBufUtil.getBytes(msg, msg.readerIndex(), length, false); offset = 0; } MessageNano prototype = clazz.getConstructor().newInstance(); diff --git a/codec/src/main/java/io/netty/handler/codec/serialization/ClassResolvers.java b/codec/src/main/java/io/netty/handler/codec/serialization/ClassResolvers.java index 0d8c11f2f9f3..32250ada801b 100644 --- a/codec/src/main/java/io/netty/handler/codec/serialization/ClassResolvers.java +++ b/codec/src/main/java/io/netty/handler/codec/serialization/ClassResolvers.java @@ -32,7 +32,7 @@ public static ClassResolver cacheDisabled(ClassLoader classLoader) { } /** - * non-agressive non-concurrent cache + * non-aggressive non-concurrent cache * good for non-shared default cache * * @param classLoader - specific classLoader to use, or null if you want to revert to default @@ -45,7 +45,7 @@ public static ClassResolver weakCachingResolver(ClassLoader classLoader) { } /** - * agressive non-concurrent cache + * aggressive non-concurrent cache * good for non-shared cache, when we're not worried about class unloading * * @param classLoader - specific classLoader to use, or null if you want to revert to default @@ -58,7 +58,7 @@ public static ClassResolver softCachingResolver(ClassLoader classLoader) { } /** - * non-agressive concurrent cache + * non-aggressive concurrent cache * good for shared cache, when we're worried about class unloading * * @param classLoader - specific classLoader to use, or null if you want to revert to default @@ -72,7 +72,7 @@ public static ClassResolver weakCachingConcurrentResolver(ClassLoader classLoade } /** - * agressive concurrent cache + * aggressive concurrent cache * good for shared cache, when we're not worried about class unloading * * @param classLoader - specific classLoader to use, or null if you want to revert to default diff --git a/codec/src/main/java/io/netty/handler/codec/xml/XmlFrameDecoder.java b/codec/src/main/java/io/netty/handler/codec/xml/XmlFrameDecoder.java index 4a8b262bf8a3..480f0be1c53e 100644 --- a/codec/src/main/java/io/netty/handler/codec/xml/XmlFrameDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/xml/XmlFrameDecoder.java @@ -59,6 +59,12 @@ * +-----------------+-------------------------------------+ * * + *

+ * The byte stream is expected to be in UTF-8 character encoding or ASCII. The current implementation + * uses direct {@code byte} to {@code char} cast and then compares that {@code char} to a few low range + * ASCII characters like {@code '<'}, {@code '>'} or {@code '/'}. UTF-8 is not using low range [0..0x7F] + * byte values for multibyte codepoint representations therefore fully supported by this implementation. + *

* Please note that this decoder is not suitable for * xml streaming protocols such as * XMPP, diff --git a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java index 8462ee4a020c..93dd2218ac0f 100644 --- a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java @@ -16,7 +16,10 @@ package io.netty.handler.codec; import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.buffer.UnpooledHeapByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; @@ -27,10 +30,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingDeque; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public class ByteToMessageDecoderTest { @@ -305,4 +305,44 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(new byte[] { (byte) 2 }))); assertFalse(channel.finish()); } + + @Test + public void releaseWhenMergeCumulateThrows() { + final Error error = new Error(); + + ByteBuf cumulation = new UnpooledHeapByteBuf(UnpooledByteBufAllocator.DEFAULT, 0, 64) { + @Override + public ByteBuf writeBytes(ByteBuf src) { + throw error; + } + }; + ByteBuf in = Unpooled.buffer().writeZero(12); + try { + ByteToMessageDecoder.MERGE_CUMULATOR.cumulate(UnpooledByteBufAllocator.DEFAULT, cumulation, in); + fail(); + } catch (Error expected) { + assertSame(error, expected); + assertEquals(0, in.refCnt()); + } + } + + @Test + public void releaseWhenCompositeCumulateThrows() { + final Error error = new Error(); + + ByteBuf cumulation = new CompositeByteBuf(UnpooledByteBufAllocator.DEFAULT, false, 64) { + @Override + public CompositeByteBuf addComponent(boolean increaseWriterIndex, ByteBuf buffer) { + throw error; + } + }; + ByteBuf in = Unpooled.buffer().writeZero(12); + try { + ByteToMessageDecoder.COMPOSITE_CUMULATOR.cumulate(UnpooledByteBufAllocator.DEFAULT, cumulation, in); + fail(); + } catch (Error expected) { + assertSame(error, expected); + assertEquals(0, in.refCnt()); + } + } } diff --git a/codec/src/test/java/io/netty/handler/codec/CharSequenceValueConverterTest.java b/codec/src/test/java/io/netty/handler/codec/CharSequenceValueConverterTest.java index 5543e2f90955..2347f0d0bf9b 100644 --- a/codec/src/test/java/io/netty/handler/codec/CharSequenceValueConverterTest.java +++ b/codec/src/test/java/io/netty/handler/codec/CharSequenceValueConverterTest.java @@ -14,6 +14,7 @@ */ package io.netty.handler.codec; +import io.netty.util.AsciiString; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -30,6 +31,16 @@ public void testBoolean() { assertFalse(converter.convertToBoolean(converter.convertBoolean(false))); } + @Test + public void testByteFromAsciiString() { + assertEquals(127, converter.convertToByte(AsciiString.of("127"))); + } + + @Test(expected = NumberFormatException.class) + public void testByteFromEmptyAsciiString() { + converter.convertToByte(AsciiString.EMPTY_STRING); + } + @Test public void testByte() { assertEquals(Byte.MAX_VALUE, converter.convertToByte(converter.convertByte(Byte.MAX_VALUE))); diff --git a/codec/src/test/java/io/netty/handler/codec/CodecOutputListTest.java b/codec/src/test/java/io/netty/handler/codec/CodecOutputListTest.java new file mode 100644 index 000000000000..1aa0927999da --- /dev/null +++ b/codec/src/test/java/io/netty/handler/codec/CodecOutputListTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import org.junit.Test; + + +import static org.junit.Assert.*; + +public class CodecOutputListTest { + + @Test + public void testCodecOutputListAdd() { + CodecOutputList codecOutputList = CodecOutputList.newInstance(); + try { + assertEquals(0, codecOutputList.size()); + assertTrue(codecOutputList.isEmpty()); + + codecOutputList.add(1); + assertEquals(1, codecOutputList.size()); + assertFalse(codecOutputList.isEmpty()); + assertEquals(1, codecOutputList.get(0)); + + codecOutputList.add(0, 0); + assertEquals(2, codecOutputList.size()); + assertFalse(codecOutputList.isEmpty()); + assertEquals(0, codecOutputList.get(0)); + assertEquals(1, codecOutputList.get(1)); + + codecOutputList.add(1, 2); + assertEquals(3, codecOutputList.size()); + assertFalse(codecOutputList.isEmpty()); + assertEquals(0, codecOutputList.get(0)); + assertEquals(2, codecOutputList.get(1)); + assertEquals(1, codecOutputList.get(2)); + } finally { + codecOutputList.recycle(); + } + } +} diff --git a/codec/src/test/java/io/netty/handler/codec/DateFormatterTest.java b/codec/src/test/java/io/netty/handler/codec/DateFormatterTest.java index 99d1d3864a9a..e833fdfb605c 100644 --- a/codec/src/test/java/io/netty/handler/codec/DateFormatterTest.java +++ b/codec/src/test/java/io/netty/handler/codec/DateFormatterTest.java @@ -17,6 +17,7 @@ import org.junit.Test; +import java.util.Calendar; import java.util.Date; import static org.junit.Assert.*; @@ -111,4 +112,26 @@ public void testParseInvalidInput() { public void testFormat() { assertEquals("Sun, 6 Nov 1994 08:49:37 GMT", format(DATE)); } + + @Test + public void testParseAllMonths() { + assertEquals(Calendar.JANUARY, getMonth(parseHttpDate("Sun, 6 Jan 1994 08:49:37 GMT"))); + assertEquals(Calendar.FEBRUARY, getMonth(parseHttpDate("Sun, 6 Feb 1994 08:49:37 GMT"))); + assertEquals(Calendar.MARCH, getMonth(parseHttpDate("Sun, 6 Mar 1994 08:49:37 GMT"))); + assertEquals(Calendar.APRIL, getMonth(parseHttpDate("Sun, 6 Apr 1994 08:49:37 GMT"))); + assertEquals(Calendar.MAY, getMonth(parseHttpDate("Sun, 6 May 1994 08:49:37 GMT"))); + assertEquals(Calendar.JUNE, getMonth(parseHttpDate("Sun, 6 Jun 1994 08:49:37 GMT"))); + assertEquals(Calendar.JULY, getMonth(parseHttpDate("Sun, 6 Jul 1994 08:49:37 GMT"))); + assertEquals(Calendar.AUGUST, getMonth(parseHttpDate("Sun, 6 Aug 1994 08:49:37 GMT"))); + assertEquals(Calendar.SEPTEMBER, getMonth(parseHttpDate("Sun, 6 Sep 1994 08:49:37 GMT"))); + assertEquals(Calendar.OCTOBER, getMonth(parseHttpDate("Sun Oct 6 08:49:37 1994"))); + assertEquals(Calendar.NOVEMBER, getMonth(parseHttpDate("Sun Nov 6 08:49:37 1994"))); + assertEquals(Calendar.DECEMBER, getMonth(parseHttpDate("Sun Dec 6 08:49:37 1994"))); + } + + private static int getMonth(Date referenceDate) { + Calendar cal = Calendar.getInstance(); + cal.setTime(referenceDate); + return cal.get(Calendar.MONTH); + } } diff --git a/codec/src/test/java/io/netty/handler/codec/DefaultHeadersTest.java b/codec/src/test/java/io/netty/handler/codec/DefaultHeadersTest.java index 073872bf10c5..2ed0498be4ba 100644 --- a/codec/src/test/java/io/netty/handler/codec/DefaultHeadersTest.java +++ b/codec/src/test/java/io/netty/handler/codec/DefaultHeadersTest.java @@ -41,11 +41,11 @@ public class DefaultHeadersTest { private static final class TestDefaultHeaders extends DefaultHeaders { - public TestDefaultHeaders() { + TestDefaultHeaders() { this(CharSequenceValueConverter.INSTANCE); } - public TestDefaultHeaders(ValueConverter converter) { + TestDefaultHeaders(ValueConverter converter) { super(converter); } } diff --git a/codec/src/test/java/io/netty/handler/codec/LineBasedFrameDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/LineBasedFrameDecoderTest.java index b68c38bbbeab..9edd178fa251 100644 --- a/codec/src/test/java/io/netty/handler/codec/LineBasedFrameDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/LineBasedFrameDecoderTest.java @@ -185,4 +185,27 @@ public void testEmptyLine() throws Exception { buf.release(); buf2.release(); } + + @Test + public void testNotFailFast() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(2, false, false)); + assertFalse(ch.writeInbound(wrappedBuffer(new byte[] { 0, 1, 2 }))); + assertFalse(ch.writeInbound(wrappedBuffer(new byte[]{ 3, 4 }))); + try { + ch.writeInbound(wrappedBuffer(new byte[] { '\n' })); + fail(); + } catch (TooLongFrameException expected) { + // Expected once we received a full frame. + } + assertFalse(ch.writeInbound(wrappedBuffer(new byte[] { '5' }))); + assertTrue(ch.writeInbound(wrappedBuffer(new byte[] { '\n' }))); + + ByteBuf expected = wrappedBuffer(new byte[] { '5', '\n' }); + ByteBuf buffer = ch.readInbound(); + assertEquals(expected, buffer); + expected.release(); + buffer.release(); + + assertFalse(ch.finish()); + } } diff --git a/codec/src/test/java/io/netty/handler/codec/MessageToMessageEncoderTest.java b/codec/src/test/java/io/netty/handler/codec/MessageToMessageEncoderTest.java index 322631adf6c9..2d09725af0b6 100644 --- a/codec/src/test/java/io/netty/handler/codec/MessageToMessageEncoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/MessageToMessageEncoderTest.java @@ -15,9 +15,14 @@ */ package io.netty.handler.codec; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; import org.junit.Test; +import static org.junit.Assert.*; import java.util.List; @@ -37,4 +42,37 @@ protected void encode(ChannelHandlerContext ctx, Object msg, List out) t }); channel.writeOutbound(new Object()); } + + @Test + public void testIntermediateWriteFailures() { + ChannelHandler encoder = new MessageToMessageEncoder() { + @Override + protected void encode(ChannelHandlerContext ctx, Object msg, List out) { + out.add(new Object()); + out.add(msg); + } + }; + + final Exception firstWriteException = new Exception(); + + ChannelHandler writeThrower = new ChannelOutboundHandlerAdapter() { + private boolean firstWritten; + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (firstWritten) { + ctx.write(msg, promise); + } else { + firstWritten = true; + promise.setFailure(firstWriteException); + } + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(writeThrower, encoder); + Object msg = new Object(); + ChannelFuture write = channel.writeAndFlush(msg); + assertSame(firstWriteException, write.cause()); + assertSame(msg, channel.readOutbound()); + assertFalse(channel.finish()); + } } diff --git a/codec/src/test/java/io/netty/handler/codec/compression/SnappyTest.java b/codec/src/test/java/io/netty/handler/codec/compression/SnappyTest.java index 232804709392..115deef15804 100644 --- a/codec/src/test/java/io/netty/handler/codec/compression/SnappyTest.java +++ b/codec/src/test/java/io/netty/handler/codec/compression/SnappyTest.java @@ -24,6 +24,8 @@ import static io.netty.handler.codec.compression.Snappy.*; import static org.junit.Assert.*; +import java.nio.CharBuffer; + public class SnappyTest { private final Snappy snappy = new Snappy(); @@ -219,7 +221,8 @@ public void encodeAndDecodeLongTextUsesCopy() throws Exception { // Decode ByteBuf outDecoded = Unpooled.buffer(); snappy.decode(out, outDecoded); - assertEquals(srcStr, outDecoded.getCharSequence(0, outDecoded.writerIndex(), CharsetUtil.US_ASCII)); + assertEquals(CharBuffer.wrap(srcStr), + CharBuffer.wrap(outDecoded.getCharSequence(0, outDecoded.writerIndex(), CharsetUtil.US_ASCII))); in.release(); out.release(); diff --git a/common/pom.xml b/common/pom.xml index dfd9adcb8e8f..6c041034b3c9 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-common diff --git a/common/src/main/java/io/netty/util/AbstractReferenceCounted.java b/common/src/main/java/io/netty/util/AbstractReferenceCounted.java index fe20f92a2e8d..b7480c4eb872 100644 --- a/common/src/main/java/io/netty/util/AbstractReferenceCounted.java +++ b/common/src/main/java/io/netty/util/AbstractReferenceCounted.java @@ -15,30 +15,58 @@ */ package io.netty.util; +import static io.netty.util.internal.ObjectUtil.checkPositive; + import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import static io.netty.util.internal.ObjectUtil.checkPositive; +import io.netty.util.internal.PlatformDependent; /** * Abstract base class for classes wants to implement {@link ReferenceCounted}. */ public abstract class AbstractReferenceCounted implements ReferenceCounted { - + private static final long REFCNT_FIELD_OFFSET; private static final AtomicIntegerFieldUpdater refCntUpdater = AtomicIntegerFieldUpdater.newUpdater(AbstractReferenceCounted.class, "refCnt"); - private volatile int refCnt = 1; + // even => "real" refcount is (refCnt >>> 1); odd => "real" refcount is 0 + @SuppressWarnings("unused") + private volatile int refCnt = 2; + + static { + long refCntFieldOffset = -1; + try { + if (PlatformDependent.hasUnsafe()) { + refCntFieldOffset = PlatformDependent.objectFieldOffset( + AbstractReferenceCounted.class.getDeclaredField("refCnt")); + } + } catch (Throwable ignore) { + refCntFieldOffset = -1; + } + + REFCNT_FIELD_OFFSET = refCntFieldOffset; + } + + private static int realRefCnt(int rawCnt) { + return (rawCnt & 1) != 0 ? 0 : rawCnt >>> 1; + } + + private int nonVolatileRawCnt() { + // TODO: Once we compile against later versions of Java we can replace the Unsafe usage here by varhandles. + return REFCNT_FIELD_OFFSET != -1 ? PlatformDependent.getInt(this, REFCNT_FIELD_OFFSET) + : refCntUpdater.get(this); + } @Override - public final int refCnt() { - return refCnt; + public int refCnt() { + return realRefCnt(refCntUpdater.get(this)); } /** * An unsafe operation intended for use by a subclass that sets the reference count of the buffer directly */ - protected final void setRefCnt(int refCnt) { - refCntUpdater.set(this, refCnt); + protected final void setRefCnt(int newRefCnt) { + refCntUpdater.set(this, newRefCnt << 1); // overflow OK here } @Override @@ -51,12 +79,19 @@ public ReferenceCounted retain(int increment) { return retain0(checkPositive(increment, "increment")); } - private ReferenceCounted retain0(int increment) { - int oldRef = refCntUpdater.getAndAdd(this, increment); - if (oldRef <= 0 || oldRef + increment < oldRef) { - // Ensure we don't resurrect (which means the refCnt was 0) and also that we encountered an overflow. - refCntUpdater.getAndAdd(this, -increment); - throw new IllegalReferenceCountException(oldRef, increment); + private ReferenceCounted retain0(final int increment) { + // all changes to the raw count are 2x the "real" change + int adjustedIncrement = increment << 1; // overflow OK here + int oldRef = refCntUpdater.getAndAdd(this, adjustedIncrement); + if ((oldRef & 1) != 0) { + throw new IllegalReferenceCountException(0, increment); + } + // don't pass 0! + if ((oldRef <= 0 && oldRef + adjustedIncrement >= 0) + || (oldRef >= 0 && oldRef + adjustedIncrement < oldRef)) { + // overflow case + refCntUpdater.getAndAdd(this, -adjustedIncrement); + throw new IllegalReferenceCountException(realRefCnt(oldRef), increment); } return this; } @@ -77,16 +112,55 @@ public boolean release(int decrement) { } private boolean release0(int decrement) { - int oldRef = refCntUpdater.getAndAdd(this, -decrement); - if (oldRef == decrement) { - deallocate(); - return true; - } else if (oldRef < decrement || oldRef - decrement > oldRef) { - // Ensure we don't over-release, and avoid underflow. - refCntUpdater.getAndAdd(this, decrement); - throw new IllegalReferenceCountException(oldRef, -decrement); + int rawCnt = nonVolatileRawCnt(), realCnt = toLiveRealCnt(rawCnt, decrement); + if (decrement == realCnt) { + if (refCntUpdater.compareAndSet(this, rawCnt, 1)) { + deallocate(); + return true; + } + return retryRelease0(decrement); + } + return releaseNonFinal0(decrement, rawCnt, realCnt); + } + + private boolean releaseNonFinal0(int decrement, int rawCnt, int realCnt) { + if (decrement < realCnt + // all changes to the raw count are 2x the "real" change + && refCntUpdater.compareAndSet(this, rawCnt, rawCnt - (decrement << 1))) { + return false; + } + return retryRelease0(decrement); + } + + private boolean retryRelease0(int decrement) { + for (;;) { + int rawCnt = refCntUpdater.get(this), realCnt = toLiveRealCnt(rawCnt, decrement); + if (decrement == realCnt) { + if (refCntUpdater.compareAndSet(this, rawCnt, 1)) { + deallocate(); + return true; + } + } else if (decrement < realCnt) { + // all changes to the raw count are 2x the "real" change + if (refCntUpdater.compareAndSet(this, rawCnt, rawCnt - (decrement << 1))) { + return false; + } + } else { + throw new IllegalReferenceCountException(realCnt, -decrement); + } + Thread.yield(); // this benefits throughput under high contention + } + } + + /** + * Like {@link #realRefCnt(int)} but throws if refCnt == 0 + */ + private static int toLiveRealCnt(int rawCnt, int decrement) { + if ((rawCnt & 1) == 0) { + return rawCnt >>> 1; } - return false; + // odd rawCnt => already deallocated + throw new IllegalReferenceCountException(0, -decrement); } /** diff --git a/common/src/main/java/io/netty/util/AsciiString.java b/common/src/main/java/io/netty/util/AsciiString.java index eca0288c2ea8..302b36e89bc6 100644 --- a/common/src/main/java/io/netty/util/AsciiString.java +++ b/common/src/main/java/io/netty/util/AsciiString.java @@ -146,7 +146,7 @@ public AsciiString(ByteBuffer value, int start, int length, boolean copy) { this.offset = start; } } else { - this.value = new byte[length]; + this.value = PlatformDependent.allocateUninitializedArray(length); int oldPos = value.position(); value.get(this.value, 0, length); value.position(oldPos); @@ -172,7 +172,7 @@ public AsciiString(char[] value, int start, int length) { + ") <= " + "value.length(" + value.length + ')'); } - this.value = new byte[length]; + this.value = PlatformDependent.allocateUninitializedArray(length); for (int i = 0, j = start; i < length; i++, j++) { this.value[i] = c2b(value[j]); } @@ -219,7 +219,7 @@ public AsciiString(CharSequence value, int start, int length) { + ") <= " + "value.length(" + value.length() + ')'); } - this.value = new byte[length]; + this.value = PlatformDependent.allocateUninitializedArray(length); for (int i = 0, j = start; i < length; i++, j++) { this.value[i] = c2b(value.charAt(j)); } @@ -483,7 +483,7 @@ public AsciiString concat(CharSequence string) { return that; } - byte[] newValue = new byte[thisLen + thatLen]; + byte[] newValue = PlatformDependent.allocateUninitializedArray(thisLen + thatLen); System.arraycopy(value, arrayOffset(), newValue, 0, thisLen); System.arraycopy(that.value, that.arrayOffset(), newValue, thisLen, thatLen); return new AsciiString(newValue, false); @@ -493,7 +493,7 @@ public AsciiString concat(CharSequence string) { return new AsciiString(string); } - byte[] newValue = new byte[thisLen + thatLen]; + byte[] newValue = PlatformDependent.allocateUninitializedArray(thisLen + thatLen); System.arraycopy(value, arrayOffset(), newValue, 0, thisLen); for (int i = thisLen, j = 0; i < newValue.length; i++, j++) { newValue[i] = c2b(string.charAt(j)); @@ -722,7 +722,7 @@ public int indexOf(char ch, int start) { } final byte chAsByte = c2b0(ch); - final int len = offset + start + length; + final int len = offset + length; for (int i = start + offset; i < len; ++i) { if (value[i] == chAsByte) { return i - offset; @@ -881,7 +881,7 @@ public AsciiString replace(char oldChar, char newChar) { final int len = offset + length; for (int i = offset; i < len; ++i) { if (value[i] == oldCharAsByte) { - byte[] buffer = new byte[length()]; + byte[] buffer = PlatformDependent.allocateUninitializedArray(length()); System.arraycopy(value, offset, buffer, 0, i - offset); buffer[i - offset] = newCharAsByte; ++i; @@ -942,7 +942,7 @@ public AsciiString toLowerCase() { return this; } - final byte[] newValue = new byte[length()]; + final byte[] newValue = PlatformDependent.allocateUninitializedArray(length()); for (i = 0, j = arrayOffset(); i < newValue.length; ++i, ++j) { newValue[i] = toLowerCase(value[j]); } @@ -972,7 +972,7 @@ public AsciiString toUpperCase() { return this; } - final byte[] newValue = new byte[length()]; + final byte[] newValue = PlatformDependent.allocateUninitializedArray(length()); for (i = 0, j = arrayOffset(); i < newValue.length; ++i, ++j) { newValue[i] = toUpperCase(value[j]); } @@ -1115,7 +1115,7 @@ public AsciiString[] split(char delim) { } } - return res.toArray(new AsciiString[res.size()]); + return res.toArray(new AsciiString[0]); } /** @@ -1452,8 +1452,8 @@ public static boolean contentEqualsIgnoreCase(CharSequence a, CharSequence b) { if (a.length() != b.length()) { return false; } - for (int i = 0, j = 0; i < a.length(); ++i, ++j) { - if (!equalsIgnoreCase(a.charAt(i), b.charAt(j))) { + for (int i = 0; i < a.length(); ++i) { + if (!equalsIgnoreCase(a.charAt(i), b.charAt(i))) { return false; } } @@ -1827,7 +1827,13 @@ private static byte toLowerCase(byte b) { return isUpperCase(b) ? (byte) (b + 32) : b; } - private static char toLowerCase(char c) { + /** + * If the character is uppercase - converts the character to lowercase, + * otherwise returns the character as it is. Only for ASCII characters. + * + * @return lowercase ASCII character equivalent + */ + public static char toLowerCase(char c) { return isUpperCase(c) ? (char) (c + 32) : c; } diff --git a/common/src/main/java/io/netty/util/CharsetUtil.java b/common/src/main/java/io/netty/util/CharsetUtil.java index 4d71b0a7082b..a9317e5497b0 100644 --- a/common/src/main/java/io/netty/util/CharsetUtil.java +++ b/common/src/main/java/io/netty/util/CharsetUtil.java @@ -65,7 +65,9 @@ public final class CharsetUtil { private static final Charset[] CHARSETS = new Charset[] { UTF_16, UTF_16BE, UTF_16LE, UTF_8, ISO_8859_1, US_ASCII }; - public static Charset[] values() { return CHARSETS; } + public static Charset[] values() { + return CHARSETS; + } /** * @deprecated Use {@link #encoder(Charset)}. diff --git a/common/src/main/java/io/netty/util/HashedWheelTimer.java b/common/src/main/java/io/netty/util/HashedWheelTimer.java index b62ace4695ed..e8e72b63872e 100644 --- a/common/src/main/java/io/netty/util/HashedWheelTimer.java +++ b/common/src/main/java/io/netty/util/HashedWheelTimer.java @@ -84,6 +84,7 @@ public class HashedWheelTimer implements Timer { private static final AtomicInteger INSTANCE_COUNTER = new AtomicInteger(); private static final AtomicBoolean WARNED_TOO_MANY_INSTANCES = new AtomicBoolean(); private static final int INSTANCE_COUNT_LIMIT = 64; + private static final long MILLISECOND_NANOS = TimeUnit.MILLISECONDS.toNanos(1); private static final ResourceLeakDetector leakDetector = ResourceLeakDetectorFactory.instance() .newResourceLeakDetector(HashedWheelTimer.class, 1); @@ -259,14 +260,25 @@ public HashedWheelTimer( mask = wheel.length - 1; // Convert tickDuration to nanos. - this.tickDuration = unit.toNanos(tickDuration); + long duration = unit.toNanos(tickDuration); // Prevent overflow. - if (this.tickDuration >= Long.MAX_VALUE / wheel.length) { + if (duration >= Long.MAX_VALUE / wheel.length) { throw new IllegalArgumentException(String.format( "tickDuration: %d (expected: 0 < tickDuration in nanos < %d", tickDuration, Long.MAX_VALUE / wheel.length)); } + + if (duration < MILLISECOND_NANOS) { + if (logger.isWarnEnabled()) { + logger.warn("Configured tickDuration %d smaller then %d, using 1ms.", + tickDuration, MILLISECOND_NANOS); + } + this.tickDuration = MILLISECOND_NANOS; + } else { + this.tickDuration = duration; + } + workerThread = threadFactory.newThread(worker); leak = leakDetection || !workerThread.isDaemon() ? leakDetector.track(this) : null; @@ -437,10 +449,12 @@ public long pendingTimeouts() { } private static void reportTooManyInstances() { - String resourceType = simpleClassName(HashedWheelTimer.class); - logger.error("You are creating too many " + resourceType + " instances. " + - resourceType + " is a shared resource that must be reused across the JVM," + - "so that only a few instances are created."); + if (logger.isErrorEnabled()) { + String resourceType = simpleClassName(HashedWheelTimer.class); + logger.error("You are creating too many " + resourceType + " instances. " + + resourceType + " is a shared resource that must be reused across the JVM," + + "so that only a few instances are created."); + } } private final class Worker implements Runnable { diff --git a/common/src/main/java/io/netty/util/HashingStrategy.java b/common/src/main/java/io/netty/util/HashingStrategy.java index f15f4a3299ba..957d765c4622 100644 --- a/common/src/main/java/io/netty/util/HashingStrategy.java +++ b/common/src/main/java/io/netty/util/HashingStrategy.java @@ -41,7 +41,7 @@ public interface HashingStrategy { * This method has the following restrictions: *
    *
  • reflexive - {@code equals(a, a)} should return true
  • - *
  • symmetric - {@code equals(a, b)} returns {@code true} iff {@code equals(b, a)} returns + *
  • symmetric - {@code equals(a, b)} returns {@code true} if {@code equals(b, a)} returns * {@code true}
  • *
  • transitive - if {@code equals(a, b)} returns {@code true} and {@code equals(a, c)} returns * {@code true} then {@code equals(b, c)} should also return {@code true}
  • diff --git a/common/src/main/java/io/netty/util/Recycler.java b/common/src/main/java/io/netty/util/Recycler.java index e84adc5a9a68..87b182e6bed3 100644 --- a/common/src/main/java/io/netty/util/Recycler.java +++ b/common/src/main/java/io/netty/util/Recycler.java @@ -17,7 +17,6 @@ package io.netty.util; import io.netty.util.concurrent.FastThreadLocal; -import io.netty.util.internal.ObjectCleaner; import io.netty.util.internal.SystemPropertyUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -217,6 +216,12 @@ public void recycle(Object object) { if (object != value) { throw new IllegalArgumentException("object does not belong to handle"); } + + Stack stack = this.stack; + if (lastRecycledId != recycleId || stack == null) { + throw new IllegalStateException("recycled already"); + } + stack.push(this); } } @@ -244,10 +249,9 @@ static final class Link extends AtomicInteger { Link next; } - // This act as a place holder for the head Link but also will be used by the ObjectCleaner - // to return space that was before reserved. Its important this does not hold any reference to - // either Stack or WeakOrderQueue. - static final class Head implements Runnable { + // This act as a place holder for the head Link but also will reclaim space once finalized. + // Its important this does not hold any reference to either Stack or WeakOrderQueue. + static final class Head { private final AtomicInteger availableSharedCapacity; Link link; @@ -256,12 +260,21 @@ static final class Head implements Runnable { this.availableSharedCapacity = availableSharedCapacity; } + /// TODO: In the future when we move to Java9+ we should use java.lang.ref.Cleaner. @Override - public void run() { - Link head = link; - while (head != null) { - reclaimSpace(LINK_CAPACITY); - head = head.next; + protected void finalize() throws Throwable { + try { + super.finalize(); + } finally { + Link head = link; + link = null; + while (head != null) { + reclaimSpace(LINK_CAPACITY); + Link next = head.next; + // Unlink to help GC and guard against GC nepotism. + head.next = null; + head = next; + } } } @@ -318,12 +331,6 @@ static WeakOrderQueue newQueue(Stack stack, Thread thread) { // may be accessed while its still constructed. stack.setHead(queue); - // We need to reclaim all space that was reserved by this WeakOrderQueue so we not run out of space in - // the stack. This is needed as we not have a good life-time control over the queue as it is used in a - // WeakHashMap which will drop it at any time. - final Head head = queue.head; - ObjectCleaner.register(queue, head); - return queue; } diff --git a/common/src/main/java/io/netty/util/ResourceLeakDetector.java b/common/src/main/java/io/netty/util/ResourceLeakDetector.java index c29320adb7d3..e7fc140c67c6 100644 --- a/common/src/main/java/io/netty/util/ResourceLeakDetector.java +++ b/common/src/main/java/io/netty/util/ResourceLeakDetector.java @@ -26,8 +26,10 @@ import java.lang.ref.ReferenceQueue; import java.lang.reflect.Method; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicReference; @@ -46,7 +48,12 @@ public class ResourceLeakDetector { private static final String PROP_TARGET_RECORDS = "io.netty.leakDetection.targetRecords"; private static final int DEFAULT_TARGET_RECORDS = 4; + private static final String PROP_SAMPLING_INTERVAL = "io.netty.leakDetection.samplingInterval"; + // There is a minor performance benefit in TLR if this is a power of 2. + private static final int DEFAULT_SAMPLING_INTERVAL = 128; + private static final int TARGET_RECORDS; + static final int SAMPLING_INTERVAL; /** * Represents the level of resource leak detection. @@ -115,6 +122,7 @@ static Level parseLevel(String levelStr) { Level level = Level.parseLevel(levelStr); TARGET_RECORDS = SystemPropertyUtil.getInt(PROP_TARGET_RECORDS, DEFAULT_TARGET_RECORDS); + SAMPLING_INTERVAL = SystemPropertyUtil.getInt(PROP_SAMPLING_INTERVAL, DEFAULT_SAMPLING_INTERVAL); ResourceLeakDetector.level = level; if (logger.isDebugEnabled()) { @@ -123,9 +131,6 @@ static Level parseLevel(String levelStr) { } } - // There is a minor performance benefit in TLR if this is a power of 2. - static final int DEFAULT_SAMPLING_INTERVAL = 128; - /** * @deprecated Use {@link #setLevel(Level)} instead. */ @@ -159,7 +164,8 @@ public static Level getLevel() { } /** the collection of active resources */ - private final ConcurrentMap, LeakEntry> allLeaks = PlatformDependent.newConcurrentHashMap(); + private final Set> allLeaks = + Collections.newSetFromMap(new ConcurrentHashMap, Boolean>()); private final ReferenceQueue refQueue = new ReferenceQueue(); private final ConcurrentMap reportedLeaks = PlatformDependent.newConcurrentHashMap(); @@ -353,13 +359,13 @@ private static final class DefaultResourceLeak @SuppressWarnings("unused") private volatile int droppedRecords; - private final ConcurrentMap, LeakEntry> allLeaks; + private final Set> allLeaks; private final int trackedHash; DefaultResourceLeak( Object referent, ReferenceQueue refQueue, - ConcurrentMap, LeakEntry> allLeaks) { + Set> allLeaks) { super(referent, refQueue); assert referent != null; @@ -368,7 +374,7 @@ private static final class DefaultResourceLeak // It's important that we not store a reference to the referent as this would disallow it from // be collected via the WeakReference. trackedHash = System.identityHashCode(referent); - allLeaks.put(this, LeakEntry.INSTANCE); + allLeaks.add(this); // Create a new Record so we always have the creation stacktrace included. headUpdater.set(this, new Record(Record.BOTTOM)); this.allLeaks = allLeaks; @@ -441,13 +447,12 @@ private void record0(Object hint) { boolean dispose() { clear(); - return allLeaks.remove(this, LeakEntry.INSTANCE); + return allLeaks.remove(this); } @Override public boolean close() { - // Use the ConcurrentMap remove method, which avoids allocating an iterator. - if (allLeaks.remove(this, LeakEntry.INSTANCE)) { + if (allLeaks.remove(this)) { // Call clear so the reference is not even enqueued. clear(); headUpdater.set(this, null); @@ -461,11 +466,42 @@ public boolean close(T trackedObject) { // Ensure that the object that was tracked is the same as the one that was passed to close(...). assert trackedHash == System.identityHashCode(trackedObject); - // We need to actually do the null check of the trackedObject after we close the leak because otherwise - // we may get false-positives reported by the ResourceLeakDetector. This can happen as the JIT / GC may - // be able to figure out that we do not need the trackedObject anymore and so already enqueue it for - // collection before we actually get a chance to close the enclosing ResourceLeak. - return close() && trackedObject != null; + try { + return close(); + } finally { + // This method will do `synchronized(trackedObject)` and we should be sure this will not cause deadlock. + // It should not, because somewhere up the callstack should be a (successful) `trackedObject.release`, + // therefore it is unreasonable that anyone else, anywhere, is holding a lock on the trackedObject. + // (Unreasonable but possible, unfortunately.) + reachabilityFence0(trackedObject); + } + } + + /** + * Ensures that the object referenced by the given reference remains + * strongly reachable, + * regardless of any prior actions of the program that might otherwise cause + * the object to become unreachable; thus, the referenced object is not + * reclaimable by garbage collection at least until after the invocation of + * this method. + * + *

    Recent versions of the JDK have a nasty habit of prematurely deciding objects are unreachable. + * see: https://stackoverflow.com/questions/26642153/finalize-called-on-strongly-reachable-object-in-java-8 + * The Java 9 method Reference.reachabilityFence offers a solution to this problem. + * + *

    This method is always implemented as a synchronization on {@code ref}, not as + * {@code Reference.reachabilityFence} for consistency across platforms and to allow building on JDK 6-8. + * It is the caller's responsibility to ensure that this synchronization will not cause deadlock. + * + * @param ref the reference. If {@code null}, this method has no effect. + * @see java.lang.ref.Reference#reachabilityFence + */ + private static void reachabilityFence0(Object ref) { + if (ref != null) { + synchronized (ref) { + // Empty synchronized is ok: https://stackoverflow.com/a/31933260/1151521 + } + } } @Override @@ -501,7 +537,7 @@ public String toString() { if (duped > 0) { buf.append(": ") - .append(dropped) + .append(duped) .append(" leak records were discarded because they were duplicates") .append(NEWLINE); } @@ -606,22 +642,4 @@ public String toString() { return buf.toString(); } } - - private static final class LeakEntry { - static final LeakEntry INSTANCE = new LeakEntry(); - private static final int HASH = System.identityHashCode(INSTANCE); - - private LeakEntry() { - } - - @Override - public int hashCode() { - return HASH; - } - - @Override - public boolean equals(Object obj) { - return obj == this; - } - } } diff --git a/common/src/main/java/io/netty/util/ResourceLeakDetectorFactory.java b/common/src/main/java/io/netty/util/ResourceLeakDetectorFactory.java index 5ca63f08d304..65f1fc8c654d 100644 --- a/common/src/main/java/io/netty/util/ResourceLeakDetectorFactory.java +++ b/common/src/main/java/io/netty/util/ResourceLeakDetectorFactory.java @@ -62,7 +62,7 @@ public static void setResourceLeakDetectorFactory(ResourceLeakDetectorFactory fa * @return a new instance of {@link ResourceLeakDetector} */ public final ResourceLeakDetector newResourceLeakDetector(Class resource) { - return newResourceLeakDetector(resource, ResourceLeakDetector.DEFAULT_SAMPLING_INTERVAL); + return newResourceLeakDetector(resource, ResourceLeakDetector.SAMPLING_INTERVAL); } /** @@ -90,7 +90,7 @@ public abstract ResourceLeakDetector newResourceLeakDetector( */ @SuppressWarnings("deprecation") public ResourceLeakDetector newResourceLeakDetector(Class resource, int samplingInterval) { - return newResourceLeakDetector(resource, ResourceLeakDetector.DEFAULT_SAMPLING_INTERVAL, Long.MAX_VALUE); + return newResourceLeakDetector(resource, ResourceLeakDetector.SAMPLING_INTERVAL, Long.MAX_VALUE); } /** diff --git a/common/src/main/java/io/netty/util/concurrent/AbstractScheduledEventExecutor.java b/common/src/main/java/io/netty/util/concurrent/AbstractScheduledEventExecutor.java index 742bffb029c4..f563ddcb8437 100644 --- a/common/src/main/java/io/netty/util/concurrent/AbstractScheduledEventExecutor.java +++ b/common/src/main/java/io/netty/util/concurrent/AbstractScheduledEventExecutor.java @@ -38,7 +38,7 @@ public int compare(ScheduledFutureTask o1, ScheduledFutureTask o2) { } }; - PriorityQueue> scheduledTaskQueue; + protected PriorityQueue> scheduledTaskQueue; protected AbstractScheduledEventExecutor() { } @@ -78,7 +78,7 @@ protected void cancelScheduledTasks() { } final ScheduledFutureTask[] scheduledTasks = - scheduledTaskQueue.toArray(new ScheduledFutureTask[scheduledTaskQueue.size()]); + scheduledTaskQueue.toArray(new ScheduledFutureTask[0]); for (ScheduledFutureTask task: scheduledTasks) { task.cancelWithoutRemove(false); @@ -159,7 +159,7 @@ public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) if (delay < 0) { delay = 0; } - validateScheduled(delay, unit); + validateScheduled0(delay, unit); return schedule(new ScheduledFutureTask( this, command, null, ScheduledFutureTask.deadlineNanos(unit.toNanos(delay)))); @@ -172,7 +172,7 @@ public ScheduledFuture schedule(Callable callable, long delay, TimeUni if (delay < 0) { delay = 0; } - validateScheduled(delay, unit); + validateScheduled0(delay, unit); return schedule(new ScheduledFutureTask( this, callable, ScheduledFutureTask.deadlineNanos(unit.toNanos(delay)))); @@ -190,8 +190,8 @@ public ScheduledFuture scheduleAtFixedRate(Runnable command, long initialDela throw new IllegalArgumentException( String.format("period: %d (expected: > 0)", period)); } - validateScheduled(initialDelay, unit); - validateScheduled(period, unit); + validateScheduled0(initialDelay, unit); + validateScheduled0(period, unit); return schedule(new ScheduledFutureTask( this, Executors.callable(command, null), @@ -211,17 +211,25 @@ public ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialD String.format("delay: %d (expected: > 0)", delay)); } - validateScheduled(initialDelay, unit); - validateScheduled(delay, unit); + validateScheduled0(initialDelay, unit); + validateScheduled0(delay, unit); return schedule(new ScheduledFutureTask( this, Executors.callable(command, null), ScheduledFutureTask.deadlineNanos(unit.toNanos(initialDelay)), -unit.toNanos(delay))); } + @SuppressWarnings("deprecation") + private void validateScheduled0(long amount, TimeUnit unit) { + validateScheduled(amount, unit); + } + /** * Sub-classes may override this to restrict the maximal amount of time someone can use to schedule a task. + * + * @deprecated will be removed in the future. */ + @Deprecated protected void validateScheduled(long amount, TimeUnit unit) { // NOOP } diff --git a/common/src/main/java/io/netty/util/concurrent/DefaultPromise.java b/common/src/main/java/io/netty/util/concurrent/DefaultPromise.java index 2f4dc7999330..99f946a831e6 100644 --- a/common/src/main/java/io/netty/util/concurrent/DefaultPromise.java +++ b/common/src/main/java/io/netty/util/concurrent/DefaultPromise.java @@ -91,7 +91,6 @@ protected DefaultPromise() { @Override public Promise setSuccess(V result) { if (setSuccess0(result)) { - notifyListeners(); return this; } throw new IllegalStateException("complete already: " + this); @@ -99,17 +98,12 @@ public Promise setSuccess(V result) { @Override public boolean trySuccess(V result) { - if (setSuccess0(result)) { - notifyListeners(); - return true; - } - return false; + return setSuccess0(result); } @Override public Promise setFailure(Throwable cause) { if (setFailure0(cause)) { - notifyListeners(); return this; } throw new IllegalStateException("complete already: " + this, cause); @@ -117,11 +111,7 @@ public Promise setFailure(Throwable cause) { @Override public boolean tryFailure(Throwable cause) { - if (setFailure0(cause)) { - notifyListeners(); - return true; - } - return false; + return setFailure0(cause); } @Override @@ -301,7 +291,7 @@ public boolean awaitUninterruptibly(long timeoutMillis) { @Override public V getNow() { Object result = this.result; - if (result instanceof CauseHolder || result == SUCCESS) { + if (result instanceof CauseHolder || result == SUCCESS || result == UNCANCELLABLE) { return null; } return (V) result; @@ -315,8 +305,9 @@ public V getNow() { @Override public boolean cancel(boolean mayInterruptIfRunning) { if (RESULT_UPDATER.compareAndSet(this, null, CANCELLATION_CAUSE_HOLDER)) { - checkNotifyWaiters(); - notifyListeners(); + if (checkNotifyWaiters()) { + notifyListeners(); + } return true; } return false; @@ -510,7 +501,9 @@ private static void notifyListener0(Future future, GenericFutureListener l) { try { l.operationComplete(future); } catch (Throwable t) { - logger.warn("An exception was thrown by " + l.getClass().getName() + ".operationComplete()", t); + if (logger.isWarnEnabled()) { + logger.warn("An exception was thrown by " + l.getClass().getName() + ".operationComplete()", t); + } } } @@ -543,16 +536,23 @@ private boolean setFailure0(Throwable cause) { private boolean setValue0(Object objResult) { if (RESULT_UPDATER.compareAndSet(this, null, objResult) || RESULT_UPDATER.compareAndSet(this, UNCANCELLABLE, objResult)) { - checkNotifyWaiters(); + if (checkNotifyWaiters()) { + notifyListeners(); + } return true; } return false; } - private synchronized void checkNotifyWaiters() { + /** + * Check if there are any waiters and if so notify these. + * @return {@code true} if there are any listeners attached to the promise, {@code false} otherwise. + */ + private synchronized boolean checkNotifyWaiters() { if (waiters > 0) { notifyAll(); } + return listeners != null; } private void incWaiters() { @@ -740,7 +740,9 @@ private static void notifyProgressiveListener0( try { l.operationProgressed(future, progress, total); } catch (Throwable t) { - logger.warn("An exception was thrown by " + l.getClass().getName() + ".operationProgressed()", t); + if (logger.isWarnEnabled()) { + logger.warn("An exception was thrown by " + l.getClass().getName() + ".operationProgressed()", t); + } } } diff --git a/common/src/main/java/io/netty/util/concurrent/FastThreadLocal.java b/common/src/main/java/io/netty/util/concurrent/FastThreadLocal.java index 561058abca07..9d808c7f45d5 100644 --- a/common/src/main/java/io/netty/util/concurrent/FastThreadLocal.java +++ b/common/src/main/java/io/netty/util/concurrent/FastThreadLocal.java @@ -16,7 +16,6 @@ package io.netty.util.concurrent; import io.netty.util.internal.InternalThreadLocalMap; -import io.netty.util.internal.ObjectCleaner; import io.netty.util.internal.PlatformDependent; import java.util.Collections; @@ -63,7 +62,7 @@ public static void removeAll() { @SuppressWarnings("unchecked") Set> variablesToRemove = (Set>) v; FastThreadLocal[] variablesToRemoveArray = - variablesToRemove.toArray(new FastThreadLocal[variablesToRemove.size()]); + variablesToRemove.toArray(new FastThreadLocal[0]); for (FastThreadLocal tlv: variablesToRemoveArray) { tlv.remove(threadLocalMap); } @@ -153,6 +152,8 @@ private void registerCleaner(final InternalThreadLocalMap threadLocalMap) { threadLocalMap.setCleanerFlag(index); + // TODO: We need to find a better way to handle this. + /* // We will need to ensure we will trigger remove(InternalThreadLocalMap) so everything will be released // and FastThreadLocal.onRemoval(...) will be called. ObjectCleaner.register(current, new Runnable() { @@ -164,6 +165,7 @@ public void run() { // the Thread is collected by GC. In this case the ThreadLocal will be gone away already. } }); + */ } /** @@ -281,7 +283,9 @@ protected V initialValue() throws Exception { } /** - * Invoked when this thread local variable is removed by {@link #remove()}. + * Invoked when this thread local variable is removed by {@link #remove()}. Be aware that {@link #remove()} + * is not guaranteed to be called when the `Thread` completes which means you can not depend on this for + * cleanup of the resources in the case of `Thread` completion. */ protected void onRemoval(@SuppressWarnings("UnusedParameters") V value) throws Exception { } } diff --git a/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java b/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java index ed3a5873ecae..bcc4b8299640 100644 --- a/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java +++ b/common/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java @@ -259,7 +259,24 @@ public void run() { } } else { state.set(NONE); - return; // done + // After setting the state to NONE, look at the tasks queue one more time. + // If it is empty, then we can return from this method. + // Otherwise, it means the producer thread has called execute(Runnable) + // and enqueued a task in between the tasks.poll() above and the state.set(NONE) here. + // There are two possible scenarios when this happen + // + // 1. The producer thread sees state == NONE, hence the compareAndSet(NONE, SUBMITTED) + // is successfully setting the state to SUBMITTED. This mean the producer + // will call / has called executor.execute(this). In this case, we can just return. + // 2. The producer thread don't see the state change, hence the compareAndSet(NONE, SUBMITTED) + // returns false. In this case, the producer thread won't call executor.execute. + // In this case, we need to change the state to RUNNING and keeps running. + // + // The above cases can be distinguished by performing a + // compareAndSet(NONE, RUNNING). If it returns "false", it is case 1; otherwise it is case 2. + if (tasks.peek() == null || !state.compareAndSet(NONE, RUNNING)) { + return; // done + } } } } diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java index 7d908c037949..215d70923280 100644 --- a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java +++ b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java @@ -20,7 +20,7 @@ import java.util.Set; /** - * @deprecated Use {@link PromiseCombiner} + * @deprecated Use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}. * * {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s * into one, by listening to individual {@link Future}s and producing an aggregated result diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java index 6624f05db1ef..8895c1a0e8e0 100644 --- a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java +++ b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java @@ -28,26 +28,62 @@ * {@link PromiseCombiner#add(Future)} and {@link PromiseCombiner#addAll(Future[])} methods. When all futures to be * combined have been added, callers must provide an aggregate promise to be notified when all combined promises have * finished via the {@link PromiseCombiner#finish(Promise)} method.

    + * + *

    This implementation is NOT thread-safe and all methods must be called + * from the {@link EventExecutor} thread.

    */ public final class PromiseCombiner { private int expectedCount; private int doneCount; - private boolean doneAdding; private Promise aggregatePromise; private Throwable cause; private final GenericFutureListener> listener = new GenericFutureListener>() { @Override - public void operationComplete(Future future) throws Exception { + public void operationComplete(final Future future) { + if (executor.inEventLoop()) { + operationComplete0(future); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + operationComplete0(future); + } + }); + } + } + + private void operationComplete0(Future future) { + assert executor.inEventLoop(); ++doneCount; if (!future.isSuccess() && cause == null) { cause = future.cause(); } - if (doneCount == expectedCount && doneAdding) { + if (doneCount == expectedCount && aggregatePromise != null) { tryPromise(); } } }; + private final EventExecutor executor; + + /** + * Deprecated use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}. + */ + @Deprecated + public PromiseCombiner() { + this(ImmediateEventExecutor.INSTANCE); + } + + /** + * The {@link EventExecutor} to use for notifications. You must call {@link #add(Future)}, {@link #addAll(Future[])} + * and {@link #finish(Promise)} from within the {@link EventExecutor} thread. + * + * @param executor the {@link EventExecutor} to use for notifications. + */ + public PromiseCombiner(EventExecutor executor) { + this.executor = ObjectUtil.checkNotNull(executor, "executor"); + } + /** * Adds a new promise to be combined. New promises may be added until an aggregate promise is added via the * {@link PromiseCombiner#finish(Promise)} method. @@ -70,6 +106,7 @@ public void add(Promise promise) { @SuppressWarnings({ "unchecked", "rawtypes" }) public void add(Future future) { checkAddAllowed(); + checkInEventLoop(); ++expectedCount; future.addListener(listener); } @@ -112,22 +149,29 @@ public void addAll(Future... futures) { * @param aggregatePromise the promise to notify when all combined futures have finished */ public void finish(Promise aggregatePromise) { - if (doneAdding) { + ObjectUtil.checkNotNull(aggregatePromise, "aggregatePromise"); + checkInEventLoop(); + if (this.aggregatePromise != null) { throw new IllegalStateException("Already finished"); } - doneAdding = true; - this.aggregatePromise = ObjectUtil.checkNotNull(aggregatePromise, "aggregatePromise"); + this.aggregatePromise = aggregatePromise; if (doneCount == expectedCount) { tryPromise(); } } + private void checkInEventLoop() { + if (!executor.inEventLoop()) { + throw new IllegalStateException("Must be called from EventExecutor thread"); + } + } + private boolean tryPromise() { return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause); } private void checkAddAllowed() { - if (doneAdding) { + if (aggregatePromise != null) { throw new IllegalStateException("Adding promises is not allowed after finished adding"); } } diff --git a/common/src/main/java/io/netty/util/concurrent/ScheduledFutureTask.java b/common/src/main/java/io/netty/util/concurrent/ScheduledFutureTask.java index 6043d6ffec0f..1eaa7b927625 100644 --- a/common/src/main/java/io/netty/util/concurrent/ScheduledFutureTask.java +++ b/common/src/main/java/io/netty/util/concurrent/ScheduledFutureTask.java @@ -35,7 +35,9 @@ static long nanoTime() { } static long deadlineNanos(long delay) { - return nanoTime() + delay; + long deadlineNanos = nanoTime() + delay; + // Guard against overflow + return deadlineNanos < 0 ? Long.MAX_VALUE : deadlineNanos; } private final long id = nextTaskId.getAndIncrement(); diff --git a/common/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java b/common/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java index 09197e6fdb0d..03ae3347aeed 100644 --- a/common/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java +++ b/common/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java @@ -304,7 +304,7 @@ protected boolean hasTasks() { * Return the number of tasks that are pending for processing. * * Be aware that this operation may be expensive as it depends on the internal implementation of the - * SingleThreadEventExecutor. So use it was care! + * SingleThreadEventExecutor. So use it with care! */ public int pendingTasks() { return taskQueue.size(); @@ -443,6 +443,19 @@ protected long delayNanos(long currentTimeNanos) { return scheduledTask.delayNanos(currentTimeNanos); } + /** + * Returns the absolute point in time (relative to {@link #nanoTime()}) at which the the next + * closest scheduled task should run. + */ + @UnstableApi + protected long deadlineNanos() { + ScheduledFutureTask scheduledTask = peekScheduledTask(); + if (scheduledTask == null) { + return nanoTime() + SCHEDULE_PURGE_INTERVAL; + } + return scheduledTask.deadlineNanos(); + } + /** * Updates the internal timestamp that tells when a submitted task was executed most recently. * {@link #runAllTasks()} and {@link #runAllTasks(long)} updates this timestamp automatically, and thus there's @@ -582,19 +595,8 @@ public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit uni gracefulShutdownQuietPeriod = unit.toNanos(quietPeriod); gracefulShutdownTimeout = unit.toNanos(timeout); - if (oldState == ST_NOT_STARTED) { - try { - doStartThread(); - } catch (Throwable cause) { - STATE_UPDATER.set(this, ST_TERMINATED); - terminationFuture.tryFailure(cause); - - if (!(cause instanceof Exception)) { - // Also rethrow as it may be an OOME for example - PlatformDependent.throwException(cause); - } - return terminationFuture; - } + if (ensureThreadStarted(oldState)) { + return terminationFuture; } if (wakeup) { @@ -645,19 +647,8 @@ public void shutdown() { } } - if (oldState == ST_NOT_STARTED) { - try { - doStartThread(); - } catch (Throwable cause) { - STATE_UPDATER.set(this, ST_TERMINATED); - terminationFuture.tryFailure(cause); - - if (!(cause instanceof Exception)) { - // Also rethrow as it may be an OOME for example - PlatformDependent.throwException(cause); - } - return; - } + if (ensureThreadStarted(oldState)) { + return; } if (wakeup) { @@ -765,8 +756,20 @@ public void execute(Runnable task) { addTask(task); if (!inEventLoop) { startThread(); - if (isShutdown() && removeTask(task)) { - reject(); + if (isShutdown()) { + boolean reject = false; + try { + if (removeTask(task)) { + reject = true; + } + } catch (UnsupportedOperationException e) { + // The task queue does not support removal so the best thing we can do is to just move on and + // hope we will be able to pick-up the task before its completely terminated. + // In worst case we will log on termination. + } + if (reject) { + reject(); + } } } @@ -810,7 +813,7 @@ private void throwIfInEventLoop(String method) { /** * Returns the {@link ThreadProperties} of the {@link Thread} that powers the {@link SingleThreadEventExecutor}. - * If the {@link SingleThreadEventExecutor} is not started yet, this operation will start it and block until the + * If the {@link SingleThreadEventExecutor} is not started yet, this operation will start it and block until * it is fully started. */ public final ThreadProperties threadProperties() { @@ -868,6 +871,24 @@ private void startThread() { } } + private boolean ensureThreadStarted(int oldState) { + if (oldState == ST_NOT_STARTED) { + try { + doStartThread(); + } catch (Throwable cause) { + STATE_UPDATER.set(this, ST_TERMINATED); + terminationFuture.tryFailure(cause); + + if (!(cause instanceof Exception)) { + // Also rethrow as it may be an OOME for example + PlatformDependent.throwException(cause); + } + return true; + } + } + return false; + } + private void doStartThread() { assert thread == null; executor.execute(new Runnable() { @@ -896,9 +917,11 @@ public void run() { // Check if confirmShutdown() was called at the end of the loop. if (success && gracefulShutdownStartTime == 0) { - logger.error("Buggy " + EventExecutor.class.getSimpleName() + " implementation; " + - SingleThreadEventExecutor.class.getSimpleName() + ".confirmShutdown() must be called " + - "before run() implementation terminates."); + if (logger.isErrorEnabled()) { + logger.error("Buggy " + EventExecutor.class.getSimpleName() + " implementation; " + + SingleThreadEventExecutor.class.getSimpleName() + ".confirmShutdown() must " + + "be called before run() implementation terminates."); + } } try { @@ -912,14 +935,20 @@ public void run() { try { cleanup(); } finally { + // Lets remove all FastThreadLocals for the Thread as we are about to terminate and notify + // the future. The user may block on the future and once it unblocks the JVM may terminate + // and start unloading classes. + // See https://github.com/netty/netty/issues/6596. + FastThreadLocal.removeAll(); + STATE_UPDATER.set(SingleThreadEventExecutor.this, ST_TERMINATED); threadLock.release(); if (!taskQueue.isEmpty()) { - logger.warn( - "An event executor terminated with " + - "non-empty task queue (" + taskQueue.size() + ')'); + if (logger.isWarnEnabled()) { + logger.warn("An event executor terminated with " + + "non-empty task queue (" + taskQueue.size() + ')'); + } } - terminationFuture.setSuccess(null); } } diff --git a/common/src/main/java/io/netty/util/internal/AppendableCharSequence.java b/common/src/main/java/io/netty/util/internal/AppendableCharSequence.java index 408c32f38027..e8e6abf56613 100644 --- a/common/src/main/java/io/netty/util/internal/AppendableCharSequence.java +++ b/common/src/main/java/io/netty/util/internal/AppendableCharSequence.java @@ -63,6 +63,12 @@ public char charAtUnsafe(int index) { @Override public AppendableCharSequence subSequence(int start, int end) { + if (start == end) { + // If start and end index is the same we need to return an empty sequence to conform to the interface. + // As our expanding logic depends on the fact that we have a char[] with length > 0 we need to construct + // an instance for which this is true. + return new AppendableCharSequence(Math.min(16, chars.length)); + } return new AppendableCharSequence(Arrays.copyOfRange(chars, start, end)); } diff --git a/common/src/main/java/io/netty/util/internal/CleanerJava6.java b/common/src/main/java/io/netty/util/internal/CleanerJava6.java index 8bbdbf724016..383582360332 100644 --- a/common/src/main/java/io/netty/util/internal/CleanerJava6.java +++ b/common/src/main/java/io/netty/util/internal/CleanerJava6.java @@ -21,6 +21,8 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.nio.ByteBuffer; +import java.security.AccessController; +import java.security.PrivilegedAction; /** @@ -32,41 +34,72 @@ final class CleanerJava6 implements Cleaner { private static final long CLEANER_FIELD_OFFSET; private static final Method CLEAN_METHOD; + private static final Field CLEANER_FIELD; private static final InternalLogger logger = InternalLoggerFactory.getInstance(CleanerJava6.class); static { - long fieldOffset = -1; - Method clean = null; + long fieldOffset; + Method clean; + Field cleanerField; Throwable error = null; - if (PlatformDependent0.hasUnsafe()) { - ByteBuffer direct = ByteBuffer.allocateDirect(1); - try { - Field cleanerField = direct.getClass().getDeclaredField("cleaner"); + final ByteBuffer direct = ByteBuffer.allocateDirect(1); + try { + Object mayBeCleanerField = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Object run() { + try { + Field cleanerField = direct.getClass().getDeclaredField("cleaner"); + if (!PlatformDependent.hasUnsafe()) { + // We need to make it accessible if we do not use Unsafe as we will access it via + // reflection. + cleanerField.setAccessible(true); + } + return cleanerField; + } catch (Throwable cause) { + return cause; + } + } + }); + if (mayBeCleanerField instanceof Throwable) { + throw (Throwable) mayBeCleanerField; + } + + cleanerField = (Field) mayBeCleanerField; + + final Object cleaner; + + // If we have sun.misc.Unsafe we will use it as its faster then using reflection, + // otherwise let us try reflection as last resort. + if (PlatformDependent.hasUnsafe()) { fieldOffset = PlatformDependent0.objectFieldOffset(cleanerField); - Object cleaner = PlatformDependent0.getObject(direct, fieldOffset); - clean = cleaner.getClass().getDeclaredMethod("clean"); - clean.invoke(cleaner); - } catch (Throwable t) { - // We don't have ByteBuffer.cleaner(). + cleaner = PlatformDependent0.getObject(direct, fieldOffset); + } else { fieldOffset = -1; - clean = null; - error = t; + cleaner = cleanerField.get(direct); } - } else { - error = new UnsupportedOperationException("sun.misc.Unsafe unavailable"); + clean = cleaner.getClass().getDeclaredMethod("clean"); + clean.invoke(cleaner); + } catch (Throwable t) { + // We don't have ByteBuffer.cleaner(). + fieldOffset = -1; + clean = null; + error = t; + cleanerField = null; } + if (error == null) { logger.debug("java.nio.ByteBuffer.cleaner(): available"); } else { logger.debug("java.nio.ByteBuffer.cleaner(): unavailable", error); } + CLEANER_FIELD = cleanerField; CLEANER_FIELD_OFFSET = fieldOffset; CLEAN_METHOD = clean; } static boolean isSupported() { - return CLEANER_FIELD_OFFSET != -1; + return CLEANER_FIELD_OFFSET != -1 || CLEANER_FIELD != null; } @Override @@ -74,13 +107,45 @@ public void freeDirectBuffer(ByteBuffer buffer) { if (!buffer.isDirect()) { return; } - try { - Object cleaner = PlatformDependent0.getObject(buffer, CLEANER_FIELD_OFFSET); - if (cleaner != null) { - CLEAN_METHOD.invoke(cleaner); + if (System.getSecurityManager() == null) { + try { + freeDirectBuffer0(buffer); + } catch (Throwable cause) { + PlatformDependent0.throwException(cause); } - } catch (Throwable cause) { + } else { + freeDirectBufferPrivileged(buffer); + } + } + + private static void freeDirectBufferPrivileged(final ByteBuffer buffer) { + Throwable cause = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Throwable run() { + try { + freeDirectBuffer0(buffer); + return null; + } catch (Throwable cause) { + return cause; + } + } + }); + if (cause != null) { PlatformDependent0.throwException(cause); } } + + private static void freeDirectBuffer0(ByteBuffer buffer) throws Exception { + final Object cleaner; + // If CLEANER_FIELD_OFFSET == -1 we need to use reflection to access the cleaner, otherwise we can use + // sun.misc.Unsafe. + if (CLEANER_FIELD_OFFSET == -1) { + cleaner = CLEANER_FIELD.get(buffer); + } else { + cleaner = PlatformDependent0.getObject(buffer, CLEANER_FIELD_OFFSET); + } + if (cleaner != null) { + CLEAN_METHOD.invoke(cleaner); + } + } } diff --git a/common/src/main/java/io/netty/util/internal/CleanerJava9.java b/common/src/main/java/io/netty/util/internal/CleanerJava9.java index d111a7e54a58..3a0c5854c38b 100644 --- a/common/src/main/java/io/netty/util/internal/CleanerJava9.java +++ b/common/src/main/java/io/netty/util/internal/CleanerJava9.java @@ -21,6 +21,8 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.nio.ByteBuffer; +import java.security.AccessController; +import java.security.PrivilegedAction; /** * Provide a way to clean a ByteBuffer on Java9+. @@ -34,20 +36,26 @@ final class CleanerJava9 implements Cleaner { final Method method; final Throwable error; if (PlatformDependent0.hasUnsafe()) { - ByteBuffer buffer = ByteBuffer.allocateDirect(1); - Object maybeInvokeMethod; - try { - // See https://bugs.openjdk.java.net/browse/JDK-8171377 - Method m = PlatformDependent0.UNSAFE.getClass().getDeclaredMethod("invokeCleaner", ByteBuffer.class); - m.invoke(PlatformDependent0.UNSAFE, buffer); - maybeInvokeMethod = m; - } catch (NoSuchMethodException e) { - maybeInvokeMethod = e; - } catch (InvocationTargetException e) { - maybeInvokeMethod = e; - } catch (IllegalAccessException e) { - maybeInvokeMethod = e; - } + final ByteBuffer buffer = ByteBuffer.allocateDirect(1); + Object maybeInvokeMethod = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Object run() { + try { + // See https://bugs.openjdk.java.net/browse/JDK-8171377 + Method m = PlatformDependent0.UNSAFE.getClass().getDeclaredMethod( + "invokeCleaner", ByteBuffer.class); + m.invoke(PlatformDependent0.UNSAFE, buffer); + return m; + } catch (NoSuchMethodException e) { + return e; + } catch (InvocationTargetException e) { + return e; + } catch (IllegalAccessException e) { + return e; + } + } + }); + if (maybeInvokeMethod instanceof Throwable) { method = null; error = (Throwable) maybeInvokeMethod; @@ -73,10 +81,35 @@ static boolean isSupported() { @Override public void freeDirectBuffer(ByteBuffer buffer) { - try { - INVOKE_CLEANER.invoke(PlatformDependent0.UNSAFE, buffer); - } catch (Throwable cause) { - PlatformDependent0.throwException(cause); + // Try to minimize overhead when there is no SecurityManager present. + // See https://bugs.openjdk.java.net/browse/JDK-8191053. + if (System.getSecurityManager() == null) { + try { + INVOKE_CLEANER.invoke(PlatformDependent0.UNSAFE, buffer); + } catch (Throwable cause) { + PlatformDependent0.throwException(cause); + } + } else { + freeDirectBufferPrivileged(buffer); + } + } + + private static void freeDirectBufferPrivileged(final ByteBuffer buffer) { + Exception error = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Exception run() { + try { + INVOKE_CLEANER.invoke(PlatformDependent0.UNSAFE, buffer); + } catch (InvocationTargetException e) { + return e; + } catch (IllegalAccessException e) { + return e; + } + return null; + } + }); + if (error != null) { + PlatformDependent0.throwException(error); } } } diff --git a/common/src/main/java/io/netty/util/internal/ConcurrentSet.java b/common/src/main/java/io/netty/util/internal/ConcurrentSet.java index 52f8c124d4ee..a735f4e430b6 100644 --- a/common/src/main/java/io/netty/util/internal/ConcurrentSet.java +++ b/common/src/main/java/io/netty/util/internal/ConcurrentSet.java @@ -20,6 +20,10 @@ import java.util.Iterator; import java.util.concurrent.ConcurrentMap; +/** + * @deprecated For removal in Netty 4.2. Please use {@link ConcurrentHashMap#newKeySet()} instead + */ +@Deprecated public final class ConcurrentSet extends AbstractSet implements Serializable { private static final long serialVersionUID = -6761513279741915432L; diff --git a/common/src/main/java/io/netty/util/internal/IntegerHolder.java b/common/src/main/java/io/netty/util/internal/IntegerHolder.java index 2a8d069a2747..e19335ccf059 100644 --- a/common/src/main/java/io/netty/util/internal/IntegerHolder.java +++ b/common/src/main/java/io/netty/util/internal/IntegerHolder.java @@ -16,6 +16,10 @@ package io.netty.util.internal; +/** + * @deprecated For removal in netty 4.2 + */ +@Deprecated public final class IntegerHolder { public int value; } diff --git a/common/src/main/java/io/netty/util/internal/InternalThreadLocalMap.java b/common/src/main/java/io/netty/util/internal/InternalThreadLocalMap.java index 6ab3bdf1316d..0a6a6c564871 100644 --- a/common/src/main/java/io/netty/util/internal/InternalThreadLocalMap.java +++ b/common/src/main/java/io/netty/util/internal/InternalThreadLocalMap.java @@ -257,10 +257,12 @@ public Map, Map> typeParameterMatcherFind return cache; } + @Deprecated public IntegerHolder counterHashCode() { return counterHashCode; } + @Deprecated public void setCounterHashCode(IntegerHolder counterHashCode) { this.counterHashCode = counterHashCode; } diff --git a/common/src/main/java/io/netty/util/internal/NativeLibraryLoader.java b/common/src/main/java/io/netty/util/internal/NativeLibraryLoader.java index cb3295e6f23b..98e551c4eb66 100644 --- a/common/src/main/java/io/netty/util/internal/NativeLibraryLoader.java +++ b/common/src/main/java/io/netty/util/internal/NativeLibraryLoader.java @@ -15,6 +15,7 @@ */ package io.netty.util.internal; +import io.netty.util.CharsetUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -47,6 +48,11 @@ public final class NativeLibraryLoader { private static final String NATIVE_RESOURCE_HOME = "META-INF/native/"; private static final File WORKDIR; private static final boolean DELETE_NATIVE_LIB_AFTER_LOADING; + private static final boolean TRY_TO_PATCH_SHADED_ID; + + // Just use a-Z and numbers as valid ID bytes. + private static final byte[] UNIQUE_ID_BYTES = + "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes(CharsetUtil.US_ASCII); static { String workdir = SystemPropertyUtil.get("io.netty.native.workdir"); @@ -69,6 +75,11 @@ public final class NativeLibraryLoader { DELETE_NATIVE_LIB_AFTER_LOADING = SystemPropertyUtil.getBoolean( "io.netty.native.deleteLibAfterLoading", true); + logger.debug("-Dio.netty.native.deleteLibAfterLoading: {}", DELETE_NATIVE_LIB_AFTER_LOADING); + + TRY_TO_PATCH_SHADED_ID = SystemPropertyUtil.getBoolean( + "io.netty.native.tryPatchShadedId", true); + logger.debug("-Dio.netty.native.tryPatchShadedId: {}", TRY_TO_PATCH_SHADED_ID); } /** @@ -117,7 +128,8 @@ private static String calculatePackagePrefix() { */ public static void load(String originalName, ClassLoader loader) { // Adjust expected name to support shading of native libraries. - String name = calculatePackagePrefix().replace('.', '_') + originalName; + String packagePrefix = calculatePackagePrefix().replace('.', '_'); + String name = packagePrefix + originalName; List suppressed = new ArrayList(); try { // first try to load from java.library.path @@ -172,27 +184,34 @@ public static void load(String originalName, ClassLoader loader) { in = url.openStream(); out = new FileOutputStream(tmpFile); - byte[] buffer = new byte[8192]; - int length; - while ((length = in.read(buffer)) > 0) { - out.write(buffer, 0, length); + if (shouldShadedLibraryIdBePatched(packagePrefix)) { + patchShadedLibraryId(in, out, originalName, name); + } else { + byte[] buffer = new byte[8192]; + int length; + while ((length = in.read(buffer)) > 0) { + out.write(buffer, 0, length); + } } + out.flush(); // Close the output stream before loading the unpacked library, // because otherwise Windows will refuse to load it when it's in use by other process. closeQuietly(out); out = null; - loadLibrary(loader, tmpFile.getPath(), true); } catch (UnsatisfiedLinkError e) { try { if (tmpFile != null && tmpFile.isFile() && tmpFile.canRead() && !NoexecVolumeDetector.canExecuteExecutable(tmpFile)) { + // Pass "io.netty.native.workdir" as an argument to allow shading tools to see + // the string. Since this is printed out to users to tell them what to do next, + // we want the value to be correct even when shading. logger.info("{} exists but cannot be executed even when execute permissions set; " + - "check volume for \"noexec\" flag; use -Dio.netty.native.workdir=[path] " + + "check volume for \"noexec\" flag; use -D{}=[path] " + "to set native working directory separately.", - tmpFile.getPath()); + tmpFile.getPath(), "io.netty.native.workdir"); } } catch (Throwable t) { suppressed.add(t); @@ -218,6 +237,93 @@ public static void load(String originalName, ClassLoader loader) { } } + // Package-private for testing. + static boolean patchShadedLibraryId(InputStream in, OutputStream out, String originalName, String name) + throws IOException { + byte[] buffer = new byte[8192]; + int length; + // We read the whole native lib into memory to make it easier to monkey-patch the id. + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(in.available()); + + while ((length = in.read(buffer)) > 0) { + byteArrayOutputStream.write(buffer, 0, length); + } + byteArrayOutputStream.flush(); + byte[] bytes = byteArrayOutputStream.toByteArray(); + byteArrayOutputStream.close(); + + final boolean patched; + // Try to patch the library id. + if (!patchShadedLibraryId(bytes, originalName, name)) { + // We did not find the Id, check if we used a originalName that has the os and arch as suffix. + // If this is the case we should also try to patch with the os and arch suffix removed. + String os = PlatformDependent.normalizedOs(); + String arch = PlatformDependent.normalizedArch(); + String osArch = "_" + os + "_" + arch; + if (originalName.endsWith(osArch)) { + patched = patchShadedLibraryId(bytes, + originalName.substring(0, originalName.length() - osArch.length()), name); + } else { + patched = false; + } + } else { + patched = true; + } + out.write(bytes, 0, bytes.length); + return patched; + } + + private static boolean shouldShadedLibraryIdBePatched(String packagePrefix) { + return TRY_TO_PATCH_SHADED_ID && PlatformDependent.isOsx() && !packagePrefix.isEmpty(); + } + + /** + * Try to patch shaded library to ensure it uses a unique ID. + */ + private static boolean patchShadedLibraryId(byte[] bytes, String originalName, String name) { + // Our native libs always have the name as part of their id so we can search for it and replace it + // to make the ID unique if shading is used. + byte[] nameBytes = originalName.getBytes(CharsetUtil.UTF_8); + int idIdx = -1; + + // Be aware this is a really raw way of patching a dylib but it does all we need without implementing + // a full mach-o parser and writer. Basically we just replace the the original bytes with some + // random bytes as part of the ID regeneration. The important thing here is that we need to use the same + // length to not corrupt the mach-o header. + outerLoop: for (int i = 0; i < bytes.length && bytes.length - i >= nameBytes.length; i++) { + int idx = i; + for (int j = 0; j < nameBytes.length;) { + if (bytes[idx++] != nameBytes[j++]) { + // Did not match the name, increase the index and try again. + break; + } else if (j == nameBytes.length) { + // We found the index within the id. + idIdx = i; + break outerLoop; + } + } + } + + if (idIdx == -1) { + logger.debug("Was not able to find the ID of the shaded native library {}, can't adjust it.", name); + return false; + } else { + // We found our ID... now monkey-patch it! + for (int i = 0; i < nameBytes.length; i++) { + // We should only use bytes as replacement that are in our UNIQUE_ID_BYTES array. + bytes[idIdx + i] = UNIQUE_ID_BYTES[PlatformDependent.threadLocalRandom() + .nextInt(UNIQUE_ID_BYTES.length)]; + } + + if (logger.isDebugEnabled()) { + logger.debug( + "Found the ID of the shaded native library {}. Replacing ID part {} with {}", + name, originalName, new String(bytes, idIdx, nameBytes.length, CharsetUtil.UTF_8)); + } + return true; + } + } + /** * Loading the native library into the specified {@link ClassLoader}. * @param loader - The {@link ClassLoader} where the native library will be loaded into diff --git a/common/src/main/java/io/netty/util/internal/PlatformDependent.java b/common/src/main/java/io/netty/util/internal/PlatformDependent.java index 904e9d2fa194..a43a534424a3 100644 --- a/common/src/main/java/io/netty/util/internal/PlatformDependent.java +++ b/common/src/main/java/io/netty/util/internal/PlatformDependent.java @@ -29,6 +29,7 @@ import org.jctools.util.UnsafeAccess; import java.io.File; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -73,14 +74,14 @@ public final class PlatformDependent { private static final boolean IS_WINDOWS = isWindows0(); private static final boolean IS_OSX = isOsx0(); + private static final boolean IS_J9_JVM = isJ9Jvm0(); private static final boolean MAYBE_SUPER_USER; private static final boolean CAN_ENABLE_TCP_NODELAY_BY_DEFAULT = !isAndroid(); private static final Throwable UNSAFE_UNAVAILABILITY_CAUSE = unsafeUnavailabilityCause0(); - private static final boolean DIRECT_BUFFER_PREFERRED = - UNSAFE_UNAVAILABILITY_CAUSE == null && !SystemPropertyUtil.getBoolean("io.netty.noPreferDirect", false); + private static final boolean DIRECT_BUFFER_PREFERRED; private static final long MAX_DIRECT_MEMORY = maxDirectMemory0(); private static final int MPSC_CHUNK_SIZE = 1024; @@ -128,20 +129,6 @@ public Random current() { } }; } - if (logger.isDebugEnabled()) { - logger.debug("-Dio.netty.noPreferDirect: {}", !DIRECT_BUFFER_PREFERRED); - } - - /* - * We do not want to log this message if unsafe is explicitly disabled. Do not remove the explicit no unsafe - * guard. - */ - if (!hasUnsafe() && !isAndroid() && !PlatformDependent0.isExplicitNoUnsafe()) { - logger.info( - "Your platform does not provide complete low-level API for accessing direct buffers reliably. " + - "Unless explicitly requested, heap buffer will always be preferred to avoid potential system " + - "instability."); - } // Here is how the system property is used: // @@ -157,14 +144,12 @@ public Random current() { } else { USE_DIRECT_BUFFER_NO_CLEANER = true; if (maxDirectMemory < 0) { - maxDirectMemory = maxDirectMemory0(); + maxDirectMemory = MAX_DIRECT_MEMORY; } } - DIRECT_MEMORY_COUNTER = new AtomicLong(); - DIRECT_MEMORY_LIMIT = maxDirectMemory; - logger.debug("-Dio.netty.maxDirectMemory: {} bytes", maxDirectMemory); + DIRECT_MEMORY_LIMIT = maxDirectMemory >= 1 ? maxDirectMemory : MAX_DIRECT_MEMORY; int tryAllocateUninitializedArray = SystemPropertyUtil.getInt("io.netty.uninitializedArrayAllocationThreshold", 1024); @@ -174,7 +159,7 @@ public Random current() { MAYBE_SUPER_USER = maybeSuperUser0(); - if (!isAndroid() && hasUnsafe()) { + if (!isAndroid()) { // only direct to method if we are not running on android. // See https://github.com/netty/netty/issues/2604 if (javaVersion() >= 9) { @@ -185,6 +170,24 @@ public Random current() { } else { CLEANER = NOOP; } + + // We should always prefer direct buffers by default if we can use a Cleaner to release direct buffers. + DIRECT_BUFFER_PREFERRED = CLEANER != NOOP + && !SystemPropertyUtil.getBoolean("io.netty.noPreferDirect", false); + if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.noPreferDirect: {}", !DIRECT_BUFFER_PREFERRED); + } + + /* + * We do not want to log this message if unsafe is explicitly disabled. Do not remove the explicit no unsafe + * guard. + */ + if (CLEANER == NOOP && !PlatformDependent0.isExplicitNoUnsafe()) { + logger.info( + "Your platform does not provide complete low-level API for accessing direct buffers reliably. " + + "Unless explicitly requested, heap buffer will always be preferred to avoid potential system " + + "instability."); + } } public static boolean hasDirectBufferNoCleanerConstructor() { @@ -275,7 +278,7 @@ public static boolean directBufferPreferred() { * Returns the maximum memory reserved for direct buffer allocation. */ public static long maxDirectMemory() { - return MAX_DIRECT_MEMORY; + return DIRECT_MEMORY_LIMIT; } /** @@ -387,6 +390,10 @@ public static ByteBuffer directBuffer(long memoryAddress, int size) { "sun.misc.Unsafe or java.nio.DirectByteBuffer.(long, int) not available"); } + public static Object getObject(Object object, long fieldOffset) { + return PlatformDependent0.getObject(object, fieldOffset); + } + public static int getInt(Object object, long fieldOffset) { return PlatformDependent0.getInt(object, fieldOffset); } @@ -551,6 +558,14 @@ public static void putLong(byte[] data, int index, long value) { PlatformDependent0.putLong(data, index, value); } + public static void putObject(Object o, long offset, Object x) { + PlatformDependent0.putObject(o, offset, x); + } + + public static long objectFieldOffset(Field field) { + return PlatformDependent0.objectFieldOffset(field); + } + public static void copyMemory(long srcAddr, long dstAddr, long length) { PlatformDependent0.copyMemory(srcAddr, dstAddr, length); } @@ -767,83 +782,43 @@ public static int hashCodeAscii(byte[] bytes, int startPos, int length) { * The resulting hash code will be case insensitive. */ public static int hashCodeAscii(CharSequence bytes) { + final int length = bytes.length(); + final int remainingBytes = length & 7; int hash = HASH_CODE_ASCII_SEED; - final int remainingBytes = bytes.length() & 7; // Benchmarking shows that by just naively looping for inputs 8~31 bytes long we incur a relatively large // performance penalty (only achieve about 60% performance of loop which iterates over each char). So because // of this we take special provisions to unroll the looping for these conditions. - switch (bytes.length()) { - case 31: - case 30: - case 29: - case 28: - case 27: - case 26: - case 25: - case 24: - hash = hashCodeAsciiCompute(bytes, bytes.length() - 24, - hashCodeAsciiCompute(bytes, bytes.length() - 16, - hashCodeAsciiCompute(bytes, bytes.length() - 8, hash))); - break; - case 23: - case 22: - case 21: - case 20: - case 19: - case 18: - case 17: - case 16: - hash = hashCodeAsciiCompute(bytes, bytes.length() - 16, - hashCodeAsciiCompute(bytes, bytes.length() - 8, hash)); - break; - case 15: - case 14: - case 13: - case 12: - case 11: - case 10: - case 9: - case 8: - hash = hashCodeAsciiCompute(bytes, bytes.length() - 8, hash); - break; - case 7: - case 6: - case 5: - case 4: - case 3: - case 2: - case 1: - case 0: - break; - default: - for (int i = bytes.length() - 8; i >= remainingBytes; i -= 8) { - hash = hashCodeAsciiCompute(bytes, i, hash); + if (length >= 32) { + for (int i = length - 8; i >= remainingBytes; i -= 8) { + hash = hashCodeAsciiCompute(bytes, i, hash); + } + } else if (length >= 8) { + hash = hashCodeAsciiCompute(bytes, length - 8, hash); + if (length >= 16) { + hash = hashCodeAsciiCompute(bytes, length - 16, hash); + if (length >= 24) { + hash = hashCodeAsciiCompute(bytes, length - 24, hash); } - break; + } } - switch(remainingBytes) { - case 7: - return ((hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0))) - * HASH_CODE_C2 + hashCodeAsciiSanitizeShort(bytes, 1)) - * HASH_CODE_C1 + hashCodeAsciiSanitizeInt(bytes, 3); - case 6: - return (hash * HASH_CODE_C1 + hashCodeAsciiSanitizeShort(bytes, 0)) - * HASH_CODE_C2 + hashCodeAsciiSanitizeInt(bytes, 2); - case 5: - return (hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0))) - * HASH_CODE_C2 + hashCodeAsciiSanitizeInt(bytes, 1); - case 4: - return hash * HASH_CODE_C1 + hashCodeAsciiSanitizeInt(bytes, 0); - case 3: - return (hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0))) - * HASH_CODE_C2 + hashCodeAsciiSanitizeShort(bytes, 1); - case 2: - return hash * HASH_CODE_C1 + hashCodeAsciiSanitizeShort(bytes, 0); - case 1: - return hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0)); - default: - return hash; + if (remainingBytes == 0) { + return hash; + } + int offset = 0; + if (remainingBytes != 2 & remainingBytes != 4 & remainingBytes != 6) { // 1, 3, 5, 7 + hash = hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0)); + offset = 1; + } + if (remainingBytes != 1 & remainingBytes != 4 & remainingBytes != 5) { // 2, 3, 6, 7 + hash = hash * (offset == 0 ? HASH_CODE_C1 : HASH_CODE_C2) + + hashCodeAsciiSanitize(hashCodeAsciiSanitizeShort(bytes, offset)); + offset += 2; + } + if (remainingBytes >= 4) { // 4, 5, 6, 7 + return hash * ((offset == 0 | offset == 3) ? HASH_CODE_C1 : HASH_CODE_C2) + + hashCodeAsciiSanitizeInt(bytes, offset); } + return hash; } private static final class Mpsc { @@ -1012,6 +987,19 @@ private static Throwable unsafeUnavailabilityCause0() { } } + /** + * Returns {@code true} if the running JVM is either IBM J9 or + * Eclipse OpenJ9, {@code false} otherwise. + */ + public static boolean isJ9Jvm() { + return IS_J9_JVM; + } + + private static boolean isJ9Jvm0() { + String vmName = SystemPropertyUtil.get("java.vm.name", "").toLowerCase(); + return vmName.startsWith("ibm j9") || vmName.startsWith("eclipse openj9"); + } + private static long maxDirectMemory0() { long maxDirectMemory = 0; @@ -1019,10 +1007,14 @@ private static long maxDirectMemory0() { try { systemClassLoader = getSystemClassLoader(); - // On z/OS we should not use VM.maxDirectMemory() as it not reflects the correct value. + // When using IBM J9 / Eclipse OpenJ9 we should not use VM.maxDirectMemory() as it not reflects the + // correct value. // See: // - https://github.com/netty/netty/issues/7654 - if (!SystemPropertyUtil.get("os.name", "").toLowerCase().contains("z/os")) { + String vmName = SystemPropertyUtil.get("java.vm.name", "").toLowerCase(); + if (!vmName.startsWith("ibm j9") && + // https://github.com/eclipse/openj9/blob/openj9-0.8.0/runtime/include/vendor_version.h#L53 + !vmName.startsWith("eclipse openj9")) { // Try to get from sun.misc.VM.maxDirectMemory() which should be most accurate. Class vmClass = Class.forName("sun.misc.VM", true, systemClassLoader); Method m = vmClass.getDeclaredMethod("maxDirectMemory"); diff --git a/common/src/main/java/io/netty/util/internal/PlatformDependent0.java b/common/src/main/java/io/netty/util/internal/PlatformDependent0.java index 275227b5f05e..edf965893cfa 100644 --- a/common/src/main/java/io/netty/util/internal/PlatformDependent0.java +++ b/common/src/main/java/io/netty/util/internal/PlatformDependent0.java @@ -262,13 +262,32 @@ public Object run() { DIRECT_BUFFER_CONSTRUCTOR = directBufferConstructor; ADDRESS_FIELD_OFFSET = objectFieldOffset(addressField); BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); - boolean unaligned; + final boolean unaligned; Object maybeUnaligned = AccessController.doPrivileged(new PrivilegedAction() { @Override public Object run() { try { Class bitsClass = Class.forName("java.nio.Bits", false, getSystemClassLoader()); + int version = javaVersion(); + if (version >= 9) { + // Java9/10 use all lowercase and later versions all uppercase. + String fieldName = version >= 11 ? "UNALIGNED" : "unaligned"; + // On Java9 and later we try to directly access the field as we can do this without + // adjust the accessible levels. + try { + Field unalignedField = bitsClass.getDeclaredField(fieldName); + if (unalignedField.getType() == boolean.class) { + long offset = UNSAFE.staticFieldOffset(unalignedField); + Object object = UNSAFE.staticFieldBase(unalignedField); + return UNSAFE.getBoolean(object, offset); + } + // There is something unexpected stored in the field, + // let us fall-back and try to use a reflective method call as last resort. + } catch (NoSuchFieldException ignore) { + // We did not find the field we expected, move on. + } + } Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); Throwable cause = ReflectionUtil.trySetAccessible(unalignedMethod, true); if (cause != null) { @@ -368,7 +387,7 @@ public Object run() { } static boolean isExplicitNoUnsafe() { - return EXPLICIT_NO_UNSAFE_CAUSE == null; + return EXPLICIT_NO_UNSAFE_CAUSE != null; } private static Throwable explicitNoUnsafeCause0() { @@ -427,7 +446,10 @@ static ByteBuffer reallocateDirectNoCleaner(ByteBuffer buffer, int capacity) { } static ByteBuffer allocateDirectNoCleaner(int capacity) { - return newDirectBuffer(UNSAFE.allocateMemory(capacity), capacity); + // Calling malloc with capacity of 0 may return a null ptr or a memory address that can be used. + // Just use 1 to make it safe to use in all cases: + // See: http://pubs.opengroup.org/onlinepubs/009695399/functions/malloc.html + return newDirectBuffer(UNSAFE.allocateMemory(Math.max(1, capacity)), capacity); } static boolean hasAllocateArrayMethod() { @@ -546,8 +568,21 @@ static void putLong(byte[] data, int index, long value) { UNSAFE.putLong(data, BYTE_ARRAY_BASE_OFFSET + index, value); } + static void putObject(Object o, long offset, Object x) { + UNSAFE.putObject(o, offset, x); + } + static void copyMemory(long srcAddr, long dstAddr, long length) { - //UNSAFE.copyMemory(srcAddr, dstAddr, length); + // Manual safe-point polling is only needed prior Java9: + // See https://bugs.openjdk.java.net/browse/JDK-8149596 + if (javaVersion() <= 8) { + copyMemoryWithSafePointPolling(srcAddr, dstAddr, length); + } else { + UNSAFE.copyMemory(srcAddr, dstAddr, length); + } + } + + private static void copyMemoryWithSafePointPolling(long srcAddr, long dstAddr, long length) { while (length > 0) { long size = Math.min(length, UNSAFE_COPY_THRESHOLD); UNSAFE.copyMemory(srcAddr, dstAddr, size); @@ -558,7 +593,17 @@ static void copyMemory(long srcAddr, long dstAddr, long length) { } static void copyMemory(Object src, long srcOffset, Object dst, long dstOffset, long length) { - //UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, length); + // Manual safe-point polling is only needed prior Java9: + // See https://bugs.openjdk.java.net/browse/JDK-8149596 + if (javaVersion() <= 8) { + copyMemoryWithSafePointPolling(src, srcOffset, dst, dstOffset, length); + } else { + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, length); + } + } + + private static void copyMemoryWithSafePointPolling( + Object src, long srcOffset, Object dst, long dstOffset, long length) { while (length > 0) { long size = Math.min(length, UNSAFE_COPY_THRESHOLD); UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); @@ -577,72 +622,57 @@ static void setMemory(Object o, long offset, long bytes, byte value) { } static boolean equals(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { - if (length <= 0) { - return true; - } - final long baseOffset1 = BYTE_ARRAY_BASE_OFFSET + startPos1; - final long baseOffset2 = BYTE_ARRAY_BASE_OFFSET + startPos2; int remainingBytes = length & 7; - final long end = baseOffset1 + remainingBytes; - for (long i = baseOffset1 - 8 + length, j = baseOffset2 - 8 + length; i >= end; i -= 8, j -= 8) { - if (UNSAFE.getLong(bytes1, i) != UNSAFE.getLong(bytes2, j)) { - return false; + final long baseOffset1 = BYTE_ARRAY_BASE_OFFSET + startPos1; + final long diff = startPos2 - startPos1; + if (length >= 8) { + final long end = baseOffset1 + remainingBytes; + for (long i = baseOffset1 - 8 + length; i >= end; i -= 8) { + if (UNSAFE.getLong(bytes1, i) != UNSAFE.getLong(bytes2, i + diff)) { + return false; + } } } - if (remainingBytes >= 4) { remainingBytes -= 4; - if (UNSAFE.getInt(bytes1, baseOffset1 + remainingBytes) != - UNSAFE.getInt(bytes2, baseOffset2 + remainingBytes)) { + long pos = baseOffset1 + remainingBytes; + if (UNSAFE.getInt(bytes1, pos) != UNSAFE.getInt(bytes2, pos + diff)) { return false; } } + final long baseOffset2 = baseOffset1 + diff; if (remainingBytes >= 2) { return UNSAFE.getChar(bytes1, baseOffset1) == UNSAFE.getChar(bytes2, baseOffset2) && - (remainingBytes == 2 || bytes1[startPos1 + 2] == bytes2[startPos2 + 2]); + (remainingBytes == 2 || + UNSAFE.getByte(bytes1, baseOffset1 + 2) == UNSAFE.getByte(bytes2, baseOffset2 + 2)); } - return bytes1[startPos1] == bytes2[startPos2]; + return remainingBytes == 0 || + UNSAFE.getByte(bytes1, baseOffset1) == UNSAFE.getByte(bytes2, baseOffset2); } static int equalsConstantTime(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { long result = 0; + long remainingBytes = length & 7; final long baseOffset1 = BYTE_ARRAY_BASE_OFFSET + startPos1; - final long baseOffset2 = BYTE_ARRAY_BASE_OFFSET + startPos2; - final int remainingBytes = length & 7; final long end = baseOffset1 + remainingBytes; - for (long i = baseOffset1 - 8 + length, j = baseOffset2 - 8 + length; i >= end; i -= 8, j -= 8) { - result |= UNSAFE.getLong(bytes1, i) ^ UNSAFE.getLong(bytes2, j); - } - switch (remainingBytes) { - case 7: - return ConstantTimeUtils.equalsConstantTime(result | - (UNSAFE.getInt(bytes1, baseOffset1 + 3) ^ UNSAFE.getInt(bytes2, baseOffset2 + 3)) | - (UNSAFE.getChar(bytes1, baseOffset1 + 1) ^ UNSAFE.getChar(bytes2, baseOffset2 + 1)) | - (UNSAFE.getByte(bytes1, baseOffset1) ^ UNSAFE.getByte(bytes2, baseOffset2)), 0); - case 6: - return ConstantTimeUtils.equalsConstantTime(result | - (UNSAFE.getInt(bytes1, baseOffset1 + 2) ^ UNSAFE.getInt(bytes2, baseOffset2 + 2)) | - (UNSAFE.getChar(bytes1, baseOffset1) ^ UNSAFE.getChar(bytes2, baseOffset2)), 0); - case 5: - return ConstantTimeUtils.equalsConstantTime(result | - (UNSAFE.getInt(bytes1, baseOffset1 + 1) ^ UNSAFE.getInt(bytes2, baseOffset2 + 1)) | - (UNSAFE.getByte(bytes1, baseOffset1) ^ UNSAFE.getByte(bytes2, baseOffset2)), 0); - case 4: - return ConstantTimeUtils.equalsConstantTime(result | - (UNSAFE.getInt(bytes1, baseOffset1) ^ UNSAFE.getInt(bytes2, baseOffset2)), 0); - case 3: - return ConstantTimeUtils.equalsConstantTime(result | - (UNSAFE.getChar(bytes1, baseOffset1 + 1) ^ UNSAFE.getChar(bytes2, baseOffset2 + 1)) | - (UNSAFE.getByte(bytes1, baseOffset1) ^ UNSAFE.getByte(bytes2, baseOffset2)), 0); - case 2: - return ConstantTimeUtils.equalsConstantTime(result | - (UNSAFE.getChar(bytes1, baseOffset1) ^ UNSAFE.getChar(bytes2, baseOffset2)), 0); - case 1: - return ConstantTimeUtils.equalsConstantTime(result | - (UNSAFE.getByte(bytes1, baseOffset1) ^ UNSAFE.getByte(bytes2, baseOffset2)), 0); - default: - return ConstantTimeUtils.equalsConstantTime(result, 0); + final long diff = startPos2 - startPos1; + for (long i = baseOffset1 - 8 + length; i >= end; i -= 8) { + result |= UNSAFE.getLong(bytes1, i) ^ UNSAFE.getLong(bytes2, i + diff); } + if (remainingBytes >= 4) { + result |= UNSAFE.getInt(bytes1, baseOffset1) ^ UNSAFE.getInt(bytes2, baseOffset1 + diff); + remainingBytes -= 4; + } + if (remainingBytes >= 2) { + long pos = end - remainingBytes; + result |= UNSAFE.getChar(bytes1, pos) ^ UNSAFE.getChar(bytes2, pos + diff); + remainingBytes -= 2; + } + if (remainingBytes == 1) { + long pos = end - 1; + result |= UNSAFE.getByte(bytes1, pos) ^ UNSAFE.getByte(bytes2, pos + diff); + } + return ConstantTimeUtils.equalsConstantTime(result, 0); } static boolean isZero(byte[] bytes, int startPos, int length) { @@ -673,35 +703,30 @@ static boolean isZero(byte[] bytes, int startPos, int length) { static int hashCodeAscii(byte[] bytes, int startPos, int length) { int hash = HASH_CODE_ASCII_SEED; - final long baseOffset = BYTE_ARRAY_BASE_OFFSET + startPos; + long baseOffset = BYTE_ARRAY_BASE_OFFSET + startPos; final int remainingBytes = length & 7; final long end = baseOffset + remainingBytes; for (long i = baseOffset - 8 + length; i >= end; i -= 8) { hash = hashCodeAsciiCompute(UNSAFE.getLong(bytes, i), hash); } - switch(remainingBytes) { - case 7: - return ((hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset))) - * HASH_CODE_C2 + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset + 1))) - * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset + 3)); - case 6: - return (hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset))) - * HASH_CODE_C2 + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset + 2)); - case 5: - return (hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset))) - * HASH_CODE_C2 + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset + 1)); - case 4: - return hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset)); - case 3: - return (hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset))) - * HASH_CODE_C2 + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset + 1)); - case 2: - return hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset)); - case 1: - return hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset)); - default: + if (remainingBytes == 0) { return hash; } + int hcConst = HASH_CODE_C1; + if (remainingBytes != 2 & remainingBytes != 4 & remainingBytes != 6) { // 1, 3, 5, 7 + hash = hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset)); + hcConst = HASH_CODE_C2; + baseOffset++; + } + if (remainingBytes != 1 & remainingBytes != 4 & remainingBytes != 5) { // 2, 3, 6, 7 + hash = hash * hcConst + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset)); + hcConst = hcConst == HASH_CODE_C1 ? HASH_CODE_C2 : HASH_CODE_C1; + baseOffset += 2; + } + if (remainingBytes >= 4) { // 4, 5, 6, 7 + return hash * hcConst + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset)); + } + return hash; } static int hashCodeAsciiCompute(long value, int hash) { diff --git a/common/src/main/java/io/netty/util/internal/ResourcesUtil.java b/common/src/main/java/io/netty/util/internal/ResourcesUtil.java new file mode 100644 index 000000000000..7b1c8ea695d4 --- /dev/null +++ b/common/src/main/java/io/netty/util/internal/ResourcesUtil.java @@ -0,0 +1,43 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.internal; + +import java.io.File; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; + +/** + * A utility class that provides various common operations and constants + * related to loading resources + */ +public final class ResourcesUtil { + + /** + * Returns a {@link File} named {@code fileName} associated with {@link Class} {@code resourceClass} . + * + * @param resourceClass The associated class + * @param fileName The file name + * @return The file named {@code fileName} associated with {@link Class} {@code resourceClass} . + */ + public static File getFile(Class resourceClass, String fileName) { + try { + return new File(URLDecoder.decode(resourceClass.getResource(fileName).getFile(), "UTF-8")); + } catch (UnsupportedEncodingException e) { + return new File(resourceClass.getResource(fileName).getFile()); + } + } + + private ResourcesUtil() { } +} diff --git a/common/src/main/java/io/netty/util/internal/logging/AbstractInternalLogger.java b/common/src/main/java/io/netty/util/internal/logging/AbstractInternalLogger.java index 78007c1e4a9a..6486efaf088f 100644 --- a/common/src/main/java/io/netty/util/internal/logging/AbstractInternalLogger.java +++ b/common/src/main/java/io/netty/util/internal/logging/AbstractInternalLogger.java @@ -29,7 +29,7 @@ public abstract class AbstractInternalLogger implements InternalLogger, Serializ private static final long serialVersionUID = -6382972526573193470L; - private static final String EXCEPTION_MESSAGE = "Unexpected exception:"; + static final String EXCEPTION_MESSAGE = "Unexpected exception:"; private final String name; diff --git a/common/src/main/java/io/netty/util/internal/logging/InternalLoggerFactory.java b/common/src/main/java/io/netty/util/internal/logging/InternalLoggerFactory.java index 9f85e3646b40..12c1b5a4477d 100644 --- a/common/src/main/java/io/netty/util/internal/logging/InternalLoggerFactory.java +++ b/common/src/main/java/io/netty/util/internal/logging/InternalLoggerFactory.java @@ -41,13 +41,18 @@ private static InternalLoggerFactory newDefaultFactory(String name) { try { f = new Slf4JLoggerFactory(true); f.newInstance(name).debug("Using SLF4J as the default logging framework"); - } catch (Throwable t1) { + } catch (Throwable ignore1) { try { f = Log4JLoggerFactory.INSTANCE; f.newInstance(name).debug("Using Log4J as the default logging framework"); - } catch (Throwable t2) { - f = JdkLoggerFactory.INSTANCE; - f.newInstance(name).debug("Using java.util.logging as the default logging framework"); + } catch (Throwable ignore2) { + try { + f = Log4J2LoggerFactory.INSTANCE; + f.newInstance(name).debug("Using Log4J2 as the default logging framework"); + } catch (Throwable ignore3) { + f = JdkLoggerFactory.INSTANCE; + f.newInstance(name).debug("Using java.util.logging as the default logging framework"); + } } } return f; diff --git a/common/src/main/java/io/netty/util/internal/logging/LocationAwareSlf4JLogger.java b/common/src/main/java/io/netty/util/internal/logging/LocationAwareSlf4JLogger.java new file mode 100644 index 000000000000..7f3628368965 --- /dev/null +++ b/common/src/main/java/io/netty/util/internal/logging/LocationAwareSlf4JLogger.java @@ -0,0 +1,252 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal.logging; + +import org.slf4j.spi.LocationAwareLogger; + +import static org.slf4j.spi.LocationAwareLogger.*; + +/** + * SLF4J logger which is location aware and so will log the correct origin of the + * logging event by filter out the wrapper itself. + */ +final class LocationAwareSlf4JLogger extends AbstractInternalLogger { + + // IMPORTANT: All our log methods first check if the log level is enabled before call the wrapped + // LocationAwareLogger.log(...) method. This is done to reduce GC creation that is caused by varargs. + + static final String FQCN = LocationAwareSlf4JLogger.class.getName(); + private static final long serialVersionUID = -8292030083201538180L; + + private final transient LocationAwareLogger logger; + + LocationAwareSlf4JLogger(LocationAwareLogger logger) { + super(logger.getName()); + this.logger = logger; + } + + private void log(final int level, final String message) { + logger.log(null, FQCN, level, message, null, null); + } + + private void log(final int level, final String message, Throwable cause) { + logger.log(null, FQCN, level, message, null, cause); + } + + private void log(final int level, final org.slf4j.helpers.FormattingTuple tuple) { + logger.log(null, FQCN, level, tuple.getMessage(), tuple.getArgArray(), tuple.getThrowable()); + } + + @Override + public boolean isTraceEnabled() { + return logger.isTraceEnabled(); + } + + @Override + public void trace(String msg) { + if (isTraceEnabled()) { + log(TRACE_INT, msg); + } + } + + @Override + public void trace(String format, Object arg) { + if (isTraceEnabled()) { + log(TRACE_INT, org.slf4j.helpers.MessageFormatter.format(format, arg)); + } + } + + @Override + public void trace(String format, Object argA, Object argB) { + if (isTraceEnabled()) { + log(TRACE_INT, org.slf4j.helpers.MessageFormatter.format(format, argA, argB)); + } + } + + @Override + public void trace(String format, Object... argArray) { + if (isTraceEnabled()) { + log(TRACE_INT, org.slf4j.helpers.MessageFormatter.arrayFormat(format, argArray)); + } + } + + @Override + public void trace(String msg, Throwable t) { + if (isTraceEnabled()) { + log(TRACE_INT, msg, t); + } + } + + @Override + public boolean isDebugEnabled() { + return logger.isDebugEnabled(); + } + + @Override + public void debug(String msg) { + if (isDebugEnabled()) { + log(DEBUG_INT, msg); + } + } + + @Override + public void debug(String format, Object arg) { + if (isDebugEnabled()) { + log(DEBUG_INT, org.slf4j.helpers.MessageFormatter.format(format, arg)); + } + } + + @Override + public void debug(String format, Object argA, Object argB) { + if (isDebugEnabled()) { + log(DEBUG_INT, org.slf4j.helpers.MessageFormatter.format(format, argA, argB)); + } + } + + @Override + public void debug(String format, Object... argArray) { + if (isDebugEnabled()) { + log(DEBUG_INT, org.slf4j.helpers.MessageFormatter.arrayFormat(format, argArray)); + } + } + + @Override + public void debug(String msg, Throwable t) { + if (isDebugEnabled()) { + log(DEBUG_INT, msg, t); + } + } + + @Override + public boolean isInfoEnabled() { + return logger.isInfoEnabled(); + } + + @Override + public void info(String msg) { + if (isInfoEnabled()) { + log(INFO_INT, msg); + } + } + + @Override + public void info(String format, Object arg) { + if (isInfoEnabled()) { + log(INFO_INT, org.slf4j.helpers.MessageFormatter.format(format, arg)); + } + } + + @Override + public void info(String format, Object argA, Object argB) { + if (isInfoEnabled()) { + log(INFO_INT, org.slf4j.helpers.MessageFormatter.format(format, argA, argB)); + } + } + + @Override + public void info(String format, Object... argArray) { + if (isInfoEnabled()) { + log(INFO_INT, org.slf4j.helpers.MessageFormatter.arrayFormat(format, argArray)); + } + } + + @Override + public void info(String msg, Throwable t) { + if (isInfoEnabled()) { + log(INFO_INT, msg, t); + } + } + + @Override + public boolean isWarnEnabled() { + return logger.isWarnEnabled(); + } + + @Override + public void warn(String msg) { + if (isWarnEnabled()) { + log(WARN_INT, msg); + } + } + + @Override + public void warn(String format, Object arg) { + if (isWarnEnabled()) { + log(WARN_INT, org.slf4j.helpers.MessageFormatter.format(format, arg)); + } + } + + @Override + public void warn(String format, Object... argArray) { + if (isWarnEnabled()) { + log(WARN_INT, org.slf4j.helpers.MessageFormatter.arrayFormat(format, argArray)); + } + } + + @Override + public void warn(String format, Object argA, Object argB) { + if (isWarnEnabled()) { + log(WARN_INT, org.slf4j.helpers.MessageFormatter.format(format, argA, argB)); + } + } + + @Override + public void warn(String msg, Throwable t) { + if (isWarnEnabled()) { + log(WARN_INT, msg, t); + } + } + + @Override + public boolean isErrorEnabled() { + return logger.isErrorEnabled(); + } + + @Override + public void error(String msg) { + if (isErrorEnabled()) { + log(ERROR_INT, msg); + } + } + + @Override + public void error(String format, Object arg) { + if (isErrorEnabled()) { + log(ERROR_INT, org.slf4j.helpers.MessageFormatter.format(format, arg)); + } + } + + @Override + public void error(String format, Object argA, Object argB) { + if (isErrorEnabled()) { + log(ERROR_INT, org.slf4j.helpers.MessageFormatter.format(format, argA, argB)); + } + } + + @Override + public void error(String format, Object... argArray) { + if (isErrorEnabled()) { + log(ERROR_INT, org.slf4j.helpers.MessageFormatter.arrayFormat(format, argArray)); + } + } + + @Override + public void error(String msg, Throwable t) { + if (isErrorEnabled()) { + log(ERROR_INT, msg, t); + } + } +} diff --git a/common/src/main/java/io/netty/util/internal/logging/Log4J2Logger.java b/common/src/main/java/io/netty/util/internal/logging/Log4J2Logger.java index 8ed2fdf48df6..5c3593f203ae 100644 --- a/common/src/main/java/io/netty/util/internal/logging/Log4J2Logger.java +++ b/common/src/main/java/io/netty/util/internal/logging/Log4J2Logger.java @@ -21,15 +21,42 @@ import org.apache.logging.log4j.spi.ExtendedLogger; import org.apache.logging.log4j.spi.ExtendedLoggerWrapper; +import java.security.AccessController; +import java.security.PrivilegedAction; + +import static io.netty.util.internal.logging.AbstractInternalLogger.EXCEPTION_MESSAGE; + class Log4J2Logger extends ExtendedLoggerWrapper implements InternalLogger { private static final long serialVersionUID = 5485418394879791397L; - - /** {@linkplain AbstractInternalLogger#EXCEPTION_MESSAGE} */ - private static final String EXCEPTION_MESSAGE = "Unexpected exception:"; + private static final boolean VARARGS_ONLY; + + static { + // Older Log4J2 versions have only log methods that takes the format + varargs. So we should not use + // Log4J2 if the version is too old. + // See https://github.com/netty/netty/issues/8217 + VARARGS_ONLY = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Boolean run() { + try { + Logger.class.getMethod("debug", String.class, Object.class); + return false; + } catch (NoSuchMethodException ignore) { + // Log4J2 version too old. + return true; + } catch (SecurityException ignore) { + // We could not detect the version so we will use Log4J2 if its on the classpath. + return false; + } + } + }); + } Log4J2Logger(Logger logger) { super((ExtendedLogger) logger, logger.getName(), logger.getMessageFactory()); + if (VARARGS_ONLY) { + throw new UnsupportedOperationException("Log4J2 version mismatch"); + } } @Override @@ -97,7 +124,7 @@ public void log(InternalLogLevel level, Throwable t) { log(toLevel(level), EXCEPTION_MESSAGE, t); } - protected Level toLevel(InternalLogLevel level) { + private static Level toLevel(InternalLogLevel level) { switch (level) { case INFO: return Level.INFO; diff --git a/common/src/main/java/io/netty/util/internal/logging/Slf4JLogger.java b/common/src/main/java/io/netty/util/internal/logging/Slf4JLogger.java index e78727ca0f58..d628456a79d7 100644 --- a/common/src/main/java/io/netty/util/internal/logging/Slf4JLogger.java +++ b/common/src/main/java/io/netty/util/internal/logging/Slf4JLogger.java @@ -20,7 +20,7 @@ /** * SLF4J logger. */ -class Slf4JLogger extends AbstractInternalLogger { +final class Slf4JLogger extends AbstractInternalLogger { private static final long serialVersionUID = 108038972685130825L; diff --git a/common/src/main/java/io/netty/util/internal/logging/Slf4JLoggerFactory.java b/common/src/main/java/io/netty/util/internal/logging/Slf4JLoggerFactory.java index 42df0a1b76ea..0d97a5ea9e2f 100644 --- a/common/src/main/java/io/netty/util/internal/logging/Slf4JLoggerFactory.java +++ b/common/src/main/java/io/netty/util/internal/logging/Slf4JLoggerFactory.java @@ -16,8 +16,10 @@ package io.netty.util.internal.logging; +import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.helpers.NOPLoggerFactory; +import org.slf4j.spi.LocationAwareLogger; /** * Logger factory which creates a SLF4J @@ -44,6 +46,12 @@ public Slf4JLoggerFactory() { @Override public InternalLogger newInstance(String name) { - return new Slf4JLogger(LoggerFactory.getLogger(name)); + return wrapLogger(LoggerFactory.getLogger(name)); + } + + // package-private for testing. + static InternalLogger wrapLogger(Logger logger) { + return logger instanceof LocationAwareLogger ? + new LocationAwareSlf4JLogger((LocationAwareLogger) logger) : new Slf4JLogger(logger); } } diff --git a/common/src/main/templates/io/netty/util/collection/KObjectHashMap.template b/common/src/main/templates/io/netty/util/collection/KObjectHashMap.template index d8aeb1b25c96..7db1ceb6c30f 100644 --- a/common/src/main/templates/io/netty/util/collection/KObjectHashMap.template +++ b/common/src/main/templates/io/netty/util/collection/KObjectHashMap.template @@ -236,7 +236,7 @@ public class @K@ObjectHashMap implements @K@ObjectMap { @Override public void remove() { - throw new UnsupportedOperationException(); + iter.remove(); } }; } diff --git a/common/src/test/java/io/netty/util/AbstractReferenceCountedTest.java b/common/src/test/java/io/netty/util/AbstractReferenceCountedTest.java index 5af113490f8e..26762355fe29 100644 --- a/common/src/test/java/io/netty/util/AbstractReferenceCountedTest.java +++ b/common/src/test/java/io/netty/util/AbstractReferenceCountedTest.java @@ -15,8 +15,17 @@ */ package io.netty.util; +import io.netty.util.internal.ThreadLocalRandom; import org.junit.Test; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -74,6 +83,107 @@ public void testRetainResurrect2() { referenceCounted.retain(2); } + @Test(timeout = 30000) + public void testRetainFromMultipleThreadsThrowsReferenceCountException() throws Exception { + int threads = 4; + Queue> futures = new ArrayDeque>(threads); + ExecutorService service = Executors.newFixedThreadPool(threads); + final AtomicInteger refCountExceptions = new AtomicInteger(); + + try { + for (int i = 0; i < 10000; i++) { + final AbstractReferenceCounted referenceCounted = newReferenceCounted(); + final CountDownLatch retainLatch = new CountDownLatch(1); + assertTrue(referenceCounted.release()); + + for (int a = 0; a < threads; a++) { + final int retainCnt = ThreadLocalRandom.current().nextInt(1, Integer.MAX_VALUE); + futures.add(service.submit(new Runnable() { + @Override + public void run() { + try { + retainLatch.await(); + try { + referenceCounted.retain(retainCnt); + } catch (IllegalReferenceCountException e) { + refCountExceptions.incrementAndGet(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + })); + } + retainLatch.countDown(); + + for (;;) { + Future f = futures.poll(); + if (f == null) { + break; + } + f.get(); + } + assertEquals(4, refCountExceptions.get()); + refCountExceptions.set(0); + } + } finally { + service.shutdown(); + } + } + + @Test(timeout = 30000) + public void testReleaseFromMultipleThreadsThrowsReferenceCountException() throws Exception { + int threads = 4; + Queue> futures = new ArrayDeque>(threads); + ExecutorService service = Executors.newFixedThreadPool(threads); + final AtomicInteger refCountExceptions = new AtomicInteger(); + + try { + for (int i = 0; i < 10000; i++) { + final AbstractReferenceCounted referenceCounted = newReferenceCounted(); + final CountDownLatch releaseLatch = new CountDownLatch(1); + final AtomicInteger releasedCount = new AtomicInteger(); + + for (int a = 0; a < threads; a++) { + final AtomicInteger releaseCnt = new AtomicInteger(0); + + futures.add(service.submit(new Runnable() { + @Override + public void run() { + try { + releaseLatch.await(); + try { + if (referenceCounted.release(releaseCnt.incrementAndGet())) { + releasedCount.incrementAndGet(); + } + } catch (IllegalReferenceCountException e) { + refCountExceptions.incrementAndGet(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + })); + } + releaseLatch.countDown(); + + for (;;) { + Future f = futures.poll(); + if (f == null) { + break; + } + f.get(); + } + assertEquals(3, refCountExceptions.get()); + assertEquals(1, releasedCount.get()); + + refCountExceptions.set(0); + } + } finally { + service.shutdown(); + } + } + private static AbstractReferenceCounted newReferenceCounted() { return new AbstractReferenceCounted() { @Override diff --git a/common/src/test/java/io/netty/util/AsciiStringCharacterTest.java b/common/src/test/java/io/netty/util/AsciiStringCharacterTest.java index fe2ec301b0f8..c2a835df660a 100644 --- a/common/src/test/java/io/netty/util/AsciiStringCharacterTest.java +++ b/common/src/test/java/io/netty/util/AsciiStringCharacterTest.java @@ -240,6 +240,9 @@ public void testEqualsIgnoreCase() { assertThat(AsciiString.contentEqualsIgnoreCase(null, "foo"), is(false)); assertThat(AsciiString.contentEqualsIgnoreCase("bar", null), is(false)); assertThat(AsciiString.contentEqualsIgnoreCase("FoO", "fOo"), is(true)); + assertThat(AsciiString.contentEqualsIgnoreCase("FoO", "bar"), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("Foo", "foobar"), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("foobar", "Foo"), is(false)); // Test variations (Ascii + String, Ascii + Ascii, String + Ascii) assertThat(AsciiString.contentEqualsIgnoreCase(new AsciiString("FoO"), "fOo"), is(true)); @@ -398,4 +401,18 @@ public void testSubStringHashCode() { //two "123"s assertEquals(AsciiString.hashCode("123"), AsciiString.hashCode("a123".substring(1))); } + + @Test + public void testIndexOf() { + AsciiString foo = AsciiString.of("This is a test"); + int i1 = foo.indexOf(' ', 0); + assertEquals(4, i1); + int i2 = foo.indexOf(' ', i1 + 1); + assertEquals(7, i2); + int i3 = foo.indexOf(' ', i2 + 1); + assertEquals(9, i3); + assertTrue(i3 + 1 < foo.length()); + int i4 = foo.indexOf(' ', i3 + 1); + assertEquals(i4, -1); + } } diff --git a/common/src/test/java/io/netty/util/ConstantPoolTest.java b/common/src/test/java/io/netty/util/ConstantPoolTest.java index d8b1bd352b79..a3fc4e7c1f92 100644 --- a/common/src/test/java/io/netty/util/ConstantPoolTest.java +++ b/common/src/test/java/io/netty/util/ConstantPoolTest.java @@ -75,7 +75,7 @@ public void testCompare() { set.add(d); set.add(a); - TestConstant[] array = set.toArray(new TestConstant[5]); + TestConstant[] array = set.toArray(new TestConstant[0]); assertThat(array.length, is(5)); // Sort by name diff --git a/common/src/test/java/io/netty/util/RecyclerTest.java b/common/src/test/java/io/netty/util/RecyclerTest.java index 6eeace5f0017..e4bcbf1da46f 100644 --- a/common/src/test/java/io/netty/util/RecyclerTest.java +++ b/common/src/test/java/io/netty/util/RecyclerTest.java @@ -81,6 +81,38 @@ public void testMultipleRecycle() { object.recycle(); } + @Test(expected = IllegalStateException.class) + public void testMultipleRecycleAtDifferentThread() throws InterruptedException { + Recycler recycler = newRecycler(1024); + final HandledObject object = recycler.get(); + final AtomicReference exceptionStore = new AtomicReference(); + final Thread thread1 = new Thread(new Runnable() { + @Override + public void run() { + object.recycle(); + } + }); + thread1.start(); + thread1.join(); + + final Thread thread2 = new Thread(new Runnable() { + @Override + public void run() { + try { + object.recycle(); + } catch (IllegalStateException e) { + exceptionStore.set(e); + } + } + }); + thread2.start(); + thread2.join(); + IllegalStateException exception = exceptionStore.get(); + if (exception != null) { + throw exception; + } + } + @Test public void testRecycle() { Recycler recycler = newRecycler(1024); diff --git a/common/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java b/common/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java index 9c29f94a555b..b5e5907293a1 100644 --- a/common/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java +++ b/common/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java @@ -21,6 +21,7 @@ import io.netty.util.internal.logging.InternalLoggerFactory; import org.junit.BeforeClass; import org.junit.Test; +import org.mockito.Mockito; import java.util.HashMap; import java.util.Map; @@ -37,10 +38,7 @@ import static java.lang.Math.max; import static org.hamcrest.Matchers.lessThan; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; @SuppressWarnings("unchecked") public class DefaultPromiseTest { @@ -66,6 +64,41 @@ private static int stackOverflowTestDepth() { return max(stackOverflowDepth << 1, stackOverflowDepth); } + @Test + public void testCancelDoesNotScheduleWhenNoListeners() { + EventExecutor executor = Mockito.mock(EventExecutor.class); + Mockito.when(executor.inEventLoop()).thenReturn(false); + + Promise promise = new DefaultPromise(executor); + promise.cancel(false); + Mockito.verify(executor, Mockito.never()).execute(Mockito.any(Runnable.class)); + assertTrue(promise.isCancelled()); + } + + @Test + public void testSuccessDoesNotScheduleWhenNoListeners() { + EventExecutor executor = Mockito.mock(EventExecutor.class); + Mockito.when(executor.inEventLoop()).thenReturn(false); + + Object value = new Object(); + Promise promise = new DefaultPromise(executor); + promise.setSuccess(value); + Mockito.verify(executor, Mockito.never()).execute(Mockito.any(Runnable.class)); + assertSame(value, promise.getNow()); + } + + @Test + public void testFailureDoesNotScheduleWhenNoListeners() { + EventExecutor executor = Mockito.mock(EventExecutor.class); + Mockito.when(executor.inEventLoop()).thenReturn(false); + + Exception cause = new Exception(); + Promise promise = new DefaultPromise(executor); + promise.setFailure(cause); + Mockito.verify(executor, Mockito.never()).execute(Mockito.any(Runnable.class)); + assertSame(cause, promise.cause()); + } + @Test(expected = CancellationException.class) public void testCancellationExceptionIsThrownWhenBlockingGet() throws InterruptedException, ExecutionException { final Promise promise = new DefaultPromise(ImmediateEventExecutor.INSTANCE); @@ -258,6 +291,22 @@ public void signalSuccessCompletionValue() { assertTrue(promise.isSuccess()); } + @Test + public void setUncancellableGetNow() { + final Promise promise = new DefaultPromise(ImmediateEventExecutor.INSTANCE); + assertNull(promise.getNow()); + assertTrue(promise.setUncancellable()); + assertNull(promise.getNow()); + assertFalse(promise.isDone()); + assertFalse(promise.isSuccess()); + + promise.setSuccess("success"); + + assertTrue(promise.isDone()); + assertTrue(promise.isSuccess()); + assertEquals("success", promise.getNow()); + } + private static void testStackOverFlowChainedFuturesA(int promiseChainLength, final EventExecutor executor, boolean runTestInExecutorThread) throws InterruptedException { diff --git a/common/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java b/common/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java index 4551d4f78b92..6457de297a11 100644 --- a/common/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java +++ b/common/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java @@ -18,6 +18,7 @@ import io.netty.util.internal.ObjectCleaner; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import java.util.concurrent.atomic.AtomicBoolean; @@ -96,13 +97,13 @@ public void run() { thread.start(); thread.join(); - assertEquals(1, ObjectCleaner.getLiveSetCount() - sizeWhenStart); + assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); Thread thread2 = new Thread(runnable); thread2.start(); thread2.join(); - assertEquals(2, ObjectCleaner.getLiveSetCount() - sizeWhenStart); + assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); } @Test @@ -128,13 +129,13 @@ public void run() { thread.start(); thread.join(); - assertEquals(2, ObjectCleaner.getLiveSetCount() - sizeWhenStart); + assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); Thread thread2 = new Thread(runnable); thread2.start(); thread2.join(); - assertEquals(4, ObjectCleaner.getLiveSetCount() - sizeWhenStart); + assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); } @Test(timeout = 4000) @@ -142,6 +143,7 @@ public void testOnRemoveCalledForFastThreadLocalGet() throws Exception { testOnRemoveCalled(true, true); } + @Ignore("onRemoval(...) not called with non FastThreadLocal") @Test(timeout = 4000) public void testOnRemoveCalledForNonFastThreadLocalGet() throws Exception { testOnRemoveCalled(false, true); @@ -152,6 +154,7 @@ public void testOnRemoveCalledForFastThreadLocalSet() throws Exception { testOnRemoveCalled(true, false); } + @Ignore("onRemoval(...) not called with non FastThreadLocal") @Test(timeout = 4000) public void testOnRemoveCalledForNonFastThreadLocalSet() throws Exception { testOnRemoveCalled(false, false); diff --git a/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java b/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java index 16035505e8a7..b32df7eb4f7f 100644 --- a/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java +++ b/common/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java @@ -25,6 +25,7 @@ import java.util.Collection; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -93,6 +94,35 @@ public void run() { } } + @Test + public void testRaceCondition() throws InterruptedException { + EventExecutorGroup group = new UnorderedThreadPoolEventExecutor(1); + NonStickyEventExecutorGroup nonStickyGroup = new NonStickyEventExecutorGroup(group, maxTaskExecutePerRun); + + try { + EventExecutor executor = nonStickyGroup.next(); + + for (int j = 0; j < 5000; j++) { + final CountDownLatch firstCompleted = new CountDownLatch(1); + final CountDownLatch latch = new CountDownLatch(2); + for (int i = 0; i < 2; i++) { + executor.execute(new Runnable() { + @Override + public void run() { + firstCompleted.countDown(); + latch.countDown(); + } + }); + Assert.assertTrue(firstCompleted.await(1, TimeUnit.SECONDS)); + } + + Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); + } + } finally { + nonStickyGroup.shutdownGracefully(); + } + } + private static void execute(EventExecutorGroup group, CountDownLatch startLatch) throws Throwable { EventExecutor executor = group.next(); Assert.assertTrue(executor instanceof OrderedEventExecutor); diff --git a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java index b46fa4102497..c1ec93784a84 100644 --- a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java +++ b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java @@ -15,9 +15,11 @@ */ package io.netty.util.concurrent; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -25,6 +27,7 @@ import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -55,7 +58,19 @@ public void accept(GenericFutureListener> listener) { @Before public void setup() { MockitoAnnotations.initMocks(this); - combiner = new PromiseCombiner(); + combiner = new PromiseCombiner(ImmediateEventExecutor.INSTANCE); + } + + @Test + public void testNullArgument() { + try { + combiner.finish(null); + Assert.fail(); + } catch (NullPointerException expected) { + // expected + } + combiner.finish(p1); + verify(p1).trySuccess(null); } @Test @@ -148,6 +163,38 @@ public void testAddFail() throws Exception { verifyFail(p3, e1); } + @Test + public void testEventExecutor() { + EventExecutor executor = mock(EventExecutor.class); + when(executor.inEventLoop()).thenReturn(false); + combiner = new PromiseCombiner(executor); + + Future future = mock(Future.class); + + try { + combiner.add(future); + Assert.fail(); + } catch (IllegalStateException expected) { + // expected + } + + try { + combiner.addAll(future); + Assert.fail(); + } catch (IllegalStateException expected) { + // expected + } + + @SuppressWarnings("unchecked") + Promise promise = (Promise) mock(Promise.class); + try { + combiner.finish(promise); + Assert.fail(); + } catch (IllegalStateException expected) { + // expected + } + } + private static void verifyFail(Promise p, Throwable cause) { verify(p).tryFailure(eq(cause)); } diff --git a/common/src/test/java/io/netty/util/concurrent/ScheduledFutureTaskTest.java b/common/src/test/java/io/netty/util/concurrent/ScheduledFutureTaskTest.java new file mode 100644 index 000000000000..1c12692ff312 --- /dev/null +++ b/common/src/test/java/io/netty/util/concurrent/ScheduledFutureTaskTest.java @@ -0,0 +1,27 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import org.junit.Assert; +import org.junit.Test; + +public class ScheduledFutureTaskTest { + + @Test + public void testDeadlineNanosNotOverflow() { + Assert.assertEquals(Long.MAX_VALUE, ScheduledFutureTask.deadlineNanos(Long.MAX_VALUE)); + } +} diff --git a/common/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java b/common/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java index 55981b24060a..efb0eb015ea3 100644 --- a/common/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java +++ b/common/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java @@ -31,7 +31,7 @@ public class SingleThreadEventExecutorTest { @Test - public void testWrappedExecutureIsShutdown() { + public void testWrappedExecutorIsShutdown() { ExecutorService executorService = Executors.newSingleThreadExecutor(); SingleThreadEventExecutor executor = new SingleThreadEventExecutor(null, executorService, false) { diff --git a/common/src/test/java/io/netty/util/internal/AppendableCharSequenceTest.java b/common/src/test/java/io/netty/util/internal/AppendableCharSequenceTest.java index 2d7bab4ec849..9d08c3ee5771 100644 --- a/common/src/test/java/io/netty/util/internal/AppendableCharSequenceTest.java +++ b/common/src/test/java/io/netty/util/internal/AppendableCharSequenceTest.java @@ -64,6 +64,16 @@ public void testSubSequence() { assertEquals("abcdefghij", master.subSequence(0, 10).toString()); } + @Test + public void testEmptySubSequence() { + AppendableCharSequence master = new AppendableCharSequence(26); + master.append("abcdefghijlkmonpqrstuvwxyz"); + AppendableCharSequence sub = master.subSequence(0, 0); + assertEquals(0, sub.length()); + sub.append('b'); + assertEquals('b', sub.charAt(0)); + } + private static void testSimpleAppend0(AppendableCharSequence seq) { String text = "testdata"; for (int i = 0; i < text.length(); i++) { diff --git a/common/src/test/java/io/netty/util/internal/NativeLibraryLoaderTest.java b/common/src/test/java/io/netty/util/internal/NativeLibraryLoaderTest.java index de73b9ecf50f..e591b6cf3be1 100644 --- a/common/src/test/java/io/netty/util/internal/NativeLibraryLoaderTest.java +++ b/common/src/test/java/io/netty/util/internal/NativeLibraryLoaderTest.java @@ -15,11 +15,19 @@ */ package io.netty.util.internal; +import io.netty.util.CharsetUtil; import org.junit.Test; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; import java.util.UUID; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -65,4 +73,70 @@ private static void verifySuppressedException(UnsatisfiedLinkError error, throw new RuntimeException(e); } } + + @Test + public void testPatchingId() throws IOException { + testPatchingId0(true, false); + } + + @Test + public void testPatchingIdWithOsArch() throws IOException { + testPatchingId0(true, true); + } + + @Test + public void testPatchingIdNotMatch() throws IOException { + testPatchingId0(false, false); + } + + @Test + public void testPatchingIdWithOsArchNotMatch() throws IOException { + testPatchingId0(false, true); + } + + private static void testPatchingId0(boolean match, boolean withOsArch) throws IOException { + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + byte[] idBytes = ("/workspace/netty-tcnative/boringssl-static/target/" + + "native-build/target/lib/libnetty_tcnative-2.0.20.Final.jnilib").getBytes(CharsetUtil.UTF_8); + + String originalName; + if (match) { + originalName = "netty-tcnative"; + } else { + originalName = "nonexist_tcnative"; + } + String name = "shaded_" + originalName; + if (withOsArch) { + name += "_osx_x86_64"; + } + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + out.write(bytes, 0, bytes.length); + out.write(idBytes, 0, idBytes.length); + out.write(bytes, 0 , bytes.length); + + out.flush(); + byte[] inBytes = out.toByteArray(); + out.close(); + + InputStream inputStream = new ByteArrayInputStream(inBytes); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + assertEquals(match, + NativeLibraryLoader.patchShadedLibraryId(inputStream, outputStream, originalName, name)); + + outputStream.flush(); + byte[] outputBytes = outputStream.toByteArray(); + assertArrayEquals(bytes, Arrays.copyOfRange(outputBytes, 0, bytes.length)); + byte[] patchedId = Arrays.copyOfRange(outputBytes, bytes.length, bytes.length + idBytes.length); + assertEquals(!match, Arrays.equals(idBytes, patchedId)); + assertArrayEquals(bytes, + Arrays.copyOfRange(outputBytes, bytes.length + idBytes.length, outputBytes.length)); + assertEquals(inBytes.length, outputBytes.length); + } finally { + inputStream.close(); + outputStream.close(); + } + } } diff --git a/common/src/test/java/io/netty/util/internal/PlatformDependentTest.java b/common/src/test/java/io/netty/util/internal/PlatformDependentTest.java index 3a747deb6158..295fe30a14c6 100644 --- a/common/src/test/java/io/netty/util/internal/PlatformDependentTest.java +++ b/common/src/test/java/io/netty/util/internal/PlatformDependentTest.java @@ -17,14 +17,13 @@ import org.junit.Test; +import java.nio.ByteBuffer; import java.util.Random; import static io.netty.util.internal.PlatformDependent.hashCodeAscii; import static io.netty.util.internal.PlatformDependent.hashCodeAsciiSafe; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; +import static org.junit.Assume.assumeTrue; public class PlatformDependentTest { private static final Random r = new Random(); @@ -146,4 +145,13 @@ public void testHashCodeAscii() { hashCodeAscii(string)); } } + + @Test + public void testAllocateWithCapacity0() { + assumeTrue(PlatformDependent.hasDirectBufferNoCleanerConstructor()); + ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(0); + assertNotEquals(0, PlatformDependent.directBufferAddress(buffer)); + assertEquals(0, buffer.capacity()); + PlatformDependent.freeDirectNoCleaner(buffer); + } } diff --git a/common/src/test/java/io/netty/util/internal/logging/Log4J2LoggerTest.java b/common/src/test/java/io/netty/util/internal/logging/Log4J2LoggerTest.java index b561993ebd83..a72eb9a2bd4a 100644 --- a/common/src/test/java/io/netty/util/internal/logging/Log4J2LoggerTest.java +++ b/common/src/test/java/io/netty/util/internal/logging/Log4J2LoggerTest.java @@ -17,7 +17,6 @@ import static org.junit.Assert.assertEquals; -import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Arrays; @@ -29,7 +28,6 @@ import org.apache.logging.log4j.spi.ExtendedLoggerWrapper; import org.hamcrest.CoreMatchers; import org.junit.Assume; -import org.junit.Test; import io.netty.util.internal.ReflectionUtil; @@ -56,25 +54,6 @@ public void logMessage(String fqcn, Level level, Marker marker, Message message, }; } - @Test - public void testEXCEPTION_MESSAGE() { - assertEquals(getFieldValue(AbstractInternalLogger.class, "EXCEPTION_MESSAGE"), - getFieldValue(Log4J2Logger.class, "EXCEPTION_MESSAGE")); - } - - @SuppressWarnings("unchecked") - private static T getFieldValue(Class clazz, String fieldName) { - try { - Field field = clazz.getDeclaredField(fieldName); - if (!field.isAccessible()) { - Assume.assumeThat(ReflectionUtil.trySetAccessible(field, true), CoreMatchers.nullValue()); - } - return (T) field.get(AbstractInternalLogger.class); - } catch (ReflectiveOperationException e) { - throw new IllegalStateException(e); - } - } - @Override protected void setLevelEnable(InternalLogLevel level, boolean enable) throws Exception { Level targetLevel = Level.valueOf(level.name()); diff --git a/common/src/test/java/io/netty/util/internal/logging/Slf4JLoggerFactoryTest.java b/common/src/test/java/io/netty/util/internal/logging/Slf4JLoggerFactoryTest.java index 9ace34d5511b..9b08fe5397cb 100644 --- a/common/src/test/java/io/netty/util/internal/logging/Slf4JLoggerFactoryTest.java +++ b/common/src/test/java/io/netty/util/internal/logging/Slf4JLoggerFactoryTest.java @@ -16,15 +16,111 @@ package io.netty.util.internal.logging; import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.slf4j.Logger; +import org.slf4j.Marker; +import org.slf4j.spi.LocationAwareLogger; -import static org.junit.Assert.*; +import java.util.Iterator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; public class Slf4JLoggerFactoryTest { @Test public void testCreation() { InternalLogger logger = Slf4JLoggerFactory.INSTANCE.newInstance("foo"); - assertTrue(logger instanceof Slf4JLogger); + assertTrue(logger instanceof Slf4JLogger || logger instanceof LocationAwareSlf4JLogger); assertEquals("foo", logger.name()); } + + @Test + public void testCreationLogger() { + Logger logger = mock(Logger.class); + when(logger.getName()).thenReturn("testlogger"); + InternalLogger internalLogger = Slf4JLoggerFactory.wrapLogger(logger); + assertTrue(internalLogger instanceof Slf4JLogger); + assertEquals("testlogger", internalLogger.name()); + } + + @Test + public void testCreationLocationAwareLogger() { + Logger logger = mock(LocationAwareLogger.class); + when(logger.getName()).thenReturn("testlogger"); + InternalLogger internalLogger = Slf4JLoggerFactory.wrapLogger(logger); + assertTrue(internalLogger instanceof LocationAwareSlf4JLogger); + assertEquals("testlogger", internalLogger.name()); + } + + @Test + public void testFormatMessage() { + ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); + LocationAwareLogger logger = mock(LocationAwareLogger.class); + when(logger.isDebugEnabled()).thenReturn(true); + when(logger.isErrorEnabled()).thenReturn(true); + when(logger.isInfoEnabled()).thenReturn(true); + when(logger.isTraceEnabled()).thenReturn(true); + when(logger.isWarnEnabled()).thenReturn(true); + when(logger.getName()).thenReturn("testlogger"); + + InternalLogger internalLogger = Slf4JLoggerFactory.wrapLogger(logger); + internalLogger.debug("{}", "debug"); + internalLogger.debug("{} {}", "debug1", "debug2"); + internalLogger.debug("{} {} {}", "debug1", "debug2", "debug3"); + + internalLogger.error("{}", "error"); + internalLogger.error("{} {}", "error1", "error2"); + internalLogger.error("{} {} {}", "error1", "error2", "error3"); + + internalLogger.info("{}", "info"); + internalLogger.info("{} {}", "info1", "info2"); + internalLogger.info("{} {} {}", "info1", "info2", "info3"); + + internalLogger.trace("{}", "trace"); + internalLogger.trace("{} {}", "trace1", "trace2"); + internalLogger.trace("{} {} {}", "trace1", "trace2", "trace3"); + + internalLogger.warn("{}", "warn"); + internalLogger.warn("{} {}", "warn1", "warn2"); + internalLogger.warn("{} {} {}", "warn1", "warn2", "warn3"); + + verify(logger, times(3)).log(ArgumentMatchers.isNull(), eq(LocationAwareSlf4JLogger.FQCN), + eq(LocationAwareLogger.DEBUG_INT), captor.capture(), any(Object[].class), + ArgumentMatchers.isNull()); + verify(logger, times(3)).log(ArgumentMatchers.isNull(), eq(LocationAwareSlf4JLogger.FQCN), + eq(LocationAwareLogger.ERROR_INT), captor.capture(), any(Object[].class), + ArgumentMatchers.isNull()); + verify(logger, times(3)).log(ArgumentMatchers.isNull(), eq(LocationAwareSlf4JLogger.FQCN), + eq(LocationAwareLogger.INFO_INT), captor.capture(), any(Object[].class), + ArgumentMatchers.isNull()); + verify(logger, times(3)).log(ArgumentMatchers.isNull(), eq(LocationAwareSlf4JLogger.FQCN), + eq(LocationAwareLogger.TRACE_INT), captor.capture(), any(Object[].class), + ArgumentMatchers.isNull()); + verify(logger, times(3)).log(ArgumentMatchers.isNull(), eq(LocationAwareSlf4JLogger.FQCN), + eq(LocationAwareLogger.WARN_INT), captor.capture(), any(Object[].class), + ArgumentMatchers.isNull()); + + Iterator logMessages = captor.getAllValues().iterator(); + assertEquals("debug", logMessages.next()); + assertEquals("debug1 debug2", logMessages.next()); + assertEquals("debug1 debug2 debug3", logMessages.next()); + assertEquals("error", logMessages.next()); + assertEquals("error1 error2", logMessages.next()); + assertEquals("error1 error2 error3", logMessages.next()); + assertEquals("info", logMessages.next()); + assertEquals("info1 info2", logMessages.next()); + assertEquals("info1 info2 info3", logMessages.next()); + assertEquals("trace", logMessages.next()); + assertEquals("trace1 trace2", logMessages.next()); + assertEquals("trace1 trace2 trace3", logMessages.next()); + assertEquals("warn", logMessages.next()); + assertEquals("warn1 warn2", logMessages.next()); + assertEquals("warn1 warn2 warn3", logMessages.next()); + assertFalse(logMessages.hasNext()); + } } diff --git a/common/src/test/templates/io/netty/util/collection/KObjectHashMapTest.template b/common/src/test/templates/io/netty/util/collection/KObjectHashMapTest.template index 2be12a9ea81a..d4214f7fce6c 100644 --- a/common/src/test/templates/io/netty/util/collection/KObjectHashMapTest.template +++ b/common/src/test/templates/io/netty/util/collection/KObjectHashMapTest.template @@ -613,4 +613,34 @@ public class @K@ObjectHashMapTest { } assertTrue(map.isEmpty()); } + + @Test + public void valuesIteratorRemove() { + Value v1 = new Value("v1"); + Value v2 = new Value("v2"); + Value v3 = new Value("v3"); + map.put((@k@) 1, v1); + map.put((@k@) 2, v2); + map.put((@k@) 3, v3); + + Iterator it = map.values().iterator(); + + assertSame(v1, it.next()); + assertSame(v2, it.next()); + it.remove(); + + assertSame(v3, it.next()); + assertFalse(it.hasNext()); + + assertEquals(2, map.size()); + assertSame(v1, map.get((@k@) 1)); + assertNull(map.get((@k@) 2)); + assertSame(v3, map.get((@k@) 3)); + + it = map.values().iterator(); + + assertSame(v1, it.next()); + assertSame(v3, it.next()); + assertFalse(it.hasNext()); + } } diff --git a/dev-tools/pom.xml b/dev-tools/pom.xml index 68136e081c7b..c3fb1e662da0 100644 --- a/dev-tools/pom.xml +++ b/dev-tools/pom.xml @@ -25,7 +25,7 @@ io.netty netty-dev-tools - 4.1.25.5.dse + 4.1.34.3.dse Netty/Dev-Tools @@ -52,6 +52,5 @@ - netty-4.1.25.dse diff --git a/docker-datastax-release.sh b/docker-datastax-release.sh new file mode 100755 index 000000000000..40812056919e --- /dev/null +++ b/docker-datastax-release.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +if [ ! -f /usr/bin/docker ]; then + export DEBIAN_FRONTEND=noninteractive + sudo apt-get update + sudo apt-get install -y software-properties-common + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - + sudo add-apt-repository \ + "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" + sudo apt-get update + sudo apt-get install -y docker-ce +fi + +sudo docker build -f docker/Dockerfile-netty-centos6 -t netty-centos6 . +sudo docker run -t -v ~/.m2:/root/.m2 -v ~/.ssh:/root/.ssh -v ~/.gnupg:/root/.gnupg -v `pwd`:/code -w /code --entrypoint="" netty-centos6 bash -ic "mvn -B clean deploy -Partifactory -DskipTests -DaltDeploymentRepository=\"artifactory::default::https://repo.sjc.dsinternal.org/artifactory/datastax-releases-local\"" diff --git a/docker/Dockerfile-netty-centos6 b/docker/Dockerfile-netty-centos6 index d6e2339f5105..de868debd493 100644 --- a/docker/Dockerfile-netty-centos6 +++ b/docker/Dockerfile-netty-centos6 @@ -3,7 +3,7 @@ MAINTAINER netty@googlegroups.com ENTRYPOINT /bin/bash ENV SOURCE_DIR $HOME/source -ENV MAVEN_VERSION 3.5.2 +ENV MAVEN_VERSION 3.2.5 ENV JAVA_VERSION 1.8.0 RUN mkdir $SOURCE_DIR @@ -23,12 +23,14 @@ RUN yum install -y \ make \ openssl-devel \ tar \ - wget - + wget \ + libaio \ + libaio-devel RUN wget -q http://archive.apache.org/dist/maven/maven-3/$MAVEN_VERSION/binaries/apache-maven-$MAVEN_VERSION-bin.tar.gz && tar xfz apache-maven-$MAVEN_VERSION-bin.tar.gz && mv apache-maven-$MAVEN_VERSION /opt/ RUN echo 'PATH=/opt/apache-maven-$MAVEN_VERSION/bin:$PATH' >> ~/.bashrc + RUN echo 'export JAVA_HOME="/usr/lib/jvm/java-$JAVA_VERSION/"' >> ~/.bashrc RUN rm -rf $SOURCE_DIR diff --git a/docker/Dockerfile.centos b/docker/Dockerfile.centos new file mode 100644 index 000000000000..4d888209023f --- /dev/null +++ b/docker/Dockerfile.centos @@ -0,0 +1,27 @@ +ARG centos_version=6 +FROM centos:$centos_version +# needed to do again after FROM due to docker limitation +ARG centos_version + +# install dependencies +RUN yum install -y \ + apr-devel \ + autoconf \ + automake \ + git \ + glibc-devel \ + libtool \ + lksctp-tools \ + lsb-core \ + make \ + openssl-devel \ + tar \ + wget + +ARG java_version=1.8 +ENV JAVA_VERSION $java_version +# installing java with jabba +RUN curl -sL https://github.com/shyiko/jabba/raw/master/install.sh | JABBA_COMMAND="install $JAVA_VERSION -o /jdk" bash + +RUN echo 'export JAVA_HOME="/jdk"' >> ~/.bashrc +RUN echo 'PATH=/jdk/bin:$PATH' >> ~/.bashrc diff --git a/docker/README.md b/docker/README.md index 157236f5bbad..9970aaed5da8 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,15 +1,19 @@ +# Using the docker images -** Create a docker image ** ``` -docker build -f Dockerfile-netty-centos6 . -t netty-centos6 +cd /path/to/netty/ ``` -** Using the image ** +## centos 6 with java 8 ``` -cd /path/to/netty/ +docker-compose -f docker/docker-compose.yaml -f docker/docker-compose.centos-6.18.yaml run test ``` +## centos 7 with java 9 + ``` -docker run -it -v ~/.m2:/root/.m2 -v ~/.ssh:/root/.ssh -v ~/.gnupg:/root/.gnupg -v `pwd`:/code -w /code netty-centos6 bash +docker-compose -f docker/docker-compose.yaml -f docker/docker-compose.centos-7.19.yaml run test ``` + +etc, etc diff --git a/docker/docker-compose.centos-6.110.yaml b/docker/docker-compose.centos-6.110.yaml new file mode 100644 index 000000000000..144780de9ec9 --- /dev/null +++ b/docker/docker-compose.centos-6.110.yaml @@ -0,0 +1,21 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-6-1.10 + build: + args: + centos_version : "6" + java_version : "openjdk@1.10.0-2" + + test: + image: netty:centos-6-1.10 + test-leak: + image: netty:centos-6-1.10 + + test-boringssl-static: + image: netty:centos-6-1.10 + + shell: + image: netty:centos-6-1.10 diff --git a/docker/docker-compose.centos-6.111.yaml b/docker/docker-compose.centos-6.111.yaml new file mode 100644 index 000000000000..73618c749ea2 --- /dev/null +++ b/docker/docker-compose.centos-6.111.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-6-1.11 + build: + args: + centos_version : "6" + java_version : "openjdk@1.11.0-2" + + test: + image: netty:centos-6-1.11 + + test-leak: + image: netty:centos-6-1.11 + + test-boringssl-static: + image: netty:centos-6-1.11 + + shell: + image: netty:centos-6-1.11 diff --git a/docker/docker-compose.centos-6.112.yaml b/docker/docker-compose.centos-6.112.yaml new file mode 100644 index 000000000000..86a97a5ce824 --- /dev/null +++ b/docker/docker-compose.centos-6.112.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-6-1.12 + build: + args: + centos_version : "6" + java_version : "openjdk@1.12.0" + + test: + image: netty:centos-6-1.12 + + test-leak: + image: netty:centos-6-1.12 + + test-boringssl-static: + image: netty:centos-6-1.12 + + shell: + image: netty:centos-6-1.12 diff --git a/docker/docker-compose.centos-6.113.yaml b/docker/docker-compose.centos-6.113.yaml new file mode 100644 index 000000000000..78cf6c82ba54 --- /dev/null +++ b/docker/docker-compose.centos-6.113.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-6-1.13 + build: + args: + centos_version : "6" + java_version : "openjdk@1.13.0-9" + + test: + image: netty:centos-6-1.13 + + test-leak: + image: netty:centos-6-1.13 + + test-boringssl-static: + image: netty:centos-6-1.13 + + shell: + image: netty:centos-6-1.13 diff --git a/docker/docker-compose.centos-6.18.yaml b/docker/docker-compose.centos-6.18.yaml new file mode 100644 index 000000000000..ecde4daf0bf9 --- /dev/null +++ b/docker/docker-compose.centos-6.18.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-6-1.8 + build: + args: + centos_version : "6" + java_version : "1.8.202" + + test: + image: netty:centos-6-1.8 + + test-leak: + image: netty:centos-6-1.8 + + test-boringssl-static: + image: netty:centos-6-1.8 + + shell: + image: netty:centos-6-1.8 diff --git a/docker/docker-compose.centos-6.19.yaml b/docker/docker-compose.centos-6.19.yaml new file mode 100644 index 000000000000..93588d6e4e38 --- /dev/null +++ b/docker/docker-compose.centos-6.19.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-6-1.9 + build: + args: + centos_version : "6" + java_version : "openjdk@1.9.0-4" + + test: + image: netty:centos-6-1.9 + + test-leak: + image: netty:centos-6-1.9 + + test-boringssl-static: + image: netty:centos-6-1.9 + + shell: + image: netty:centos-6-1.9 diff --git a/docker/docker-compose.centos-6.openj9111.yaml b/docker/docker-compose.centos-6.openj9111.yaml new file mode 100644 index 000000000000..a7d08ddd51bb --- /dev/null +++ b/docker/docker-compose.centos-6.openj9111.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-6-openj9-1.11 + build: + args: + centos_version : "6" + java_version : "adopt-openj9@1.11.0-2" + + test: + image: netty:centos-6-openj9-1.11 + + test-leak: + image: netty:centos-6-openj9-1.11 + + test-boringssl-static: + image: netty:centos-6-openj9-1.11 + + shell: + image: netty:centos-6-openj9-1.11 diff --git a/docker/docker-compose.centos-7.110.yaml b/docker/docker-compose.centos-7.110.yaml new file mode 100644 index 000000000000..7f165208e1e0 --- /dev/null +++ b/docker/docker-compose.centos-7.110.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-7-1.10 + build: + args: + centos_version : "7" + java_version : "openjdk@1.10.0-2" + + test: + image: netty:centos-7-1.10 + + test-leak: + image: netty:centos-7-1.10 + + test-boringssl-static: + image: netty:centos-7-1.10 + + shell: + image: netty:centos-7-1.10 diff --git a/docker/docker-compose.centos-7.111.yaml b/docker/docker-compose.centos-7.111.yaml new file mode 100644 index 000000000000..75d635d25b3f --- /dev/null +++ b/docker/docker-compose.centos-7.111.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-7-1.11 + build: + args: + centos_version : "7" + java_version : "openjdk@1.11.0-2" + + test: + image: netty:centos-7-1.11 + + test-leak: + image: netty:centos-7-1.11 + + test-boringssl-static: + image: netty:centos-7-1.11 + + shell: + image: netty:centos-7-1.11 diff --git a/docker/docker-compose.centos-7.112.yaml b/docker/docker-compose.centos-7.112.yaml new file mode 100644 index 000000000000..627e8d34a0e3 --- /dev/null +++ b/docker/docker-compose.centos-7.112.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-7-1.12 + build: + args: + centos_version : "7" + java_version : "openjdk@1.12.0" + + test: + image: netty:centos-7-1.12 + + test-leak: + image: netty:centos-7-1.12 + + test-boringssl-static: + image: netty:centos-7-1.12 + + shell: + image: netty:centos-7-1.12 diff --git a/docker/docker-compose.centos-7.113.yaml b/docker/docker-compose.centos-7.113.yaml new file mode 100644 index 000000000000..d05d70a15805 --- /dev/null +++ b/docker/docker-compose.centos-7.113.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-7-1.13 + build: + args: + centos_version : "7" + java_version : "openjdk@1.13.0-9" + + test: + image: netty:centos-7-1.13 + + test-leak: + image: netty:centos-7-1.13 + + test-boringssl-static: + image: netty:centos-7-1.13 + + shell: + image: netty:centos-7-1.13 diff --git a/docker/docker-compose.centos-7.18.yaml b/docker/docker-compose.centos-7.18.yaml new file mode 100644 index 000000000000..327920119743 --- /dev/null +++ b/docker/docker-compose.centos-7.18.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-7-1.8 + build: + args: + centos_version : "7" + java_version : "1.8.202" + + test: + image: netty:centos-7-1.8 + + test-leak: + image: netty:centos-7-1.8 + + test-boringssl-static: + image: netty:centos-7-1.8 + + shell: + image: netty:centos-7-1.8 diff --git a/docker/docker-compose.centos-7.19.yaml b/docker/docker-compose.centos-7.19.yaml new file mode 100644 index 000000000000..5f1af4f11e79 --- /dev/null +++ b/docker/docker-compose.centos-7.19.yaml @@ -0,0 +1,22 @@ +version: "3" + +services: + + runtime-setup: + image: netty:centos-7-1.9 + build: + args: + centos_version : "7" + java_version : "openjdk@1.9.0-4" + + test: + image: netty:centos-7-1.9 + + test-leak: + image: netty:centos-7-1.9 + + test-boringssl-static: + image: netty:centos-7-1.9 + + shell: + image: netty:centos-7-1.9 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml new file mode 100644 index 000000000000..3e3bd2cf6e9f --- /dev/null +++ b/docker/docker-compose.yaml @@ -0,0 +1,42 @@ +version: "3" + +services: + + runtime-setup: + image: netty:default + build: + context: . + dockerfile: Dockerfile.centos + + common: &common + image: netty:default + depends_on: [runtime-setup] + volumes: + - ~/.ssh:/root/.ssh + - ~/.gnupg:/root/.gnupg + - ..:/code + working_dir: /code + + test-leak: + <<: *common + command: /bin/bash -cl "./mvnw -Pleak clean install -Dio.netty.testsuite.badHost=netty.io" + + test: + <<: *common + command: /bin/bash -cl "./mvnw clean install -Dio.netty.testsuite.badHost=netty.io" + + test-boringssl-static: + <<: *common + command: /bin/bash -cl "./mvnw -P boringssl clean install -Dio.netty.testsuite.badHost=netty.io -Dxml.skip=true" + + shell: + <<: *common + environment: + - SANOTYPE_USER + - SANOTYPE_PASSWORD + volumes: + - ~/.ssh:/root/.ssh + - ~/.gnupg:/root/.gnupg + - ..:/code + - ~/.m2:/root/.m2 + entrypoint: /bin/bash diff --git a/example/pom.xml b/example/pom.xml index 62e633befab6..b528fd6019d7 100644 --- a/example/pom.xml +++ b/example/pom.xml @@ -21,7 +21,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-example @@ -29,7 +29,20 @@ Netty/Example + + true + + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + ${project.groupId} netty-transport @@ -37,7 +50,7 @@ ${project.groupId} - netty-transport-sctp + netty-codec ${project.version} @@ -45,6 +58,11 @@ netty-handler ${project.version} + + ${project.groupId} + netty-transport-sctp + ${project.version} + ${project.groupId} netty-handler-proxy @@ -80,6 +98,7 @@ netty-codec-stomp ${project.version} + com.google.protobuf protobuf-java diff --git a/example/src/main/java/io/netty/example/echo/EchoServer.java b/example/src/main/java/io/netty/example/echo/EchoServer.java index ddc43fe042c9..b7b4a770595c 100644 --- a/example/src/main/java/io/netty/example/echo/EchoServer.java +++ b/example/src/main/java/io/netty/example/echo/EchoServer.java @@ -51,6 +51,7 @@ public static void main(String[] args) throws Exception { // Configure the server. EventLoopGroup bossGroup = new NioEventLoopGroup(1); EventLoopGroup workerGroup = new NioEventLoopGroup(); + final EchoServerHandler serverHandler = new EchoServerHandler(); try { ServerBootstrap b = new ServerBootstrap(); b.group(bossGroup, workerGroup) @@ -65,7 +66,7 @@ public void initChannel(SocketChannel ch) throws Exception { p.addLast(sslCtx.newHandler(ch.alloc())); } //p.addLast(new LoggingHandler(LogLevel.INFO)); - p.addLast(new EchoServerHandler()); + p.addLast(serverHandler); } }); diff --git a/example/src/main/java/io/netty/example/http/file/HttpStaticFileServerHandler.java b/example/src/main/java/io/netty/example/http/file/HttpStaticFileServerHandler.java index e2a801324dd1..d5e875102e49 100644 --- a/example/src/main/java/io/netty/example/http/file/HttpStaticFileServerHandler.java +++ b/example/src/main/java/io/netty/example/http/file/HttpStaticFileServerHandler.java @@ -117,11 +117,12 @@ public void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) thr return; } - if (request.method() != GET) { + if (!GET.equals(request.method())) { sendError(ctx, METHOD_NOT_ALLOWED); return; } + final boolean keepAlive = HttpUtil.isKeepAlive(request); final String uri = request.uri(); final String path = sanitizeUri(uri); if (path == null) { @@ -137,9 +138,9 @@ public void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) thr if (file.isDirectory()) { if (uri.endsWith("/")) { - sendListing(ctx, file, uri); + sendListing(ctx, file, uri, keepAlive); } else { - sendRedirect(ctx, uri + '/'); + sendRedirect(ctx, uri + '/', keepAlive); } return; } @@ -160,7 +161,7 @@ public void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) thr long ifModifiedSinceDateSeconds = ifModifiedSinceDate.getTime() / 1000; long fileLastModifiedSeconds = file.lastModified() / 1000; if (ifModifiedSinceDateSeconds == fileLastModifiedSeconds) { - sendNotModified(ctx); + sendNotModified(ctx, keepAlive); return; } } @@ -178,8 +179,9 @@ public void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) thr HttpUtil.setContentLength(response, fileLength); setContentTypeHeader(response, file); setDateAndCacheHeaders(response, file); - if (HttpUtil.isKeepAlive(request)) { - response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); + + if (!keepAlive) { + response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); } // Write the initial line and the header. @@ -218,7 +220,7 @@ public void operationComplete(ChannelProgressiveFuture future) { }); // Decide whether to close the connection or not. - if (!HttpUtil.isKeepAlive(request)) { + if (!keepAlive) { // Close the connection when the whole content is written out. lastContentFuture.addListener(ChannelFutureListener.CLOSE); } @@ -264,7 +266,7 @@ private static String sanitizeUri(String uri) { private static final Pattern ALLOWED_FILE_NAME = Pattern.compile("[^-\\._]?[^<>&\\\"]*"); - private static void sendListing(ChannelHandlerContext ctx, File dir, String dirPath) { + private static void sendListing(ChannelHandlerContext ctx, File dir, String dirPath, boolean keepAlive) { FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, OK); response.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html; charset=UTF-8"); @@ -304,16 +306,14 @@ private static void sendListing(ChannelHandlerContext ctx, File dir, String dirP response.content().writeBytes(buffer); buffer.release(); - // Close the connection as soon as the error message is sent. - ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + sendAndCleanupConnection(ctx, response, keepAlive); } - private static void sendRedirect(ChannelHandlerContext ctx, String newUri) { + private static void sendRedirect(ChannelHandlerContext ctx, String newUri, boolean keepAlive) { FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, FOUND); response.headers().set(HttpHeaderNames.LOCATION, newUri); - // Close the connection as soon as the error message is sent. - ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + sendAndCleanupConnection(ctx, response, keepAlive); } private static void sendError(ChannelHandlerContext ctx, HttpResponseStatus status) { @@ -321,8 +321,7 @@ private static void sendError(ChannelHandlerContext ctx, HttpResponseStatus stat HTTP_1_1, status, Unpooled.copiedBuffer("Failure: " + status + "\r\n", CharsetUtil.UTF_8)); response.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/plain; charset=UTF-8"); - // Close the connection as soon as the error message is sent. - ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + sendAndCleanupConnection(ctx, response, false); } /** @@ -331,12 +330,32 @@ private static void sendError(ChannelHandlerContext ctx, HttpResponseStatus stat * @param ctx * Context */ - private static void sendNotModified(ChannelHandlerContext ctx) { + private static void sendNotModified(ChannelHandlerContext ctx, boolean keepAlive) { FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, NOT_MODIFIED); setDateHeader(response); - // Close the connection as soon as the error message is sent. - ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + sendAndCleanupConnection(ctx, response, keepAlive); + } + + /** + * If Keep-Alive is disabled, attaches "Connection: close" header to the response + * and closes the connection after the response being sent. + */ + private static void sendAndCleanupConnection(ChannelHandlerContext ctx, FullHttpResponse response, + boolean keepAlive) { + HttpUtil.setContentLength(response, response.content().readableBytes()); + if (!keepAlive) { + // We're going to close the connection as soon as the response is sent, + // so we should also make it clear for the client. + response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); + } + + ChannelFuture flushPromise = ctx.writeAndFlush(response); + + if (!keepAlive) { + // Close the connection as soon as the response is sent. + flushPromise.addListener(ChannelFutureListener.CLOSE); + } } /** diff --git a/example/src/main/java/io/netty/example/http/helloworld/HttpHelloWorldServerHandler.java b/example/src/main/java/io/netty/example/http/helloworld/HttpHelloWorldServerHandler.java index af142dc11d7c..3faf51af415b 100644 --- a/example/src/main/java/io/netty/example/http/helloworld/HttpHelloWorldServerHandler.java +++ b/example/src/main/java/io/netty/example/http/helloworld/HttpHelloWorldServerHandler.java @@ -18,16 +18,18 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpObject; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpRequest; import io.netty.util.AsciiString; + import static io.netty.handler.codec.http.HttpResponseStatus.*; import static io.netty.handler.codec.http.HttpVersion.*; -public class HttpHelloWorldServerHandler extends ChannelInboundHandlerAdapter { +public class HttpHelloWorldServerHandler extends SimpleChannelInboundHandler { private static final byte[] CONTENT = { 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd' }; private static final AsciiString CONTENT_TYPE = AsciiString.cached("Content-Type"); @@ -41,7 +43,7 @@ public void channelReadComplete(ChannelHandlerContext ctx) { } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { + public void channelRead0(ChannelHandlerContext ctx, HttpObject msg) { if (msg instanceof HttpRequest) { HttpRequest req = (HttpRequest) msg; diff --git a/example/src/main/java/io/netty/example/http/upload/HttpUploadServerHandler.java b/example/src/main/java/io/netty/example/http/upload/HttpUploadServerHandler.java index 1246a29fd44c..8c9cd2948b0e 100644 --- a/example/src/main/java/io/netty/example/http/upload/HttpUploadServerHandler.java +++ b/example/src/main/java/io/netty/example/http/upload/HttpUploadServerHandler.java @@ -145,7 +145,7 @@ public void channelRead0(ChannelHandlerContext ctx, HttpObject msg) throws Excep responseContent.append("\r\n\r\n"); // if GET Method: should not try to create a HttpPostRequestDecoder - if (request.method().equals(HttpMethod.GET)) { + if (HttpMethod.GET.equals(request.method())) { // GET Method: should not try to create a HttpPostRequestDecoder // So stop here responseContent.append("\r\n\r\nEND OF GET CONTENT\r\n"); @@ -225,12 +225,8 @@ private void readHttpDataChunkByChunk() { logger.info(" 100% (FinalSize: " + partialContent.length() + ")"); partialContent = null; } - try { - // new value - writeHttpData(data); - } finally { - data.release(); - } + // new value + writeHttpData(data); } } // Check partial decoding for a FileUpload diff --git a/example/src/main/java/io/netty/example/http/websocketx/benchmarkserver/WebSocketServerHandler.java b/example/src/main/java/io/netty/example/http/websocketx/benchmarkserver/WebSocketServerHandler.java index 78f61ba14f45..a31d4f308e89 100644 --- a/example/src/main/java/io/netty/example/http/websocketx/benchmarkserver/WebSocketServerHandler.java +++ b/example/src/main/java/io/netty/example/http/websocketx/benchmarkserver/WebSocketServerHandler.java @@ -71,7 +71,7 @@ private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) { } // Allow only GET methods. - if (req.method() != GET) { + if (!GET.equals(req.method())) { sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN)); return; } diff --git a/example/src/main/java/io/netty/example/http/websocketx/server/WebSocketFrameHandler.java b/example/src/main/java/io/netty/example/http/websocketx/server/WebSocketFrameHandler.java index aa27c0163ea4..a4536682e13e 100644 --- a/example/src/main/java/io/netty/example/http/websocketx/server/WebSocketFrameHandler.java +++ b/example/src/main/java/io/netty/example/http/websocketx/server/WebSocketFrameHandler.java @@ -21,16 +21,12 @@ import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Echoes uppercase content of text frames. */ public class WebSocketFrameHandler extends SimpleChannelInboundHandler { - private static final Logger logger = LoggerFactory.getLogger(WebSocketFrameHandler.class); - @Override protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception { // ping and pong frames already handled @@ -38,7 +34,6 @@ protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) thr if (frame instanceof TextWebSocketFrame) { // Send the uppercase string back. String request = ((TextWebSocketFrame) frame).text(); - logger.info("{} received {}", ctx.channel(), request); ctx.channel().writeAndFlush(new TextWebSocketFrame(request.toUpperCase(Locale.US))); } else { String message = "unsupported frame type: " + frame.getClass().getName(); diff --git a/example/src/main/java/io/netty/example/http/websocketx/server/WebSocketIndexPageHandler.java b/example/src/main/java/io/netty/example/http/websocketx/server/WebSocketIndexPageHandler.java index 7d543ce39965..124415d12afb 100644 --- a/example/src/main/java/io/netty/example/http/websocketx/server/WebSocketIndexPageHandler.java +++ b/example/src/main/java/io/netty/example/http/websocketx/server/WebSocketIndexPageHandler.java @@ -58,7 +58,7 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) thro } // Allow only GET methods. - if (req.method() != GET) { + if (!GET.equals(req.method())) { sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN)); return; } diff --git a/example/src/main/java/io/netty/example/memcache/binary/MemcacheClientHandler.java b/example/src/main/java/io/netty/example/memcache/binary/MemcacheClientHandler.java index c2ee142f7b6e..1864ad33c267 100644 --- a/example/src/main/java/io/netty/example/memcache/binary/MemcacheClientHandler.java +++ b/example/src/main/java/io/netty/example/memcache/binary/MemcacheClientHandler.java @@ -69,6 +69,7 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) public void channelRead(ChannelHandlerContext ctx, Object msg) { FullBinaryMemcacheResponse res = (FullBinaryMemcacheResponse) msg; System.out.println(res.content().toString(CharsetUtil.UTF_8)); + res.release(); } @Override diff --git a/example/src/main/java/io/netty/example/ocsp/OcspClientExample.java b/example/src/main/java/io/netty/example/ocsp/OcspClientExample.java index f36001cab43c..814e44363d8a 100644 --- a/example/src/main/java/io/netty/example/ocsp/OcspClientExample.java +++ b/example/src/main/java/io/netty/example/ocsp/OcspClientExample.java @@ -157,7 +157,7 @@ private static class HttpClientHandler extends ChannelInboundHandlerAdapter { private final Promise promise; - public HttpClientHandler(String host, Promise promise) { + HttpClientHandler(String host, Promise promise) { this.host = host; this.promise = promise; } @@ -203,7 +203,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E private static class ExampleOcspClientHandler extends OcspClientHandler { - public ExampleOcspClientHandler(ReferenceCountedOpenSslEngine engine) { + ExampleOcspClientHandler(ReferenceCountedOpenSslEngine engine) { super(engine); } diff --git a/example/src/main/java/io/netty/example/sctp/SctpEchoServer.java b/example/src/main/java/io/netty/example/sctp/SctpEchoServer.java index 96493f4ddfa4..4346e399bdeb 100644 --- a/example/src/main/java/io/netty/example/sctp/SctpEchoServer.java +++ b/example/src/main/java/io/netty/example/sctp/SctpEchoServer.java @@ -37,6 +37,7 @@ public static void main(String[] args) throws Exception { // Configure the server. EventLoopGroup bossGroup = new NioEventLoopGroup(1); EventLoopGroup workerGroup = new NioEventLoopGroup(); + final SctpEchoServerHandler serverHandler = new SctpEchoServerHandler(); try { ServerBootstrap b = new ServerBootstrap(); b.group(bossGroup, workerGroup) @@ -48,7 +49,7 @@ public static void main(String[] args) throws Exception { public void initChannel(SctpChannel ch) throws Exception { ch.pipeline().addLast( //new LoggingHandler(LogLevel.INFO), - new SctpEchoServerHandler()); + serverHandler); } }); diff --git a/handler-proxy/pom.xml b/handler-proxy/pom.xml index 09a0dd87ccd3..24f20e6e7e67 100644 --- a/handler-proxy/pom.xml +++ b/handler-proxy/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-handler-proxy @@ -35,11 +35,26 @@ Netty/Handler/Proxy + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + ${project.groupId} netty-transport ${project.version} + + ${project.groupId} + netty-codec + ${project.version} + ${project.groupId} netty-codec-socks @@ -50,6 +65,7 @@ netty-codec-http ${project.version} + ${project.groupId} netty-handler diff --git a/handler-proxy/src/main/java/io/netty/handler/proxy/HttpProxyHandler.java b/handler-proxy/src/main/java/io/netty/handler/proxy/HttpProxyHandler.java index 2c41faba0fff..61088eb674f8 100644 --- a/handler-proxy/src/main/java/io/netty/handler/proxy/HttpProxyHandler.java +++ b/handler-proxy/src/main/java/io/netty/handler/proxy/HttpProxyHandler.java @@ -47,9 +47,10 @@ public final class HttpProxyHandler extends ProxyHandler { private final String username; private final String password; private final CharSequence authorization; + private final HttpHeaders outboundHeaders; private final boolean ignoreDefaultPortsInConnectHostHeader; private HttpResponseStatus status; - private HttpHeaders headers; + private HttpHeaders inboundHeaders; public HttpProxyHandler(SocketAddress proxyAddress) { this(proxyAddress, null); @@ -66,7 +67,7 @@ public HttpProxyHandler(SocketAddress proxyAddress, username = null; password = null; authorization = null; - this.headers = headers; + this.outboundHeaders = headers; this.ignoreDefaultPortsInConnectHostHeader = ignoreDefaultPortsInConnectHostHeader; } @@ -102,7 +103,7 @@ public HttpProxyHandler(SocketAddress proxyAddress, authz.release(); authzBase64.release(); - this.headers = headers; + this.outboundHeaders = headers; this.ignoreDefaultPortsInConnectHostHeader = ignoreDefaultPortsInConnectHostHeader; } @@ -163,8 +164,8 @@ protected Object newInitialMessage(ChannelHandlerContext ctx) throws Exception { req.headers().set(HttpHeaderNames.PROXY_AUTHORIZATION, authorization); } - if (headers != null) { - req.headers().add(headers); + if (outboundHeaders != null) { + req.headers().add(outboundHeaders); } return req; @@ -174,21 +175,48 @@ protected Object newInitialMessage(ChannelHandlerContext ctx) throws Exception { protected boolean handleResponse(ChannelHandlerContext ctx, Object response) throws Exception { if (response instanceof HttpResponse) { if (status != null) { - throw new ProxyConnectException(exceptionMessage("too many responses")); + throw new HttpProxyConnectException(exceptionMessage("too many responses"), /*headers=*/ null); } - status = ((HttpResponse) response).status(); + HttpResponse res = (HttpResponse) response; + status = res.status(); + inboundHeaders = res.headers(); } boolean finished = response instanceof LastHttpContent; if (finished) { if (status == null) { - throw new ProxyConnectException(exceptionMessage("missing response")); + throw new HttpProxyConnectException(exceptionMessage("missing response"), inboundHeaders); } if (status.code() != 200) { - throw new ProxyConnectException(exceptionMessage("status: " + status)); + throw new HttpProxyConnectException(exceptionMessage("status: " + status), inboundHeaders); } } return finished; } + + /** + * Specific case of a connection failure, which may include headers from the proxy. + */ + public static final class HttpProxyConnectException extends ProxyConnectException { + private static final long serialVersionUID = -8824334609292146066L; + + private final HttpHeaders headers; + + /** + * @param message The failure message. + * @param headers Header associated with the connection failure. May be {@code null}. + */ + public HttpProxyConnectException(String message, HttpHeaders headers) { + super(message); + this.headers = headers; + } + + /** + * Returns headers, if any. May be {@code null}. + */ + public HttpHeaders headers() { + return headers; + } + } } diff --git a/handler-proxy/src/test/java/io/netty/handler/proxy/HttpProxyHandlerTest.java b/handler-proxy/src/test/java/io/netty/handler/proxy/HttpProxyHandlerTest.java index 04747e44c09f..5f920ab1f279 100644 --- a/handler-proxy/src/test/java/io/netty/handler/proxy/HttpProxyHandlerTest.java +++ b/handler-proxy/src/test/java/io/netty/handler/proxy/HttpProxyHandlerTest.java @@ -15,20 +15,36 @@ */ package io.netty.handler.proxy; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.proxy.HttpProxyHandler.HttpProxyConnectException; import io.netty.util.NetUtil; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Test; import java.net.InetAddress; import java.net.InetSocketAddress; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.*; import static org.mockito.Mockito.*; public class HttpProxyHandlerTest { @@ -153,6 +169,65 @@ public void testCustomHeaders() throws Exception { true); } + @Test + public void testExceptionDuringConnect() throws Exception { + EventLoopGroup group = null; + Channel serverChannel = null; + Channel clientChannel = null; + try { + group = new DefaultEventLoopGroup(1); + final LocalAddress addr = new LocalAddress("a"); + final AtomicReference exception = new AtomicReference(); + ChannelFuture sf = + new ServerBootstrap().channel(LocalServerChannel.class).group(group).childHandler( + new ChannelInitializer() { + + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addFirst(new HttpResponseEncoder()); + DefaultFullHttpResponse response = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.BAD_GATEWAY); + response.headers().add("name", "value"); + response.headers().add(HttpHeaderNames.CONTENT_LENGTH, "0"); + ch.writeAndFlush(response); + } + }).bind(addr); + serverChannel = sf.sync().channel(); + ChannelFuture cf = new Bootstrap().channel(LocalChannel.class).group(group).handler( + new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addFirst(new HttpProxyHandler(addr)); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) { + exception.set(cause); + } + }); + } + }).connect(new InetSocketAddress("localhost", 1234)); + clientChannel = cf.sync().channel(); + clientChannel.close().sync(); + + assertTrue(exception.get() instanceof HttpProxyConnectException); + HttpProxyConnectException actual = (HttpProxyConnectException) exception.get(); + assertNotNull(actual.headers()); + assertEquals("value", actual.headers().get("name")); + } finally { + if (clientChannel != null) { + clientChannel.close(); + } + if (serverChannel != null) { + serverChannel.close(); + } + if (group != null) { + group.shutdownGracefully(); + } + } + } + private static void testInitialMessage(InetSocketAddress socketAddress, String expectedUrl, String expectedHostHeader, diff --git a/handler-proxy/src/test/java/io/netty/handler/proxy/ProxyHandlerTest.java b/handler-proxy/src/test/java/io/netty/handler/proxy/ProxyHandlerTest.java index e79bed7cf50f..77e20fb864c2 100644 --- a/handler-proxy/src/test/java/io/netty/handler/proxy/ProxyHandlerTest.java +++ b/handler-proxy/src/test/java/io/netty/handler/proxy/ProxyHandlerTest.java @@ -35,6 +35,7 @@ import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.LineBasedFrameDecoder; import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.resolver.NoopAddressResolverGroup; @@ -63,8 +64,8 @@ import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import java.util.Random; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; @@ -91,8 +92,8 @@ public class ProxyHandlerTest { SslContext cctx; try { SelfSignedCertificate ssc = new SelfSignedCertificate(); - sctx = SslContext.newServerContext(ssc.certificate(), ssc.privateKey()); - cctx = SslContext.newClientContext(InsecureTrustManagerFactory.INSTANCE); + sctx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()).build(); + cctx = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build(); } catch (Exception e) { throw new Error(e); } @@ -130,15 +131,27 @@ public class ProxyHandlerTest { deadSocks5Proxy, interSocks5Proxy, anonSocks5Proxy, socks5Proxy ); + // set to non-zero value in case you need predictable shuffling of test cases + // look for "Seed used: *" debug message in test logs + private static final long reproducibleSeed = 0L; + @Parameters(name = "{index}: {0}") public static List testItems() { + List items = Arrays.asList( // HTTP ------------------------------------------------------- new SuccessTestItem( - "Anonymous HTTP proxy: successful connection", + "Anonymous HTTP proxy: successful connection, AUTO_READ on", + DESTINATION, + true, + new HttpProxyHandler(anonHttpProxy.address())), + + new SuccessTestItem( + "Anonymous HTTP proxy: successful connection, AUTO_READ off", DESTINATION, + false, new HttpProxyHandler(anonHttpProxy.address())), new FailureTestItem( @@ -152,8 +165,15 @@ public static List testItems() { new HttpProxyHandler(httpProxy.address())), new SuccessTestItem( - "HTTP proxy: successful connection", + "HTTP proxy: successful connection, AUTO_READ on", + DESTINATION, + true, + new HttpProxyHandler(httpProxy.address(), USERNAME, PASSWORD)), + + new SuccessTestItem( + "HTTP proxy: successful connection, AUTO_READ off", DESTINATION, + false, new HttpProxyHandler(httpProxy.address(), USERNAME, PASSWORD)), new FailureTestItem( @@ -173,8 +193,16 @@ public static List testItems() { // HTTPS ------------------------------------------------------ new SuccessTestItem( - "Anonymous HTTPS proxy: successful connection", + "Anonymous HTTPS proxy: successful connection, AUTO_READ on", + DESTINATION, + true, + clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), + new HttpProxyHandler(anonHttpsProxy.address())), + + new SuccessTestItem( + "Anonymous HTTPS proxy: successful connection, AUTO_READ off", DESTINATION, + false, clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), new HttpProxyHandler(anonHttpsProxy.address())), @@ -191,8 +219,16 @@ public static List testItems() { new HttpProxyHandler(httpsProxy.address())), new SuccessTestItem( - "HTTPS proxy: successful connection", + "HTTPS proxy: successful connection, AUTO_READ on", DESTINATION, + true, + clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), + new HttpProxyHandler(httpsProxy.address(), USERNAME, PASSWORD)), + + new SuccessTestItem( + "HTTPS proxy: successful connection, AUTO_READ off", + DESTINATION, + false, clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), new HttpProxyHandler(httpsProxy.address(), USERNAME, PASSWORD)), @@ -216,8 +252,15 @@ public static List testItems() { // SOCKS4 ----------------------------------------------------- new SuccessTestItem( - "Anonymous SOCKS4: successful connection", + "Anonymous SOCKS4: successful connection, AUTO_READ on", + DESTINATION, + true, + new Socks4ProxyHandler(anonSocks4Proxy.address())), + + new SuccessTestItem( + "Anonymous SOCKS4: successful connection, AUTO_READ off", DESTINATION, + false, new Socks4ProxyHandler(anonSocks4Proxy.address())), new FailureTestItem( @@ -231,8 +274,15 @@ public static List testItems() { new Socks4ProxyHandler(socks4Proxy.address())), new SuccessTestItem( - "SOCKS4: successful connection", + "SOCKS4: successful connection, AUTO_READ on", DESTINATION, + true, + new Socks4ProxyHandler(socks4Proxy.address(), USERNAME)), + + new SuccessTestItem( + "SOCKS4: successful connection, AUTO_READ off", + DESTINATION, + false, new Socks4ProxyHandler(socks4Proxy.address(), USERNAME)), new FailureTestItem( @@ -252,8 +302,15 @@ public static List testItems() { // SOCKS5 ----------------------------------------------------- new SuccessTestItem( - "Anonymous SOCKS5: successful connection", + "Anonymous SOCKS5: successful connection, AUTO_READ on", DESTINATION, + true, + new Socks5ProxyHandler(anonSocks5Proxy.address())), + + new SuccessTestItem( + "Anonymous SOCKS5: successful connection, AUTO_READ off", + DESTINATION, + false, new Socks5ProxyHandler(anonSocks5Proxy.address())), new FailureTestItem( @@ -267,8 +324,15 @@ public static List testItems() { new Socks5ProxyHandler(socks5Proxy.address())), new SuccessTestItem( - "SOCKS5: successful connection", + "SOCKS5: successful connection, AUTO_READ on", + DESTINATION, + true, + new Socks5ProxyHandler(socks5Proxy.address(), USERNAME, PASSWORD)), + + new SuccessTestItem( + "SOCKS5: successful connection, AUTO_READ off", DESTINATION, + false, new Socks5ProxyHandler(socks5Proxy.address(), USERNAME, PASSWORD)), new FailureTestItem( @@ -288,8 +352,20 @@ public static List testItems() { // HTTP + HTTPS + SOCKS4 + SOCKS5 new SuccessTestItem( - "Single-chain: successful connection", + "Single-chain: successful connection, AUTO_READ on", DESTINATION, + true, + new Socks5ProxyHandler(interSocks5Proxy.address()), // SOCKS5 + new Socks4ProxyHandler(interSocks4Proxy.address()), // SOCKS4 + clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), + new HttpProxyHandler(interHttpsProxy.address()), // HTTPS + new HttpProxyHandler(interHttpProxy.address()), // HTTP + new HttpProxyHandler(anonHttpProxy.address())), + + new SuccessTestItem( + "Single-chain: successful connection, AUTO_READ off", + DESTINATION, + false, new Socks5ProxyHandler(interSocks5Proxy.address()), // SOCKS5 new Socks4ProxyHandler(interSocks4Proxy.address()), // SOCKS4 clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), @@ -300,8 +376,9 @@ public static List testItems() { // (HTTP + HTTPS + SOCKS4 + SOCKS5) * 2 new SuccessTestItem( - "Double-chain: successful connection", + "Double-chain: successful connection, AUTO_READ on", DESTINATION, + true, new Socks5ProxyHandler(interSocks5Proxy.address()), // SOCKS5 new Socks4ProxyHandler(interSocks4Proxy.address()), // SOCKS4 clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), @@ -312,8 +389,23 @@ public static List testItems() { clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), new HttpProxyHandler(interHttpsProxy.address()), // HTTPS new HttpProxyHandler(interHttpProxy.address()), // HTTP - new HttpProxyHandler(anonHttpProxy.address())) + new HttpProxyHandler(anonHttpProxy.address())), + new SuccessTestItem( + "Double-chain: successful connection, AUTO_READ off", + DESTINATION, + false, + new Socks5ProxyHandler(interSocks5Proxy.address()), // SOCKS5 + new Socks4ProxyHandler(interSocks4Proxy.address()), // SOCKS4 + clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), + new HttpProxyHandler(interHttpsProxy.address()), // HTTPS + new HttpProxyHandler(interHttpProxy.address()), // HTTP + new Socks5ProxyHandler(interSocks5Proxy.address()), // SOCKS5 + new Socks4ProxyHandler(interSocks4Proxy.address()), // SOCKS4 + clientSslCtx.newHandler(PooledByteBufAllocator.DEFAULT), + new HttpProxyHandler(interHttpsProxy.address()), // HTTPS + new HttpProxyHandler(interHttpProxy.address()), // HTTP + new HttpProxyHandler(anonHttpProxy.address())) ); // Convert the test items to the list of constructor parameters. @@ -323,7 +415,9 @@ public static List testItems() { } // Randomize the execution order to increase the possibility of exposing failure dependencies. - Collections.shuffle(params); + long seed = (reproducibleSeed == 0L) ? System.currentTimeMillis() : reproducibleSeed; + logger.debug("Seed used: {}\n", seed); + Collections.shuffle(params, new Random(seed)); return params; } @@ -516,8 +610,16 @@ public String toString() { private static final class SuccessTestItem extends TestItem { private final int expectedEventCount; - - SuccessTestItem(String name, InetSocketAddress destination, ChannelHandler... clientHandlers) { + // Probably we need to be more flexible here and as for the configuration map, + // not a single key. But as far as it works for now, I'm leaving the impl. + // as is, in case we need to cover more cases (like, AUTO_CLOSE, TCP_NODELAY etc) + // feel free to replace this boolean with either config or method to setup bootstrap + private final boolean autoRead; + + SuccessTestItem(String name, + InetSocketAddress destination, + boolean autoRead, + ChannelHandler... clientHandlers) { super(name, destination, clientHandlers); int expectedEventCount = 0; for (ChannelHandler h: clientHandlers) { @@ -527,6 +629,7 @@ private static final class SuccessTestItem extends TestItem { } this.expectedEventCount = expectedEventCount; + this.autoRead = autoRead; } @Override @@ -535,7 +638,7 @@ protected void test() throws Exception { Bootstrap b = new Bootstrap(); b.group(group); b.channel(NioSocketChannel.class); - b.option(ChannelOption.AUTO_READ, ThreadLocalRandom.current().nextBoolean()); + b.option(ChannelOption.AUTO_READ, this.autoRead); b.resolver(NoopAddressResolverGroup.INSTANCE); b.handler(new ChannelInitializer() { @Override diff --git a/handler/pom.xml b/handler/pom.xml index 2fe59853145d..028cea90bcf1 100644 --- a/handler/pom.xml +++ b/handler/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-handler @@ -35,6 +35,11 @@ Netty/Handler + + ${project.groupId} + netty-common + ${project.version} + ${project.groupId} netty-buffer diff --git a/handler/src/main/java/io/netty/handler/flush/FlushConsolidationHandler.java b/handler/src/main/java/io/netty/handler/flush/FlushConsolidationHandler.java index 12d31627f6ec..472a83bffcae 100644 --- a/handler/src/main/java/io/netty/handler/flush/FlushConsolidationHandler.java +++ b/handler/src/main/java/io/netty/handler/flush/FlushConsolidationHandler.java @@ -47,8 +47,8 @@ * high throughput, this gives the opportunity to process other flushes before the task gets executed, thus * batching multiple flushes into one. * - * If {@code explicitFlushAfterFlushes} is reached the flush will also be forwarded as well (whether while in a read - * loop, or while batching outside of a read loop). + * If {@code explicitFlushAfterFlushes} is reached the flush will be forwarded as well (whether while in a read loop, or + * while batching outside of a read loop). *

    * If the {@link Channel} becomes non-writable it will also try to execute any pending flush operations. *

    @@ -65,10 +65,17 @@ public class FlushConsolidationHandler extends ChannelDuplexHandler { private Future nextScheduledFlush; /** - * Create new instance which explicit flush after 256 pending flush operations latest. + * The default number of flushes after which a flush will be forwarded to downstream handlers (whether while in a + * read loop, or while batching outside of a read loop). + */ + public static final int DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES = 256; + + /** + * Create new instance which explicit flush after {@value DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES} pending flush + * operations at the latest. */ public FlushConsolidationHandler() { - this(256, false); + this(DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES, false); } /** diff --git a/handler/src/main/java/io/netty/handler/ipfilter/UniqueIpFilter.java b/handler/src/main/java/io/netty/handler/ipfilter/UniqueIpFilter.java index ac346481d734..15099bb2a314 100644 --- a/handler/src/main/java/io/netty/handler/ipfilter/UniqueIpFilter.java +++ b/handler/src/main/java/io/netty/handler/ipfilter/UniqueIpFilter.java @@ -38,10 +38,9 @@ public class UniqueIpFilter extends AbstractRemoteAddressFilter> o2j = PlatformDependent.newConcurrentHashMap(); + private static final Map j2oTls13; + private static final Map> o2jTls13; + + static { + Map j2oTls13Map = new HashMap(); + j2oTls13Map.put("TLS_AES_128_GCM_SHA256", "AEAD-AES128-GCM-SHA256"); + j2oTls13Map.put("TLS_AES_256_GCM_SHA384", "AEAD-AES256-GCM-SHA384"); + j2oTls13Map.put("TLS_CHACHA20_POLY1305_SHA256", "AEAD-CHACHA20-POLY1305-SHA256"); + j2oTls13 = Collections.unmodifiableMap(j2oTls13Map); + + Map> o2jTls13Map = new HashMap>(); + o2jTls13Map.put("TLS_AES_128_GCM_SHA256", singletonMap("TLS", "TLS_AES_128_GCM_SHA256")); + o2jTls13Map.put("TLS_AES_256_GCM_SHA384", singletonMap("TLS", "TLS_AES_256_GCM_SHA384")); + o2jTls13Map.put("TLS_CHACHA20_POLY1305_SHA256", singletonMap("TLS", "TLS_CHACHA20_POLY1305_SHA256")); + o2jTls13Map.put("AEAD-AES128-GCM-SHA256", singletonMap("TLS", "TLS_AES_128_GCM_SHA256")); + o2jTls13Map.put("AEAD-AES256-GCM-SHA384", singletonMap("TLS", "TLS_AES_256_GCM_SHA384")); + o2jTls13Map.put("AEAD-CHACHA20-POLY1305-SHA256", singletonMap("TLS", "TLS_CHACHA20_POLY1305_SHA256")); + o2jTls13 = Collections.unmodifiableMap(o2jTls13Map); + } + /** * Clears the cache for testing purpose. */ @@ -122,49 +145,26 @@ static boolean isO2JCached(String key, String protocol, String value) { } } - /** - * Converts the specified Java cipher suites to the colon-separated OpenSSL cipher suite specification. - */ - static String toOpenSsl(Iterable javaCipherSuites) { - final StringBuilder buf = new StringBuilder(); - for (String c: javaCipherSuites) { - if (c == null) { - break; - } - - String converted = toOpenSsl(c); - if (converted != null) { - c = converted; - } - - buf.append(c); - buf.append(':'); - } - - if (buf.length() > 0) { - buf.setLength(buf.length() - 1); - return buf.toString(); - } else { - return ""; - } - } - /** * Converts the specified Java cipher suite to its corresponding OpenSSL cipher suite name. * * @return {@code null} if the conversion has failed */ - static String toOpenSsl(String javaCipherSuite) { + static String toOpenSsl(String javaCipherSuite, boolean boringSSL) { String converted = j2o.get(javaCipherSuite); if (converted != null) { return converted; - } else { - return cacheFromJava(javaCipherSuite); } + return cacheFromJava(javaCipherSuite, boringSSL); } - private static String cacheFromJava(String javaCipherSuite) { - String openSslCipherSuite = toOpenSslUncached(javaCipherSuite); + private static String cacheFromJava(String javaCipherSuite, boolean boringSSL) { + String converted = j2oTls13.get(javaCipherSuite); + if (converted != null) { + return boringSSL ? converted : javaCipherSuite; + } + + String openSslCipherSuite = toOpenSslUncached(javaCipherSuite, boringSSL); if (openSslCipherSuite == null) { return null; } @@ -185,7 +185,12 @@ private static String cacheFromJava(String javaCipherSuite) { return openSslCipherSuite; } - static String toOpenSslUncached(String javaCipherSuite) { + static String toOpenSslUncached(String javaCipherSuite, boolean boringSSL) { + String converted = j2oTls13.get(javaCipherSuite); + if (converted != null) { + return boringSSL ? converted : javaCipherSuite; + } + Matcher m = JAVA_CIPHERSUITE_PATTERN.matcher(javaCipherSuite); if (!m.matches()) { return null; @@ -287,14 +292,23 @@ static String toJava(String openSslCipherSuite, String protocol) { String javaCipherSuite = p2j.get(protocol); if (javaCipherSuite == null) { - javaCipherSuite = protocol + '_' + p2j.get(""); + String cipher = p2j.get(""); + if (cipher == null) { + return null; + } + javaCipherSuite = protocol + '_' + cipher; } return javaCipherSuite; } private static Map cacheFromOpenSsl(String openSslCipherSuite) { - String javaCipherSuiteSuffix = toJavaUncached(openSslCipherSuite); + Map converted = o2jTls13.get(openSslCipherSuite); + if (converted != null) { + return converted; + } + + String javaCipherSuiteSuffix = toJavaUncached0(openSslCipherSuite, false); if (javaCipherSuiteSuffix == null) { return null; } @@ -320,6 +334,17 @@ private static Map cacheFromOpenSsl(String openSslCipherSuite) { } static String toJavaUncached(String openSslCipherSuite) { + return toJavaUncached0(openSslCipherSuite, true); + } + + private static String toJavaUncached0(String openSslCipherSuite, boolean checkTls13) { + if (checkTls13) { + Map converted = o2jTls13.get(openSslCipherSuite); + if (converted != null) { + return converted.get("TLS"); + } + } + Matcher m = OPENSSL_CIPHERSUITE_PATTERN.matcher(openSslCipherSuite); if (!m.matches()) { return null; @@ -423,5 +448,47 @@ private static String toJavaHmacAlgo(String hmacAlgo) { return hmacAlgo; } + /** + * Convert the given ciphers if needed to OpenSSL format and append them to the correct {@link StringBuilder} + * depending on if its a TLSv1.3 cipher or not. If this methods returns without throwing an exception its + * guaranteed that at least one of the {@link StringBuilder}s contain some ciphers that can be used to configure + * OpenSSL. + */ + static void convertToCipherStrings(Iterable cipherSuites, StringBuilder cipherBuilder, + StringBuilder cipherTLSv13Builder, boolean boringSSL) { + for (String c: cipherSuites) { + if (c == null) { + break; + } + + String converted = toOpenSsl(c, boringSSL); + if (converted == null) { + converted = c; + } + + if (!OpenSsl.isCipherSuiteAvailable(converted)) { + throw new IllegalArgumentException("unsupported cipher suite: " + c + '(' + converted + ')'); + } + + if (SslUtils.isTLSv13Cipher(converted) || SslUtils.isTLSv13Cipher(c)) { + cipherTLSv13Builder.append(converted); + cipherTLSv13Builder.append(':'); + } else { + cipherBuilder.append(converted); + cipherBuilder.append(':'); + } + } + + if (cipherBuilder.length() == 0 && cipherTLSv13Builder.length() == 0) { + throw new IllegalArgumentException("empty cipher suites"); + } + if (cipherBuilder.length() > 0) { + cipherBuilder.setLength(cipherBuilder.length() - 1); + } + if (cipherTLSv13Builder.length() > 0) { + cipherTLSv13Builder.setLength(cipherTLSv13Builder.length() - 1); + } + } + private CipherSuiteConverter() { } } diff --git a/handler/src/main/java/io/netty/handler/ssl/ConscryptAlpnSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ConscryptAlpnSslEngine.java index b13dd7566c88..d9767a710627 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ConscryptAlpnSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ConscryptAlpnSslEngine.java @@ -71,7 +71,7 @@ private ConscryptAlpnSslEngine(SSLEngine engine, ByteBufAllocator alloc, List leakDetector = + ResourceLeakDetectorFactory.instance().newResourceLeakDetector(DefaultOpenSslKeyMaterial.class); + private final ResourceLeakTracker leak; + private final X509Certificate[] x509CertificateChain; + private long chain; + private long privateKey; + + DefaultOpenSslKeyMaterial(long chain, long privateKey, X509Certificate[] x509CertificateChain) { + this.chain = chain; + this.privateKey = privateKey; + this.x509CertificateChain = x509CertificateChain; + leak = leakDetector.track(this); + } + + @Override + public X509Certificate[] certificateChain() { + return x509CertificateChain.clone(); + } + + @Override + public long certificateChainAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } + return chain; + } + + @Override + public long privateKeyAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } + return privateKey; + } + + @Override + protected void deallocate() { + SSL.freeX509Chain(chain); + chain = 0; + SSL.freePrivateKey(privateKey); + privateKey = 0; + if (leak != null) { + boolean closed = leak.close(this); + assert closed; + } + } + + @Override + public DefaultOpenSslKeyMaterial retain() { + if (leak != null) { + leak.record(); + } + super.retain(); + return this; + } + + @Override + public DefaultOpenSslKeyMaterial retain(int increment) { + if (leak != null) { + leak.record(); + } + super.retain(increment); + return this; + } + + @Override + public DefaultOpenSslKeyMaterial touch() { + if (leak != null) { + leak.record(); + } + super.touch(); + return this; + } + + @Override + public DefaultOpenSslKeyMaterial touch(Object hint) { + if (leak != null) { + leak.record(hint); + } + return this; + } + + @Override + public boolean release() { + if (leak != null) { + leak.record(); + } + return super.release(); + } + + @Override + public boolean release(int decrement) { + if (leak != null) { + leak.record(); + } + return super.release(decrement); + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/DelegatingSslContext.java b/handler/src/main/java/io/netty/handler/ssl/DelegatingSslContext.java index af027982f230..c79da85458ee 100644 --- a/handler/src/main/java/io/netty/handler/ssl/DelegatingSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/DelegatingSslContext.java @@ -21,6 +21,7 @@ import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLSessionContext; import java.util.List; +import java.util.concurrent.Executor; /** * Adapter class which allows to wrap another {@link SslContext} and init {@link SSLEngine} instances. @@ -86,6 +87,21 @@ protected final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, i return handler; } + @Override + protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls, Executor executor) { + SslHandler handler = ctx.newHandler(alloc, startTls, executor); + initHandler(handler); + return handler; + } + + @Override + protected SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, + boolean startTls, Executor executor) { + SslHandler handler = ctx.newHandler(alloc, peerHost, peerPort, startTls, executor); + initHandler(handler); + return handler; + } + @Override public final SSLSessionContext sessionContext() { return ctx.sessionContext(); diff --git a/handler/src/main/java/io/netty/handler/ssl/ExtendedOpenSslSession.java b/handler/src/main/java/io/netty/handler/ssl/ExtendedOpenSslSession.java new file mode 100644 index 000000000000..184845a9fe1f --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/ExtendedOpenSslSession.java @@ -0,0 +1,178 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.ExtendedSSLSession; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSessionContext; +import javax.security.cert.X509Certificate; +import java.security.Principal; +import java.security.cert.Certificate; +import java.util.Collections; +import java.util.List; + +/** + * Delegates all operations to a wrapped {@link OpenSslSession} except the methods defined by {@link ExtendedSSLSession} + * itself. + */ +abstract class ExtendedOpenSslSession extends ExtendedSSLSession implements OpenSslSession { + + // TODO: use OpenSSL API to actually fetch the real data but for now just do what Conscrypt does: + // https://github.com/google/conscrypt/blob/1.2.0/common/ + // src/main/java/org/conscrypt/Java7ExtendedSSLSession.java#L32 + private static final String[] LOCAL_SUPPORTED_SIGNATURE_ALGORITHMS = { + "SHA512withRSA", "SHA512withECDSA", "SHA384withRSA", "SHA384withECDSA", "SHA256withRSA", + "SHA256withECDSA", "SHA224withRSA", "SHA224withECDSA", "SHA1withRSA", "SHA1withECDSA", + }; + + private final OpenSslSession wrapped; + + ExtendedOpenSslSession(OpenSslSession wrapped) { + this.wrapped = wrapped; + } + + // Use rawtypes an unchecked override to be able to also work on java7. + @SuppressWarnings({ "unchecked", "rawtypes" }) + public abstract List getRequestedServerNames(); + + // Do not mark as override so we can compile on java8. + public List getStatusResponses() { + // Just return an empty list for now until we support it as otherwise we will fail in java9 + // because of their sun.security.ssl.X509TrustManagerImpl class. + return Collections.emptyList(); + } + + @Override + public final void handshakeFinished() throws SSLException { + wrapped.handshakeFinished(); + } + + @Override + public final void tryExpandApplicationBufferSize(int packetLengthDataOnly) { + wrapped.tryExpandApplicationBufferSize(packetLengthDataOnly); + } + + @Override + public final String[] getLocalSupportedSignatureAlgorithms() { + return LOCAL_SUPPORTED_SIGNATURE_ALGORITHMS.clone(); + } + + @Override + public final byte[] getId() { + return wrapped.getId(); + } + + @Override + public final SSLSessionContext getSessionContext() { + return wrapped.getSessionContext(); + } + + @Override + public final long getCreationTime() { + return wrapped.getCreationTime(); + } + + @Override + public final long getLastAccessedTime() { + return wrapped.getLastAccessedTime(); + } + + @Override + public final void invalidate() { + wrapped.invalidate(); + } + + @Override + public final boolean isValid() { + return wrapped.isValid(); + } + + @Override + public final void putValue(String s, Object o) { + wrapped.putValue(s, o); + } + + @Override + public final Object getValue(String s) { + return wrapped.getValue(s); + } + + @Override + public final void removeValue(String s) { + wrapped.removeValue(s); + } + + @Override + public final String[] getValueNames() { + return wrapped.getValueNames(); + } + + @Override + public final Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return wrapped.getPeerCertificates(); + } + + @Override + public final Certificate[] getLocalCertificates() { + return wrapped.getLocalCertificates(); + } + + @Override + public final X509Certificate[] getPeerCertificateChain() throws SSLPeerUnverifiedException { + return wrapped.getPeerCertificateChain(); + } + + @Override + public final Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return wrapped.getPeerPrincipal(); + } + + @Override + public final Principal getLocalPrincipal() { + return wrapped.getLocalPrincipal(); + } + + @Override + public final String getCipherSuite() { + return wrapped.getCipherSuite(); + } + + @Override + public String getProtocol() { + return wrapped.getProtocol(); + } + + @Override + public final String getPeerHost() { + return wrapped.getPeerHost(); + } + + @Override + public final int getPeerPort() { + return wrapped.getPeerPort(); + } + + @Override + public final int getPacketBufferSize() { + return wrapped.getPacketBufferSize(); + } + + @Override + public final int getApplicationBufferSize() { + return wrapped.getApplicationBufferSize(); + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/IdentityCipherSuiteFilter.java b/handler/src/main/java/io/netty/handler/ssl/IdentityCipherSuiteFilter.java index cfd5bb7b2a9f..9f2198103470 100644 --- a/handler/src/main/java/io/netty/handler/ssl/IdentityCipherSuiteFilter.java +++ b/handler/src/main/java/io/netty/handler/ssl/IdentityCipherSuiteFilter.java @@ -46,8 +46,8 @@ public String[] filterCipherSuites(Iterable ciphers, List defaul Set supportedCiphers) { if (ciphers == null) { return defaultToDefaultCiphers ? - defaultCiphers.toArray(new String[defaultCiphers.size()]) : - supportedCiphers.toArray(new String[supportedCiphers.size()]); + defaultCiphers.toArray(new String[0]) : + supportedCiphers.toArray(new String[0]); } else { List newCiphers = new ArrayList(supportedCiphers.size()); for (String c : ciphers) { @@ -56,7 +56,7 @@ public String[] filterCipherSuites(Iterable ciphers, List defaul } newCiphers.add(c); } - return newCiphers.toArray(new String[newCiphers.size()]); + return newCiphers.toArray(new String[0]); } } } diff --git a/handler/src/main/java/io/netty/handler/ssl/Java8SslUtils.java b/handler/src/main/java/io/netty/handler/ssl/Java8SslUtils.java index 583d4cf49867..c47d96571218 100644 --- a/handler/src/main/java/io/netty/handler/ssl/Java8SslUtils.java +++ b/handler/src/main/java/io/netty/handler/ssl/Java8SslUtils.java @@ -48,11 +48,25 @@ static List getSniHostNames(SSLParameters sslParameters) { } static void setSniHostNames(SSLParameters sslParameters, List names) { + sslParameters.setServerNames(getSniHostNames(names)); + } + + static List getSniHostNames(List names) { + if (names == null || names.isEmpty()) { + return Collections.emptyList(); + } List sniServerNames = new ArrayList(names.size()); for (String name: names) { sniServerNames.add(new SNIHostName(name)); } - sslParameters.setServerNames(sniServerNames); + return sniServerNames; + } + + static List getSniHostName(byte[] hostname) { + if (hostname == null || hostname.length == 0) { + return Collections.emptyList(); + } + return Collections.singletonList(new SNIHostName(hostname)); } static boolean getUseCipherSuitesOrder(SSLParameters sslParameters) { @@ -69,7 +83,7 @@ static void setSNIMatchers(SSLParameters sslParameters, Collection matchers) } @SuppressWarnings("unchecked") - static boolean checkSniHostnameMatch(Collection matchers, String hostname) { + static boolean checkSniHostnameMatch(Collection matchers, byte[] hostname) { if (matchers != null && !matchers.isEmpty()) { SNIHostName name = new SNIHostName(hostname); Iterator matcherIt = (Iterator) matchers.iterator(); diff --git a/handler/src/main/java/io/netty/handler/ssl/JdkSslContext.java b/handler/src/main/java/io/netty/handler/ssl/JdkSslContext.java index 41fd9c6f6a8f..d74bbdec981f 100644 --- a/handler/src/main/java/io/netty/handler/ssl/JdkSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/JdkSslContext.java @@ -17,6 +17,7 @@ package io.netty.handler.ssl; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -26,6 +27,7 @@ import java.security.KeyException; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; +import java.security.Provider; import java.security.Security; import java.security.UnrecoverableKeyException; import java.security.cert.CertificateException; @@ -34,6 +36,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Set; @@ -58,11 +61,13 @@ public class JdkSslContext extends SslContext { static final String PROTOCOL = "TLS"; private static final String[] DEFAULT_PROTOCOLS; private static final List DEFAULT_CIPHERS; + private static final List DEFAULT_CIPHERS_NON_TLSV13; private static final Set SUPPORTED_CIPHERS; + private static final Set SUPPORTED_CIPHERS_NON_TLSV13; + private static final Provider DEFAULT_PROVIDER; static { SSLContext context; - int i; try { context = SSLContext.getInstance(PROTOCOL); context.init(null, null, null); @@ -70,31 +75,54 @@ public class JdkSslContext extends SslContext { throw new Error("failed to initialize the default SSL context", e); } + DEFAULT_PROVIDER = context.getProvider(); + SSLEngine engine = context.createSSLEngine(); + DEFAULT_PROTOCOLS = defaultProtocols(engine); + + SUPPORTED_CIPHERS = Collections.unmodifiableSet(supportedCiphers(engine)); + DEFAULT_CIPHERS = Collections.unmodifiableList(defaultCiphers(engine, SUPPORTED_CIPHERS)); + + List ciphersNonTLSv13 = new ArrayList(DEFAULT_CIPHERS); + ciphersNonTLSv13.removeAll(Arrays.asList(SslUtils.DEFAULT_TLSV13_CIPHER_SUITES)); + DEFAULT_CIPHERS_NON_TLSV13 = Collections.unmodifiableList(ciphersNonTLSv13); + + Set suppertedCiphersNonTLSv13 = new LinkedHashSet(SUPPORTED_CIPHERS); + suppertedCiphersNonTLSv13.removeAll(Arrays.asList(SslUtils.DEFAULT_TLSV13_CIPHER_SUITES)); + SUPPORTED_CIPHERS_NON_TLSV13 = Collections.unmodifiableSet(suppertedCiphersNonTLSv13); + + if (logger.isDebugEnabled()) { + logger.debug("Default protocols (JDK): {} ", Arrays.asList(DEFAULT_PROTOCOLS)); + logger.debug("Default cipher suites (JDK): {}", DEFAULT_CIPHERS); + } + } + private static String[] defaultProtocols(SSLEngine engine) { // Choose the sensible default list of protocols. final String[] supportedProtocols = engine.getSupportedProtocols(); Set supportedProtocolsSet = new HashSet(supportedProtocols.length); - for (i = 0; i < supportedProtocols.length; ++i) { + for (int i = 0; i < supportedProtocols.length; ++i) { supportedProtocolsSet.add(supportedProtocols[i]); } List protocols = new ArrayList(); addIfSupported( supportedProtocolsSet, protocols, - "TLSv1.2", "TLSv1.1", "TLSv1"); + // Do not include TLSv1.3 for now by default. + SslUtils.PROTOCOL_TLS_V1_2, SslUtils.PROTOCOL_TLS_V1_1, SslUtils.PROTOCOL_TLS_V1); if (!protocols.isEmpty()) { - DEFAULT_PROTOCOLS = protocols.toArray(new String[protocols.size()]); - } else { - DEFAULT_PROTOCOLS = engine.getEnabledProtocols(); + return protocols.toArray(new String[0]); } + return engine.getEnabledProtocols(); + } + private static Set supportedCiphers(SSLEngine engine) { // Choose the sensible default list of cipher suites. final String[] supportedCiphers = engine.getSupportedCipherSuites(); - SUPPORTED_CIPHERS = new HashSet(supportedCiphers.length); - for (i = 0; i < supportedCiphers.length; ++i) { + Set supportedCiphersSet = new LinkedHashSet(supportedCiphers.length); + for (int i = 0; i < supportedCiphers.length; ++i) { String supportedCipher = supportedCiphers[i]; - SUPPORTED_CIPHERS.add(supportedCipher); + supportedCiphersSet.add(supportedCipher); // IBM's J9 JVM utilizes a custom naming scheme for ciphers and only returns ciphers with the "SSL_" // prefix instead of the "TLS_" prefix (as defined in the JSSE cipher suite names [1]). According to IBM's // documentation [2] the "SSL_" prefix is "interchangeable" with the "TLS_" prefix. @@ -108,21 +136,29 @@ public class JdkSslContext extends SslContext { final String tlsPrefixedCipherName = "TLS_" + supportedCipher.substring("SSL_".length()); try { engine.setEnabledCipherSuites(new String[]{tlsPrefixedCipherName}); - SUPPORTED_CIPHERS.add(tlsPrefixedCipherName); + supportedCiphersSet.add(tlsPrefixedCipherName); } catch (IllegalArgumentException ignored) { // The cipher is not supported ... move on to the next cipher. } } } + return supportedCiphersSet; + } + + private static List defaultCiphers(SSLEngine engine, Set supportedCiphers) { List ciphers = new ArrayList(); - addIfSupported(SUPPORTED_CIPHERS, ciphers, DEFAULT_CIPHER_SUITES); + addIfSupported(supportedCiphers, ciphers, DEFAULT_CIPHER_SUITES); useFallbackCiphersIfDefaultIsEmpty(ciphers, engine.getEnabledCipherSuites()); - DEFAULT_CIPHERS = Collections.unmodifiableList(ciphers); + return ciphers; + } - if (logger.isDebugEnabled()) { - logger.debug("Default protocols (JDK): {} ", Arrays.asList(DEFAULT_PROTOCOLS)); - logger.debug("Default cipher suites (JDK): {}", DEFAULT_CIPHERS); + private static boolean isTlsV13Supported(String[] protocols) { + for (String protocol: protocols) { + if (SslUtils.PROTOCOL_TLS_V1_3.equals(protocol)) { + return true; + } } + return false; } private final String[] protocols; @@ -140,7 +176,10 @@ public class JdkSslContext extends SslContext { * @param sslContext the {@link SSLContext} to use. * @param isClient {@code true} if this context should create {@link SSLEngine}s for client-side usage. * @param clientAuth the {@link ClientAuth} to use. This will only be used when {@param isClient} is {@code false}. + * @deprecated Use {@link #JdkSslContext(SSLContext, boolean, Iterable, CipherSuiteFilter, + * ApplicationProtocolConfig, ClientAuth, String[], boolean)} */ + @Deprecated public JdkSslContext(SSLContext sslContext, boolean isClient, ClientAuth clientAuth) { this(sslContext, isClient, null, IdentityCipherSuiteFilter.INSTANCE, @@ -156,11 +195,44 @@ public JdkSslContext(SSLContext sslContext, boolean isClient, * @param cipherFilter the filter to use. * @param apn the {@link ApplicationProtocolConfig} to use. * @param clientAuth the {@link ClientAuth} to use. This will only be used when {@param isClient} is {@code false}. + * @deprecated Use {@link #JdkSslContext(SSLContext, boolean, Iterable, CipherSuiteFilter, + * ApplicationProtocolConfig, ClientAuth, String[], boolean)} */ + @Deprecated public JdkSslContext(SSLContext sslContext, boolean isClient, Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, ClientAuth clientAuth) { - this(sslContext, isClient, ciphers, cipherFilter, toNegotiator(apn, !isClient), clientAuth, null, false); + this(sslContext, isClient, ciphers, cipherFilter, apn, clientAuth, null, false); + } + + /** + * Creates a new {@link JdkSslContext} from a pre-configured {@link SSLContext}. + * + * @param sslContext the {@link SSLContext} to use. + * @param isClient {@code true} if this context should create {@link SSLEngine}s for client-side usage. + * @param ciphers the ciphers to use or {@code null} if the standard should be used. + * @param cipherFilter the filter to use. + * @param apn the {@link ApplicationProtocolConfig} to use. + * @param clientAuth the {@link ClientAuth} to use. This will only be used when {@param isClient} is {@code false}. + * @param protocols the protocols to enable, or {@code null} to enable the default protocols. + * @param startTls {@code true} if the first write request shouldn't be encrypted + */ + public JdkSslContext(SSLContext sslContext, + boolean isClient, + Iterable ciphers, + CipherSuiteFilter cipherFilter, + ApplicationProtocolConfig apn, + ClientAuth clientAuth, + String[] protocols, + boolean startTls) { + this(sslContext, + isClient, + ciphers, + cipherFilter, + toNegotiator(apn, !isClient), + clientAuth, + protocols == null ? null : protocols.clone(), + startTls); } @SuppressWarnings("deprecation") @@ -169,11 +241,49 @@ public JdkSslContext(SSLContext sslContext, boolean isClient, Iterable c super(startTls); this.apn = checkNotNull(apn, "apn"); this.clientAuth = checkNotNull(clientAuth, "clientAuth"); + this.sslContext = checkNotNull(sslContext, "sslContext"); + + final List defaultCiphers; + final Set supportedCiphers; + if (DEFAULT_PROVIDER.equals(sslContext.getProvider())) { + this.protocols = protocols == null? DEFAULT_PROTOCOLS : protocols; + if (isTlsV13Supported(this.protocols)) { + supportedCiphers = SUPPORTED_CIPHERS; + defaultCiphers = DEFAULT_CIPHERS; + } else { + // TLSv1.3 is not supported, ensure we do not include any TLSv1.3 ciphersuite. + supportedCiphers = SUPPORTED_CIPHERS_NON_TLSV13; + defaultCiphers = DEFAULT_CIPHERS_NON_TLSV13; + } + } else { + // This is a different Provider then the one used by the JDK by default so we can not just assume + // the same protocols and ciphers are supported. For example even if Java11+ is used Conscrypt will + // not support TLSv1.3 and the TLSv1.3 ciphersuites. + SSLEngine engine = sslContext.createSSLEngine(); + try { + if (protocols == null) { + this.protocols = defaultProtocols(engine); + } else { + this.protocols = protocols; + } + supportedCiphers = supportedCiphers(engine); + defaultCiphers = defaultCiphers(engine, supportedCiphers); + if (!isTlsV13Supported(this.protocols)) { + // TLSv1.3 is not supported, ensure we do not include any TLSv1.3 ciphersuite. + for (String cipher: SslUtils.DEFAULT_TLSV13_CIPHER_SUITES) { + supportedCiphers.remove(cipher); + defaultCiphers.remove(cipher); + } + } + } finally { + ReferenceCountUtil.release(engine); + } + } + cipherSuites = checkNotNull(cipherFilter, "cipherFilter").filterCipherSuites( - ciphers, DEFAULT_CIPHERS, SUPPORTED_CIPHERS); - this.protocols = protocols == null ? DEFAULT_PROTOCOLS : protocols; + ciphers, defaultCiphers, supportedCiphers); + unmodifiableCipherSuites = Collections.unmodifiableList(Arrays.asList(cipherSuites)); - this.sslContext = checkNotNull(sslContext, "sslContext"); this.isClient = isClient; } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java b/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java index 165b5cbc896e..78a528ff985d 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSsl.java @@ -17,7 +17,7 @@ package io.netty.handler.ssl; import io.netty.buffer.ByteBuf; -import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.buffer.ByteBufAllocator; import io.netty.internal.tcnative.Buffer; import io.netty.internal.tcnative.Library; import io.netty.internal.tcnative.SSL; @@ -30,24 +30,19 @@ import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import java.io.ByteArrayInputStream; import java.security.AccessController; import java.security.PrivilegedAction; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Set; -import static io.netty.handler.ssl.SslUtils.DEFAULT_CIPHER_SUITES; -import static io.netty.handler.ssl.SslUtils.addIfSupported; -import static io.netty.handler.ssl.SslUtils.useFallbackCiphersIfDefaultIsEmpty; -import static io.netty.handler.ssl.SslUtils.PROTOCOL_SSL_V2; -import static io.netty.handler.ssl.SslUtils.PROTOCOL_SSL_V2_HELLO; -import static io.netty.handler.ssl.SslUtils.PROTOCOL_SSL_V3; -import static io.netty.handler.ssl.SslUtils.PROTOCOL_TLS_V1; -import static io.netty.handler.ssl.SslUtils.PROTOCOL_TLS_V1_1; -import static io.netty.handler.ssl.SslUtils.PROTOCOL_TLS_V1_2; +import static io.netty.handler.ssl.SslUtils.*; /** * Tells if {@code netty-tcnative} and its OpenSSL support @@ -66,7 +61,8 @@ public final class OpenSsl { private static final boolean SUPPORTS_HOSTNAME_VALIDATION; private static final boolean USE_KEYMANAGER_FACTORY; private static final boolean SUPPORTS_OCSP; - + private static final boolean TLSV13_SUPPORTED; + private static final boolean IS_BORINGSSL; static final Set SUPPORTED_PROTOCOLS_SET; static { @@ -105,7 +101,13 @@ public final class OpenSsl { } try { - initializeTcNative(); + String engine = SystemPropertyUtil.get("io.netty.handler.ssl.openssl.engine", null); + if (engine == null) { + logger.debug("Initialize netty-tcnative using engine: 'default'"); + } else { + logger.debug("Initialize netty-tcnative using engine: '{}'", engine); + } + initializeTcNative(engine); // The library was initialized successfully. If loading the library failed above, // reset the cause now since it appears that the library was loaded by some other @@ -133,22 +135,59 @@ public final class OpenSsl { boolean supportsKeyManagerFactory = false; boolean useKeyManagerFactory = false; boolean supportsHostNameValidation = false; + boolean tlsv13Supported = false; + + IS_BORINGSSL = "BoringSSL".equals(versionString()); + try { final long sslCtx = SSLContext.make(SSL.SSL_PROTOCOL_ALL, SSL.SSL_MODE_SERVER); long certBio = 0; - SelfSignedCertificate cert = null; try { - SSLContext.setCipherSuite(sslCtx, "ALL"); + try { + StringBuilder tlsv13Ciphers = new StringBuilder(); + + for (String cipher: TLSV13_CIPHERS) { + String converted = CipherSuiteConverter.toOpenSsl(cipher, IS_BORINGSSL); + if (converted != null) { + tlsv13Ciphers.append(converted).append(':'); + } + } + if (tlsv13Ciphers.length() == 0) { + tlsv13Supported = false; + } else { + tlsv13Ciphers.setLength(tlsv13Ciphers.length() - 1); + SSLContext.setCipherSuite(sslCtx, tlsv13Ciphers.toString() , true); + tlsv13Supported = true; + } + + } catch (Exception ignore) { + tlsv13Supported = false; + } + + SSLContext.setCipherSuite(sslCtx, "ALL", false); + final long ssl = SSL.newSSL(sslCtx, true); try { for (String c: SSL.getCiphers(ssl)) { // Filter out bad input. - if (c == null || c.isEmpty() || availableOpenSslCipherSuites.contains(c)) { + if (c == null || c.isEmpty() || availableOpenSslCipherSuites.contains(c) || + // Filter out TLSv1.3 ciphers if not supported. + !tlsv13Supported && isTLSv13Cipher(c)) { continue; } availableOpenSslCipherSuites.add(c); } - + if (IS_BORINGSSL) { + // Currently BoringSSL does not include these when calling SSL.getCiphers() even when these + // are supported. + Collections.addAll(availableOpenSslCipherSuites, + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384" , + "TLS_CHACHA20_POLY1305_SHA256", + "AEAD-AES128-GCM-SHA256", + "AEAD-AES256-GCM-SHA384", + "AEAD-CHACHA20-POLY1305-SHA256"); + } try { SSL.setHostNameValidation(ssl, 0, "netty.io"); supportsHostNameValidation = true; @@ -156,8 +195,8 @@ public final class OpenSsl { logger.debug("Hostname Verification not supported."); } try { - cert = new SelfSignedCertificate(); - certBio = ReferenceCountedOpenSslContext.toBIO(cert.cert()); + X509Certificate certificate = selfSignedCertificate(); + certBio = ReferenceCountedOpenSslContext.toBIO(ByteBufAllocator.DEFAULT, certificate); SSL.setCertificateChainBio(ssl, certBio, false); supportsKeyManagerFactory = true; try { @@ -179,9 +218,6 @@ public Boolean run() { if (certBio != 0) { SSL.freeBIO(certBio); } - if (cert != null) { - cert.delete(); - } } } finally { SSLContext.free(sslCtx); @@ -194,11 +230,18 @@ public Boolean run() { AVAILABLE_OPENSSL_CIPHER_SUITES.size() * 2); for (String cipher: AVAILABLE_OPENSSL_CIPHER_SUITES) { // Included converted but also openssl cipher name - availableJavaCipherSuites.add(CipherSuiteConverter.toJava(cipher, "TLS")); - availableJavaCipherSuites.add(CipherSuiteConverter.toJava(cipher, "SSL")); + if (!isTLSv13Cipher(cipher)) { + availableJavaCipherSuites.add(CipherSuiteConverter.toJava(cipher, "TLS")); + availableJavaCipherSuites.add(CipherSuiteConverter.toJava(cipher, "SSL")); + } else { + // TLSv1.3 ciphers have the correct format. + availableJavaCipherSuites.add(cipher); + } } addIfSupported(availableJavaCipherSuites, defaultCiphers, DEFAULT_CIPHER_SUITES); + addIfSupported(availableJavaCipherSuites, defaultCiphers, TLSV13_CIPHER_SUITES); + useFallbackCiphersIfDefaultIsEmpty(defaultCiphers, availableJavaCipherSuites); DEFAULT_CIPHERS = Collections.unmodifiableList(defaultCiphers); @@ -217,27 +260,35 @@ public Boolean run() { Set protocols = new LinkedHashSet(6); // Seems like there is no way to explicitly disable SSLv2Hello in openssl so it is always enabled protocols.add(PROTOCOL_SSL_V2_HELLO); - if (doesSupportProtocol(SSL.SSL_PROTOCOL_SSLV2)) { + if (doesSupportProtocol(SSL.SSL_PROTOCOL_SSLV2, SSL.SSL_OP_NO_SSLv2)) { protocols.add(PROTOCOL_SSL_V2); } - if (doesSupportProtocol(SSL.SSL_PROTOCOL_SSLV3)) { + if (doesSupportProtocol(SSL.SSL_PROTOCOL_SSLV3, SSL.SSL_OP_NO_SSLv3)) { protocols.add(PROTOCOL_SSL_V3); } - if (doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1)) { + if (doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1, SSL.SSL_OP_NO_TLSv1)) { protocols.add(PROTOCOL_TLS_V1); } - if (doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1_1)) { + if (doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1_1, SSL.SSL_OP_NO_TLSv1_1)) { protocols.add(PROTOCOL_TLS_V1_1); } - if (doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1_2)) { + if (doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1_2, SSL.SSL_OP_NO_TLSv1_2)) { protocols.add(PROTOCOL_TLS_V1_2); } + // This is only supported by java11 and later. + if (tlsv13Supported && doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1_3, SSL.SSL_OP_NO_TLSv1_3)) { + protocols.add(PROTOCOL_TLS_V1_3); + TLSV13_SUPPORTED = true; + } else { + TLSV13_SUPPORTED = false; + } + SUPPORTED_PROTOCOLS_SET = Collections.unmodifiableSet(protocols); SUPPORTS_OCSP = doesSupportOcsp(); if (logger.isDebugEnabled()) { - logger.debug("Supported protocols (OpenSSL): {} ", Arrays.asList(SUPPORTED_PROTOCOLS_SET)); + logger.debug("Supported protocols (OpenSSL): {} ", SUPPORTED_PROTOCOLS_SET); logger.debug("Default cipher suites (OpenSSL): {}", DEFAULT_CIPHERS); } } else { @@ -250,9 +301,50 @@ public Boolean run() { USE_KEYMANAGER_FACTORY = false; SUPPORTED_PROTOCOLS_SET = Collections.emptySet(); SUPPORTS_OCSP = false; + TLSV13_SUPPORTED = false; + IS_BORINGSSL = false; } } + /** + * Returns a self-signed {@link X509Certificate} for {@code netty.io}. + */ + static X509Certificate selfSignedCertificate() throws CertificateException { + // Bytes of self-signed certificate for netty.io + byte[] certBytes = { + 48, -126, 1, -92, 48, -126, 1, 13, -96, 3, 2, 1, 2, 2, 9, 0, -9, 61, + 44, 121, -118, -4, -45, -120, 48, 13, 6, 9, 42, -122, 72, -122, + -9, 13, 1, 1, 5, 5, 0, 48, 19, 49, 17, 48, 15, 6, 3, 85, 4, 3, 19, + 8, 110, 101, 116, 116, 121, 46, 105, 111, 48, 32, 23, 13, 49, 55, + 49, 48, 50, 48, 49, 56, 49, 54, 51, 54, 90, 24, 15, 57, 57, 57, 57, + 49, 50, 51, 49, 50, 51, 53, 57, 53, 57, 90, 48, 19, 49, 17, 48, 15, + 6, 3, 85, 4, 3, 19, 8, 110, 101, 116, 116, 121, 46, 105, 111, 48, -127, + -97, 48, 13, 6, 9, 42, -122, 72, -122, -9, 13, 1, 1, 1, 5, 0, 3, -127, + -115, 0, 48, -127, -119, 2, -127, -127, 0, -116, 37, 122, -53, 28, 46, + 13, -90, -14, -33, 111, -108, -41, 59, 90, 124, 113, -112, -66, -17, + -102, 44, 13, 7, -33, -28, 24, -79, -126, -76, 40, 111, -126, -103, + -102, 34, 11, 45, 16, -38, 63, 24, 80, 24, 76, 88, -93, 96, 11, 38, + -19, -64, -11, 87, -49, -52, -65, 24, 36, -22, 53, 8, -42, 14, -121, + 114, 6, 17, -82, 10, 92, -91, -127, 81, -12, -75, 105, -10, -106, 91, + -38, 111, 50, 57, -97, -125, 109, 42, -87, -1, -19, 80, 78, 49, -97, -4, + 23, -2, -103, 122, -107, -43, 4, -31, -21, 90, 39, -9, -106, 34, -101, + -116, 31, -94, -84, 80, -6, -78, -33, 87, -90, 31, 103, 100, 56, -103, + -5, 11, 2, 3, 1, 0, 1, 48, 13, 6, 9, 42, -122, 72, -122, -9, 13, 1, 1, + 5, 5, 0, 3, -127, -127, 0, 112, 45, -73, 5, 64, 49, 59, 101, 51, 73, + -96, 62, 23, -84, 90, -41, -58, 83, -20, -72, 38, 123, -108, -45, 28, + 96, -122, -18, 30, 42, 86, 87, -87, -28, 107, 110, 11, -59, 91, 100, + 101, -18, 26, -103, -78, -80, -3, 38, 113, 83, -48, -108, 109, 41, -15, + 6, 112, 105, 7, -46, -11, -3, -51, 40, -66, -73, -83, -46, -94, -121, + -88, 51, -106, -77, 109, 53, -7, 123, 91, 75, -105, -22, 64, 121, -72, + -59, -21, -44, 84, 12, 9, 120, 21, -26, 13, 49, -81, -58, -47, 117, + -44, -18, -17, 124, 49, -48, 19, 16, -41, 71, -52, -107, 99, -19, -29, + 105, -93, -71, -38, -97, -128, -2, 118, 119, 49, -126, 109, 119 }; + + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + return (X509Certificate) cf.generateCertificate( + new ByteArrayInputStream(certBytes)); + } + private static boolean doesSupportOcsp() { boolean supportsOcsp = false; if (version() >= 0x10002000L) { @@ -271,7 +363,11 @@ private static boolean doesSupportOcsp() { } return supportsOcsp; } - private static boolean doesSupportProtocol(int protocol) { + private static boolean doesSupportProtocol(int protocol, int opt) { + if (opt == 0) { + // If the opt is 0 the protocol is not supported. This is for example the case with BoringSSL and SSLv2. + return false; + } long sslCtx = -1; try { sslCtx = SSLContext.make(protocol, SSL.SSL_MODE_COMBINED); @@ -377,7 +473,7 @@ public static Set availableJavaCipherSuites() { * Both Java-style cipher suite and OpenSSL-style cipher suite are accepted. */ public static boolean isCipherSuiteAvailable(String cipherSuite) { - String converted = CipherSuiteConverter.toOpenSsl(cipherSuite); + String converted = CipherSuiteConverter.toOpenSsl(cipherSuite, IS_BORINGSSL); if (converted != null) { cipherSuite = converted; } @@ -428,11 +524,11 @@ private static void loadTcNative() throws Exception { libNames.add(staticLibName); NativeLibraryLoader.loadFirstAvailable(SSL.class.getClassLoader(), - libNames.toArray(new String[libNames.size()])); + libNames.toArray(new String[0])); } - private static boolean initializeTcNative() throws Exception { - return Library.initialize(); + private static boolean initializeTcNative(String engine) throws Exception { + return Library.initialize("provided", engine); } static void releaseIfNeeded(ReferenceCounted counted) { @@ -440,4 +536,12 @@ static void releaseIfNeeded(ReferenceCounted counted) { ReferenceCountUtil.safeRelease(counted); } } + + static boolean isTlsv13Supported() { + return TLSV13_SUPPORTED; + } + + static boolean isBoringSSL() { + return IS_BORINGSSL; + } } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java new file mode 100644 index 000000000000..db8779bdbd7a --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java @@ -0,0 +1,68 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; + +import javax.net.ssl.X509KeyManager; +import java.util.Iterator; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * {@link OpenSslKeyMaterialProvider} that will cache the {@link OpenSslKeyMaterial} to reduce the overhead + * of parsing the chain and the key for generation of the material. + */ +final class OpenSslCachingKeyMaterialProvider extends OpenSslKeyMaterialProvider { + + private final ConcurrentMap cache = new ConcurrentHashMap(); + + OpenSslCachingKeyMaterialProvider(X509KeyManager keyManager, String password) { + super(keyManager, password); + } + + @Override + OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) throws Exception { + OpenSslKeyMaterial material = cache.get(alias); + if (material == null) { + material = super.chooseKeyMaterial(allocator, alias); + if (material == null) { + // No keymaterial should be used. + return null; + } + + OpenSslKeyMaterial old = cache.putIfAbsent(alias, material); + if (old != null) { + material.release(); + material = old; + } + } + // We need to call retain() as we want to always have at least a refCnt() of 1 before destroy() was called. + return material.retain(); + } + + @Override + void destroy() { + // Remove and release all entries. + do { + Iterator iterator = cache.values().iterator(); + while (iterator.hasNext()) { + iterator.next().release(); + iterator.remove(); + } + } while (!cache.isEmpty()); + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingX509KeyManagerFactory.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingX509KeyManagerFactory.java new file mode 100644 index 000000000000..6581d9b73244 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslCachingX509KeyManagerFactory.java @@ -0,0 +1,60 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.X509KeyManager; +import java.security.InvalidAlgorithmParameterException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.UnrecoverableKeyException; +import java.security.cert.X509Certificate; + +/** + * Wraps another {@link KeyManagerFactory} and caches its chains / certs for an alias for better performance when using + * {@link SslProvider#OPENSSL} or {@link SslProvider#OPENSSL_REFCNT}. + * + * Because of the caching its important that the wrapped {@link KeyManagerFactory}s {@link X509KeyManager}s always + * return the same {@link X509Certificate} chain and {@link PrivateKey} for the same alias. + */ +public final class OpenSslCachingX509KeyManagerFactory extends KeyManagerFactory { + + public OpenSslCachingX509KeyManagerFactory(final KeyManagerFactory factory) { + super(new KeyManagerFactorySpi() { + @Override + protected void engineInit(KeyStore keyStore, char[] chars) + throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + factory.init(keyStore, chars); + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) + throws InvalidAlgorithmParameterException { + factory.init(managerFactoryParameters); + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + return factory.getKeyManagers(); + } + }, factory.getProvider(), factory.getAlgorithm()); + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslCertificateException.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslCertificateException.java index 4672d00787b9..f20b2d3ba08e 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslCertificateException.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslCertificateException.java @@ -70,7 +70,9 @@ public int errorCode() { } private static int checkErrorCode(int errorCode) { - if (!CertificateVerifier.isValid(errorCode)) { + // Call OpenSsl.isAvailable() to ensure we try to load the native lib as CertificateVerifier.isValid(...) + // will depend on it. If loading fails we will just skip the validation. + if (OpenSsl.isAvailable() && !CertificateVerifier.isValid(errorCode)) { throw new IllegalArgumentException("errorCode '" + errorCode + "' invalid, see https://www.openssl.org/docs/man1.0.2/apps/verify.html."); } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java index 46412e9f526b..6856c7f2746c 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java @@ -203,9 +203,4 @@ public OpenSslClientContext(File trustCertCollectionFile, TrustManagerFactory tr public OpenSslSessionContext sessionContext() { return sessionContext; } - - @Override - OpenSslKeyMaterialManager keyMaterialManager() { - return null; - } } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslExtendedKeyMaterialManager.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslExtendedKeyMaterialManager.java deleted file mode 100644 index 38f6a7f723b9..000000000000 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslExtendedKeyMaterialManager.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2016 The Netty Project - * - * The Netty Project licenses this file to you under the Apache License, - * version 2.0 (the "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ -package io.netty.handler.ssl; - -import javax.net.ssl.X509ExtendedKeyManager; -import javax.security.auth.x500.X500Principal; - -final class OpenSslExtendedKeyMaterialManager extends OpenSslKeyMaterialManager { - - private final X509ExtendedKeyManager keyManager; - - OpenSslExtendedKeyMaterialManager(X509ExtendedKeyManager keyManager, String password) { - super(keyManager, password); - this.keyManager = keyManager; - } - - @Override - protected String chooseClientAlias(ReferenceCountedOpenSslEngine engine, String[] keyTypes, - X500Principal[] issuer) { - return keyManager.chooseEngineClientAlias(keyTypes, issuer, engine); - } - - @Override - protected String chooseServerAlias(ReferenceCountedOpenSslEngine engine, String type) { - return keyManager.chooseEngineServerAlias(type, null, engine); - } -} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslJavaxX509Certificate.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslJavaxX509Certificate.java index da10dedf0c1a..af52ddc92ae4 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslJavaxX509Certificate.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslJavaxX509Certificate.java @@ -32,7 +32,7 @@ final class OpenSslJavaxX509Certificate extends X509Certificate { private final byte[] bytes; private X509Certificate wrapped; - public OpenSslJavaxX509Certificate(byte[] bytes) { + OpenSslJavaxX509Certificate(byte[] bytes) { this.bytes = bytes; } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java new file mode 100644 index 000000000000..68fc85a3ed76 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java @@ -0,0 +1,59 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.ReferenceCounted; + +import java.security.cert.X509Certificate; + +/** + * Holds references to the native key-material that is used by OpenSSL. + */ +interface OpenSslKeyMaterial extends ReferenceCounted { + + /** + * Returns the configured {@link X509Certificate}s. + */ + X509Certificate[] certificateChain(); + + /** + * Returns the pointer to the {@code STACK_OF(X509)} which holds the certificate chain. + */ + long certificateChainAddress(); + + /** + * Returns the pointer to the {@code EVP_PKEY}. + */ + long privateKeyAddress(); + + @Override + OpenSslKeyMaterial retain(); + + @Override + OpenSslKeyMaterial retain(int increment); + + @Override + OpenSslKeyMaterial touch(); + + @Override + OpenSslKeyMaterial touch(Object hint); + + @Override + boolean release(); + + @Override + boolean release(int decrement); +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java index 2e48e8b04be8..9f0e0199efb2 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java @@ -15,11 +15,10 @@ */ package io.netty.handler.ssl; -import io.netty.buffer.ByteBufAllocator; -import io.netty.internal.tcnative.CertificateRequestedCallback; import io.netty.internal.tcnative.SSL; import javax.net.ssl.SSLException; +import javax.net.ssl.X509ExtendedKeyManager; import javax.net.ssl.X509KeyManager; import javax.security.auth.x500.X500Principal; import java.security.PrivateKey; @@ -29,14 +28,12 @@ import java.util.Map; import java.util.Set; -import static io.netty.handler.ssl.ReferenceCountedOpenSslContext.freeBio; -import static io.netty.handler.ssl.ReferenceCountedOpenSslContext.toBIO; /** * Manages key material for {@link OpenSslEngine}s and so set the right {@link PrivateKey}s and * {@link X509Certificate}s. */ -class OpenSslKeyMaterialManager { +final class OpenSslKeyMaterialManager { // Code in this class is inspired by code of conscrypts: // - https://android.googlesource.com/platform/external/ @@ -62,15 +59,13 @@ class OpenSslKeyMaterialManager { KEY_TYPES.put("DH_RSA", KEY_TYPE_DH_RSA); } - private final X509KeyManager keyManager; - private final String password; + private final OpenSslKeyMaterialProvider provider; - OpenSslKeyMaterialManager(X509KeyManager keyManager, String password) { - this.keyManager = keyManager; - this.password = password; + OpenSslKeyMaterialManager(OpenSslKeyMaterialProvider provider) { + this.provider = provider; } - void setKeyMaterial(ReferenceCountedOpenSslEngine engine) throws SSLException { + void setKeyMaterialServerSide(ReferenceCountedOpenSslEngine engine) throws SSLException { long ssl = engine.sslPointer(); String[] authMethods = SSL.authenticationMethods(ssl); Set aliases = new HashSet(authMethods.length); @@ -79,101 +74,54 @@ void setKeyMaterial(ReferenceCountedOpenSslEngine engine) throws SSLException { if (type != null) { String alias = chooseServerAlias(engine, type); if (alias != null && aliases.add(alias)) { - setKeyMaterial(ssl, alias); + setKeyMaterial(engine, alias); } } } } - CertificateRequestedCallback.KeyMaterial keyMaterial(ReferenceCountedOpenSslEngine engine, String[] keyTypes, - X500Principal[] issuer) throws SSLException { + void setKeyMaterialClientSide(ReferenceCountedOpenSslEngine engine, String[] keyTypes, + X500Principal[] issuer) throws SSLException { String alias = chooseClientAlias(engine, keyTypes, issuer); - long keyBio = 0; - long keyCertChainBio = 0; - long pkey = 0; - long certChain = 0; - - try { - // TODO: Should we cache these and so not need to do a memory copy all the time ? - X509Certificate[] certificates = keyManager.getCertificateChain(alias); - if (certificates == null || certificates.length == 0) { - return null; - } - - PrivateKey key = keyManager.getPrivateKey(alias); - keyCertChainBio = toBIO(certificates); - certChain = SSL.parseX509Chain(keyCertChainBio); - if (key != null) { - keyBio = toBIO(key); - pkey = SSL.parsePrivateKey(keyBio, password); - } - CertificateRequestedCallback.KeyMaterial material = new CertificateRequestedCallback.KeyMaterial( - certChain, pkey); - - // Reset to 0 so we do not free these. This is needed as the client certificate callback takes ownership - // of both the key and the certificate if they are returned from this method, and thus must not - // be freed here. - certChain = pkey = 0; - return material; - } catch (SSLException e) { - throw e; - } catch (Exception e) { - throw new SSLException(e); - } finally { - freeBio(keyBio); - freeBio(keyCertChainBio); - SSL.freePrivateKey(pkey); - SSL.freeX509Chain(certChain); + // Only try to set the keymaterial if we have a match. This is also consistent with what OpenJDK does: + // http://hg.openjdk.java.net/jdk/jdk11/file/76072a077ee1/ + // src/java.base/share/classes/sun/security/ssl/CertificateRequest.java#l362 + if (alias != null) { + setKeyMaterial(engine, alias); } } - private void setKeyMaterial(long ssl, String alias) throws SSLException { - long keyBio = 0; - long keyCertChainBio = 0; - long keyCertChainBio2 = 0; - + private void setKeyMaterial(ReferenceCountedOpenSslEngine engine, String alias) throws SSLException { + OpenSslKeyMaterial keyMaterial = null; try { - // TODO: Should we cache these and so not need to do a memory copy all the time ? - X509Certificate[] certificates = keyManager.getCertificateChain(alias); - if (certificates == null || certificates.length == 0) { - return; - } - - PrivateKey key = keyManager.getPrivateKey(alias); - - // Only encode one time - PemEncoded encoded = PemX509Certificate.toPEM(ByteBufAllocator.DEFAULT, true, certificates); - try { - keyCertChainBio = toBIO(ByteBufAllocator.DEFAULT, encoded.retain()); - keyCertChainBio2 = toBIO(ByteBufAllocator.DEFAULT, encoded.retain()); - - if (key != null) { - keyBio = toBIO(key); - } - SSL.setCertificateBio(ssl, keyCertChainBio, keyBio, password); - - // We may have more then one cert in the chain so add all of them now. - SSL.setCertificateChainBio(ssl, keyCertChainBio2, true); - } finally { - encoded.release(); + keyMaterial = provider.chooseKeyMaterial(engine.alloc, alias); + if (keyMaterial != null) { + engine.setKeyMaterial(keyMaterial); } } catch (SSLException e) { throw e; } catch (Exception e) { throw new SSLException(e); } finally { - freeBio(keyBio); - freeBio(keyCertChainBio); - freeBio(keyCertChainBio2); + if (keyMaterial != null) { + keyMaterial.release(); + } } } - - protected String chooseClientAlias(@SuppressWarnings("unused") ReferenceCountedOpenSslEngine engine, + private String chooseClientAlias(ReferenceCountedOpenSslEngine engine, String[] keyTypes, X500Principal[] issuer) { - return keyManager.chooseClientAlias(keyTypes, issuer, null); + X509KeyManager manager = provider.keyManager(); + if (manager instanceof X509ExtendedKeyManager) { + return ((X509ExtendedKeyManager) manager).chooseEngineClientAlias(keyTypes, issuer, engine); + } + return manager.chooseClientAlias(keyTypes, issuer, null); } - protected String chooseServerAlias(@SuppressWarnings("unused") ReferenceCountedOpenSslEngine engine, String type) { - return keyManager.chooseServerAlias(type, null, null); + private String chooseServerAlias(ReferenceCountedOpenSslEngine engine, String type) { + X509KeyManager manager = provider.keyManager(); + if (manager instanceof X509ExtendedKeyManager) { + return ((X509ExtendedKeyManager) manager).chooseEngineServerAlias(type, null, engine); + } + return manager.chooseServerAlias(type, null, null); } } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java new file mode 100644 index 000000000000..72cd2e0c8b4f --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java @@ -0,0 +1,100 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.internal.tcnative.SSL; + +import javax.net.ssl.X509KeyManager; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +import static io.netty.handler.ssl.ReferenceCountedOpenSslContext.toBIO; + +/** + * Provides {@link OpenSslKeyMaterial} for a given alias. + */ +class OpenSslKeyMaterialProvider { + + private final X509KeyManager keyManager; + private final String password; + + OpenSslKeyMaterialProvider(X509KeyManager keyManager, String password) { + this.keyManager = keyManager; + this.password = password; + } + + /** + * Returns the underlying {@link X509KeyManager} that is used. + */ + X509KeyManager keyManager() { + return keyManager; + } + + /** + * Returns the {@link OpenSslKeyMaterial} or {@code null} (if none) that should be used during the handshake by + * OpenSSL. + */ + OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) throws Exception { + X509Certificate[] certificates = keyManager.getCertificateChain(alias); + if (certificates == null || certificates.length == 0) { + return null; + } + + PrivateKey key = keyManager.getPrivateKey(alias); + PemEncoded encoded = PemX509Certificate.toPEM(allocator, true, certificates); + long chainBio = 0; + long pkeyBio = 0; + long chain = 0; + long pkey = 0; + try { + chainBio = toBIO(allocator, encoded.retain()); + chain = SSL.parseX509Chain(chainBio); + + OpenSslKeyMaterial keyMaterial; + if (key instanceof OpenSslPrivateKey) { + keyMaterial = ((OpenSslPrivateKey) key).toKeyMaterial(chain, certificates); + } else { + pkeyBio = toBIO(allocator, key); + pkey = key == null ? 0 : SSL.parsePrivateKey(pkeyBio, password); + keyMaterial = new DefaultOpenSslKeyMaterial(chain, pkey, certificates); + } + + // See the chain and pkey to 0 so we will not release it as the ownership was + // transferred to OpenSslKeyMaterial. + chain = 0; + pkey = 0; + return keyMaterial; + } finally { + SSL.freeBIO(chainBio); + SSL.freeBIO(pkeyBio); + if (chain != 0) { + SSL.freeX509Chain(chain); + } + if (pkey != 0) { + SSL.freePrivateKey(pkey); + } + encoded.release(); + } + } + + /** + * Will be invoked once the provider should be destroyed. + */ + void destroy() { + // NOOP. + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java new file mode 100644 index 000000000000..67639aae3ca6 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java @@ -0,0 +1,200 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.EmptyArrays; + +import javax.security.auth.Destroyable; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +final class OpenSslPrivateKey extends AbstractReferenceCounted implements PrivateKey { + + private long privateKeyAddress; + + OpenSslPrivateKey(long privateKeyAddress) { + this.privateKeyAddress = privateKeyAddress; + } + + @Override + public String getAlgorithm() { + return "unkown"; + } + + @Override + public String getFormat() { + // As we do not support encoding we should return null as stated in the javadocs of PrivateKey. + return null; + } + + @Override + public byte[] getEncoded() { + return null; + } + + /** + * Returns the pointer to the {@code EVP_PKEY}. + */ + long privateKeyAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } + return privateKeyAddress; + } + + @Override + protected void deallocate() { + SSL.freePrivateKey(privateKeyAddress); + privateKeyAddress = 0; + } + + @Override + public OpenSslPrivateKey retain() { + super.retain(); + return this; + } + + @Override + public OpenSslPrivateKey retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public OpenSslPrivateKey touch() { + super.touch(); + return this; + } + + @Override + public OpenSslPrivateKey touch(Object hint) { + return this; + } + + /** + * NOTE: This is a JDK8 interface/method. Due to backwards compatibility + * reasons it's not possible to slap the {@code @Override} annotation onto + * this method. + * + * @see Destroyable#destroy() + */ + public void destroy() { + release(refCnt()); + } + + /** + * NOTE: This is a JDK8 interface/method. Due to backwards compatibility + * reasons it's not possible to slap the {@code @Override} annotation onto + * this method. + * + * @see Destroyable#isDestroyed() + */ + public boolean isDestroyed() { + return refCnt() == 0; + } + + /** + * Convert to a {@link OpenSslKeyMaterial}. Reference count of both is shared. + */ + OpenSslKeyMaterial toKeyMaterial(long certificateChain, X509Certificate[] chain) { + return new OpenSslPrivateKeyMaterial(certificateChain, chain); + } + + private final class OpenSslPrivateKeyMaterial implements OpenSslKeyMaterial { + + private long certificateChain; + private final X509Certificate[] x509CertificateChain; + + OpenSslPrivateKeyMaterial(long certificateChain, X509Certificate[] x509CertificateChain) { + this.certificateChain = certificateChain; + this.x509CertificateChain = x509CertificateChain == null ? + EmptyArrays.EMPTY_X509_CERTIFICATES : x509CertificateChain; + } + + @Override + public X509Certificate[] certificateChain() { + return x509CertificateChain.clone(); + } + + @Override + public long certificateChainAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } + return certificateChain; + } + + @Override + public long privateKeyAddress() { + return OpenSslPrivateKey.this.privateKeyAddress(); + } + + @Override + public OpenSslKeyMaterial retain() { + OpenSslPrivateKey.this.retain(); + return this; + } + + @Override + public OpenSslKeyMaterial retain(int increment) { + OpenSslPrivateKey.this.retain(increment); + return this; + } + + @Override + public OpenSslKeyMaterial touch() { + OpenSslPrivateKey.this.touch(); + return this; + } + + @Override + public OpenSslKeyMaterial touch(Object hint) { + OpenSslPrivateKey.this.touch(hint); + return this; + } + + @Override + public boolean release() { + if (OpenSslPrivateKey.this.release()) { + releaseChain(); + return true; + } + return false; + } + + @Override + public boolean release(int decrement) { + if (OpenSslPrivateKey.this.release(decrement)) { + releaseChain(); + return true; + } + return false; + } + + private void releaseChain() { + SSL.freeX509Chain(certificateChain); + certificateChain = 0; + } + + @Override + public int refCnt() { + return OpenSslPrivateKey.this.refCnt(); + } + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java index f57434b133c0..e27b05ae8b9c 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java @@ -15,7 +15,6 @@ */ package io.netty.handler.ssl; -import io.netty.handler.ssl.ReferenceCountedOpenSslServerContext.ServerContext; import io.netty.internal.tcnative.SSL; import java.io.File; @@ -37,7 +36,6 @@ */ public final class OpenSslServerContext extends OpenSslContext { private final OpenSslServerSessionContext sessionContext; - private final OpenSslKeyMaterialManager keyMaterialManager; /** * Creates a new instance. @@ -349,10 +347,8 @@ private OpenSslServerContext( // Create a new SSL_CTX and configure it. boolean success = false; try { - ServerContext context = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, - keyCertChain, key, keyPassword, keyManagerFactory); - sessionContext = context.sessionContext; - keyMaterialManager = context.keyMaterialManager; + sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, + keyCertChain, key, keyPassword, keyManagerFactory); success = true; } finally { if (!success) { @@ -365,9 +361,4 @@ private OpenSslServerContext( public OpenSslServerSessionContext sessionContext() { return sessionContext; } - - @Override - OpenSslKeyMaterialManager keyMaterialManager() { - return keyMaterialManager; - } } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslServerSessionContext.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslServerSessionContext.java index 8c92debfecde..691ee0b661bf 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslServerSessionContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslServerSessionContext.java @@ -25,8 +25,8 @@ * {@link OpenSslSessionContext} implementation which offers extra methods which are only useful for the server-side. */ public final class OpenSslServerSessionContext extends OpenSslSessionContext { - OpenSslServerSessionContext(ReferenceCountedOpenSslContext context) { - super(context); + OpenSslServerSessionContext(ReferenceCountedOpenSslContext context, OpenSslKeyMaterialProvider provider) { + super(context, provider); } @Override diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslSession.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslSession.java new file mode 100644 index 000000000000..f9fde2662e98 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslSession.java @@ -0,0 +1,36 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; + +interface OpenSslSession extends SSLSession { + + /** + * Finish the handshake and so init everything in the {@link OpenSslSession} that should be accessible by + * the user. + */ + void handshakeFinished() throws SSLException; + + /** + * Expand (or increase) the value returned by {@link #getApplicationBufferSize()} if necessary. + *

    + * This is only called in a synchronized block, so no need to use atomic operations. + * @param packetLengthDataOnly The packet size which exceeds the current {@link #getApplicationBufferSize()}. + */ + void tryExpandApplicationBufferSize(int packetLengthDataOnly); +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslSessionContext.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslSessionContext.java index 846a968735d8..9faefb138011 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslSessionContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslSessionContext.java @@ -34,14 +34,21 @@ public abstract class OpenSslSessionContext implements SSLSessionContext { private static final Enumeration EMPTY = new EmptyEnumeration(); private final OpenSslSessionStats stats; + + // The OpenSslKeyMaterialProvider is not really used by the OpenSslSessionContext but only be stored here + // to make it easier to destroy it later because the ReferenceCountedOpenSslContext will hold a reference + // to OpenSslSessionContext. + private final OpenSslKeyMaterialProvider provider; + final ReferenceCountedOpenSslContext context; // IMPORTANT: We take the OpenSslContext and not just the long (which points the native instance) to prevent // the GC to collect OpenSslContext as this would also free the pointer and so could result in a // segfault when the user calls any of the methods here that try to pass the pointer down to the native // level. - OpenSslSessionContext(ReferenceCountedOpenSslContext context) { + OpenSslSessionContext(ReferenceCountedOpenSslContext context, OpenSslKeyMaterialProvider provider) { this.context = context; + this.provider = provider; stats = new OpenSslSessionStats(context); } @@ -123,6 +130,12 @@ public OpenSslSessionStats stats() { return stats; } + final void destroy() { + if (provider != null) { + provider.destroy(); + } + } + private static final class EmptyEnumeration implements Enumeration { @Override public boolean hasMoreElements() { diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslTlsv13X509ExtendedTrustManager.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslTlsv13X509ExtendedTrustManager.java new file mode 100644 index 000000000000..00c6886e9ad4 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslTlsv13X509ExtendedTrustManager.java @@ -0,0 +1,498 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionContext; +import javax.net.ssl.X509ExtendedTrustManager; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.security.Principal; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.List; + +/** + * Provide a way to use {@code TLSv1.3} with Java versions prior to 11 by adding a + * = 7 && session instanceof ExtendedOpenSslSession) { + final ExtendedOpenSslSession extendedOpenSslSession = (ExtendedOpenSslSession) session; + return new ExtendedOpenSslSession(extendedOpenSslSession) { + @Override + public List getRequestedServerNames() { + return extendedOpenSslSession.getRequestedServerNames(); + } + + @Override + public String[] getPeerSupportedSignatureAlgorithms() { + return extendedOpenSslSession.getPeerSupportedSignatureAlgorithms(); + } + + @Override + public String getProtocol() { + return SslUtils.PROTOCOL_TLS_V1_2; + } + }; + } else { + return new SSLSession() { + @Override + public byte[] getId() { + return session.getId(); + } + + @Override + public SSLSessionContext getSessionContext() { + return session.getSessionContext(); + } + + @Override + public long getCreationTime() { + return session.getCreationTime(); + } + + @Override + public long getLastAccessedTime() { + return session.getLastAccessedTime(); + } + + @Override + public void invalidate() { + session.invalidate(); + } + + @Override + public boolean isValid() { + return session.isValid(); + } + + @Override + public void putValue(String s, Object o) { + session.putValue(s, o); + } + + @Override + public Object getValue(String s) { + return session.getValue(s); + } + + @Override + public void removeValue(String s) { + session.removeValue(s); + } + + @Override + public String[] getValueNames() { + return session.getValueNames(); + } + + @Override + public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return session.getPeerCertificates(); + } + + @Override + public Certificate[] getLocalCertificates() { + return session.getLocalCertificates(); + } + + @Override + public javax.security.cert.X509Certificate[] getPeerCertificateChain() + throws SSLPeerUnverifiedException { + return session.getPeerCertificateChain(); + } + + @Override + public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return session.getPeerPrincipal(); + } + + @Override + public Principal getLocalPrincipal() { + return session.getLocalPrincipal(); + } + + @Override + public String getCipherSuite() { + return session.getCipherSuite(); + } + + @Override + public String getProtocol() { + return SslUtils.PROTOCOL_TLS_V1_2; + } + + @Override + public String getPeerHost() { + return session.getPeerHost(); + } + + @Override + public int getPeerPort() { + return session.getPeerPort(); + } + + @Override + public int getPacketBufferSize() { + return session.getPacketBufferSize(); + } + + @Override + public int getApplicationBufferSize() { + return session.getApplicationBufferSize(); + } + }; + } + } + }; + } + return engine; + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, final String s, SSLEngine sslEngine) + throws CertificateException { + tm.checkClientTrusted(x509Certificates, s, wrapEngine(sslEngine)); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) + throws CertificateException { + tm.checkServerTrusted(x509Certificates, s, wrapEngine(sslEngine)); + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + tm.checkClientTrusted(x509Certificates, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + tm.checkServerTrusted(x509Certificates, s); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return tm.getAcceptedIssuers(); + } + + private static final class DummySSLEngine extends SSLEngine { + + private final boolean client; + + DummySSLEngine(boolean client) { + this.client = client; + } + + @Override + public SSLSession getHandshakeSession() { + return new SSLSession() { + @Override + public byte[] getId() { + return EmptyArrays.EMPTY_BYTES; + } + + @Override + public SSLSessionContext getSessionContext() { + return null; + } + + @Override + public long getCreationTime() { + return 0; + } + + @Override + public long getLastAccessedTime() { + return 0; + } + + @Override + public void invalidate() { + // NOOP + } + + @Override + public boolean isValid() { + return false; + } + + @Override + public void putValue(String s, Object o) { + // NOOP + } + + @Override + public Object getValue(String s) { + return null; + } + + @Override + public void removeValue(String s) { + // NOOP + } + + @Override + public String[] getValueNames() { + return EmptyArrays.EMPTY_STRINGS; + } + + @Override + public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return EmptyArrays.EMPTY_CERTIFICATES; + } + + @Override + public Certificate[] getLocalCertificates() { + return EmptyArrays.EMPTY_CERTIFICATES; + } + + @Override + public javax.security.cert.X509Certificate[] getPeerCertificateChain() + throws SSLPeerUnverifiedException { + return EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES; + } + + @Override + public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return null; + } + + @Override + public Principal getLocalPrincipal() { + return null; + } + + @Override + public String getCipherSuite() { + return null; + } + + @Override + public String getProtocol() { + return SslUtils.PROTOCOL_TLS_V1_3; + } + + @Override + public String getPeerHost() { + return null; + } + + @Override + public int getPeerPort() { + return 0; + } + + @Override + public int getPacketBufferSize() { + return 0; + } + + @Override + public int getApplicationBufferSize() { + return 0; + } + }; + } + + @Override + public SSLEngineResult wrap(ByteBuffer[] byteBuffers, int i, int i1, ByteBuffer byteBuffer) + throws SSLException { + throw new UnsupportedOperationException(); + } + + @Override + public SSLEngineResult unwrap(ByteBuffer byteBuffer, ByteBuffer[] byteBuffers, int i, int i1) + throws SSLException { + throw new UnsupportedOperationException(); + } + + @Override + public Runnable getDelegatedTask() { + return null; + } + + @Override + public void closeInbound() throws SSLException { + // NOOP + } + + @Override + public boolean isInboundDone() { + return true; + } + + @Override + public void closeOutbound() { + // NOOP + } + + @Override + public boolean isOutboundDone() { + return true; + } + + @Override + public String[] getSupportedCipherSuites() { + return EmptyArrays.EMPTY_STRINGS; + } + + @Override + public String[] getEnabledCipherSuites() { + return EmptyArrays.EMPTY_STRINGS; + } + + @Override + public void setEnabledCipherSuites(String[] strings) { + // NOOP + } + + @Override + public String[] getSupportedProtocols() { + return new String[] { SslUtils.PROTOCOL_TLS_V1_3 }; + } + + @Override + public String[] getEnabledProtocols() { + return new String[] { SslUtils.PROTOCOL_TLS_V1_3 }; + } + + @Override + public void setEnabledProtocols(String[] strings) { + // NOOP + } + + @Override + public SSLSession getSession() { + return getHandshakeSession(); + } + + @Override + public void beginHandshake() throws SSLException { + // NOOP + } + + @Override + public HandshakeStatus getHandshakeStatus() { + return HandshakeStatus.NEED_TASK; + } + + @Override + public void setUseClientMode(boolean b) { + // NOOP + } + + @Override + public boolean getUseClientMode() { + return client; + } + + @Override + public void setNeedClientAuth(boolean b) { + // NOOP + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean b) { + // NOOP + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean b) { + // NOOP + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslX509Certificate.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslX509Certificate.java index 77d0713613a9..52dbbdfa8d64 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslX509Certificate.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslX509Certificate.java @@ -15,20 +15,25 @@ */ package io.netty.handler.ssl; +import javax.security.auth.x500.X500Principal; import java.io.ByteArrayInputStream; import java.math.BigInteger; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.security.NoSuchProviderException; import java.security.Principal; +import java.security.Provider; import java.security.PublicKey; import java.security.SignatureException; import java.security.cert.CertificateEncodingException; import java.security.cert.CertificateException; import java.security.cert.CertificateExpiredException; import java.security.cert.CertificateNotYetValidException; +import java.security.cert.CertificateParsingException; import java.security.cert.X509Certificate; +import java.util.Collection; import java.util.Date; +import java.util.List; import java.util.Set; final class OpenSslX509Certificate extends X509Certificate { @@ -36,7 +41,7 @@ final class OpenSslX509Certificate extends X509Certificate { private final byte[] bytes; private X509Certificate wrapped; - public OpenSslX509Certificate(byte[] bytes) { + OpenSslX509Certificate(byte[] bytes) { this.bytes = bytes; } @@ -50,6 +55,37 @@ public void checkValidity(Date date) throws CertificateExpiredException, Certifi unwrap().checkValidity(date); } + @Override + public X500Principal getIssuerX500Principal() { + return unwrap().getIssuerX500Principal(); + } + + @Override + public X500Principal getSubjectX500Principal() { + return unwrap().getSubjectX500Principal(); + } + + @Override + public List getExtendedKeyUsage() throws CertificateParsingException { + return unwrap().getExtendedKeyUsage(); + } + + @Override + public Collection> getSubjectAlternativeNames() throws CertificateParsingException { + return unwrap().getSubjectAlternativeNames(); + } + + @Override + public Collection> getIssuerAlternativeNames() throws CertificateParsingException { + return unwrap().getSubjectAlternativeNames(); + } + + // No @Override annotation as it was only introduced in Java8. + public void verify(PublicKey key, Provider sigProvider) + throws CertificateException, NoSuchAlgorithmException, InvalidKeyException, SignatureException { + unwrap().verify(key, sigProvider); + } + @Override public int getVersion() { return unwrap().getVersion(); diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java new file mode 100644 index 000000000000..90d94cb656cc --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java @@ -0,0 +1,373 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.internal.tcnative.SSL; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.X509KeyManager; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.security.InvalidAlgorithmParameterException; +import java.security.Key; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.KeyStoreSpi; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.Provider; +import java.security.UnrecoverableKeyException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.Date; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; + +/** + * Special {@link KeyManagerFactory} that pre-compute the keymaterial used when {@link SslProvider#OPENSSL} or + * {@link SslProvider#OPENSSL_REFCNT} is used and so will improve handshake times and its performance. + * + * Because the keymaterial is pre-computed any modification to the {@link KeyStore} is ignored after + * {@link #init(KeyStore, char[])} is called. + * + * {@link #init(ManagerFactoryParameters)} is not supported by this implementation and so a call to it will always + * result in an {@link InvalidAlgorithmParameterException}. + */ +public final class OpenSslX509KeyManagerFactory extends KeyManagerFactory { + + private final OpenSslKeyManagerFactorySpi spi; + + public OpenSslX509KeyManagerFactory() { + this(newOpenSslKeyManagerFactorySpi(null)); + } + + public OpenSslX509KeyManagerFactory(Provider provider) { + this(newOpenSslKeyManagerFactorySpi(provider)); + } + + public OpenSslX509KeyManagerFactory(String algorithm, Provider provider) throws NoSuchAlgorithmException { + this(newOpenSslKeyManagerFactorySpi(algorithm, provider)); + } + + private OpenSslX509KeyManagerFactory(OpenSslKeyManagerFactorySpi spi) { + super(spi, spi.kmf.getProvider(), spi.kmf.getAlgorithm()); + this.spi = spi; + } + + private static OpenSslKeyManagerFactorySpi newOpenSslKeyManagerFactorySpi(Provider provider) { + try { + return newOpenSslKeyManagerFactorySpi(null, provider); + } catch (NoSuchAlgorithmException e) { + // This should never happen as we use the default algorithm. + throw new IllegalStateException(e); + } + } + + private static OpenSslKeyManagerFactorySpi newOpenSslKeyManagerFactorySpi(String algorithm, Provider provider) + throws NoSuchAlgorithmException { + if (algorithm == null) { + algorithm = KeyManagerFactory.getDefaultAlgorithm(); + } + return new OpenSslKeyManagerFactorySpi( + provider == null ? KeyManagerFactory.getInstance(algorithm) : + KeyManagerFactory.getInstance(algorithm, provider)); + } + + OpenSslKeyMaterialProvider newProvider() { + return spi.newProvider(); + } + + private static final class OpenSslKeyManagerFactorySpi extends KeyManagerFactorySpi { + final KeyManagerFactory kmf; + private volatile ProviderFactory providerFactory; + + OpenSslKeyManagerFactorySpi(KeyManagerFactory kmf) { + this.kmf = ObjectUtil.checkNotNull(kmf, "kmf"); + } + + @Override + protected synchronized void engineInit(KeyStore keyStore, char[] chars) + throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + if (providerFactory != null) { + throw new KeyStoreException("Already initialized"); + } + if (!keyStore.aliases().hasMoreElements()) { + throw new KeyStoreException("No aliases found"); + } + + kmf.init(keyStore, chars); + providerFactory = new ProviderFactory(ReferenceCountedOpenSslContext.chooseX509KeyManager( + kmf.getKeyManagers()), password(chars), Collections.list(keyStore.aliases())); + } + + private static String password(char[] password) { + if (password == null || password.length == 0) { + return null; + } + return new String(password); + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) + throws InvalidAlgorithmParameterException { + throw new InvalidAlgorithmParameterException("Not supported"); + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + ProviderFactory providerFactory = this.providerFactory; + if (providerFactory == null) { + throw new IllegalStateException("engineInit(...) not called yet"); + } + return new KeyManager[] { providerFactory.keyManager }; + } + + OpenSslKeyMaterialProvider newProvider() { + ProviderFactory providerFactory = this.providerFactory; + if (providerFactory == null) { + throw new IllegalStateException("engineInit(...) not called yet"); + } + return providerFactory.newProvider(); + } + + private static final class ProviderFactory { + private final X509KeyManager keyManager; + private final String password; + private final Iterable aliases; + + ProviderFactory(X509KeyManager keyManager, String password, Iterable aliases) { + this.keyManager = keyManager; + this.password = password; + this.aliases = aliases; + } + + OpenSslKeyMaterialProvider newProvider() { + return new OpenSslPopulatedKeyMaterialProvider(keyManager, + password, aliases); + } + + /** + * {@link OpenSslKeyMaterialProvider} implementation that pre-compute the {@link OpenSslKeyMaterial} for + * all aliases. + */ + private static final class OpenSslPopulatedKeyMaterialProvider extends OpenSslKeyMaterialProvider { + private final Map materialMap; + + OpenSslPopulatedKeyMaterialProvider( + X509KeyManager keyManager, String password, Iterable aliases) { + super(keyManager, password); + materialMap = new HashMap(); + boolean initComplete = false; + try { + for (String alias: aliases) { + if (alias != null && !materialMap.containsKey(alias)) { + try { + materialMap.put(alias, super.chooseKeyMaterial( + UnpooledByteBufAllocator.DEFAULT, alias)); + } catch (Exception e) { + // Just store the exception and rethrow it when we try to choose the keymaterial + // for this alias later on. + materialMap.put(alias, e); + } + } + } + initComplete = true; + } finally { + if (!initComplete) { + destroy(); + } + } + if (materialMap.isEmpty()) { + throw new IllegalArgumentException("aliases must be non-empty"); + } + } + + @Override + OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) throws Exception { + Object value = materialMap.get(alias); + if (value == null) { + // There is no keymaterial for the requested alias, return null + return null; + } + if (value instanceof OpenSslKeyMaterial) { + return ((OpenSslKeyMaterial) value).retain(); + } + throw (Exception) value; + } + + @Override + void destroy() { + for (Object material: materialMap.values()) { + ReferenceCountUtil.release(material); + } + materialMap.clear(); + } + } + } + } + + /** + * Create a new initialized {@link OpenSslX509KeyManagerFactory} which loads its {@link PrivateKey} directly from + * an {@code OpenSSL engine} via the + * ENGINE_load_private_key + * function. + */ + public static OpenSslX509KeyManagerFactory newEngineBased(File certificateChain, String password) + throws CertificateException, IOException, + KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + return newEngineBased(SslContext.toX509Certificates(certificateChain), password); + } + + /** + * Create a new initialized {@link OpenSslX509KeyManagerFactory} which loads its {@link PrivateKey} directly from + * an {@code OpenSSL engine} via the + * ENGINE_load_private_key + * function. + */ + public static OpenSslX509KeyManagerFactory newEngineBased(X509Certificate[] certificateChain, String password) + throws CertificateException, IOException, + KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + KeyStore store = new OpenSslEngineKeyStore(certificateChain.clone()); + store.load(null, null); + OpenSslX509KeyManagerFactory factory = new OpenSslX509KeyManagerFactory(); + factory.init(store, password == null ? null : password.toCharArray()); + return factory; + } + + private static final class OpenSslEngineKeyStore extends KeyStore { + private OpenSslEngineKeyStore(final X509Certificate[] certificateChain) { + super(new KeyStoreSpi() { + + private final Date creationDate = new Date(); + + @Override + public Key engineGetKey(String alias, char[] password) throws UnrecoverableKeyException { + if (engineContainsAlias(alias)) { + try { + return new OpenSslPrivateKey(SSL.loadPrivateKeyFromEngine( + alias, password == null ? null : new String(password))); + } catch (Exception e) { + UnrecoverableKeyException keyException = + new UnrecoverableKeyException("Unable to load key from engine"); + keyException.initCause(e); + throw keyException; + } + } + return null; + } + + @Override + public Certificate[] engineGetCertificateChain(String alias) { + return engineContainsAlias(alias)? certificateChain.clone() : null; + } + + @Override + public Certificate engineGetCertificate(String alias) { + return engineContainsAlias(alias)? certificateChain[0] : null; + } + + @Override + public Date engineGetCreationDate(String alias) { + return engineContainsAlias(alias)? creationDate : null; + } + + @Override + public void engineSetKeyEntry(String alias, Key key, char[] password, Certificate[] chain) + throws KeyStoreException { + throw new KeyStoreException("Not supported"); + } + + @Override + public void engineSetKeyEntry(String alias, byte[] key, Certificate[] chain) throws KeyStoreException { + throw new KeyStoreException("Not supported"); + } + + @Override + public void engineSetCertificateEntry(String alias, Certificate cert) throws KeyStoreException { + throw new KeyStoreException("Not supported"); + } + + @Override + public void engineDeleteEntry(String alias) throws KeyStoreException { + throw new KeyStoreException("Not supported"); + } + + @Override + public Enumeration engineAliases() { + return Collections.enumeration(Collections.singleton(SslContext.ALIAS)); + } + + @Override + public boolean engineContainsAlias(String alias) { + return SslContext.ALIAS.equals(alias); + } + + @Override + public int engineSize() { + return 1; + } + + @Override + public boolean engineIsKeyEntry(String alias) { + return engineContainsAlias(alias); + } + + @Override + public boolean engineIsCertificateEntry(String alias) { + return engineContainsAlias(alias); + } + + @Override + public String engineGetCertificateAlias(Certificate cert) { + if (cert instanceof X509Certificate) { + for (X509Certificate x509Certificate : certificateChain) { + if (x509Certificate.equals(cert)) { + return SslContext.ALIAS; + } + } + } + return null; + } + + @Override + public void engineStore(OutputStream stream, char[] password) { + throw new UnsupportedOperationException(); + } + + @Override + public void engineLoad(InputStream stream, char[] password) { + if (stream != null && password != null) { + throw new UnsupportedOperationException(); + } + } + }, null, "native"); + + OpenSsl.ensureAvailability(); + } + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslX509TrustManagerWrapper.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslX509TrustManagerWrapper.java new file mode 100644 index 000000000000..0a0db0bbf4ea --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslX509TrustManagerWrapper.java @@ -0,0 +1,190 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import java.lang.reflect.Field; +import java.security.AccessController; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.security.PrivilegedAction; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +/** + * Utility which allows to wrap {@link X509TrustManager} implementations with the internal implementation used by + * {@code SSLContextImpl} that provides extended verification. + * + * This is really a "hack" until there is an official API as requested on the in + * JDK-8210843. + */ +final class OpenSslX509TrustManagerWrapper { + private static final InternalLogger LOGGER = InternalLoggerFactory + .getInstance(OpenSslX509TrustManagerWrapper.class); + private static final TrustManagerWrapper WRAPPER; + + static { + // By default we will not do any wrapping but just return the passed in manager. + TrustManagerWrapper wrapper = new TrustManagerWrapper() { + @Override + public X509TrustManager wrapIfNeeded(X509TrustManager manager) { + return manager; + } + }; + + Throwable cause = null; + Throwable unsafeCause = PlatformDependent.getUnsafeUnavailabilityCause(); + if (unsafeCause == null) { + SSLContext context; + try { + context = newSSLContext(); + // Now init with an array that only holds a X509TrustManager. This should be wrapped into an + // AbstractTrustManagerWrapper which will delegate the TrustManager itself but also do extra + // validations. + // + // See: + // - http://hg.openjdk.java.net/jdk8u/jdk8u/jdk/file/ + // cadea780bc76/src/share/classes/sun/security/ssl/SSLContextImpl.java#l127 + context.init(null, new TrustManager[] { + new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } + }, null); + } catch (Throwable error) { + context = null; + cause = error; + } + if (cause != null) { + LOGGER.debug("Unable to access wrapped TrustManager", cause); + } else { + final SSLContext finalContext = context; + Object maybeWrapper = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Object run() { + try { + Field contextSpiField = SSLContext.class.getDeclaredField("contextSpi"); + final long spiOffset = PlatformDependent.objectFieldOffset(contextSpiField); + Object spi = PlatformDependent.getObject(finalContext, spiOffset); + if (spi != null) { + Class clazz = spi.getClass(); + + // Let's cycle through the whole hierarchy until we find what we are looking for or + // there is nothing left in which case we will not wrap at all. + do { + try { + Field trustManagerField = clazz.getDeclaredField("trustManager"); + final long tmOffset = PlatformDependent.objectFieldOffset(trustManagerField); + Object trustManager = PlatformDependent.getObject(spi, tmOffset); + if (trustManager instanceof X509ExtendedTrustManager) { + return new UnsafeTrustManagerWrapper(spiOffset, tmOffset); + } + } catch (NoSuchFieldException ignore) { + // try next + } + clazz = clazz.getSuperclass(); + } while (clazz != null); + } + throw new NoSuchFieldException(); + } catch (NoSuchFieldException e) { + return e; + } catch (SecurityException e) { + return e; + } + } + }); + if (maybeWrapper instanceof Throwable) { + LOGGER.debug("Unable to access wrapped TrustManager", (Throwable) maybeWrapper); + } else { + wrapper = (TrustManagerWrapper) maybeWrapper; + } + } + } else { + LOGGER.debug("Unable to access wrapped TrustManager", cause); + } + WRAPPER = wrapper; + } + + private OpenSslX509TrustManagerWrapper() { } + + static X509TrustManager wrapIfNeeded(X509TrustManager trustManager) { + return WRAPPER.wrapIfNeeded(trustManager); + } + + private interface TrustManagerWrapper { + X509TrustManager wrapIfNeeded(X509TrustManager manager); + } + + private static SSLContext newSSLContext() throws NoSuchAlgorithmException { + return SSLContext.getInstance("TLS"); + } + + private static final class UnsafeTrustManagerWrapper implements TrustManagerWrapper { + private final long spiOffset; + private final long tmOffset; + + UnsafeTrustManagerWrapper(long spiOffset, long tmOffset) { + this.spiOffset = spiOffset; + this.tmOffset = tmOffset; + } + + @Override + public X509TrustManager wrapIfNeeded(X509TrustManager manager) { + if (!(manager instanceof X509ExtendedTrustManager)) { + try { + SSLContext ctx = newSSLContext(); + ctx.init(null, new TrustManager[] { manager }, null); + Object spi = PlatformDependent.getObject(ctx, spiOffset); + if (spi != null) { + Object tm = PlatformDependent.getObject(spi, tmOffset); + if (tm instanceof X509ExtendedTrustManager) { + return (X509TrustManager) tm; + } + } + } catch (NoSuchAlgorithmException e) { + // This should never happen as we did the same in the static + // before. + PlatformDependent.throwException(e); + } catch (KeyManagementException e) { + // This should never happen as we did the same in the static + // before. + PlatformDependent.throwException(e); + } + } + return manager; + } + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/PemPrivateKey.java b/handler/src/main/java/io/netty/handler/ssl/PemPrivateKey.java index e7bfd12392a8..46145a04b1ec 100644 --- a/handler/src/main/java/io/netty/handler/ssl/PemPrivateKey.java +++ b/handler/src/main/java/io/netty/handler/ssl/PemPrivateKey.java @@ -60,7 +60,12 @@ static PemEncoded toPEM(ByteBufAllocator allocator, boolean useDirect, PrivateKe return ((PemEncoded) key).retain(); } - ByteBuf encoded = Unpooled.wrappedBuffer(key.getEncoded()); + byte[] bytes = key.getEncoded(); + if (bytes == null) { + throw new IllegalArgumentException(key.getClass().getName() + " does not support encoding"); + } + + ByteBuf encoded = Unpooled.wrappedBuffer(bytes); try { ByteBuf base64 = SslUtils.toBase64(allocator, encoded); try { diff --git a/handler/src/main/java/io/netty/handler/ssl/PemReader.java b/handler/src/main/java/io/netty/handler/ssl/PemReader.java index 016d215c2f83..4cddad9c912a 100644 --- a/handler/src/main/java/io/netty/handler/ssl/PemReader.java +++ b/handler/src/main/java/io/netty/handler/ssl/PemReader.java @@ -98,7 +98,7 @@ static ByteBuf[] readCertificates(InputStream in) throws CertificateException { throw new CertificateException("found no certificates in input stream"); } - return certs.toArray(new ByteBuf[certs.size()]); + return certs.toArray(new ByteBuf[0]); } static ByteBuf readPrivateKey(File file) throws KeyException { diff --git a/handler/src/main/java/io/netty/handler/ssl/PemValue.java b/handler/src/main/java/io/netty/handler/ssl/PemValue.java index becb5b849215..ada5e4ddf661 100644 --- a/handler/src/main/java/io/netty/handler/ssl/PemValue.java +++ b/handler/src/main/java/io/netty/handler/ssl/PemValue.java @@ -34,7 +34,7 @@ class PemValue extends AbstractReferenceCounted implements PemEncoded { private final boolean sensitive; - public PemValue(ByteBuf content, boolean sensitive) { + PemValue(ByteBuf content, boolean sensitive) { this.content = ObjectUtil.checkNotNull(content, "content"); this.sensitive = sensitive; } diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslClientContext.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslClientContext.java index b2135734e936..f508971f014d 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslClientContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslClientContext.java @@ -15,25 +15,27 @@ */ package io.netty.handler.ssl; +import io.netty.internal.tcnative.CertificateCallback; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; -import io.netty.internal.tcnative.CertificateRequestedCallback; import io.netty.internal.tcnative.SSL; import io.netty.internal.tcnative.SSLContext; import java.security.KeyStore; import java.security.PrivateKey; import java.security.cert.X509Certificate; + +import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Set; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.TrustManagerFactory; -import javax.net.ssl.X509ExtendedKeyManager; import javax.net.ssl.X509ExtendedTrustManager; -import javax.net.ssl.X509KeyManager; import javax.net.ssl.X509TrustManager; import javax.security.auth.x500.X500Principal; @@ -48,6 +50,12 @@ public final class ReferenceCountedOpenSslClientContext extends ReferenceCountedOpenSslContext { private static final InternalLogger logger = InternalLoggerFactory.getInstance(ReferenceCountedOpenSslClientContext.class); + private static final Set SUPPORTED_KEY_TYPES = Collections.unmodifiableSet(new LinkedHashSet( + Arrays.asList(OpenSslKeyMaterialManager.KEY_TYPE_RSA, + OpenSslKeyMaterialManager.KEY_TYPE_DH_RSA, + OpenSslKeyMaterialManager.KEY_TYPE_EC, + OpenSslKeyMaterialManager.KEY_TYPE_EC_RSA, + OpenSslKeyMaterialManager.KEY_TYPE_EC_EC))); private final OpenSslSessionContext sessionContext; ReferenceCountedOpenSslClientContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, @@ -70,11 +78,6 @@ public final class ReferenceCountedOpenSslClientContext extends ReferenceCounted } } - @Override - OpenSslKeyMaterialManager keyMaterialManager() { - return null; - } - @Override public OpenSslSessionContext sessionContext() { return sessionContext; @@ -90,71 +93,89 @@ static OpenSslSessionContext newSessionContext(ReferenceCountedOpenSslContext th throw new IllegalArgumentException( "Either both keyCertChain and key needs to be null or none of them"); } + OpenSslKeyMaterialProvider keyMaterialProvider = null; try { - if (!OpenSsl.useKeyManagerFactory()) { - if (keyManagerFactory != null) { - throw new IllegalArgumentException( - "KeyManagerFactory not supported"); - } - if (keyCertChain != null/* && key != null*/) { - setKeyMaterial(ctx, keyCertChain, key, keyPassword); - } - } else { - // javadocs state that keyManagerFactory has precedent over keyCertChain - if (keyManagerFactory == null && keyCertChain != null) { - keyManagerFactory = buildKeyManagerFactory( - keyCertChain, key, keyPassword, keyManagerFactory); - } + try { + if (!OpenSsl.useKeyManagerFactory()) { + if (keyManagerFactory != null) { + throw new IllegalArgumentException( + "KeyManagerFactory not supported"); + } + if (keyCertChain != null/* && key != null*/) { + setKeyMaterial(ctx, keyCertChain, key, keyPassword); + } + } else { + // javadocs state that keyManagerFactory has precedent over keyCertChain + if (keyManagerFactory == null && keyCertChain != null) { + char[] keyPasswordChars = keyStorePassword(keyPassword); + KeyStore ks = buildKeyStore(keyCertChain, key, keyPasswordChars); + if (ks.aliases().hasMoreElements()) { + keyManagerFactory = new OpenSslX509KeyManagerFactory(); + } else { + keyManagerFactory = new OpenSslCachingX509KeyManagerFactory( + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())); + } + keyManagerFactory.init(ks, keyPasswordChars); + keyMaterialProvider = providerFor(keyManagerFactory, keyPassword); + } else if (keyManagerFactory != null) { + keyMaterialProvider = providerFor(keyManagerFactory, keyPassword); + } - if (keyManagerFactory != null) { - X509KeyManager keyManager = chooseX509KeyManager(keyManagerFactory.getKeyManagers()); - OpenSslKeyMaterialManager materialManager = useExtendedKeyManager(keyManager) ? - new OpenSslExtendedKeyMaterialManager( - (X509ExtendedKeyManager) keyManager, keyPassword) : - new OpenSslKeyMaterialManager(keyManager, keyPassword); - SSLContext.setCertRequestedCallback(ctx, new OpenSslCertificateRequestedCallback( - engineMap, materialManager)); + if (keyMaterialProvider != null) { + OpenSslKeyMaterialManager materialManager = new OpenSslKeyMaterialManager(keyMaterialProvider); + SSLContext.setCertificateCallback(ctx, new OpenSslClientCertificateCallback( + engineMap, materialManager)); + } } + } catch (Exception e) { + throw new SSLException("failed to set certificate and key", e); } - } catch (Exception e) { - throw new SSLException("failed to set certificate and key", e); - } - SSLContext.setVerify(ctx, SSL.SSL_CVERIFY_NONE, VERIFY_DEPTH); + SSLContext.setVerify(ctx, SSL.SSL_CVERIFY_NONE, VERIFY_DEPTH); - try { - if (trustCertCollection != null) { - trustManagerFactory = buildTrustManagerFactory(trustCertCollection, trustManagerFactory); - } else if (trustManagerFactory == null) { - trustManagerFactory = TrustManagerFactory.getInstance( - TrustManagerFactory.getDefaultAlgorithm()); - trustManagerFactory.init((KeyStore) null); - } - final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers()); + try { + if (trustCertCollection != null) { + trustManagerFactory = buildTrustManagerFactory(trustCertCollection, trustManagerFactory); + } else if (trustManagerFactory == null) { + trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + } + final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers()); - // IMPORTANT: The callbacks set for verification must be static to prevent memory leak as - // otherwise the context can never be collected. This is because the JNI code holds - // a global reference to the callbacks. - // - // See https://github.com/netty/netty/issues/5372 + // IMPORTANT: The callbacks set for verification must be static to prevent memory leak as + // otherwise the context can never be collected. This is because the JNI code holds + // a global reference to the callbacks. + // + // See https://github.com/netty/netty/issues/5372 - // Use this to prevent an error when running on java < 7 - if (useExtendedTrustManager(manager)) { - SSLContext.setCertVerifyCallback(ctx, - new ExtendedTrustManagerVerifyCallback(engineMap, (X509ExtendedTrustManager) manager)); - } else { - SSLContext.setCertVerifyCallback(ctx, new TrustManagerVerifyCallback(engineMap, manager)); + // Use this to prevent an error when running on java < 7 + if (useExtendedTrustManager(manager)) { + SSLContext.setCertVerifyCallback(ctx, + new ExtendedTrustManagerVerifyCallback(engineMap, (X509ExtendedTrustManager) manager)); + } else { + SSLContext.setCertVerifyCallback(ctx, new TrustManagerVerifyCallback(engineMap, manager)); + } + } catch (Exception e) { + if (keyMaterialProvider != null) { + keyMaterialProvider.destroy(); + } + throw new SSLException("unable to setup trustmanager", e); + } + OpenSslClientSessionContext context = new OpenSslClientSessionContext(thiz, keyMaterialProvider); + keyMaterialProvider = null; + return context; + } finally { + if (keyMaterialProvider != null) { + keyMaterialProvider.destroy(); } - } catch (Exception e) { - throw new SSLException("unable to setup trustmanager", e); } - return new OpenSslClientSessionContext(thiz); } // No cache is currently supported for client side mode. static final class OpenSslClientSessionContext extends OpenSslSessionContext { - OpenSslClientSessionContext(ReferenceCountedOpenSslContext context) { - super(context); + OpenSslClientSessionContext(ReferenceCountedOpenSslContext context, OpenSslKeyMaterialProvider provider) { + super(context, provider); } @Override @@ -212,7 +233,7 @@ private static final class ExtendedTrustManagerVerifyCallback extends AbstractCe ExtendedTrustManagerVerifyCallback(OpenSslEngineMap engineMap, X509ExtendedTrustManager manager) { super(engineMap); - this.manager = manager; + this.manager = OpenSslTlsv13X509ExtendedTrustManager.wrap(manager, true); } @Override @@ -222,21 +243,21 @@ void verify(ReferenceCountedOpenSslEngine engine, X509Certificate[] peerCerts, S } } - private static final class OpenSslCertificateRequestedCallback implements CertificateRequestedCallback { + private static final class OpenSslClientCertificateCallback implements CertificateCallback { private final OpenSslEngineMap engineMap; private final OpenSslKeyMaterialManager keyManagerHolder; - OpenSslCertificateRequestedCallback(OpenSslEngineMap engineMap, OpenSslKeyMaterialManager keyManagerHolder) { + OpenSslClientCertificateCallback(OpenSslEngineMap engineMap, OpenSslKeyMaterialManager keyManagerHolder) { this.engineMap = engineMap; this.keyManagerHolder = keyManagerHolder; } @Override - public KeyMaterial requested(long ssl, byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals) { + public void handle(long ssl, byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals) throws Exception { final ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); try { final Set keyTypesSet = supportedClientKeyTypes(keyTypeBytes); - final String[] keyTypes = keyTypesSet.toArray(new String[keyTypesSet.size()]); + final String[] keyTypes = keyTypesSet.toArray(new String[0]); final X500Principal[] issuers; if (asn1DerEncodedPrincipals == null) { issuers = null; @@ -246,13 +267,12 @@ public KeyMaterial requested(long ssl, byte[] keyTypeBytes, byte[][] asn1DerEnco issuers[i] = new X500Principal(asn1DerEncodedPrincipals[i]); } } - return keyManagerHolder.keyMaterial(engine, keyTypes, issuers); + keyManagerHolder.setKeyMaterialClientSide(engine, keyTypes, issuers); } catch (Throwable cause) { logger.debug("request of key failed", cause); SSLHandshakeException e = new SSLHandshakeException("General OpenSslEngine problem"); e.initCause(cause); engine.handshakeException = e; - return null; } } @@ -265,6 +285,10 @@ public KeyMaterial requested(long ssl, byte[] keyTypeBytes, byte[][] asn1DerEnco * {@code X509ExtendedKeyManager.chooseEngineClientAlias}. */ private static Set supportedClientKeyTypes(byte[] clientCertificateTypes) { + if (clientCertificateTypes == null) { + // Try all of the supported key types. + return SUPPORTED_KEY_TYPES; + } Set result = new HashSet(clientCertificateTypes.length); for (byte keyTypeCode : clientCertificateTypes) { String keyType = clientKeyType(keyTypeCode); @@ -280,15 +304,15 @@ private static Set supportedClientKeyTypes(byte[] clientCertificateTypes private static String clientKeyType(byte clientCertificateType) { // See also http://www.ietf.org/assignments/tls-parameters/tls-parameters.xml switch (clientCertificateType) { - case CertificateRequestedCallback.TLS_CT_RSA_SIGN: + case CertificateCallback.TLS_CT_RSA_SIGN: return OpenSslKeyMaterialManager.KEY_TYPE_RSA; // RFC rsa_sign - case CertificateRequestedCallback.TLS_CT_RSA_FIXED_DH: + case CertificateCallback.TLS_CT_RSA_FIXED_DH: return OpenSslKeyMaterialManager.KEY_TYPE_DH_RSA; // RFC rsa_fixed_dh - case CertificateRequestedCallback.TLS_CT_ECDSA_SIGN: + case CertificateCallback.TLS_CT_ECDSA_SIGN: return OpenSslKeyMaterialManager.KEY_TYPE_EC; // RFC ecdsa_sign - case CertificateRequestedCallback.TLS_CT_RSA_FIXED_ECDH: + case CertificateCallback.TLS_CT_RSA_FIXED_ECDH: return OpenSslKeyMaterialManager.KEY_TYPE_EC_RSA; // RFC rsa_fixed_ecdh - case CertificateRequestedCallback.TLS_CT_ECDSA_FIXED_ECDH: + case CertificateCallback.TLS_CT_ECDSA_FIXED_ECDH: return OpenSslKeyMaterialManager.KEY_TYPE_EC_EC; // RFC ecdsa_fixed_ecdh default: return null; diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java index a695d2d689ac..6506e949a52b 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java @@ -44,16 +44,17 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.Executor; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.TrustManager; -import javax.net.ssl.X509ExtendedKeyManager; import javax.net.ssl.X509ExtendedTrustManager; import javax.net.ssl.X509KeyManager; import javax.net.ssl.X509TrustManager; @@ -85,7 +86,8 @@ public Integer run() { 2048)); } }); - + private static final boolean USE_TASKS = + SystemPropertyUtil.getBoolean("io.netty.handler.ssl.openssl.useTasks", false); private static final Integer DH_KEY_LENGTH; private static final ResourceLeakDetector leakDetector = ResourceLeakDetectorFactory.instance().newResourceLeakDetector(ReferenceCountedOpenSslContext.class); @@ -225,24 +227,71 @@ public String run() { boolean success = false; try { try { - ctx = SSLContext.make(SSL.SSL_PROTOCOL_ALL, mode); + int protocolOpts = SSL.SSL_PROTOCOL_SSLV3 | SSL.SSL_PROTOCOL_TLSV1 | + SSL.SSL_PROTOCOL_TLSV1_1 | SSL.SSL_PROTOCOL_TLSV1_2; + if (OpenSsl.isTlsv13Supported()) { + protocolOpts |= SSL.SSL_PROTOCOL_TLSV1_3; + } + ctx = SSLContext.make(protocolOpts, mode); } catch (Exception e) { throw new SSLException("failed to create an SSL_CTX", e); } - SSLContext.setOptions(ctx, SSLContext.getOptions(ctx) | - SSL.SSL_OP_NO_SSLv2 | - SSL.SSL_OP_NO_SSLv3 | - SSL.SSL_OP_CIPHER_SERVER_PREFERENCE | + boolean tlsv13Supported = OpenSsl.isTlsv13Supported(); + StringBuilder cipherBuilder = new StringBuilder(); + StringBuilder cipherTLSv13Builder = new StringBuilder(); + + /* List the ciphers that are permitted to negotiate. */ + try { + if (unmodifiableCiphers.isEmpty()) { + // Set non TLSv1.3 ciphers. + SSLContext.setCipherSuite(ctx, StringUtil.EMPTY_STRING, false); + if (tlsv13Supported) { + // Set TLSv1.3 ciphers. + SSLContext.setCipherSuite(ctx, StringUtil.EMPTY_STRING, true); + } + } else { + CipherSuiteConverter.convertToCipherStrings( + unmodifiableCiphers, cipherBuilder, cipherTLSv13Builder, OpenSsl.isBoringSSL()); + + // Set non TLSv1.3 ciphers. + SSLContext.setCipherSuite(ctx, cipherBuilder.toString(), false); + if (tlsv13Supported) { + // Set TLSv1.3 ciphers. + SSLContext.setCipherSuite(ctx, cipherTLSv13Builder.toString(), true); + } + } + } catch (SSLException e) { + throw e; + } catch (Exception e) { + throw new SSLException("failed to set cipher suite: " + unmodifiableCiphers, e); + } + + int options = SSLContext.getOptions(ctx) | + SSL.SSL_OP_NO_SSLv2 | + SSL.SSL_OP_NO_SSLv3 | + // Disable TLSv1.3 by default for now. Even if TLSv1.3 is not supported this will + // work fine as in this case SSL_OP_NO_TLSv1_3 will be 0. + SSL.SSL_OP_NO_TLSv1_3 | - // We do not support compression at the moment so we should explicitly disable it. - SSL.SSL_OP_NO_COMPRESSION | + SSL.SSL_OP_CIPHER_SERVER_PREFERENCE | - // Disable ticket support by default to be more inline with SSLEngineImpl of the JDK. - // This also let SSLSession.getId() work the same way for the JDK implementation and the - // OpenSSLEngine. If tickets are supported SSLSession.getId() will only return an ID on the - // server-side if it could make use of tickets. - SSL.SSL_OP_NO_TICKET); + // We do not support compression at the moment so we should explicitly disable it. + SSL.SSL_OP_NO_COMPRESSION | + + // Disable ticket support by default to be more inline with SSLEngineImpl of the JDK. + // This also let SSLSession.getId() work the same way for the JDK implementation and the + // OpenSSLEngine. If tickets are supported SSLSession.getId() will only return an ID on the + // server-side if it could make use of tickets. + SSL.SSL_OP_NO_TICKET; + + if (cipherBuilder.length() == 0) { + // No ciphers that are compatible with SSLv2 / SSLv3 / TLSv1 / TLSv1.1 / TLSv1.2 + options |= SSL.SSL_OP_NO_SSLv2 | SSL.SSL_OP_NO_SSLv3 | SSL.SSL_OP_NO_TLSv1 + | SSL.SSL_OP_NO_TLSv1_1 | SSL.SSL_OP_NO_TLSv1_2; + } + + SSLContext.setOptions(ctx, options); // We need to enable SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER as the memory address may change between // calling OpenSSLEngine.wrap(...). @@ -253,19 +302,10 @@ public String run() { SSLContext.setTmpDHLength(ctx, DH_KEY_LENGTH); } - /* List the ciphers that are permitted to negotiate. */ - try { - SSLContext.setCipherSuite(ctx, CipherSuiteConverter.toOpenSsl(unmodifiableCiphers)); - } catch (SSLException e) { - throw e; - } catch (Exception e) { - throw new SSLException("failed to set cipher suite: " + unmodifiableCiphers, e); - } - List nextProtoList = apn.protocols(); /* Set next protocols for next protocol negotiation extension, if specified */ if (!nextProtoList.isEmpty()) { - String[] appProtocols = nextProtoList.toArray(new String[nextProtoList.size()]); + String[] appProtocols = nextProtoList.toArray(new String[0]); int selectorBehavior = opensslSelectorFailureBehavior(apn.selectorFailureBehavior()); switch (apn.protocol()) { @@ -303,6 +343,8 @@ public String run() { if (enableOcsp) { SSLContext.enableOcsp(ctx, isClient()); } + + SSLContext.setUseTasks(ctx, USE_TASKS); success = true; } finally { if (!success) { @@ -362,12 +404,21 @@ protected final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, i return new SslHandler(newEngine0(alloc, peerHost, peerPort, false), startTls); } + @Override + protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls, Executor executor) { + return new SslHandler(newEngine0(alloc, null, -1, false), startTls, executor); + } + + @Override + protected SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, + boolean startTls, Executor executor) { + return new SslHandler(newEngine0(alloc, peerHost, peerPort, false), executor); + } + SSLEngine newEngine0(ByteBufAllocator alloc, String peerHost, int peerPort, boolean jdkCompatibilityMode) { return new ReferenceCountedOpenSslEngine(this, alloc, peerHost, peerPort, jdkCompatibilityMode, true); } - abstract OpenSslKeyMaterialManager keyMaterialManager(); - /** * Returns a new server-side {@link SSLEngine} with the current configuration. */ @@ -385,13 +436,7 @@ public final SSLEngine newEngine(ByteBufAllocator alloc) { */ @Deprecated public final long context() { - Lock readerLock = ctxLock.readLock(); - readerLock.lock(); - try { - return ctx; - } finally { - readerLock.unlock(); - } + return sslCtxPointer(); } /** @@ -466,7 +511,7 @@ public final long sslCtxPointer() { Lock readerLock = ctxLock.readLock(); readerLock.lock(); try { - return ctx; + return SSLContext.getSslCtx(ctx); } finally { readerLock.unlock(); } @@ -486,6 +531,11 @@ private void destroy() { SSLContext.free(ctx); ctx = 0; + + OpenSslSessionContext context = sessionContext(); + if (context != null) { + context.destroy(); + } } } finally { writerLock.unlock(); @@ -503,7 +553,7 @@ protected static X509Certificate[] certificates(byte[][] chain) { protected static X509TrustManager chooseTrustManager(TrustManager[] managers) { for (TrustManager m : managers) { if (m instanceof X509TrustManager) { - return (X509TrustManager) m; + return OpenSslX509TrustManagerWrapper.wrapIfNeeded((X509TrustManager) m); } } throw new IllegalStateException("no X509TrustManager found"); @@ -566,10 +616,6 @@ static boolean useExtendedTrustManager(X509TrustManager trustManager) { return PlatformDependent.javaVersion() >= 7 && trustManager instanceof X509ExtendedTrustManager; } - static boolean useExtendedKeyManager(X509KeyManager keyManager) { - return PlatformDependent.javaVersion() >= 7 && keyManager instanceof X509ExtendedKeyManager; - } - @Override public final int refCnt() { return refCnt.refCnt(); @@ -710,7 +756,7 @@ static void setKeyMaterial(long ctx, X509Certificate[] keyCertChain, PrivateKey keyCertChainBio2 = toBIO(ByteBufAllocator.DEFAULT, encoded.retain()); if (key != null) { - keyBio = toBIO(key); + keyBio = toBIO(ByteBufAllocator.DEFAULT, key); } SSLContext.setCertificateBio( @@ -742,12 +788,11 @@ static void freeBio(long bio) { * Return the pointer to a in-memory BIO * or {@code 0} if the {@code key} is {@code null}. The BIO contains the content of the {@code key}. */ - static long toBIO(PrivateKey key) throws Exception { + static long toBIO(ByteBufAllocator allocator, PrivateKey key) throws Exception { if (key == null) { return 0; } - ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; PemEncoded pem = PemPrivateKey.toPEM(allocator, true, key); try { return toBIO(allocator, pem.retain()); @@ -760,7 +805,7 @@ static long toBIO(PrivateKey key) throws Exception { * Return the pointer to a in-memory BIO * or {@code 0} if the {@code certChain} is {@code null}. The BIO contains the content of the {@code certChain}. */ - static long toBIO(X509Certificate... certChain) throws Exception { + static long toBIO(ByteBufAllocator allocator, X509Certificate... certChain) throws Exception { if (certChain == null) { return 0; } @@ -769,7 +814,6 @@ static long toBIO(X509Certificate... certChain) throws Exception { throw new IllegalArgumentException("certChain can't be empty"); } - ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; PemEncoded pem = PemX509Certificate.toPEM(allocator, true, certChain); try { return toBIO(allocator, pem.retain()); @@ -821,4 +865,23 @@ private static long newBIO(ByteBuf buffer) throws Exception { buffer.release(); } } + + /** + * Returns the {@link OpenSslKeyMaterialProvider} that should be used for OpenSSL. Depending on the given + * {@link KeyManagerFactory} this may cache the {@link OpenSslKeyMaterial} for better performance if it can + * ensure that the same material is always returned for the same alias. + */ + static OpenSslKeyMaterialProvider providerFor(KeyManagerFactory factory, String password) { + if (factory instanceof OpenSslX509KeyManagerFactory) { + return ((OpenSslX509KeyManagerFactory) factory).newProvider(); + } + + X509KeyManager keyManager = chooseX509KeyManager(factory.getKeyManagers()); + if (factory instanceof OpenSslCachingX509KeyManagerFactory) { + // The user explicit used OpenSslCachingX509KeyManagerFactory which signals us that its fine to cache. + return new OpenSslCachingKeyMaterialProvider(keyManager, password); + } + // We can not be sure if the material may change at runtime so we will not cache it. + return new OpenSslKeyMaterialProvider(keyManager, password); + } } diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index c0eb141a5f33..4627d9502bc1 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -20,6 +20,7 @@ import io.netty.internal.tcnative.Buffer; import io.netty.internal.tcnative.SSL; import io.netty.util.AbstractReferenceCounted; +import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCounted; import io.netty.util.ResourceLeakDetector; import io.netty.util.ResourceLeakDetectorFactory; @@ -39,9 +40,12 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.locks.Lock; @@ -64,15 +68,15 @@ import static io.netty.handler.ssl.SslUtils.PROTOCOL_TLS_V1; import static io.netty.handler.ssl.SslUtils.PROTOCOL_TLS_V1_1; import static io.netty.handler.ssl.SslUtils.PROTOCOL_TLS_V1_2; +import static io.netty.handler.ssl.SslUtils.PROTOCOL_TLS_V1_3; import static io.netty.handler.ssl.SslUtils.SSL_RECORD_HEADER_LENGTH; -import static io.netty.internal.tcnative.SSL.SSL_MAX_PLAINTEXT_LENGTH; -import static io.netty.internal.tcnative.SSL.SSL_MAX_RECORD_LENGTH; import static io.netty.util.internal.EmptyArrays.EMPTY_CERTIFICATES; import static io.netty.util.internal.EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES; import static io.netty.util.internal.ObjectUtil.checkNotNull; import static java.lang.Integer.MAX_VALUE; import static java.lang.Math.min; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; @@ -107,12 +111,14 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1 = 2; private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_1 = 3; private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_2 = 4; + private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3 = 5; private static final int[] OPENSSL_OP_NO_PROTOCOLS = { SSL.SSL_OP_NO_SSLv2, SSL.SSL_OP_NO_SSLv3, SSL.SSL_OP_NO_TLSv1, SSL.SSL_OP_NO_TLSv1_1, - SSL.SSL_OP_NO_TLSv1_2 + SSL.SSL_OP_NO_TLSv1_2, + SSL.SSL_OP_NO_TLSv1_3 }; /** * The flags argument is usually 0. @@ -122,16 +128,15 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc /** * Depends upon tcnative ... only use if tcnative is available! */ - static final int MAX_PLAINTEXT_LENGTH = SSL_MAX_PLAINTEXT_LENGTH; + static final int MAX_PLAINTEXT_LENGTH = SSL.SSL_MAX_PLAINTEXT_LENGTH; /** * Depends upon tcnative ... only use if tcnative is available! */ - private static final int MAX_RECORD_SIZE = SSL_MAX_RECORD_LENGTH; + private static final int MAX_RECORD_SIZE = SSL.SSL_MAX_RECORD_LENGTH; private static final AtomicIntegerFieldUpdater DESTROYED_UPDATER = AtomicIntegerFieldUpdater.newUpdater(ReferenceCountedOpenSslEngine.class, "destroyed"); - private static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL"; private static final SSLEngineResult NEED_UNWRAP_OK = new SSLEngineResult(OK, NEED_UNWRAP, 0, 0); private static final SSLEngineResult NEED_UNWRAP_CLOSED = new SSLEngineResult(CLOSED, NEED_UNWRAP, 0, 0); private static final SSLEngineResult NEED_WRAP_OK = new SSLEngineResult(OK, NEED_WRAP, 0, 0); @@ -141,7 +146,6 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc // OpenSSL state private long ssl; private long networkBIO; - private boolean certificateSet; private enum HandshakeState { /** @@ -166,6 +170,7 @@ private enum HandshakeState { private boolean receivedShutdown; private volatile int destroyed; private volatile String applicationProtocol; + private volatile boolean needTask; // Reference Counting private final ResourceLeakTracker leak; @@ -190,6 +195,7 @@ protected void deallocate() { }; private volatile ClientAuth clientAuth = ClientAuth.NONE; + private volatile Certificate[] localCertificateChain; // Updated once a new handshake is started and so the SSLSession reused. private volatile long lastAccessed = -1; @@ -209,14 +215,12 @@ protected void deallocate() { final boolean jdkCompatibilityMode; private final boolean clientMode; - private final ByteBufAllocator alloc; + final ByteBufAllocator alloc; private final OpenSslEngineMap engineMap; private final OpenSslApplicationProtocolNegotiator apn; private final OpenSslSession session; - private final Certificate[] localCerts; private final ByteBuffer[] singleSrcBuffer = new ByteBuffer[1]; private final ByteBuffer[] singleDstBuffer = new ByteBuffer[1]; - private final OpenSslKeyMaterialManager keyMaterialManager; private final boolean enableOcsp; private int maxWrapOverhead; private int maxWrapBufferSize; @@ -237,18 +241,95 @@ protected void deallocate() { * wrap or unwrap call. * @param leakDetection {@code true} to enable leak detection of this object. */ - ReferenceCountedOpenSslEngine(ReferenceCountedOpenSslContext context, ByteBufAllocator alloc, String peerHost, + ReferenceCountedOpenSslEngine(ReferenceCountedOpenSslContext context, final ByteBufAllocator alloc, String peerHost, int peerPort, boolean jdkCompatibilityMode, boolean leakDetection) { super(peerHost, peerPort); OpenSsl.ensureAvailability(); this.alloc = checkNotNull(alloc, "alloc"); apn = (OpenSslApplicationProtocolNegotiator) context.applicationProtocolNegotiator(); - session = new OpenSslSession(context.sessionContext()); clientMode = context.isClient(); + if (PlatformDependent.javaVersion() >= 7) { + session = new ExtendedOpenSslSession(new DefaultOpenSslSession(context.sessionContext())) { + private String[] peerSupportedSignatureAlgorithms; + private List requestedServerNames; + + @Override + public List getRequestedServerNames() { + if (clientMode) { + return Java8SslUtils.getSniHostNames(sniHostNames); + } else { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (requestedServerNames == null) { + if (isDestroyed()) { + requestedServerNames = Collections.emptyList(); + } else { + String name = SSL.getSniHostname(ssl); + if (name == null) { + requestedServerNames = Collections.emptyList(); + } else { + // Convert to bytes as we do not want to do any strict validation of the + // SNIHostName while creating it. + requestedServerNames = + Java8SslUtils.getSniHostName( + SSL.getSniHostname(ssl).getBytes(CharsetUtil.UTF_8)); + } + } + } + return requestedServerNames; + } + } + } + + @Override + public String[] getPeerSupportedSignatureAlgorithms() { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (peerSupportedSignatureAlgorithms == null) { + if (isDestroyed()) { + peerSupportedSignatureAlgorithms = EmptyArrays.EMPTY_STRINGS; + } else { + String[] algs = SSL.getSigAlgs(ssl); + if (algs == null) { + peerSupportedSignatureAlgorithms = EmptyArrays.EMPTY_STRINGS; + } else { + Set algorithmList = new LinkedHashSet(algs.length); + for (String alg: algs) { + String converted = SignatureAlgorithmConverter.toJavaName(alg); + + if (converted != null) { + algorithmList.add(converted); + } + } + peerSupportedSignatureAlgorithms = algorithmList.toArray(new String[0]); + } + } + } + return peerSupportedSignatureAlgorithms.clone(); + } + } + + @Override + public List getStatusResponses() { + byte[] ocspResponse = null; + if (enableOcsp && clientMode) { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (!isDestroyed()) { + ocspResponse = SSL.getOcspResponse(ssl); + } + } + } + return ocspResponse == null ? + Collections.emptyList() : Collections.singletonList(ocspResponse); + } + }; + } else { + session = new DefaultOpenSslSession(context.sessionContext()); + } engineMap = context.engineMap; - localCerts = context.keyCertChain; - keyMaterialManager = context.keyMaterialManager(); enableOcsp = context.enableOcsp; + // context.keyCertChain will only be non-null if we do not use the KeyManagerFactory. In this case + // localCertificateChain will be set in setKeyMaterial(...). + localCertificateChain = context.keyCertChain; + this.jdkCompatibilityMode = jdkCompatibilityMode; Lock readerLock = context.ctxLock.readLock(); readerLock.lock(); @@ -271,10 +352,11 @@ protected void deallocate() { setEnabledProtocols(context.protocols); } - // Use SNI if peerHost was specified + // Use SNI if peerHost was specified and a valid hostname // See https://github.com/netty/netty/issues/4746 - if (clientMode && peerHost != null) { + if (clientMode && SslUtils.isValidHostNameForSNI(peerHost)) { SSL.setTlsExtHostName(ssl, peerHost); + sniHostNames = Collections.singletonList(peerHost); } if (enableOcsp) { @@ -288,7 +370,11 @@ protected void deallocate() { // setMode may impact the overhead. calculateMaxWrapOverhead(); } catch (Throwable cause) { - SSL.freeSSL(ssl); + // Call shutdown so we are sure we correctly release all native memory and also guard against the + // case when shutdown() will be called by the finalizer again. If we would call SSL.free(...) directly + // the finalizer may end up calling it again as we would miss to update the DESTROYED_UPDATER. + shutdown(); + PlatformDependent.throwException(cause); } } @@ -298,6 +384,11 @@ protected void deallocate() { leak = leakDetection ? leakDetector.track(this) : null; } + final void setKeyMaterial(OpenSslKeyMaterial keyMaterial) throws Exception { + SSL.setKeyMaterial(ssl, keyMaterial.certificateChainAddress(), keyMaterial.privateKeyAddress()); + localCertificateChain = keyMaterial.certificateChain(); + } + /** * Sets the OCSP response. */ @@ -607,8 +698,17 @@ public final SSLEngineResult wrap( int bioLengthBefore = SSL.bioLengthByteBuffer(networkBIO); - // Explicit use outboundClosed as we want to drain any bytes that are still present. + // Explicitly use outboundClosed as we want to drain any bytes that are still present. if (outboundClosed) { + // If the outbound was closed we want to ensure we can produce the alert to the destination buffer. + // This is true even if we not using jdkCompatibilityMode. + // + // We use a plaintextLength of 2 as we at least want to have an alert fit into it. + // https://tools.ietf.org/html/rfc5246#section-7.2 + if (!isBytesAvailableEnoughForWrap(dst.remaining(), 2, 1)) { + return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); + } + // There is something left to drain. // See https://github.com/netty/netty/issues/6260 bytesProduced = SSL.bioFlushByteBuffer(networkBIO); @@ -655,6 +755,10 @@ public final SSLEngineResult wrap( // we may have freed up space by flushing above. bytesProduced = bioLengthBefore - SSL.bioLengthByteBuffer(networkBIO); + if (status == NEED_TASK) { + return newResult(status, 0, bytesProduced); + } + if (bytesProduced > 0) { // If we have filled up the dst buffer and we have not finished the handshake we should try to // wrap again. Otherwise we should only try to wrap again if there is still data pending in @@ -785,9 +889,11 @@ public final SSLEngineResult wrap( // to write encrypted data to. This is an OVERFLOW condition. // [1] https://www.openssl.org/docs/manmaster/ssl/SSL_write.html return newResult(BUFFER_OVERFLOW, status, bytesConsumed, bytesProduced); + } else if (sslError == SSL.SSL_ERROR_WANT_X509_LOOKUP) { + return newResult(NEED_TASK, bytesConsumed, bytesProduced); } else { // Everything else is considered as error - throw shutdownWithError("SSL_write"); + throw shutdownWithError("SSL_write", sslError); } } } @@ -825,6 +931,10 @@ private SSLEngineResult newResult(SSLEngineResult.Status status, SSLEngineResult } return new SSLEngineResult(CLOSED, hs, bytesConsumed, bytesProduced); } + if (hs == NEED_TASK) { + // Set needTask to true so getHandshakeStatus() will return the correct value. + needTask = true; + } return new SSLEngineResult(status, hs, bytesConsumed, bytesProduced); } @@ -844,22 +954,23 @@ private SSLEngineResult newResultMayFinishHandshake(SSLEngineResult.Status statu /** * Log the error, shutdown the engine and throw an exception. */ - private SSLException shutdownWithError(String operations) { - String err = SSL.getLastError(); - return shutdownWithError(operations, err); + private SSLException shutdownWithError(String operations, int sslError) { + return shutdownWithError(operations, sslError, SSL.getLastErrorNumber()); } - private SSLException shutdownWithError(String operation, String err) { + private SSLException shutdownWithError(String operation, int sslError, int error) { + String errorString = SSL.getErrorString(error); if (logger.isDebugEnabled()) { - logger.debug("{} failed: OpenSSL error: {}", operation, err); + logger.debug("{} failed with {}: OpenSSL error: {} {}", + operation, sslError, error, errorString); } // There was an internal error -- shutdown shutdown(); if (handshakeState == HandshakeState.FINISHED) { - return new SSLException(err); + return new SSLException(errorString); } - return new SSLHandshakeException(err); + return new SSLHandshakeException(errorString); } public final SSLEngineResult unwrap( @@ -921,6 +1032,11 @@ public final SSLEngineResult unwrap( } status = handshake(); + + if (status == NEED_TASK) { + return newResult(status, 0, 0); + } + if (status == NEED_WRAP) { return NEED_WRAP_OK; } @@ -1061,9 +1177,12 @@ public final SSLEngineResult unwrap( closeAll(); } return newResultMayFinishHandshake(isInboundDone() ? CLOSED : OK, status, - bytesConsumed, bytesProduced); + bytesConsumed, bytesProduced); + } else if (sslError == SSL.SSL_ERROR_WANT_X509_LOOKUP) { + return newResult(isInboundDone() ? CLOSED : OK, + NEED_TASK, bytesConsumed, bytesProduced); } else { - return sslReadErrorResult(SSL.getLastErrorNumber(), bytesConsumed, + return sslReadErrorResult(sslError, SSL.getLastErrorNumber(), bytesConsumed, bytesProduced); } } @@ -1092,22 +1211,24 @@ public final SSLEngineResult unwrap( } } - private SSLEngineResult sslReadErrorResult(int err, int bytesConsumed, int bytesProduced) throws SSLException { - String errStr = SSL.getErrorString(err); - + private SSLEngineResult sslReadErrorResult(int error, int stackError, int bytesConsumed, int bytesProduced) + throws SSLException { // Check if we have a pending handshakeException and if so see if we need to consume all pending data from the // BIO first or can just shutdown and throw it now. // This is needed so we ensure close_notify etc is correctly send to the remote peer. // See https://github.com/netty/netty/issues/3900 if (SSL.bioLengthNonApplication(networkBIO) > 0) { if (handshakeException == null && handshakeState != HandshakeState.FINISHED) { - // we seems to have data left that needs to be transfered and so the user needs + // we seems to have data left that needs to be transferred and so the user needs // call wrap(...). Store the error so we can pick it up later. - handshakeException = new SSLHandshakeException(errStr); + handshakeException = new SSLHandshakeException(SSL.getErrorString(stackError)); } + // We need to clear all errors so we not pick up anything that was left on the stack on the next + // operation. Note that shutdownWithError(...) will cleanup the stack as well so its only needed here. + SSL.clearError(); return new SSLEngineResult(OK, NEED_WRAP, bytesConsumed, bytesProduced); } - throw shutdownWithError("SSL_read", errStr); + throw shutdownWithError("SSL_read", error, stackError); } private void closeAll() throws SSLException { @@ -1120,7 +1241,10 @@ private void rejectRemoteInitiatedRenegotiation() throws SSLHandshakeException { // As rejectRemoteInitiatedRenegotiation() is called in a finally block we also need to check if we shutdown // the engine before as otherwise SSL.getHandshakeCount(ssl) will throw an NPE if the passed in ssl is 0. // See https://github.com/netty/netty/issues/7353 - if (!isDestroyed() && SSL.getHandshakeCount(ssl) > 1) { + if (!isDestroyed() && SSL.getHandshakeCount(ssl) > 1 && + // As we may count multiple handshakes when TLSv1.3 is used we should just ignore this here as + // renegotiation is not supported in TLSv1.3 as per spec. + !SslUtils.PROTOCOL_TLS_V1_3.equals(session.getProtocol()) && handshakeState == HandshakeState.FINISHED) { // TODO: In future versions me may also want to send a fatal_alert to the client and so notify it // that the renegotiation failed. shutdown(); @@ -1189,10 +1313,29 @@ public final synchronized SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] ds } @Override - public final Runnable getDelegatedTask() { - // Currently, we do not delegate SSL computation tasks - // TODO: in the future, possibly create tasks to do encrypt / decrypt async - return null; + public final synchronized Runnable getDelegatedTask() { + if (isDestroyed()) { + return null; + } + final Runnable task = SSL.getTask(ssl); + if (task == null) { + return null; + } + return new Runnable() { + @Override + public void run() { + try { + if (isDestroyed()) { + // The engine was destroyed in the meantime, just return. + return; + } + task.run(); + } finally { + // The task was run, reset needTask to false so getHandshakeStatus() returns the correct value. + needTask = false; + } + } + }; } @Override @@ -1256,7 +1399,8 @@ private boolean doSSLShutdown() { int sslErr = SSL.getError(ssl, err); if (sslErr == SSL.SSL_ERROR_SYSCALL || sslErr == SSL.SSL_ERROR_SSL) { if (logger.isDebugEnabled()) { - logger.debug("SSL_shutdown failed: OpenSSL error: {}", SSL.getLastError()); + int error = SSL.getLastErrorNumber(); + logger.debug("SSL_shutdown failed: OpenSSL error: {} {}", error, SSL.getErrorString(error)); } // There was an internal error -- shutdown shutdown(); @@ -1276,7 +1420,7 @@ public final synchronized boolean isOutboundDone() { @Override public final String[] getSupportedCipherSuites() { - return OpenSsl.AVAILABLE_CIPHER_SUITES.toArray(new String[OpenSsl.AVAILABLE_CIPHER_SUITES.size()]); + return OpenSsl.AVAILABLE_CIPHER_SUITES.toArray(new String[0]); } @Override @@ -1292,15 +1436,18 @@ public final String[] getEnabledCipherSuites() { if (enabled == null) { return EmptyArrays.EMPTY_STRINGS; } else { + List enabledList = new ArrayList(); synchronized (this) { for (int i = 0; i < enabled.length; i++) { String mapped = toJavaCipherSuite(enabled[i]); - if (mapped != null) { - enabled[i] = mapped; + final String cipher = mapped == null ? enabled[i] : mapped; + if (!OpenSsl.isTlsv13Supported() && SslUtils.isTLSv13Cipher(cipher)) { + continue; } + enabledList.add(cipher); } } - return enabled; + return enabledList.toArray(new String[0]); } } @@ -1309,35 +1456,28 @@ public final void setEnabledCipherSuites(String[] cipherSuites) { checkNotNull(cipherSuites, "cipherSuites"); final StringBuilder buf = new StringBuilder(); - for (String c: cipherSuites) { - if (c == null) { - break; - } - - String converted = CipherSuiteConverter.toOpenSsl(c); - if (converted == null) { - converted = c; - } - - if (!OpenSsl.isCipherSuiteAvailable(converted)) { - throw new IllegalArgumentException("unsupported cipher suite: " + c + '(' + converted + ')'); - } - - buf.append(converted); - buf.append(':'); - } - - if (buf.length() == 0) { - throw new IllegalArgumentException("empty cipher suites"); - } - buf.setLength(buf.length() - 1); + final StringBuilder bufTLSv13 = new StringBuilder(); + CipherSuiteConverter.convertToCipherStrings(Arrays.asList(cipherSuites), buf, bufTLSv13, OpenSsl.isBoringSSL()); final String cipherSuiteSpec = buf.toString(); + final String cipherSuiteSpecTLSv13 = bufTLSv13.toString(); + if (!OpenSsl.isTlsv13Supported() && !cipherSuiteSpecTLSv13.isEmpty()) { + throw new IllegalArgumentException("TLSv1.3 is not supported by this java version."); + } synchronized (this) { if (!isDestroyed()) { + // TODO: Should we also adjust the protocols based on if there are any ciphers left that can be used + // for TLSv1.3 or for previor SSL/TLS versions ? try { - SSL.setCipherSuites(ssl, cipherSuiteSpec); + // Set non TLSv1.3 ciphers. + SSL.setCipherSuites(ssl, cipherSuiteSpec, false); + + if (OpenSsl.isTlsv13Supported()) { + // Set TLSv1.3 ciphers. + SSL.setCipherSuites(ssl, cipherSuiteSpecTLSv13, true); + } + } catch (Exception e) { throw new IllegalStateException("failed to enable cipher suites: " + cipherSuiteSpec, e); } @@ -1349,7 +1489,7 @@ public final void setEnabledCipherSuites(String[] cipherSuites) { @Override public final String[] getSupportedProtocols() { - return OpenSsl.SUPPORTED_PROTOCOLS_SET.toArray(new String[OpenSsl.SUPPORTED_PROTOCOLS_SET.size()]); + return OpenSsl.SUPPORTED_PROTOCOLS_SET.toArray(new String[0]); } @Override @@ -1363,7 +1503,7 @@ public final String[] getEnabledProtocols() { if (!isDestroyed()) { opts = SSL.getOptions(ssl); } else { - return enabled.toArray(new String[1]); + return enabled.toArray(new String[0]); } } if (isProtocolEnabled(opts, SSL.SSL_OP_NO_TLSv1, PROTOCOL_TLS_V1)) { @@ -1375,13 +1515,16 @@ public final String[] getEnabledProtocols() { if (isProtocolEnabled(opts, SSL.SSL_OP_NO_TLSv1_2, PROTOCOL_TLS_V1_2)) { enabled.add(PROTOCOL_TLS_V1_2); } + if (isProtocolEnabled(opts, SSL.SSL_OP_NO_TLSv1_3, PROTOCOL_TLS_V1_3)) { + enabled.add(PROTOCOL_TLS_V1_3); + } if (isProtocolEnabled(opts, SSL.SSL_OP_NO_SSLv2, PROTOCOL_SSL_V2)) { enabled.add(PROTOCOL_SSL_V2); } if (isProtocolEnabled(opts, SSL.SSL_OP_NO_SSLv3, PROTOCOL_SSL_V3)) { enabled.add(PROTOCOL_SSL_V3); } - return enabled.toArray(new String[enabled.size()]); + return enabled.toArray(new String[0]); } private static boolean isProtocolEnabled(int opts, int disableMask, String protocolString) { @@ -1446,13 +1589,20 @@ public final void setEnabledProtocols(String[] protocols) { if (maxProtocolIndex < OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_2) { maxProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_2; } + } else if (p.equals(PROTOCOL_TLS_V1_3)) { + if (minProtocolIndex > OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3) { + minProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3; + } + if (maxProtocolIndex < OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3) { + maxProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3; + } } } synchronized (this) { if (!isDestroyed()) { // Clear out options which disable protocols SSL.clearOptions(ssl, SSL.SSL_OP_NO_SSLv2 | SSL.SSL_OP_NO_SSLv3 | SSL.SSL_OP_NO_TLSv1 | - SSL.SSL_OP_NO_TLSv1_1 | SSL.SSL_OP_NO_TLSv1_2); + SSL.SSL_OP_NO_TLSv1_1 | SSL.SSL_OP_NO_TLSv1_2 | SSL.SSL_OP_NO_TLSv1_3); int opts = 0; for (int i = 0; i < minProtocolIndex; ++i) { @@ -1499,7 +1649,10 @@ public final synchronized void beginHandshake() throws SSLException { throw RENEGOTIATION_UNSUPPORTED; case NOT_STARTED: handshakeState = HandshakeState.STARTED_EXPLICITLY; - handshake(); + if (handshake() == NEED_TASK) { + // Set needTask to true so getHandshakeStatus() will return the correct value. + needTask = true; + } calculateMaxWrapOverhead(); break; default: @@ -1555,11 +1708,6 @@ private SSLEngineResult.HandshakeStatus handshake() throws SSLException { lastAccessed = System.currentTimeMillis(); } - if (!certificateSet && keyMaterialManager != null) { - certificateSet = true; - keyMaterialManager.setKeyMaterial(this); - } - int code = SSL.doHandshake(ssl); if (code <= 0) { // Check if we have a pending exception that was created during the handshake and if so throw it after @@ -1574,14 +1722,21 @@ private SSLEngineResult.HandshakeStatus handshake() throws SSLException { int sslError = SSL.getError(ssl, code); if (sslError == SSL.SSL_ERROR_WANT_READ || sslError == SSL.SSL_ERROR_WANT_WRITE) { return pendingStatus(SSL.bioLengthNonApplication(networkBIO)); - } else { - // Everything else is considered as error - throw shutdownWithError("SSL_do_handshake"); } + + if (sslError == SSL.SSL_ERROR_WANT_X509_LOOKUP) { + return NEED_TASK; + } + + // Everything else is considered as error + throw shutdownWithError("SSL_do_handshake", sslError); + } + // We have produced more data as part of the handshake if this is the case the user should call wrap(...) + if (SSL.bioLengthNonApplication(networkBIO) > 0) { + return NEED_WRAP; } // if SSL_do_handshake returns > 0 or sslError == SSL.SSL_ERROR_NAME it means the handshake was finished. session.handshakeFinished(); - engineMap.remove(ssl); return FINISHED; } @@ -1598,12 +1753,26 @@ private SSLEngineResult.HandshakeStatus mayFinishHandshake(SSLEngineResult.Hands @Override public final synchronized SSLEngineResult.HandshakeStatus getHandshakeStatus() { // Check if we are in the initial handshake phase or shutdown phase - return needPendingStatus() ? pendingStatus(SSL.bioLengthNonApplication(networkBIO)) : NOT_HANDSHAKING; + if (needPendingStatus()) { + if (needTask) { + // There is a task outstanding + return NEED_TASK; + } + return pendingStatus(SSL.bioLengthNonApplication(networkBIO)); + } + return NOT_HANDSHAKING; } private SSLEngineResult.HandshakeStatus getHandshakeStatus(int pending) { // Check if we are in the initial handshake phase or shutdown phase - return needPendingStatus() ? pendingStatus(pending) : NOT_HANDSHAKING; + if (needPendingStatus()) { + if (needTask) { + // There is a task outstanding + return NEED_TASK; + } + return pendingStatus(pending); + } + return NOT_HANDSHAKING; } private boolean needPendingStatus() { @@ -1619,7 +1788,8 @@ private String toJavaCipherSuite(String openSslCipherSuite) { return null; } - String prefix = toJavaCipherSuitePrefix(SSL.getVersion(ssl)); + String version = SSL.getVersion(ssl); + String prefix = toJavaCipherSuitePrefix(version); return CipherSuiteConverter.toJava(openSslCipherSuite, prefix); } @@ -1773,10 +1943,20 @@ public final synchronized void setSSLParameters(SSLParameters sslParameters) { } final String endPointIdentificationAlgorithm = sslParameters.getEndpointIdentificationAlgorithm(); - final boolean endPointVerificationEnabled = endPointIdentificationAlgorithm != null && - !endPointIdentificationAlgorithm.isEmpty(); - SSL.setHostNameValidation(ssl, DEFAULT_HOSTNAME_VALIDATION_FLAGS, - endPointVerificationEnabled ? getPeerHost() : null); + final boolean endPointVerificationEnabled = isEndPointVerificationEnabled(endPointIdentificationAlgorithm); + + final boolean wasEndPointVerificationEnabled = + isEndPointVerificationEnabled(this.endPointIdentificationAlgorithm); + + if (wasEndPointVerificationEnabled && !endPointVerificationEnabled) { + // Passing in null will disable hostname verification again so only do so if it was enabled before. + SSL.setHostNameValidation(ssl, DEFAULT_HOSTNAME_VALIDATION_FLAGS, null); + } else { + String host = endPointVerificationEnabled ? getPeerHost() : null; + if (host != null && !host.isEmpty()) { + SSL.setHostNameValidation(ssl, DEFAULT_HOSTNAME_VALIDATION_FLAGS, host); + } + } // If the user asks for hostname verification we must ensure we verify the peer. // If the user disables hostname verification we leave it up to the user to change the mode manually. if (clientMode && endPointVerificationEnabled) { @@ -1789,11 +1969,15 @@ public final synchronized void setSSLParameters(SSLParameters sslParameters) { super.setSSLParameters(sslParameters); } + private static boolean isEndPointVerificationEnabled(String endPointIdentificationAlgorithm) { + return endPointIdentificationAlgorithm != null && !endPointIdentificationAlgorithm.isEmpty(); + } + private boolean isDestroyed() { return destroyed != 0; } - final boolean checkSniHostnameMatch(String hostname) { + final boolean checkSniHostnameMatch(byte[] hostname) { return Java8SslUtils.checkSniHostnameMatch(matchers, hostname); } @@ -1810,13 +1994,14 @@ private static long bufferAddress(ByteBuffer b) { return Buffer.address(b); } - private final class OpenSslSession implements SSLSession { + private final class DefaultOpenSslSession implements OpenSslSession { private final OpenSslSessionContext sessionContext; // These are guarded by synchronized(OpenSslEngine.this) as handshakeFinished() may be triggered by any // thread. private X509Certificate[] x509PeerCerts; private Certificate[] peerCerts; + private String protocol; private String cipher; private byte[] id; @@ -1826,10 +2011,14 @@ private final class OpenSslSession implements SSLSession { // lazy init for memory reasons private Map values; - OpenSslSession(OpenSslSessionContext sessionContext) { + DefaultOpenSslSession(OpenSslSessionContext sessionContext) { this.sessionContext = sessionContext; } + private SSLSessionBindingEvent newSSLSessionBindingEvent(String name) { + return new SSLSessionBindingEvent(session, name); + } + @Override public byte[] getId() { synchronized (ReferenceCountedOpenSslEngine.this) { @@ -1889,14 +2078,19 @@ public void putValue(String name, Object value) { if (value == null) { throw new NullPointerException("value"); } - Map values = this.values; - if (values == null) { - // Use size of 2 to keep the memory overhead small - values = this.values = new HashMap(2); + final Object old; + synchronized (this) { + Map values = this.values; + if (values == null) { + // Use size of 2 to keep the memory overhead small + values = this.values = new HashMap(2); + } + old = values.put(name, value); } - Object old = values.put(name, value); + if (value instanceof SSLSessionBindingListener) { - ((SSLSessionBindingListener) value).valueBound(new SSLSessionBindingEvent(this, name)); + // Use newSSLSessionBindingEvent so we alway use the wrapper if needed. + ((SSLSessionBindingListener) value).valueBound(newSSLSessionBindingEvent(name)); } notifyUnbound(old, name); } @@ -1906,10 +2100,12 @@ public Object getValue(String name) { if (name == null) { throw new NullPointerException("name"); } - if (values == null) { - return null; + synchronized (this) { + if (values == null) { + return null; + } + return values.get(name); } - return values.get(name); } @Override @@ -1917,26 +2113,34 @@ public void removeValue(String name) { if (name == null) { throw new NullPointerException("name"); } - Map values = this.values; - if (values == null) { - return; + + final Object old; + synchronized (this) { + Map values = this.values; + if (values == null) { + return; + } + old = values.remove(name); } - Object old = values.remove(name); + notifyUnbound(old, name); } @Override public String[] getValueNames() { - Map values = this.values; - if (values == null || values.isEmpty()) { - return EmptyArrays.EMPTY_STRINGS; + synchronized (this) { + Map values = this.values; + if (values == null || values.isEmpty()) { + return EmptyArrays.EMPTY_STRINGS; + } + return values.keySet().toArray(new String[0]); } - return values.keySet().toArray(new String[values.size()]); } private void notifyUnbound(Object value, String name) { if (value instanceof SSLSessionBindingListener) { - ((SSLSessionBindingListener) value).valueUnbound(new SSLSessionBindingEvent(this, name)); + // Use newSSLSessionBindingEvent so we alway use the wrapper if needed. + ((SSLSessionBindingListener) value).valueUnbound(newSSLSessionBindingEvent(name)); } } @@ -1944,7 +2148,8 @@ private void notifyUnbound(Object value, String name) { * Finish the handshake and so init everything in the {@link OpenSslSession} that should be accessible by * the user. */ - void handshakeFinished() throws SSLException { + @Override + public void handshakeFinished() throws SSLException { synchronized (ReferenceCountedOpenSslEngine.this) { if (!isDestroyed()) { id = SSL.getSessionId(ssl); @@ -2084,6 +2289,7 @@ public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { @Override public Certificate[] getLocalCertificates() { + Certificate[] localCerts = ReferenceCountedOpenSslEngine.this.localCertificateChain; if (localCerts == null) { return null; } @@ -2110,7 +2316,7 @@ public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { @Override public Principal getLocalPrincipal() { - Certificate[] local = localCerts; + Certificate[] local = ReferenceCountedOpenSslEngine.this.localCertificateChain; if (local == null || local.length == 0) { return null; } @@ -2121,7 +2327,7 @@ public Principal getLocalPrincipal() { public String getCipherSuite() { synchronized (ReferenceCountedOpenSslEngine.this) { if (cipher == null) { - return INVALID_CIPHER; + return SslUtils.INVALID_CIPHER; } return cipher; } @@ -2162,13 +2368,8 @@ public int getApplicationBufferSize() { return applicationBufferSize; } - /** - * Expand (or increase) the value returned by {@link #getApplicationBufferSize()} if necessary. - *

    - * This is only called in a synchronized block, so no need to use atomic operations. - * @param packetLengthDataOnly The packet size which exceeds the current {@link #getApplicationBufferSize()}. - */ - void tryExpandApplicationBufferSize(int packetLengthDataOnly) { + @Override + public void tryExpandApplicationBufferSize(int packetLengthDataOnly) { if (packetLengthDataOnly > MAX_PLAINTEXT_LENGTH && applicationBufferSize != MAX_RECORD_SIZE) { applicationBufferSize = MAX_RECORD_SIZE; } diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslServerContext.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslServerContext.java index 4c9df3148c0c..e901aebb3833 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslServerContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslServerContext.java @@ -15,9 +15,12 @@ */ package io.netty.handler.ssl; +import io.netty.buffer.ByteBufAllocator; +import io.netty.internal.tcnative.CertificateCallback; import io.netty.internal.tcnative.SSL; import io.netty.internal.tcnative.SSLContext; import io.netty.internal.tcnative.SniHostNameMatcher; +import io.netty.util.CharsetUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -27,10 +30,9 @@ import java.security.cert.X509Certificate; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.TrustManagerFactory; -import javax.net.ssl.X509ExtendedKeyManager; import javax.net.ssl.X509ExtendedTrustManager; -import javax.net.ssl.X509KeyManager; import javax.net.ssl.X509TrustManager; import static io.netty.util.internal.ObjectUtil.checkNotNull; @@ -48,7 +50,6 @@ public final class ReferenceCountedOpenSslServerContext extends ReferenceCounted InternalLoggerFactory.getInstance(ReferenceCountedOpenSslServerContext.class); private static final byte[] ID = {'n', 'e', 't', 't', 'y'}; private final OpenSslServerSessionContext sessionContext; - private final OpenSslKeyMaterialManager keyMaterialManager; ReferenceCountedOpenSslServerContext( X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, @@ -72,10 +73,8 @@ private ReferenceCountedOpenSslServerContext( // Create a new SSL_CTX and configure it. boolean success = false; try { - ServerContext context = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, + sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory); - sessionContext = context.sessionContext; - keyMaterialManager = context.keyMaterialManager; success = true; } finally { if (!success) { @@ -89,104 +88,135 @@ public OpenSslServerSessionContext sessionContext() { return sessionContext; } - @Override - OpenSslKeyMaterialManager keyMaterialManager() { - return keyMaterialManager; - } - - static final class ServerContext { - OpenSslServerSessionContext sessionContext; - OpenSslKeyMaterialManager keyMaterialManager; - } - - static ServerContext newSessionContext(ReferenceCountedOpenSslContext thiz, long ctx, OpenSslEngineMap engineMap, - X509Certificate[] trustCertCollection, - TrustManagerFactory trustManagerFactory, - X509Certificate[] keyCertChain, PrivateKey key, - String keyPassword, KeyManagerFactory keyManagerFactory) + static OpenSslServerSessionContext newSessionContext(ReferenceCountedOpenSslContext thiz, long ctx, + OpenSslEngineMap engineMap, + X509Certificate[] trustCertCollection, + TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, + String keyPassword, KeyManagerFactory keyManagerFactory) throws SSLException { - ServerContext result = new ServerContext(); + OpenSslKeyMaterialProvider keyMaterialProvider = null; try { - SSLContext.setVerify(ctx, SSL.SSL_CVERIFY_NONE, VERIFY_DEPTH); - if (!OpenSsl.useKeyManagerFactory()) { - if (keyManagerFactory != null) { - throw new IllegalArgumentException( - "KeyManagerFactory not supported"); - } - checkNotNull(keyCertChain, "keyCertChain"); - - setKeyMaterial(ctx, keyCertChain, key, keyPassword); - } else { - // javadocs state that keyManagerFactory has precedent over keyCertChain, and we must have a - // keyManagerFactory for the server so build one if it is not specified. - if (keyManagerFactory == null) { - keyManagerFactory = buildKeyManagerFactory( - keyCertChain, key, keyPassword, keyManagerFactory); + try { + SSLContext.setVerify(ctx, SSL.SSL_CVERIFY_NONE, VERIFY_DEPTH); + if (!OpenSsl.useKeyManagerFactory()) { + if (keyManagerFactory != null) { + throw new IllegalArgumentException( + "KeyManagerFactory not supported"); + } + checkNotNull(keyCertChain, "keyCertChain"); + + setKeyMaterial(ctx, keyCertChain, key, keyPassword); + } else { + // javadocs state that keyManagerFactory has precedent over keyCertChain, and we must have a + // keyManagerFactory for the server so build one if it is not specified. + if (keyManagerFactory == null) { + char[] keyPasswordChars = keyStorePassword(keyPassword); + KeyStore ks = buildKeyStore(keyCertChain, key, keyPasswordChars); + if (ks.aliases().hasMoreElements()) { + keyManagerFactory = new OpenSslX509KeyManagerFactory(); + } else { + keyManagerFactory = new OpenSslCachingX509KeyManagerFactory( + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())); + } + keyManagerFactory.init(ks, keyPasswordChars); + } + keyMaterialProvider = providerFor(keyManagerFactory, keyPassword); + + SSLContext.setCertificateCallback(ctx, new OpenSslServerCertificateCallback( + engineMap, new OpenSslKeyMaterialManager(keyMaterialProvider))); } - X509KeyManager keyManager = chooseX509KeyManager(keyManagerFactory.getKeyManagers()); - result.keyMaterialManager = useExtendedKeyManager(keyManager) ? - new OpenSslExtendedKeyMaterialManager( - (X509ExtendedKeyManager) keyManager, keyPassword) : - new OpenSslKeyMaterialManager(keyManager, keyPassword); - } - } catch (Exception e) { - throw new SSLException("failed to set certificate and key", e); - } - try { - if (trustCertCollection != null) { - trustManagerFactory = buildTrustManagerFactory(trustCertCollection, trustManagerFactory); - } else if (trustManagerFactory == null) { - // Mimic the way SSLContext.getInstance(KeyManager[], null, null) works - trustManagerFactory = TrustManagerFactory.getInstance( - TrustManagerFactory.getDefaultAlgorithm()); - trustManagerFactory.init((KeyStore) null); + } catch (Exception e) { + throw new SSLException("failed to set certificate and key", e); } + try { + if (trustCertCollection != null) { + trustManagerFactory = buildTrustManagerFactory(trustCertCollection, trustManagerFactory); + } else if (trustManagerFactory == null) { + // Mimic the way SSLContext.getInstance(KeyManager[], null, null) works + trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + } - final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers()); + final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers()); - // IMPORTANT: The callbacks set for verification must be static to prevent memory leak as - // otherwise the context can never be collected. This is because the JNI code holds - // a global reference to the callbacks. - // - // See https://github.com/netty/netty/issues/5372 + // IMPORTANT: The callbacks set for verification must be static to prevent memory leak as + // otherwise the context can never be collected. This is because the JNI code holds + // a global reference to the callbacks. + // + // See https://github.com/netty/netty/issues/5372 - // Use this to prevent an error when running on java < 7 - if (useExtendedTrustManager(manager)) { - SSLContext.setCertVerifyCallback(ctx, - new ExtendedTrustManagerVerifyCallback(engineMap, (X509ExtendedTrustManager) manager)); - } else { - SSLContext.setCertVerifyCallback(ctx, new TrustManagerVerifyCallback(engineMap, manager)); - } + // Use this to prevent an error when running on java < 7 + if (useExtendedTrustManager(manager)) { + SSLContext.setCertVerifyCallback(ctx, new ExtendedTrustManagerVerifyCallback( + engineMap, (X509ExtendedTrustManager) manager)); + } else { + SSLContext.setCertVerifyCallback(ctx, new TrustManagerVerifyCallback(engineMap, manager)); + } - X509Certificate[] issuers = manager.getAcceptedIssuers(); - if (issuers != null && issuers.length > 0) { - long bio = 0; - try { - bio = toBIO(issuers); - if (!SSLContext.setCACertificateBio(ctx, bio)) { - throw new SSLException("unable to setup accepted issuers for trustmanager " + manager); + X509Certificate[] issuers = manager.getAcceptedIssuers(); + if (issuers != null && issuers.length > 0) { + long bio = 0; + try { + bio = toBIO(ByteBufAllocator.DEFAULT, issuers); + if (!SSLContext.setCACertificateBio(ctx, bio)) { + throw new SSLException("unable to setup accepted issuers for trustmanager " + manager); + } + } finally { + freeBio(bio); } - } finally { - freeBio(bio); } + + if (PlatformDependent.javaVersion() >= 8) { + // Only do on Java8+ as SNIMatcher is not supported in earlier releases. + // IMPORTANT: The callbacks set for hostname matching must be static to prevent memory leak as + // otherwise the context can never be collected. This is because the JNI code holds + // a global reference to the matcher. + SSLContext.setSniHostnameMatcher(ctx, new OpenSslSniHostnameMatcher(engineMap)); + } + } catch (SSLException e) { + throw e; + } catch (Exception e) { + throw new SSLException("unable to setup trustmanager", e); } - if (PlatformDependent.javaVersion() >= 8) { - // Only do on Java8+ as SNIMatcher is not supported in earlier releases. - // IMPORTANT: The callbacks set for hostname matching must be static to prevent memory leak as - // otherwise the context can never be collected. This is because the JNI code holds - // a global reference to the matcher. - SSLContext.setSniHostnameMatcher(ctx, new OpenSslSniHostnameMatcher(engineMap)); + OpenSslServerSessionContext sessionContext = new OpenSslServerSessionContext(thiz, keyMaterialProvider); + sessionContext.setSessionIdContext(ID); + + keyMaterialProvider = null; + + return sessionContext; + } finally { + if (keyMaterialProvider != null) { + keyMaterialProvider.destroy(); } - } catch (SSLException e) { - throw e; - } catch (Exception e) { - throw new SSLException("unable to setup trustmanager", e); + } + } + + private static final class OpenSslServerCertificateCallback implements CertificateCallback { + private final OpenSslEngineMap engineMap; + private final OpenSslKeyMaterialManager keyManagerHolder; + + OpenSslServerCertificateCallback(OpenSslEngineMap engineMap, OpenSslKeyMaterialManager keyManagerHolder) { + this.engineMap = engineMap; + this.keyManagerHolder = keyManagerHolder; } - result.sessionContext = new OpenSslServerSessionContext(thiz); - result.sessionContext.setSessionIdContext(ID); - return result; + @Override + public void handle(long ssl, byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals) throws Exception { + final ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); + try { + // For now we just ignore the asn1DerEncodedPrincipals as this is kind of inline with what the + // OpenJDK SSLEngineImpl does. + keyManagerHolder.setKeyMaterialServerSide(engine); + } catch (Throwable cause) { + logger.debug("Failed to set the server-side key material", cause); + SSLHandshakeException e = new SSLHandshakeException("General OpenSslEngine problem"); + e.initCause(cause); + engine.handshakeException = e; + } + } } private static final class TrustManagerVerifyCallback extends AbstractCertificateVerifier { @@ -209,7 +239,7 @@ private static final class ExtendedTrustManagerVerifyCallback extends AbstractCe ExtendedTrustManagerVerifyCallback(OpenSslEngineMap engineMap, X509ExtendedTrustManager manager) { super(engineMap); - this.manager = manager; + this.manager = OpenSslTlsv13X509ExtendedTrustManager.wrap(manager, false); } @Override @@ -230,7 +260,8 @@ private static final class OpenSslSniHostnameMatcher implements SniHostNameMatch public boolean match(long ssl, String hostname) { ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); if (engine != null) { - return engine.checkSniHostnameMatch(hostname); + // TODO: In the next release of tcnative we should pass the byte[] directly in and not use a String. + return engine.checkSniHostnameMatch(hostname.getBytes(CharsetUtil.UTF_8)); } logger.warn("No ReferenceCountedOpenSslEngine found for SSL pointer: {}", ssl); return false; diff --git a/handler/src/main/java/io/netty/handler/ssl/SignatureAlgorithmConverter.java b/handler/src/main/java/io/netty/handler/ssl/SignatureAlgorithmConverter.java new file mode 100644 index 000000000000..c68a5f9a51c6 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/SignatureAlgorithmConverter.java @@ -0,0 +1,74 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Converts OpenSSL signature Algorithm names to + * + * Java signature Algorithm names. + */ +final class SignatureAlgorithmConverter { + + private SignatureAlgorithmConverter() { } + + // OpenSSL has 3 different formats it uses at the moment we will match against all of these. + // For example: + // ecdsa-with-SHA384 + // hmacWithSHA384 + // dsa_with_SHA224 + // + // For more details see https://github.com/openssl/openssl/blob/OpenSSL_1_0_2p/crypto/objects/obj_dat.h + // + // BoringSSL uses a different format: + // https://github.com/google/boringssl/blob/8525ff3/ssl/ssl_privkey.cc#L436 + // + private static final Pattern PATTERN = Pattern.compile( + // group 1 - 2 + "(?:(^[a-zA-Z].+)With(.+)Encryption$)|" + + // group 3 - 4 + "(?:(^[a-zA-Z].+)(?:_with_|-with-|_pkcs1_|_pss_rsae_)(.+$))|" + + // group 5 - 6 + "(?:(^[a-zA-Z].+)_(.+$))"); + + /** + * Converts an OpenSSL algorithm name to a Java algorithm name and return it, + * or return {@code null} if the conversation failed because the format is not known. + */ + static String toJavaName(String opensslName) { + if (opensslName == null) { + return null; + } + Matcher matcher = PATTERN.matcher(opensslName); + if (matcher.matches()) { + String group1 = matcher.group(1); + if (group1 != null) { + return group1.toUpperCase(Locale.ROOT) + "with" + matcher.group(2).toUpperCase(Locale.ROOT); + } + if (matcher.group(3) != null) { + return matcher.group(4).toUpperCase(Locale.ROOT) + "with" + matcher.group(3).toUpperCase(Locale.ROOT); + } + + if (matcher.group(5) != null) { + return matcher.group(6).toUpperCase(Locale.ROOT) + "with" + matcher.group(5).toUpperCase(Locale.ROOT); + } + } + return null; + } +} diff --git a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java index cda2bbdce7ae..c6a82278d823 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java @@ -15,6 +15,7 @@ */ package io.netty.handler.ssl; +import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.DecoderException; import io.netty.util.AsyncMapping; @@ -129,7 +130,7 @@ protected final void onLookupComplete(ChannelHandlerContext ctx, protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslContext sslContext) throws Exception { SslHandler sslHandler = null; try { - sslHandler = sslContext.newHandler(ctx.alloc()); + sslHandler = newSslHandler(sslContext, ctx.alloc()); ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler); sslHandler = null; } finally { @@ -142,6 +143,14 @@ protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslCon } } + /** + * Returns a new {@link SslHandler} using the given {@link SslContext} and {@link ByteBufAllocator}. + * Users may override this method to implement custom behavior. + */ + protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) { + return context.newHandler(allocator); + } + private static final class AsyncMappingAdapter implements AsyncMapping { private final Mapping mapping; diff --git a/handler/src/main/java/io/netty/handler/ssl/SslContext.java b/handler/src/main/java/io/netty/handler/ssl/SslContext.java index ef5c4bfdca79..dab0ce745d28 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslContext.java @@ -53,7 +53,6 @@ import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; -import java.security.Security; import java.security.UnrecoverableKeyException; import java.security.cert.CertificateException; import java.security.cert.CertificateFactory; @@ -61,6 +60,7 @@ import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.util.List; +import java.util.concurrent.Executor; /** * A secure socket protocol implementation which acts as a factory for {@link SSLEngine} and {@link SslHandler}. @@ -85,6 +85,8 @@ * */ public abstract class SslContext { + static final String ALIAS = "key"; + static final CertificateFactory X509_CERT_FACTORY; static { try { @@ -878,6 +880,22 @@ public final List nextProtocols() { */ public abstract SSLSessionContext sessionContext(); + /** + * Create a new SslHandler. + * @see #newHandler(ByteBufAllocator, Executor) + */ + public final SslHandler newHandler(ByteBufAllocator alloc) { + return newHandler(alloc, startTls); + } + + /** + * Create a new SslHandler. + * @see #newHandler(ByteBufAllocator) + */ + protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls) { + return new SslHandler(newEngine(alloc), startTls); + } + /** * Creates a new {@link SslHandler}. *

    If {@link SslProvider#OPENSSL_REFCNT} is used then the returned {@link SslHandler} will release the engine @@ -899,18 +917,37 @@ public final List nextProtocols() { * SSLEngine javadocs which * limits wrap/unwrap to operate on a single SSL/TLS packet. * @param alloc If supported by the SSLEngine then the SSLEngine will use this to allocate ByteBuf objects. + * @param delegatedTaskExecutor the {@link Executor} that will be used to execute tasks that are returned by + * {@link SSLEngine#getDelegatedTask()}. * @return a new {@link SslHandler} */ - public final SslHandler newHandler(ByteBufAllocator alloc) { - return newHandler(alloc, startTls); + public SslHandler newHandler(ByteBufAllocator alloc, Executor delegatedTaskExecutor) { + return newHandler(alloc, startTls, delegatedTaskExecutor); } /** * Create a new SslHandler. - * @see #newHandler(ByteBufAllocator) + * @see #newHandler(ByteBufAllocator, String, int, boolean, Executor) */ - protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls) { - return new SslHandler(newEngine(alloc), startTls); + protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls, Executor executor) { + return new SslHandler(newEngine(alloc), startTls, executor); + } + + /** + * Creates a new {@link SslHandler} + * + * @see #newHandler(ByteBufAllocator, String, int, Executor) + */ + public final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort) { + return newHandler(alloc, peerHost, peerPort, startTls); + } + + /** + * Create a new SslHandler. + * @see #newHandler(ByteBufAllocator, String, int, boolean, Executor) + */ + protected SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, boolean startTls) { + return new SslHandler(newEngine(alloc, peerHost, peerPort), startTls); } /** @@ -936,19 +973,19 @@ protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls) { * @param alloc If supported by the SSLEngine then the SSLEngine will use this to allocate ByteBuf objects. * @param peerHost the non-authoritative name of the host * @param peerPort the non-authoritative port + * @param delegatedTaskExecutor the {@link Executor} that will be used to execute tasks that are returned by + * {@link SSLEngine#getDelegatedTask()}. * * @return a new {@link SslHandler} */ - public final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort) { - return newHandler(alloc, peerHost, peerPort, startTls); + public SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, + Executor delegatedTaskExecutor) { + return newHandler(alloc, peerHost, peerPort, startTls, delegatedTaskExecutor); } - /** - * Create a new SslHandler. - * @see #newHandler(ByteBufAllocator, String, int, boolean) - */ - protected SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, boolean startTls) { - return new SslHandler(newEngine(alloc, peerHost, peerPort), startTls); + protected SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, boolean startTls, + Executor delegatedTaskExecutor) { + return new SslHandler(newEngine(alloc, peerHost, peerPort), startTls, delegatedTaskExecutor); } /** @@ -1000,7 +1037,7 @@ static KeyStore buildKeyStore(X509Certificate[] certChain, PrivateKey key, char[ CertificateException, IOException { KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); ks.load(null, null); - ks.setKeyEntry("key", key, keyPasswordChars, certChain); + ks.setKeyEntry(ALIAS, key, keyPasswordChars, certChain); return ks; } @@ -1081,10 +1118,9 @@ private static X509Certificate[] getCertificatesFromBuffers(ByteBuf[] certs) thr CertificateFactory cf = CertificateFactory.getInstance("X.509"); X509Certificate[] x509Certs = new X509Certificate[certs.length]; - int i = 0; try { - for (; i < certs.length; i++) { - InputStream is = new ByteBufInputStream(certs[i], true); + for (int i = 0; i < certs.length; i++) { + InputStream is = new ByteBufInputStream(certs[i], false); try { x509Certs[i] = (X509Certificate) cf.generateCertificate(is); } finally { @@ -1097,8 +1133,8 @@ private static X509Certificate[] getCertificatesFromBuffers(ByteBuf[] certs) thr } } } finally { - for (; i < certs.length; i++) { - certs[i].release(); + for (ByteBuf buf: certs) { + buf.release(); } } return x509Certs; @@ -1154,8 +1190,15 @@ static KeyManagerFactory buildKeyManagerFactory(X509Certificate[] certChainFile, String keyPassword, KeyManagerFactory kmf) throws KeyStoreException, NoSuchAlgorithmException, IOException, CertificateException, UnrecoverableKeyException { - char[] keyPasswordChars = keyPassword == null ? EmptyArrays.EMPTY_CHARS : keyPassword.toCharArray(); + char[] keyPasswordChars = keyStorePassword(keyPassword); KeyStore ks = buildKeyStore(certChainFile, key, keyPasswordChars); + return buildKeyManagerFactory(ks, keyAlgorithm, keyPasswordChars, kmf); + } + + static KeyManagerFactory buildKeyManagerFactory(KeyStore ks, + String keyAlgorithm, + char[] keyPasswordChars, KeyManagerFactory kmf) + throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { // Set up key manager factory to use our key store if (kmf == null) { kmf = KeyManagerFactory.getInstance(keyAlgorithm); @@ -1164,4 +1207,8 @@ static KeyManagerFactory buildKeyManagerFactory(X509Certificate[] certChainFile, return kmf; } + + static char[] keyStorePassword(String keyPassword) { + return keyPassword == null ? EmptyArrays.EMPTY_CHARS : keyPassword.toCharArray(); + } } diff --git a/handler/src/main/java/io/netty/handler/ssl/SslContextBuilder.java b/handler/src/main/java/io/netty/handler/ssl/SslContextBuilder.java index 2e3c55b9cb2d..ae21440ce099 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslContextBuilder.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslContextBuilder.java @@ -121,6 +121,9 @@ public static SslContextBuilder forServer( /** * Creates a builder for new server-side {@link SslContext}. * + * If you use {@link SslProvider#OPENSSL} or {@link SslProvider#OPENSSL_REFCNT} consider using + * {@link OpenSslX509KeyManagerFactory} or {@link OpenSslCachingX509KeyManagerFactory}. + * * @param keyManagerFactory non-{@code null} factory for server's private key * @see #keyManager(KeyManagerFactory) */ @@ -335,6 +338,9 @@ public SslContextBuilder keyManager(PrivateKey key, String keyPassword, X509Cert * if the used openssl version is 1.0.1+. You can check if your openssl version supports using a * {@link KeyManagerFactory} by calling {@link OpenSsl#supportsKeyManagerFactory()}. If this is not the case * you must use {@link #keyManager(File, File)} or {@link #keyManager(File, File, String)}. + * + * If you use {@link SslProvider#OPENSSL} or {@link SslProvider#OPENSSL_REFCNT} consider using + * {@link OpenSslX509KeyManagerFactory} or {@link OpenSslCachingX509KeyManagerFactory}. */ public SslContextBuilder keyManager(KeyManagerFactory keyManagerFactory) { if (forServer) { diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index fb6fc92766ad..c3f74af4ef0b 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -33,6 +33,7 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromiseNotifier; import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; import io.netty.handler.codec.UnsupportedMessageTypeException; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; @@ -42,6 +43,7 @@ import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.ImmediateExecutor; import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.PromiseNotifier; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.ThrowableUtil; import io.netty.util.internal.UnstableApi; @@ -54,10 +56,9 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.DatagramChannel; import java.nio.channels.SocketChannel; -import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; @@ -93,8 +94,8 @@ * *

    Closing the session

    *

    - * To close the SSL session, the {@link #close()} method should be - * called to send the {@code close_notify} message to the remote peer. One + * To close the SSL session, the {@link #closeOutbound()} method should be + * called to send the {@code close_notify} message to the remote peer. One * exception is when you close the {@link Channel} - {@link SslHandler} * intercepts the close request and send the {@code close_notify} message * before the channel closure automatically. Once the SSL session is closed, @@ -389,6 +390,7 @@ abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, private boolean flushedBeforeHandshake; private boolean readDuringHandshake; private boolean handshakeStarted; + private SslHandlerCoalescingBufferQueue pendingUnencryptedWrites; private Promise handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise sslClosePromise = new LazyChannelPromise(); @@ -401,6 +403,7 @@ abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, private boolean outboundClosed; private boolean closeNotify; + private boolean processTask; private int packetLength; @@ -416,7 +419,7 @@ abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, volatile int wrapDataSize = MAX_PLAINTEXT_LENGTH; /** - * Creates a new instance. + * Creates a new instance which runs all delegated tasks directly on the {@link EventExecutor}. * * @param engine the {@link SSLEngine} this handler will use */ @@ -425,29 +428,36 @@ public SslHandler(SSLEngine engine) { } /** - * Creates a new instance. + * Creates a new instance which runs all delegated tasks directly on the {@link EventExecutor}. * * @param engine the {@link SSLEngine} this handler will use * @param startTls {@code true} if the first write request shouldn't be * encrypted by the {@link SSLEngine} */ - @SuppressWarnings("deprecation") public SslHandler(SSLEngine engine, boolean startTls) { this(engine, startTls, ImmediateExecutor.INSTANCE); } /** - * @deprecated Use {@link #SslHandler(SSLEngine)} instead. + * Creates a new instance. + * + * @param engine the {@link SSLEngine} this handler will use + * @param delegatedTaskExecutor the {@link Executor} that will be used to execute tasks that are returned by + * {@link SSLEngine#getDelegatedTask()}. */ - @Deprecated public SslHandler(SSLEngine engine, Executor delegatedTaskExecutor) { this(engine, false, delegatedTaskExecutor); } /** - * @deprecated Use {@link #SslHandler(SSLEngine, boolean)} instead. + * Creates a new instance. + * + * @param engine the {@link SSLEngine} this handler will use + * @param startTls {@code true} if the first write request shouldn't be + * encrypted by the {@link SSLEngine} + * @param delegatedTaskExecutor the {@link Executor} that will be used to execute tasks that are returned by + * {@link SSLEngine#getDelegatedTask()}. */ - @Deprecated public SslHandler(SSLEngine engine, boolean startTls, Executor delegatedTaskExecutor) { if (engine == null) { throw new NullPointerException("engine"); @@ -622,42 +632,64 @@ public Future handshakeFuture() { } /** - * Sends an SSL {@code close_notify} message to the specified channel and - * destroys the underlying {@link SSLEngine}. - * - * @deprecated use {@link Channel#close()} or {@link ChannelHandlerContext#close()} + * Use {@link #closeOutbound()} */ @Deprecated public ChannelFuture close() { - return close(ctx.newPromise()); + return closeOutbound(); } /** - * See {@link #close()} - * - * @deprecated use {@link Channel#close()} or {@link ChannelHandlerContext#close()} + * Use {@link #closeOutbound(ChannelPromise)} */ @Deprecated - public ChannelFuture close(final ChannelPromise promise) { + public ChannelFuture close(ChannelPromise promise) { + return closeOutbound(promise); + } + + /** + * Sends an SSL {@code close_notify} message to the specified channel and + * destroys the underlying {@link SSLEngine}. This will not close the underlying + * {@link Channel}. If you want to also close the {@link Channel} use {@link Channel#close()} or + * {@link ChannelHandlerContext#close()} + */ + public ChannelFuture closeOutbound() { + return closeOutbound(ctx.newPromise()); + } + + /** + * Sends an SSL {@code close_notify} message to the specified channel and + * destroys the underlying {@link SSLEngine}. This will not close the underlying + * {@link Channel}. If you want to also close the {@link Channel} use {@link Channel#close()} or + * {@link ChannelHandlerContext#close()} + */ + public ChannelFuture closeOutbound(final ChannelPromise promise) { final ChannelHandlerContext ctx = this.ctx; - ctx.executor().execute(new Runnable() { - @Override - public void run() { - outboundClosed = true; - engine.closeOutbound(); - try { - flush(ctx, promise); - } catch (Exception e) { - if (!promise.tryFailure(e)) { - logger.warn("{} flush() raised a masked exception.", ctx.channel(), e); - } + if (ctx.executor().inEventLoop()) { + closeOutbound0(promise); + } else { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + closeOutbound0(promise); } - } - }); - + }); + } return promise; } + private void closeOutbound0(ChannelPromise promise) { + outboundClosed = true; + engine.closeOutbound(); + try { + flush(ctx, promise); + } catch (Exception e) { + if (!promise.tryFailure(e)) { + logger.warn("{} flush() raised a masked exception.", ctx.channel(), e); + } + } + } + /** * Return the {@link Future} that will get notified if the inbound of the {@link SSLEngine} is closed. * @@ -745,6 +777,13 @@ public void flush(ChannelHandlerContext ctx) throws Exception { sentFirstMessage = true; pendingUnencryptedWrites.writeAndRemoveAll(ctx); forceFlush(ctx); + // Explicit start handshake processing once we send the first message. This will also ensure + // we will schedule the timeout if needed. + startHandshakeProcessing(); + return; + } + + if (processTask) { return; } @@ -787,7 +826,7 @@ private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLExcepti final int wrapDataSize = this.wrapDataSize; // Only continue to loop if the handler was not removed in the meantime. // See https://github.com/netty/netty/issues/5860 - while (!ctx.isRemoved()) { + outer: while (!ctx.isRemoved()) { promise = ctx.newPromise(); buf = wrapDataSize > 0 ? pendingUnencryptedWrites.remove(alloc, wrapDataSize, promise) : @@ -824,7 +863,11 @@ private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLExcepti switch (result.getHandshakeStatus()) { case NEED_TASK: - runDelegatedTasks(); + if (!runDelegatedTasks(inUnwrap)) { + // We scheduled a task on the delegatingTaskExecutor, so stop processing as we will + // resume once the task completes. + break outer; + } break; case FINISHED: setHandshakeSuccess(); @@ -893,7 +936,7 @@ private boolean wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) thro try { // Only continue to loop if the handler was not removed in the meantime. // See https://github.com/netty/netty/issues/5860 - while (!ctx.isRemoved()) { + outer: while (!ctx.isRemoved()) { if (out == null) { // As this is called for the handshake we have no real idea how big the buffer needs to be. // That said 2048 should give us enough room to include everything like ALPN / NPN data. @@ -910,12 +953,17 @@ private boolean wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) thro out = null; } - switch (result.getHandshakeStatus()) { + HandshakeStatus status = result.getHandshakeStatus(); + switch (status) { case FINISHED: setHandshakeSuccess(); return false; case NEED_TASK: - runDelegatedTasks(); + if (!runDelegatedTasks(inUnwrap)) { + // We scheduled a task on the delegatingTaskExecutor, so stop processing as we will + // resume once the task completes. + break outer; + } break; case NEED_UNWRAP: if (inUnwrap) { @@ -941,7 +989,7 @@ private boolean wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) thro throw new IllegalStateException("Unknown handshake status: " + result.getHandshakeStatus()); } - if (result.bytesProduced() == 0) { + if (result.bytesProduced() == 0 && status != HandshakeStatus.NEED_TASK) { break; } @@ -1217,6 +1265,9 @@ private void handleUnwrapThrowable(ChannelHandlerContext ctx, Throwable cause) { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws SSLException { + if (processTask) { + return; + } if (jdkCompatibilityMode) { decodeJdkCompatible(ctx, in); } else { @@ -1226,6 +1277,10 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t @Override public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete0(ctx); + } + + private void channelReadComplete0(ChannelHandlerContext ctx) { // Discard bytes of the cumulation buffer if needed. discardSomeReadBytes(); @@ -1340,7 +1395,16 @@ private int unwrap( } break; case NEED_TASK: - runDelegatedTasks(); + if (!runDelegatedTasks(true)) { + // We scheduled a task on the delegatingTaskExecutor, so stop processing as we will + // resume once the task completes. + // + // We break out of the loop only and do NOT return here as we still may need to notify + // about the closure of the SSLEngine. + // + wrapLater = false; + break unwrapLoop; + } break; case FINISHED: setHandshakeSuccess(); @@ -1363,13 +1427,7 @@ private int unwrap( wrapLater = true; continue; } - if (flushedBeforeHandshake) { - // We need to call wrap(...) in case there was a flush done before the handshake completed. - // - // See https://github.com/netty/netty/pull/2437 - flushedBeforeHandshake = false; - wrapLater = true; - } + // If we are not handshaking and there is no more data to unwrap then the next call to unwrap // will not produce any data. We can avoid the potentially costly unwrap operation and break // out of the loop. @@ -1381,7 +1439,9 @@ private int unwrap( throw new IllegalStateException("unknown handshake status: " + handshakeStatus); } - if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) { + if (status == Status.BUFFER_UNDERFLOW || + // If we processed NEED_TASK we should try again even we did not consume or produce anything. + handshakeStatus != HandshakeStatus.NEED_TASK && consumed == 0 && produced == 0) { if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) { // The underlying engine is starving so we need to feed it with more data. // See https://github.com/netty/netty/pull/5039 @@ -1392,6 +1452,15 @@ private int unwrap( } } + if (flushedBeforeHandshake && handshakePromise.isDone()) { + // We need to call wrap(...) in case there was a flush done before the handshake completed to ensure + // we do not stale. + // + // See https://github.com/netty/netty/pull/2437 + flushedBeforeHandshake = false; + wrapLater = true; + } + if (wrapLater) { wrap(ctx, true); } @@ -1418,65 +1487,233 @@ private static ByteBuffer toByteBuffer(ByteBuf out, int index, int len) { out.nioBuffer(index, len); } - /** - * Fetches all delegated tasks from the {@link SSLEngine} and runs them via the {@link #delegatedTaskExecutor}. - * If the {@link #delegatedTaskExecutor} is {@link ImmediateExecutor}, just call {@link Runnable#run()} directly - * instead of using {@link Executor#execute(Runnable)}. Otherwise, run the tasks via - * the {@link #delegatedTaskExecutor} and wait until the tasks are finished. - */ - private void runDelegatedTasks() { - if (delegatedTaskExecutor == ImmediateExecutor.INSTANCE) { - for (;;) { - Runnable task = engine.getDelegatedTask(); - if (task == null) { - break; - } + private static boolean inEventLoop(Executor executor) { + return executor instanceof EventExecutor && ((EventExecutor) executor).inEventLoop(); + } - task.run(); + private static void runAllDelegatedTasks(SSLEngine engine) { + for (;;) { + Runnable task = engine.getDelegatedTask(); + if (task == null) { + return; } + task.run(); + } + } + + /** + * Will either run the delegated task directly calling {@link Runnable#run()} and return {@code true} or will + * offload the delegated task using {@link Executor#execute(Runnable)} and return {@code false}. + * + * If the task is offloaded it will take care to resume its work on the {@link EventExecutor} once there are no + * more tasks to process. + */ + private boolean runDelegatedTasks(boolean inUnwrap) { + if (delegatedTaskExecutor == ImmediateExecutor.INSTANCE || inEventLoop(delegatedTaskExecutor)) { + // We should run the task directly in the EventExecutor thread and not offload at all. + runAllDelegatedTasks(engine); + return true; } else { - final List tasks = new ArrayList(2); - for (;;) { - final Runnable task = engine.getDelegatedTask(); - if (task == null) { - break; + executeDelegatedTasks(inUnwrap); + return false; + } + } + + private void executeDelegatedTasks(boolean inUnwrap) { + processTask = true; + try { + delegatedTaskExecutor.execute(new SslTasksRunner(inUnwrap)); + } catch (RejectedExecutionException e) { + processTask = false; + throw e; + } + } + + /** + * {@link Runnable} that will be scheduled on the {@code delegatedTaskExecutor} and will take care + * of resume work on the {@link EventExecutor} once the task was executed. + */ + private final class SslTasksRunner implements Runnable { + private final boolean inUnwrap; + + SslTasksRunner(boolean inUnwrap) { + this.inUnwrap = inUnwrap; + } + + // Handle errors which happened during task processing. + private void taskError(Throwable e) { + if (inUnwrap) { + // As the error happened while the task was scheduled as part of unwrap(...) we also need to ensure + // we fire it through the pipeline as inbound error to be consistent with what we do in decode(...). + // + // This will also ensure we fail the handshake future and flush all produced data. + try { + handleUnwrapThrowable(ctx, e); + } catch (Throwable cause) { + safeExceptionCaught(cause); } + } else { + setHandshakeFailure(ctx, e); + forceFlush(ctx); + } + } - tasks.add(task); + // Try to call exceptionCaught(...) + private void safeExceptionCaught(Throwable cause) { + try { + exceptionCaught(ctx, wrapIfNeeded(cause)); + } catch (Throwable error) { + ctx.fireExceptionCaught(error); } + } - if (tasks.isEmpty()) { - return; + private Throwable wrapIfNeeded(Throwable cause) { + if (!inUnwrap) { + // If we are not in unwrap(...) we can just rethrow without wrapping at all. + return cause; } + // As the exception would have been triggered by an inbound operation we will need to wrap it in a + // DecoderException to mimic what a decoder would do when decode(...) throws. + return cause instanceof DecoderException ? cause : new DecoderException(cause); + } + + private void tryDecodeAgain() { + try { + channelRead(ctx, Unpooled.EMPTY_BUFFER); + } catch (Throwable cause) { + safeExceptionCaught(cause); + } finally { + // As we called channelRead(...) we also need to call channelReadComplete(...) which + // will ensure we either call ctx.fireChannelReadComplete() or will trigger a ctx.read() if + // more data is needed. + channelReadComplete0(ctx); + } + } - final CountDownLatch latch = new CountDownLatch(1); - delegatedTaskExecutor.execute(new Runnable() { - @Override - public void run() { - try { - for (Runnable task: tasks) { - task.run(); + /** + * Executed after the wrapped {@code task} was executed via {@code delegatedTaskExecutor} to resume work + * on the {@link EventExecutor}. + */ + private void resumeOnEventExecutor() { + assert ctx.executor().inEventLoop(); + + processTask = false; + + try { + HandshakeStatus status = engine.getHandshakeStatus(); + switch (status) { + // There is another task that needs to be executed and offloaded to the delegatingTaskExecutor. + case NEED_TASK: + executeDelegatedTasks(inUnwrap); + + break; + + // The handshake finished, lets notify about the completion of it and resume processing. + case FINISHED: + setHandshakeSuccess(); + + // deliberate fall-through + + // Not handshaking anymore, lets notify about the completion if not done yet and resume processing. + case NOT_HANDSHAKING: + setHandshakeSuccessIfStillHandshaking(); + try { + // Lets call wrap to ensure we produce the alert if there is any pending and also to + // ensure we flush any queued data.. + wrap(ctx, inUnwrap); + } catch (Throwable e) { + taskError(e); + return; + } + if (inUnwrap) { + // If we were in the unwrap call when the task was processed we should also try to unwrap + // non app data first as there may not anything left in the inbound buffer to process. + unwrapNonAppData(ctx); } - } catch (Exception e) { - ctx.fireExceptionCaught(e); - } finally { - latch.countDown(); - } - } - }); - boolean interrupted = false; - while (latch.getCount() != 0) { - try { - latch.await(); - } catch (InterruptedException e) { - // Interrupt later. - interrupted = true; + // Flush now as we may have written some data as part of the wrap call. + forceFlush(ctx); + + tryDecodeAgain(); + break; + + // We need more data so lets try to unwrap first and then call decode again which will feed us + // with buffered data (if there is any). + case NEED_UNWRAP: + unwrapNonAppData(ctx); + tryDecodeAgain(); + break; + + // To make progress we need to call SSLEngine.wrap(...) which may produce more output data + // that will be written to the Channel. + case NEED_WRAP: + try { + if (!wrapNonAppData(ctx, false) && inUnwrap) { + // The handshake finished in wrapNonAppData(...), we need to try call + // unwrapNonAppData(...) as we may have some alert that we should read. + // + // This mimics what we would do when we are calling this method while in unwrap(...). + unwrapNonAppData(ctx); + } + + // Flush now as we may have written some data as part of the wrap call. + forceFlush(ctx); + } catch (Throwable e) { + taskError(e); + return; + } + + // Now try to feed in more data that we have buffered. + tryDecodeAgain(); + break; + default: + // Should never reach here as we handle all cases. + throw new AssertionError(); } + } catch (Throwable cause) { + safeExceptionCaught(cause); } + } - if (interrupted) { - Thread.currentThread().interrupt(); + @Override + public void run() { + try { + runAllDelegatedTasks(engine); + + // All tasks were processed. + assert engine.getHandshakeStatus() != HandshakeStatus.NEED_TASK; + + // Jump back on the EventExecutor. + ctx.executor().execute(new Runnable() { + @Override + public void run() { + resumeOnEventExecutor(); + } + }); + } catch (final Throwable cause) { + handleException(cause); + } + } + + private void handleException(final Throwable cause) { + if (ctx.executor().inEventLoop()) { + processTask = false; + safeExceptionCaught(cause); + } else { + try { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + processTask = false; + safeExceptionCaught(cause); + } + }); + } catch (RejectedExecutionException ignore) { + processTask = false; + // the context itself will handle the rejected exception when try to schedule the operation so + // ignore the RejectedExecutionException + ctx.fireExceptionCaught(cause); + } } } } @@ -1541,7 +1778,8 @@ private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause, boo // // See https://github.com/netty/netty/issues/1340 String msg = e.getMessage(); - if (msg == null || !msg.contains("possible truncation attack")) { + if (msg == null || !(msg.contains("possible truncation attack") || + msg.contains("closing inbound before receiving peer's close_notify"))) { logger.debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e); } } @@ -1636,14 +1874,15 @@ public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { } private void startHandshakeProcessing() { - handshakeStarted = true; - if (engine.getUseClientMode()) { - // Begin the initial handshake. - // channelActive() event has been fired already, which means this.channelActive() will - // not be invoked. We have to initialize here instead. - handshake(null); - } else { - applyHandshakeTimeout(null); + if (!handshakeStarted) { + handshakeStarted = true; + if (engine.getUseClientMode()) { + // Begin the initial handshake. + // channelActive() event has been fired already, which means this.channelActive() will + // not be invoked. We have to initialize here instead. + handshake(); + } + applyHandshakeTimeout(); } } @@ -1677,52 +1916,46 @@ public Future renegotiate(final Promise promise) { executor.execute(new Runnable() { @Override public void run() { - handshake(promise); + renegotiateOnEventLoop(promise); } }); return promise; } - handshake(promise); + renegotiateOnEventLoop(promise); return promise; } + private void renegotiateOnEventLoop(final Promise newHandshakePromise) { + final Promise oldHandshakePromise = handshakePromise; + if (!oldHandshakePromise.isDone()) { + // There's no need to handshake because handshake is in progress already. + // Merge the new promise into the old one. + oldHandshakePromise.addListener(new PromiseNotifier>(newHandshakePromise)); + } else { + handshakePromise = newHandshakePromise; + handshake(); + applyHandshakeTimeout(); + } + } + /** * Performs TLS (re)negotiation. - * - * @param newHandshakePromise if {@code null}, use the existing {@link #handshakePromise}, - * assuming that the current negotiation has not been finished. - * Currently, {@code null} is expected only for the initial handshake. */ - private void handshake(final Promise newHandshakePromise) { - final Promise p; - if (newHandshakePromise != null) { - final Promise oldHandshakePromise = handshakePromise; - if (!oldHandshakePromise.isDone()) { - // There's no need to handshake because handshake is in progress already. - // Merge the new promise into the old one. - oldHandshakePromise.addListener(new FutureListener() { - @Override - public void operationComplete(Future future) throws Exception { - if (future.isSuccess()) { - newHandshakePromise.setSuccess(future.getNow()); - } else { - newHandshakePromise.setFailure(future.cause()); - } - } - }); - return; - } - - handshakePromise = p = newHandshakePromise; - } else if (engine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) { + private void handshake() { + if (engine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) { // Not all SSLEngine implementations support calling beginHandshake multiple times while a handshake // is in progress. See https://github.com/netty/netty/issues/4718. return; } else { - // Forced to reuse the old handshake. - p = handshakePromise; - assert !p.isDone(); + if (handshakePromise.isDone()) { + // If the handshake is done already lets just return directly as there is no need to trigger it again. + // This can happen if the handshake(...) was triggered before we called channelActive(...) by a + // flush() that was triggered by a ChannelFutureListener that was added to the ChannelFuture returned + // from the connect(...) method. In this case we will see the flush() happen before we had a chance to + // call fireChannelActive() on the pipeline. + return; + } } // Begin handshake. @@ -1733,27 +1966,27 @@ public void operationComplete(Future future) throws Exception { } catch (Throwable e) { setHandshakeFailure(ctx, e); } finally { - forceFlush(ctx); + forceFlush(ctx); } - applyHandshakeTimeout(p); } - private void applyHandshakeTimeout(Promise p) { - final Promise promise = p == null ? handshakePromise : p; + private void applyHandshakeTimeout() { + final Promise localHandshakePromise = this.handshakePromise; + // Set timeout if necessary. final long handshakeTimeoutMillis = this.handshakeTimeoutMillis; - if (handshakeTimeoutMillis <= 0 || promise.isDone()) { + if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { return; } final ScheduledFuture timeoutFuture = ctx.executor().schedule(new Runnable() { @Override public void run() { - if (promise.isDone()) { + if (localHandshakePromise.isDone()) { return; } try { - if (handshakePromise.tryFailure(HANDSHAKE_TIMED_OUT)) { + if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT)) { SslUtils.handleHandshakeFailure(ctx, HANDSHAKE_TIMED_OUT, true); } } finally { @@ -1763,7 +1996,7 @@ public void run() { }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); // Cancel the handshake timeout when handshake is finished. - promise.addListener(new FutureListener() { + localHandshakePromise.addListener(new FutureListener() { @Override public void operationComplete(Future f) throws Exception { timeoutFuture.cancel(false); diff --git a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java index 3dd40fdf37be..e7640365adc9 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java @@ -21,9 +21,15 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.base64.Base64; import io.netty.handler.codec.base64.Base64Dialect; +import io.netty.util.NetUtil; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashSet; import java.util.List; import java.util.Set; @@ -35,7 +41,11 @@ * Constants for SSL packets. */ final class SslUtils { - + // See https://tools.ietf.org/html/rfc8446#appendix-B.4 + static final Set TLSV13_CIPHERS = Collections.unmodifiableSet(new LinkedHashSet( + asList("TLS_AES_256_GCM_SHA384", "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_128_GCM_SHA256", "TLS_AES_128_CCM_8_SHA256", + "TLS_AES_128_CCM_SHA256"))); // Protocols static final String PROTOCOL_SSL_V2_HELLO = "SSLv2Hello"; static final String PROTOCOL_SSL_V2 = "SSLv2"; @@ -43,6 +53,9 @@ final class SslUtils { static final String PROTOCOL_TLS_V1 = "TLSv1"; static final String PROTOCOL_TLS_V1_1 = "TLSv1.1"; static final String PROTOCOL_TLS_V1_2 = "TLSv1.2"; + static final String PROTOCOL_TLS_V1_3 = "TLSv1.3"; + + static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL"; /** * change cipher spec @@ -84,20 +97,37 @@ final class SslUtils { */ static final int NOT_ENCRYPTED = -2; - static final String[] DEFAULT_CIPHER_SUITES = { + static final String[] DEFAULT_CIPHER_SUITES; + static final String[] DEFAULT_TLSV13_CIPHER_SUITES; + static final String[] TLSV13_CIPHER_SUITES = { "TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384" }; + + static { + if (PlatformDependent.javaVersion() >= 11) { + DEFAULT_TLSV13_CIPHER_SUITES = TLSV13_CIPHER_SUITES; + } else { + DEFAULT_TLSV13_CIPHER_SUITES = EmptyArrays.EMPTY_STRINGS; + } + + List defaultCiphers = new ArrayList(); // GCM (Galois/Counter Mode) requires JDK 8. - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", + defaultCiphers.add("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"); + defaultCiphers.add("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"); + defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); + defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"); // AES256 requires JCE unlimited strength jurisdiction policy files. - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", + defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"); // GCM (Galois/Counter Mode) requires JDK 8. - "TLS_RSA_WITH_AES_128_GCM_SHA256", - "TLS_RSA_WITH_AES_128_CBC_SHA", + defaultCiphers.add("TLS_RSA_WITH_AES_128_GCM_SHA256"); + defaultCiphers.add("TLS_RSA_WITH_AES_128_CBC_SHA"); // AES256 requires JCE unlimited strength jurisdiction policy files. - "TLS_RSA_WITH_AES_256_CBC_SHA" - }; + defaultCiphers.add("TLS_RSA_WITH_AES_256_CBC_SHA"); + + for (String tlsv13Cipher: DEFAULT_TLSV13_CIPHER_SUITES) { + defaultCiphers.add(tlsv13Cipher); + } + + DEFAULT_CIPHER_SUITES = defaultCiphers.toArray(new String[0]); + } /** * Add elements from {@code names} into {@code enabled} if they are in {@code supported}. @@ -349,6 +379,25 @@ static ByteBuf toBase64(ByteBufAllocator allocator, ByteBuf src) { return dst; } + /** + * Validate that the given hostname can be used in SNI extension. + */ + static boolean isValidHostNameForSNI(String hostname) { + return hostname != null && + hostname.indexOf('.') > 0 && + !hostname.endsWith(".") && + !NetUtil.isValidIpV4Address(hostname) && + !NetUtil.isValidIpV6Address(hostname); + } + + /** + * Returns {@code true} if the the given cipher (in openssl format) is for TLSv1.3, {@code false} otherwise. + */ + static boolean isTLSv13Cipher(String cipher) { + // See https://tools.ietf.org/html/rfc8446#appendix-B.4 + return TLSV13_CIPHERS.contains(cipher); + } + private SslUtils() { } } diff --git a/handler/src/main/java/io/netty/handler/ssl/SupportedCipherSuiteFilter.java b/handler/src/main/java/io/netty/handler/ssl/SupportedCipherSuiteFilter.java index 132f1a0e5e52..265672309f8e 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SupportedCipherSuiteFilter.java +++ b/handler/src/main/java/io/netty/handler/ssl/SupportedCipherSuiteFilter.java @@ -53,7 +53,7 @@ public String[] filterCipherSuites(Iterable ciphers, List defaul newCiphers.add(c); } } - return newCiphers.toArray(new String[newCiphers.size()]); + return newCiphers.toArray(new String[0]); } } diff --git a/handler/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactory.java b/handler/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactory.java index b4fca297a797..454334501efd 100644 --- a/handler/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactory.java +++ b/handler/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactory.java @@ -172,7 +172,7 @@ public FingerprintTrustManagerFactory(byte[]... fingerprints) { list.add(f.clone()); } - this.fingerprints = list.toArray(new byte[list.size()][]); + this.fingerprints = list.toArray(new byte[0][]); } private static byte[][] toFingerprintArray(Iterable fingerprints) { @@ -197,7 +197,7 @@ private static byte[][] toFingerprintArray(Iterable fingerprints) { list.add(StringUtil.decodeHexDump(f)); } - return list.toArray(new byte[list.size()][]); + return list.toArray(new byte[0][]); } @Override diff --git a/handler/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java b/handler/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java index d78b4686772a..23ce8f6af8b4 100644 --- a/handler/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java +++ b/handler/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java @@ -44,12 +44,16 @@ public final class InsecureTrustManagerFactory extends SimpleTrustManagerFactory private static final TrustManager tm = new X509TrustManager() { @Override public void checkClientTrusted(X509Certificate[] chain, String s) { - logger.debug("Accepting a client certificate: " + chain[0].getSubjectDN()); + if (logger.isDebugEnabled()) { + logger.debug("Accepting a client certificate: " + chain[0].getSubjectDN()); + } } @Override public void checkServerTrusted(X509Certificate[] chain, String s) { - logger.debug("Accepting a server certificate: " + chain[0].getSubjectDN()); + if (logger.isDebugEnabled()) { + logger.debug("Accepting a server certificate: " + chain[0].getSubjectDN()); + } } @Override diff --git a/handler/src/main/java/io/netty/handler/ssl/util/OpenJdkSelfSignedCertGenerator.java b/handler/src/main/java/io/netty/handler/ssl/util/OpenJdkSelfSignedCertGenerator.java index 07a6fb91eb5b..30d74e270540 100644 --- a/handler/src/main/java/io/netty/handler/ssl/util/OpenJdkSelfSignedCertGenerator.java +++ b/handler/src/main/java/io/netty/handler/ssl/util/OpenJdkSelfSignedCertGenerator.java @@ -64,16 +64,16 @@ static String[] generate(String fqdn, KeyPair keypair, SecureRandom random, Date info.set(X509CertInfo.VALIDITY, new CertificateValidity(notBefore, notAfter)); info.set(X509CertInfo.KEY, new CertificateX509Key(keypair.getPublic())); info.set(X509CertInfo.ALGORITHM_ID, - new CertificateAlgorithmId(new AlgorithmId(AlgorithmId.sha1WithRSAEncryption_oid))); + new CertificateAlgorithmId(new AlgorithmId(AlgorithmId.sha256WithRSAEncryption_oid))); // Sign the cert to identify the algorithm that's used. X509CertImpl cert = new X509CertImpl(info); - cert.sign(key, "SHA1withRSA"); + cert.sign(key, "SHA256withRSA"); // Update the algorithm and sign again. info.set(CertificateAlgorithmId.NAME + '.' + CertificateAlgorithmId.ALGORITHM, cert.get(X509CertImpl.SIG_ALG)); cert = new X509CertImpl(info); - cert.sign(key, "SHA1withRSA"); + cert.sign(key, "SHA256withRSA"); cert.verify(keypair.getPublic()); return newSelfSignedCertificate(fqdn, key, cert); diff --git a/handler/src/main/java/io/netty/handler/ssl/util/SelfSignedCertificate.java b/handler/src/main/java/io/netty/handler/ssl/util/SelfSignedCertificate.java index 112e1a88a094..9f010ce8ec72 100644 --- a/handler/src/main/java/io/netty/handler/ssl/util/SelfSignedCertificate.java +++ b/handler/src/main/java/io/netty/handler/ssl/util/SelfSignedCertificate.java @@ -175,7 +175,9 @@ public SelfSignedCertificate(String fqdn, SecureRandom random, int bits, Date no try { certificateInput.close(); } catch (IOException e) { - logger.warn("Failed to close a file: " + certificate, e); + if (logger.isWarnEnabled()) { + logger.warn("Failed to close a file: " + certificate, e); + } } } } @@ -288,7 +290,9 @@ static String[] newSelfSignedCertificate( private static void safeDelete(File certFile) { if (!certFile.delete()) { - logger.warn("Failed to delete a file: " + certFile); + if (logger.isWarnEnabled()) { + logger.warn("Failed to delete a file: " + certFile); + } } } @@ -296,7 +300,9 @@ private static void safeClose(File keyFile, OutputStream keyOut) { try { keyOut.close(); } catch (IOException e) { - logger.warn("Failed to close a file: " + keyFile, e); + if (logger.isWarnEnabled()) { + logger.warn("Failed to close a file: " + keyFile, e); + } } } } diff --git a/handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java b/handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java index 6f753c5c5100..1a1822b59735 100644 --- a/handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java +++ b/handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java @@ -166,20 +166,28 @@ private void discard(Throwable cause) { Object message = currentWrite.msg; if (message instanceof ChunkedInput) { ChunkedInput in = (ChunkedInput) message; + boolean endOfInput; + long inputLength; try { - if (!in.isEndOfInput()) { - if (cause == null) { - cause = new ClosedChannelException(); - } - currentWrite.fail(cause); - } else { - currentWrite.success(in.length()); - } + endOfInput = in.isEndOfInput(); + inputLength = in.length(); closeInput(in); } catch (Exception e) { - currentWrite.fail(e); - logger.warn(ChunkedInput.class.getSimpleName() + ".isEndOfInput() failed", e); closeInput(in); + currentWrite.fail(e); + if (logger.isWarnEnabled()) { + logger.warn(ChunkedInput.class.getSimpleName() + " failed", e); + } + continue; + } + + if (!endOfInput) { + if (cause == null) { + cause = new ClosedChannelException(); + } + currentWrite.fail(cause); + } else { + currentWrite.success(inputLength); } } else { if (cause == null) { @@ -207,6 +215,21 @@ private void doFlush(final ChannelHandlerContext ctx) { if (currentWrite == null) { break; } + + if (currentWrite.promise.isDone()) { + // This might happen e.g. in the case when a write operation + // failed, but there're still unconsumed chunks left. + // Most chunked input sources would stop generating chunks + // and report end of input, but this doesn't work with any + // source wrapped in HttpChunkedInput. + // Note, that we're not trying to release the message/chunks + // as this had to be done already by someone who resolved the + // promise (using ChunkedInput.close method). + // See https://github.com/netty/netty/issues/8700. + this.currentWrite = null; + continue; + } + final PendingWrite currentWrite = this.currentWrite; final Object pendingMessage = currentWrite.msg; @@ -232,8 +255,8 @@ private void doFlush(final ChannelHandlerContext ctx) { ReferenceCountUtil.release(message); } - currentWrite.fail(t); closeInput(chunks); + currentWrite.fail(t); break; } @@ -262,9 +285,17 @@ private void doFlush(final ChannelHandlerContext ctx) { f.addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { - currentWrite.progress(chunks.progress(), chunks.length()); - currentWrite.success(chunks.length()); - closeInput(chunks); + if (!future.isSuccess()) { + closeInput(chunks); + currentWrite.fail(future.cause()); + } else { + // read state of the input in local variables before closing it + long inputProgress = chunks.progress(); + long inputLength = chunks.length(); + closeInput(chunks); + currentWrite.progress(inputProgress, inputLength); + currentWrite.success(inputLength); + } } }); } else if (channel.isWritable()) { @@ -272,7 +303,7 @@ public void operationComplete(ChannelFuture future) throws Exception { @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { - closeInput((ChunkedInput) pendingMessage); + closeInput(chunks); currentWrite.fail(future.cause()); } else { currentWrite.progress(chunks.progress(), chunks.length()); @@ -284,7 +315,7 @@ public void operationComplete(ChannelFuture future) throws Exception { @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { - closeInput((ChunkedInput) pendingMessage); + closeInput(chunks); currentWrite.fail(future.cause()); } else { currentWrite.progress(chunks.progress(), chunks.length()); diff --git a/handler/src/main/java/io/netty/handler/timeout/ReadTimeoutHandler.java b/handler/src/main/java/io/netty/handler/timeout/ReadTimeoutHandler.java index 1fe47889d8a4..cba6e60d720e 100644 --- a/handler/src/main/java/io/netty/handler/timeout/ReadTimeoutHandler.java +++ b/handler/src/main/java/io/netty/handler/timeout/ReadTimeoutHandler.java @@ -33,7 +33,7 @@ * * public class MyChannelInitializer extends {@link ChannelInitializer}<{@link Channel}> { * public void initChannel({@link Channel} channel) { - * channel.pipeline().addLast("readTimeoutHandler", new {@link ReadTimeoutHandler}(30); + * channel.pipeline().addLast("readTimeoutHandler", new {@link ReadTimeoutHandler}(30)); * channel.pipeline().addLast("myHandler", new MyHandler()); * } * } diff --git a/handler/src/main/java/io/netty/handler/traffic/AbstractTrafficShapingHandler.java b/handler/src/main/java/io/netty/handler/traffic/AbstractTrafficShapingHandler.java index caeb767fcf19..09e0f38798e6 100644 --- a/handler/src/main/java/io/netty/handler/traffic/AbstractTrafficShapingHandler.java +++ b/handler/src/main/java/io/netty/handler/traffic/AbstractTrafficShapingHandler.java @@ -441,11 +441,15 @@ public void run() { // Anything else allows the handler to reset the AutoRead if (logger.isDebugEnabled()) { if (config.isAutoRead() && !isHandlerActive(ctx)) { - logger.debug("Unsuspend: " + config.isAutoRead() + ':' + - isHandlerActive(ctx)); + if (logger.isDebugEnabled()) { + logger.debug("Unsuspend: " + config.isAutoRead() + ':' + + isHandlerActive(ctx)); + } } else { - logger.debug("Normal unsuspend: " + config.isAutoRead() + ':' - + isHandlerActive(ctx)); + if (logger.isDebugEnabled()) { + logger.debug("Normal unsuspend: " + config.isAutoRead() + ':' + + isHandlerActive(ctx)); + } } } channel.attr(READ_SUSPENDED).set(false); diff --git a/handler/src/test/java/io/netty/handler/ipfilter/UniqueIpFilterTest.java b/handler/src/test/java/io/netty/handler/ipfilter/UniqueIpFilterTest.java new file mode 100644 index 000000000000..6e9eb66f3e78 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ipfilter/UniqueIpFilterTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.internal.SocketUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.net.SocketAddress; +import java.util.concurrent.Callable; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class UniqueIpFilterTest { + + @Test + public void testUniqueIpFilterHandler() throws ExecutionException, InterruptedException { + final CyclicBarrier barrier = new CyclicBarrier(2); + ExecutorService executorService = Executors.newFixedThreadPool(2); + try { + for (int round = 0; round < 10000; round++) { + final UniqueIpFilter ipFilter = new UniqueIpFilter(); + Future future1 = newChannelAsync(barrier, executorService, ipFilter); + Future future2 = newChannelAsync(barrier, executorService, ipFilter); + EmbeddedChannel ch1 = future1.get(); + EmbeddedChannel ch2 = future2.get(); + Assert.assertTrue(ch1.isActive() || ch2.isActive()); + Assert.assertFalse(ch1.isActive() && ch2.isActive()); + + barrier.reset(); + ch1.close().await(); + ch2.close().await(); + } + } finally { + executorService.shutdown(); + } + } + + private static Future newChannelAsync(final CyclicBarrier barrier, + ExecutorService executorService, + final ChannelHandler... handler) { + return executorService.submit(new Callable() { + @Override + public EmbeddedChannel call() throws Exception { + barrier.await(); + return new EmbeddedChannel(handler) { + @Override + protected SocketAddress remoteAddress0() { + return isActive() ? SocketUtils.socketAddress("91.92.93.1", 5421) : null; + } + }; + } + }); + } + +} diff --git a/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java b/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java new file mode 100644 index 000000000000..855cb848b798 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java @@ -0,0 +1,302 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Promise; + +import java.net.SocketAddress; +import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.junit.AfterClass; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; + +import static org.junit.Assert.assertTrue; + +/** + * The purpose of this unit test is to act as a canary and catch changes in supported cipher suites. + */ +@RunWith(Parameterized.class) +public class CipherSuiteCanaryTest { + + private static EventLoopGroup GROUP; + + private static SelfSignedCertificate CERT; + + @Parameters(name = "{index}: serverSslProvider = {0}, clientSslProvider = {1}, rfcCipherName = {2}, delegate = {3}") + public static Collection parameters() { + List dst = new ArrayList(); + dst.addAll(expand("TLS_DHE_RSA_WITH_AES_128_GCM_SHA256")); // DHE-RSA-AES128-GCM-SHA256 + return dst; + } + + @BeforeClass + public static void init() throws Exception { + GROUP = new DefaultEventLoopGroup(); + CERT = new SelfSignedCertificate(); + } + + @AfterClass + public static void destroy() { + GROUP.shutdownGracefully(); + CERT.delete(); + } + + private final SslProvider serverSslProvider; + + private final SslProvider clientSslProvider; + + private final String rfcCipherName; + private final boolean delegate; + + public CipherSuiteCanaryTest(SslProvider serverSslProvider, SslProvider clientSslProvider, + String rfcCipherName, boolean delegate) { + this.serverSslProvider = serverSslProvider; + this.clientSslProvider = clientSslProvider; + this.rfcCipherName = rfcCipherName; + this.delegate = delegate; + } + + private static void assumeCipherAvailable(SslProvider provider, String cipher) throws NoSuchAlgorithmException { + boolean cipherSupported = false; + if (provider == SslProvider.JDK) { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + for (String c: engine.getSupportedCipherSuites()) { + if (cipher.equals(c)) { + cipherSupported = true; + break; + } + } + } else { + cipherSupported = OpenSsl.isCipherSuiteAvailable(cipher); + } + Assume.assumeTrue("Unsupported cipher: " + cipher, cipherSupported); + } + + private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) { + if (executor == null) { + return sslCtx.newHandler(allocator); + } else { + return sslCtx.newHandler(allocator, executor); + } + } + + @Test + public void testHandshake() throws Exception { + // Check if the cipher is supported at all which may not be the case for various JDK versions and OpenSSL API + // implementations. + assumeCipherAvailable(serverSslProvider, rfcCipherName); + assumeCipherAvailable(clientSslProvider, rfcCipherName); + + List ciphers = Collections.singletonList(rfcCipherName); + + final SslContext sslServerContext = SslContextBuilder.forServer(CERT.certificate(), CERT.privateKey()) + .sslProvider(serverSslProvider) + .ciphers(ciphers) + // As this is not a TLSv1.3 cipher we should ensure we talk something else. + .protocols(SslUtils.PROTOCOL_TLS_V1_2) + .build(); + + final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null; + + try { + final SslContext sslClientContext = SslContextBuilder.forClient() + .sslProvider(clientSslProvider) + .ciphers(ciphers) + // As this is not a TLSv1.3 cipher we should ensure we talk something else. + .protocols(SslUtils.PROTOCOL_TLS_V1_2) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + + try { + final Promise serverPromise = GROUP.next().newPromise(); + final Promise clientPromise = GROUP.next().newPromise(); + + ChannelHandler serverHandler = new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(newSslHandler(sslServerContext, ch.alloc(), executorService)); + + pipeline.addLast(new SimpleChannelInboundHandler() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + serverPromise.cancel(true); + ctx.fireChannelInactive(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + if (serverPromise.trySuccess(null)) { + ctx.writeAndFlush(Unpooled.wrappedBuffer(new byte[] {'P', 'O', 'N', 'G'})); + } + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (!serverPromise.tryFailure(cause)) { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }; + + LocalAddress address = new LocalAddress("test-" + serverSslProvider + + '-' + clientSslProvider + '-' + rfcCipherName); + + Channel server = server(address, serverHandler); + try { + ChannelHandler clientHandler = new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(newSslHandler(sslClientContext, ch.alloc(), executorService)); + + pipeline.addLast(new SimpleChannelInboundHandler() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + clientPromise.cancel(true); + ctx.fireChannelInactive(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + clientPromise.trySuccess(null); + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + throws Exception { + if (!clientPromise.tryFailure(cause)) { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }; + + Channel client = client(server, clientHandler); + try { + client.writeAndFlush(Unpooled.wrappedBuffer(new byte[] {'P', 'I', 'N', 'G'})) + .syncUninterruptibly(); + + assertTrue("client timeout", clientPromise.await(5L, TimeUnit.SECONDS)); + assertTrue("server timeout", serverPromise.await(5L, TimeUnit.SECONDS)); + + clientPromise.sync(); + serverPromise.sync(); + } finally { + client.close().sync(); + } + } finally { + server.close().sync(); + } + } finally { + ReferenceCountUtil.release(sslClientContext); + } + } finally { + ReferenceCountUtil.release(sslServerContext); + + if (executorService != null) { + executorService.shutdown(); + } + } + } + + private static Channel server(LocalAddress address, ChannelHandler handler) throws Exception { + ServerBootstrap bootstrap = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(GROUP) + .childHandler(handler); + + return bootstrap.bind(address).sync().channel(); + } + + private static Channel client(Channel server, ChannelHandler handler) throws Exception { + SocketAddress remoteAddress = server.localAddress(); + + Bootstrap bootstrap = new Bootstrap() + .channel(LocalChannel.class) + .group(GROUP) + .handler(handler); + + return bootstrap.connect(remoteAddress).sync().channel(); + } + + private static List expand(String rfcCipherName) { + List dst = new ArrayList(); + SslProvider[] sslProviders = SslProvider.values(); + + for (int i = 0; i < sslProviders.length; i++) { + SslProvider serverSslProvider = sslProviders[i]; + + for (int j = 0; j < sslProviders.length; j++) { + SslProvider clientSslProvider = sslProviders[j]; + + if ((serverSslProvider != SslProvider.JDK || clientSslProvider != SslProvider.JDK) + && !OpenSsl.isAvailable()) { + continue; + } + + dst.add(new Object[]{serverSslProvider, clientSslProvider, rfcCipherName, true}); + dst.add(new Object[]{serverSslProvider, clientSslProvider, rfcCipherName, false}); + } + } + + if (dst.isEmpty()) { + throw new IllegalStateException(); + } + + return dst; + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/CipherSuiteConverterTest.java b/handler/src/test/java/io/netty/handler/ssl/CipherSuiteConverterTest.java index ffe53d2ba8b2..a0d425d0a7ac 100644 --- a/handler/src/test/java/io/netty/handler/ssl/CipherSuiteConverterTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/CipherSuiteConverterTest.java @@ -22,8 +22,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThat; +import static org.junit.Assert.*; public class CipherSuiteConverterTest { @@ -123,10 +122,14 @@ public void testJ2OMappings() throws Exception { testJ2OMapping("TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256", "ECDHE-PSK-CHACHA20-POLY1305"); testJ2OMapping("TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256", "DHE-PSK-CHACHA20-POLY1305"); testJ2OMapping("TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256", "RSA-PSK-CHACHA20-POLY1305"); + + testJ2OMapping("TLS_AES_128_GCM_SHA256", "TLS_AES_128_GCM_SHA256"); + testJ2OMapping("TLS_AES_256_GCM_SHA384", "TLS_AES_256_GCM_SHA384"); + testJ2OMapping("TLS_CHACHA20_POLY1305_SHA256", "TLS_CHACHA20_POLY1305_SHA256"); } private static void testJ2OMapping(String javaCipherSuite, String openSslCipherSuite) { - final String actual = CipherSuiteConverter.toOpenSslUncached(javaCipherSuite); + final String actual = CipherSuiteConverter.toOpenSslUncached(javaCipherSuite, false); logger.info("{} => {}", javaCipherSuite, actual); assertThat(actual, is(openSslCipherSuite)); } @@ -317,15 +320,18 @@ private static void testUnknownOpenSSLCiphersToJava(String openSslCipherSuite) { private static void testUnknownJavaCiphersToOpenSSL(String javaCipherSuite) { CipherSuiteConverter.clearCache(); - assertNull(CipherSuiteConverter.toOpenSsl(javaCipherSuite)); - assertNull(CipherSuiteConverter.toOpenSsl(javaCipherSuite)); + assertNull(CipherSuiteConverter.toOpenSsl(javaCipherSuite, false)); + assertNull(CipherSuiteConverter.toOpenSsl(javaCipherSuite, true)); } private static void testCachedJ2OMapping(String javaCipherSuite, String openSslCipherSuite) { CipherSuiteConverter.clearCache(); - final String actual1 = CipherSuiteConverter.toOpenSsl(javaCipherSuite); + // For TLSv1.3 this should make no diffierence if boringSSL is true or false + final String actual1 = CipherSuiteConverter.toOpenSsl(javaCipherSuite, false); assertThat(actual1, is(openSslCipherSuite)); + final String actual2 = CipherSuiteConverter.toOpenSsl(javaCipherSuite, true); + assertEquals(actual1, actual2); // Ensure that the cache entries have been created. assertThat(CipherSuiteConverter.isJ2OCached(javaCipherSuite, actual1), is(true)); @@ -333,12 +339,12 @@ private static void testCachedJ2OMapping(String javaCipherSuite, String openSslC assertThat(CipherSuiteConverter.isO2JCached(actual1, "SSL", "SSL_" + javaCipherSuite.substring(4)), is(true)); assertThat(CipherSuiteConverter.isO2JCached(actual1, "TLS", "TLS_" + javaCipherSuite.substring(4)), is(true)); - final String actual2 = CipherSuiteConverter.toOpenSsl(javaCipherSuite); - assertThat(actual2, is(openSslCipherSuite)); + final String actual3 = CipherSuiteConverter.toOpenSsl(javaCipherSuite, false); + assertThat(actual3, is(openSslCipherSuite)); // Test if the returned cipher strings are identical, // so that the TLS sessions with the same cipher suite do not create many strings. - assertThat(actual1, is(sameInstance(actual2))); + assertThat(actual1, is(sameInstance(actual3))); } @Test @@ -374,4 +380,34 @@ private static void testCachedO2JMapping(String javaCipherSuite, String openSslC assertThat(tlsActual1, is(sameInstance(tlsActual2))); assertThat(sslActual1, is(sameInstance(sslActual2))); } + + @Test + public void testTlsv13Mappings() { + CipherSuiteConverter.clearCache(); + + assertEquals("TLS_AES_128_GCM_SHA256", + CipherSuiteConverter.toJava("TLS_AES_128_GCM_SHA256", "TLS")); + assertNull(CipherSuiteConverter.toJava("TLS_AES_128_GCM_SHA256", "SSL")); + assertEquals("TLS_AES_256_GCM_SHA384", + CipherSuiteConverter.toJava("TLS_AES_256_GCM_SHA384", "TLS")); + assertNull(CipherSuiteConverter.toJava("TLS_AES_256_GCM_SHA384", "SSL")); + assertEquals("TLS_CHACHA20_POLY1305_SHA256", + CipherSuiteConverter.toJava("TLS_CHACHA20_POLY1305_SHA256", "TLS")); + assertNull(CipherSuiteConverter.toJava("TLS_CHACHA20_POLY1305_SHA256", "SSL")); + + // BoringSSL use different cipher naming then OpenSSL so we need to test for both + assertEquals("TLS_AES_128_GCM_SHA256", + CipherSuiteConverter.toOpenSsl("TLS_AES_128_GCM_SHA256", false)); + assertEquals("TLS_AES_256_GCM_SHA384", + CipherSuiteConverter.toOpenSsl("TLS_AES_256_GCM_SHA384", false)); + assertEquals("TLS_CHACHA20_POLY1305_SHA256", + CipherSuiteConverter.toOpenSsl("TLS_CHACHA20_POLY1305_SHA256", false)); + + assertEquals("AEAD-AES128-GCM-SHA256", + CipherSuiteConverter.toOpenSsl("TLS_AES_128_GCM_SHA256", true)); + assertEquals("AEAD-AES256-GCM-SHA384", + CipherSuiteConverter.toOpenSsl("TLS_AES_256_GCM_SHA384", true)); + assertEquals("AEAD-CHACHA20-POLY1305-SHA256", + CipherSuiteConverter.toOpenSsl("TLS_CHACHA20_POLY1305_SHA256", true)); + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java b/handler/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java index 870f71ef364a..da2d76757fd5 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java @@ -31,17 +31,18 @@ @RunWith(Parameterized.class) public class ConscryptJdkSslEngineInteropTest extends SSLEngineTest { - @Parameterized.Parameters(name = "{index}: bufferType = {0}") - public static Collection data() { - List params = new ArrayList(); + @Parameterized.Parameters(name = "{index}: bufferType = {0}, combo = {1}, delegate = {2}") + public static Collection data() { + List params = new ArrayList(); for (BufferType type: BufferType.values()) { - params.add(type); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), true }); } return params; } - public ConscryptJdkSslEngineInteropTest(BufferType type) { - super(type); + public ConscryptJdkSslEngineInteropTest(BufferType type, ProtocolCipherCombo combo, boolean delegate) { + super(type, combo, delegate); } @BeforeClass @@ -79,4 +80,18 @@ protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. return super.mySetupMutualAuthServerIsValidServerException(cause) || causedBySSLException(cause); } + + @Ignore("Ignore due bug in Conscrypt") + @Override + public void testSessionBindingEvent() throws Exception { + // Ignore due bug in Conscrypt where the incorrect SSLSession object is used in the SSLSessionBindingEvent. + // See https://github.com/google/conscrypt/issues/593 + } + + @Ignore("Ignore due bug in Conscrypt") + @Override + public void testHandshakeSession() throws Exception { + // Ignore as Conscrypt does not correctly return the local certificates while the TrustManager is invoked. + // See https://github.com/google/conscrypt/issues/634 + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/ConscryptOpenSslEngineInteropTest.java b/handler/src/test/java/io/netty/handler/ssl/ConscryptOpenSslEngineInteropTest.java new file mode 100644 index 000000000000..9cb5bb610072 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/ConscryptOpenSslEngineInteropTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import java.security.Provider; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; +import static org.junit.Assume.assumeTrue; + +@RunWith(Parameterized.class) +public class ConscryptOpenSslEngineInteropTest extends ConscryptSslEngineTest { + + @Parameterized.Parameters(name = "{index}: bufferType = {0}, combo = {1}, delegate = {2}") + public static Collection data() { + List params = new ArrayList(); + for (BufferType type: BufferType.values()) { + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), true }); + } + return params; + } + + public ConscryptOpenSslEngineInteropTest(BufferType type, ProtocolCipherCombo combo, boolean delegate) { + super(type, combo, delegate); + } + + @BeforeClass + public static void checkOpenssl() { + assumeTrue(OpenSsl.isAvailable()); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.JDK; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.OPENSSL; + } + + @Override + protected Provider serverSslContextProvider() { + return null; + } + + @Override + @Test + @Ignore("TODO: Make this work with Conscrypt") + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth() { + super.testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(); + } + + @Override + @Test + @Ignore("TODO: Make this work with Conscrypt") + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth() { + super.testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidClientException(cause) || causedBySSLException(cause); + } + + @Override + @Test + public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(); + } + + @Override + @Test + public void testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(); + } + + @Override + @Test + public void testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(); + } + + @Override + @Test + public void testClientHostnameValidationSuccess() throws InterruptedException, SSLException { + assumeTrue(OpenSsl.supportsHostnameValidation()); + super.testClientHostnameValidationSuccess(); + } + + @Override + @Test + public void testClientHostnameValidationFail() throws InterruptedException, SSLException { + assumeTrue(OpenSsl.supportsHostnameValidation()); + super.testClientHostnameValidationFail(); + } + + @Override + @Test + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidServerException(cause) || causedBySSLException(cause); + } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/ConscryptSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/ConscryptSslEngineTest.java new file mode 100644 index 000000000000..114552f37347 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/ConscryptSslEngineTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.security.Provider; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static org.junit.Assume.assumeTrue; + +@RunWith(Parameterized.class) +public class ConscryptSslEngineTest extends SSLEngineTest { + + @Parameterized.Parameters(name = "{index}: bufferType = {0}, combo = {1}, delegate = {2}") + public static Collection data() { + List params = new ArrayList(); + for (BufferType type: BufferType.values()) { + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), true }); + } + return params; + } + + public ConscryptSslEngineTest(BufferType type, ProtocolCipherCombo combo, boolean delegate) { + super(type, combo, delegate); + } + + @BeforeClass + public static void checkConscrypt() { + assumeTrue(Conscrypt.isAvailable()); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.JDK; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.JDK; + } + + @Override + protected Provider clientSslContextProvider() { + return Java8SslTestUtils.conscryptProvider(); + } + + @Override + protected Provider serverSslContextProvider() { + return Java8SslTestUtils.conscryptProvider(); + } + + @Ignore /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth() { + } + + @Ignore /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth() { + } + + @Ignore("Ignore due bug in Conscrypt") + @Override + public void testSessionBindingEvent() throws Exception { + // Ignore due bug in Conscrypt where the incorrect SSLSession object is used in the SSLSessionBindingEvent. + // See https://github.com/google/conscrypt/issues/593 + } + + @Ignore("Ignore due bug in Conscrypt") + @Override + public void testHandshakeSession() throws Exception { + // Ignore as Conscrypt does not correctly return the local certificates while the TrustManager is invoked. + // See https://github.com/google/conscrypt/issues/634 + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/Java8SslTestUtils.java b/handler/src/test/java/io/netty/handler/ssl/Java8SslTestUtils.java index cc2e6c6ed325..50fb935d9dc7 100644 --- a/handler/src/test/java/io/netty/handler/ssl/Java8SslTestUtils.java +++ b/handler/src/test/java/io/netty/handler/ssl/Java8SslTestUtils.java @@ -20,19 +20,21 @@ import javax.net.ssl.SNIMatcher; import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; import java.security.Provider; +import java.util.Arrays; import java.util.Collections; final class Java8SslTestUtils { private Java8SslTestUtils() { } - static void setSNIMatcher(SSLParameters parameters) { + static void setSNIMatcher(SSLParameters parameters, final byte[] match) { SNIMatcher matcher = new SNIMatcher(0) { @Override public boolean matches(SNIServerName sniServerName) { - return false; + return Arrays.equals(match, sniServerName.getEncoded()); } }; parameters.setSNIMatchers(Collections.singleton(matcher)); @@ -41,4 +43,14 @@ public boolean matches(SNIServerName sniServerName) { static Provider conscryptProvider() { return new OpenSSLProvider(); } + + /** + * Wraps the given {@link SSLEngine} to add extra tests while executing methods if possible / needed. + */ + static SSLEngine wrapSSLEngineForTesting(SSLEngine engine) { + if (engine instanceof ReferenceCountedOpenSslEngine) { + return new OpenSslErrorStackAssertSSLEngine((ReferenceCountedOpenSslEngine) engine); + } + return engine; + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/JdkConscryptSslEngineInteropTest.java b/handler/src/test/java/io/netty/handler/ssl/JdkConscryptSslEngineInteropTest.java index 6d8862a0f23d..d7aa08fab400 100644 --- a/handler/src/test/java/io/netty/handler/ssl/JdkConscryptSslEngineInteropTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/JdkConscryptSslEngineInteropTest.java @@ -16,6 +16,7 @@ package io.netty.handler.ssl; import java.security.Provider; + import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; @@ -31,17 +32,18 @@ @RunWith(Parameterized.class) public class JdkConscryptSslEngineInteropTest extends SSLEngineTest { - @Parameterized.Parameters(name = "{index}: bufferType = {0}") - public static Collection data() { - List params = new ArrayList(); + @Parameterized.Parameters(name = "{index}: bufferType = {0}, combo = {1}, delegate = {2}") + public static Collection data() { + List params = new ArrayList(); for (BufferType type: BufferType.values()) { - params.add(type); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), true }); } return params; } - public JdkConscryptSslEngineInteropTest(BufferType type) { - super(type); + public JdkConscryptSslEngineInteropTest(BufferType type, ProtocolCipherCombo combo, boolean delegate) { + super(type, combo, delegate); } @BeforeClass @@ -83,4 +85,11 @@ protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. return super.mySetupMutualAuthServerIsValidClientException(cause) || causedBySSLException(cause); } + + @Ignore("Ignore due bug in Conscrypt") + @Override + public void testHandshakeSession() throws Exception { + // Ignore as Conscrypt does not correctly return the local certificates while the TrustManager is invoked. + // See https://github.com/google/conscrypt/issues/634 + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java b/handler/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java index d696d6b2e7a5..23e004bb842e 100644 --- a/handler/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java @@ -15,11 +15,13 @@ */ package io.netty.handler.ssl; +import io.netty.util.internal.PlatformDependent; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import javax.net.ssl.SSLEngine; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -31,17 +33,23 @@ @RunWith(Parameterized.class) public class JdkOpenSslEngineInteroptTest extends SSLEngineTest { - @Parameterized.Parameters(name = "{index}: bufferType = {0}") - public static Collection data() { - List params = new ArrayList(); + @Parameterized.Parameters(name = "{index}: bufferType = {0}, combo = {1}, delegate = {2}") + public static Collection data() { + List params = new ArrayList(); for (BufferType type: BufferType.values()) { - params.add(type); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), true }); + + if (PlatformDependent.javaVersion() >= 11 && OpenSsl.isTlsv13Supported()) { + params.add(new Object[] { type, ProtocolCipherCombo.tlsv13(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv13(), true }); + } } return params; } - public JdkOpenSslEngineInteroptTest(BufferType type) { - super(type); + public JdkOpenSslEngineInteroptTest(BufferType type, ProtocolCipherCombo protocolCipherCombo, boolean delegate) { + super(type, protocolCipherCombo, delegate); } @BeforeClass @@ -94,6 +102,20 @@ public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth() thr super.testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(); } + @Override + @Test + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(); + } + + @Override + @Test + public void testSessionAfterHandshakeKeyManagerFactory() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactory(); + } + @Override protected void mySetupMutualAuthServerInitSslHandler(SslHandler handler) { ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) handler.engine(); @@ -105,4 +127,15 @@ protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. return super.mySetupMutualAuthServerIsValidClientException(cause) || causedBySSLException(cause); } + + @Override + public void testHandshakeSession() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testHandshakeSession(); + } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java index f37a6aff2512..db86298932bf 100644 --- a/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java @@ -26,6 +26,7 @@ import java.util.ArrayList; import java.util.Collection; +import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PlatformDependent; import org.junit.Ignore; import org.junit.Test; @@ -141,12 +142,18 @@ final void activate(JdkSslEngineTest instance) { private static final String FALLBACK_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http1_1"; private static final String APPLICATION_LEVEL_PROTOCOL_NOT_COMPATIBLE = "my-protocol-FOO"; - @Parameterized.Parameters(name = "{index}: providerType = {0}, bufferType = {1}") + @Parameterized.Parameters(name = "{index}: providerType = {0}, bufferType = {1}, combo = {2}, delegate = {3}") public static Collection data() { List params = new ArrayList(); for (ProviderType providerType : ProviderType.values()) { for (BufferType bufferType : BufferType.values()) { - params.add(new Object[]{providerType, bufferType}); + params.add(new Object[]{ providerType, bufferType, ProtocolCipherCombo.tlsv12(), true }); + params.add(new Object[]{ providerType, bufferType, ProtocolCipherCombo.tlsv12(), false }); + + if (PlatformDependent.javaVersion() >= 11) { + params.add(new Object[] { providerType, bufferType, ProtocolCipherCombo.tlsv13(), true }); + params.add(new Object[] { providerType, bufferType, ProtocolCipherCombo.tlsv13(), false }); + } } } return params; @@ -156,8 +163,9 @@ public static Collection data() { private Provider provider; - public JdkSslEngineTest(ProviderType providerType, BufferType bufferType) { - super(bufferType); + public JdkSslEngineTest(ProviderType providerType, BufferType bufferType, + ProtocolCipherCombo protocolCipherCombo, boolean delegate) { + super(bufferType, protocolCipherCombo, delegate); this.providerType = providerType; } @@ -235,9 +243,11 @@ public String select(List protocols) { InsecureTrustManagerFactory.INSTANCE, null, IdentityCipherSuiteFilter.INSTANCE, clientApn, 0, 0); - setupHandlers(serverSslCtx, clientSslCtx); + setupHandlers(new TestDelegatingSslContext(serverSslCtx), new TestDelegatingSslContext(clientSslCtx)); assertTrue(clientLatch.await(2, TimeUnit.SECONDS)); - assertTrue(clientException instanceof SSLHandshakeException); + // When using TLSv1.3 the handshake is NOT sent in an extra round trip which means there will be + // no exception reported in this case but just the channel will be closed. + assertTrue(clientException instanceof SSLHandshakeException || clientException == null); } } catch (SkipTestException e) { // ALPN availability is dependent on the java version. If ALPN is not available because of @@ -358,4 +368,16 @@ private static final class SkipTestException extends RuntimeException { super(message); } } + + private final class TestDelegatingSslContext extends DelegatingSslContext { + TestDelegatingSslContext(SslContext ctx) { + super(ctx); + } + + @Override + protected void initEngine(SSLEngine engine) { + engine.setEnabledProtocols(protocols()); + engine.setEnabledCipherSuites(ciphers().toArray(EmptyArrays.EMPTY_STRINGS)); + } + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java new file mode 100644 index 000000000000..cfb4557fc237 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.UnpooledByteBufAllocator; +import org.junit.Assert; +import org.junit.Test; + +import javax.net.ssl.KeyManagerFactory; + +import static org.junit.Assert.*; + +public class OpenSslCachingKeyMaterialProviderTest extends OpenSslKeyMaterialProviderTest { + + @Override + protected KeyManagerFactory newKeyManagerFactory() throws Exception { + return new OpenSslCachingX509KeyManagerFactory(super.newKeyManagerFactory()); + } + + @Override + protected OpenSslKeyMaterialProvider newMaterialProvider(KeyManagerFactory factory, String password) { + return new OpenSslCachingKeyMaterialProvider(ReferenceCountedOpenSslContext.chooseX509KeyManager( + factory.getKeyManagers()), password); + } + + @Override + protected void assertRelease(OpenSslKeyMaterial material) { + Assert.assertFalse(material.release()); + } + + @Test + public void testMaterialCached() throws Exception { + OpenSslKeyMaterialProvider provider = newMaterialProvider(newKeyManagerFactory(), PASSWORD); + + OpenSslKeyMaterial material = provider.chooseKeyMaterial(UnpooledByteBufAllocator.DEFAULT, EXISTING_ALIAS); + assertNotNull(material); + assertNotEquals(0, material.certificateChainAddress()); + assertNotEquals(0, material.privateKeyAddress()); + assertEquals(2, material.refCnt()); + + OpenSslKeyMaterial material2 = provider.chooseKeyMaterial(UnpooledByteBufAllocator.DEFAULT, EXISTING_ALIAS); + assertNotNull(material2); + assertEquals(material.certificateChainAddress(), material2.certificateChainAddress()); + assertEquals(material.privateKeyAddress(), material2.privateKeyAddress()); + assertEquals(3, material.refCnt()); + assertEquals(3, material2.refCnt()); + + assertFalse(material.release()); + assertFalse(material2.release()); + + // After this the material should have been released. + provider.destroy(); + + assertEquals(0, material.refCnt()); + assertEquals(0, material2.refCnt()); + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateExceptionTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateExceptionTest.java index 229e853cd203..c9e8163fa66d 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateExceptionTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslCertificateExceptionTest.java @@ -18,20 +18,15 @@ import io.netty.internal.tcnative.CertificateVerifier; import org.junit.Assert; import org.junit.Assume; -import org.junit.BeforeClass; import org.junit.Test; import java.lang.reflect.Field; public class OpenSslCertificateExceptionTest { - @BeforeClass - public static void assumeOpenSsl() { - Assume.assumeTrue(OpenSsl.isAvailable()); - } - @Test public void testValidErrorCode() throws Exception { + Assume.assumeTrue(OpenSsl.isAvailable()); Field[] fields = CertificateVerifier.class.getFields(); for (Field field : fields) { if (field.isAccessible()) { @@ -44,6 +39,13 @@ public void testValidErrorCode() throws Exception { @Test(expected = IllegalArgumentException.class) public void testNonValidErrorCode() { + Assume.assumeTrue(OpenSsl.isAvailable()); new OpenSslCertificateException(Integer.MIN_VALUE); } + + @Test + public void testCanBeInstancedWhenOpenSslIsNotAvailable() { + Assume.assumeFalse(OpenSsl.isAvailable()); + new OpenSslCertificateException(0); + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslConscryptSslEngineInteropTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslConscryptSslEngineInteropTest.java new file mode 100644 index 000000000000..fa49fa20cbdc --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslConscryptSslEngineInteropTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import java.security.Provider; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; +import static org.junit.Assume.assumeTrue; + +@RunWith(Parameterized.class) +public class OpenSslConscryptSslEngineInteropTest extends ConscryptSslEngineTest { + + @Parameterized.Parameters(name = "{index}: bufferType = {0}, combo = {1}, delegate = {2}") + public static Collection data() { + List params = new ArrayList(); + for (BufferType type: BufferType.values()) { + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), true }); + } + return params; + } + + public OpenSslConscryptSslEngineInteropTest(BufferType type, ProtocolCipherCombo combo, boolean delegate) { + super(type, combo, delegate); + } + + @BeforeClass + public static void checkOpenssl() { + assumeTrue(OpenSsl.isAvailable()); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.OPENSSL; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.JDK; + } + + @Override + protected Provider clientSslContextProvider() { + return null; + } + + @Override + @Test + @Ignore("TODO: Make this work with Conscrypt") + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth() { + super.testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(); + } + + @Override + @Test + @Ignore("TODO: Make this work with Conscrypt") + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth() { + super.testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidClientException(cause) || causedBySSLException(cause); + } + + @Override + @Test + public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(); + } + + @Override + @Test + public void testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(); + } + + @Override + @Test + public void testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(); + } + + @Override + @Test + public void testClientHostnameValidationSuccess() throws InterruptedException, SSLException { + assumeTrue(OpenSsl.supportsHostnameValidation()); + super.testClientHostnameValidationSuccess(); + } + + @Override + @Test + public void testClientHostnameValidationFail() throws InterruptedException, SSLException { + assumeTrue(OpenSsl.supportsHostnameValidation()); + super.testClientHostnameValidationFail(); + } + + @Override + @Test + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidServerException(cause) || causedBySSLException(cause); + } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java index 9972b2d5f4b4..c90a69fd7f27 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java @@ -21,6 +21,9 @@ import io.netty.handler.ssl.ApplicationProtocolConfig.SelectorFailureBehavior; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.internal.tcnative.SSL; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PlatformDependent; import org.junit.Assume; import org.junit.BeforeClass; @@ -64,17 +67,23 @@ public class OpenSslEngineTest extends SSLEngineTest { private static final String PREFERRED_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http2"; private static final String FALLBACK_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http1_1"; - @Parameterized.Parameters(name = "{index}: bufferType = {0}") - public static Collection data() { - List params = new ArrayList(); + @Parameterized.Parameters(name = "{index}: bufferType = {0}, combo = {1}, delegate = {2}") + public static Collection data() { + List params = new ArrayList(); for (BufferType type: BufferType.values()) { - params.add(type); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), true }); + + if (OpenSsl.isTlsv13Supported()) { + params.add(new Object[] { type, ProtocolCipherCombo.tlsv13(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv13(), true }); + } } return params; } - public OpenSslEngineTest(BufferType type) { - super(type); + public OpenSslEngineTest(BufferType type, ProtocolCipherCombo cipherCombo, boolean delegate) { + super(type, cipherCombo, delegate); } @BeforeClass @@ -82,6 +91,26 @@ public static void checkOpenSsl() { assumeTrue(OpenSsl.isAvailable()); } + @Override + public void tearDown() throws InterruptedException { + super.tearDown(); + assertEquals("SSL error stack not correctly consumed", 0, SSL.getLastErrorNumber()); + } + + @Override + @Test + public void testSessionAfterHandshakeKeyManagerFactory() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactory(); + } + + @Override + @Test + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(); + } + @Override @Test public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth() throws Exception { @@ -131,6 +160,12 @@ public void testClientHostnameValidationFail() throws InterruptedException, SSLE super.testClientHostnameValidationFail(); } + @Override + public void testHandshakeSession() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testHandshakeSession(); + } + private static boolean isNpnSupported(String versionString) { String[] versionStringParts = versionString.split(" ", -1); if (versionStringParts.length == 2 && "LibreSSL".equals(versionStringParts[0])) { @@ -197,18 +232,22 @@ public void testEnablingAnAlreadyDisabledSslProtocol() throws Exception { @Test public void testWrapBuffersNoWritePendingError() throws Exception { clientSslCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider()) - .build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(sslServerProvider()) - .build(); + .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SSLEngine clientEngine = null; SSLEngine serverEngine = null; try { - clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); - serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); handshake(clientEngine, serverEngine); ByteBuffer src = allocateBuffer(1024 * 10); @@ -231,18 +270,22 @@ public void testWrapBuffersNoWritePendingError() throws Exception { @Test public void testOnlySmallBufferNeededForWrap() throws Exception { clientSslCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider()) - .build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(sslServerProvider()) - .build(); + .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SSLEngine clientEngine = null; SSLEngine serverEngine = null; try { - clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); - serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); handshake(clientEngine, serverEngine); // Allocate a buffer which is small enough and set the limit to the capacity to mark its whole content @@ -251,9 +294,9 @@ public void testOnlySmallBufferNeededForWrap() throws Exception { ByteBuffer src = allocateBuffer(srcLen); ByteBuffer dstTooSmall = allocateBuffer( - src.capacity() + ((ReferenceCountedOpenSslEngine) clientEngine).maxWrapOverhead() - 1); + src.capacity() + unwrapEngine(clientEngine).maxWrapOverhead() - 1); ByteBuffer dst = allocateBuffer( - src.capacity() + ((ReferenceCountedOpenSslEngine) clientEngine).maxWrapOverhead()); + src.capacity() + unwrapEngine(clientEngine).maxWrapOverhead()); // Check that we fail to wrap if the dst buffers capacity is not at least // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH @@ -282,25 +325,29 @@ public void testOnlySmallBufferNeededForWrap() throws Exception { @Test public void testNeededDstCapacityIsCorrectlyCalculated() throws Exception { clientSslCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider()) - .build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(sslServerProvider()) - .build(); + .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SSLEngine clientEngine = null; SSLEngine serverEngine = null; try { - clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); - serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); handshake(clientEngine, serverEngine); ByteBuffer src = allocateBuffer(1024); ByteBuffer src2 = src.duplicate(); ByteBuffer dst = allocateBuffer(src.capacity() - + ((ReferenceCountedOpenSslEngine) clientEngine).maxWrapOverhead()); + + unwrapEngine(clientEngine).maxWrapOverhead()); SSLEngineResult result = clientEngine.wrap(new ByteBuffer[] { src, src2 }, dst); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); @@ -318,18 +365,22 @@ public void testNeededDstCapacityIsCorrectlyCalculated() throws Exception { @Test public void testSrcsLenOverFlowCorrectlyHandled() throws Exception { clientSslCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider()) - .build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(sslServerProvider()) - .build(); + .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SSLEngine clientEngine = null; SSLEngine serverEngine = null; try { - clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); - serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); handshake(clientEngine, serverEngine); ByteBuffer src = allocateBuffer(1024); @@ -343,9 +394,9 @@ public void testSrcsLenOverFlowCorrectlyHandled() throws Exception { srcsLen += dup.capacity(); } - ByteBuffer[] srcs = srcList.toArray(new ByteBuffer[srcList.size()]); + ByteBuffer[] srcs = srcList.toArray(new ByteBuffer[0]); ByteBuffer dst = allocateBuffer( - ((ReferenceCountedOpenSslEngine) clientEngine).maxEncryptedPacketLength() - 1); + unwrapEngine(clientEngine).maxEncryptedPacketLength() - 1); SSLEngineResult result = clientEngine.wrap(srcs, dst); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); @@ -365,9 +416,11 @@ public void testSrcsLenOverFlowCorrectlyHandled() throws Exception { @Test public void testCalculateOutNetBufSizeOverflow() throws SSLException { clientSslCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider()) - .build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SSLEngine clientEngine = null; try { clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); @@ -381,9 +434,11 @@ public void testCalculateOutNetBufSizeOverflow() throws SSLException { @Test public void testCalculateOutNetBufSize0() throws SSLException { clientSslCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider()) - .build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); SSLEngine clientEngine = null; try { clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); @@ -393,6 +448,72 @@ public void testCalculateOutNetBufSize0() throws SSLException { } } + @Test + public void testCorrectlyCalculateSpaceForAlert() throws Exception { + testCorrectlyCalculateSpaceForAlert(true); + } + + @Test + public void testCorrectlyCalculateSpaceForAlertJDKCompatabilityModeOff() throws Exception { + testCorrectlyCalculateSpaceForAlert(false); + } + + private void testCorrectlyCalculateSpaceForAlert(boolean jdkCompatabilityMode) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + + clientSslCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + if (jdkCompatabilityMode) { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + } else { + clientEngine = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + serverEngine = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + } + handshake(clientEngine, serverEngine); + + // This should produce an alert + clientEngine.closeOutbound(); + + ByteBuffer empty = allocateBuffer(0); + ByteBuffer dst = allocateBuffer(clientEngine.getSession().getPacketBufferSize()); + // Limit to something that is guaranteed to be too small to hold a SSL Record. + dst.limit(1); + + // As we called closeOutbound() before this should produce a BUFFER_OVERFLOW. + SSLEngineResult result = clientEngine.wrap(empty, dst); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + + // This must calculate a length that can hold an alert at least (or more). + dst.limit(dst.capacity()); + + result = clientEngine.wrap(empty, dst); + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); + + // flip the buffer so we can verify we produced a full length buffer. + dst.flip(); + + int length = SslUtils.getEncryptedPacketLength(new ByteBuffer[] { dst }, 0); + assertEquals(length, dst.remaining()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + @Override protected void mySetupMutualAuthServerInitSslHandler(SslHandler handler) { ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) handler.engine(); @@ -402,70 +523,49 @@ protected void mySetupMutualAuthServerInitSslHandler(SslHandler handler) { @Test public void testWrapWithDifferentSizesTLSv1() throws Exception { clientSslCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider()) - .build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build(); SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(sslServerProvider()) - .build(); + .sslProvider(sslServerProvider()) + .build(); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ADH-AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ECDHE-RSA-AES128-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ADH-CAMELLIA128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "DES-CBC3-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "AECDH-AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "AECDH-DES-CBC3-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "CAMELLIA128-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "DHE-RSA-AES256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "SEED-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "RC4-MD5"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ADH-AES256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "AES256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ADH-SEED-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ADH-DES-CBC3-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "EDH-RSA-DES-CBC3-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ADH-RC4-MD5"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "IDEA-CBC-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "DHE-RSA-AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "RC4-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "CAMELLIA256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "AECDH-RC4-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "DHE-RSA-SEED-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "AECDH-AES256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ECDHE-RSA-DES-CBC3-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ADH-CAMELLIA256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "DHE-RSA-CAMELLIA256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ECDHE-RSA-AES256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "DHE-RSA-CAMELLIA128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1, "ECDHE-RSA-RC4-SHA"); } @Test public void testWrapWithDifferentSizesTLSv1_1() throws Exception { clientSslCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider()) - .build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build(); SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(sslServerProvider()) - .build(); + .sslProvider(sslServerProvider()) + .build(); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "ECDHE-RSA-AES256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "DHE-RSA-AES256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "DHE-RSA-CAMELLIA256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "ADH-CAMELLIA256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "ADH-AES256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "AES256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "CAMELLIA256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "AECDH-AES128-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "DHE-RSA-CAMELLIA128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "ECDHE-RSA-AES256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "ADH-AES128-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "ADH-SEED-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "ADH-CAMELLIA128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "SEED-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "CAMELLIA128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_1, "IDEA-CBC-SHA"); @@ -490,15 +590,10 @@ public void testWrapWithDifferentSizesTLSv1_2() throws Exception { .sslProvider(sslServerProvider()) .build(); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ECDHE-RSA-AES128-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-CAMELLIA128-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DHE-RSA-AES256-GCM-SHA384"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DES-CBC3-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "AECDH-AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "AES128-GCM-SHA256"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DHE-RSA-AES128-GCM-SHA256"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ECDHE-RSA-AES256-SHA384"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "AECDH-DES-CBC3-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "AES256-GCM-SHA384"); @@ -506,35 +601,19 @@ public void testWrapWithDifferentSizesTLSv1_2() throws Exception { testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ECDHE-RSA-AES128-GCM-SHA256"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ECDHE-RSA-AES128-SHA256"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "CAMELLIA128-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DHE-RSA-AES256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "SEED-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "RC4-MD5"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-AES256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "AES256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-SEED-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DHE-RSA-AES128-SHA256"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-DES-CBC3-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "EDH-RSA-DES-CBC3-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-RC4-MD5"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "IDEA-CBC-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DHE-RSA-AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "RC4-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "CAMELLIA256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-AES128-GCM-SHA256"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "AES128-SHA256"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "AECDH-RC4-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-AES256-GCM-SHA384"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DHE-RSA-SEED-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DHE-RSA-AES256-SHA256"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "AECDH-AES256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ECDHE-RSA-DES-CBC3-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-CAMELLIA256-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ECDHE-RSA-AES256-GCM-SHA384"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DHE-RSA-CAMELLIA256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-AES256-SHA256"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ADH-AES128-SHA256"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ECDHE-RSA-AES256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "DHE-RSA-CAMELLIA128-SHA"); testWrapWithDifferentSizes(PROTOCOL_TLS_V1_2, "ECDHE-RSA-RC4-SHA"); } @@ -550,9 +629,7 @@ public void testWrapWithDifferentSizesSSLv3() throws Exception { .build(); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "ADH-AES128-SHA"); - testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "ADH-CAMELLIA128-SHA"); - testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "DES-CBC3-SHA"); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "AECDH-AES128-SHA"); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "AECDH-DES-CBC3-SHA"); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "CAMELLIA128-SHA"); @@ -560,7 +637,6 @@ public void testWrapWithDifferentSizesSSLv3() throws Exception { testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "SEED-SHA"); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "RC4-MD5"); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "ADH-AES256-SHA"); - testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "AES256-SHA"); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "ADH-SEED-SHA"); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "ADH-DES-CBC3-SHA"); testWrapWithDifferentSizes(PROTOCOL_SSL_V3, "EDH-RSA-DES-CBC3-SHA"); @@ -587,14 +663,18 @@ public void testMultipleRecordsInOneBufferWithNonZeroPositionJDKCompatabilityMod .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + SSLEngine client = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + SSLEngine server = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); try { // Choose buffer size small enough that we can put multiple buffers into one buffer and pass it into the @@ -664,14 +744,18 @@ public void testInputTooBigAndFillsUpBuffersJDKCompatabilityModeOff() throws Exc .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + SSLEngine client = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + SSLEngine server = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); try { ByteBuffer plainClient = allocateBuffer(MAX_PLAINTEXT_LENGTH + 100); @@ -748,14 +832,18 @@ public void testPartialPacketUnwrapJDKCompatabilityModeOff() throws Exception { .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + SSLEngine client = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + SSLEngine server = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); try { ByteBuffer plainClient = allocateBuffer(1024); @@ -823,14 +911,18 @@ public void testBufferUnderFlowAvoidedIfJDKCompatabilityModeOff() throws Excepti .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + SSLEngine client = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + SSLEngine server = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); try { ByteBuffer plainClient = allocateBuffer(1024); @@ -905,8 +997,8 @@ private void testWrapWithDifferentSizes(String protocol, String cipher) throws E SSLEngine clientEngine = null; SSLEngine serverEngine = null; try { - clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); - serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); clientEngine.setEnabledCipherSuites(new String[] { cipher }); clientEngine.setEnabledProtocols(new String[] { protocol }); serverEngine.setEnabledCipherSuites(new String[] { cipher }); @@ -915,7 +1007,8 @@ private void testWrapWithDifferentSizes(String protocol, String cipher) throws E try { handshake(clientEngine, serverEngine); } catch (SSLException e) { - if (e.getMessage().contains("unsupported protocol")) { + if (e.getMessage().contains("unsupported protocol") || + e.getMessage().contains("no protocols available")) { Assume.assumeNoException(protocol + " not supported with cipher " + cipher, e); } throw e; @@ -936,7 +1029,7 @@ private void testWrapWithDifferentSizes(String protocol, String cipher) throws E private void testWrapDstBigEnough(SSLEngine engine, int srcLen) throws SSLException { ByteBuffer src = allocateBuffer(srcLen); - ByteBuffer dst = allocateBuffer(srcLen + ((ReferenceCountedOpenSslEngine) engine).maxWrapOverhead()); + ByteBuffer dst = allocateBuffer(srcLen + unwrapEngine(engine).maxWrapOverhead()); SSLEngineResult result = engine.wrap(src, dst); assertEquals(SSLEngineResult.Status.OK, result.getStatus()); @@ -955,14 +1048,40 @@ public void testSNIMatchersDoesNotThrow() throws Exception { assumeTrue(PlatformDependent.javaVersion() >= 8); SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(sslServerProvider()) - .build(); + .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + + SSLEngine engine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + try { + SSLParameters parameters = new SSLParameters(); + Java8SslTestUtils.setSNIMatcher(parameters, EmptyArrays.EMPTY_BYTES); + engine.setSSLParameters(parameters); + } finally { + cleanupServerSslEngine(engine); + ssc.delete(); + } + } - SSLEngine engine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + @Test + public void testSNIMatchersWithSNINameWithUnderscore() throws Exception { + assumeTrue(PlatformDependent.javaVersion() >= 8); + byte[] name = "rb8hx3pww30y3tvw0mwy.v1_1".getBytes(CharsetUtil.UTF_8); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + + SSLEngine engine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { SSLParameters parameters = new SSLParameters(); - Java8SslTestUtils.setSNIMatcher(parameters); + Java8SslTestUtils.setSNIMatcher(parameters, name); engine.setSSLParameters(parameters); + assertTrue(unwrapEngine(engine).checkSniHostnameMatch(name)); + assertFalse(unwrapEngine(engine).checkSniHostnameMatch("other".getBytes(CharsetUtil.UTF_8))); } finally { cleanupServerSslEngine(engine); ssc.delete(); @@ -973,10 +1092,12 @@ public void testSNIMatchersDoesNotThrow() throws Exception { public void testAlgorithmConstraintsThrows() throws Exception { SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(sslServerProvider()) - .build(); + .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); - SSLEngine engine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine engine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { SSLParameters parameters = new SSLParameters(); parameters.setAlgorithmConstraints(new AlgorithmConstraints() { @@ -1021,4 +1142,19 @@ private static ApplicationProtocolConfig acceptingNegotiator(Protocol protocol, SelectedListenerFailureBehavior.ACCEPT, supportedProtocols); } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + if (PlatformDependent.javaVersion() >= 8) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } + return engine; + } + + ReferenceCountedOpenSslEngine unwrapEngine(SSLEngine engine) { + if (engine instanceof JdkSslEngine) { + return (ReferenceCountedOpenSslEngine) ((JdkSslEngine) engine).getWrappedEngine(); + } + return (ReferenceCountedOpenSslEngine) engine; + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslErrorStackAssertSSLEngine.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslErrorStackAssertSSLEngine.java new file mode 100644 index 000000000000..af71f72570a7 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslErrorStackAssertSSLEngine.java @@ -0,0 +1,440 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.PlatformDependent; +import org.junit.Assert; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.function.BiFunction; + +/** + * Special {@link SSLEngine} which allows to wrap a {@link ReferenceCountedOpenSslEngine} and verify that that + * Error stack is empty after each method call. + */ +final class OpenSslErrorStackAssertSSLEngine extends JdkSslEngine implements ReferenceCounted { + + OpenSslErrorStackAssertSSLEngine(ReferenceCountedOpenSslEngine engine) { + super(engine); + } + + @Override + public String getPeerHost() { + try { + return getWrappedEngine().getPeerHost(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public int getPeerPort() { + try { + return getWrappedEngine().getPeerPort(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult wrap(ByteBuffer src, ByteBuffer dst) throws SSLException { + try { + return getWrappedEngine().wrap(src, dst); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult wrap(ByteBuffer[] srcs, ByteBuffer dst) throws SSLException { + try { + return getWrappedEngine().wrap(srcs, dst); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult wrap(ByteBuffer[] byteBuffers, int i, int i1, ByteBuffer byteBuffer) throws SSLException { + try { + return getWrappedEngine().wrap(byteBuffers, i, i1, byteBuffer); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer dst) throws SSLException { + try { + return getWrappedEngine().unwrap(src, dst); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts) throws SSLException { + try { + return getWrappedEngine().unwrap(src, dsts); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult unwrap(ByteBuffer byteBuffer, ByteBuffer[] byteBuffers, int i, int i1) throws SSLException { + try { + return getWrappedEngine().unwrap(byteBuffer, byteBuffers, i, i1); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public Runnable getDelegatedTask() { + try { + return getWrappedEngine().getDelegatedTask(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void closeInbound() throws SSLException { + try { + getWrappedEngine().closeInbound(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean isInboundDone() { + try { + return getWrappedEngine().isInboundDone(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void closeOutbound() { + try { + getWrappedEngine().closeOutbound(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean isOutboundDone() { + try { + return getWrappedEngine().isOutboundDone(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public String[] getSupportedCipherSuites() { + try { + return getWrappedEngine().getSupportedCipherSuites(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public String[] getEnabledCipherSuites() { + try { + return getWrappedEngine().getEnabledCipherSuites(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setEnabledCipherSuites(String[] strings) { + try { + getWrappedEngine().setEnabledCipherSuites(strings); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public String[] getSupportedProtocols() { + try { + return getWrappedEngine().getSupportedProtocols(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public String[] getEnabledProtocols() { + try { + return getWrappedEngine().getEnabledProtocols(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setEnabledProtocols(String[] strings) { + try { + getWrappedEngine().setEnabledProtocols(strings); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLSession getSession() { + try { + return getWrappedEngine().getSession(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLSession getHandshakeSession() { + try { + return getWrappedEngine().getHandshakeSession(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void beginHandshake() throws SSLException { + try { + getWrappedEngine().beginHandshake(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult.HandshakeStatus getHandshakeStatus() { + try { + return getWrappedEngine().getHandshakeStatus(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setUseClientMode(boolean b) { + try { + getWrappedEngine().setUseClientMode(b); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean getUseClientMode() { + try { + return getWrappedEngine().getUseClientMode(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setNeedClientAuth(boolean b) { + try { + getWrappedEngine().setNeedClientAuth(b); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean getNeedClientAuth() { + try { + return getWrappedEngine().getNeedClientAuth(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setWantClientAuth(boolean b) { + try { + getWrappedEngine().setWantClientAuth(b); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean getWantClientAuth() { + try { + return getWrappedEngine().getWantClientAuth(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setEnableSessionCreation(boolean b) { + try { + getWrappedEngine().setEnableSessionCreation(b); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean getEnableSessionCreation() { + try { + return getWrappedEngine().getEnableSessionCreation(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLParameters getSSLParameters() { + try { + return getWrappedEngine().getSSLParameters(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setSSLParameters(SSLParameters params) { + try { + getWrappedEngine().setSSLParameters(params); + } finally { + assertErrorStackEmpty(); + } + } + + public String getApplicationProtocol() { + if (PlatformDependent.javaVersion() >= 9) { + try { + return Java9SslUtils.getApplicationProtocol(getWrappedEngine()); + } finally { + assertErrorStackEmpty(); + } + } + throw new UnsupportedOperationException(); + } + + public String getHandshakeApplicationProtocol() { + if (PlatformDependent.javaVersion() >= 9) { + try { + return Java9SslUtils.getHandshakeApplicationProtocol(getWrappedEngine()); + } finally { + assertErrorStackEmpty(); + } + } + throw new UnsupportedOperationException(); + } + + public void setHandshakeApplicationProtocolSelector(BiFunction, String> selector) { + if (PlatformDependent.javaVersion() >= 9) { + try { + Java9SslUtils.setHandshakeApplicationProtocolSelector(getWrappedEngine(), selector); + } finally { + assertErrorStackEmpty(); + } + } + throw new UnsupportedOperationException(); + } + + public BiFunction, String> getHandshakeApplicationProtocolSelector() { + if (PlatformDependent.javaVersion() >= 9) { + try { + return Java9SslUtils.getHandshakeApplicationProtocolSelector(getWrappedEngine()); + } finally { + assertErrorStackEmpty(); + } + } + throw new UnsupportedOperationException(); + } + + @Override + public int refCnt() { + return getWrappedEngine().refCnt(); + } + + @Override + public OpenSslErrorStackAssertSSLEngine retain() { + getWrappedEngine().retain(); + return this; + } + + @Override + public OpenSslErrorStackAssertSSLEngine retain(int increment) { + getWrappedEngine().retain(increment); + return this; + } + + @Override + public OpenSslErrorStackAssertSSLEngine touch() { + getWrappedEngine().touch(); + return this; + } + + @Override + public OpenSslErrorStackAssertSSLEngine touch(Object hint) { + getWrappedEngine().touch(hint); + return this; + } + + @Override + public boolean release() { + return getWrappedEngine().release(); + } + + @Override + public boolean release(int decrement) { + return getWrappedEngine().release(decrement); + } + + @Override + public String getNegotiatedApplicationProtocol() { + return getWrappedEngine().getNegotiatedApplicationProtocol(); + } + + @Override + void setNegotiatedApplicationProtocol(String applicationProtocol) { + throw new UnsupportedOperationException(); + } + + @Override + public ReferenceCountedOpenSslEngine getWrappedEngine() { + return (ReferenceCountedOpenSslEngine) super.getWrappedEngine(); + } + + private static void assertErrorStackEmpty() { + Assert.assertEquals("SSL error stack non-empty", 0, SSL.getLastErrorNumber()); + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java index f63a16feb5e7..440c34c25b33 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java @@ -15,10 +15,12 @@ */ package io.netty.handler.ssl; +import io.netty.util.internal.PlatformDependent; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; +import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -33,17 +35,23 @@ @RunWith(Parameterized.class) public class OpenSslJdkSslEngineInteroptTest extends SSLEngineTest { - @Parameterized.Parameters(name = "{index}: bufferType = {0}") - public static Collection data() { - List params = new ArrayList(); + @Parameterized.Parameters(name = "{index}: bufferType = {0}, combo = {1}, delegate = {2}") + public static Collection data() { + List params = new ArrayList(); for (BufferType type: BufferType.values()) { - params.add(type); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv12(), true }); + + if (PlatformDependent.javaVersion() >= 11 && OpenSsl.isTlsv13Supported()) { + params.add(new Object[] { type, ProtocolCipherCombo.tlsv13(), false }); + params.add(new Object[] { type, ProtocolCipherCombo.tlsv13(), true }); + } } return params; } - public OpenSslJdkSslEngineInteroptTest(BufferType type) { - super(type); + public OpenSslJdkSslEngineInteroptTest(BufferType type, ProtocolCipherCombo combo, boolean delegate) { + super(type, combo, delegate); } @BeforeClass @@ -106,9 +114,27 @@ public void testClientHostnameValidationFail() throws InterruptedException, SSLE super.testClientHostnameValidationFail(); } + @Override + @Test + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(); + } + @Override protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) { // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. return super.mySetupMutualAuthServerIsValidServerException(cause) || causedBySSLException(cause); } + + @Override + public void testHandshakeSession() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testHandshakeSession(); + } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialManagerTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialManagerTest.java new file mode 100644 index 000000000000..ae197a23e98b --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialManagerTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.util.internal.EmptyArrays; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; + +import javax.net.ssl.SSLException; +import javax.net.ssl.X509ExtendedKeyManager; +import java.net.Socket; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +public class OpenSslKeyMaterialManagerTest { + + @Test + public void testChooseClientAliasReturnsNull() throws SSLException { + Assume.assumeTrue(OpenSsl.isAvailable()); + + X509ExtendedKeyManager keyManager = new X509ExtendedKeyManager() { + @Override + public String[] getClientAliases(String s, Principal[] principals) { + return EmptyArrays.EMPTY_STRINGS; + } + + @Override + public String chooseClientAlias(String[] strings, Principal[] principals, Socket socket) { + return null; + } + + @Override + public String[] getServerAliases(String s, Principal[] principals) { + return EmptyArrays.EMPTY_STRINGS; + } + + @Override + public String chooseServerAlias(String s, Principal[] principals, Socket socket) { + return null; + } + + @Override + public X509Certificate[] getCertificateChain(String s) { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + + @Override + public PrivateKey getPrivateKey(String s) { + return null; + } + }; + + OpenSslKeyMaterialManager manager = new OpenSslKeyMaterialManager( + new OpenSslKeyMaterialProvider(keyManager, null) { + @Override + OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) throws Exception { + Assert.fail("Should not be called when alias is null"); + return null; + } + }); + SslContext context = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).build(); + OpenSslEngine engine = + (OpenSslEngine) context.newEngine(UnpooledByteBufAllocator.DEFAULT); + manager.setKeyMaterialClientSide(engine, EmptyArrays.EMPTY_STRINGS, null); + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java new file mode 100644 index 000000000000..5b793fe87b17 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.UnpooledByteBufAllocator; +import org.junit.BeforeClass; +import org.junit.Test; + +import javax.net.ssl.KeyManagerFactory; + +import java.security.KeyStore; + +import static org.junit.Assert.*; +import static org.junit.Assume.assumeTrue; + +public class OpenSslKeyMaterialProviderTest { + + static final String PASSWORD = "example"; + static final String EXISTING_ALIAS = "1"; + private static final String NON_EXISTING_ALIAS = "nonexisting"; + + @BeforeClass + public static void checkOpenSsl() { + assumeTrue(OpenSsl.isAvailable()); + } + + protected KeyManagerFactory newKeyManagerFactory() throws Exception { + char[] password = PASSWORD.toCharArray(); + final KeyStore keystore = KeyStore.getInstance("PKCS12"); + keystore.load(getClass().getResourceAsStream("mutual_auth_server.p12"), password); + + KeyManagerFactory kmf = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(keystore, password); + return kmf; + } + + protected OpenSslKeyMaterialProvider newMaterialProvider(KeyManagerFactory factory, String password) { + return new OpenSslKeyMaterialProvider(ReferenceCountedOpenSslContext.chooseX509KeyManager( + factory.getKeyManagers()), password); + } + + protected void assertRelease(OpenSslKeyMaterial material) { + assertTrue(material.release()); + } + + @Test + public void testChooseKeyMaterial() throws Exception { + OpenSslKeyMaterialProvider provider = newMaterialProvider(newKeyManagerFactory(), PASSWORD); + OpenSslKeyMaterial nonExistingMaterial = provider.chooseKeyMaterial( + UnpooledByteBufAllocator.DEFAULT, NON_EXISTING_ALIAS); + assertNull(nonExistingMaterial); + + OpenSslKeyMaterial material = provider.chooseKeyMaterial(UnpooledByteBufAllocator.DEFAULT, EXISTING_ALIAS); + assertNotNull(material); + assertNotEquals(0, material.certificateChainAddress()); + assertNotEquals(0, material.privateKeyAddress()); + assertRelease(material); + + provider.destroy(); + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslTest.java index 4fd849109977..0a9d19db6585 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslTest.java @@ -22,7 +22,9 @@ public class OpenSslTest { @Test public void testDefaultCiphers() { - Assert.assertTrue( - OpenSsl.DEFAULT_CIPHERS.size() <= SslUtils.DEFAULT_CIPHER_SUITES.length); + if (!OpenSsl.isTlsv13Supported()) { + Assert.assertTrue( + OpenSsl.DEFAULT_CIPHERS.size() <= SslUtils.DEFAULT_CIPHER_SUITES.length); + } } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactoryProviderTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactoryProviderTest.java new file mode 100644 index 000000000000..6afe5e4e72f4 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactoryProviderTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.KeyManagerFactory; +import java.security.KeyStore; + +public class OpenSslX509KeyManagerFactoryProviderTest extends OpenSslCachingKeyMaterialProviderTest { + + @Override + protected KeyManagerFactory newKeyManagerFactory() throws Exception { + char[] password = PASSWORD.toCharArray(); + final KeyStore keystore = KeyStore.getInstance("PKCS12"); + keystore.load(getClass().getResourceAsStream("mutual_auth_server.p12"), password); + + OpenSslX509KeyManagerFactory kmf = new OpenSslX509KeyManagerFactory(); + kmf.init(keystore, password); + return kmf; + } + + @Override + protected OpenSslKeyMaterialProvider newMaterialProvider(KeyManagerFactory kmf, String password) { + return ((OpenSslX509KeyManagerFactory) kmf).newProvider(); + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java index 813e59937281..6d01a2673e77 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java @@ -33,6 +33,7 @@ import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.handler.ssl.util.SimpleTrustManagerFactory; import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ResourcesUtil; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.Promise; @@ -46,7 +47,6 @@ import javax.net.ssl.SSLException; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; -import java.io.File; import java.net.InetSocketAddress; import java.security.KeyStore; import java.security.cert.CertificateException; @@ -302,8 +302,8 @@ public X509Certificate[] getAcceptedIssuers() { final SslContext sslClientCtx = SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) - .keyManager(new File(getClass().getResource("test.crt").getFile()), - new File(getClass().getResource("test_unencrypted.pem").getFile())) + .keyManager(ResourcesUtil.getFile(getClass(), "test.crt"), + ResourcesUtil.getFile(getClass(), "test_unencrypted.pem")) .sslProvider(clientProvider).build(); NioEventLoopGroup group = new NioEventLoopGroup(); @@ -340,7 +340,7 @@ protected void initChannel(Channel ch) throws Exception { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { if (cause.getCause() instanceof SSLException) { // We received the alert and so produce an SSLException. - promise.setSuccess(null); + promise.trySuccess(null); } } }); @@ -381,12 +381,21 @@ private void testCloseNotify(final long closeNotifyReadTimeout, final boolean ti SelfSignedCertificate ssc = new SelfSignedCertificate(); final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(serverProvider) - .build(); + .sslProvider(serverProvider) + // Use TLSv1.2 as we depend on the fact that the handshake + // is done in an extra round trip in the test which + // is not true in TLSv1.3 + .protocols(SslUtils.PROTOCOL_TLS_V1_2) + .build(); final SslContext sslClientCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(clientProvider).build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(clientProvider) + // Use TLSv1.2 as we depend on the fact that the handshake + // is done in an extra round trip in the test which + // is not true in TLSv1.3 + .protocols(SslUtils.PROTOCOL_TLS_V1_2) + .build(); EventLoopGroup group = new NioEventLoopGroup(); Channel sc = null; diff --git a/handler/src/test/java/io/netty/handler/ssl/PemEncodedTest.java b/handler/src/test/java/io/netty/handler/ssl/PemEncodedTest.java index 793f77228785..b2531eb8f6dc 100644 --- a/handler/src/test/java/io/netty/handler/ssl/PemEncodedTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/PemEncodedTest.java @@ -24,7 +24,10 @@ import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; +import java.security.PrivateKey; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; import org.junit.Test; import io.netty.handler.ssl.util.SelfSignedCertificate; @@ -69,6 +72,26 @@ private static void testPemEncoded(SslProvider provider) throws Exception { } } + @Test(expected = IllegalArgumentException.class) + public void testEncodedReturnsNull() throws Exception { + PemPrivateKey.toPEM(UnpooledByteBufAllocator.DEFAULT, true, new PrivateKey() { + @Override + public String getAlgorithm() { + return null; + } + + @Override + public String getFormat() { + return null; + } + + @Override + public byte[] getEncoded() { + return null; + } + }); + } + private static void assertRelease(PemEncoded encoded) { assertTrue(encoded.release()); } diff --git a/handler/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java index ddbd0b16dedb..1728a53844a2 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java @@ -23,8 +23,8 @@ public class ReferenceCountedOpenSslEngineTest extends OpenSslEngineTest { - public ReferenceCountedOpenSslEngineTest(BufferType type) { - super(type); + public ReferenceCountedOpenSslEngineTest(BufferType type, ProtocolCipherCombo combo, boolean delegate) { + super(type, combo, delegate); } @Override @@ -60,9 +60,11 @@ protected void cleanupServerSslEngine(SSLEngine engine) { @Test(expected = NullPointerException.class) public void testNotLeakOnException() throws Exception { clientSslCtx = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider()) - .build(); + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); clientSslCtx.newEngine(null); } diff --git a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java index 7d86fa278e5b..9ecb615daabf 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java @@ -33,17 +33,21 @@ import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.ssl.ApplicationProtocolConfig.Protocol; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.handler.ssl.util.SimpleTrustManagerFactory; import io.netty.util.CharsetUtil; import io.netty.util.NetUtil; import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ResourcesUtil; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Promise; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import org.junit.After; +import org.junit.Assume; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -53,39 +57,61 @@ import java.io.ByteArrayInputStream; import java.io.File; import java.io.FileInputStream; +import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; +import java.net.Socket; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; import java.security.Provider; +import java.security.UnrecoverableKeyException; import java.security.cert.Certificate; import java.security.cert.CertificateException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.SNIHostName; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.Status; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionBindingEvent; +import javax.net.ssl.SSLSessionBindingListener; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.TrustManagerFactorySpi; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; import javax.security.cert.X509Certificate; -import static io.netty.handler.ssl.SslUtils.PROTOCOL_TLS_V1; -import static io.netty.handler.ssl.SslUtils.PROTOCOL_TLS_V1_2; -import static io.netty.handler.ssl.SslUtils.SSL_RECORD_HEADER_LENGTH; - +import static io.netty.handler.ssl.SslUtils.*; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.verify; @@ -203,10 +229,45 @@ enum BufferType { Mixed } + static final class ProtocolCipherCombo { + private static final ProtocolCipherCombo TLSV12 = new ProtocolCipherCombo( + PROTOCOL_TLS_V1_2, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); + private static final ProtocolCipherCombo TLSV13 = new ProtocolCipherCombo( + PROTOCOL_TLS_V1_3, "TLS_AES_128_GCM_SHA256"); + final String protocol; + final String cipher; + + private ProtocolCipherCombo(String protocol, String cipher) { + this.protocol = protocol; + this.cipher = cipher; + } + + static ProtocolCipherCombo tlsv12() { + return TLSV12; + } + + static ProtocolCipherCombo tlsv13() { + return TLSV13; + } + + @Override + public String toString() { + return "ProtocolCipherCombo{" + + "protocol='" + protocol + '\'' + + ", cipher='" + cipher + '\'' + + '}'; + } + } + private final BufferType type; + private final ProtocolCipherCombo protocolCipherCombo; + private final boolean delegate; + private ExecutorService delegatingExecutor; - protected SSLEngineTest(BufferType type) { + protected SSLEngineTest(BufferType type, ProtocolCipherCombo protocolCipherCombo, boolean delegate) { this.type = type; + this.protocolCipherCombo = protocolCipherCombo; + this.delegate = delegate; } protected ByteBuffer allocateBuffer(int len) { @@ -392,6 +453,9 @@ public void setup() { MockitoAnnotations.initMocks(this); serverLatch = new CountDownLatch(1); clientLatch = new CountDownLatch(1); + if (delegate) { + delegatingExecutor = Executors.newCachedThreadPool(); + } } @After @@ -451,25 +515,32 @@ public void tearDown() throws InterruptedException { clientGroupShutdownFuture.sync(); } serverException = null; + + if (delegatingExecutor != null) { + delegatingExecutor.shutdown(); + } } @Test - public void testMutualAuthSameCerts() throws Exception { - mySetupMutualAuth(new File(getClass().getResource("test_unencrypted.pem").getFile()), - new File(getClass().getResource("test.crt").getFile()), - null); + public void testMutualAuthSameCerts() throws Throwable { + mySetupMutualAuth(ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"), + ResourcesUtil.getFile(getClass(), "test.crt"), + null); runTest(null); assertTrue(serverLatch.await(2, TimeUnit.SECONDS)); - assertNull(serverException); + Throwable cause = serverException; + if (cause != null) { + throw cause; + } } @Test public void testMutualAuthDiffCerts() throws Exception { - File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile()); - File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); + File serverKeyFile = ResourcesUtil.getFile(getClass(), "test_encrypted.pem"); + File serverCrtFile = ResourcesUtil.getFile(getClass(), "test.crt"); String serverKeyPassword = "12345"; - File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile()); - File clientCrtFile = new File(getClass().getResource("test2.crt").getFile()); + File clientKeyFile = ResourcesUtil.getFile(getClass(), "test2_encrypted.pem"); + File clientCrtFile = ResourcesUtil.getFile(getClass(), "test2.crt"); String clientKeyPassword = "12345"; mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); @@ -479,11 +550,11 @@ public void testMutualAuthDiffCerts() throws Exception { @Test public void testMutualAuthDiffCertsServerFailure() throws Exception { - File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile()); - File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); + File serverKeyFile = ResourcesUtil.getFile(getClass(), "test_encrypted.pem"); + File serverCrtFile = ResourcesUtil.getFile(getClass(), "test.crt"); String serverKeyPassword = "12345"; - File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile()); - File clientCrtFile = new File(getClass().getResource("test2.crt").getFile()); + File clientKeyFile = ResourcesUtil.getFile(getClass(), "test2_encrypted.pem"); + File clientCrtFile = ResourcesUtil.getFile(getClass(), "test2.crt"); String clientKeyPassword = "12345"; // Client trusts server but server only trusts itself mySetupMutualAuth(serverCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, @@ -494,11 +565,11 @@ public void testMutualAuthDiffCertsServerFailure() throws Exception { @Test public void testMutualAuthDiffCertsClientFailure() throws Exception { - File serverKeyFile = new File(getClass().getResource("test_unencrypted.pem").getFile()); - File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); + File serverKeyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + File serverCrtFile = ResourcesUtil.getFile(getClass(), "test.crt"); String serverKeyPassword = null; - File clientKeyFile = new File(getClass().getResource("test2_unencrypted.pem").getFile()); - File clientCrtFile = new File(getClass().getResource("test2.crt").getFile()); + File clientKeyFile = ResourcesUtil.getFile(getClass(), "test2_unencrypted.pem"); + File clientCrtFile = ResourcesUtil.getFile(getClass(), "test2.crt"); String clientKeyPassword = null; // Server trusts client but client only trusts itself mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, @@ -544,7 +615,7 @@ private void testMutualAuthInvalidClientCertSucceed(ClientAuth auth) throws Exce final KeyManagerFactory clientKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); clientKeyManagerFactory.init(clientKeyStore, password); - File commonCertChain = new File(getClass().getResource("mutual_auth_ca.pem").getFile()); + File commonCertChain = ResourcesUtil.getFile(getClass(), "mutual_auth_ca.pem"); mySetupMutualAuth(serverKeyManagerFactory, commonCertChain, clientKeyManagerFactory, commonCertChain, auth, false, false); @@ -571,7 +642,7 @@ private void testMutualAuthClientCertFail(ClientAuth auth, String clientCert, bo final KeyManagerFactory clientKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); clientKeyManagerFactory.init(clientKeyStore, password); - File commonCertChain = new File(getClass().getResource("mutual_auth_ca.pem").getFile()); + File commonCertChain = ResourcesUtil.getFile(getClass(), "mutual_auth_ca.pem"); mySetupMutualAuth(serverKeyManagerFactory, commonCertChain, clientKeyManagerFactory, commonCertChain, auth, true, serverInitEngine); @@ -603,7 +674,8 @@ protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) } protected boolean mySetupMutualAuthServerIsValidException(Throwable cause) { - return cause instanceof SSLHandshakeException || cause instanceof ClosedChannelException; + // As in TLSv1.3 the handshake is sent without an extra roundtrip an SSLException is valid as well. + return cause instanceof SSLException || cause instanceof ClosedChannelException; } protected void mySetupMutualAuthServerInitSslHandler(SslHandler handler) { @@ -614,25 +686,30 @@ private void mySetupMutualAuth(KeyManagerFactory serverKMF, final File serverTru ClientAuth clientAuth, final boolean failureExpected, final boolean serverInitEngine) throws SSLException, InterruptedException { - serverSslCtx = SslContextBuilder.forServer(serverKMF) - .sslProvider(sslServerProvider()) - .sslContextProvider(serverSslContextProvider()) - .trustManager(serverTrustManager) - .clientAuth(clientAuth) - .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) - .sessionCacheSize(0) - .sessionTimeout(0) - .build(); + serverSslCtx = + SslContextBuilder.forServer(serverKMF) + .protocols(protocols()) + .ciphers(ciphers()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .trustManager(serverTrustManager) + .clientAuth(clientAuth) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0).build(); + + clientSslCtx = + SslContextBuilder.forClient() + .protocols(protocols()) + .ciphers(ciphers()) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .trustManager(clientTrustManager) + .keyManager(clientKMF) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0).build(); - clientSslCtx = SslContextBuilder.forClient() - .sslProvider(sslClientProvider()) - .sslContextProvider(clientSslContextProvider()) - .trustManager(clientTrustManager) - .keyManager(clientKMF) - .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) - .sessionCacheSize(0) - .sessionTimeout(0) - .build(); serverConnectedChannel = null; sb = new ServerBootstrap(); cb = new Bootstrap(); @@ -645,7 +722,8 @@ protected void initChannel(Channel ch) throws Exception { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); ChannelPipeline p = ch.pipeline(); - SslHandler handler = serverSslCtx.newHandler(ch.alloc()); + SslHandler handler = delegatingExecutor == null ? serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); if (serverInitEngine) { mySetupMutualAuthServerInitSslHandler(handler); } @@ -688,16 +766,20 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E protected void initChannel(Channel ch) throws Exception { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); ChannelPipeline p = ch.pipeline(); - p.addLast(clientSslCtx.newHandler(ch.alloc())); + + SslHandler handler = delegatingExecutor == null ? clientSslCtx.newHandler(ch.alloc()) : + clientSslCtx.newHandler(ch.alloc(), delegatingExecutor); + p.addLast(handler); p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); p.addLast(new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt == SslHandshakeCompletionEvent.SUCCESS) { - if (failureExpected) { - clientException = new IllegalStateException("handshake complete. expected failure"); + // With TLS1.3 a mutal auth error will not be propagated as a handshake error most of the + // time as the handshake needs NO extra roundtrip. + if (!failureExpected) { + clientLatch.countDown(); } - clientLatch.countDown(); } else if (evt instanceof SslHandshakeCompletionEvent) { clientException = ((SslHandshakeCompletionEvent) evt).cause(); clientLatch.countDown(); @@ -707,7 +789,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (cause.getCause() instanceof SSLHandshakeException) { + if (cause.getCause() instanceof SSLException) { clientException = cause.getCause(); clientLatch.countDown(); } else { @@ -718,7 +800,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E } }); - serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + serverChannel = sb.bind(new InetSocketAddress(8443)).sync().channel(); int port = ((InetSocketAddress) serverChannel.localAddress()).getPort(); ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port)); @@ -728,9 +810,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E @Test public void testClientHostnameValidationSuccess() throws InterruptedException, SSLException { - mySetupClientHostnameValidation(new File(getClass().getResource("localhost_server.pem").getFile()), - new File(getClass().getResource("localhost_server.key").getFile()), - new File(getClass().getResource("mutual_auth_ca.pem").getFile()), + mySetupClientHostnameValidation(ResourcesUtil.getFile(getClass(), "localhost_server.pem"), + ResourcesUtil.getFile(getClass(), "localhost_server.key"), + ResourcesUtil.getFile(getClass(), "mutual_auth_ca.pem"), false); assertTrue(clientLatch.await(5, TimeUnit.SECONDS)); assertNull(clientException); @@ -740,9 +822,9 @@ public void testClientHostnameValidationSuccess() throws InterruptedException, S @Test public void testClientHostnameValidationFail() throws InterruptedException, SSLException { - mySetupClientHostnameValidation(new File(getClass().getResource("notlocalhost_server.pem").getFile()), - new File(getClass().getResource("notlocalhost_server.key").getFile()), - new File(getClass().getResource("mutual_auth_ca.pem").getFile()), + mySetupClientHostnameValidation(ResourcesUtil.getFile(getClass(), "notlocalhost_server.pem"), + ResourcesUtil.getFile(getClass(), "notlocalhost_server.key"), + ResourcesUtil.getFile(getClass(), "mutual_auth_ca.pem"), true); assertTrue(clientLatch.await(5, TimeUnit.SECONDS)); assertTrue("unexpected exception: " + clientException, @@ -759,6 +841,8 @@ private void mySetupClientHostnameValidation(File serverCrtFile, File serverKeyF final String expectedHost = "localhost"; serverSslCtx = SslContextBuilder.forServer(serverCrtFile, serverKeyFile, null) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .sslContextProvider(serverSslContextProvider()) .trustManager(InsecureTrustManagerFactory.INSTANCE) .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) @@ -768,12 +852,15 @@ private void mySetupClientHostnameValidation(File serverCrtFile, File serverKeyF clientSslCtx = SslContextBuilder.forClient() .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .sslContextProvider(clientSslContextProvider()) .trustManager(clientTrustCrtFile) .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) .sessionCacheSize(0) .sessionTimeout(0) .build(); + serverConnectedChannel = null; sb = new ServerBootstrap(); cb = new Bootstrap(); @@ -785,7 +872,10 @@ private void mySetupClientHostnameValidation(File serverCrtFile, File serverKeyF protected void initChannel(Channel ch) throws Exception { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); ChannelPipeline p = ch.pipeline(); - p.addLast(serverSslCtx.newHandler(ch.alloc())); + + SslHandler handler = delegatingExecutor == null ? serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + p.addLast(handler); p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); p.addLast(new ChannelInboundHandlerAdapter() { @Override @@ -825,8 +915,16 @@ protected void initChannel(Channel ch) throws Exception { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); ChannelPipeline p = ch.pipeline(); InetSocketAddress remoteAddress = (InetSocketAddress) serverChannel.localAddress(); - SslHandler sslHandler = clientSslCtx.newHandler(ch.alloc(), expectedHost, 0); + + SslHandler sslHandler = delegatingExecutor == null ? + clientSslCtx.newHandler(ch.alloc(), expectedHost, 0) : + clientSslCtx.newHandler(ch.alloc(), expectedHost, 0, delegatingExecutor); + SSLParameters parameters = sslHandler.engine().getSSLParameters(); + if (SslUtils.isValidHostNameForSNI(expectedHost)) { + assertEquals(1, parameters.getServerNames().size()); + assertEquals(new SNIHostName(expectedHost), parameters.getServerNames().get(0)); + } parameters.setEndpointIdentificationAlgorithm("HTTPS"); sslHandler.engine().setSSLParameters(parameters); p.addLast(sslHandler); @@ -872,28 +970,66 @@ private void mySetupMutualAuth(File keyFile, File crtFile, String keyPassword) mySetupMutualAuth(crtFile, keyFile, crtFile, keyPassword, crtFile, keyFile, crtFile, keyPassword); } + private void verifySSLSessionForMutualAuth(SSLSession session, File certFile, String principalName) + throws Exception { + InputStream in = null; + try { + assertEquals(principalName, session.getLocalPrincipal().getName()); + assertEquals(principalName, session.getPeerPrincipal().getName()); + assertNotNull(session.getId()); + assertEquals(protocolCipherCombo.cipher, session.getCipherSuite()); + assertEquals(protocolCipherCombo.protocol, session.getProtocol()); + assertTrue(session.getApplicationBufferSize() > 0); + assertTrue(session.getCreationTime() > 0); + assertTrue(session.isValid()); + assertTrue(session.getLastAccessedTime() > 0); + + in = new FileInputStream(certFile); + final byte[] certBytes = SslContext.X509_CERT_FACTORY + .generateCertificate(in).getEncoded(); + + // Verify session + assertEquals(1, session.getPeerCertificates().length); + assertArrayEquals(certBytes, session.getPeerCertificates()[0].getEncoded()); + + assertEquals(1, session.getPeerCertificateChain().length); + assertArrayEquals(certBytes, session.getPeerCertificateChain()[0].getEncoded()); + + assertEquals(1, session.getLocalCertificates().length); + assertArrayEquals(certBytes, session.getLocalCertificates()[0].getEncoded()); + } finally { + if (in != null) { + in.close(); + } + } + } + private void mySetupMutualAuth( File servertTrustCrtFile, File serverKeyFile, final File serverCrtFile, String serverKeyPassword, - File clientTrustCrtFile, File clientKeyFile, File clientCrtFile, String clientKeyPassword) + File clientTrustCrtFile, File clientKeyFile, final File clientCrtFile, String clientKeyPassword) throws InterruptedException, SSLException { - serverSslCtx = SslContextBuilder.forServer(serverCrtFile, serverKeyFile, serverKeyPassword) - .sslProvider(sslServerProvider()) - .sslContextProvider(serverSslContextProvider()) - .trustManager(servertTrustCrtFile) - .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) - .sessionCacheSize(0) - .sessionTimeout(0) - .build(); + serverSslCtx = + SslContextBuilder.forServer(serverCrtFile, serverKeyFile, serverKeyPassword) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .trustManager(servertTrustCrtFile) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0).build(); + clientSslCtx = + SslContextBuilder.forClient() + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .trustManager(clientTrustCrtFile) + .keyManager(clientCrtFile, clientKeyFile, clientKeyPassword) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0).build(); - clientSslCtx = SslContextBuilder.forClient() - .sslProvider(sslClientProvider()) - .sslContextProvider(clientSslContextProvider()) - .trustManager(clientTrustCrtFile) - .keyManager(clientCrtFile, clientKeyFile, clientKeyPassword) - .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) - .sessionCacheSize(0) - .sessionTimeout(0) - .build(); serverConnectedChannel = null; sb = new ServerBootstrap(); cb = new Bootstrap(); @@ -902,13 +1038,14 @@ private void mySetupMutualAuth( sb.channel(NioServerSocketChannel.class); sb.childHandler(new ChannelInitializer() { @Override - protected void initChannel(Channel ch) throws Exception { + protected void initChannel(Channel ch) { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); ChannelPipeline p = ch.pipeline(); - SSLEngine engine = serverSslCtx.newEngine(ch.alloc()); + final SSLEngine engine = wrapEngine(serverSslCtx.newEngine(ch.alloc())); engine.setUseClientMode(false); engine.setNeedClientAuth(true); + p.addLast(new SslHandler(engine)); p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); p.addLast(new ChannelInboundHandlerAdapter() { @@ -927,27 +1064,8 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt == SslHandshakeCompletionEvent.SUCCESS) { try { - InputStream in = new FileInputStream(serverCrtFile); - try { - final byte[] cert = SslContext.X509_CERT_FACTORY - .generateCertificate(in).getEncoded(); - - // Verify session - SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession(); - assertEquals(1, session.getPeerCertificates().length); - assertArrayEquals(cert, session.getPeerCertificates()[0].getEncoded()); - - assertEquals(1, session.getPeerCertificateChain().length); - assertArrayEquals(cert, session.getPeerCertificateChain()[0].getEncoded()); - - assertEquals(1, session.getLocalCertificates().length); - assertArrayEquals(cert, session.getLocalCertificates()[0].getEncoded()); - - assertEquals(PRINCIPAL_NAME, session.getLocalPrincipal().getName()); - assertEquals(PRINCIPAL_NAME, session.getPeerPrincipal().getName()); - } finally { - in.close(); - } + verifySSLSessionForMutualAuth( + engine.getSession(), serverCrtFile, PRINCIPAL_NAME); } catch (Throwable cause) { serverException = cause; } @@ -965,13 +1083,29 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc protected void initChannel(Channel ch) throws Exception { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + final SslHandler handler = delegatingExecutor == null ? + clientSslCtx.newHandler(ch.alloc()) : + clientSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + handler.engine().setNeedClientAuth(true); ChannelPipeline p = ch.pipeline(); - p.addLast(clientSslCtx.newHandler(ch.alloc())); + p.addLast(handler); p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == SslHandshakeCompletionEvent.SUCCESS) { + try { + verifySSLSessionForMutualAuth( + handler.engine().getSession(), clientCrtFile, PRINCIPAL_NAME); + } catch (Throwable cause) { + clientException = cause; + } + } + } + @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - cause.printStackTrace(); if (cause.getCause() instanceof SSLHandshakeException) { clientException = cause.getCause(); clientLatch.countDown(); @@ -1024,7 +1158,7 @@ private static void writeAndVerifyReceived(ByteBuf message, Channel sendChannel, MessageReceiver receiver) throws Exception { List dataCapture = null; try { - sendChannel.writeAndFlush(message); + assertTrue(sendChannel.writeAndFlush(message).await(5, TimeUnit.SECONDS)); receiverLatch.await(5, TimeUnit.SECONDS); message.resetReaderIndex(); ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); @@ -1047,7 +1181,7 @@ public void testGetCreationTime() throws Exception { .sslContextProvider(clientSslContextProvider()).build(); SSLEngine engine = null; try { - engine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + engine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); assertTrue(engine.getSession().getCreationTime() <= System.currentTimeMillis()); } finally { cleanupClientSslEngine(engine); @@ -1060,17 +1194,21 @@ public void testSessionInvalidate() throws Exception { .trustManager(InsecureTrustManagerFactory.INSTANCE) .sslProvider(sslClientProvider()) .sslContextProvider(clientSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) .sslProvider(sslServerProvider()) .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); SSLEngine clientEngine = null; SSLEngine serverEngine = null; try { - clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); - serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); handshake(clientEngine, serverEngine); SSLSession session = serverEngine.getSession(); @@ -1089,18 +1227,22 @@ public void testSSLSessionId() throws Exception { clientSslCtx = SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) .sslProvider(sslClientProvider()) + // This test only works for non TLSv1.3 for now + .protocols(PROTOCOL_TLS_V1_2) .sslContextProvider(clientSslContextProvider()) .build(); SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) .sslProvider(sslServerProvider()) + // This test only works for non TLSv1.3 for now + .protocols(PROTOCOL_TLS_V1_2) .sslContextProvider(serverSslContextProvider()) .build(); SSLEngine clientEngine = null; SSLEngine serverEngine = null; try { - clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); - serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); // Before the handshake the id should have length == 0 assertEquals(0, clientEngine.getSession().getId().length); @@ -1122,10 +1264,14 @@ public void testSSLSessionId() throws Exception { @Test(timeout = 30000) public void clientInitiatedRenegotiationWithFatalAlertDoesNotInfiniteLoopServer() throws CertificateException, SSLException, InterruptedException, ExecutionException { + Assume.assumeTrue(PlatformDependent.javaVersion() >= 11); final SelfSignedCertificate ssc = new SelfSignedCertificate(); serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(sslServerProvider()) - .sslContextProvider(serverSslContextProvider()).build(); + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); sb = new ServerBootstrap() .group(new NioEventLoopGroup(1)) .channel(NioServerSocketChannel.class) @@ -1135,7 +1281,12 @@ public void initChannel(SocketChannel ch) { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); ChannelPipeline p = ch.pipeline(); - p.addLast(serverSslCtx.newHandler(ch.alloc())); + + SslHandler handler = delegatingExecutor == null ? + serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + p.addLast(handler); p.addLast(new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { @@ -1175,8 +1326,12 @@ public void channelInactive(ChannelHandlerContext ctx) { serverChannel = sb.bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); clientSslCtx = SslContextBuilder.forClient() - .sslProvider(SslProvider.JDK) // OpenSslEngine doesn't support renegotiation on client side - .trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + // OpenSslEngine doesn't support renegotiation on client side + .sslProvider(SslProvider.JDK) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); cb = new Bootstrap(); cb.group(new NioEventLoopGroup(1)) @@ -1187,7 +1342,11 @@ public void initChannel(SocketChannel ch) { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); ChannelPipeline p = ch.pipeline(); - SslHandler sslHandler = clientSslCtx.newHandler(ch.alloc()); + + SslHandler sslHandler = delegatingExecutor == null ? + clientSslCtx.newHandler(ch.alloc()) : + clientSslCtx.newHandler(ch.alloc(), delegatingExecutor); + // The renegotiate is not expected to succeed, so we should stop trying in a timely manner so // the unit test can terminate relativley quicly. sslHandler.setHandshakeTimeout(1, TimeUnit.SECONDS); @@ -1233,14 +1392,16 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { protected void testEnablingAnAlreadyDisabledSslProtocol(String[] protocols1, String[] protocols2) throws Exception { SSLEngine sslEngine = null; try { - File serverKeyFile = new File(getClass().getResource("test_unencrypted.pem").getFile()); - File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); + File serverKeyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + File serverCrtFile = ResourcesUtil.getFile(getClass(), "test.crt"); serverSslCtx = SslContextBuilder.forServer(serverCrtFile, serverKeyFile) - .sslProvider(sslServerProvider()) - .sslContextProvider(serverSslContextProvider()) - .build(); + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); - sslEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + sslEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); // Disable all protocols sslEngine.setEnabledProtocols(EmptyArrays.EMPTY_STRINGS); @@ -1265,7 +1426,7 @@ protected void testEnablingAnAlreadyDisabledSslProtocol(String[] protocols1, Str } } - protected void handshake(SSLEngine clientEngine, SSLEngine serverEngine) throws SSLException { + protected void handshake(SSLEngine clientEngine, SSLEngine serverEngine) throws Exception { ByteBuffer cTOs = allocateBuffer(clientEngine.getSession().getPacketBufferSize()); ByteBuffer sTOc = allocateBuffer(serverEngine.getSession().getPacketBufferSize()); @@ -1317,7 +1478,10 @@ protected void handshake(SSLEngine clientEngine, SSLEngine serverEngine) throws cTOsPos = cTOs.position(); sTOcPos = sTOc.position(); - if (!clientHandshakeFinished) { + if (!clientHandshakeFinished || + // After the handshake completes it is possible we have more data that was send by the server as + // the server will send session updates after the handshake. In this case continue to unwrap. + SslUtils.PROTOCOL_TLS_V1_3.equals(clientEngine.getSession().getProtocol())) { int clientAppReadBufferPos = clientAppReadBuffer.position(); clientResult = clientEngine.unwrap(sTOc, clientAppReadBuffer); @@ -1329,7 +1493,7 @@ protected void handshake(SSLEngine clientEngine, SSLEngine serverEngine) throws clientHandshakeFinished = true; } } else { - assertFalse(sTOc.hasRemaining()); + assertEquals(0, sTOc.remaining()); } if (!serverHandshakeFinished) { @@ -1355,14 +1519,18 @@ private static boolean isHandshakeFinished(SSLEngineResult result) { return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED; } - private static void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) { + private void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) throws Exception { if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) { for (;;) { Runnable task = engine.getDelegatedTask(); if (task == null) { break; } - task.run(); + if (delegatingExecutor == null) { + task.run(); + } else { + delegatingExecutor.submit(task).get(); + } } } } @@ -1412,24 +1580,35 @@ protected void setupHandlers(ApplicationProtocolConfig serverApn, ApplicationPro SelfSignedCertificate ssc = new SelfSignedCertificate(); try { - setupHandlers(SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey(), null) - .sslProvider(sslServerProvider()) - .sslContextProvider(serverSslContextProvider()) - .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) - .applicationProtocolConfig(serverApn) - .sessionCacheSize(0) - .sessionTimeout(0) - .build(), - - SslContextBuilder.forClient() - .sslProvider(sslClientProvider()) - .sslContextProvider(clientSslContextProvider()) - .applicationProtocolConfig(clientApn) - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) - .sessionCacheSize(0) - .sessionTimeout(0) - .build()); + SslContextBuilder serverCtxBuilder = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey(), null) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .applicationProtocolConfig(serverApn) + .sessionCacheSize(0) + .sessionTimeout(0); + if (serverApn.protocol() == Protocol.NPN || serverApn.protocol() == Protocol.NPN_AND_ALPN) { + // NPN is not really well supported with TLSv1.3 so force to use TLSv1.2 + // See https://github.com/openssl/openssl/issues/3665 + serverCtxBuilder.protocols(PROTOCOL_TLS_V1_2); + } + + SslContextBuilder clientCtxBuilder = SslContextBuilder.forClient() + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .applicationProtocolConfig(clientApn) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0); + + if (clientApn.protocol() == Protocol.NPN || clientApn.protocol() == Protocol.NPN_AND_ALPN) { + // NPN is not really well supported with TLSv1.3 so force to use TLSv1.2 + // See https://github.com/openssl/openssl/issues/3665 + clientCtxBuilder.protocols(PROTOCOL_TLS_V1_2); + } + + setupHandlers(serverCtxBuilder.build(), clientCtxBuilder.build()); } finally { ssc.delete(); } @@ -1453,7 +1632,12 @@ protected void initChannel(Channel ch) throws Exception { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); ChannelPipeline p = ch.pipeline(); - p.addLast(serverSslCtx.newHandler(ch.alloc())); + + SslHandler sslHandler = delegatingExecutor == null ? + serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + p.addLast(sslHandler); p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); p.addLast(new ChannelInboundHandlerAdapter() { @Override @@ -1478,7 +1662,12 @@ protected void initChannel(Channel ch) throws Exception { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); ChannelPipeline p = ch.pipeline(); - p.addLast(clientSslCtx.newHandler(ch.alloc())); + + SslHandler sslHandler = delegatingExecutor == null ? + clientSslCtx.newHandler(ch.alloc()) : + clientSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + p.addLast(sslHandler); p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); p.addLast(new ChannelInboundHandlerAdapter() { @Override @@ -1490,6 +1679,11 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E ctx.fireExceptionCaught(cause); } } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + clientLatch.countDown(); + } }); } }); @@ -1503,12 +1697,15 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E @Test(timeout = 30000) public void testMutualAuthSameCertChain() throws Exception { - serverSslCtx = SslContextBuilder.forServer( - new ByteArrayInputStream(X509_CERT_PEM.getBytes(CharsetUtil.UTF_8)), - new ByteArrayInputStream(PRIVATE_KEY_PEM.getBytes(CharsetUtil.UTF_8))) - .trustManager(new ByteArrayInputStream(X509_CERT_PEM.getBytes(CharsetUtil.UTF_8))) - .clientAuth(ClientAuth.REQUIRE).sslProvider(sslServerProvider()) - .sslContextProvider(serverSslContextProvider()).build(); + serverSslCtx = + SslContextBuilder.forServer( + new ByteArrayInputStream(X509_CERT_PEM.getBytes(CharsetUtil.UTF_8)), + new ByteArrayInputStream(PRIVATE_KEY_PEM.getBytes(CharsetUtil.UTF_8))) + .trustManager(new ByteArrayInputStream(X509_CERT_PEM.getBytes(CharsetUtil.UTF_8))) + .clientAuth(ClientAuth.REQUIRE).sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()).build(); sb = new ServerBootstrap(); sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); @@ -1520,7 +1717,11 @@ public void testMutualAuthSameCertChain() throws Exception { protected void initChannel(Channel ch) throws Exception { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); - ch.pipeline().addFirst(serverSslCtx.newHandler(ch.alloc())); + SslHandler sslHandler = delegatingExecutor == null ? + serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + ch.pipeline().addFirst(sslHandler); ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { @@ -1559,13 +1760,14 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); - clientSslCtx = SslContextBuilder.forClient() - .keyManager( + clientSslCtx = + SslContextBuilder.forClient().keyManager( new ByteArrayInputStream(CLIENT_X509_CERT_CHAIN_PEM.getBytes(CharsetUtil.UTF_8)), new ByteArrayInputStream(CLIENT_PRIVATE_KEY_PEM.getBytes(CharsetUtil.UTF_8))) .trustManager(new ByteArrayInputStream(X509_CERT_PEM.getBytes(CharsetUtil.UTF_8))) .sslProvider(sslClientProvider()) - .sslContextProvider(clientSslContextProvider()).build(); + .sslContextProvider(clientSslContextProvider()) + .protocols(protocols()).ciphers(ciphers()).build(); cb = new Bootstrap(); cb.group(new NioEventLoopGroup()); cb.channel(NioSocketChannel.class); @@ -1573,7 +1775,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc @Override protected void initChannel(Channel ch) throws Exception { ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); - ch.pipeline().addLast(new SslHandler(clientSslCtx.newEngine(ch.alloc()))); + ch.pipeline().addLast(new SslHandler(wrapEngine(clientSslCtx.newEngine(ch.alloc())))); } }).connect(serverChannel.localAddress()).syncUninterruptibly().channel(); @@ -1589,14 +1791,18 @@ public void testUnwrapBehavior() throws Exception { .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); byte[] bytes = "Hello World".getBytes(CharsetUtil.US_ASCII); @@ -1681,14 +1887,14 @@ private void testProtocol(String[] clientProtocols, String[] serverProtocols) th .sslProvider(sslClientProvider()) .protocols(clientProtocols) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) .protocols(serverProtocols) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { handshake(client, server); @@ -1720,8 +1926,8 @@ public void testHandshakeCompletesWithNonContiguousProtocolsTLSv1_2CipherOnly() SSLEngine clientEngine = null; SSLEngine serverEngine = null; try { - clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); - serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); handshake(clientEngine, serverEngine); } finally { cleanupClientSslEngine(clientEngine); @@ -1751,8 +1957,8 @@ public void testHandshakeCompletesWithoutFilteringSupportedCipher() throws Excep SSLEngine clientEngine = null; SSLEngine serverEngine = null; try { - clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); - serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); handshake(clientEngine, serverEngine); } finally { cleanupClientSslEngine(clientEngine); @@ -1769,14 +1975,18 @@ public void testPacketBufferSizeLimit() throws Exception { .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { // Allocate an buffer that is bigger then the max plain record size. @@ -1808,8 +2018,10 @@ public void testSSLEngineUnwrapNoSslRecord() throws Exception { clientSslCtx = SslContextBuilder .forClient() .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { ByteBuffer src = allocateBuffer(client.getSession().getApplicationBufferSize()); @@ -1836,8 +2048,10 @@ public void testBeginHandshakeAfterEngineClosed() throws SSLException { clientSslCtx = SslContextBuilder .forClient() .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { client.closeInbound(); @@ -1860,14 +2074,18 @@ public void testBeginHandshakeCloseOutbound() throws Exception { clientSslCtx = SslContextBuilder .forClient() .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { testBeginHandshakeCloseOutbound(client); @@ -1907,14 +2125,18 @@ public void testCloseInboundAfterBeginHandshake() throws Exception { clientSslCtx = SslContextBuilder .forClient() .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { testCloseInboundAfterBeginHandshake(client); @@ -1944,14 +2166,18 @@ public void testCloseNotifySequence() throws Exception { .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + // This test only works for non TLSv1.3 for now + .protocols(PROTOCOL_TLS_V1_2) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + // This test only works for non TLSv1.3 for now + .protocols(PROTOCOL_TLS_V1_2) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { ByteBuffer plainClientOut = allocateBuffer(client.getSession().getApplicationBufferSize()); @@ -1976,7 +2202,14 @@ public void testCloseNotifySequence() throws Exception { assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); // Need an UNWRAP to read the response of the close_notify - assertEquals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); + if (PlatformDependent.javaVersion() >= 12 && sslClientProvider() == SslProvider.JDK) { + // This is a workaround for a possible JDK12+ bug. + // + // See http://mail.openjdk.java.net/pipermail/security-dev/2019-February/019406.html. + assertEquals(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); + } else { + assertEquals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); + } int produced = result.bytesProduced(); int consumed = result.bytesConsumed(); @@ -2011,6 +2244,7 @@ public void testCloseNotifySequence() throws Exception { result = server.wrap(empty, encryptedServerToClient); encryptedServerToClient.flip(); + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); // UNWRAP/WRAP are not expected after this point assertEquals(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); @@ -2025,6 +2259,7 @@ public void testCloseNotifySequence() throws Exception { assertTrue(server.isInboundDone()); result = client.unwrap(encryptedServerToClient, plainClientOut); + plainClientOut.flip(); assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); // UNWRAP/WRAP are not expected after this point @@ -2085,14 +2320,18 @@ public void testWrapAfterCloseOutbound() throws Exception { .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { ByteBuffer dst = allocateBuffer(client.getSession().getPacketBufferSize()); @@ -2124,14 +2363,18 @@ public void testMultipleRecordsInOneBufferWithNonZeroPosition() throws Exception .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { // Choose buffer size small enough that we can put multiple buffers into one buffer and pass it into the @@ -2199,14 +2442,18 @@ public void testMultipleRecordsInOneBufferBiggerThenPacketBufferSize() throws Ex .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { ByteBuffer plainClientOut = allocateBuffer(4096); @@ -2219,21 +2466,35 @@ public void testMultipleRecordsInOneBufferBiggerThenPacketBufferSize() throws Ex int srcLen = plainClientOut.remaining(); SSLEngineResult result; - while (encClientToServer.position() <= server.getSession().getPacketBufferSize()) { + int count = 0; + do { + int plainClientOutPosition = plainClientOut.position(); + int encClientToServerPosition = encClientToServer.position(); result = client.wrap(plainClientOut, encClientToServer); + if (result.getStatus() == Status.BUFFER_OVERFLOW) { + // We did not have enough room to wrap + assertEquals(plainClientOutPosition, plainClientOut.position()); + assertEquals(encClientToServerPosition, encClientToServer.position()); + break; + } assertEquals(SSLEngineResult.Status.OK, result.getStatus()); assertEquals(srcLen, result.bytesConsumed()); assertTrue(result.bytesProduced() > 0); plainClientOut.clear(); - } + ++count; + } while (encClientToServer.position() < server.getSession().getPacketBufferSize()); + + // Check that we were able to wrap multiple times. + assertTrue(count >= 2); encClientToServer.flip(); result = server.unwrap(encClientToServer, plainServerOut); assertEquals(SSLEngineResult.Status.OK, result.getStatus()); assertTrue(result.bytesConsumed() > 0); assertTrue(result.bytesProduced() > 0); + assertTrue(encClientToServer.hasRemaining()); } finally { cert.delete(); cleanupClientSslEngine(client); @@ -2249,14 +2510,18 @@ public void testBufferUnderFlow() throws Exception { .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { ByteBuffer plainClient = allocateBuffer(1024); @@ -2320,14 +2585,18 @@ public void testWrapDoesNotZeroOutSrc() throws Exception { .forClient() .trustManager(cert.cert()) .sslProvider(sslClientProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); serverSslCtx = SslContextBuilder .forServer(cert.certificate(), cert.privateKey()) .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) .build(); - SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); try { ByteBuffer plainServerOut = allocateBuffer(server.getSession().getApplicationBufferSize() / 2); @@ -2355,4 +2624,590 @@ public void testWrapDoesNotZeroOutSrc() throws Exception { cert.delete(); } } + + @Test + public void testDisableProtocols() throws Exception { + testDisableProtocols(PROTOCOL_SSL_V2, PROTOCOL_SSL_V2); + testDisableProtocols(PROTOCOL_SSL_V3, PROTOCOL_SSL_V2, PROTOCOL_SSL_V3); + testDisableProtocols(PROTOCOL_TLS_V1, PROTOCOL_SSL_V2, PROTOCOL_SSL_V3, PROTOCOL_TLS_V1); + testDisableProtocols(PROTOCOL_TLS_V1_1, PROTOCOL_SSL_V2, PROTOCOL_SSL_V3, PROTOCOL_TLS_V1, PROTOCOL_TLS_V1_1); + testDisableProtocols(PROTOCOL_TLS_V1_2, PROTOCOL_SSL_V2, + PROTOCOL_SSL_V3, PROTOCOL_TLS_V1, PROTOCOL_TLS_V1_1, PROTOCOL_TLS_V1_2); + } + + private void testDisableProtocols(String protocol, String... disabledProtocols) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + SslContext ctx = SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + SSLEngine server = wrapEngine(ctx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + Set supported = new HashSet(Arrays.asList(server.getSupportedProtocols())); + if (supported.contains(protocol)) { + server.setEnabledProtocols(server.getSupportedProtocols()); + assertEquals(supported, new HashSet(Arrays.asList(server.getSupportedProtocols()))); + + for (String disabled : disabledProtocols) { + supported.remove(disabled); + } + if (supported.contains(PROTOCOL_SSL_V2_HELLO) && supported.size() == 1) { + // It's not allowed to set only PROTOCOL_SSL_V2_HELLO if using JDK SSLEngine. + return; + } + server.setEnabledProtocols(supported.toArray(new String[0])); + assertEquals(supported, new HashSet(Arrays.asList(server.getEnabledProtocols()))); + server.setEnabledProtocols(server.getSupportedProtocols()); + } + } finally { + cleanupServerSslEngine(server); + cleanupClientSslContext(ctx); + cert.delete(); + } + } + + @Test + public void testUsingX509TrustManagerVerifiesHostname() throws Exception { + SslProvider clientProvider = sslClientProvider(); + if (clientProvider == SslProvider.OPENSSL || clientProvider == SslProvider.OPENSSL_REFCNT) { + // Need to check if we support hostname validation in the current used OpenSSL version before running + // the test. + Assume.assumeTrue(OpenSsl.supportsHostnameValidation()); + } + SelfSignedCertificate cert = new SelfSignedCertificate(); + clientSslCtx = SslContextBuilder + .forClient() + .trustManager(new TrustManagerFactory(new TrustManagerFactorySpi() { + @Override + protected void engineInit(KeyStore keyStore) { + // NOOP + } + @Override + protected TrustManager[] engineGetTrustManagers() { + // Provide a custom trust manager, this manager trust all certificates + return new TrustManager[] { + new X509TrustManager() { + @Override + public void checkClientTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s) { + // NOOP + } + + @Override + public void checkServerTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s) { + // NOOP + } + + @Override + public java.security.cert.X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } + }; + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + } + }, null, TrustManagerFactory.getDefaultAlgorithm()) { + }) + .sslProvider(sslClientProvider()) + .build(); + + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT, "netty.io", 1234)); + SSLParameters sslParameters = client.getSSLParameters(); + sslParameters.setEndpointIdentificationAlgorithm("HTTPS"); + client.setSSLParameters(sslParameters); + + serverSslCtx = SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + try { + handshake(client, server); + fail(); + } catch (SSLException expected) { + // expected as the hostname not matches. + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + cert.delete(); + } + } + + @Test + public void testInvalidCipher() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + List cipherList = new ArrayList(); + Collections.addAll(cipherList, ((SSLSocketFactory) SSLSocketFactory.getDefault()).getDefaultCipherSuites()); + cipherList.add("InvalidCipher"); + SSLEngine server = null; + try { + serverSslCtx = SslContextBuilder.forServer(cert.key(), cert.cert()).sslProvider(sslClientProvider()) + .ciphers(cipherList).build(); + server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + fail(); + } catch (IllegalArgumentException expected) { + // expected when invalid cipher is used. + } catch (SSLException expected) { + // expected when invalid cipher is used. + } finally { + cert.delete(); + cleanupServerSslEngine(server); + } + } + + @Test + public void testGetCiphersuite() throws Exception { + clientSslCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(clientEngine, serverEngine); + + String clientCipher = clientEngine.getSession().getCipherSuite(); + String serverCipher = serverEngine.getSession().getCipherSuite(); + assertEquals(clientCipher, serverCipher); + + assertEquals(protocolCipherCombo.cipher, clientCipher); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @Test + public void testSessionBindingEvent() throws Exception { + clientSslCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(clientEngine, serverEngine); + SSLSession session = clientEngine.getSession(); + assertEquals(0, session.getValueNames().length); + + class SSLSessionBindingEventValue implements SSLSessionBindingListener { + SSLSessionBindingEvent boundEvent; + SSLSessionBindingEvent unboundEvent; + + @Override + public void valueBound(SSLSessionBindingEvent sslSessionBindingEvent) { + assertNull(boundEvent); + boundEvent = sslSessionBindingEvent; + } + + @Override + public void valueUnbound(SSLSessionBindingEvent sslSessionBindingEvent) { + assertNull(unboundEvent); + unboundEvent = sslSessionBindingEvent; + } + } + + String name = "name"; + String name2 = "name2"; + + SSLSessionBindingEventValue value1 = new SSLSessionBindingEventValue(); + session.putValue(name, value1); + assertSSLSessionBindingEventValue(name, session, value1.boundEvent); + assertNull(value1.unboundEvent); + assertEquals(1, session.getValueNames().length); + + session.putValue(name2, "value"); + + SSLSessionBindingEventValue value2 = new SSLSessionBindingEventValue(); + session.putValue(name, value2); + assertEquals(2, session.getValueNames().length); + + assertSSLSessionBindingEventValue(name, session, value1.unboundEvent); + assertSSLSessionBindingEventValue(name, session, value2.boundEvent); + assertNull(value2.unboundEvent); + assertEquals(2, session.getValueNames().length); + + session.removeValue(name); + assertSSLSessionBindingEventValue(name, session, value2.unboundEvent); + assertEquals(1, session.getValueNames().length); + session.removeValue(name2); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + private static void assertSSLSessionBindingEventValue( + String name, SSLSession session, SSLSessionBindingEvent event) { + assertEquals(name, event.getName()); + assertEquals(session, event.getSession()); + assertEquals(session, event.getSource()); + } + + @Test + public void testSessionAfterHandshake() throws Exception { + testSessionAfterHandshake0(false, false); + } + + @Test + public void testSessionAfterHandshakeMutualAuth() throws Exception { + testSessionAfterHandshake0(false, true); + } + + @Test + public void testSessionAfterHandshakeKeyManagerFactory() throws Exception { + testSessionAfterHandshake0(true, false); + } + + @Test + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception { + testSessionAfterHandshake0(true, true); + } + + private void testSessionAfterHandshake0(boolean useKeyManagerFactory, boolean mutualAuth) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + KeyManagerFactory kmf = useKeyManagerFactory ? + SslContext.buildKeyManagerFactory( + new java.security.cert.X509Certificate[] { ssc.cert()}, ssc.key(), null, null) : null; + + SslContextBuilder clientContextBuilder = SslContextBuilder.forClient(); + if (mutualAuth) { + if (kmf != null) { + clientContextBuilder.keyManager(kmf); + } else { + clientContextBuilder.keyManager(ssc.key(), ssc.cert()); + } + } + clientSslCtx = clientContextBuilder + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + + SslContextBuilder serverContextBuilder = kmf != null ? + SslContextBuilder.forServer(kmf) : + SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()); + if (mutualAuth) { + serverContextBuilder.clientAuth(ClientAuth.REQUIRE); + } + serverSslCtx = serverContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + handshake(clientEngine, serverEngine); + + SSLSession clientSession = clientEngine.getSession(); + SSLSession serverSession = serverEngine.getSession(); + + assertNull(clientSession.getPeerHost()); + assertNull(serverSession.getPeerHost()); + assertEquals(-1, clientSession.getPeerPort()); + assertEquals(-1, serverSession.getPeerPort()); + + assertTrue(clientSession.getCreationTime() > 0); + assertTrue(serverSession.getCreationTime() > 0); + + assertTrue(clientSession.getLastAccessedTime() > 0); + assertTrue(serverSession.getLastAccessedTime() > 0); + + assertEquals(protocolCipherCombo.protocol, clientSession.getProtocol()); + assertEquals(protocolCipherCombo.protocol, serverSession.getProtocol()); + + assertEquals(protocolCipherCombo.cipher, clientSession.getCipherSuite()); + assertEquals(protocolCipherCombo.cipher, serverSession.getCipherSuite()); + + assertNotNull(clientSession.getId()); + assertNotNull(serverSession.getId()); + + assertTrue(clientSession.getApplicationBufferSize() > 0); + assertTrue(serverSession.getApplicationBufferSize() > 0); + + assertTrue(clientSession.getPacketBufferSize() > 0); + assertTrue(serverSession.getPacketBufferSize() > 0); + + assertNotNull(clientSession.getSessionContext()); + assertNotNull(serverSession.getSessionContext()); + + Object value = new Object(); + + assertEquals(0, clientSession.getValueNames().length); + clientSession.putValue("test", value); + assertEquals("test", clientSession.getValueNames()[0]); + assertSame(value, clientSession.getValue("test")); + clientSession.removeValue("test"); + assertEquals(0, clientSession.getValueNames().length); + + assertEquals(0, serverSession.getValueNames().length); + serverSession.putValue("test", value); + assertEquals("test", serverSession.getValueNames()[0]); + assertSame(value, serverSession.getValue("test")); + serverSession.removeValue("test"); + assertEquals(0, serverSession.getValueNames().length); + + Certificate[] serverLocalCertificates = serverSession.getLocalCertificates(); + assertEquals(1, serverLocalCertificates.length); + assertArrayEquals(ssc.cert().getEncoded(), serverLocalCertificates[0].getEncoded()); + + Principal serverLocalPrincipal = serverSession.getLocalPrincipal(); + assertNotNull(serverLocalPrincipal); + + if (mutualAuth) { + Certificate[] clientLocalCertificates = clientSession.getLocalCertificates(); + assertEquals(1, clientLocalCertificates.length); + + Certificate[] serverPeerCertificates = serverSession.getPeerCertificates(); + assertEquals(1, serverPeerCertificates.length); + assertArrayEquals(clientLocalCertificates[0].getEncoded(), serverPeerCertificates[0].getEncoded()); + + X509Certificate[] serverPeerX509Certificates = serverSession.getPeerCertificateChain(); + assertEquals(1, serverPeerX509Certificates.length); + assertArrayEquals(clientLocalCertificates[0].getEncoded(), serverPeerX509Certificates[0].getEncoded()); + + Principal clientLocalPrincipial = clientSession.getLocalPrincipal(); + assertNotNull(clientLocalPrincipial); + + Principal serverPeerPrincipal = serverSession.getPeerPrincipal(); + assertEquals(clientLocalPrincipial, serverPeerPrincipal); + } else { + assertNull(clientSession.getLocalCertificates()); + assertNull(clientSession.getLocalPrincipal()); + + try { + serverSession.getPeerCertificates(); + fail(); + } catch (SSLPeerUnverifiedException expected) { + // As we did not use mutual auth this is expected + } + + try { + serverSession.getPeerCertificateChain(); + fail(); + } catch (SSLPeerUnverifiedException expected) { + // As we did not use mutual auth this is expected + } + + try { + serverSession.getPeerPrincipal(); + fail(); + } catch (SSLPeerUnverifiedException expected) { + // As we did not use mutual auth this is expected + } + } + + Certificate[] clientPeerCertificates = clientSession.getPeerCertificates(); + assertEquals(1, clientPeerCertificates.length); + assertArrayEquals(serverLocalCertificates[0].getEncoded(), clientPeerCertificates[0].getEncoded()); + + X509Certificate[] clientPeerX509Certificates = clientSession.getPeerCertificateChain(); + assertEquals(1, clientPeerX509Certificates.length); + assertArrayEquals(serverLocalCertificates[0].getEncoded(), clientPeerX509Certificates[0].getEncoded()); + + Principal clientPeerPrincipal = clientSession.getPeerPrincipal(); + assertEquals(serverLocalPrincipal, clientPeerPrincipal); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @Test + public void testHandshakeSession() throws Exception { + final SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final TestTrustManagerFactory clientTmf = new TestTrustManagerFactory(ssc.cert()); + final TestTrustManagerFactory serverTmf = new TestTrustManagerFactory(ssc.cert()); + + clientSslCtx = SslContextBuilder.forClient() + .trustManager(new SimpleTrustManagerFactory() { + @Override + protected void engineInit(KeyStore keyStore) { + // NOOP + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + // NOOP + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { clientTmf }; + } + }) + .keyManager(newKeyManagerFactory(ssc)) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + serverSslCtx = SslContextBuilder.forServer(newKeyManagerFactory(ssc)) + .trustManager(new SimpleTrustManagerFactory() { + @Override + protected void engineInit(KeyStore keyStore) { + // NOOP + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + // NOOP + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { serverTmf }; + } + }) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .clientAuth(ClientAuth.REQUIRE) + .build(); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(clientEngine, serverEngine); + + assertTrue(clientTmf.isVerified()); + assertTrue(serverTmf.isVerified()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + private KeyManagerFactory newKeyManagerFactory(SelfSignedCertificate ssc) + throws UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException, + CertificateException, IOException { + return SslContext.buildKeyManagerFactory( + new java.security.cert.X509Certificate[] { ssc.cert() }, ssc.key(), null, null); + } + + private final class TestTrustManagerFactory extends X509ExtendedTrustManager { + private final Certificate localCert; + private volatile boolean verified; + + TestTrustManagerFactory(Certificate localCert) { + this.localCert = localCert; + } + + boolean isVerified() { + return verified; + } + + @Override + public void checkClientTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s, Socket socket) { + fail(); + } + + @Override + public void checkServerTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s, Socket socket) { + fail(); + } + + @Override + public void checkClientTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) { + verified = true; + assertFalse(sslEngine.getUseClientMode()); + SSLSession session = sslEngine.getHandshakeSession(); + assertNotNull(session); + Certificate[] localCertificates = session.getLocalCertificates(); + assertNotNull(localCertificates); + assertEquals(1, localCertificates.length); + assertEquals(localCert, localCertificates[0]); + assertNotNull(session.getLocalPrincipal()); + } + + @Override + public void checkServerTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) { + verified = true; + assertTrue(sslEngine.getUseClientMode()); + SSLSession session = sslEngine.getHandshakeSession(); + assertNotNull(session); + assertNull(session.getLocalCertificates()); + assertNull(session.getLocalPrincipal()); + } + + @Override + public void checkClientTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s) { + fail(); + } + + @Override + public void checkServerTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s) { + fail(); + } + + @Override + public java.security.cert.X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } + + protected SSLEngine wrapEngine(SSLEngine engine) { + return engine; + } + + protected List ciphers() { + return Collections.singletonList(protocolCipherCombo.cipher); + } + + protected String[] protocols() { + return new String[] { protocolCipherCombo.protocol }; + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SignatureAlgorithmConverterTest.java b/handler/src/test/java/io/netty/handler/ssl/SignatureAlgorithmConverterTest.java new file mode 100644 index 000000000000..0b8ce3a03e6b --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/SignatureAlgorithmConverterTest.java @@ -0,0 +1,59 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class SignatureAlgorithmConverterTest { + + @Test + public void testWithEncryption() { + assertEquals("SHA512withRSA", SignatureAlgorithmConverter.toJavaName("sha512WithRSAEncryption")); + } + + @Test + public void testWithDash() { + assertEquals("SHA256withECDSA", SignatureAlgorithmConverter.toJavaName("ecdsa-with-SHA256")); + } + + @Test + public void testWithUnderscore() { + assertEquals("SHA256withDSA", SignatureAlgorithmConverter.toJavaName("dsa_with_SHA256")); + } + + @Test + public void testBoringSSLOneUnderscore() { + assertEquals("SHA256withECDSA", SignatureAlgorithmConverter.toJavaName("ecdsa_sha256")); + } + + @Test + public void testBoringSSLPkcs1() { + assertEquals("SHA256withRSA", SignatureAlgorithmConverter.toJavaName("rsa_pkcs1_sha256")); + } + + @Test + public void testBoringSSLPSS() { + assertEquals("SHA256withRSA", SignatureAlgorithmConverter.toJavaName("rsa_pss_rsae_sha256")); + } + + @Test + public void testInvalid() { + assertNull(SignatureAlgorithmConverter.toJavaName("ThisIsSomethingInvalid")); + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java b/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java index 7ba413c7696f..4db7c7b73a3e 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java @@ -29,14 +29,43 @@ import io.netty.channel.local.LocalServerChannel; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.handler.ssl.util.SimpleTrustManagerFactory; +import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Promise; +import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.ThrowableUtil; +import org.junit.Assert; +import javax.net.ssl.ExtendedSSLSession; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.SNIHostName; import javax.net.ssl.SNIMatcher; import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedKeyManager; +import javax.net.ssl.X509ExtendedTrustManager; +import java.io.IOException; +import java.net.Socket; +import java.security.InvalidAlgorithmParameterException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; /** * In extra class to be able to run tests with java7 without trying to load classes that not exists in java7. @@ -48,20 +77,25 @@ private SniClientJava8TestUtil() { } static void testSniClient(SslProvider sslClientProvider, SslProvider sslServerProvider, final boolean match) throws Exception { final String sniHost = "sni.netty.io"; + SelfSignedCertificate cert = new SelfSignedCertificate(); LocalAddress address = new LocalAddress("test"); EventLoopGroup group = new DefaultEventLoopGroup(1); + SslContext sslServerContext = null; + SslContext sslClientContext = null; + Channel sc = null; Channel cc = null; try { - SelfSignedCertificate cert = new SelfSignedCertificate(); - final SslContext sslServerContext = SslContextBuilder.forServer(cert.key(), cert.cert()) + sslServerContext = SslContextBuilder.forServer(cert.key(), cert.cert()) .sslProvider(sslServerProvider).build(); final Promise promise = group.next().newPromise(); ServerBootstrap sb = new ServerBootstrap(); + + final SslContext finalContext = sslServerContext; sc = sb.group(group).channel(LocalServerChannel.class).childHandler(new ChannelInitializer() { @Override protected void initChannel(Channel ch) throws Exception { - SslHandler handler = sslServerContext.newHandler(ch.alloc()); + SslHandler handler = finalContext.newHandler(ch.alloc()); SSLParameters parameters = handler.engine().getSSLParameters(); SNIMatcher matcher = new SNIMatcher(0) { @Override @@ -104,11 +138,11 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } }).bind(address).syncUninterruptibly().channel(); - SslContext sslContext = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE) + sslClientContext = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE) .sslProvider(sslClientProvider).build(); SslHandler sslHandler = new SslHandler( - sslContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1)); + sslClientContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1)); Bootstrap cb = new Bootstrap(); cc = cb.group(group).channel(LocalChannel.class).handler(sslHandler) .connect(address).syncUninterruptibly().channel(); @@ -122,7 +156,190 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc if (sc != null) { sc.close().syncUninterruptibly(); } + + ReferenceCountUtil.release(sslServerContext); + ReferenceCountUtil.release(sslClientContext); + + cert.delete(); + group.shutdownGracefully(); } } + + static void assertSSLSession(boolean clientSide, SSLSession session, String name) { + assertSSLSession(clientSide, session, new SNIHostName(name)); + } + + private static void assertSSLSession(boolean clientSide, SSLSession session, SNIServerName name) { + Assert.assertNotNull(session); + if (session instanceof ExtendedSSLSession) { + ExtendedSSLSession extendedSSLSession = (ExtendedSSLSession) session; + List names = extendedSSLSession.getRequestedServerNames(); + Assert.assertEquals(1, names.size()); + Assert.assertEquals(name, names.get(0)); + Assert.assertTrue(extendedSSLSession.getLocalSupportedSignatureAlgorithms().length > 0); + if (clientSide) { + Assert.assertEquals(0, extendedSSLSession.getPeerSupportedSignatureAlgorithms().length); + } else { + Assert.assertTrue(extendedSSLSession.getPeerSupportedSignatureAlgorithms().length >= 0); + } + } + } + + static TrustManagerFactory newSniX509TrustmanagerFactory(String name) { + return new SniX509TrustmanagerFactory(new SNIHostName(name)); + } + + private static final class SniX509TrustmanagerFactory extends SimpleTrustManagerFactory { + + private final SNIServerName name; + + SniX509TrustmanagerFactory(SNIServerName name) { + this.name = name; + } + + @Override + protected void engineInit(KeyStore keyStore) throws Exception { + // NOOP + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) throws Exception { + // NOOP + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { new X509ExtendedTrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s, Socket socket) + throws CertificateException { + Assert.fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s, Socket socket) + throws CertificateException { + Assert.fail(); + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) + throws CertificateException { + Assert.fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) + throws CertificateException { + assertSSLSession(sslEngine.getUseClientMode(), sslEngine.getHandshakeSession(), name); + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + Assert.fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + Assert.fail(); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } }; + } + } + + static KeyManagerFactory newSniX509KeyManagerFactory(SelfSignedCertificate cert, String hostname) + throws NoSuchAlgorithmException, KeyStoreException, UnrecoverableKeyException, + IOException, CertificateException { + return new SniX509KeyManagerFactory( + new SNIHostName(hostname), SslContext.buildKeyManagerFactory( + new X509Certificate[] { cert.cert() }, cert.key(), null, null)); + } + + private static final class SniX509KeyManagerFactory extends KeyManagerFactory { + + SniX509KeyManagerFactory(final SNIServerName name, final KeyManagerFactory factory) { + super(new KeyManagerFactorySpi() { + @Override + protected void engineInit(KeyStore keyStore, char[] chars) + throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + factory.init(keyStore, chars); + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) + throws InvalidAlgorithmParameterException { + factory.init(managerFactoryParameters); + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + List managers = new ArrayList(); + for (final KeyManager km: factory.getKeyManagers()) { + if (km instanceof X509ExtendedKeyManager) { + managers.add(new X509ExtendedKeyManager() { + @Override + public String[] getClientAliases(String s, Principal[] principals) { + return ((X509ExtendedKeyManager) km).getClientAliases(s, principals); + } + + @Override + public String chooseClientAlias(String[] strings, Principal[] principals, + Socket socket) { + return ((X509ExtendedKeyManager) km).chooseClientAlias(strings, principals, socket); + } + + @Override + public String[] getServerAliases(String s, Principal[] principals) { + return ((X509ExtendedKeyManager) km).getServerAliases(s, principals); + } + + @Override + public String chooseServerAlias(String s, Principal[] principals, Socket socket) { + return ((X509ExtendedKeyManager) km).chooseServerAlias(s, principals, socket); + } + + @Override + public X509Certificate[] getCertificateChain(String s) { + return ((X509ExtendedKeyManager) km).getCertificateChain(s); + } + + @Override + public PrivateKey getPrivateKey(String s) { + return ((X509ExtendedKeyManager) km).getPrivateKey(s); + } + + @Override + public String chooseEngineClientAlias(String[] strings, Principal[] principals, + SSLEngine sslEngine) { + return ((X509ExtendedKeyManager) km) + .chooseEngineClientAlias(strings, principals, sslEngine); + } + + @Override + public String chooseEngineServerAlias(String s, Principal[] principals, + SSLEngine sslEngine) { + + SSLSession session = sslEngine.getHandshakeSession(); + assertSSLSession(sslEngine.getUseClientMode(), session, name); + return ((X509ExtendedKeyManager) km) + .chooseEngineServerAlias(s, principals, sslEngine); + } + }); + } else { + managers.add(km); + } + } + return managers.toArray(new KeyManager[0]); + } + }, factory.getProvider(), factory.getAlgorithm()); + } + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SniClientTest.java b/handler/src/test/java/io/netty/handler/ssl/SniClientTest.java index ca1c9b8bf319..56ea815eb85b 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniClientTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniClientTest.java @@ -28,104 +28,100 @@ import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.util.Mapping; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; import io.netty.util.concurrent.Promise; import io.netty.util.internal.PlatformDependent; import org.junit.Assert; import org.junit.Assume; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLException; - +import javax.net.ssl.TrustManagerFactory; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +@RunWith(Parameterized.class) public class SniClientTest { - @Test(timeout = 30000) - public void testSniClientJdkSslServerJdkSsl() throws Exception { - testSniClient(SslProvider.JDK, SslProvider.JDK); - } - - @Test(timeout = 30000) - public void testSniClientOpenSslServerOpenSsl() throws Exception { - Assume.assumeTrue(OpenSsl.isAvailable()); - testSniClient(SslProvider.OPENSSL, SslProvider.OPENSSL); - } - - @Test(timeout = 30000) - public void testSniClientJdkSslServerOpenSsl() throws Exception { - Assume.assumeTrue(OpenSsl.isAvailable()); - testSniClient(SslProvider.JDK, SslProvider.OPENSSL); - } - - @Test(timeout = 30000) - public void testSniClientOpenSslServerJdkSsl() throws Exception { - Assume.assumeTrue(OpenSsl.isAvailable()); - testSniClient(SslProvider.OPENSSL, SslProvider.JDK); - } - - @Test(timeout = 30000) - public void testSniSNIMatcherMatchesClientJdkSslServerJdkSsl() throws Exception { - Assume.assumeTrue(PlatformDependent.javaVersion() >= 8); - SniClientJava8TestUtil.testSniClient(SslProvider.JDK, SslProvider.JDK, true); - } + @Parameters(name = "{index}: serverSslProvider = {0}, clientSslProvider = {1}") + public static Collection parameters() { + List providers = new ArrayList(Arrays.asList(SslProvider.values())); + if (!OpenSsl.isAvailable()) { + providers.remove(SslProvider.OPENSSL); + providers.remove(SslProvider.OPENSSL_REFCNT); + } - @Test(timeout = 30000, expected = SSLException.class) - public void testSniSNIMatcherDoesNotMatchClientJdkSslServerJdkSsl() throws Exception { - Assume.assumeTrue(PlatformDependent.javaVersion() >= 8); - SniClientJava8TestUtil.testSniClient(SslProvider.JDK, SslProvider.JDK, false); + List params = new ArrayList(); + for (SslProvider sp: providers) { + for (SslProvider cp: providers) { + params.add(new Object[] { sp, cp }); + } + } + return params; } - @Test(timeout = 30000) - public void testSniSNIMatcherMatchesClientOpenSslServerOpenSsl() throws Exception { - Assume.assumeTrue(PlatformDependent.javaVersion() >= 8); - Assume.assumeTrue(OpenSsl.isAvailable()); - SniClientJava8TestUtil.testSniClient(SslProvider.OPENSSL, SslProvider.OPENSSL, true); - } + private final SslProvider serverProvider; + private final SslProvider clientProvider; - @Test(timeout = 30000, expected = SSLException.class) - public void testSniSNIMatcherDoesNotMatchClientOpenSslServerOpenSsl() throws Exception { - Assume.assumeTrue(PlatformDependent.javaVersion() >= 8); - Assume.assumeTrue(OpenSsl.isAvailable()); - SniClientJava8TestUtil.testSniClient(SslProvider.OPENSSL, SslProvider.OPENSSL, false); + public SniClientTest(SslProvider serverProvider, SslProvider clientProvider) { + this.serverProvider = serverProvider; + this.clientProvider = clientProvider; } @Test(timeout = 30000) - public void testSniSNIMatcherMatchesClientJdkSslServerOpenSsl() throws Exception { - Assume.assumeTrue(PlatformDependent.javaVersion() >= 8); - Assume.assumeTrue(OpenSsl.isAvailable()); - SniClientJava8TestUtil.testSniClient(SslProvider.JDK, SslProvider.OPENSSL, true); - } - - @Test(timeout = 30000, expected = SSLException.class) - public void testSniSNIMatcherDoesNotMatchClientJdkSslServerOpenSsl() throws Exception { - Assume.assumeTrue(PlatformDependent.javaVersion() >= 8); - Assume.assumeTrue(OpenSsl.isAvailable()); - SniClientJava8TestUtil.testSniClient(SslProvider.JDK, SslProvider.OPENSSL, false); + public void testSniClient() throws Exception { + testSniClient(serverProvider, clientProvider); } @Test(timeout = 30000) - public void testSniSNIMatcherMatchesClientOpenSslServerJdkSsl() throws Exception { + public void testSniSNIMatcherMatchesClient() throws Exception { Assume.assumeTrue(PlatformDependent.javaVersion() >= 8); - Assume.assumeTrue(OpenSsl.isAvailable()); - SniClientJava8TestUtil.testSniClient(SslProvider.OPENSSL, SslProvider.JDK, true); + SniClientJava8TestUtil.testSniClient(serverProvider, clientProvider, true); } @Test(timeout = 30000, expected = SSLException.class) - public void testSniSNIMatcherDoesNotMatchClientOpenSslServerJdkSsl() throws Exception { + public void testSniSNIMatcherDoesNotMatchClient() throws Exception { Assume.assumeTrue(PlatformDependent.javaVersion() >= 8); - Assume.assumeTrue(OpenSsl.isAvailable()); - SniClientJava8TestUtil.testSniClient(SslProvider.OPENSSL, SslProvider.JDK, false); + SniClientJava8TestUtil.testSniClient(serverProvider, clientProvider, false); } - private static void testSniClient(SslProvider sslClientProvider, SslProvider sslServerProvider) throws Exception { - final String sniHost = "sni.netty.io"; + private static void testSniClient(SslProvider sslServerProvider, SslProvider sslClientProvider) throws Exception { + String sniHostName = "sni.netty.io"; LocalAddress address = new LocalAddress("test"); EventLoopGroup group = new DefaultEventLoopGroup(1); + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContext sslServerContext = null; + SslContext sslClientContext = null; + Channel sc = null; Channel cc = null; try { - SelfSignedCertificate cert = new SelfSignedCertificate(); - final SslContext sslServerContext = SslContextBuilder.forServer(cert.key(), cert.cert()) - .sslProvider(sslServerProvider).build(); + if ((sslServerProvider == SslProvider.OPENSSL || sslServerProvider == SslProvider.OPENSSL_REFCNT) + && !OpenSsl.useKeyManagerFactory()) { + sslServerContext = SslContextBuilder.forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider) + .build(); + } else { + // The used OpenSSL version does support a KeyManagerFactory, so use it. + KeyManagerFactory kmf = PlatformDependent.javaVersion() >= 8 ? + SniClientJava8TestUtil.newSniX509KeyManagerFactory(cert, sniHostName) : + SslContext.buildKeyManagerFactory( + new X509Certificate[] { cert.cert() }, cert.key(), null, null); + + sslServerContext = SslContextBuilder.forServer(kmf) + .sslProvider(sslServerProvider) + .build(); + } + final SslContext finalContext = sslServerContext; final Promise promise = group.next().newPromise(); ServerBootstrap sb = new ServerBootstrap(); sc = sb.group(group).channel(LocalServerChannel.class).childHandler(new ChannelInitializer() { @@ -135,19 +131,33 @@ protected void initChannel(Channel ch) throws Exception { @Override public SslContext map(String input) { promise.setSuccess(input); - return sslServerContext; + return finalContext; } })); } }).bind(address).syncUninterruptibly().channel(); - SslContext sslContext = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(sslClientProvider).build(); + TrustManagerFactory tmf = PlatformDependent.javaVersion() >= 8 ? + SniClientJava8TestUtil.newSniX509TrustmanagerFactory(sniHostName) : + InsecureTrustManagerFactory.INSTANCE; + sslClientContext = SslContextBuilder.forClient().trustManager(tmf) + .sslProvider(sslClientProvider).build(); Bootstrap cb = new Bootstrap(); - cc = cb.group(group).channel(LocalChannel.class).handler(new SslHandler( - sslContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1))) + + SslHandler handler = new SslHandler( + sslClientContext.newEngine(ByteBufAllocator.DEFAULT, sniHostName, -1)); + cc = cb.group(group).channel(LocalChannel.class).handler(handler) .connect(address).syncUninterruptibly().channel(); - Assert.assertEquals(sniHost, promise.syncUninterruptibly().getNow()); + Assert.assertEquals(sniHostName, promise.syncUninterruptibly().getNow()); + + // After we are done with handshaking getHandshakeSession() should return null. + handler.handshakeFuture().syncUninterruptibly(); + Assert.assertNull(handler.engine().getHandshakeSession()); + + if (PlatformDependent.javaVersion() >= 8) { + SniClientJava8TestUtil.assertSSLSession( + handler.engine().getUseClientMode(), handler.engine().getSession(), sniHostName); + } } finally { if (cc != null) { cc.close().syncUninterruptibly(); @@ -155,6 +165,11 @@ public SslContext map(String input) { if (sc != null) { sc.close().syncUninterruptibly(); } + ReferenceCountUtil.release(sslServerContext); + ReferenceCountUtil.release(sslClientContext); + + cert.delete(); + group.shutdownGracefully(); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java index a6dceb364e70..8d8003d745e1 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -44,6 +44,7 @@ import io.netty.util.Mapping; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.netty.util.internal.ResourcesUtil; import io.netty.util.concurrent.Promise; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.StringUtil; @@ -98,8 +99,8 @@ private static SslContext makeSslContext(SslProvider provider, boolean apn) thro assumeApnSupported(provider); } - File keyFile = new File(SniHandlerTest.class.getResource("test_encrypted.pem").getFile()); - File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile()); + File keyFile = ResourcesUtil.getFile(SniHandlerTest.class, "test_encrypted.pem"); + File crtFile = ResourcesUtil.getFile(SniHandlerTest.class, "test.crt"); SslContextBuilder sslCtxBuilder = SslContextBuilder.forServer(crtFile, keyFile, "12345") .sslProvider(provider); @@ -114,7 +115,7 @@ private static SslContext makeSslClientContext(SslProvider provider, boolean apn assumeApnSupported(provider); } - File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile()); + File crtFile = ResourcesUtil.getFile(SniHandlerTest.class, "test.crt"); SslContextBuilder sslCtxBuilder = SslContextBuilder.forClient().trustManager(crtFile).sslProvider(provider); if (apn) { @@ -541,7 +542,7 @@ protected void initChannel(Channel ch) throws Exception { private static class CustomSslHandler extends SslHandler { private final SslContext sslContext; - public CustomSslHandler(SslContext sslContext, SSLEngine sslEngine) { + CustomSslHandler(SslContext sslContext, SSLEngine sslEngine) { super(sslEngine); this.sslContext = ObjectUtil.checkNotNull(sslContext, "sslContext"); } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java b/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java index 752424cfb557..20f2ccbb1459 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java @@ -15,15 +15,18 @@ */ package io.netty.handler.ssl; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.handler.ssl.util.SelfSignedCertificate; import org.junit.Assume; +import org.junit.Ignore; +import org.junit.Rule; import org.junit.Test; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import java.util.Collections; + +import static org.junit.Assert.*; public class SslContextBuilderTest { @@ -71,6 +74,39 @@ public void testServerContextOpenssl() throws Exception { testServerContext(SslProvider.OPENSSL); } + @Test(expected = IllegalArgumentException.class) + public void testInvalidCipherJdk() throws Exception { + Assume.assumeTrue(OpenSsl.isAvailable()); + testInvalidCipher(SslProvider.JDK); + } + + @Test + public void testInvalidCipherOpenSSL() throws Exception { + Assume.assumeTrue(OpenSsl.isAvailable()); + try { + // This may fail or not depending on the OpenSSL version used + // See https://github.com/openssl/openssl/issues/7196 + testInvalidCipher(SslProvider.OPENSSL); + if (!OpenSsl.versionString().contains("1.1.1")) { + fail(); + } + } catch (SSLException expected) { + // ok + } + } + + private static void testInvalidCipher(SslProvider provider) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContextBuilder builder = SslContextBuilder.forClient() + .sslProvider(provider) + .ciphers(Collections.singleton("SOME_INVALID_CIPHER")) + .keyManager(cert.certificate(), + cert.privateKey()) + .trustManager(cert.certificate()); + SslContext context = builder.build(); + context.newEngine(UnpooledByteBufAllocator.DEFAULT); + } + private static void testClientContextFromFile(SslProvider provider) throws Exception { SelfSignedCertificate cert = new SelfSignedCertificate(); SslContextBuilder builder = SslContextBuilder.forClient() diff --git a/handler/src/test/java/io/netty/handler/ssl/SslContextTest.java b/handler/src/test/java/io/netty/handler/ssl/SslContextTest.java index 1c0032885ac9..247575b2dfa7 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslContextTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslContextTest.java @@ -15,6 +15,7 @@ */ package io.netty.handler.ssl; +import io.netty.util.internal.ResourcesUtil; import org.junit.Assert; import org.junit.Test; @@ -23,6 +24,7 @@ import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; +import java.security.cert.CertificateException; import java.security.spec.InvalidKeySpecException; import javax.net.ssl.SSLContext; @@ -37,58 +39,58 @@ public abstract class SslContextTest { @Test(expected = IOException.class) public void testUnencryptedEmptyPassword() throws Exception { PrivateKey key = SslContext.toPrivateKey( - new File(getClass().getResource("test2_unencrypted.pem").getFile()), ""); + ResourcesUtil.getFile(getClass(), "test2_unencrypted.pem"), ""); Assert.assertNotNull(key); } @Test public void testUnEncryptedNullPassword() throws Exception { PrivateKey key = SslContext.toPrivateKey( - new File(getClass().getResource("test2_unencrypted.pem").getFile()), null); + ResourcesUtil.getFile(getClass(), "test2_unencrypted.pem"), null); Assert.assertNotNull(key); } @Test public void testEncryptedEmptyPassword() throws Exception { PrivateKey key = SslContext.toPrivateKey( - new File(getClass().getResource("test_encrypted_empty_pass.pem").getFile()), ""); + ResourcesUtil.getFile(getClass(), "test_encrypted_empty_pass.pem"), ""); Assert.assertNotNull(key); } @Test(expected = InvalidKeySpecException.class) public void testEncryptedNullPassword() throws Exception { SslContext.toPrivateKey( - new File(getClass().getResource("test_encrypted_empty_pass.pem").getFile()), null); + ResourcesUtil.getFile(getClass(), "test_encrypted_empty_pass.pem"), null); } @Test public void testSslServerWithEncryptedPrivateKey() throws SSLException { - File keyFile = new File(getClass().getResource("test_encrypted.pem").getFile()); - File crtFile = new File(getClass().getResource("test.crt").getFile()); + File keyFile = ResourcesUtil.getFile(getClass(), "test_encrypted.pem"); + File crtFile = ResourcesUtil.getFile(getClass(), "test.crt"); newServerContext(crtFile, keyFile, "12345"); } @Test public void testSslServerWithEncryptedPrivateKey2() throws SSLException { - File keyFile = new File(getClass().getResource("test2_encrypted.pem").getFile()); - File crtFile = new File(getClass().getResource("test2.crt").getFile()); + File keyFile = ResourcesUtil.getFile(getClass(), "test2_encrypted.pem"); + File crtFile = ResourcesUtil.getFile(getClass(), "test2.crt"); newServerContext(crtFile, keyFile, "12345"); } @Test public void testSslServerWithUnencryptedPrivateKey() throws SSLException { - File keyFile = new File(getClass().getResource("test_unencrypted.pem").getFile()); - File crtFile = new File(getClass().getResource("test.crt").getFile()); + File keyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + File crtFile = ResourcesUtil.getFile(getClass(), "test.crt"); newServerContext(crtFile, keyFile, null); } @Test(expected = SSLException.class) public void testSslServerWithUnencryptedPrivateKeyEmptyPass() throws SSLException { - File keyFile = new File(getClass().getResource("test_unencrypted.pem").getFile()); - File crtFile = new File(getClass().getResource("test.crt").getFile()); + File keyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + File crtFile = ResourcesUtil.getFile(getClass(), "test.crt"); newServerContext(crtFile, keyFile, ""); } @@ -107,12 +109,17 @@ public void testSupportedCiphers() throws KeyManagementException, NoSuchAlgorith exception = e; } assumeNotNull(exception); - File keyFile = new File(getClass().getResource("test_unencrypted.pem").getFile()); - File crtFile = new File(getClass().getResource("test.crt").getFile()); + File keyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + File crtFile = ResourcesUtil.getFile(getClass(), "test.crt"); SslContext sslContext = newServerContext(crtFile, keyFile, null); assertFalse(sslContext.cipherSuites().contains(unsupportedCipher)); } + @Test(expected = CertificateException.class) + public void test() throws CertificateException { + SslContext.toX509Certificates(new File(getClass().getResource("ec_params_unsupported.pem").getFile())); + } + protected abstract SslContext newServerContext(File crtFile, File keyFile, String pass) throws SSLException; } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslErrorTest.java b/handler/src/test/java/io/netty/handler/ssl/SslErrorTest.java index 27aa9bfe16d7..fbe31d73f7a1 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslErrorTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslErrorTest.java @@ -124,9 +124,10 @@ public void testCorrectAlert() throws Exception { Assume.assumeTrue(OpenSsl.isAvailable()); SelfSignedCertificate ssc = new SelfSignedCertificate(); - final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) - .sslProvider(serverProvider) - .trustManager(new SimpleTrustManagerFactory() { + final SslContext sslServerCtx = + SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(serverProvider) + .trustManager(new SimpleTrustManagerFactory() { @Override protected void engineInit(KeyStore keyStore) { } @Override @@ -203,14 +204,24 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { if (reason == CertPathValidatorException.BasicReason.EXPIRED) { verifyException(unwrappedCause, "expired", promise); } else if (reason == CertPathValidatorException.BasicReason.NOT_YET_VALID) { - verifyException(unwrappedCause, "bad", promise); + // BoringSSL uses "expired" in this case while others use "bad" + if (OpenSsl.isBoringSSL()) { + verifyException(unwrappedCause, "expired", promise); + } else { + verifyException(unwrappedCause, "bad", promise); + } } else if (reason == CertPathValidatorException.BasicReason.REVOKED) { verifyException(unwrappedCause, "revoked", promise); } } else if (exception instanceof CertificateExpiredException) { verifyException(unwrappedCause, "expired", promise); } else if (exception instanceof CertificateNotYetValidException) { - verifyException(unwrappedCause, "bad", promise); + // BoringSSL uses "expired" in this case while others use "bad" + if (OpenSsl.isBoringSSL()) { + verifyException(unwrappedCause, "expired", promise); + } else { + verifyException(unwrappedCause, "bad", promise); + } } else if (exception instanceof CertificateRevokedException) { verifyException(unwrappedCause, "revoked", promise); } @@ -242,14 +253,16 @@ private static void verifyException(Throwable cause, String messagePart, Promise if (message.toLowerCase(Locale.UK).contains(messagePart.toLowerCase(Locale.UK))) { promise.setSuccess(null); } else { - promise.setFailure(new AssertionError("message not contains '" + messagePart + "': " + message)); + Throwable error = new AssertionError("message not contains '" + messagePart + "': " + message); + error.initCause(cause); + promise.setFailure(error); } } private static final class TestCertificateException extends CertificateException { private static final long serialVersionUID = -5816338303868751410L; - public TestCertificateException(Throwable cause) { + TestCertificateException(Throwable cause) { super(cause); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 02cc3ee5b5b0..6749fe4ebc61 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -53,7 +53,10 @@ import io.netty.util.ReferenceCounted; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.ImmediateExecutor; import io.netty.util.concurrent.Promise; +import org.hamcrest.CoreMatchers; import org.junit.Test; import java.net.InetSocketAddress; @@ -62,8 +65,13 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; @@ -71,9 +79,7 @@ import javax.net.ssl.SSLProtocolException; import static io.netty.buffer.Unpooled.wrappedBuffer; -import static org.hamcrest.CoreMatchers.instanceOf; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; @@ -132,6 +138,16 @@ public void testServerHandshakeTimeout() throws Exception { testHandshakeTimeout(false); } + private static SSLEngine newServerModeSSLEngine() throws NoSuchAlgorithmException { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + // Set the mode before we try to do the handshake as otherwise it may throw an IllegalStateException. + // See: + // - https://docs.oracle.com/javase/10/docs/api/javax/net/ssl/SSLEngine.html#beginHandshake() + // - http://mail.openjdk.java.net/pipermail/security-dev/2018-July/017715.html + engine.setUseClientMode(false); + return engine; + } + private static void testHandshakeTimeout(boolean client) throws Exception { SSLEngine engine = SSLContext.getDefault().createSSLEngine(); engine.setUseClientMode(client); @@ -154,9 +170,7 @@ private static void testHandshakeTimeout(boolean client) throws Exception { @Test public void testTruncatedPacket() throws Exception { - SSLEngine engine = SSLContext.getDefault().createSSLEngine(); - engine.setUseClientMode(false); - + SSLEngine engine = newServerModeSSLEngine(); EmbeddedChannel ch = new EmbeddedChannel(new SslHandler(engine)); // Push the first part of a 5-byte handshake message. @@ -182,9 +196,7 @@ public void testTruncatedPacket() throws Exception { @Test public void testNonByteBufWriteIsReleased() throws Exception { - SSLEngine engine = SSLContext.getDefault().createSSLEngine(); - engine.setUseClientMode(false); - + SSLEngine engine = newServerModeSSLEngine(); EmbeddedChannel ch = new EmbeddedChannel(new SslHandler(engine)); AbstractReferenceCounted referenceCounted = new AbstractReferenceCounted() { @@ -209,9 +221,7 @@ protected void deallocate() { @Test(expected = UnsupportedMessageTypeException.class) public void testNonByteBufNotPassThrough() throws Exception { - SSLEngine engine = SSLContext.getDefault().createSSLEngine(); - engine.setUseClientMode(false); - + SSLEngine engine = newServerModeSSLEngine(); EmbeddedChannel ch = new EmbeddedChannel(new SslHandler(engine)); try { @@ -223,9 +233,7 @@ public void testNonByteBufNotPassThrough() throws Exception { @Test public void testIncompleteWriteDoesNotCompletePromisePrematurely() throws NoSuchAlgorithmException { - SSLEngine engine = SSLContext.getDefault().createSSLEngine(); - engine.setUseClientMode(false); - + SSLEngine engine = newServerModeSSLEngine(); EmbeddedChannel ch = new EmbeddedChannel(new SslHandler(engine)); ChannelPromise promise = ch.newPromise(); @@ -397,7 +405,8 @@ public void channelInactive(ChannelHandlerContext ctx) { @Test public void testCloseFutureNotified() throws Exception { - SslHandler handler = new SslHandler(SSLContext.getDefault().createSSLEngine()); + SSLEngine engine = newServerModeSSLEngine(); + SslHandler handler = new SslHandler(engine); EmbeddedChannel ch = new EmbeddedChannel(handler); ch.close(); @@ -415,7 +424,7 @@ public void testCloseFutureNotified() throws Exception { @Test(timeout = 5000) public void testEventsFired() throws Exception { - SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + SSLEngine engine = newServerModeSSLEngine(); final BlockingQueue events = new LinkedBlockingQueue(); EmbeddedChannel channel = new EmbeddedChannel(new SslHandler(engine), new ChannelInboundHandlerAdapter() { @Override @@ -626,17 +635,24 @@ protected void initChannel(Channel ch) { }); sc = sb.bind(address).syncUninterruptibly().channel(); + final AtomicReference sslHandlerRef = new AtomicReference(); Bootstrap b = new Bootstrap() .group(group) .channel(LocalChannel.class) .handler(new ChannelInitializer() { @Override protected void initChannel(Channel ch) { - ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); + SslHandler handler = sslClientCtx.newHandler(ch.alloc()); + + // We propagate the SslHandler via an AtomicReference to the outer-scope as using + // pipeline.get(...) may return null if the pipeline was teared down by the time we call it. + // This will happen if the channel was closed in the meantime. + sslHandlerRef.set(handler); + ch.pipeline().addLast(handler); } }); cc = b.connect(sc.localAddress()).syncUninterruptibly().channel(); - SslHandler handler = cc.pipeline().get(SslHandler.class); + SslHandler handler = sslHandlerRef.get(); handler.handshakeFuture().awaitUninterruptibly(); assertFalse(handler.handshakeFuture().isSuccess()); @@ -668,4 +684,313 @@ public void testOutboundClosedAfterChannelInactive() throws Exception { assertTrue(engine.isOutboundDone()); } + + @Test(timeout = 10000) + public void testHandshakeFailedByWriteBeforeChannelActive() throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .protocols(SslUtils.PROTOCOL_SSL_V3) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(SslProvider.JDK).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final CountDownLatch activeLatch = new CountDownLatch(1); + final AtomicReference errorRef = new AtomicReference(); + final SslHandler sslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT); + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInboundHandlerAdapter()) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslHandler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + throws Exception { + if (cause instanceof AssertionError) { + errorRef.set((AssertionError) cause); + } + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + activeLatch.countDown(); + } + }); + } + }).connect(sc.localAddress()).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // Write something to trigger the handshake before fireChannelActive is called. + future.channel().writeAndFlush(wrappedBuffer(new byte [] { 1, 2, 3, 4 })); + } + }).syncUninterruptibly().channel(); + + // Ensure there is no AssertionError thrown by having the handshake failed by the writeAndFlush(...) before + // channelActive(...) was called. Let's first wait for the activeLatch countdown to happen and after this + // check if we saw and AssertionError (even if we timed out waiting). + activeLatch.await(5, TimeUnit.SECONDS); + AssertionError error = errorRef.get(); + if (error != null) { + throw error; + } + assertThat(sslHandler.handshakeFuture().await().cause(), + CoreMatchers.instanceOf(SSLException.class)); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test(timeout = 10000) + public void testHandshakeTimeoutFlushStartsHandshake() throws Exception { + testHandshakeTimeout0(false); + } + + @Test(timeout = 10000) + public void testHandshakeTimeoutStartTLS() throws Exception { + testHandshakeTimeout0(true); + } + + private static void testHandshakeTimeout0(final boolean startTls) throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .startTls(true) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(SslProvider.JDK).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final SslHandler sslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT); + sslHandler.setHandshakeTimeout(500, TimeUnit.MILLISECONDS); + + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInboundHandlerAdapter()) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslHandler); + if (startTls) { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.writeAndFlush(wrappedBuffer(new byte[] { 1, 2, 3, 4 })); + } + }); + } + } + }).connect(sc.localAddress()); + if (!startTls) { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // Write something to trigger the handshake before fireChannelActive is called. + future.channel().writeAndFlush(wrappedBuffer(new byte [] { 1, 2, 3, 4 })); + } + }); + } + cc = future.syncUninterruptibly().channel(); + + Throwable cause = sslHandler.handshakeFuture().await().cause(); + assertThat(cause, CoreMatchers.instanceOf(SSLException.class)); + assertThat(cause.getMessage(), containsString("timed out")); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test + public void testHandshakeWithExecutorThatExecuteDirecty() throws Exception { + testHandshakeWithExecutor(new Executor() { + @Override + public void execute(Runnable command) { + command.run(); + } + }); + } + + @Test + public void testHandshakeWithImmediateExecutor() throws Exception { + testHandshakeWithExecutor(ImmediateExecutor.INSTANCE); + } + + @Test + public void testHandshakeWithImmediateEventExecutor() throws Exception { + testHandshakeWithExecutor(ImmediateEventExecutor.INSTANCE); + } + + @Test + public void testHandshakeWithExecutor() throws Exception { + ExecutorService executorService = Executors.newCachedThreadPool(); + try { + testHandshakeWithExecutor(executorService); + } finally { + executorService.shutdown(); + } + } + + private void testHandshakeWithExecutor(Executor executor) throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(SslProvider.JDK).build(); + + final SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(SslProvider.JDK).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final SslHandler clientSslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT, executor); + final SslHandler serverSslHandler = sslServerCtx.newHandler(UnpooledByteBufAllocator.DEFAULT, executor); + + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(serverSslHandler) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientSslHandler); + } + }).connect(sc.localAddress()); + cc = future.syncUninterruptibly().channel(); + + assertTrue(clientSslHandler.handshakeFuture().await().isSuccess()); + assertTrue(serverSslHandler.handshakeFuture().await().isSuccess()); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test + public void testClientHandshakeTimeoutBecauseExecutorNotExecute() throws Exception { + testHandshakeTimeoutBecauseExecutorNotExecute(true); + } + + @Test + public void testServerHandshakeTimeoutBecauseExecutorNotExecute() throws Exception { + testHandshakeTimeoutBecauseExecutorNotExecute(false); + } + + private void testHandshakeTimeoutBecauseExecutorNotExecute(final boolean client) throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(SslProvider.JDK).build(); + + final SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(SslProvider.JDK).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final SslHandler clientSslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT, new Executor() { + @Override + public void execute(Runnable command) { + if (!client) { + command.run(); + } + // Do nothing to simulate slow execution. + } + }); + if (client) { + clientSslHandler.setHandshakeTimeout(100, TimeUnit.MILLISECONDS); + } + final SslHandler serverSslHandler = sslServerCtx.newHandler(UnpooledByteBufAllocator.DEFAULT, new Executor() { + @Override + public void execute(Runnable command) { + if (client) { + command.run(); + } + // Do nothing to simulate slow execution. + } + }); + if (!client) { + serverSslHandler.setHandshakeTimeout(100, TimeUnit.MILLISECONDS); + } + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(serverSslHandler) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientSslHandler); + } + }).connect(sc.localAddress()); + cc = future.syncUninterruptibly().channel(); + + if (client) { + Throwable cause = clientSslHandler.handshakeFuture().await().cause(); + assertThat(cause, CoreMatchers.instanceOf(SSLException.class)); + assertThat(cause.getMessage(), containsString("timed out")); + assertFalse(serverSslHandler.handshakeFuture().await().isSuccess()); + } else { + Throwable cause = serverSslHandler.handshakeFuture().await().cause(); + assertThat(cause, CoreMatchers.instanceOf(SSLException.class)); + assertThat(cause.getMessage(), containsString("timed out")); + assertFalse(clientSslHandler.handshakeFuture().await().isSuccess()); + } + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + } + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslUtilsTest.java b/handler/src/test/java/io/netty/handler/ssl/SslUtilsTest.java index c1de33dd6e3d..ce9a22d717f8 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslUtilsTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslUtilsTest.java @@ -28,6 +28,7 @@ import static io.netty.handler.ssl.SslUtils.getEncryptedPacketLength; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class SslUtilsTest { @@ -63,4 +64,15 @@ private static SSLEngine newEngine() throws SSLException, NoSuchAlgorithmExcepti engine.beginHandshake(); return engine; } + + @Test + public void testIsTLSv13Cipher() { + assertTrue(SslUtils.isTLSv13Cipher("TLS_AES_128_GCM_SHA256")); + assertTrue(SslUtils.isTLSv13Cipher("TLS_AES_256_GCM_SHA384")); + assertTrue(SslUtils.isTLSv13Cipher("TLS_CHACHA20_POLY1305_SHA256")); + assertTrue(SslUtils.isTLSv13Cipher("TLS_AES_128_CCM_SHA256")); + assertTrue(SslUtils.isTLSv13Cipher("TLS_AES_128_CCM_8_SHA256")); + assertFalse(SslUtils.isTLSv13Cipher("TLS_DHE_RSA_WITH_AES_128_GCM_SHA256")); + } + } diff --git a/handler/src/test/java/io/netty/handler/ssl/ocsp/OcspTest.java b/handler/src/test/java/io/netty/handler/ssl/ocsp/OcspTest.java index 161f52a99f64..cfed8240463b 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ocsp/OcspTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ocsp/OcspTest.java @@ -459,7 +459,7 @@ private static final class TestClientOcspContext implements OcspClientCallback { private volatile byte[] response; - public TestClientOcspContext(boolean valid) { + TestClientOcspContext(boolean valid) { this.valid = valid; } @@ -481,7 +481,7 @@ private static final class OcspClientCallbackHandler extends OcspClientHandler { private final OcspClientCallback callback; - public OcspClientCallbackHandler(ReferenceCountedOpenSslEngine engine, OcspClientCallback callback) { + OcspClientCallbackHandler(ReferenceCountedOpenSslEngine engine, OcspClientCallback callback) { super(engine); this.callback = callback; } @@ -496,7 +496,7 @@ protected boolean verify(ChannelHandlerContext ctx, ReferenceCountedOpenSslEngin private static final class OcspTestException extends IllegalStateException { private static final long serialVersionUID = 4516426833250228159L; - public OcspTestException(String message) { + OcspTestException(String message) { super(message); } } diff --git a/handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java b/handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java index 66b69516fc0d..be6951d88bdb 100644 --- a/handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java @@ -21,8 +21,11 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; import org.junit.Test; import java.io.ByteArrayInputStream; @@ -30,11 +33,12 @@ import java.io.FileOutputStream; import java.io.IOException; import java.nio.channels.Channels; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; +import static java.util.concurrent.TimeUnit.*; +import static org.junit.Assert.*; public class ChunkedWriteHandlerTest { private static final byte[] BYTES = new byte[1024 * 64]; @@ -162,8 +166,7 @@ public void operationComplete(ChannelFuture future) throws Exception { EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); ch.writeAndFlush(input).addListener(listener).syncUninterruptibly(); - ch.checkException(); - ch.finish(); + assertTrue(ch.finish()); // the listener should have been notified assertTrue(listenerNotified.get()); @@ -220,13 +223,354 @@ public long progress() { EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); ch.writeAndFlush(input).syncUninterruptibly(); - ch.checkException(); assertTrue(ch.finish()); assertEquals(0, ch.readOutbound()); assertNull(ch.readOutbound()); } + @Test + public void testWriteFailureChunkedStream() throws IOException { + checkFirstFailed(new ChunkedStream(new ByteArrayInputStream(BYTES))); + } + + @Test + public void testWriteFailureChunkedNioStream() throws IOException { + checkFirstFailed(new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES)))); + } + + @Test + public void testWriteFailureChunkedFile() throws IOException { + checkFirstFailed(new ChunkedFile(TMP)); + } + + @Test + public void testWriteFailureChunkedNioFile() throws IOException { + checkFirstFailed(new ChunkedNioFile(TMP)); + } + + @Test + public void testWriteFailureUnchunkedData() throws IOException { + checkFirstFailed(Unpooled.wrappedBuffer(BYTES)); + } + + @Test + public void testSkipAfterFailedChunkedStream() throws IOException { + checkSkipFailed(new ChunkedStream(new ByteArrayInputStream(BYTES)), + new ChunkedStream(new ByteArrayInputStream(BYTES))); + } + + @Test + public void testSkipAfterFailedChunkedNioStream() throws IOException { + checkSkipFailed(new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES))), + new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES)))); + } + + @Test + public void testSkipAfterFailedChunkedFile() throws IOException { + checkSkipFailed(new ChunkedFile(TMP), new ChunkedFile(TMP)); + } + + @Test + public void testSkipAfterFailedChunkedNioFile() throws IOException { + checkSkipFailed(new ChunkedNioFile(TMP), new ChunkedFile(TMP)); + } + + // See https://github.com/netty/netty/issues/8700. + @Test + public void testFailureWhenLastChunkFailed() throws IOException { + ChannelOutboundHandlerAdapter failLast = new ChannelOutboundHandlerAdapter() { + private int passedWrites; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (++this.passedWrites < 4) { + ctx.write(msg, promise); + } else { + ReferenceCountUtil.release(msg); + promise.tryFailure(new RuntimeException()); + } + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(failLast, new ChunkedWriteHandler()); + ChannelFuture r = ch.writeAndFlush(new ChunkedFile(TMP, 1024 * 16)); // 4 chunks + assertTrue(ch.finish()); + + assertFalse(r.isSuccess()); + assertTrue(r.cause() instanceof RuntimeException); + + // 3 out of 4 chunks were already written + int read = 0; + for (;;) { + ByteBuf buffer = ch.readOutbound(); + if (buffer == null) { + break; + } + read += buffer.readableBytes(); + buffer.release(); + } + + assertEquals(1024 * 16 * 3, read); + } + + @Test + public void testDiscardPendingWritesOnInactive() throws IOException { + + final AtomicBoolean closeWasCalled = new AtomicBoolean(false); + + ChunkedInput notifiableInput = new ChunkedInput() { + private boolean done; + private final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.ISO_8859_1); + + @Override + public boolean isEndOfInput() throws Exception { + return done; + } + + @Override + public void close() throws Exception { + buffer.release(); + closeWasCalled.set(true); + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + if (done) { + return null; + } + done = true; + return buffer.retainedDuplicate(); + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return 1; + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + // Write 3 messages and close channel before flushing + ChannelFuture r1 = ch.write(new ChunkedFile(TMP)); + ChannelFuture r2 = ch.write(new ChunkedNioFile(TMP)); + ch.write(notifiableInput); + + // Should be `false` as we do not expect any messages to be written + assertFalse(ch.finish()); + + assertFalse(r1.isSuccess()); + assertFalse(r2.isSuccess()); + assertTrue(closeWasCalled.get()); + } + + // See https://github.com/netty/netty/issues/8700. + @Test + public void testStopConsumingChunksWhenFailed() { + final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.ISO_8859_1); + final AtomicInteger chunks = new AtomicInteger(0); + + ChunkedInput nonClosableInput = new ChunkedInput() { + @Override + public boolean isEndOfInput() throws Exception { + return chunks.get() >= 5; + } + + @Override + public void close() throws Exception { + // no-op + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + chunks.incrementAndGet(); + return buffer.retainedDuplicate(); + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return 1; + } + }; + + ChannelOutboundHandlerAdapter noOpWrites = new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + ReferenceCountUtil.release(msg); + promise.tryFailure(new RuntimeException()); + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(noOpWrites, new ChunkedWriteHandler()); + ch.writeAndFlush(nonClosableInput).awaitUninterruptibly(); + // Should be `false` as we do not expect any messages to be written + assertFalse(ch.finish()); + buffer.release(); + + // We should expect only single chunked being read from the input. + // It's possible to get a race condition here between resolving a promise and + // allocating a new chunk, but should be fine when working with embedded channels. + assertEquals(1, chunks.get()); + } + + @Test + public void testCloseSuccessfulChunkedInput() { + int chunks = 10; + TestChunkedInput input = new TestChunkedInput(chunks); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + assertTrue(ch.writeOutbound(input)); + + for (int i = 0; i < chunks; i++) { + ByteBuf buf = ch.readOutbound(); + assertEquals(i, buf.readInt()); + buf.release(); + } + + assertTrue(input.isClosed()); + assertFalse(ch.finish()); + } + + @Test + public void testCloseFailedChunkedInput() { + Exception error = new Exception("Unable to produce a chunk"); + ThrowingChunkedInput input = new ThrowingChunkedInput(error); + + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + try { + ch.writeOutbound(input); + fail("Exception expected"); + } catch (Exception e) { + assertEquals(error, e); + } + + assertTrue(input.isClosed()); + assertFalse(ch.finish()); + } + + @Test + public void testWriteListenerInvokedAfterSuccessfulChunkedInputClosed() throws Exception { + final TestChunkedInput input = new TestChunkedInput(2); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertTrue(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testWriteListenerInvokedAfterFailedChunkedInputClosed() throws Exception { + final ThrowingChunkedInput input = new ThrowingChunkedInput(new RuntimeException()); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertFalse(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertFalse(ch.finish()); + } + + @Test + public void testWriteListenerInvokedAfterChannelClosedAndInputFullyConsumed() throws Exception { + // use empty input which has endOfInput = true + final TestChunkedInput input = new TestChunkedInput(0); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.close(); // close channel to make handler discard the input on subsequent flush + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertTrue(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertFalse(ch.finish()); + } + + @Test + public void testWriteListenerInvokedAfterChannelClosedAndInputNotFullyConsumed() throws Exception { + // use non-empty input which has endOfInput = false + final TestChunkedInput input = new TestChunkedInput(42); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.close(); // close channel to make handler discard the input on subsequent flush + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertFalse(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertFalse(ch.finish()); + } + private static void check(Object... inputs) { EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); @@ -255,4 +599,159 @@ private static void check(Object... inputs) { assertEquals(BYTES.length * inputs.length, read); } + + private static void checkFirstFailed(Object input) { + ChannelOutboundHandlerAdapter noOpWrites = new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + ReferenceCountUtil.release(msg); + promise.tryFailure(new RuntimeException()); + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(noOpWrites, new ChunkedWriteHandler()); + ChannelFuture r = ch.writeAndFlush(input); + + // Should be `false` as we do not expect any messages to be written + assertFalse(ch.finish()); + assertTrue(r.cause() instanceof RuntimeException); + } + + private static void checkSkipFailed(Object input1, Object input2) { + ChannelOutboundHandlerAdapter failFirst = new ChannelOutboundHandlerAdapter() { + private boolean alreadyFailed; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (alreadyFailed) { + ctx.write(msg, promise); + } else { + this.alreadyFailed = true; + ReferenceCountUtil.release(msg); + promise.tryFailure(new RuntimeException()); + } + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(failFirst, new ChunkedWriteHandler()); + ChannelFuture r1 = ch.write(input1); + ChannelFuture r2 = ch.writeAndFlush(input2).awaitUninterruptibly(); + assertTrue(ch.finish()); + + assertTrue(r1.cause() instanceof RuntimeException); + assertTrue(r2.isSuccess()); + + // note, that after we've "skipped" the first write, + // we expect to see the second message, chunk by chunk + int i = 0; + int read = 0; + for (;;) { + ByteBuf buffer = ch.readOutbound(); + if (buffer == null) { + break; + } + while (buffer.isReadable()) { + assertEquals(BYTES[i++], buffer.readByte()); + read++; + if (i == BYTES.length) { + i = 0; + } + } + buffer.release(); + } + + assertEquals(BYTES.length, read); + } + + private static final class TestChunkedInput implements ChunkedInput { + private final int chunksToProduce; + + private int chunksProduced; + private volatile boolean closed; + + TestChunkedInput(int chunksToProduce) { + this.chunksToProduce = chunksToProduce; + } + + @Override + public boolean isEndOfInput() { + return chunksProduced >= chunksToProduce; + } + + @Override + public void close() { + closed = true; + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) { + ByteBuf buf = allocator.buffer(); + buf.writeInt(chunksProduced); + chunksProduced++; + return buf; + } + + @Override + public long length() { + return chunksToProduce; + } + + @Override + public long progress() { + return chunksProduced; + } + + boolean isClosed() { + return closed; + } + } + + private static final class ThrowingChunkedInput implements ChunkedInput { + private final Exception error; + + private volatile boolean closed; + + ThrowingChunkedInput(Exception error) { + this.error = error; + } + + @Override + public boolean isEndOfInput() { + return false; + } + + @Override + public void close() { + closed = true; + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + throw error; + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return -1; + } + + boolean isClosed() { + return closed; + } + } } diff --git a/handler/src/test/java/io/netty/handler/timeout/IdleStateHandlerTest.java b/handler/src/test/java/io/netty/handler/timeout/IdleStateHandlerTest.java index 0a5b6278224e..a27364f43933 100644 --- a/handler/src/test/java/io/netty/handler/timeout/IdleStateHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/timeout/IdleStateHandlerTest.java @@ -317,7 +317,7 @@ private static class TestableIdleStateHandler extends IdleStateHandler { private long ticksInNanos; - public TestableIdleStateHandler(boolean observeOutput, + TestableIdleStateHandler(boolean observeOutput, long readerIdleTime, long writerIdleTime, long allIdleTime, TimeUnit unit) { super(observeOutput, readerIdleTime, writerIdleTime, allIdleTime, unit); @@ -369,7 +369,7 @@ ScheduledFuture schedule(ChannelHandlerContext ctx, Runnable task, long delay private static class ObservableChannel extends EmbeddedChannel { - public ObservableChannel(ChannelHandler... handlers) { + ObservableChannel(ChannelHandler... handlers) { super(handlers); } diff --git a/handler/src/test/resources/io/netty/handler/ssl/ec_params_unsupported.pem b/handler/src/test/resources/io/netty/handler/ssl/ec_params_unsupported.pem new file mode 100644 index 000000000000..cafaea422434 --- /dev/null +++ b/handler/src/test/resources/io/netty/handler/ssl/ec_params_unsupported.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC8TCCApagAwIBAgIJAOeu9WKx0IutMAoGCCqGSM49BAMCMFkxCzAJBgNVBAYT +AkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn +aXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0xODExMDEyMDAwMTha +Fw0yMDEwMzEyMDAwMThaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 +YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMM +CWxvY2FsaG9zdDCCAUswggEDBgcqhkjOPQIBMIH3AgEBMCwGByqGSM49AQECIQD/ +////AAAAAQAAAAAAAAAAAAAAAP///////////////zBbBCD/////AAAAAQAAAAAA +AAAAAAAAAP///////////////AQgWsY12Ko6k+ez671VdpiGvGUdBrDMU7D2O848 +PifSYEsDFQDEnTYIhucEk2pmeOETnSa3gZ9+kARBBGsX0fLhLEJH+Lzm5WOkQPJ3 +A32BLeszoPShOUXYmMKWT+NC4v4af5uO5+tKfA+eFivOM1drMV7Oy7ZAaDe/UfUC +IQD/////AAAAAP//////////vOb6racXnoTzucrC/GMlUQIBAQNCAAQ3G/YXF+YE +XuASiyC1822n0iNPumHgFplF+6/veicKm+mDNA3NA/1zTRKJOyqpDdMyB9tgFrdV +zcHzw7JW+lDpo1MwUTAdBgNVHQ4EFgQUonraQIcnNMppU+GoJ6+vPbC84pEwHwYD +VR0jBBgwFoAUonraQIcnNMppU+GoJ6+vPbC84pEwDwYDVR0TAQH/BAUwAwEB/zAK +BggqhkjOPQQDAgNJADBGAiEAoIkAinhds0VvNtWdi6f+r+U8AA9rUsR1sJBzVOYD +ErACIQCMMyfEWW8d4N3q8fpZ/lWTNaionVWeZZHWjseTmafWQg== +-----END CERTIFICATE----- diff --git a/microbench/pom.xml b/microbench/pom.xml index 0cabb582bf7c..75e7356f717f 100644 --- a/microbench/pom.xml +++ b/microbench/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-microbench @@ -31,9 +31,15 @@ true - 1.19 + 1.21 + + + + + true - linux @@ -42,6 +48,9 @@ linux + + ${jni.classifier} + @@ -55,6 +64,71 @@ + + mac + + + mac + + + + ${jni.classifier} + + + + + maven-compiler-plugin + + + **/*.java + + + + + + + + benchmark-jar + + + + org.apache.maven.plugins + maven-shade-plugin + 2.2 + + + package + + shade + + + microbenchmarks + + + org.openjdk.jmh.Main + + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + @@ -78,6 +152,18 @@ netty-codec-redis ${project.version} + + ${project.groupId} + netty-transport-native-epoll + ${project.version} + ${epoll.classifier} + + + ${project.groupId} + netty-transport-native-kqueue + ${project.version} + ${kqueue.classifier} + junit junit diff --git a/microbench/src/main/java/io/netty/buffer/AbstractByteBufGetCharSequenceBenchmark.java b/microbench/src/main/java/io/netty/buffer/AbstractByteBufGetCharSequenceBenchmark.java new file mode 100644 index 000000000000..6fcdb94cc893 --- /dev/null +++ b/microbench/src/main/java/io/netty/buffer/AbstractByteBufGetCharSequenceBenchmark.java @@ -0,0 +1,124 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.microbench.util.AbstractMicrobenchmark; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; + +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS) +public class AbstractByteBufGetCharSequenceBenchmark extends AbstractMicrobenchmark { + + public enum ByteBufType { + DIRECT { + @Override + ByteBuf newBuffer(byte[] bytes, int length) { + ByteBuf buffer = Unpooled.directBuffer(length); + buffer.writeBytes(bytes, 0, length); + return buffer; + } + }, + HEAP_OFFSET { + @Override + ByteBuf newBuffer(byte[] bytes, int length) { + return Unpooled.wrappedBuffer(bytes, 1, length); + } + }, + HEAP { + @Override + ByteBuf newBuffer(byte[] bytes, int length) { + return Unpooled.wrappedBuffer(bytes, 0, length); + } + }, + COMPOSITE { + @Override + ByteBuf newBuffer(byte[] bytes, int length) { + CompositeByteBuf buffer = Unpooled.compositeBuffer(); + int offset = 0; + // 8 buffers per composite. + int capacity = length / 8; + + while (length > 0) { + buffer.addComponent(true, Unpooled.wrappedBuffer(bytes, offset, Math.min(length, capacity))); + length -= capacity; + offset += capacity; + } + return buffer; + } + }; + abstract ByteBuf newBuffer(byte[] bytes, int length); + } + + @Param({ "8", "64", "1024", "10240", "1073741824" }) + public int size; + + @Param({ "US-ASCII", "ISO_8859_1" }) + public String charsetName; + + @Param + public ByteBufType bufferType; + + private ByteBuf buffer; + private Charset charset; + + @Override + protected String[] jvmArgs() { + // Ensure we minimize the GC overhead by sizing the heap big enough. + return new String[] { "-XX:MaxDirectMemorySize=2g", "-Xmx8g", "-Xms8g", "-Xmn6g" }; + } + + @Setup + public void setup() { + byte[] bytes = new byte[size + 2]; + Arrays.fill(bytes, (byte) 'a'); + + // Use an offset to not allow any optimizations because we use the exact passed in byte[] for heap buffers. + buffer = bufferType.newBuffer(bytes, size); + charset = Charset.forName(charsetName); + } + + @TearDown + public void teardown() { + buffer.release(); + } + + @Benchmark + public int getCharSequence() { + return traverse(buffer.getCharSequence(buffer.readerIndex(), size, charset)); + } + + @Benchmark + public int getCharSequenceOld() { + return traverse(buffer.toString(buffer.readerIndex(), size, charset)); + } + + private static int traverse(CharSequence cs) { + int i = 0, len = cs.length(); + while (i < len && cs.charAt(i++) != 0) { + // ensure result is "used" + } + return i; + } +} diff --git a/microbench/src/main/java/io/netty/buffer/ByteBufAccessBenchmark.java b/microbench/src/main/java/io/netty/buffer/ByteBufAccessBenchmark.java new file mode 100644 index 000000000000..2614408a9fb3 --- /dev/null +++ b/microbench/src/main/java/io/netty/buffer/ByteBufAccessBenchmark.java @@ -0,0 +1,157 @@ +/* +* Copyright 2019 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +package io.netty.buffer; + +import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import io.netty.microbench.util.AbstractMicrobenchmark; +import io.netty.util.internal.PlatformDependent; + +@Warmup(iterations = 5, time = 1500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 1500, timeUnit = TimeUnit.MILLISECONDS) +@Fork(3) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class ByteBufAccessBenchmark extends AbstractMicrobenchmark { + + static final class NioFacade extends WrappedByteBuf { + private final ByteBuffer byteBuffer; + NioFacade(ByteBuffer byteBuffer) { + super(Unpooled.EMPTY_BUFFER); + this.byteBuffer = byteBuffer; + } + @Override + public ByteBuf setLong(int index, long value) { + byteBuffer.putLong(index, value); + return this; + } + @Override + public long getLong(int index) { + return byteBuffer.getLong(index); + } + @Override + public byte readByte() { + return byteBuffer.get(); + } + @Override + public ByteBuf touch() { + // hack since WrappedByteBuf.readerIndex(int) is final + byteBuffer.position(0); + return this; + } + @Override + public boolean release() { + PlatformDependent.freeDirectWithCleaner(byteBuffer); + return true; + } + } + + public enum ByteBufType { + UNSAFE { + @Override + ByteBuf newBuffer() { + return new UnpooledUnsafeDirectByteBuf( + UnpooledByteBufAllocator.DEFAULT, 64, 64).setIndex(0, 64); + } + }, + UNSAFE_SLICE { + @Override + ByteBuf newBuffer() { + return UNSAFE.newBuffer().slice(16, 48); + } + }, + HEAP { + @Override + ByteBuf newBuffer() { + return new UnpooledUnsafeHeapByteBuf( + UnpooledByteBufAllocator.DEFAULT, 64, 64).setIndex(0, 64); + } + }, + COMPOSITE { + @Override + ByteBuf newBuffer() { + return Unpooled.wrappedBuffer(UNSAFE.newBuffer(), HEAP.newBuffer()); + } + }, + NIO { + @Override + ByteBuf newBuffer() { + return new NioFacade(ByteBuffer.allocateDirect(64)); + } + }; + abstract ByteBuf newBuffer(); + } + + @Param + public ByteBufType bufferType; + + @Param({ "true", "false" }) + public String checkAccessible; + + @Param({ "true", "false" }) + public String checkBounds; + + @Param({ "8" }) + public int batchSize; // applies only to readBatch benchmark + + @Setup + public void setup() { + System.setProperty("io.netty.buffer.checkAccessible", checkAccessible); + System.setProperty("io.netty.buffer.checkBounds", checkBounds); + buffer = bufferType.newBuffer(); + } + + private ByteBuf buffer; + + @TearDown + public void tearDown() { + buffer.release(); + System.clearProperty("io.netty.buffer.checkAccessible"); + System.clearProperty("io.netty.buffer.checkBounds"); + } + + @Benchmark + public long setGetLong() { + return buffer.setLong(0, 1).getLong(0); + } + + @Benchmark + public ByteBuf setLong() { + return buffer.setLong(0, 1); + } + + @Benchmark + public int readBatch() { + buffer.readerIndex(0).touch(); + int result = 0; + for (int i = 0, size = batchSize; i < size; i++) { + result += buffer.readByte(); + } + return result; + } +} diff --git a/microbench/src/main/java/io/netty/buffer/ByteBufUtilDecodeStringBenchmark.java b/microbench/src/main/java/io/netty/buffer/ByteBufUtilDecodeStringBenchmark.java new file mode 100644 index 000000000000..367a645af094 --- /dev/null +++ b/microbench/src/main/java/io/netty/buffer/ByteBufUtilDecodeStringBenchmark.java @@ -0,0 +1,112 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.microbench.util.AbstractMicrobenchmark; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; + +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS) +public class ByteBufUtilDecodeStringBenchmark extends AbstractMicrobenchmark { + + public enum ByteBufType { + DIRECT { + @Override + ByteBuf newBuffer(byte[] bytes, int length) { + ByteBuf buffer = Unpooled.directBuffer(length); + buffer.writeBytes(bytes, 0, length); + return buffer; + } + }, + HEAP_OFFSET { + @Override + ByteBuf newBuffer(byte[] bytes, int length) { + return Unpooled.wrappedBuffer(bytes, 1, length); + } + }, + HEAP { + @Override + ByteBuf newBuffer(byte[] bytes, int length) { + return Unpooled.wrappedBuffer(bytes, 0, length); + } + }, + COMPOSITE { + @Override + ByteBuf newBuffer(byte[] bytes, int length) { + CompositeByteBuf buffer = Unpooled.compositeBuffer(); + int offset = 0; + // 8 buffers per composite. + int capacity = length / 8; + + while (length > 0) { + buffer.addComponent(true, Unpooled.wrappedBuffer(bytes, offset, Math.min(length, capacity))); + length -= capacity; + offset += capacity; + } + return buffer; + } + }; + + abstract ByteBuf newBuffer(byte[] bytes, int length); + } + + @Param({ "8", "64", "1024", "10240", "1073741824" }) + public int size; + + @Param({ "US-ASCII", "UTF-8" }) + public String charsetName; + + @Param + public ByteBufType bufferType; + + private ByteBuf buffer; + private Charset charset; + + @Override + protected String[] jvmArgs() { + // Ensure we minimize the GC overhead by sizing the heap big enough. + return new String[] { "-XX:MaxDirectMemorySize=2g", "-Xmx8g", "-Xms8g", "-Xmn6g" }; + } + + @Setup + public void setup() { + byte[] bytes = new byte[size + 2]; + Arrays.fill(bytes, (byte) 'a'); + + // Use an offset to not allow any optimizations because we use the exact passed in byte[] for heap buffers. + buffer = bufferType.newBuffer(bytes, size); + charset = Charset.forName(charsetName); + } + + @TearDown + public void teardown() { + buffer.release(); + } + + @Benchmark + public String decodeString() { + return ByteBufUtil.decodeString(buffer, buffer.readerIndex(), size, charset); + } +} diff --git a/microbench/src/main/java/io/netty/buffer/CompositeByteBufRandomAccessBenchmark.java b/microbench/src/main/java/io/netty/buffer/CompositeByteBufRandomAccessBenchmark.java new file mode 100644 index 000000000000..bb2d9283f0f9 --- /dev/null +++ b/microbench/src/main/java/io/netty/buffer/CompositeByteBufRandomAccessBenchmark.java @@ -0,0 +1,118 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import io.netty.microbench.util.AbstractMicrobenchmark; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.buffer.Unpooled.wrappedBuffer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS) +public class CompositeByteBufRandomAccessBenchmark extends AbstractMicrobenchmark { + + public enum ByteBufType { + SMALL_CHUNKS { + @Override + ByteBuf newBuffer(int length) { + return newBufferSmallChunks(length); + } + }, + LARGE_CHUNKS { + @Override + ByteBuf newBuffer(int length) { + return newBufferLargeChunks(length); + } + }; + abstract ByteBuf newBuffer(int length); + } + + @Param({ "64", "10240", "1024000" }) // ({ "64", "1024", "10240", "102400", "1024000" }) + public int size; + + @Param + public ByteBufType bufferType; + + private ByteBuf buffer; + private Random random; + + @Setup + public void setup() { + buffer = bufferType.newBuffer(size); + random = new Random(0L); + } + + @TearDown + public void teardown() { + buffer.release(); + } + + @Benchmark + public long setGetLong() { + int i = random.nextInt(size - 8); + return buffer.setLong(i, 1).getLong(i); + } + + @Benchmark + public ByteBuf setLong() { + int i = random.nextInt(size - 8); + return buffer.setLong(i, 1); + } + + private static ByteBuf newBufferSmallChunks(int length) { + + List buffers = new ArrayList(((length + 1) / 45) * 19); + for (int i = 0; i < length + 45; i += 45) { + for (int j = 1; j <= 9; j++) { + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[j])); + } + buffers.add(EMPTY_BUFFER); + } + + ByteBuf buffer = wrappedBuffer(Integer.MAX_VALUE, buffers.toArray(new ByteBuf[0])); + + // Truncate to the requested capacity. + return buffer.capacity(length).writerIndex(0); + } + + private static ByteBuf newBufferLargeChunks(int length) { + + List buffers = new ArrayList((length + 1) / 512); + for (int i = 0; i < length + 1536; i += 1536) { + buffers.add(wrappedBuffer(new byte[512])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[1024])); + } + + ByteBuf buffer = wrappedBuffer(Integer.MAX_VALUE, buffers.toArray(new ByteBuf[0])); + + // Truncate to the requested capacity. + return buffer.capacity(length).writerIndex(0); + } +} diff --git a/microbench/src/main/java/io/netty/buffer/CompositeByteBufSequentialBenchmark.java b/microbench/src/main/java/io/netty/buffer/CompositeByteBufSequentialBenchmark.java new file mode 100644 index 000000000000..9ba6f19aa825 --- /dev/null +++ b/microbench/src/main/java/io/netty/buffer/CompositeByteBufSequentialBenchmark.java @@ -0,0 +1,132 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.microbench.util.AbstractMicrobenchmark; +import io.netty.util.ByteProcessor; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.buffer.Unpooled.wrappedBuffer; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS) +public class CompositeByteBufSequentialBenchmark extends AbstractMicrobenchmark { + + public enum ByteBufType { + SMALL_CHUNKS { + @Override + ByteBuf newBuffer(int length) { + return newBufferSmallChunks(length); + } + }, + LARGE_CHUNKS { + @Override + ByteBuf newBuffer(int length) { + return newBufferLargeChunks(length); + } + }; + abstract ByteBuf newBuffer(int length); + } + + @Param({ "8", "64", "1024", "10240", "102400", "1024000" }) + public int size; + + @Param + public ByteBufType bufferType; + + private ByteBuf buffer; + + @Setup + public void setup() { + buffer = bufferType.newBuffer(size); + } + + @TearDown + public void teardown() { + buffer.release(); + } + + private static final ByteProcessor TEST_PROCESSOR = new ByteProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value == 'b'; // false + } + }; + + @Benchmark + public int forEachByte() { + buffer.setIndex(0, buffer.capacity()); + buffer.forEachByte(TEST_PROCESSOR); + return buffer.forEachByteDesc(TEST_PROCESSOR); + } + + @Benchmark + public int sequentialWriteAndRead() { + buffer.clear(); + for (int i = 0, l = buffer.writableBytes(); i < l; i++) { + buffer.writeByte('a'); + } + for (int i = 0, l = buffer.readableBytes(); i < l; i++) { + if (buffer.readByte() == 'b') { + return -1; + } + } + return 1; + } + + private static ByteBuf newBufferSmallChunks(int length) { + + List buffers = new ArrayList(((length + 1) / 45) * 19); + for (int i = 0; i < length + 45; i += 45) { + for (int j = 1; j <= 9; j++) { + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[j])); + } + buffers.add(EMPTY_BUFFER); + } + + ByteBuf buffer = wrappedBuffer(Integer.MAX_VALUE, buffers.toArray(new ByteBuf[0])); + + // Truncate to the requested capacity. + return buffer.capacity(length).writerIndex(0); + } + + private static ByteBuf newBufferLargeChunks(int length) { + + List buffers = new ArrayList((length + 1) / 512); + for (int i = 0; i < length + 1536; i += 1536) { + buffers.add(wrappedBuffer(new byte[512])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[1024])); + } + + ByteBuf buffer = wrappedBuffer(Integer.MAX_VALUE, buffers.toArray(new ByteBuf[0])); + + // Truncate to the requested capacity. + return buffer.capacity(length).writerIndex(0); + } +} diff --git a/microbench/src/main/java/io/netty/buffer/CompositeByteBufWriteOutBenchmark.java b/microbench/src/main/java/io/netty/buffer/CompositeByteBufWriteOutBenchmark.java new file mode 100644 index 000000000000..e525c8cd0aa5 --- /dev/null +++ b/microbench/src/main/java/io/netty/buffer/CompositeByteBufWriteOutBenchmark.java @@ -0,0 +1,114 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import io.netty.microbench.util.AbstractMicrobenchmark; + +import static io.netty.buffer.Unpooled.wrappedBuffer; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 12, time = 1, timeUnit = TimeUnit.SECONDS) +public class CompositeByteBufWriteOutBenchmark extends AbstractMicrobenchmark { + + public enum ByteBufType { + SMALL_CHUNKS { + @Override + ByteBuf[] sourceBuffers(int length) { + return makeSmallChunks(length); + } + }, + LARGE_CHUNKS { + @Override + ByteBuf[] sourceBuffers(int length) { + return makeLargeChunks(length); + } + }; + abstract ByteBuf[] sourceBuffers(int length); + } + + @Override + protected String[] jvmArgs() { + // Ensure we minimize the GC overhead by sizing the heap big enough. + return new String[] { "-XX:MaxDirectMemorySize=2g", "-Xmx4g", "-Xms4g", "-Xmn3g" }; + } + + @Param({ "64", "1024", "10240", "102400", "1024000" }) + public int size; + + @Param + public ByteBufType bufferType; + + private ByteBuf targetBuffer; + + private ByteBuf[] sourceBufs; + + @Setup + public void setup() { + targetBuffer = PooledByteBufAllocator.DEFAULT.directBuffer(size + 2048); + sourceBufs = bufferType.sourceBuffers(size); + } + + @TearDown + public void teardown() { + targetBuffer.release(); + } + + @Benchmark + public int writeCBB() { + ByteBuf cbb = Unpooled.wrappedBuffer(Integer.MAX_VALUE, sourceBufs); // CompositeByteBuf + return targetBuffer.clear().writeBytes(cbb).readableBytes(); + } + + @Benchmark + public int writeFCBB() { + ByteBuf cbb = Unpooled.wrappedUnmodifiableBuffer(sourceBufs); // FastCompositeByteBuf + return targetBuffer.clear().writeBytes(cbb).readableBytes(); + } + + private static ByteBuf[] makeSmallChunks(int length) { + + List buffers = new ArrayList(((length + 1) / 48) * 9); + for (int i = 0; i < length + 48; i += 48) { + for (int j = 4; j <= 12; j++) { + buffers.add(wrappedBuffer(new byte[j])); + } + } + + return buffers.toArray(new ByteBuf[0]); + } + + private static ByteBuf[] makeLargeChunks(int length) { + + List buffers = new ArrayList((length + 1) / 768); + for (int i = 0; i < length + 1536; i += 1536) { + buffers.add(wrappedBuffer(new byte[512])); + buffers.add(wrappedBuffer(new byte[1024])); + } + + return buffers.toArray(new ByteBuf[0]); + } +} diff --git a/microbench/src/main/java/io/netty/handler/codec/CodecOutputListBenchmark.java b/microbench/src/main/java/io/netty/handler/codec/CodecOutputListBenchmark.java index 8db4a99ffc25..d7d40307a5e0 100644 --- a/microbench/src/main/java/io/netty/handler/codec/CodecOutputListBenchmark.java +++ b/microbench/src/main/java/io/netty/handler/codec/CodecOutputListBenchmark.java @@ -18,10 +18,8 @@ import io.netty.microbench.util.AbstractMicrobenchmark; import io.netty.util.internal.RecyclableArrayList; import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.Level; import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.TearDown; @@ -39,13 +37,6 @@ public class CodecOutputListBenchmark extends AbstractMicrobenchmark { @Param({ "1", "4" }) public int elements; - @Setup(Level.Invocation) - public void setup() { - codecOutputList = CodecOutputList.newInstance(); - recycleableArrayList = RecyclableArrayList.newInstance(16); - arrayList = new ArrayList(16); - } - @TearDown public void destroy() { codecOutputList.recycle(); @@ -54,16 +45,19 @@ public void destroy() { @Benchmark public void codecOutList() { + codecOutputList = CodecOutputList.newInstance(); benchmarkAddAndClear(codecOutputList, elements); } @Benchmark public void recyclableArrayList() { + recycleableArrayList = RecyclableArrayList.newInstance(16); benchmarkAddAndClear(recycleableArrayList, elements); } @Benchmark public void arrayList() { + arrayList = new ArrayList(16); benchmarkAddAndClear(arrayList, elements); } diff --git a/microbench/src/main/java/io/netty/handler/codec/DateFormatter2Benchmark.java b/microbench/src/main/java/io/netty/handler/codec/DateFormatter2Benchmark.java new file mode 100644 index 000000000000..dd44ed60cc26 --- /dev/null +++ b/microbench/src/main/java/io/netty/handler/codec/DateFormatter2Benchmark.java @@ -0,0 +1,95 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.microbench.util.AbstractMicrobenchmark; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.Date; + +@Threads(1) +@Warmup(iterations = 5) +@Measurement(iterations = 5) +public class DateFormatter2Benchmark extends AbstractMicrobenchmark { + + @Param({"Sun, 27 Jan 2016 19:18:46 GMT", "Sun, 27 Dec 2016 19:18:46 GMT"}) + String DATE_STRING; + + @Benchmark + public Date parseHttpHeaderDateFormatterNew() { + return DateFormatter.parseHttpDate(DATE_STRING); + } + + /* + @Benchmark + public Date parseHttpHeaderDateFormatter() { + return DateFormatterOld.parseHttpDate(DATE_STRING); + } + */ + + /* + * Benchmark (DATE_STRING) Mode Cnt Score Error Units + * parseHttpHeaderDateFormatter Sun, 27 Jan 2016 19:18:46 GMT thrpt 6 4142781.221 ± 82155.002 ops/s + * parseHttpHeaderDateFormatter Sun, 27 Dec 2016 19:18:46 GMT thrpt 6 3781810.558 ± 38679.061 ops/s + * parseHttpHeaderDateFormatterNew Sun, 27 Jan 2016 19:18:46 GMT thrpt 6 4372569.705 ± 30257.537 ops/s + * parseHttpHeaderDateFormatterNew Sun, 27 Dec 2016 19:18:46 GMT thrpt 6 4339785.100 ± 57542.660 ops/s + */ + + /*Old DateFormatter.tryParseMonth method: + private boolean tryParseMonth(CharSequence txt, int tokenStart, int tokenEnd) { + int len = tokenEnd - tokenStart; + + if (len != 3) { + return false; + } + + if (matchMonth("Jan", txt, tokenStart)) { + month = Calendar.JANUARY; + } else if (matchMonth("Feb", txt, tokenStart)) { + month = Calendar.FEBRUARY; + } else if (matchMonth("Mar", txt, tokenStart)) { + month = Calendar.MARCH; + } else if (matchMonth("Apr", txt, tokenStart)) { + month = Calendar.APRIL; + } else if (matchMonth("May", txt, tokenStart)) { + month = Calendar.MAY; + } else if (matchMonth("Jun", txt, tokenStart)) { + month = Calendar.JUNE; + } else if (matchMonth("Jul", txt, tokenStart)) { + month = Calendar.JULY; + } else if (matchMonth("Aug", txt, tokenStart)) { + month = Calendar.AUGUST; + } else if (matchMonth("Sep", txt, tokenStart)) { + month = Calendar.SEPTEMBER; + } else if (matchMonth("Oct", txt, tokenStart)) { + month = Calendar.OCTOBER; + } else if (matchMonth("Nov", txt, tokenStart)) { + month = Calendar.NOVEMBER; + } else if (matchMonth("Dec", txt, tokenStart)) { + month = Calendar.DECEMBER; + } else { + return false; + } + + return true; + } + */ + +} diff --git a/microbench/src/main/java/io/netty/handler/codec/http2/HpackBenchmarkUtil.java b/microbench/src/main/java/io/netty/handler/codec/http2/HpackBenchmarkUtil.java index c94bc59a8c75..e8a3a56d1ea8 100644 --- a/microbench/src/main/java/io/netty/handler/codec/http2/HpackBenchmarkUtil.java +++ b/microbench/src/main/java/io/netty/handler/codec/http2/HpackBenchmarkUtil.java @@ -49,7 +49,7 @@ private static class HeadersKey { final HpackHeadersSize size; final boolean limitToAscii; - public HeadersKey(HpackHeadersSize size, boolean limitToAscii) { + HeadersKey(HpackHeadersSize size, boolean limitToAscii) { this.size = size; this.limitToAscii = limitToAscii; } diff --git a/microbench/src/main/java/io/netty/microbench/buffer/ByteBufBenchmark.java b/microbench/src/main/java/io/netty/microbench/buffer/ByteBufBenchmark.java index 70d1ceef532a..c1cfa39add1d 100644 --- a/microbench/src/main/java/io/netty/microbench/buffer/ByteBufBenchmark.java +++ b/microbench/src/main/java/io/netty/microbench/buffer/ByteBufBenchmark.java @@ -20,6 +20,7 @@ import io.netty.buffer.Unpooled; import io.netty.microbench.util.AbstractMicrobenchmark; import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.TearDown; @@ -27,10 +28,13 @@ public class ByteBufBenchmark extends AbstractMicrobenchmark { static { - System.setProperty("io.netty.buffer.bytebuf.checkAccessible", "false"); + System.setProperty("io.netty.buffer.checkAccessible", "false"); } private static final byte BYTE = '0'; + @Param({ "true", "false" }) + public String checkBounds; + private ByteBuffer byteBuffer; private ByteBuffer directByteBuffer; private ByteBuf buffer; @@ -39,6 +43,7 @@ public class ByteBufBenchmark extends AbstractMicrobenchmark { @Setup public void setup() { + System.setProperty("io.netty.buffer.checkBounds", checkBounds); byteBuffer = ByteBuffer.allocate(8); directByteBuffer = ByteBuffer.allocateDirect(8); buffer = Unpooled.buffer(8); diff --git a/microbench/src/main/java/io/netty/microbench/buffer/HeapByteBufBenchmark.java b/microbench/src/main/java/io/netty/microbench/buffer/HeapByteBufBenchmark.java index b7cf0c5fa8fe..06df3af9cb95 100644 --- a/microbench/src/main/java/io/netty/microbench/buffer/HeapByteBufBenchmark.java +++ b/microbench/src/main/java/io/netty/microbench/buffer/HeapByteBufBenchmark.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.microbench.util.AbstractMicrobenchmark; import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.TearDown; @@ -26,6 +27,9 @@ public class HeapByteBufBenchmark extends AbstractMicrobenchmark { + @Param({ "true", "false" }) + public String checkBounds; + private ByteBuf unsafeBuffer; private ByteBuf buffer; @@ -39,6 +43,7 @@ private static ByteBuf newBuffer(String classname) throws Exception { @Setup public void setup() throws Exception { + System.setProperty("io.netty.buffer.bytebuf.checkBounds", checkBounds); unsafeBuffer = newBuffer("io.netty.buffer.UnpooledUnsafeHeapByteBuf"); buffer = newBuffer("io.netty.buffer.UnpooledHeapByteBuf"); unsafeBuffer.writeLong(1L); diff --git a/microbench/src/main/java/io/netty/microbench/buffer/UnsafeByteBufBenchmark.java b/microbench/src/main/java/io/netty/microbench/buffer/UnsafeByteBufBenchmark.java new file mode 100644 index 000000000000..435feb61d8f3 --- /dev/null +++ b/microbench/src/main/java/io/netty/microbench/buffer/UnsafeByteBufBenchmark.java @@ -0,0 +1,64 @@ +/* +* Copyright 2018 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +package io.netty.microbench.buffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.buffer.UnpooledUnsafeDirectByteBuf; +import io.netty.microbench.util.AbstractMicrobenchmark; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; + +import java.nio.ByteBuffer; + + +public class UnsafeByteBufBenchmark extends AbstractMicrobenchmark { + + private ByteBuf unsafeBuffer; + private ByteBuffer byteBuffer; + + @Setup + public void setup() { + unsafeBuffer = new UnpooledUnsafeDirectByteBuf(UnpooledByteBufAllocator.DEFAULT, 64, 64); + byteBuffer = ByteBuffer.allocateDirect(64); + } + + @TearDown + public void tearDown() { + unsafeBuffer.release(); + } + + @Benchmark + public long setGetLongUnsafeByteBuf() { + return unsafeBuffer.setLong(0, 1).getLong(0); + } + + @Benchmark + public long setGetLongByteBuffer() { + return byteBuffer.putLong(0, 1).getLong(0); + } + + @Benchmark + public ByteBuf setLongUnsafeByteBuf() { + return unsafeBuffer.setLong(0, 1); + } + + @Benchmark + public ByteBuffer setLongByteBuffer() { + return byteBuffer.putLong(0, 1); + } +} diff --git a/microbench/src/main/java/io/netty/microbench/channel/epoll/EpollSocketChannelBenchmark.java b/microbench/src/main/java/io/netty/microbench/channel/epoll/EpollSocketChannelBenchmark.java new file mode 100644 index 000000000000..5ecd18668d21 --- /dev/null +++ b/microbench/src/main/java/io/netty/microbench/channel/epoll/EpollSocketChannelBenchmark.java @@ -0,0 +1,139 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.microbench.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.epoll.EpollSocketChannel; +import io.netty.microbench.util.AbstractMicrobenchmark; +import io.netty.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.TearDown; + +public class EpollSocketChannelBenchmark extends AbstractMicrobenchmark { + + private EpollEventLoopGroup group; + private Channel serverChan; + private Channel chan; + private ByteBuf abyte; + private ScheduledFuture future; + + @Setup + public void setup() throws Exception { + group = new EpollEventLoopGroup(1); + + // add an arbitrary timeout to make the timer reschedule + future = group.schedule(new Runnable() { + @Override + public void run() { + throw new AssertionError(); + } + }, 5, TimeUnit.MINUTES); + serverChan = new ServerBootstrap() + .channel(EpollServerSocketChannel.class) + .group(group) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new ChannelDuplexHandler() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof ByteBuf) { + ctx.writeAndFlush(msg, ctx.voidPromise()); + } else { + throw new AssertionError(); + } + } + }); + } + }) + .bind(0) + .sync() + .channel(); + chan = new Bootstrap() + .channel(EpollSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new ChannelDuplexHandler() { + + private ChannelPromise lastWritePromise; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof ByteBuf) { + + ByteBuf buf = (ByteBuf) msg; + try { + if (buf.readableBytes() == 1) { + lastWritePromise.trySuccess(); + lastWritePromise = null; + } else { + throw new AssertionError(); + } + } finally { + buf.release(); + } + } else { + throw new AssertionError(); + } + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (lastWritePromise != null) { + throw new IllegalStateException(); + } + lastWritePromise = promise; + super.write(ctx, msg, ctx.voidPromise()); + } + }); + } + }) + .group(group) + .connect(serverChan.localAddress()) + .sync() + .channel(); + + abyte = chan.alloc().directBuffer(1); + abyte.writeByte('a'); + } + + @TearDown + public void tearDown() throws Exception { + chan.close().sync(); + serverChan.close().sync(); + future.cancel(true); + group.shutdownGracefully(0, 0, TimeUnit.SECONDS).sync(); + abyte.release(); + } + + @Benchmark + public Object pingPong() throws Exception { + return chan.pipeline().writeAndFlush(abyte.retainedSlice()).sync(); + } +} diff --git a/microbench/src/main/java/io/netty/microbench/channel/epoll/package-info.java b/microbench/src/main/java/io/netty/microbench/channel/epoll/package-info.java new file mode 100644 index 000000000000..08304d05d00c --- /dev/null +++ b/microbench/src/main/java/io/netty/microbench/channel/epoll/package-info.java @@ -0,0 +1,19 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/** + * Benchmarks for {@link io.netty.microbench.channel.epoll}. + */ +package io.netty.microbench.channel.epoll; diff --git a/microbench/src/main/java/io/netty/microbench/concurrent/BurstCostExecutorsBenchmark.java b/microbench/src/main/java/io/netty/microbench/concurrent/BurstCostExecutorsBenchmark.java new file mode 100644 index 000000000000..acab0f033d4f --- /dev/null +++ b/microbench/src/main/java/io/netty/microbench/concurrent/BurstCostExecutorsBenchmark.java @@ -0,0 +1,331 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.microbench.concurrent; + +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.kqueue.KQueue; +import io.netty.channel.kqueue.KQueueEventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.microbench.util.AbstractMicrobenchmark; +import io.netty.util.concurrent.DefaultEventExecutor; +import io.netty.util.internal.PlatformDependent; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.Collection; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class BurstCostExecutorsBenchmark extends AbstractMicrobenchmark { + + /** + * This executor is useful as the best burst latency performer because it won't go to sleep and won't be hit by the + * cost of being awaken on both offer/consumer side. + */ + private static final class SpinExecutorService implements ExecutorService { + + private static final Runnable POISON_PILL = new Runnable() { + @Override + public void run() { + } + }; + private final Queue tasks; + private final AtomicBoolean poisoned = new AtomicBoolean(); + private final Thread executorThread; + + SpinExecutorService(int maxTasks) { + tasks = PlatformDependent.newFixedMpscQueue(maxTasks); + executorThread = new Thread(new Runnable() { + @Override + public void run() { + final Queue tasks = SpinExecutorService.this.tasks; + Runnable task; + while ((task = tasks.poll()) != POISON_PILL) { + if (task != null) { + task.run(); + } + } + } + }); + executorThread.start(); + } + + @Override + public void shutdown() { + if (poisoned.compareAndSet(false, true)) { + while (!tasks.offer(POISON_PILL)) { + // Just try again + } + try { + executorThread.join(); + } catch (InterruptedException e) { + //We're quite trusty :) + } + } + } + + @Override + public List shutdownNow() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isShutdown() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isTerminated() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + throw new UnsupportedOperationException(); + } + + @Override + public Future submit(Callable task) { + throw new UnsupportedOperationException(); + } + + @Override + public Future submit(Runnable task, T result) { + throw new UnsupportedOperationException(); + } + + @Override + public Future submit(Runnable task) { + throw new UnsupportedOperationException(); + } + + @Override + public List> invokeAll(Collection> tasks) throws InterruptedException { + throw new UnsupportedOperationException(); + } + + @Override + public List> invokeAll(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException { + throw new UnsupportedOperationException(); + } + + @Override + public T invokeAny(Collection> tasks) + throws InterruptedException, ExecutionException { + throw new UnsupportedOperationException(); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + throw new UnsupportedOperationException(); + } + + @Override + public void execute(Runnable command) { + if (!tasks.offer(command)) { + throw new RejectedExecutionException( + "If that happens, there is something wrong with the available capacity/burst size"); + } + } + } + + private enum ExecutorType { + spinning, + defaultEventExecutor, + juc, + nioEventLoop, + epollEventLoop, + kqueueEventLoop + } + + @Param({ "1", "10" }) + private int burstLength; + @Param({ "spinning", "epollEventLoop", "nioEventLoop", "defaultEventExecutor", "juc", "kqueueEventLoop" }) + private String executorType; + @Param({ "0", "10" }) + private int work; + + private ExecutorService executor; + private ExecutorService executorToShutdown; + + @Setup + public void setup() { + ExecutorType type = ExecutorType.valueOf(executorType); + switch (type) { + case spinning: + //The case with 3 producers can have a peak of 3*burstLength offers: + //4 is to leave some room between the offers and 1024 is to leave some room + //between producer/consumer when work is > 0 and 1 producer. + //If work = 0 then the task queue is supposed to be near empty most of the time. + executor = new SpinExecutorService(Math.min(1024, burstLength * 4)); + executorToShutdown = executor; + break; + case defaultEventExecutor: + executor = new DefaultEventExecutor(); + executorToShutdown = executor; + break; + case juc: + executor = Executors.newSingleThreadScheduledExecutor(); + executorToShutdown = executor; + break; + case nioEventLoop: + NioEventLoopGroup nioEventLoopGroup = new NioEventLoopGroup(1); + nioEventLoopGroup.setIoRatio(1); + executor = nioEventLoopGroup.next(); + executorToShutdown = nioEventLoopGroup; + break; + case epollEventLoop: + Epoll.ensureAvailability(); + EpollEventLoopGroup epollEventLoopGroup = new EpollEventLoopGroup(1); + epollEventLoopGroup.setIoRatio(1); + executor = epollEventLoopGroup.next(); + executorToShutdown = epollEventLoopGroup; + break; + case kqueueEventLoop: + KQueue.ensureAvailability(); + KQueueEventLoopGroup kQueueEventLoopGroup = new KQueueEventLoopGroup(1); + kQueueEventLoopGroup.setIoRatio(1); + executor = kQueueEventLoopGroup.next(); + executorToShutdown = kQueueEventLoopGroup; + break; + } + } + + @TearDown + public void tearDown() { + executorToShutdown.shutdown(); + } + + @State(Scope.Thread) + public static class PerThreadState { + //To reduce the benchmark noise we avoid using AtomicInteger that would + //suffer of false sharing while reading/writing the counter due to the surrounding + //instances on heap: thanks to JMH the "completed" field will be padded + //avoiding false-sharing for free + private static final AtomicIntegerFieldUpdater DONE_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(PerThreadState.class, "completed"); + private volatile int completed; + + private Runnable completeTask; + + @Setup + public void setup(BurstCostExecutorsBenchmark bench) { + final int work = bench.work; + if (work > 0) { + completeTask = new Runnable() { + @Override + public void run() { + Blackhole.consumeCPU(work); + //We can avoid the full barrier cost of a volatile set given that the + //benchmark is focusing on executors with a single threaded consumer: + //it would reduce the cost on consumer side while allowing to focus just + //to the threads hand-off/wake-up cost + DONE_UPDATER.lazySet(PerThreadState.this, completed + 1); + } + }; + } else { + completeTask = new Runnable() { + @Override + public void run() { + //We can avoid the full barrier cost of a volatile set given that the + //benchmark is focusing on executors with a single threaded consumer: + //it would reduce the cost on consumer side while allowing to focus just + //to the threads hand-off/wake-up cost + DONE_UPDATER.lazySet(PerThreadState.this, completed + 1); + } + }; + } + } + + /** + * Single-writer reset of completed counter. + */ + public void resetCompleted() { + //We can avoid the full barrier cost of a volatile set given that + //the counter can be reset from a single thread and it should be reset + //only after any submitted tasks are completed + DONE_UPDATER.lazySet(this, 0); + } + + /** + * It would spin-wait until at least {@code value} tasks are being completed. + */ + public int spinWaitCompletionOf(int value) { + while (true) { + final int lastRead = this.completed; + if (lastRead >= value) { + return lastRead; + } + } + } + } + + @Benchmark + @BenchmarkMode(Mode.SampleTime) + @Threads(1) + public int test1Producer(final PerThreadState state) { + return executeBurst(state); + } + + @Benchmark + @BenchmarkMode(Mode.SampleTime) + @Threads(2) + public int test2Producers(final PerThreadState state) { + return executeBurst(state); + } + + @Benchmark + @BenchmarkMode(Mode.SampleTime) + @Threads(3) + public int test3Producers(final PerThreadState state) { + return executeBurst(state); + } + + private int executeBurst(final PerThreadState state) { + final ExecutorService executor = this.executor; + final int burstLength = this.burstLength; + final Runnable completeTask = state.completeTask; + for (int i = 0; i < burstLength; i++) { + executor.execute(completeTask); + } + final int value = state.spinWaitCompletionOf(burstLength); + state.resetCompleted(); + return value; + } +} diff --git a/microbench/src/main/java/io/netty/microbench/headers/HeadersBenchmark.java b/microbench/src/main/java/io/netty/microbench/headers/HeadersBenchmark.java index 589a89200810..c8706a0a447d 100644 --- a/microbench/src/main/java/io/netty/microbench/headers/HeadersBenchmark.java +++ b/microbench/src/main/java/io/netty/microbench/headers/HeadersBenchmark.java @@ -102,14 +102,6 @@ public void setup() { emptyHttp2HeadersNoValidate = new DefaultHttp2Headers(false); } - @Setup(Level.Invocation) - public void setupEmptyHeaders() { - emptyHttpHeaders.clear(); - emptyHttp2Headers .clear(); - emptyHttpHeadersNoValidate.clear(); - emptyHttp2HeadersNoValidate.clear(); - } - @Benchmark @BenchmarkMode(Mode.AverageTime) public void httpRemove(Blackhole bh) { @@ -183,30 +175,35 @@ public void http2Iterate(Blackhole bh) { @BenchmarkMode(Mode.AverageTime) public void httpAddAllFastest(Blackhole bh) { bh.consume(emptyHttpHeadersNoValidate.add(httpHeaders)); + emptyHttpHeadersNoValidate.clear(); } @Benchmark @BenchmarkMode(Mode.AverageTime) public void httpAddAllFast(Blackhole bh) { bh.consume(emptyHttpHeaders.add(httpHeaders)); + emptyHttpHeaders.clear(); } @Benchmark @BenchmarkMode(Mode.AverageTime) public void http2AddAllFastest(Blackhole bh) { bh.consume(emptyHttp2HeadersNoValidate.add(http2Headers)); + emptyHttp2HeadersNoValidate.clear(); } @Benchmark @BenchmarkMode(Mode.AverageTime) public void http2AddAllFast(Blackhole bh) { bh.consume(emptyHttp2Headers.add(http2Headers)); + emptyHttp2Headers.clear(); } @Benchmark @BenchmarkMode(Mode.AverageTime) public void http2AddAllSlow(Blackhole bh) { bh.consume(emptyHttp2Headers.add(slowHttp2Headers)); + emptyHttp2Headers.clear(); } private static final class SlowHeaders implements Headers { diff --git a/microbench/src/main/java/io/netty/microbench/util/AbstractMicrobenchmarkBase.java b/microbench/src/main/java/io/netty/microbench/util/AbstractMicrobenchmarkBase.java index d3400d34635c..a623cc89ea96 100644 --- a/microbench/src/main/java/io/netty/microbench/util/AbstractMicrobenchmarkBase.java +++ b/microbench/src/main/java/io/netty/microbench/util/AbstractMicrobenchmarkBase.java @@ -43,8 +43,7 @@ public abstract class AbstractMicrobenchmarkBase { protected static final int DEFAULT_WARMUP_ITERATIONS = 10; protected static final int DEFAULT_MEASURE_ITERATIONS = 10; protected static final String[] BASE_JVM_ARGS = { - "-server", "-dsa", "-da", "-ea:io.netty...", "-XX:+AggressiveOpts", "-XX:+UseBiasedLocking", - "-XX:+UseFastAccessorMethods", "-XX:+OptimizeStringConcat", + "-server", "-dsa", "-da", "-ea:io.netty...", "-XX:+HeapDumpOnOutOfMemoryError", "-Dio.netty.leakDetection.level=disabled"}; static { @@ -93,8 +92,7 @@ protected static String[] removeAssertions(String[] jvmArgs) { } } if (jvmArgs.length != customArgs.size()) { - jvmArgs = new String[customArgs.size()]; - customArgs.toArray(jvmArgs); + jvmArgs = customArgs.toArray(new String[0]); } return jvmArgs; } diff --git a/pom.xml b/pom.xml index df92f7a2c5e8..77de42766bc0 100644 --- a/pom.xml +++ b/pom.xml @@ -17,16 +17,16 @@ 4.0.0 - + io.netty netty-parent pom - 4.1.25.5.dse + 4.1.34.3.dse Netty http://netty.io/ @@ -53,7 +53,6 @@ https://github.com/netty/netty scm:git:git://github.com/netty/netty.git scm:git:ssh://git@github.com/netty/netty.git - netty-4.1.25.dse @@ -68,6 +67,50 @@ + + + java13 + + 13 + + + + + true + + 3.0.0-M1 + + 2.0.5.Final + + 1.7 + 1.7 + + true + + + + + + java12 + + 12 + + + + + true + + 3.0.0-M1 + + 2.0.5.Final + + 1.7 + 1.7 + + true + + + java11 @@ -82,6 +125,8 @@ 3.0.0-M1 2.0.5.Final + + true @@ -121,6 +166,13 @@ + + boringssl + + netty-tcnative-boringssl-static + + + leak @@ -136,7 +188,7 @@ noUnsafe - -Dio.netty.noUnsafe + -Dio.netty.noUnsafe=true @@ -200,9 +252,9 @@ ${project.build.directory}/dev-tools UTF-8 UTF-8 - 22 + 23 1.4.11.Final - 2.0.7 + 2.0.8 "${settings.localRepository}"/org/mortbay/jetty/alpn/jetty-alpn-agent/${jetty.alpnAgent.version}/jetty-alpn-agent-${jetty.alpnAgent.version}.jar -server @@ -221,11 +273,11 @@ fedora netty-tcnative - 2.0.8.Final + 2.0.22.Final ${os.detected.classifier} org.conscrypt conscrypt-openjdk-uber - 1.0.1 + 1.3.0 ${os.detected.name}-${os.detected.arch} ${project.basedir}/../common/src/test/resources/logback-test.xml @@ -236,6 +288,7 @@ false false false + true @@ -273,6 +326,7 @@ testsuite-autobahn testsuite-http2 testsuite-osgi + testsuite-shading microbench bom @@ -538,7 +592,7 @@ org.apache.commons commons-compress - 1.12 + 1.18 test @@ -616,6 +670,35 @@ + + com.github.siom79.japicmp + japicmp-maven-plugin + 0.13.1 + + + true + true + \d+\.\d+\.\d+\.Final + + + ^(?!io\.netty\.).* + ^io\.netty\.internal\.tcnative\..* + + + @io.netty.util.internal.UnstableApi + + + ${skipJapicmp} + + + + verify + + cmp + + + + maven-enforcer-plugin ${enforcer.plugin.version} @@ -637,10 +720,10 @@ - x86_64 JDK must be used. + x86_64/AARCH64 JDK must be used. os.detected.arch - ^x86_64$ + ^(x86_64|aarch_64)$ @@ -649,7 +732,7 @@ maven-compiler-plugin - 3.6.0 + 3.8.0 1.8 true @@ -684,8 +767,8 @@ org.codehaus.mojo.signature - java18 - 1.0 + java16 + 1.1 sun.misc.Unsafe @@ -701,6 +784,14 @@ java.nio.channels.SocketChannel java.net.StandardProtocolFamily java.nio.channels.spi.SelectorProvider + java.net.SocketOption + java.net.StandardSocketOptions + java.nio.channels.NetworkChannel + + + java.nio.channels.AsynchronousFileChannel + java.nio.channels.CompletionHandler + java.util.concurrent.CompletableFuture sun.security.x509.AlgorithmId @@ -717,6 +808,7 @@ javax.net.ssl.SSLEngine + javax.net.ssl.ExtendedSSLSession javax.net.ssl.X509ExtendedTrustManager javax.net.ssl.SSLParameters javax.net.ssl.SNIServerName @@ -745,6 +837,7 @@ java.util.concurrent.atomic.LongAdder java.util.function.BiFunction + java.security.cert.X509Certificate java.net.InetAddress @@ -771,7 +864,7 @@ maven-checkstyle-plugin - 2.12.1 + 3.0.0 check-style @@ -785,7 +878,10 @@ true true io/netty/checkstyle.xml - true + + ${project.build.sourceDirectory} + ${project.build.testSourceDirectory} + @@ -864,6 +960,8 @@ ${testJavaHome}/bin/java + + false @@ -971,8 +1069,9 @@ - 2.5.3 false @@ -993,7 +1092,7 @@ 1.9.4 - --> + @@ -1110,12 +1209,12 @@ maven-surefire-plugin - 2.19.1 + 2.22.1 maven-failsafe-plugin - 2.19.1 + 2.22.1 maven-clean-plugin @@ -1309,6 +1408,83 @@ + + org.apache.maven.plugins + maven-remote-resources-plugin + 1.5 + + + io.netty:netty-dev-tools:${project.version} + + ${netty.dev.tools.directory} + + false + false + + + + + process + + + + + + de.thetaphi + forbiddenapis + 2.2 + + + check-forbidden-apis + + ${maven.compiler.target} + + false + + false + + + + + + + + + ${netty.dev.tools.directory}/forbidden/signatures.txt + + **.SuppressForbidden + + compile + + check + + + + check-forbidden-test-apis + + ${maven.compiler.target} + + true + + false + + + + + + + + ${netty.dev.tools.directory}/forbidden/signatures.txt + + **.SuppressForbidden + + test-compile + + testCheck + + + + diff --git a/resolver-dns/pom.xml b/resolver-dns/pom.xml index bc4f25a19283..8c48330700f6 100644 --- a/resolver-dns/pom.xml +++ b/resolver-dns/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-resolver-dns @@ -35,12 +35,17 @@ ${project.groupId} - netty-resolver + netty-common ${project.version} ${project.groupId} - netty-codec-dns + netty-buffer + ${project.version} + + + ${project.groupId} + netty-resolver ${project.version} @@ -48,6 +53,16 @@ netty-transport ${project.version} + + ${project.groupId} + netty-codec + ${project.version} + + + ${project.groupId} + netty-codec-dns + ${project.version} + org.apache.directory.server apacheds-protocol-dns diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/AuthoritativeDnsServerCache.java b/resolver-dns/src/main/java/io/netty/resolver/dns/AuthoritativeDnsServerCache.java new file mode 100644 index 000000000000..b011066700d4 --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/AuthoritativeDnsServerCache.java @@ -0,0 +1,63 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.EventLoop; +import io.netty.util.internal.UnstableApi; + +import java.net.InetSocketAddress; + +/** + * Cache which stores the nameservers that should be used to resolve a specific hostname. + */ +@UnstableApi +public interface AuthoritativeDnsServerCache { + + /** + * Returns the cached nameservers that should be used to resolve the given hostname. The returned + * {@link DnsServerAddressStream} may contain unresolved {@link InetSocketAddress}es that will be resolved + * when needed while resolving other domain names. + * + * @param hostname the hostname + * @return the cached entries or an {@code null} if none. + */ + DnsServerAddressStream get(String hostname); + + /** + * Caches a nameserver that should be used to resolve the given hostname. + * + * @param hostname the hostname + * @param address the nameserver address (which may be unresolved). + * @param originalTtl the TTL as returned by the DNS server + * @param loop the {@link EventLoop} used to register the TTL timeout + */ + void cache(String hostname, InetSocketAddress address, long originalTtl, EventLoop loop); + + /** + * Clears all cached nameservers. + * + * @see #clear(String) + */ + void clear(); + + /** + * Clears the cached nameservers for the specified hostname. + * + * @return {@code true} if and only if there was an entry for the specified host name in the cache and + * it has been removed by this method + */ + boolean clear(String hostname); +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/AuthoritativeDnsServerCacheAdapter.java b/resolver-dns/src/main/java/io/netty/resolver/dns/AuthoritativeDnsServerCacheAdapter.java new file mode 100644 index 000000000000..470106e69534 --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/AuthoritativeDnsServerCacheAdapter.java @@ -0,0 +1,80 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.EventLoop; +import io.netty.handler.codec.dns.DnsRecord; +import io.netty.util.internal.UnstableApi; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * {@link AuthoritativeDnsServerCache} implementation which delegates all operations to a wrapped {@link DnsCache}. + * This implementation is only present to preserve a upgrade story. + */ +@UnstableApi +final class AuthoritativeDnsServerCacheAdapter implements AuthoritativeDnsServerCache { + + private static final DnsRecord[] EMPTY = new DnsRecord[0]; + private final DnsCache cache; + + AuthoritativeDnsServerCacheAdapter(DnsCache cache) { + this.cache = checkNotNull(cache, "cache"); + } + + @Override + public DnsServerAddressStream get(String hostname) { + List entries = cache.get(hostname, EMPTY); + if (entries == null || entries.isEmpty()) { + return null; + } + if (entries.get(0).cause() != null) { + return null; + } + + List addresses = new ArrayList(entries.size()); + + int i = 0; + do { + InetAddress addr = entries.get(i).address(); + addresses.add(new InetSocketAddress(addr, DefaultDnsServerAddressStreamProvider.DNS_PORT)); + } while (++i < entries.size()); + return new SequentialDnsServerAddressStream(addresses, 0); + } + + @Override + public void cache(String hostname, InetSocketAddress address, long originalTtl, EventLoop loop) { + // We only cache resolved addresses. + if (!address.isUnresolved()) { + cache.cache(hostname, EMPTY, address.getAddress(), originalTtl, loop); + } + } + + @Override + public void clear() { + cache.clear(); + } + + @Override + public boolean clear(String hostname) { + return cache.clear(hostname); + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/Cache.java b/resolver-dns/src/main/java/io/netty/resolver/dns/Cache.java new file mode 100644 index 000000000000..cdd590e8fc4e --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/Cache.java @@ -0,0 +1,292 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.EventLoop; +import io.netty.util.internal.PlatformDependent; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Delayed; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import static java.util.Collections.singletonList; + +/** + * Abstract cache that automatically removes entries for a hostname once the TTL for an entry is reached. + * + * @param + */ +abstract class Cache { + private static final AtomicReferenceFieldUpdater FUTURE_UPDATER = + AtomicReferenceFieldUpdater.newUpdater(Cache.Entries.class, ScheduledFuture.class, "expirationFuture"); + + private static final ScheduledFuture CANCELLED = new ScheduledFuture() { + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public long getDelay(TimeUnit unit) { + // We ignore unit and always return the minimum value to ensure the TTL of the cancelled marker is + // the smallest. + return Long.MIN_VALUE; + } + + @Override + public int compareTo(Delayed o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCancelled() { + return true; + } + + @Override + public boolean isDone() { + return true; + } + + @Override + public Object get() { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(long timeout, TimeUnit unit) { + throw new UnsupportedOperationException(); + } + }; + + // Two years are supported by all our EventLoop implementations and so safe to use as maximum. + // See also: https://github.com/netty/netty/commit/b47fb817991b42ec8808c7d26538f3f2464e1fa6 + static final int MAX_SUPPORTED_TTL_SECS = (int) TimeUnit.DAYS.toSeconds(365 * 2); + + private final ConcurrentMap resolveCache = PlatformDependent.newConcurrentHashMap(); + + /** + * Remove everything from the cache. + */ + final void clear() { + while (!resolveCache.isEmpty()) { + for (Iterator> i = resolveCache.entrySet().iterator(); i.hasNext();) { + Map.Entry e = i.next(); + i.remove(); + + e.getValue().clearAndCancel(); + } + } + } + + /** + * Clear all entries (if anything exists) for the given hostname and return {@code true} if anything was removed. + */ + final boolean clear(String hostname) { + Entries entries = resolveCache.remove(hostname); + return entries != null && entries.clearAndCancel(); + } + + /** + * Returns all caches entries for the given hostname. + */ + final List get(String hostname) { + Entries entries = resolveCache.get(hostname); + return entries == null ? null : entries.get(); + } + + /** + * Cache a value for the given hostname that will automatically expire once the TTL is reached. + */ + final void cache(String hostname, E value, int ttl, EventLoop loop) { + Entries entries = resolveCache.get(hostname); + if (entries == null) { + entries = new Entries(hostname); + Entries oldEntries = resolveCache.putIfAbsent(hostname, entries); + if (oldEntries != null) { + entries = oldEntries; + } + } + entries.add(value, ttl, loop); + } + + /** + * Return the number of hostames for which we have cached something. + */ + final int size() { + return resolveCache.size(); + } + + /** + * Returns {@code true} if this entry should replace all other entries that are already cached for the hostname. + */ + protected abstract boolean shouldReplaceAll(E entry); + + /** + * Sort the {@link List} for a {@code hostname} before caching these. + */ + protected void sortEntries( + @SuppressWarnings("unused") String hostname, @SuppressWarnings("unused") List entries) { + // NOOP. + } + + /** + * Returns {@code true} if both entries are equal. + */ + protected abstract boolean equals(E entry, E otherEntry); + + // Directly extend AtomicReference for intrinsics and also to keep memory overhead low. + private final class Entries extends AtomicReference> implements Runnable { + + private final String hostname; + // Needs to be package-private to be able to access it via the AtomicReferenceFieldUpdater + volatile ScheduledFuture expirationFuture; + + Entries(String hostname) { + super(Collections.emptyList()); + this.hostname = hostname; + } + + void add(E e, int ttl, EventLoop loop) { + if (!shouldReplaceAll(e)) { + for (;;) { + List entries = get(); + if (!entries.isEmpty()) { + final E firstEntry = entries.get(0); + if (shouldReplaceAll(firstEntry)) { + assert entries.size() == 1; + + if (compareAndSet(entries, singletonList(e))) { + scheduleCacheExpirationIfNeeded(ttl, loop); + return; + } else { + // Need to try again as CAS failed + continue; + } + } + + // Create a new List for COW semantics + List newEntries = new ArrayList(entries.size() + 1); + int i = 0; + E replacedEntry = null; + do { + E entry = entries.get(i); + // Only add old entry if the address is not the same as the one we try to add as well. + // In this case we will skip it and just add the new entry as this may have + // more up-to-date data and cancel the old after we were able to update the cache. + if (!Cache.this.equals(e, entry)) { + newEntries.add(entry); + } else { + replacedEntry = entry; + newEntries.add(e); + + ++i; + for (; i < entries.size(); ++i) { + newEntries.add(entries.get(i)); + } + break; + } + } while (++i < entries.size()); + if (replacedEntry == null) { + newEntries.add(e); + } + sortEntries(hostname, newEntries); + + if (compareAndSet(entries, Collections.unmodifiableList(newEntries))) { + scheduleCacheExpirationIfNeeded(ttl, loop); + return; + } + } else if (compareAndSet(entries, singletonList(e))) { + scheduleCacheExpirationIfNeeded(ttl, loop); + return; + } + } + } else { + set(singletonList(e)); + scheduleCacheExpirationIfNeeded(ttl, loop); + } + } + + private void scheduleCacheExpirationIfNeeded(int ttl, EventLoop loop) { + for (;;) { + // We currently don't calculate a new TTL when we need to retry the CAS as we don't expect this to + // be invoked very concurrently and also we use SECONDS anyway. If this ever becomes a problem + // we can reconsider. + ScheduledFuture oldFuture = FUTURE_UPDATER.get(this); + if (oldFuture == null || oldFuture.getDelay(TimeUnit.SECONDS) > ttl) { + ScheduledFuture newFuture = loop.schedule(this, ttl, TimeUnit.SECONDS); + // It is possible that + // 1. task will fire in between this line, or + // 2. multiple timers may be set if there is concurrency + // (1) Shouldn't be a problem because we will fail the CAS and then the next loop will see CANCELLED + // so the ttl will not be less, and we will bail out of the loop. + // (2) This is a trade-off to avoid concurrency resulting in contention on a synchronized block. + if (FUTURE_UPDATER.compareAndSet(this, oldFuture, newFuture)) { + if (oldFuture != null) { + oldFuture.cancel(true); + } + break; + } else { + // There was something else scheduled in the meantime... Cancel and try again. + newFuture.cancel(true); + } + } else { + break; + } + } + } + + boolean clearAndCancel() { + List entries = getAndSet(Collections.emptyList()); + if (entries.isEmpty()) { + return false; + } + + ScheduledFuture expirationFuture = FUTURE_UPDATER.getAndSet(this, CANCELLED); + if (expirationFuture != null) { + expirationFuture.cancel(false); + } + + return true; + } + + @Override + public void run() { + // We always remove all entries for a hostname once one entry expire. This is not the + // most efficient to do but this way we can guarantee that if a DnsResolver + // be configured to prefer one ip family over the other we will not return unexpected + // results to the enduser if one of the A or AAAA records has different TTL settings. + // + // As a TTL is just a hint of the maximum time a cache is allowed to cache stuff it's + // completely fine to remove the entry even if the TTL is not reached yet. + // + // See https://github.com/netty/netty/issues/7329 + resolveCache.remove(hostname, this); + + clearAndCancel(); + } + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultAuthoritativeDnsServerCache.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultAuthoritativeDnsServerCache.java new file mode 100644 index 000000000000..6bc29a6a736c --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultAuthoritativeDnsServerCache.java @@ -0,0 +1,130 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.EventLoop; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; + +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.ConcurrentMap; + +import static io.netty.util.internal.ObjectUtil.*; + +/** + * Default implementation of {@link AuthoritativeDnsServerCache}, backed by a {@link ConcurrentMap}. + */ +@UnstableApi +public class DefaultAuthoritativeDnsServerCache implements AuthoritativeDnsServerCache { + + private final int minTtl; + private final int maxTtl; + private final Comparator comparator; + private final Cache resolveCache = new Cache() { + @Override + protected boolean shouldReplaceAll(InetSocketAddress entry) { + return false; + } + + @Override + protected boolean equals(InetSocketAddress entry, InetSocketAddress otherEntry) { + if (PlatformDependent.javaVersion() >= 7) { + return entry.getHostString().equalsIgnoreCase(otherEntry.getHostString()); + } + return entry.getHostName().equalsIgnoreCase(otherEntry.getHostName()); + } + + @Override + protected void sortEntries(String hostname, List entries) { + if (comparator != null) { + Collections.sort(entries, comparator); + } + } + }; + + /** + * Create a cache that respects the TTL returned by the DNS server. + */ + public DefaultAuthoritativeDnsServerCache() { + this(0, Cache.MAX_SUPPORTED_TTL_SECS, null); + } + + /** + * Create a cache. + * + * @param minTtl the minimum TTL + * @param maxTtl the maximum TTL + * @param comparator the {@link Comparator} to order the {@link InetSocketAddress} for a hostname or {@code null} + * if insertion order should be used. + */ + public DefaultAuthoritativeDnsServerCache(int minTtl, int maxTtl, Comparator comparator) { + this.minTtl = Math.min(Cache.MAX_SUPPORTED_TTL_SECS, checkPositiveOrZero(minTtl, "minTtl")); + this.maxTtl = Math.min(Cache.MAX_SUPPORTED_TTL_SECS, checkPositive(maxTtl, "maxTtl")); + if (minTtl > maxTtl) { + throw new IllegalArgumentException( + "minTtl: " + minTtl + ", maxTtl: " + maxTtl + " (expected: 0 <= minTtl <= maxTtl)"); + } + this.comparator = comparator; + } + + @SuppressWarnings("unchecked") + @Override + public DnsServerAddressStream get(String hostname) { + checkNotNull(hostname, "hostname"); + + List addresses = resolveCache.get(hostname); + if (addresses == null || addresses.isEmpty()) { + return null; + } + return new SequentialDnsServerAddressStream(addresses, 0); + } + + @Override + public void cache(String hostname, InetSocketAddress address, long originalTtl, EventLoop loop) { + checkNotNull(hostname, "hostname"); + checkNotNull(address, "address"); + checkNotNull(loop, "loop"); + + if (PlatformDependent.javaVersion() >= 7 && address.getHostString() == null) { + // We only cache addresses that have also a host string as we will need it later when trying to replace + // unresolved entries in the cache. + return; + } + + resolveCache.cache(hostname, address, Math.max(minTtl, (int) Math.min(maxTtl, originalTtl)), loop); + } + + @Override + public void clear() { + resolveCache.clear(); + } + + @Override + public boolean clear(String hostname) { + checkNotNull(hostname, "hostname"); + + return resolveCache.clear(hostname); + } + + @Override + public String toString() { + return "DefaultAuthoritativeDnsServerCache(minTtl=" + minTtl + ", maxTtl=" + maxTtl + ", cached nameservers=" + + resolveCache.size() + ')'; + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCache.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCache.java index d5431362fec0..c1978395204c 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCache.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCache.java @@ -17,19 +17,13 @@ import io.netty.channel.EventLoop; import io.netty.handler.codec.dns.DnsRecord; -import io.netty.util.concurrent.ScheduledFuture; -import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; import io.netty.util.internal.UnstableApi; import java.net.InetAddress; -import java.util.ArrayList; import java.util.Collections; -import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; @@ -41,11 +35,25 @@ @UnstableApi public class DefaultDnsCache implements DnsCache { - private final ConcurrentMap resolveCache = PlatformDependent.newConcurrentHashMap(); + private final Cache resolveCache = new Cache() { + + @Override + protected boolean shouldReplaceAll(DefaultDnsCacheEntry entry) { + return entry.cause() != null; + } + + @Override + protected boolean equals(DefaultDnsCacheEntry entry, DefaultDnsCacheEntry otherEntry) { + if (entry.address() != null) { + return entry.address().equals(otherEntry.address()); + } + if (otherEntry.address() != null) { + return false; + } + return entry.cause().equals(otherEntry.cause()); + } + }; - // Two years are supported by all our EventLoop implementations and so safe to use as maximum. - // See also: https://github.com/netty/netty/commit/b47fb817991b42ec8808c7d26538f3f2464e1fa6 - private static final int MAX_SUPPORTED_TTL_SECS = (int) TimeUnit.DAYS.toSeconds(365 * 2); private final int minTtl; private final int maxTtl; private final int negativeTtl; @@ -55,7 +63,7 @@ public class DefaultDnsCache implements DnsCache { * and doesn't cache negative responses. */ public DefaultDnsCache() { - this(0, MAX_SUPPORTED_TTL_SECS, 0); + this(0, Cache.MAX_SUPPORTED_TTL_SECS, 0); } /** @@ -65,8 +73,8 @@ public DefaultDnsCache() { * @param negativeTtl the TTL for failed queries */ public DefaultDnsCache(int minTtl, int maxTtl, int negativeTtl) { - this.minTtl = Math.min(MAX_SUPPORTED_TTL_SECS, checkPositiveOrZero(minTtl, "minTtl")); - this.maxTtl = Math.min(MAX_SUPPORTED_TTL_SECS, checkPositiveOrZero(maxTtl, "maxTtl")); + this.minTtl = Math.min(Cache.MAX_SUPPORTED_TTL_SECS, checkPositiveOrZero(minTtl, "minTtl")); + this.maxTtl = Math.min(Cache.MAX_SUPPORTED_TTL_SECS, checkPositiveOrZero(maxTtl, "maxTtl")); if (minTtl > maxTtl) { throw new IllegalArgumentException( "minTtl: " + minTtl + ", maxTtl: " + maxTtl + " (expected: 0 <= minTtl <= maxTtl)"); @@ -102,21 +110,13 @@ public int negativeTtl() { @Override public void clear() { - while (!resolveCache.isEmpty()) { - for (Iterator> i = resolveCache.entrySet().iterator(); i.hasNext();) { - Map.Entry e = i.next(); - i.remove(); - - e.getValue().clearAndCancel(); - } - } + resolveCache.clear(); } @Override public boolean clear(String hostname) { checkNotNull(hostname, "hostname"); - Entries entries = resolveCache.remove(hostname); - return entries != null && entries.clearAndCancel(); + return resolveCache.clear(appendDot(hostname)); } private static boolean emptyAdditionals(DnsRecord[] additionals) { @@ -130,8 +130,7 @@ public List get(String hostname, DnsRecord[] additional return Collections.emptyList(); } - Entries entries = resolveCache.get(hostname); - return entries == null ? null : entries.get(); + return resolveCache.get(appendDot(hostname)); } @Override @@ -140,11 +139,11 @@ public DnsCacheEntry cache(String hostname, DnsRecord[] additionals, checkNotNull(hostname, "hostname"); checkNotNull(address, "address"); checkNotNull(loop, "loop"); - final DefaultDnsCacheEntry e = new DefaultDnsCacheEntry(hostname, address); + DefaultDnsCacheEntry e = new DefaultDnsCacheEntry(hostname, address); if (maxTtl == 0 || !emptyAdditionals(additionals)) { return e; } - cache0(e, Math.max(minTtl, Math.min(MAX_SUPPORTED_TTL_SECS, (int) Math.min(maxTtl, originalTtl))), loop); + resolveCache.cache(appendDot(hostname), e, Math.max(minTtl, (int) Math.min(maxTtl, originalTtl)), loop); return e; } @@ -154,52 +153,15 @@ public DnsCacheEntry cache(String hostname, DnsRecord[] additionals, Throwable c checkNotNull(cause, "cause"); checkNotNull(loop, "loop"); - final DefaultDnsCacheEntry e = new DefaultDnsCacheEntry(hostname, cause); + DefaultDnsCacheEntry e = new DefaultDnsCacheEntry(hostname, cause); if (negativeTtl == 0 || !emptyAdditionals(additionals)) { return e; } - cache0(e, Math.min(MAX_SUPPORTED_TTL_SECS, negativeTtl), loop); + resolveCache.cache(appendDot(hostname), e, negativeTtl, loop); return e; } - private void cache0(DefaultDnsCacheEntry e, int ttl, EventLoop loop) { - Entries entries = resolveCache.get(e.hostname()); - if (entries == null) { - entries = new Entries(e); - Entries oldEntries = resolveCache.putIfAbsent(e.hostname(), entries); - if (oldEntries != null) { - entries = oldEntries; - } - } - entries.add(e); - - scheduleCacheExpiration(e, ttl, loop); - } - - private void scheduleCacheExpiration(final DefaultDnsCacheEntry e, - int ttl, - EventLoop loop) { - e.scheduleExpiration(loop, new Runnable() { - @Override - public void run() { - // We always remove all entries for a hostname once one entry expire. This is not the - // most efficient to do but this way we can guarantee that if a DnsResolver - // be configured to prefer one ip family over the other we will not return unexpected - // results to the enduser if one of the A or AAAA records has different TTL settings. - // - // As a TTL is just a hint of the maximum time a cache is allowed to cache stuff it's - // completely fine to remove the entry even if the TTL is not reached yet. - // - // See https://github.com/netty/netty/issues/7329 - Entries entries = resolveCache.remove(e.hostname); - if (entries != null) { - entries.clearAndCancel(); - } - } - }, ttl, TimeUnit.SECONDS); - } - @Override public String toString() { return new StringBuilder() @@ -207,7 +169,7 @@ public String toString() { .append(minTtl).append(", maxTtl=") .append(maxTtl).append(", negativeTtl=") .append(negativeTtl).append(", cached resolved hostname=") - .append(resolveCache.size()).append(")") + .append(resolveCache.size()).append(')') .toString(); } @@ -215,17 +177,16 @@ private static final class DefaultDnsCacheEntry implements DnsCacheEntry { private final String hostname; private final InetAddress address; private final Throwable cause; - private volatile ScheduledFuture expirationFuture; DefaultDnsCacheEntry(String hostname, InetAddress address) { - this.hostname = checkNotNull(hostname, "hostname"); - this.address = checkNotNull(address, "address"); + this.hostname = hostname; + this.address = address; cause = null; } DefaultDnsCacheEntry(String hostname, Throwable cause) { - this.hostname = checkNotNull(hostname, "hostname"); - this.cause = checkNotNull(cause, "cause"); + this.hostname = hostname; + this.cause = cause; address = null; } @@ -243,18 +204,6 @@ String hostname() { return hostname; } - void scheduleExpiration(EventLoop loop, Runnable task, long delay, TimeUnit unit) { - assert expirationFuture == null : "expiration task scheduled already"; - expirationFuture = loop.schedule(task, delay, unit); - } - - void cancelExpiration() { - ScheduledFuture expirationFuture = this.expirationFuture; - if (expirationFuture != null) { - expirationFuture.cancel(false); - } - } - @Override public String toString() { if (cause != null) { @@ -265,77 +214,7 @@ public String toString() { } } - // Directly extend AtomicReference for intrinsics and also to keep memory overhead low. - private static final class Entries extends AtomicReference> { - - Entries(DefaultDnsCacheEntry entry) { - super(Collections.singletonList(entry)); - } - - void add(DefaultDnsCacheEntry e) { - if (e.cause() == null) { - for (;;) { - List entries = get(); - if (!entries.isEmpty()) { - final DefaultDnsCacheEntry firstEntry = entries.get(0); - if (firstEntry.cause() != null) { - assert entries.size() == 1; - if (compareAndSet(entries, Collections.singletonList(e))) { - firstEntry.cancelExpiration(); - return; - } else { - // Need to try again as CAS failed - continue; - } - } - - // Create a new List for COW semantics - List newEntries = new ArrayList(entries.size() + 1); - DefaultDnsCacheEntry replacedEntry = null; - for (int i = 0; i < entries.size(); i++) { - DefaultDnsCacheEntry entry = entries.get(i); - // Only add old entry if the address is not the same as the one we try to add as well. - // In this case we will skip it and just add the new entry as this may have - // more up-to-date data and cancel the old after we were able to update the cache. - if (!e.address().equals(entry.address())) { - newEntries.add(entry); - } else { - assert replacedEntry == null; - replacedEntry = entry; - } - } - newEntries.add(e); - if (compareAndSet(entries, newEntries)) { - if (replacedEntry != null) { - replacedEntry.cancelExpiration(); - } - return; - } - } else if (compareAndSet(entries, Collections.singletonList(e))) { - return; - } - } - } else { - List entries = getAndSet(Collections.singletonList(e)); - cancelExpiration(entries); - } - } - - boolean clearAndCancel() { - List entries = getAndSet(Collections.emptyList()); - if (entries.isEmpty()) { - return false; - } - - cancelExpiration(entries); - return true; - } - - private static void cancelExpiration(List entryList) { - final int numEntries = entryList.size(); - for (int i = 0; i < numEntries; i++) { - entryList.get(i).cancelExpiration(); - } - } + private static String appendDot(String hostname) { + return StringUtil.endsWith(hostname, '.') ? hostname : hostname + '.'; } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCnameCache.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCnameCache.java new file mode 100644 index 000000000000..8638808757ab --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCnameCache.java @@ -0,0 +1,99 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.EventLoop; +import io.netty.util.AsciiString; +import io.netty.util.internal.UnstableApi; + +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.*; + +/** + * Default implementation of a {@link DnsCnameCache}. + */ +@UnstableApi +public final class DefaultDnsCnameCache implements DnsCnameCache { + private final int minTtl; + private final int maxTtl; + + private final Cache cache = new Cache() { + @Override + protected boolean shouldReplaceAll(String entry) { + // Only one 1:1 mapping is supported as specified in the RFC. + return true; + } + + @Override + protected boolean equals(String entry, String otherEntry) { + return AsciiString.contentEqualsIgnoreCase(entry, otherEntry); + } + }; + + /** + * Create a cache that respects the TTL returned by the DNS server. + */ + public DefaultDnsCnameCache() { + this(0, Cache.MAX_SUPPORTED_TTL_SECS); + } + + /** + * Create a cache. + * + * @param minTtl the minimum TTL + * @param maxTtl the maximum TTL + */ + public DefaultDnsCnameCache(int minTtl, int maxTtl) { + this.minTtl = Math.min(Cache.MAX_SUPPORTED_TTL_SECS, checkPositiveOrZero(minTtl, "minTtl")); + this.maxTtl = Math.min(Cache.MAX_SUPPORTED_TTL_SECS, checkPositive(maxTtl, "maxTtl")); + if (minTtl > maxTtl) { + throw new IllegalArgumentException( + "minTtl: " + minTtl + ", maxTtl: " + maxTtl + " (expected: 0 <= minTtl <= maxTtl)"); + } + } + + @SuppressWarnings("unchecked") + @Override + public String get(String hostname) { + checkNotNull(hostname, "hostname"); + List cached = cache.get(hostname); + if (cached == null || cached.isEmpty()) { + return null; + } + // We can never have more then one record. + return cached.get(0); + } + + @Override + public void cache(String hostname, String cname, long originalTtl, EventLoop loop) { + checkNotNull(hostname, "hostname"); + checkNotNull(cname, "cname"); + checkNotNull(loop, "loop"); + cache.cache(hostname, cname, Math.max(minTtl, (int) Math.min(maxTtl, originalTtl)), loop); + } + + @Override + public void clear() { + cache.clear(); + } + + @Override + public boolean clear(String hostname) { + checkNotNull(hostname, "hostname"); + return cache.clear(hostname); + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsServerAddressStreamProvider.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsServerAddressStreamProvider.java index 610a51b7daa4..00be07229520 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsServerAddressStreamProvider.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsServerAddressStreamProvider.java @@ -16,23 +16,17 @@ package io.netty.resolver.dns; import io.netty.util.NetUtil; +import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.SocketUtils; import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; -import javax.naming.Context; -import javax.naming.NamingException; -import javax.naming.directory.DirContext; -import javax.naming.directory.InitialDirContext; import java.lang.reflect.Method; import java.net.Inet6Address; import java.net.InetSocketAddress; -import java.net.URI; -import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Collections; -import java.util.Hashtable; import java.util.List; import static io.netty.resolver.dns.DnsServerAddresses.sequential; @@ -50,47 +44,15 @@ public final class DefaultDnsServerAddressStreamProvider implements DnsServerAdd public static final DefaultDnsServerAddressStreamProvider INSTANCE = new DefaultDnsServerAddressStreamProvider(); private static final List DEFAULT_NAME_SERVER_LIST; - private static final InetSocketAddress[] DEFAULT_NAME_SERVER_ARRAY; private static final DnsServerAddresses DEFAULT_NAME_SERVERS; static final int DNS_PORT = 53; static { final List defaultNameServers = new ArrayList(2); - - // Using jndi-dns to obtain the default name servers. - // - // See: - // - http://docs.oracle.com/javase/8/docs/technotes/guides/jndi/jndi-dns.html - // - http://mail.openjdk.java.net/pipermail/net-dev/2017-March/010695.html - Hashtable env = new Hashtable(); - env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.dns.DnsContextFactory"); - env.put("java.naming.provider.url", "dns://"); - try { - DirContext ctx = new InitialDirContext(env); - String dnsUrls = (String) ctx.getEnvironment().get("java.naming.provider.url"); - // Only try if not empty as otherwise we will produce an exception - if (dnsUrls != null && !dnsUrls.isEmpty()) { - String[] servers = dnsUrls.split(" "); - for (String server : servers) { - try { - URI uri = new URI(server); - String host = new URI(server).getHost(); - - if (host == null || host.isEmpty()) { - logger.debug( - "Skipping a nameserver URI as host portion could not be extracted: {}", server); - // If the host portion can not be parsed we should just skip this entry. - continue; - } - int port = uri.getPort(); - defaultNameServers.add(SocketUtils.socketAddress(uri.getHost(), port == -1 ? DNS_PORT : port)); - } catch (URISyntaxException e) { - logger.debug("Skipping a malformed nameserver URI: {}", server, e); - } - } - } - } catch (NamingException ignore) { - // Will try reflection if this fails. + if (!PlatformDependent.isAndroid()) { + // Only try to use when not on Android as the classes not exists there: + // See https://github.com/netty/netty/issues/8654 + DirContextUtils.addNameServers(defaultNameServers, DNS_PORT); } if (defaultNameServers.isEmpty()) { @@ -142,8 +104,7 @@ public final class DefaultDnsServerAddressStreamProvider implements DnsServerAdd } DEFAULT_NAME_SERVER_LIST = Collections.unmodifiableList(defaultNameServers); - DEFAULT_NAME_SERVER_ARRAY = defaultNameServers.toArray(new InetSocketAddress[defaultNameServers.size()]); - DEFAULT_NAME_SERVERS = sequential(DEFAULT_NAME_SERVER_ARRAY); + DEFAULT_NAME_SERVERS = sequential(DEFAULT_NAME_SERVER_LIST); } private DefaultDnsServerAddressStreamProvider() { @@ -177,12 +138,4 @@ public static List defaultAddressList() { public static DnsServerAddresses defaultAddresses() { return DEFAULT_NAME_SERVERS; } - - /** - * Get the array form of {@link #defaultAddressList()}. - * @return The array form of {@link #defaultAddressList()}. - */ - static InetSocketAddress[] defaultAddressArray() { - return DEFAULT_NAME_SERVER_ARRAY.clone(); - } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsServerAddresses.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsServerAddresses.java index 0efd24491504..96da62e2ae59 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsServerAddresses.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsServerAddresses.java @@ -17,16 +17,17 @@ package io.netty.resolver.dns; import java.net.InetSocketAddress; +import java.util.List; abstract class DefaultDnsServerAddresses extends DnsServerAddresses { - protected final InetSocketAddress[] addresses; + protected final List addresses; private final String strVal; - DefaultDnsServerAddresses(String type, InetSocketAddress[] addresses) { + DefaultDnsServerAddresses(String type, List addresses) { this.addresses = addresses; - final StringBuilder buf = new StringBuilder(type.length() + 2 + addresses.length * 16); + final StringBuilder buf = new StringBuilder(type.length() + 2 + addresses.size() * 16); buf.append(type).append('('); for (InetSocketAddress a: addresses) { diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DirContextUtils.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DirContextUtils.java new file mode 100644 index 000000000000..45c1ac433f2e --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DirContextUtils.java @@ -0,0 +1,77 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.util.internal.SocketUtils; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import javax.naming.Context; +import javax.naming.NamingException; +import javax.naming.directory.DirContext; +import javax.naming.directory.InitialDirContext; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Hashtable; +import java.util.List; + +final class DirContextUtils { + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(DirContextUtils.class); + + private DirContextUtils() { } + + static void addNameServers(List defaultNameServers, int defaultPort) { + // Using jndi-dns to obtain the default name servers. + // + // See: + // - http://docs.oracle.com/javase/8/docs/technotes/guides/jndi/jndi-dns.html + // - http://mail.openjdk.java.net/pipermail/net-dev/2017-March/010695.html + Hashtable env = new Hashtable(); + env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.dns.DnsContextFactory"); + env.put("java.naming.provider.url", "dns://"); + + try { + DirContext ctx = new InitialDirContext(env); + String dnsUrls = (String) ctx.getEnvironment().get("java.naming.provider.url"); + // Only try if not empty as otherwise we will produce an exception + if (dnsUrls != null && !dnsUrls.isEmpty()) { + String[] servers = dnsUrls.split(" "); + for (String server : servers) { + try { + URI uri = new URI(server); + String host = new URI(server).getHost(); + + if (host == null || host.isEmpty()) { + logger.debug( + "Skipping a nameserver URI as host portion could not be extracted: {}", server); + // If the host portion can not be parsed we should just skip this entry. + continue; + } + int port = uri.getPort(); + defaultNameServers.add(SocketUtils.socketAddress(uri.getHost(), port == -1 ? + defaultPort : port)); + } catch (URISyntaxException e) { + logger.debug("Skipping a malformed nameserver URI: {}", server, e); + } + } + } + } catch (NamingException ignore) { + // Will try reflection if this fails. + } + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java index 63490efa4ad9..85078c67924b 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java @@ -25,15 +25,19 @@ import io.netty.channel.EventLoop; import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsRecordType; +import io.netty.util.concurrent.Promise; final class DnsAddressResolveContext extends DnsResolveContext { private final DnsCache resolveCache; + private final AuthoritativeDnsServerCache authoritativeDnsServerCache; DnsAddressResolveContext(DnsNameResolver parent, String hostname, DnsRecord[] additionals, - DnsServerAddressStream nameServerAddrs, DnsCache resolveCache) { + DnsServerAddressStream nameServerAddrs, DnsCache resolveCache, + AuthoritativeDnsServerCache authoritativeDnsServerCache) { super(parent, hostname, DnsRecord.CLASS_IN, parent.resolveRecordTypes(), additionals, nameServerAddrs); this.resolveCache = resolveCache; + this.authoritativeDnsServerCache = authoritativeDnsServerCache; } @Override @@ -41,7 +45,8 @@ DnsResolveContext newResolverContext(DnsNameResolver parent, String int dnsClass, DnsRecordType[] expectedTypes, DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) { - return new DnsAddressResolveContext(parent, hostname, additionals, nameServerAddrs, resolveCache); + return new DnsAddressResolveContext(parent, hostname, additionals, nameServerAddrs, resolveCache, + authoritativeDnsServerCache); } @Override @@ -84,4 +89,23 @@ void cache(String hostname, DnsRecord[] additionals, void cache(String hostname, DnsRecord[] additionals, UnknownHostException cause) { resolveCache.cache(hostname, additionals, cause, parent.ch.eventLoop()); } + + @Override + void doSearchDomainQuery(String hostname, Promise> nextPromise) { + // Query the cache for the hostname first and only do a query if we could not find it in the cache. + if (!DnsNameResolver.doResolveAllCached( + hostname, additionals, nextPromise, resolveCache, parent.resolvedInternetProtocolFamiliesUnsafe())) { + super.doSearchDomainQuery(hostname, nextPromise); + } + } + + @Override + DnsCache resolveCache() { + return resolveCache; + } + + @Override + AuthoritativeDnsServerCache authoritativeDnsServerCache() { + return authoritativeDnsServerCache; + } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsCnameCache.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsCnameCache.java new file mode 100644 index 000000000000..820ef67d2459 --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsCnameCache.java @@ -0,0 +1,59 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.EventLoop; +import io.netty.util.internal.UnstableApi; + +/** + * A cache for {@code CNAME}s. + */ +@UnstableApi +public interface DnsCnameCache { + + /** + * Returns the cached cname for the given hostname. + * + * @param hostname the hostname + * @return the cached entries or an {@code null} if none. + */ + String get(String hostname); + + /** + * Caches a cname entry that should be used for the given hostname. + * + * @param hostname the hostname + * @param cname the cname mapping. + * @param originalTtl the TTL as returned by the DNS server + * @param loop the {@link EventLoop} used to register the TTL timeout + */ + void cache(String hostname, String cname, long originalTtl, EventLoop loop); + + /** + * Clears all cached nameservers. + * + * @see #clear(String) + */ + void clear(); + + /** + * Clears the cached nameservers for the specified hostname. + * + * @return {@code true} if and only if there was an entry for the specified host name in the cache and + * it has been removed by this method + */ + boolean clear(String hostname); +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 5bd7baad2594..c8384fffe9f9 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -47,6 +47,7 @@ import io.netty.resolver.ResolvedAddressTypes; import io.netty.util.NetUtil; import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; @@ -67,6 +68,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.Iterator; import java.util.List; @@ -124,14 +126,10 @@ public class DnsNameResolver extends InetNameResolver { static { String[] searchDomains; try { - Class configClass = Class.forName("sun.net.dns.ResolverConfiguration"); - Method open = configClass.getMethod("open"); - Method nameservers = configClass.getMethod("searchlist"); - Object instance = open.invoke(null); - - @SuppressWarnings("unchecked") - List list = (List) nameservers.invoke(instance); - searchDomains = list.toArray(new String[list.size()]); + List list = PlatformDependent.isWindows() + ? getSearchDomainsHack() + : UnixResolverDnsServerAddressStreamProvider.parseEtcResolverSearchDomains(); + searchDomains = list.toArray(new String[0]); } catch (Exception ignore) { // Failed to get the system name search domain list. searchDomains = EmptyArrays.EMPTY_STRINGS; @@ -147,12 +145,26 @@ public class DnsNameResolver extends InetNameResolver { DEFAULT_NDOTS = ndots; } + @SuppressWarnings("unchecked") + private static List getSearchDomainsHack() throws Exception { + // This code on Java 9+ yields a warning about illegal reflective access that will be denied in + // a future release. There doesn't seem to be a better way to get search domains for Windows yet. + Class configClass = Class.forName("sun.net.dns.ResolverConfiguration"); + Method open = configClass.getMethod("open"); + Method nameservers = configClass.getMethod("searchlist"); + Object instance = open.invoke(null); + + return (List) nameservers.invoke(instance); + } + private static final DatagramDnsResponseDecoder DECODER = new DatagramDnsResponseDecoder(); private static final DatagramDnsQueryEncoder ENCODER = new DatagramDnsQueryEncoder(); final Future channelFuture; - final DatagramChannel ch; + final Channel ch; + // Comparator that ensures we will try first to use the nameservers that use our preferred address type. + private final Comparator nameServerComparator; /** * Manages the {@link DnsQueryContext}s in progress and their query IDs. */ @@ -162,12 +174,13 @@ public class DnsNameResolver extends InetNameResolver { * Cache for {@link #doResolve(String, Promise)} and {@link #doResolveAll(String, Promise)}. */ private final DnsCache resolveCache; - private final DnsCache authoritativeDnsServerCache; + private final AuthoritativeDnsServerCache authoritativeDnsServerCache; + private final DnsCnameCache cnameCache; private final FastThreadLocal nameServerAddrStream = new FastThreadLocal() { @Override - protected DnsServerAddressStream initialValue() throws Exception { + protected DnsServerAddressStream initialValue() { return dnsServerAddressStreamProvider.nameServerAddressStream(""); } }; @@ -214,12 +227,91 @@ protected DnsServerAddressStream initialValue() throws Exception { * @param ndots the ndots value * @param decodeIdn {@code true} if domain / host names should be decoded to unicode when received. * See rfc3492. + * @deprecated Use {@link DnsNameResolverBuilder}. + */ + @Deprecated + public DnsNameResolver( + EventLoop eventLoop, + ChannelFactory channelFactory, + final DnsCache resolveCache, + final DnsCache authoritativeDnsServerCache, + DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory, + long queryTimeoutMillis, + ResolvedAddressTypes resolvedAddressTypes, + boolean recursionDesired, + int maxQueriesPerResolve, + boolean traceEnabled, + int maxPayloadSize, + boolean optResourceEnabled, + HostsFileEntriesResolver hostsFileEntriesResolver, + DnsServerAddressStreamProvider dnsServerAddressStreamProvider, + String[] searchDomains, + int ndots, + boolean decodeIdn) { + this(eventLoop, channelFactory, resolveCache, + new AuthoritativeDnsServerCacheAdapter(authoritativeDnsServerCache), dnsQueryLifecycleObserverFactory, + queryTimeoutMillis, resolvedAddressTypes, recursionDesired, maxQueriesPerResolve, traceEnabled, + maxPayloadSize, optResourceEnabled, hostsFileEntriesResolver, dnsServerAddressStreamProvider, + searchDomains, ndots, decodeIdn); + } + + /** + * Creates a new DNS-based name resolver that communicates with the specified list of DNS servers. + * + * @param eventLoop the {@link EventLoop} which will perform the communication with the DNS servers + * @param channelFactory the {@link ChannelFactory} that will create a {@link DatagramChannel} + * @param resolveCache the DNS resolved entries cache + * @param authoritativeDnsServerCache the cache used to find the authoritative DNS server for a domain + * @param dnsQueryLifecycleObserverFactory used to generate new instances of {@link DnsQueryLifecycleObserver} which + * can be used to track metrics for DNS servers. + * @param queryTimeoutMillis timeout of each DNS query in millis + * @param resolvedAddressTypes the preferred address types + * @param recursionDesired if recursion desired flag must be set + * @param maxQueriesPerResolve the maximum allowed number of DNS queries for a given name resolution + * @param traceEnabled if trace is enabled + * @param maxPayloadSize the capacity of the datagram packet buffer + * @param optResourceEnabled if automatic inclusion of a optional records is enabled + * @param hostsFileEntriesResolver the {@link HostsFileEntriesResolver} used to check for local aliases + * @param dnsServerAddressStreamProvider The {@link DnsServerAddressStreamProvider} used to determine the name + * servers for each hostname lookup. + * @param searchDomains the list of search domain + * (can be null, if so, will try to default to the underlying platform ones) + * @param ndots the ndots value + * @param decodeIdn {@code true} if domain / host names should be decoded to unicode when received. + * See rfc3492. + * @deprecated Use {@link DnsNameResolverBuilder}. */ + @Deprecated public DnsNameResolver( EventLoop eventLoop, ChannelFactory channelFactory, final DnsCache resolveCache, - DnsCache authoritativeDnsServerCache, + final AuthoritativeDnsServerCache authoritativeDnsServerCache, + DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory, + long queryTimeoutMillis, + ResolvedAddressTypes resolvedAddressTypes, + boolean recursionDesired, + int maxQueriesPerResolve, + boolean traceEnabled, + int maxPayloadSize, + boolean optResourceEnabled, + HostsFileEntriesResolver hostsFileEntriesResolver, + DnsServerAddressStreamProvider dnsServerAddressStreamProvider, + String[] searchDomains, + int ndots, + boolean decodeIdn) { + this(eventLoop, channelFactory, resolveCache, NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache, + dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes, recursionDesired, + maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled, hostsFileEntriesResolver, + dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn); + } + + DnsNameResolver( + EventLoop eventLoop, + ChannelFactory channelFactory, + final DnsCache resolveCache, + final DnsCnameCache cnameCache, + final AuthoritativeDnsServerCache authoritativeDnsServerCache, DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory, long queryTimeoutMillis, ResolvedAddressTypes resolvedAddressTypes, @@ -244,12 +336,12 @@ public DnsNameResolver( this.dnsServerAddressStreamProvider = checkNotNull(dnsServerAddressStreamProvider, "dnsServerAddressStreamProvider"); this.resolveCache = checkNotNull(resolveCache, "resolveCache"); - this.authoritativeDnsServerCache = checkNotNull(authoritativeDnsServerCache, "authoritativeDnsServerCache"); + this.cnameCache = checkNotNull(cnameCache, "cnameCache"); this.dnsQueryLifecycleObserverFactory = traceEnabled ? - dnsQueryLifecycleObserverFactory instanceof NoopDnsQueryLifecycleObserverFactory ? - new TraceDnsQueryLifeCycleObserverFactory() : - new BiDnsQueryLifecycleObserverFactory(new TraceDnsQueryLifeCycleObserverFactory(), - dnsQueryLifecycleObserverFactory) : + dnsQueryLifecycleObserverFactory instanceof NoopDnsQueryLifecycleObserverFactory ? + new TraceDnsQueryLifeCycleObserverFactory() : + new BiDnsQueryLifecycleObserverFactory(new TraceDnsQueryLifeCycleObserverFactory(), + dnsQueryLifecycleObserverFactory) : checkNotNull(dnsQueryLifecycleObserverFactory, "dnsQueryLifecycleObserverFactory"); this.searchDomains = searchDomains != null ? searchDomains.clone() : DEFAULT_SEARCH_DOMAINS; this.ndots = ndots >= 0 ? ndots : DEFAULT_NDOTS; @@ -261,32 +353,31 @@ public DnsNameResolver( supportsARecords = true; resolveRecordTypes = IPV4_ONLY_RESOLVED_RECORD_TYPES; resolvedInternetProtocolFamilies = IPV4_ONLY_RESOLVED_PROTOCOL_FAMILIES; - preferredAddressType = InternetProtocolFamily.IPv4; break; case IPV4_PREFERRED: supportsAAAARecords = true; supportsARecords = true; resolveRecordTypes = IPV4_PREFERRED_RESOLVED_RECORD_TYPES; resolvedInternetProtocolFamilies = IPV4_PREFERRED_RESOLVED_PROTOCOL_FAMILIES; - preferredAddressType = InternetProtocolFamily.IPv4; break; case IPV6_ONLY: supportsAAAARecords = true; supportsARecords = false; resolveRecordTypes = IPV6_ONLY_RESOLVED_RECORD_TYPES; resolvedInternetProtocolFamilies = IPV6_ONLY_RESOLVED_PROTOCOL_FAMILIES; - preferredAddressType = InternetProtocolFamily.IPv6; break; case IPV6_PREFERRED: supportsAAAARecords = true; supportsARecords = true; resolveRecordTypes = IPV6_PREFERRED_RESOLVED_RECORD_TYPES; resolvedInternetProtocolFamilies = IPV6_PREFERRED_RESOLVED_PROTOCOL_FAMILIES; - preferredAddressType = InternetProtocolFamily.IPv6; break; default: throw new IllegalArgumentException("Unknown ResolvedAddressTypes " + resolvedAddressTypes); } + preferredAddressType = preferredAddressType(this.resolvedAddressTypes); + this.authoritativeDnsServerCache = checkNotNull(authoritativeDnsServerCache, "authoritativeDnsServerCache"); + nameServerComparator = new NameServerComparator(preferredAddressType.addressType()); Bootstrap b = new Bootstrap(); b.group(executor()); @@ -301,20 +392,46 @@ protected void initChannel(DatagramChannel ch) throws Exception { }); channelFuture = responseHandler.channelActivePromise; - ch = (DatagramChannel) b.register().channel(); + ChannelFuture future = b.register(); + Throwable cause = future.cause(); + if (cause != null) { + if (cause instanceof RuntimeException) { + throw (RuntimeException) cause; + } + if (cause instanceof Error) { + throw (Error) cause; + } + throw new IllegalStateException("Unable to create / register Channel", cause); + } + ch = future.channel(); ch.config().setRecvByteBufAllocator(new FixedRecvByteBufAllocator(maxPayloadSize)); ch.closeFuture().addListener(new ChannelFutureListener() { @Override - public void operationComplete(ChannelFuture future) throws Exception { + public void operationComplete(ChannelFuture future) { resolveCache.clear(); + cnameCache.clear(); + authoritativeDnsServerCache.clear(); } }); } + static InternetProtocolFamily preferredAddressType(ResolvedAddressTypes resolvedAddressTypes) { + switch (resolvedAddressTypes) { + case IPV4_ONLY: + case IPV4_PREFERRED: + return InternetProtocolFamily.IPv4; + case IPV6_ONLY: + case IPV6_PREFERRED: + return InternetProtocolFamily.IPv6; + default: + throw new IllegalArgumentException("Unknown ResolvedAddressTypes " + resolvedAddressTypes); + } + } + // Only here to override in unit tests. - int dnsRedirectPort(@SuppressWarnings("unused") InetAddress server) { - return DNS_PORT; + InetSocketAddress newRedirectServerAddress(InetAddress server) { + return new InetSocketAddress(server, DNS_PORT); } final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory() { @@ -322,12 +439,26 @@ final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory() { } /** - * Provides the opportunity to sort the name servers before following a redirected DNS query. - * @param nameServers The addresses of the DNS servers which are used in the event of a redirect. - * @return A {@link DnsServerAddressStream} which will be used to follow the DNS redirect. + * Creates a new {@link DnsServerAddressStream} to following a redirected DNS query. By overriding this + * it provides the opportunity to sort the name servers before following a redirected DNS query. + * + * @param hostname the hostname. + * @param nameservers The addresses of the DNS servers which are used in the event of a redirect. This may + * contain resolved and unresolved addresses so the used {@link DnsServerAddressStream} must + * allow unresolved addresses if you want to include these as well. + * @return A {@link DnsServerAddressStream} which will be used to follow the DNS redirect or {@code null} if + * none should be followed. */ - protected DnsServerAddressStream uncachedRedirectDnsServerStream(List nameServers) { - return DnsServerAddresses.sequential(nameServers).stream(); + protected DnsServerAddressStream newRedirectDnsServerStream( + @SuppressWarnings("unused") String hostname, List nameservers) { + DnsServerAddressStream cached = authoritativeDnsServerCache().get(hostname); + if (cached == null || cached.size() == 0) { + // If there is no cache hit (which may be the case for example when a NoopAuthoritativeDnsServerCache + // is used), we will just directly use the provided nameservers. + Collections.sort(nameservers, nameServerComparator); + return new SequentialDnsServerAddressStream(nameservers, 0); + } + return cached; } /** @@ -337,10 +468,17 @@ public DnsCache resolveCache() { return resolveCache; } + /** + * Returns the {@link DnsCnameCache}. + */ + DnsCnameCache cnameCache() { + return cnameCache; + } + /** * Returns the cache used for authoritative DNS servers for a domain. */ - public DnsCache authoritativeDnsServerCache() { + public AuthoritativeDnsServerCache authoritativeDnsServerCache() { return authoritativeDnsServerCache; } @@ -784,15 +922,16 @@ protected void doResolveAll(String inetHost, return; } - if (!doResolveAllCached(hostname, additionals, promise, resolveCache)) { + if (!doResolveAllCached(hostname, additionals, promise, resolveCache, resolvedInternetProtocolFamilies)) { doResolveAllUncached(hostname, additionals, promise, resolveCache); } } - private boolean doResolveAllCached(String hostname, - DnsRecord[] additionals, - Promise> promise, - DnsCache resolveCache) { + static boolean doResolveAllCached(String hostname, + DnsRecord[] additionals, + Promise> promise, + DnsCache resolveCache, + InternetProtocolFamily[] resolvedInternetProtocolFamilies) { final List cachedEntries = resolveCache.get(hostname, additionals); if (cachedEntries == null || cachedEntries.isEmpty()) { return false; @@ -824,13 +963,36 @@ private boolean doResolveAllCached(String hostname, } } - private void doResolveAllUncached(String hostname, + private void doResolveAllUncached(final String hostname, + final DnsRecord[] additionals, + final Promise> promise, + final DnsCache resolveCache) { + // Call doResolveUncached0(...) in the EventLoop as we may need to submit multiple queries which would need + // to submit multiple Runnable at the end if we are not already on the EventLoop. + EventExecutor executor = executor(); + if (executor.inEventLoop()) { + doResolveAllUncached0(hostname, additionals, promise, resolveCache); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + doResolveAllUncached0(hostname, additionals, promise, resolveCache); + } + }); + } + } + + private void doResolveAllUncached0(String hostname, DnsRecord[] additionals, Promise> promise, DnsCache resolveCache) { + + assert executor().inEventLoop(); + final DnsServerAddressStream nameServerAddrs = dnsServerAddressStreamProvider.nameServerAddressStream(hostname); - new DnsAddressResolveContext(this, hostname, additionals, nameServerAddrs, resolveCache).resolve(promise); + new DnsAddressResolveContext(this, hostname, additionals, nameServerAddrs, + resolveCache, authoritativeDnsServerCache).resolve(promise); } private static String hostname(String inetHost) { @@ -875,8 +1037,8 @@ private InetSocketAddress nextNameServerAddress() { public Future> query( InetSocketAddress nameServerAddr, DnsQuestion question) { - return query0(nameServerAddr, question, EMPTY_ADDITIONALS, - ch.eventLoop().>newPromise()); + return query0(nameServerAddr, question, EMPTY_ADDITIONALS, true, ch.newPromise(), + ch.eventLoop().>newPromise()); } /** @@ -885,8 +1047,8 @@ public Future> query( public Future> query( InetSocketAddress nameServerAddr, DnsQuestion question, Iterable additionals) { - return query0(nameServerAddr, question, toArray(additionals, false), - ch.eventLoop().>newPromise()); + return query0(nameServerAddr, question, toArray(additionals, false), true, ch.newPromise(), + ch.eventLoop().>newPromise()); } /** @@ -896,7 +1058,7 @@ public Future> query( InetSocketAddress nameServerAddr, DnsQuestion question, Promise> promise) { - return query0(nameServerAddr, question, EMPTY_ADDITIONALS, promise); + return query0(nameServerAddr, question, EMPTY_ADDITIONALS, true, ch.newPromise(), promise); } /** @@ -907,7 +1069,7 @@ public Future> query( Iterable additionals, Promise> promise) { - return query0(nameServerAddr, question, toArray(additionals, false), promise); + return query0(nameServerAddr, question, toArray(additionals, false), true, ch.newPromise(), promise); } /** @@ -928,16 +1090,14 @@ public static boolean isTimeoutError(Throwable cause) { return cause != null && cause.getCause() instanceof DnsNameResolverTimeoutException; } - final Future> query0( - InetSocketAddress nameServerAddr, DnsQuestion question, - DnsRecord[] additionals, - Promise> promise) { - return query0(nameServerAddr, question, additionals, ch.newPromise(), promise); + final void flushQueries() { + ch.flush(); } final Future> query0( InetSocketAddress nameServerAddr, DnsQuestion question, DnsRecord[] additionals, + boolean flush, ChannelPromise writePromise, Promise> promise) { assert !writePromise.isVoid(); @@ -945,7 +1105,8 @@ final Future> query0( final Promise> castPromise = cast( checkNotNull(promise, "promise")); try { - new DnsQueryContext(this, nameServerAddr, question, additionals, castPromise).query(writePromise); + new DnsQueryContext(this, nameServerAddr, question, additionals, castPromise) + .query(flush, writePromise); return castPromise; } catch (Exception e) { return castPromise.setFailure(e); @@ -957,6 +1118,10 @@ private static Promise> cast(P return (Promise>) promise; } + final DnsServerAddressStream newNameServerAddressStream(String hostname) { + return dnsServerAddressStreamProvider.nameServerAddressStream(hostname); + } + private final class DnsResponseHandler extends ChannelInboundHandlerAdapter { private final Promise channelActivePromise; @@ -966,7 +1131,7 @@ private final class DnsResponseHandler extends ChannelInboundHandlerAdapter { } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object msg) { try { final DatagramDnsResponse res = (DatagramDnsResponse) msg; final int queryId = res.id(); @@ -994,7 +1159,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { logger.warn("{} Unexpected exception: ", ch, cause); } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java index a5df3409dfef..06266ccfd36c 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java @@ -40,7 +40,8 @@ public final class DnsNameResolverBuilder { private EventLoop eventLoop; private ChannelFactory channelFactory; private DnsCache resolveCache; - private DnsCache authoritativeDnsServerCache; + private DnsCnameCache cnameCache; + private AuthoritativeDnsServerCache authoritativeDnsServerCache; private Integer minTtl; private Integer maxTtl; private Integer negativeTtl; @@ -123,6 +124,17 @@ public DnsNameResolverBuilder resolveCache(DnsCache resolveCache) { return this; } + /** + * Sets the cache for {@code CNAME} mappings. + * + * @param cnameCache the cache used to cache {@code CNAME} mappings for a domain. + * @return {@code this} + */ + public DnsNameResolverBuilder cnameCache(DnsCnameCache cnameCache) { + this.cnameCache = cnameCache; + return this; + } + /** * Set the factory used to generate objects which can observe individual DNS queries. * @param lifecycleObserverFactory the factory used to generate objects which can observe individual DNS queries. @@ -139,8 +151,21 @@ public DnsNameResolverBuilder dnsQueryLifecycleObserverFactory(DnsQueryLifecycle * * @param authoritativeDnsServerCache the authoritative NS servers cache * @return {@code this} + * @deprecated Use {@link #authoritativeDnsServerCache(AuthoritativeDnsServerCache)} */ + @Deprecated public DnsNameResolverBuilder authoritativeDnsServerCache(DnsCache authoritativeDnsServerCache) { + this.authoritativeDnsServerCache = new AuthoritativeDnsServerCacheAdapter(authoritativeDnsServerCache); + return this; + } + + /** + * Sets the cache for authoritative NS servers + * + * @param authoritativeDnsServerCache the authoritative NS servers cache + * @return {@code this} + */ + public DnsNameResolverBuilder authoritativeDnsServerCache(AuthoritativeDnsServerCache authoritativeDnsServerCache) { this.authoritativeDnsServerCache = authoritativeDnsServerCache; return this; } @@ -335,7 +360,7 @@ public DnsNameResolverBuilder searchDomains(Iterable searchDomains) { list.add(f); } - this.searchDomains = list.toArray(new String[list.size()]); + this.searchDomains = list.toArray(new String[0]); return this; } @@ -355,6 +380,19 @@ private DnsCache newCache() { return new DefaultDnsCache(intValue(minTtl, 0), intValue(maxTtl, Integer.MAX_VALUE), intValue(negativeTtl, 0)); } + private AuthoritativeDnsServerCache newAuthoritativeDnsServerCache() { + return new DefaultAuthoritativeDnsServerCache( + intValue(minTtl, 0), intValue(maxTtl, Integer.MAX_VALUE), + // Let us use the sane ordering as DnsNameResolver will be used when returning + // nameservers from the cache. + new NameServerComparator(DnsNameResolver.preferredAddressType(resolvedAddressTypes).addressType())); + } + + private DnsCnameCache newCnameCache() { + return new DefaultDnsCnameCache( + intValue(minTtl, 0), intValue(maxTtl, Integer.MAX_VALUE)); + } + /** * Set if domain / host names should be decoded to unicode when received. * See rfc3492. @@ -386,12 +424,14 @@ public DnsNameResolver build() { } DnsCache resolveCache = this.resolveCache != null ? this.resolveCache : newCache(); - DnsCache authoritativeDnsServerCache = this.authoritativeDnsServerCache != null ? - this.authoritativeDnsServerCache : newCache(); + DnsCnameCache cnameCache = this.cnameCache != null ? this.cnameCache : newCnameCache(); + AuthoritativeDnsServerCache authoritativeDnsServerCache = this.authoritativeDnsServerCache != null ? + this.authoritativeDnsServerCache : newAuthoritativeDnsServerCache(); return new DnsNameResolver( eventLoop, channelFactory, resolveCache, + cnameCache, authoritativeDnsServerCache, dnsQueryLifecycleObserverFactory, queryTimeoutMillis, @@ -428,6 +468,9 @@ public DnsNameResolverBuilder copy() { copiedBuilder.resolveCache(resolveCache); } + if (cnameCache != null) { + copiedBuilder.cnameCache(cnameCache); + } if (maxTtl != null && minTtl != null) { copiedBuilder.ttl(minTtl, maxTtl); } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java index 12196f184ca9..08bbcb6088f7 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java @@ -28,6 +28,7 @@ import io.netty.handler.codec.dns.DnsResponse; import io.netty.handler.codec.dns.DnsSection; import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.ScheduledFuture; @@ -39,7 +40,7 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull; -final class DnsQueryContext { +final class DnsQueryContext implements FutureListener> { private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsQueryContext.class); @@ -68,6 +69,9 @@ final class DnsQueryContext { recursionDesired = parent.isRecursionDesired(); id = parent.queryContextManager.add(this); + // Ensure we remove the id from the QueryContextManager once the query completes. + promise.addListener(this); + if (parent.isOptResourceEnabled()) { optResource = new AbstractDnsOptPseudoRrRecord(parent.maxPayloadSize(), 0, 0) { // We may want to remove this in the future and let the user just specify the opt record in the query. @@ -85,7 +89,7 @@ DnsQuestion question() { return question; } - void query(ChannelPromise writePromise) { + void query(boolean flush, ChannelPromise writePromise) { final DnsQuestion question = question(); final InetSocketAddress nameServerAddr = nameServerAddr(); final DatagramDnsQuery query = new DatagramDnsQuery(null, nameServerAddr, id); @@ -106,18 +110,21 @@ void query(ChannelPromise writePromise) { logger.debug("{} WRITE: [{}: {}], {}", parent.ch, id, nameServerAddr, question); } - sendQuery(query, writePromise); + sendQuery(query, flush, writePromise); } - private void sendQuery(final DnsQuery query, final ChannelPromise writePromise) { + private void sendQuery(final DnsQuery query, final boolean flush, final ChannelPromise writePromise) { if (parent.channelFuture.isDone()) { - writeQuery(query, writePromise); + writeQuery(query, flush, writePromise); } else { parent.channelFuture.addListener(new GenericFutureListener>() { @Override - public void operationComplete(Future future) throws Exception { + public void operationComplete(Future future) { if (future.isSuccess()) { - writeQuery(query, writePromise); + // If the query is done in a late fashion (as the channel was not ready yet) we always flush + // to ensure we did not race with a previous flush() that was done when the Channel was not + // ready yet. + writeQuery(query, true, writePromise); } else { Throwable cause = future.cause(); promise.tryFailure(cause); @@ -128,14 +135,15 @@ public void operationComplete(Future future) throws Exception { } } - private void writeQuery(final DnsQuery query, final ChannelPromise writePromise) { - final ChannelFuture writeFuture = parent.ch.writeAndFlush(query, writePromise); + private void writeQuery(final DnsQuery query, final boolean flush, final ChannelPromise writePromise) { + final ChannelFuture writeFuture = flush ? parent.ch.writeAndFlush(query, writePromise) : + parent.ch.write(query, writePromise); if (writeFuture.isDone()) { onQueryWriteCompletion(writeFuture); } else { writeFuture.addListener(new ChannelFutureListener() { @Override - public void operationComplete(ChannelFuture future) throws Exception { + public void operationComplete(ChannelFuture future) { onQueryWriteCompletion(writeFuture); } }); @@ -181,29 +189,18 @@ void finish(AddressedEnvelope envelope } private void setSuccess(AddressedEnvelope envelope) { - parent.queryContextManager.remove(nameServerAddr(), id); - - // Cancel the timeout task. - final ScheduledFuture timeoutFuture = this.timeoutFuture; - if (timeoutFuture != null) { - timeoutFuture.cancel(false); - } - Promise> promise = this.promise; - if (promise.setUncancellable()) { - @SuppressWarnings("unchecked") - AddressedEnvelope castResponse = - (AddressedEnvelope) envelope.retain(); - if (!promise.trySuccess(castResponse)) { - // We failed to notify the promise as it was failed before, thus we need to release the envelope - envelope.release(); - } + @SuppressWarnings("unchecked") + AddressedEnvelope castResponse = + (AddressedEnvelope) envelope.retain(); + if (!promise.trySuccess(castResponse)) { + // We failed to notify the promise as it was failed before, thus we need to release the envelope + envelope.release(); } } private void setFailure(String message, Throwable cause) { final InetSocketAddress nameServerAddr = nameServerAddr(); - parent.queryContextManager.remove(nameServerAddr, id); final StringBuilder buf = new StringBuilder(message.length() + 64); buf.append('[') @@ -222,4 +219,18 @@ private void setFailure(String message, Throwable cause) { } promise.tryFailure(e); } + + @Override + public void operationComplete(Future> future) { + // Cancel the timeout task. + final ScheduledFuture timeoutFuture = this.timeoutFuture; + if (timeoutFuture != null) { + this.timeoutFuture = null; + timeoutFuture.cancel(false); + } + + // Remove the id from the manager as soon as the query completes. This may be because of success, failure or + // cancellation + parent.queryContextManager.remove(nameServerAddr, id); + } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java index bc63e710e14d..f625be975cea 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java @@ -44,20 +44,21 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; +import java.util.AbstractList; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.IdentityHashMap; import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Set; import static io.netty.resolver.dns.DnsAddressDecoder.decodeAddress; -import static io.netty.resolver.dns.DnsNameResolver.trySuccess; import static java.lang.Math.min; -import static java.util.Collections.unmodifiableList; abstract class DnsResolveContext { @@ -97,7 +98,7 @@ public void operationComplete(Future>> queriesInProgress = Collections.newSetFromMap( @@ -124,6 +125,27 @@ public void operationComplete(Future> promise) { final String[] searchDomains = parent.searchDomains(); if (searchDomains.length == 0 || parent.ndots() == 0 || StringUtil.endsWith(hostname, '.')) { - internalResolve(promise); + internalResolve(hostname, promise); } else { final boolean startWithoutSearchDomain = hasNDots(); final String initialHostname = startWithoutSearchDomain ? hostname : hostname + '.' + searchDomains[0]; final int initialSearchDomainIdx = startWithoutSearchDomain ? 0 : 1; - doSearchDomainQuery(initialHostname, new FutureListener>() { + final Promise> searchDomainPromise = parent.executor().newPromise(); + searchDomainPromise.addListener(new FutureListener>() { private int searchDomainIdx = initialSearchDomainIdx; @Override public void operationComplete(Future> future) { @@ -175,15 +198,18 @@ public void operationComplete(Future> future) { if (DnsNameResolver.isTransportOrTimeoutError(cause)) { promise.tryFailure(new SearchDomainUnknownHostException(cause, hostname)); } else if (searchDomainIdx < searchDomains.length) { - doSearchDomainQuery(hostname + '.' + searchDomains[searchDomainIdx++], this); + Promise> newPromise = parent.executor().newPromise(); + newPromise.addListener(this); + doSearchDomainQuery(hostname + '.' + searchDomains[searchDomainIdx++], newPromise); } else if (!startWithoutSearchDomain) { - internalResolve(promise); + internalResolve(hostname, promise); } else { promise.tryFailure(new SearchDomainUnknownHostException(cause, hostname)); } } } }); + doSearchDomainQuery(initialHostname, searchDomainPromise); } } @@ -213,35 +239,42 @@ public Throwable fillInStackTrace() { } } - private void doSearchDomainQuery(String hostname, FutureListener> listener) { + void doSearchDomainQuery(String hostname, Promise> nextPromise) { DnsResolveContext nextContext = newResolverContext(parent, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs); - Promise> nextPromise = parent.executor().newPromise(); - nextContext.internalResolve(nextPromise); - nextPromise.addListener(listener); + nextContext.internalResolve(hostname, nextPromise); } - private void internalResolve(Promise> promise) { - DnsServerAddressStream nameServerAddressStream = getNameServers(hostname); + private static String hostnameWithDot(String name) { + if (StringUtil.endsWith(name, '.')) { + return name; + } + return name + '.'; + } - final int end = expectedTypes.length - 1; - for (int i = 0; i < end; ++i) { - if (!query(hostname, expectedTypes[i], nameServerAddressStream.duplicate(), promise)) { - return; + private void internalResolve(String name, Promise> promise) { + for (;;) { + // Resolve from cnameCache() until there is no more cname entry cached. + String mapping = cnameCache().get(hostnameWithDot(name)); + if (mapping == null) { + break; } + name = mapping; } - query(hostname, expectedTypes[end], nameServerAddressStream, promise); - } - /** - * Add an authoritative nameserver to the cache if its not a root server. - */ - private void addNameServerToCache( - AuthoritativeNameServer name, InetAddress resolved, long ttl) { - if (!name.isRootServer()) { - // Cache NS record if not for a root server as we should never cache for root servers. - parent.authoritativeDnsServerCache().cache(name.domainName(), - additionals, resolved, ttl, parent.ch.eventLoop()); + try { + DnsServerAddressStream nameServerAddressStream = getNameServers(name); + + final int end = expectedTypes.length - 1; + for (int i = 0; i < end; ++i) { + if (!query(name, expectedTypes[i], nameServerAddressStream.duplicate(), false, promise)) { + return; + } + } + query(name, expectedTypes[end], nameServerAddressStream, false, promise); + } finally { + // Now flush everything we submitted before. + parent.flushQueries(); } } @@ -280,55 +313,20 @@ private DnsServerAddressStream getNameServersFromCache(String hostname) { } idx = idx2; - List entries = parent.authoritativeDnsServerCache().get(hostname, additionals); - if (entries != null && !entries.isEmpty()) { - return DnsServerAddresses.sequential(new DnsCacheIterable(entries)).stream(); + DnsServerAddressStream entries = authoritativeDnsServerCache().get(hostname); + if (entries != null) { + // The returned List may contain unresolved InetSocketAddress instances that will be + // resolved on the fly in query(....). + return entries; } } } - private final class DnsCacheIterable implements Iterable { - private final List entries; - - DnsCacheIterable(List entries) { - this.entries = entries; - } - - @Override - public Iterator iterator() { - return new Iterator() { - Iterator entryIterator = entries.iterator(); - - @Override - public boolean hasNext() { - return entryIterator.hasNext(); - } - - @Override - public InetSocketAddress next() { - InetAddress address = entryIterator.next().address(); - return new InetSocketAddress(address, parent.dnsRedirectPort(address)); - } - - @Override - public void remove() { - entryIterator.remove(); - } - }; - } - } - - private void query(final DnsServerAddressStream nameServerAddrStream, final int nameServerAddrStreamIndex, - final DnsQuestion question, - final Promise> promise, Throwable cause) { - query(nameServerAddrStream, nameServerAddrStreamIndex, question, - parent.dnsQueryLifecycleObserverFactory().newDnsQueryLifecycleObserver(question), promise, cause); - } - private void query(final DnsServerAddressStream nameServerAddrStream, final int nameServerAddrStreamIndex, final DnsQuestion question, final DnsQueryLifecycleObserver queryLifecycleObserver, + final boolean flush, final Promise> promise, final Throwable cause) { if (nameServerAddrStreamIndex >= nameServerAddrStream.size() || allowedQueries == 0 || promise.isCancelled()) { @@ -338,11 +336,20 @@ private void query(final DnsServerAddressStream nameServerAddrStream, } --allowedQueries; + final InetSocketAddress nameServerAddr = nameServerAddrStream.next(); + if (nameServerAddr.isUnresolved()) { + queryUnresolvedNameserver(nameServerAddr, nameServerAddrStream, nameServerAddrStreamIndex, question, + queryLifecycleObserver, promise, cause); + return; + } final ChannelPromise writePromise = parent.ch.newPromise(); - final Future> f = parent.query0( - nameServerAddr, question, additionals, writePromise, - parent.ch.eventLoop().>newPromise()); + final Promise> queryPromise = + parent.ch.eventLoop().newPromise(); + + final Future> f = + parent.query0(nameServerAddr, question, additionals, flush, writePromise, queryPromise); + queriesInProgress.add(f); queryLifecycleObserver.queryWritten(nameServerAddr, writePromise); @@ -372,7 +379,8 @@ public void operationComplete(Future> promise, + final Throwable cause) { + final String nameServerName = PlatformDependent.javaVersion() >= 7 ? + nameServerAddr.getHostString() : nameServerAddr.getHostName(); + assert nameServerName != null; + + // Placeholder so we will not try to finish the original query yet. + final Future> resolveFuture = parent.executor() + .newSucceededFuture(null); + queriesInProgress.add(resolveFuture); + + Promise> resolverPromise = parent.executor().newPromise(); + resolverPromise.addListener(new FutureListener>() { + @Override + public void operationComplete(final Future> future) { + // Remove placeholder. + queriesInProgress.remove(resolveFuture); + + if (future.isSuccess()) { + List resolvedAddresses = future.getNow(); + DnsServerAddressStream addressStream = new CombinedDnsServerAddressStream( + nameServerAddr, resolvedAddresses, nameServerAddrStream); + query(addressStream, nameServerAddrStreamIndex, question, + queryLifecycleObserver, true, promise, cause); + } else { + // Ignore the server and try the next one... + query(nameServerAddrStream, nameServerAddrStreamIndex + 1, + question, queryLifecycleObserver, true, promise, cause); + } + } + }); + if (!DnsNameResolver.doResolveAllCached(nameServerName, additionals, resolverPromise, resolveCache(), + parent.resolvedInternetProtocolFamiliesUnsafe())) { + final AuthoritativeDnsServerCache authoritativeDnsServerCache = authoritativeDnsServerCache(); + new DnsAddressResolveContext(parent, nameServerName, additionals, + parent.newNameServerAddressStream(nameServerName), + resolveCache(), new AuthoritativeDnsServerCache() { + @Override + public DnsServerAddressStream get(String hostname) { + // To not risk falling into any loop, we will not use the cache while following redirects but only + // on the initial query. + return null; + } + + @Override + public void cache(String hostname, InetSocketAddress address, long originalTtl, EventLoop loop) { + authoritativeDnsServerCache.cache(hostname, address, originalTtl, loop); + } + + @Override + public void clear() { + authoritativeDnsServerCache.clear(); + } + + @Override + public boolean clear(String hostname) { + return authoritativeDnsServerCache.clear(hostname); + } + }).resolve(resolverPromise); + } + } + private void onResponse(final DnsServerAddressStream nameServerAddrStream, final int nameServerAddrStreamIndex, - final DnsQuestion question, AddressedEnvelope envelope, - final DnsQueryLifecycleObserver queryLifecycleObserver, - Promise> promise) { + final DnsQuestion question, AddressedEnvelope envelope, + final DnsQueryLifecycleObserver queryLifecycleObserver, + Promise> promise) { try { final DnsResponse res = envelope.content(); final DnsResponseCode code = res.code(); @@ -400,7 +475,8 @@ private void onResponse(final DnsServerAddressStream nameServerAddrStream, final final DnsRecordType type = question.type(); if (type == DnsRecordType.CNAME) { - onResponseCNAME(question, buildAliasMap(envelope.content()), queryLifecycleObserver, promise); + onResponseCNAME(question, buildAliasMap(envelope.content(), cnameCache(), parent.executor()), + queryLifecycleObserver, promise); return; } @@ -418,9 +494,32 @@ private void onResponse(final DnsServerAddressStream nameServerAddrStream, final // Retry with the next server if the server did not tell us that the domain does not exist. if (code != DnsResponseCode.NXDOMAIN) { query(nameServerAddrStream, nameServerAddrStreamIndex + 1, question, - queryLifecycleObserver.queryNoAnswer(code), promise, null); + queryLifecycleObserver.queryNoAnswer(code), true, promise, null); } else { queryLifecycleObserver.queryFailed(NXDOMAIN_QUERY_FAILED_EXCEPTION); + + // Try with the next server if is not authoritative for the domain. + // + // From https://tools.ietf.org/html/rfc1035 : + // + // RCODE Response code - this 4 bit field is set as part of + // responses. The values have the following + // interpretation: + // + // .... + // .... + // + // 3 Name Error - Meaningful only for + // responses from an authoritative name + // server, this code signifies that the + // domain name referenced in the query does + // not exist. + // .... + // .... + if (!res.isAuthoritativeAnswer()) { + query(nameServerAddrStream, nameServerAddrStreamIndex + 1, question, + newDnsQueryLifecycleObserver(question), true, promise, null); + } } } finally { ReferenceCountUtil.safeRelease(envelope); @@ -438,11 +537,10 @@ private boolean handleRedirect( // Check if we have answers, if not this may be an non authority NS and so redirects must be handled. if (res.count(DnsSection.ANSWER) == 0) { AuthoritativeNameServerList serverNames = extractAuthoritativeNameServers(question.name(), res); - if (serverNames != null) { - List nameServers = new ArrayList(serverNames.size()); int additionalCount = res.count(DnsSection.ADDITIONAL); + AuthoritativeDnsServerCache authoritativeDnsServerCache = authoritativeDnsServerCache(); for (int i = 0; i < additionalCount; i++) { final DnsRecord r = res.recordAt(DnsSection.ADDITIONAL, i); @@ -451,27 +549,24 @@ private boolean handleRedirect( continue; } - final String recordName = r.name(); - final AuthoritativeNameServer authoritativeNameServer = serverNames.remove(recordName); + // We may have multiple ADDITIONAL entries for the same nameserver name. For example one AAAA and + // one A record. + serverNames.handleWithAdditional(parent, r, authoritativeDnsServerCache); + } - if (authoritativeNameServer == null) { - // Not a server we are interested in. - continue; - } + // Process all unresolved nameservers as well. + serverNames.handleWithoutAdditionals(parent, resolveCache(), authoritativeDnsServerCache); - InetAddress resolved = decodeAddress(r, recordName, parent.isDecodeIdn()); - if (resolved == null) { - // Could not parse it, move to the next. - continue; - } + List addresses = serverNames.addressList(); - nameServers.add(new InetSocketAddress(resolved, parent.dnsRedirectPort(resolved))); - addNameServerToCache(authoritativeNameServer, resolved, r.timeToLive()); - } + // Give the user the chance to sort or filter the used servers for the query. + DnsServerAddressStream serverStream = parent.newRedirectDnsServerStream( + question.name(), addresses); - if (!nameServers.isEmpty()) { - query(parent.uncachedRedirectDnsServerStream(nameServers), 0, question, - queryLifecycleObserver.queryRedirected(unmodifiableList(nameServers)), promise, null); + if (serverStream != null) { + query(serverStream, 0, question, + queryLifecycleObserver.queryRedirected(new DnsAddressStreamList(serverStream)), + true, promise, null); return true; } } @@ -479,6 +574,60 @@ private boolean handleRedirect( return false; } + private static final class DnsAddressStreamList extends AbstractList { + + private final DnsServerAddressStream duplicate; + private List addresses; + + DnsAddressStreamList(DnsServerAddressStream stream) { + duplicate = stream.duplicate(); + } + + @Override + public InetSocketAddress get(int index) { + if (addresses == null) { + DnsServerAddressStream stream = duplicate.duplicate(); + addresses = new ArrayList(size()); + for (int i = 0; i < stream.size(); i++) { + addresses.add(stream.next()); + } + } + return addresses.get(index); + } + + @Override + public int size() { + return duplicate.size(); + } + + @Override + public Iterator iterator() { + return new Iterator() { + private final DnsServerAddressStream stream = duplicate.duplicate(); + private int i; + + @Override + public boolean hasNext() { + return i < stream.size(); + } + + @Override + public InetSocketAddress next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + i++; + return stream.next(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + } + /** * Returns the {@code {@link AuthoritativeNameServerList} which were included in {@link DnsSection#AUTHORITY} * or {@code null} if non are found. @@ -493,7 +642,7 @@ private static AuthoritativeNameServerList extractAuthoritativeNameServers(Strin for (int i = 0; i < authorityCount; i++) { serverNames.add(res.recordAt(DnsSection.AUTHORITY, i)); } - return serverNames; + return serverNames.isEmpty() ? null : serverNames; } private void onExpectedResponse( @@ -502,7 +651,7 @@ private void onExpectedResponse( // We often get a bunch of CNAMES as well when we asked for A/AAAA. final DnsResponse response = envelope.content(); - final Map cnames = buildAliasMap(response); + final Map cnames = buildAliasMap(response, cnameCache(), parent.executor()); final int answerCount = response.count(DnsSection.ANSWER); boolean found = false; @@ -526,10 +675,11 @@ private void onExpectedResponse( // Make sure the record is for the questioned domain. if (!recordName.equals(questionName)) { + Map cnamesCopy = new HashMap(cnames); // Even if the record's name is not exactly same, it might be an alias defined in the CNAME records. String resolved = questionName; do { - resolved = cnames.get(resolved); + resolved = cnamesCopy.remove(resolved); if (recordName.equals(resolved)) { break; } @@ -540,7 +690,7 @@ private void onExpectedResponse( } } - final T converted = convertRecord(r, hostname, additionals, parent.ch.eventLoop()); + final T converted = convertRecord(r, hostname, additionals, parent.executor()); if (converted == null) { continue; } @@ -565,8 +715,7 @@ private void onExpectedResponse( } else { queryLifecycleObserver.querySucceed(); // We also got a CNAME so we need to ensure we also query it. - onResponseCNAME(question, cnames, - parent.dnsQueryLifecycleObserverFactory().newDnsQueryLifecycleObserver(question), promise); + onResponseCNAME(question, cnames, newDnsQueryLifecycleObserver(question), promise); } } @@ -597,7 +746,7 @@ private void onResponseCNAME( } } - private static Map buildAliasMap(DnsResponse response) { + private static Map buildAliasMap(DnsResponse response, DnsCnameCache cache, EventLoop loop) { final int answerCount = response.count(DnsSection.ANSWER); Map cnames = null; for (int i = 0; i < answerCount; i ++) { @@ -621,18 +770,27 @@ private static Map buildAliasMap(DnsResponse response) { cnames = new HashMap(min(8, answerCount)); } - cnames.put(r.name().toLowerCase(Locale.US), domainName.toLowerCase(Locale.US)); + String name = r.name().toLowerCase(Locale.US); + String mapping = domainName.toLowerCase(Locale.US); + + // Cache the CNAME as well. + String nameWithDot = hostnameWithDot(name); + String mappingWithDot = hostnameWithDot(mapping); + if (!nameWithDot.equalsIgnoreCase(mappingWithDot)) { + cache.cache(nameWithDot, mappingWithDot, r.timeToLive(), loop); + cnames.put(name, mapping); + } } return cnames != null? cnames : Collections.emptyMap(); } private void tryToFinishResolve(final DnsServerAddressStream nameServerAddrStream, - final int nameServerAddrStreamIndex, - final DnsQuestion question, - final DnsQueryLifecycleObserver queryLifecycleObserver, - final Promise> promise, - final Throwable cause) { + final int nameServerAddrStreamIndex, + final DnsQuestion question, + final DnsQueryLifecycleObserver queryLifecycleObserver, + final Promise> promise, + final Throwable cause) { // There are no queries left to try. if (!queriesInProgress.isEmpty()) { @@ -649,10 +807,11 @@ private void tryToFinishResolve(final DnsServerAddressStream nameServerAddrStrea if (queryLifecycleObserver == NoopDnsQueryLifecycleObserver.INSTANCE) { // If the queryLifecycleObserver has already been terminated we should create a new one for this // fresh query. - query(nameServerAddrStream, nameServerAddrStreamIndex + 1, question, promise, cause); + query(nameServerAddrStream, nameServerAddrStreamIndex + 1, question, + newDnsQueryLifecycleObserver(question), true, promise, cause); } else { query(nameServerAddrStream, nameServerAddrStreamIndex + 1, question, queryLifecycleObserver, - promise, cause); + true, promise, cause); } return; } @@ -668,7 +827,7 @@ private void tryToFinishResolve(final DnsServerAddressStream nameServerAddrStrea // As the last resort, try to query CNAME, just in case the name server has it. triedCNAME = true; - query(hostname, DnsRecordType.CNAME, getNameServers(hostname), promise); + query(hostname, DnsRecordType.CNAME, getNameServers(hostname), true, promise); return; } } else { @@ -695,7 +854,7 @@ private void finishResolve(Promise> promise, Throwable cause) { if (finalResult != null) { // Found at least one resolved record. - trySuccess(promise, filterResults(finalResult)); + DnsNameResolver.trySuccess(promise, filterResults(finalResult)); return; } @@ -745,35 +904,97 @@ private DnsServerAddressStream getNameServers(String hostname) { private void followCname(DnsQuestion question, String cname, DnsQueryLifecycleObserver queryLifecycleObserver, Promise> promise) { + Set cnames = null; + for (;;) { + // Resolve from cnameCache() until there is no more cname entry cached. + String mapping = cnameCache().get(hostnameWithDot(cname)); + if (mapping == null) { + break; + } + if (cnames == null) { + // Detect loops. + cnames = new HashSet(2); + } + if (!cnames.add(cname)) { + // Follow CNAME from cache would loop. Lets break here. + break; + } + cname = mapping; + } + DnsServerAddressStream stream = getNameServers(cname); final DnsQuestion cnameQuestion; try { - cnameQuestion = newQuestion(cname, question.type()); + cnameQuestion = new DefaultDnsQuestion(cname, question.type(), dnsClass); } catch (Throwable cause) { queryLifecycleObserver.queryFailed(cause); PlatformDependent.throwException(cause); return; } - query(stream, 0, cnameQuestion, queryLifecycleObserver.queryCNAMEd(cnameQuestion), promise, null); + query(stream, 0, cnameQuestion, queryLifecycleObserver.queryCNAMEd(cnameQuestion), + true, promise, null); } private boolean query(String hostname, DnsRecordType type, DnsServerAddressStream dnsServerAddressStream, - Promise> promise) { - final DnsQuestion question = newQuestion(hostname, type); - if (question == null) { + boolean flush, Promise> promise) { + final DnsQuestion question; + try { + question = new DefaultDnsQuestion(hostname, type, dnsClass); + } catch (Throwable cause) { + // Assume a single failure means that queries will succeed. If the hostname is invalid for one type + // there is no case where it is known to be valid for another type. + promise.tryFailure(new IllegalArgumentException("Unable to create DNS Question for: [" + hostname + ", " + + type + ']', cause)); return false; } - query(dnsServerAddressStream, 0, question, promise, null); + query(dnsServerAddressStream, 0, question, newDnsQueryLifecycleObserver(question), flush, promise, null); return true; } - private DnsQuestion newQuestion(String hostname, DnsRecordType type) { - try { - return new DefaultDnsQuestion(hostname, type, dnsClass); - } catch (IllegalArgumentException e) { - // java.net.IDN.toASCII(...) may throw an IllegalArgumentException if it fails to parse the hostname - return null; + private DnsQueryLifecycleObserver newDnsQueryLifecycleObserver(DnsQuestion question) { + return parent.dnsQueryLifecycleObserverFactory().newDnsQueryLifecycleObserver(question); + } + + private final class CombinedDnsServerAddressStream implements DnsServerAddressStream { + private final InetSocketAddress replaced; + private final DnsServerAddressStream originalStream; + private final List resolvedAddresses; + private Iterator resolved; + + CombinedDnsServerAddressStream(InetSocketAddress replaced, List resolvedAddresses, + DnsServerAddressStream originalStream) { + this.replaced = replaced; + this.resolvedAddresses = resolvedAddresses; + this.originalStream = originalStream; + resolved = resolvedAddresses.iterator(); + } + + @Override + public InetSocketAddress next() { + if (resolved.hasNext()) { + return nextResolved0(); + } + InetSocketAddress address = originalStream.next(); + if (address.equals(replaced)) { + resolved = resolvedAddresses.iterator(); + return nextResolved0(); + } + return address; + } + + private InetSocketAddress nextResolved0() { + return parent.newRedirectServerAddress(resolved.next()); + } + + @Override + public int size() { + return originalStream.size() + resolvedAddresses.size() - 1; + } + + @Override + public DnsServerAddressStream duplicate() { + return new CombinedDnsServerAddressStream(replaced, resolvedAddresses, originalStream.duplicate()); } } @@ -786,7 +1007,8 @@ private static final class AuthoritativeNameServerList { // We not expect the linked-list to be very long so a double-linked-list is overkill. private AuthoritativeNameServer head; - private int count; + + private int nameServerCount; AuthoritativeNameServerList(String questionName) { this.questionName = questionName.toLowerCase(Locale.US); @@ -830,50 +1052,164 @@ void add(DnsRecord r) { // We are only interested in preserving the nameservers which are the closest to our qName, so ensure // we drop servers that have a smaller dots count. if (head == null || head.dots < dots) { - count = 1; - head = new AuthoritativeNameServer(dots, recordName, domainName); + nameServerCount = 1; + head = new AuthoritativeNameServer(dots, r.timeToLive(), recordName, domainName); } else if (head.dots == dots) { AuthoritativeNameServer serverName = head; while (serverName.next != null) { serverName = serverName.next; } - serverName.next = new AuthoritativeNameServer(dots, recordName, domainName); - count++; + serverName.next = new AuthoritativeNameServer(dots, r.timeToLive(), recordName, domainName); + nameServerCount++; } } - // Just walk the linked-list and mark the entry as removed when matched, so next lookup will need to process - // one node less. - AuthoritativeNameServer remove(String nsName) { + void handleWithAdditional( + DnsNameResolver parent, DnsRecord r, AuthoritativeDnsServerCache authoritativeCache) { + // Just walk the linked-list and mark the entry as handled when matched. AuthoritativeNameServer serverName = head; + String nsName = r.name(); + InetAddress resolved = decodeAddress(r, nsName, parent.isDecodeIdn()); + if (resolved == null) { + // Could not parse the address, just ignore. + return; + } + while (serverName != null) { - if (!serverName.removed && serverName.nsName.equalsIgnoreCase(nsName)) { - serverName.removed = true; - return serverName; + if (serverName.nsName.equalsIgnoreCase(nsName)) { + if (serverName.address != null) { + // We received multiple ADDITIONAL records for the same name. + // Search for the last we insert before and then append a new one. + while (serverName.next != null && serverName.next.isCopy) { + serverName = serverName.next; + } + AuthoritativeNameServer server = new AuthoritativeNameServer(serverName); + server.next = serverName.next; + serverName.next = server; + serverName = server; + + nameServerCount++; + } + // We should replace the TTL if needed with the one of the ADDITIONAL record so we use + // the smallest for caching. + serverName.update(parent.newRedirectServerAddress(resolved), r.timeToLive()); + + // Cache the server now. + cache(serverName, authoritativeCache, parent.executor()); + return; } serverName = serverName.next; } - return null; } - int size() { - return count; + // Now handle all AuthoritativeNameServer for which we had no ADDITIONAL record + void handleWithoutAdditionals( + DnsNameResolver parent, DnsCache cache, AuthoritativeDnsServerCache authoritativeCache) { + AuthoritativeNameServer serverName = head; + + while (serverName != null) { + if (serverName.address == null) { + // These will be resolved on the fly if needed. + cacheUnresolved(serverName, authoritativeCache, parent.executor()); + + // Try to resolve via cache as we had no ADDITIONAL entry for the server. + + List entries = cache.get(serverName.nsName, null); + if (entries != null && !entries.isEmpty()) { + InetAddress address = entries.get(0).address(); + + // If address is null we have a resolution failure cached so just use an unresolved address. + if (address != null) { + serverName.update(parent.newRedirectServerAddress(address)); + + for (int i = 1; i < entries.size(); i++) { + address = entries.get(i).address(); + + assert address != null : + "Cache returned a cached failure, should never return anything else"; + + AuthoritativeNameServer server = new AuthoritativeNameServer(serverName); + server.next = serverName.next; + serverName.next = server; + serverName = server; + serverName.update(parent.newRedirectServerAddress(address)); + + nameServerCount++; + } + } + } + } + serverName = serverName.next; + } + } + + private static void cacheUnresolved( + AuthoritativeNameServer server, AuthoritativeDnsServerCache authoritativeCache, EventLoop loop) { + // We still want to cached the unresolved address + server.address = InetSocketAddress.createUnresolved( + server.nsName, DefaultDnsServerAddressStreamProvider.DNS_PORT); + + // Cache the server now. + cache(server, authoritativeCache, loop); + } + + private static void cache(AuthoritativeNameServer server, AuthoritativeDnsServerCache cache, EventLoop loop) { + // Cache NS record if not for a root server as we should never cache for root servers. + if (!server.isRootServer()) { + cache.cache(server.domainName, server.address, server.ttl, loop); + } + } + + /** + * Returns {@code true} if empty, {@code false} otherwise. + */ + boolean isEmpty() { + return nameServerCount == 0; + } + + /** + * Creates a new {@link List} which holds the {@link InetSocketAddress}es. + */ + List addressList() { + List addressList = new ArrayList(nameServerCount); + + AuthoritativeNameServer server = head; + while (server != null) { + if (server.address != null) { + addressList.add(server.address); + } + server = server.next; + } + return addressList; } } - static final class AuthoritativeNameServer { - final int dots; + private static final class AuthoritativeNameServer { + private final int dots; + private final String domainName; + final boolean isCopy; final String nsName; - final String domainName; + + private long ttl; + private InetSocketAddress address; AuthoritativeNameServer next; - boolean removed; - AuthoritativeNameServer(int dots, String domainName, String nsName) { + AuthoritativeNameServer(int dots, long ttl, String domainName, String nsName) { this.dots = dots; + this.ttl = ttl; this.nsName = nsName; this.domainName = domainName; + isCopy = false; + } + + AuthoritativeNameServer(AuthoritativeNameServer server) { + dots = server.dots; + ttl = server.ttl; + nsName = server.nsName; + domainName = server.domainName; + isCopy = true; } /** @@ -884,10 +1220,16 @@ boolean isRootServer() { } /** - * The domain for which the {@link AuthoritativeNameServer} is responsible. + * Update the server with the given address and TTL if needed. */ - String domainName() { - return domainName; + void update(InetSocketAddress address, long ttl) { + assert this.address == null || this.address.isUnresolved(); + this.address = address; + this.ttl = min(ttl, ttl); + } + + void update(InetSocketAddress address) { + update(address, Long.MAX_VALUE); } } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddressStreamProviders.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddressStreamProviders.java index 7765b953878b..fbcd8cc9c109 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddressStreamProviders.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddressStreamProviders.java @@ -18,18 +18,45 @@ import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.UnstableApi; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + /** * Utility methods related to {@link DnsServerAddressStreamProvider}. */ @UnstableApi public final class DnsServerAddressStreamProviders { + // We use 5 minutes which is the same as what OpenJDK is using in sun.net.dns.ResolverConfigurationImpl. + private static final long REFRESH_INTERVAL = TimeUnit.MINUTES.toNanos(5); + // TODO(scott): how is this done on Windows? This may require a JNI call to GetNetworkParams // https://msdn.microsoft.com/en-us/library/aa365968(VS.85).aspx. private static final DnsServerAddressStreamProvider DEFAULT_DNS_SERVER_ADDRESS_STREAM_PROVIDER = + new DnsServerAddressStreamProvider() { + private volatile DnsServerAddressStreamProvider currentProvider = provider(); + private final AtomicLong lastRefresh = new AtomicLong(System.nanoTime()); + + @Override + public DnsServerAddressStream nameServerAddressStream(String hostname) { + long last = lastRefresh.get(); + DnsServerAddressStreamProvider current = currentProvider; + if (System.nanoTime() - last > REFRESH_INTERVAL) { + // This is slightly racy which means it will be possible still use the old configuration for a small + // amount of time, but that's ok. + if (lastRefresh.compareAndSet(last, System.nanoTime())) { + current = currentProvider = provider(); + } + } + return current.nameServerAddressStream(hostname); + } + + private DnsServerAddressStreamProvider provider() { // If on windows just use the DefaultDnsServerAddressStreamProvider.INSTANCE as otherwise // we will log some error which may be confusing. - PlatformDependent.isWindows() ? DefaultDnsServerAddressStreamProvider.INSTANCE : + return PlatformDependent.isWindows() ? DefaultDnsServerAddressStreamProvider.INSTANCE : UnixResolverDnsServerAddressStreamProvider.parseSilently(); + } + }; private DnsServerAddressStreamProviders() { } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddresses.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddresses.java index ced160f5f8be..72c0acd1b599 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddresses.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddresses.java @@ -23,8 +23,6 @@ import java.util.Collection; import java.util.List; -import static io.netty.resolver.dns.DefaultDnsServerAddressStreamProvider.defaultAddressArray; - /** * Provides an infinite sequence of DNS server addresses to {@link DnsNameResolver}. */ @@ -77,9 +75,9 @@ public static DnsServerAddresses sequential(InetSocketAddress... addresses) { return sequential0(sanitize(addresses)); } - private static DnsServerAddresses sequential0(final InetSocketAddress... addresses) { - if (addresses.length == 1) { - return singleton(addresses[0]); + private static DnsServerAddresses sequential0(final List addresses) { + if (addresses.size() == 1) { + return singleton(addresses.get(0)); } return new DefaultDnsServerAddresses("sequential", addresses) { @@ -106,9 +104,9 @@ public static DnsServerAddresses shuffled(InetSocketAddress... addresses) { return shuffled0(sanitize(addresses)); } - private static DnsServerAddresses shuffled0(final InetSocketAddress[] addresses) { - if (addresses.length == 1) { - return singleton(addresses[0]); + private static DnsServerAddresses shuffled0(List addresses) { + if (addresses.size() == 1) { + return singleton(addresses.get(0)); } return new DefaultDnsServerAddresses("shuffled", addresses) { @@ -139,9 +137,9 @@ public static DnsServerAddresses rotational(InetSocketAddress... addresses) { return rotational0(sanitize(addresses)); } - private static DnsServerAddresses rotational0(final InetSocketAddress[] addresses) { - if (addresses.length == 1) { - return singleton(addresses[0]); + private static DnsServerAddresses rotational0(List addresses) { + if (addresses.size() == 1) { + return singleton(addresses.get(0)); } return new RotationalDnsServerAddresses(addresses); @@ -161,7 +159,7 @@ public static DnsServerAddresses singleton(final InetSocketAddress address) { return new SingletonDnsServerAddresses(address); } - private static InetSocketAddress[] sanitize(Iterable addresses) { + private static List sanitize(Iterable addresses) { if (addresses == null) { throw new NullPointerException("addresses"); } @@ -187,10 +185,10 @@ private static InetSocketAddress[] sanitize(Iterable sanitize(InetSocketAddress[] addresses) { if (addresses == null) { throw new NullPointerException("addresses"); } @@ -207,10 +205,10 @@ private static InetSocketAddress[] sanitize(InetSocketAddress[] addresses) { } if (list.isEmpty()) { - return defaultAddressArray(); + return DefaultDnsServerAddressStreamProvider.defaultAddressList(); } - return list.toArray(new InetSocketAddress[list.size()]); + return list; } /** diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/NameServerComparator.java b/resolver-dns/src/main/java/io/netty/resolver/dns/NameServerComparator.java new file mode 100644 index 000000000000..3f6703cb58de --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/NameServerComparator.java @@ -0,0 +1,61 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.util.internal.ObjectUtil; + +import java.io.Serializable; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Comparator; +import java.util.List; + +/** + * Special {@link Comparator} implementation to sort the nameservers to use when follow redirects. + * + * This implementation follows all the semantics listed in the + * Comparator apidocs + * with the limitation that {@link InetSocketAddress#equals(Object)} will not result in the same return value as + * {@link #compare(InetSocketAddress, InetSocketAddress)}. This is completely fine as this should only be used + * to sort {@link List}s. + */ +public final class NameServerComparator implements Comparator, Serializable { + + private static final long serialVersionUID = 8372151874317596185L; + + private final Class preferredAddressType; + + public NameServerComparator(Class preferredAddressType) { + this.preferredAddressType = ObjectUtil.checkNotNull(preferredAddressType, "preferredAddressType"); + } + + @Override + public int compare(InetSocketAddress addr1, InetSocketAddress addr2) { + if (addr1.equals(addr2)) { + return 0; + } + if (!addr1.isUnresolved() && !addr2.isUnresolved()) { + if (addr1.getAddress().getClass() == addr2.getAddress().getClass()) { + return 0; + } + return preferredAddressType.isAssignableFrom(addr1.getAddress().getClass()) ? -1 : 1; + } + if (addr1.isUnresolved() && addr2.isUnresolved()) { + return 0; + } + return addr1.isUnresolved() ? 1 : -1; + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/NoopAuthoritativeDnsServerCache.java b/resolver-dns/src/main/java/io/netty/resolver/dns/NoopAuthoritativeDnsServerCache.java new file mode 100644 index 000000000000..0cdce779e92d --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/NoopAuthoritativeDnsServerCache.java @@ -0,0 +1,53 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.EventLoop; +import io.netty.util.internal.UnstableApi; + +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.List; + +/** + * A noop {@link AuthoritativeDnsServerCache} that actually never caches anything. + */ +@UnstableApi +public final class NoopAuthoritativeDnsServerCache implements AuthoritativeDnsServerCache { + public static final NoopAuthoritativeDnsServerCache INSTANCE = new NoopAuthoritativeDnsServerCache(); + + private NoopAuthoritativeDnsServerCache() { } + + @Override + public DnsServerAddressStream get(String hostname) { + return null; + } + + @Override + public void cache(String hostname, InetSocketAddress address, long originalTtl, EventLoop loop) { + // NOOP + } + + @Override + public void clear() { + // NOOP + } + + @Override + public boolean clear(String hostname) { + return false; + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/NoopDnsCnameCache.java b/resolver-dns/src/main/java/io/netty/resolver/dns/NoopDnsCnameCache.java new file mode 100644 index 000000000000..54113c40b258 --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/NoopDnsCnameCache.java @@ -0,0 +1,47 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.EventLoop; +import io.netty.util.internal.UnstableApi; + +@UnstableApi +public final class NoopDnsCnameCache implements DnsCnameCache { + + public static final NoopDnsCnameCache INSTANCE = new NoopDnsCnameCache(); + + private NoopDnsCnameCache() { } + + @Override + public String get(String hostname) { + return null; + } + + @Override + public void cache(String hostname, String cname, long originalTtl, EventLoop loop) { + // NOOP + } + + @Override + public void clear() { + // NOOP + } + + @Override + public boolean clear(String hostname) { + return false; + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/RotationalDnsServerAddresses.java b/resolver-dns/src/main/java/io/netty/resolver/dns/RotationalDnsServerAddresses.java index dfa2a5085102..623d6f6a10fb 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/RotationalDnsServerAddresses.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/RotationalDnsServerAddresses.java @@ -17,6 +17,7 @@ package io.netty.resolver.dns; import java.net.InetSocketAddress; +import java.util.List; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; final class RotationalDnsServerAddresses extends DefaultDnsServerAddresses { @@ -27,7 +28,7 @@ final class RotationalDnsServerAddresses extends DefaultDnsServerAddresses { @SuppressWarnings("UnusedDeclaration") private volatile int startIdx; - RotationalDnsServerAddresses(InetSocketAddress[] addresses) { + RotationalDnsServerAddresses(List addresses) { super("rotational", addresses); } @@ -36,7 +37,7 @@ public DnsServerAddressStream stream() { for (;;) { int curStartIdx = startIdx; int nextStartIdx = curStartIdx + 1; - if (nextStartIdx >= addresses.length) { + if (nextStartIdx >= addresses.size()) { nextStartIdx = 0; } if (startIdxUpdater.compareAndSet(this, curStartIdx, nextStartIdx)) { diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/SequentialDnsServerAddressStream.java b/resolver-dns/src/main/java/io/netty/resolver/dns/SequentialDnsServerAddressStream.java index b2288e33589d..dd3f3665f04e 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/SequentialDnsServerAddressStream.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/SequentialDnsServerAddressStream.java @@ -17,13 +17,15 @@ package io.netty.resolver.dns; import java.net.InetSocketAddress; +import java.util.Collection; +import java.util.List; final class SequentialDnsServerAddressStream implements DnsServerAddressStream { - private final InetSocketAddress[] addresses; + private final List addresses; private int i; - SequentialDnsServerAddressStream(InetSocketAddress[] addresses, int startIdx) { + SequentialDnsServerAddressStream(List addresses, int startIdx) { this.addresses = addresses; i = startIdx; } @@ -31,8 +33,8 @@ final class SequentialDnsServerAddressStream implements DnsServerAddressStream { @Override public InetSocketAddress next() { int i = this.i; - InetSocketAddress next = addresses[i]; - if (++ i < addresses.length) { + InetSocketAddress next = addresses.get(i); + if (++ i < addresses.size()) { this.i = i; } else { this.i = 0; @@ -42,7 +44,7 @@ public InetSocketAddress next() { @Override public int size() { - return addresses.length; + return addresses.size(); } @Override @@ -55,8 +57,8 @@ public String toString() { return toString("sequential", i, addresses); } - static String toString(String type, int index, InetSocketAddress[] addresses) { - final StringBuilder buf = new StringBuilder(type.length() + 2 + addresses.length * 16); + static String toString(String type, int index, Collection addresses) { + final StringBuilder buf = new StringBuilder(type.length() + 2 + addresses.size() * 16); buf.append(type).append("(index: ").append(index); buf.append(", addrs: ("); for (InetSocketAddress a: addresses) { diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/ShuffledDnsServerAddressStream.java b/resolver-dns/src/main/java/io/netty/resolver/dns/ShuffledDnsServerAddressStream.java index a30302e3ae37..9b56f23b2c2f 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/ShuffledDnsServerAddressStream.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/ShuffledDnsServerAddressStream.java @@ -19,11 +19,13 @@ import io.netty.util.internal.PlatformDependent; import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.List; import java.util.Random; final class ShuffledDnsServerAddressStream implements DnsServerAddressStream { - private final InetSocketAddress[] addresses; + private final List addresses; private int i; /** @@ -31,34 +33,26 @@ final class ShuffledDnsServerAddressStream implements DnsServerAddressStream { * @param addresses The addresses are not cloned. It is assumed the caller has cloned this array or otherwise will * not modify the contents. */ - ShuffledDnsServerAddressStream(InetSocketAddress[] addresses) { + ShuffledDnsServerAddressStream(List addresses) { this.addresses = addresses; shuffle(); } - private ShuffledDnsServerAddressStream(InetSocketAddress[] addresses, int startIdx) { + private ShuffledDnsServerAddressStream(List addresses, int startIdx) { this.addresses = addresses; i = startIdx; } private void shuffle() { - final InetSocketAddress[] addresses = this.addresses; - final Random r = PlatformDependent.threadLocalRandom(); - - for (int i = addresses.length - 1; i >= 0; i --) { - InetSocketAddress tmp = addresses[i]; - int j = r.nextInt(i + 1); - addresses[i] = addresses[j]; - addresses[j] = tmp; - } + Collections.shuffle(addresses, PlatformDependent.threadLocalRandom()); } @Override public InetSocketAddress next() { int i = this.i; - InetSocketAddress next = addresses[i]; - if (++ i < addresses.length) { + InetSocketAddress next = addresses.get(i); + if (++ i < addresses.size()) { this.i = i; } else { this.i = 0; @@ -69,7 +63,7 @@ public InetSocketAddress next() { @Override public int size() { - return addresses.length; + return addresses.size(); } @Override diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/UnixResolverDnsServerAddressStreamProvider.java b/resolver-dns/src/main/java/io/netty/resolver/dns/UnixResolverDnsServerAddressStreamProvider.java index d5571c93a977..afe29313c41f 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/UnixResolverDnsServerAddressStreamProvider.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/UnixResolverDnsServerAddressStreamProvider.java @@ -28,9 +28,11 @@ import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.regex.Pattern; import static io.netty.resolver.dns.DefaultDnsServerAddressStreamProvider.DNS_PORT; import static io.netty.util.internal.ObjectUtil.checkNotNull; @@ -51,11 +53,13 @@ public final class UnixResolverDnsServerAddressStreamProvider implements DnsServ private static final String SORTLIST_ROW_LABEL = "sortlist"; private static final String OPTIONS_ROW_LABEL = "options"; private static final String DOMAIN_ROW_LABEL = "domain"; + private static final String SEARCH_ROW_LABEL = "search"; private static final String PORT_ROW_LABEL = "port"; private static final String NDOTS_LABEL = "ndots:"; static final int DEFAULT_NDOTS = 1; private final DnsServerAddresses defaultNameServerAddresses; private final Map domainToNameServerStreamMap; + private static final Pattern SEARCH_DOMAIN_PATTERN = Pattern.compile("\\s+"); /** * Attempt to parse {@code /etc/resolv.conf} and files in the {@code /etc/resolver} directory by default. @@ -284,4 +288,60 @@ static int parseEtcResolverFirstNdots(File etcResolvConf) throws IOException { } return DEFAULT_NDOTS; } + + /** + * Parse a file of the format /etc/resolv.conf and return the + * list of search domains found in it or an empty list if not found. + * @return List of search domains. + * @throws IOException If a failure occurs parsing the file. + */ + static List parseEtcResolverSearchDomains() throws IOException { + return parseEtcResolverSearchDomains(new File(ETC_RESOLV_CONF_FILE)); + } + + /** + * Parse a file of the format /etc/resolv.conf and return the + * list of search domains found in it or an empty list if not found. + * @param etcResolvConf a file of the format /etc/resolv.conf. + * @return List of search domains. + * @throws IOException If a failure occurs parsing the file. + */ + static List parseEtcResolverSearchDomains(File etcResolvConf) throws IOException { + String localDomain = null; + List searchDomains = new ArrayList(); + + FileReader fr = new FileReader(etcResolvConf); + BufferedReader br = null; + try { + br = new BufferedReader(fr); + String line; + while ((line = br.readLine()) != null) { + if (localDomain == null && line.startsWith(DOMAIN_ROW_LABEL)) { + int i = indexOfNonWhiteSpace(line, DOMAIN_ROW_LABEL.length()); + if (i >= 0) { + localDomain = line.substring(i); + } + } else if (line.startsWith(SEARCH_ROW_LABEL)) { + int i = indexOfNonWhiteSpace(line, SEARCH_ROW_LABEL.length()); + if (i >= 0) { + // May contain more then one entry, either seperated by whitespace or tab. + // See https://linux.die.net/man/5/resolver + String[] domains = SEARCH_DOMAIN_PATTERN.split(line.substring(i)); + Collections.addAll(searchDomains, domains); + } + } + } + } finally { + if (br == null) { + fr.close(); + } else { + br.close(); + } + } + + // return what was on the 'domain' line only if there were no 'search' lines + return localDomain != null && searchDomains.isEmpty() + ? Collections.singletonList(localDomain) + : searchDomains; + } } diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultAuthoritativeDnsServerCacheTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultAuthoritativeDnsServerCacheTest.java new file mode 100644 index 000000000000..ada0eaee262d --- /dev/null +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultAuthoritativeDnsServerCacheTest.java @@ -0,0 +1,198 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.util.NetUtil; +import org.junit.Test; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.*; + +public class DefaultAuthoritativeDnsServerCacheTest { + + @Test + public void testExpire() throws Throwable { + InetSocketAddress resolved1 = new InetSocketAddress( + InetAddress.getByAddress("ns1", new byte[] { 10, 0, 0, 1 }), 53); + InetSocketAddress resolved2 = new InetSocketAddress( + InetAddress.getByAddress("ns2", new byte[] { 10, 0, 0, 2 }), 53); + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultAuthoritativeDnsServerCache cache = new DefaultAuthoritativeDnsServerCache(); + cache.cache("netty.io", resolved1, 1, loop); + cache.cache("netty.io", resolved2, 10000, loop); + + Throwable error = loop.schedule(new Callable() { + @Override + public Throwable call() { + try { + assertNull(cache.get("netty.io")); + return null; + } catch (Throwable cause) { + return cause; + } + } + }, 1, TimeUnit.SECONDS).get(); + if (error != null) { + throw error; + } + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testExpireWithDifferentTTLs() { + testExpireWithTTL0(1); + testExpireWithTTL0(1000); + testExpireWithTTL0(1000000); + } + + private static void testExpireWithTTL0(int days) { + EventLoopGroup group = new NioEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultAuthoritativeDnsServerCache cache = new DefaultAuthoritativeDnsServerCache(); + cache.cache("netty.io", new InetSocketAddress(NetUtil.LOCALHOST, 53), days, loop); + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testAddMultipleDnsServerForSameHostname() throws Exception { + InetSocketAddress resolved1 = new InetSocketAddress( + InetAddress.getByAddress("ns1", new byte[] { 10, 0, 0, 1 }), 53); + InetSocketAddress resolved2 = new InetSocketAddress( + InetAddress.getByAddress("ns2", new byte[] { 10, 0, 0, 2 }), 53); + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultAuthoritativeDnsServerCache cache = new DefaultAuthoritativeDnsServerCache(); + cache.cache("netty.io", resolved1, 100, loop); + cache.cache("netty.io", resolved2, 10000, loop); + + DnsServerAddressStream entries = cache.get("netty.io"); + assertEquals(2, entries.size()); + assertEquals(resolved1, entries.next()); + assertEquals(resolved2, entries.next()); + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testUnresolvedReplacedByResolved() throws Exception { + InetSocketAddress unresolved = InetSocketAddress.createUnresolved("ns1", 53); + InetSocketAddress resolved1 = new InetSocketAddress( + InetAddress.getByAddress("ns2", new byte[] { 10, 0, 0, 2 }), 53); + InetSocketAddress resolved2 = new InetSocketAddress( + InetAddress.getByAddress("ns1", new byte[] { 10, 0, 0, 1 }), 53); + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultAuthoritativeDnsServerCache cache = new DefaultAuthoritativeDnsServerCache(); + cache.cache("netty.io", unresolved, 100, loop); + cache.cache("netty.io", resolved1, 10000, loop); + + DnsServerAddressStream entries = cache.get("netty.io"); + assertEquals(2, entries.size()); + assertEquals(unresolved, entries.next()); + assertEquals(resolved1, entries.next()); + + cache.cache("netty.io", resolved2, 100, loop); + DnsServerAddressStream entries2 = cache.get("netty.io"); + + assertEquals(2, entries2.size()); + assertEquals(resolved2, entries2.next()); + assertEquals(resolved1, entries2.next()); + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testUseNoComparator() throws Exception { + testUseComparator0(true); + } + + @Test + public void testUseComparator() throws Exception { + testUseComparator0(false); + } + + private static void testUseComparator0(boolean noComparator) throws Exception { + InetSocketAddress unresolved = InetSocketAddress.createUnresolved("ns1", 53); + InetSocketAddress resolved = new InetSocketAddress( + InetAddress.getByAddress("ns2", new byte[] { 10, 0, 0, 2 }), 53); + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultAuthoritativeDnsServerCache cache; + + if (noComparator) { + cache = new DefaultAuthoritativeDnsServerCache(10000, 10000, null); + } else { + cache = new DefaultAuthoritativeDnsServerCache(10000, 10000, + new Comparator() { + @Override + public int compare(InetSocketAddress o1, InetSocketAddress o2) { + if (o1.equals(o2)) { + return 0; + } + if (o1.isUnresolved()) { + return 1; + } else { + return -1; + } + } + }); + } + cache.cache("netty.io", unresolved, 100, loop); + cache.cache("netty.io", resolved, 10000, loop); + + DnsServerAddressStream entries = cache.get("netty.io"); + assertEquals(2, entries.size()); + + if (noComparator) { + assertEquals(unresolved, entries.next()); + assertEquals(resolved, entries.next()); + } else { + assertEquals(resolved, entries.next()); + assertEquals(unresolved, entries.next()); + } + } finally { + group.shutdownGracefully(); + } + } + +} diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultDnsCacheTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultDnsCacheTest.java index 17aba98e3c25..d3b2384c7193 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultDnsCacheTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultDnsCacheTest.java @@ -172,4 +172,30 @@ public void testCacheFailed() throws Exception { group.shutdownGracefully(); } } + + @Test + public void testDotHandling() throws Exception { + InetAddress addr1 = InetAddress.getByAddress(new byte[] { 10, 0, 0, 1 }); + InetAddress addr2 = InetAddress.getByAddress(new byte[] { 10, 0, 0, 2 }); + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultDnsCache cache = new DefaultDnsCache(1, 100, 100); + cache.cache("netty.io", null, addr1, 10000, loop); + cache.cache("netty.io.", null, addr2, 10000, loop); + + List entries = cache.get("netty.io", null); + assertEquals(2, entries.size()); + assertEntry(entries.get(0), addr1); + assertEntry(entries.get(1), addr2); + + List entries2 = cache.get("netty.io.", null); + assertEquals(2, entries2.size()); + assertEntry(entries2.get(0), addr1); + assertEntry(entries2.get(1), addr2); + } finally { + group.shutdownGracefully(); + } + } } diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultDnsCnameCacheTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultDnsCnameCacheTest.java new file mode 100644 index 000000000000..a08702f5f2b1 --- /dev/null +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DefaultDnsCnameCacheTest.java @@ -0,0 +1,132 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import org.junit.Test; + +import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.*; + +public class DefaultDnsCnameCacheTest { + + @Test + public void testExpire() throws Throwable { + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultDnsCnameCache cache = new DefaultDnsCnameCache(); + cache.cache("netty.io", "mapping.netty.io", 1, loop); + + Throwable error = loop.schedule(new Callable() { + @Override + public Throwable call() { + try { + assertNull(cache.get("netty.io")); + return null; + } catch (Throwable cause) { + return cause; + } + } + }, 1, TimeUnit.SECONDS).get(); + if (error != null) { + throw error; + } + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testExpireWithDifferentTTLs() { + testExpireWithTTL0(1); + testExpireWithTTL0(1000); + testExpireWithTTL0(1000000); + } + + private static void testExpireWithTTL0(int days) { + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultDnsCnameCache cache = new DefaultDnsCnameCache(); + cache.cache("netty.io", "mapping.netty.io", TimeUnit.DAYS.toSeconds(days), loop); + assertEquals("mapping.netty.io", cache.get("netty.io")); + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testMultipleCnamesForSameHostname() throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultDnsCnameCache cache = new DefaultDnsCnameCache(); + cache.cache("netty.io", "mapping1.netty.io", 10, loop); + cache.cache("netty.io", "mapping2.netty.io", 10000, loop); + + assertEquals("mapping2.netty.io", cache.get("netty.io")); + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testAddSameCnameForSameHostname() throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultDnsCnameCache cache = new DefaultDnsCnameCache(); + cache.cache("netty.io", "mapping.netty.io", 10, loop); + cache.cache("netty.io", "mapping.netty.io", 10000, loop); + + assertEquals("mapping.netty.io", cache.get("netty.io")); + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testClear() throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(1); + + try { + EventLoop loop = group.next(); + final DefaultDnsCnameCache cache = new DefaultDnsCnameCache(); + cache.cache("x.netty.io", "mapping.netty.io", 100000, loop); + cache.cache("y.netty.io", "mapping.netty.io", 100000, loop); + + assertEquals("mapping.netty.io", cache.get("x.netty.io")); + assertEquals("mapping.netty.io", cache.get("y.netty.io")); + + assertTrue(cache.clear("x.netty.io")); + assertNull(cache.get("x.netty.io")); + assertEquals("mapping.netty.io", cache.get("y.netty.io")); + cache.clear(); + assertNull(cache.get("y.netty.io")); + } finally { + group.shutdownGracefully(); + } + } +} diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index dcda33c8d5a7..d03ee8707757 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -58,9 +58,12 @@ import org.hamcrest.Matchers; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import java.io.IOException; +import java.net.DatagramSocket; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; @@ -79,26 +82,26 @@ import java.util.Map.Entry; import java.util.Queue; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static io.netty.handler.codec.dns.DnsRecordType.A; import static io.netty.handler.codec.dns.DnsRecordType.AAAA; import static io.netty.handler.codec.dns.DnsRecordType.CNAME; -import static io.netty.resolver.dns.DefaultDnsServerAddressStreamProvider.DNS_PORT; import static io.netty.resolver.dns.DnsServerAddresses.sequential; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; public class DnsNameResolverTest { @@ -214,12 +217,14 @@ public class DnsNameResolverTest { "localhost"))); private static final Map DOMAINS_PUNYCODE = new HashMap(); + static { DOMAINS_PUNYCODE.put("büchner.de", "xn--bchner-3ya.de"); DOMAINS_PUNYCODE.put("müller.de", "xn--mller-kva.de"); } private static final Set DOMAINS_ALL; + static { Set all = new HashSet(DOMAINS.size() + DOMAINS_PUNYCODE.size()); all.addAll(DOMAINS); @@ -231,6 +236,7 @@ public class DnsNameResolverTest { * The list of the domain names to exclude from {@link #testResolveAorAAAA()}. */ private static final Set EXCLUSIONS_RESOLVE_A = new HashSet(); + static { Collections.addAll( EXCLUSIONS_RESOLVE_A, @@ -244,6 +250,7 @@ public class DnsNameResolverTest { * Unfortunately, there are only handful of domain names with IPv6 addresses. */ private static final Set EXCLUSIONS_RESOLVE_AAAA = new HashSet(); + static { EXCLUSIONS_RESOLVE_AAAA.addAll(EXCLUSIONS_RESOLVE_A); EXCLUSIONS_RESOLVE_AAAA.addAll(DOMAINS); @@ -281,6 +288,7 @@ public class DnsNameResolverTest { * The list of the domain names to exclude from {@link #testQueryMx()}. */ private static final Set EXCLUSIONS_QUERY_MX = new HashSet(); + static { Collections.addAll( EXCLUSIONS_QUERY_MX, @@ -298,6 +306,9 @@ public class DnsNameResolverTest { private static final TestDnsServer dnsServer = new TestDnsServer(DOMAINS_ALL); private static final EventLoopGroup group = new NioEventLoopGroup(1); + @Rule + public ExpectedException expectedException = ExpectedException.none(); + private static DnsNameResolverBuilder newResolver(boolean decodeToUnicode) { return newResolver(decodeToUnicode, null); } @@ -341,6 +352,7 @@ private static DnsNameResolverBuilder newNonCachedResolver(ResolvedAddressTypes public static void init() throws Exception { dnsServer.start(); } + @AfterClass public static void destroy() { dnsServer.stop(); @@ -393,20 +405,20 @@ public Set getRecords(QuestionRecord question) { }); dnsServer2.start(); try { - final Set overridenHostnames = new HashSet(); + final Set overriddenHostnames = new HashSet(); for (String name : DOMAINS) { if (EXCLUSIONS_RESOLVE_A.contains(name)) { continue; } if (PlatformDependent.threadLocalRandom().nextBoolean()) { - overridenHostnames.add(name); + overriddenHostnames.add(name); } } DnsNameResolver resolver = newResolver(false, new DnsServerAddressStreamProvider() { @Override public DnsServerAddressStream nameServerAddressStream(String hostname) { - return overridenHostnames.contains(hostname) ? sequential(dnsServer2.localAddress()).stream() : - null; + return overriddenHostnames.contains(hostname) ? sequential(dnsServer2.localAddress()).stream() : + null; } }).build(); try { @@ -415,7 +427,7 @@ public DnsServerAddressStream nameServerAddressStream(String hostname) { if (resolvedEntry.getValue().isLoopbackAddress()) { continue; } - if (overridenHostnames.contains(resolvedEntry.getKey())) { + if (overriddenHostnames.contains(resolvedEntry.getKey())) { assertEquals("failed to resolve " + resolvedEntry.getKey(), overriddenIP, resolvedEntry.getValue().getHostAddress()); } else { @@ -448,7 +460,7 @@ public void testResolveA() throws Exception { // Ensure the result from the cache is identical from the uncached one. assertThat(resultB.size(), is(resultA.size())); - for (Entry e: resultA.entrySet()) { + for (Entry e : resultA.entrySet()) { InetAddress expected = e.getValue(); InetAddress actual = resultB.get(e.getKey()); if (!actual.equals(expected)) { @@ -550,7 +562,7 @@ private static Map testResolve0(DnsNameResolver resolver, S assertThat(resolved.getHostName(), is(unresolved)); boolean typeMatches = false; - for (InternetProtocolFamily f: resolver.resolvedInternetProtocolFamiliesUnsafe()) { + for (InternetProtocolFamily f : resolver.resolvedInternetProtocolFamiliesUnsafe()) { Class resolvedType = resolved.getClass(); if (f.addressType().isAssignableFrom(resolvedType)) { typeMatches = true; @@ -575,7 +587,7 @@ public void testQueryMx() { Map>> futures = new LinkedHashMap>>(); - for (String name: DOMAINS) { + for (String name : DOMAINS) { if (EXCLUSIONS_QUERY_MX.contains(name)) { continue; } @@ -583,7 +595,7 @@ public void testQueryMx() { queryMx(resolver, futures, name); } - for (Entry>> e: futures.entrySet()) { + for (Entry>> e : futures.entrySet()) { String hostname = e.getKey(); Future> f = e.getValue().awaitUninterruptibly(); @@ -592,7 +604,7 @@ public void testQueryMx() { final int answerCount = response.count(DnsSection.ANSWER); final List mxList = new ArrayList(answerCount); - for (int i = 0; i < answerCount; i ++) { + for (int i = 0; i < answerCount; i++) { final DnsRecord r = response.recordAt(DnsSection.ANSWER, i); if (r.type() == DnsRecordType.MX) { mxList.add(r); @@ -601,7 +613,7 @@ public void testQueryMx() { assertThat(mxList.size(), is(greaterThan(0))); StringBuilder buf = new StringBuilder(); - for (DnsRecord r: mxList) { + for (DnsRecord r : mxList) { ByteBuf recordContent = ((ByteBufHolder) r).content(); buf.append(StringUtil.NEWLINE); @@ -886,7 +898,7 @@ public Set getRecords(QuestionRecord question) throws DnsExcepti dnsServer3.start(); DnsNameResolver resolver = null; try { - DnsCache nsCache = new DefaultDnsCache(); + AuthoritativeDnsServerCache nsCache = new DefaultAuthoritativeDnsServerCache(); // What we want to test is the following: // 1. Do a DNS query. // 2. CNAME is returned, we want to lookup that CNAME on multiple DNS servers @@ -896,7 +908,7 @@ public Set getRecords(QuestionRecord question) throws DnsExcepti // The DnsCache is used for the name server cache, but doesn't provide a InetSocketAddress (only InetAddress // so no port), so we only specify the name server in the cache, and then specify both name servers in the // fallback name server provider. - nsCache.cache("nettyfoo.com.", null, dnsServer2.localAddress().getAddress(), 10000, group.next()); + nsCache.cache("nettyfoo.com.", dnsServer2.localAddress(), 10000, group.next()); resolver = new DnsNameResolver( group.next(), new ReflectiveChannelFactory(NioDatagramChannel.class), NoopDnsCache.INSTANCE, nsCache, NoopDnsQueryLifecycleObserverFactory.INSTANCE, 3000, @@ -905,8 +917,10 @@ public Set getRecords(QuestionRecord question) throws DnsExcepti new SequentialDnsServerAddressStreamProvider(dnsServer2.localAddress(), dnsServer3.localAddress()), DnsNameResolver.DEFAULT_SEARCH_DOMAINS, 0, true) { @Override - int dnsRedirectPort(InetAddress server) { - return hitServer2.get() ? dnsServer3.localAddress().getPort() : dnsServer2.localAddress().getPort(); + InetSocketAddress newRedirectServerAddress(InetAddress server) { + int port = hitServer2.get() ? dnsServer3.localAddress().getPort() : + dnsServer2.localAddress().getPort(); + return new InetSocketAddress(server, port); } }; InetAddress resolvedAddress = resolver.resolve(firstName).syncUninterruptibly().getNow(); @@ -956,7 +970,7 @@ public void testResolveAllMx() { assertThat(resolver.isRecursionDesired(), is(true)); final Map>> futures = new LinkedHashMap>>(); - for (String name: DOMAINS) { + for (String name : DOMAINS) { if (EXCLUSIONS_QUERY_MX.contains(name)) { continue; } @@ -964,14 +978,14 @@ public void testResolveAllMx() { futures.put(name, resolver.resolveAll(new DefaultDnsQuestion(name, DnsRecordType.MX))); } - for (Entry>> e: futures.entrySet()) { + for (Entry>> e : futures.entrySet()) { String hostname = e.getKey(); Future> f = e.getValue().awaitUninterruptibly(); final List mxList = f.getNow(); assertThat(mxList.size(), is(greaterThan(0))); StringBuilder buf = new StringBuilder(); - for (DnsRecord r: mxList) { + for (DnsRecord r : mxList) { ByteBuf recordContent = ((ByteBufHolder) r).content(); buf.append(StringUtil.NEWLINE); @@ -1013,7 +1027,7 @@ public InetAddress address(String inetHost, ResolvedAddressTypes resolvedAddress }).build(); final List records = resolver.resolveAll(new DefaultDnsQuestion("foo.com.", A)) - .syncUninterruptibly().getNow(); + .syncUninterruptibly().getNow(); assertThat(records, Matchers.hasSize(1)); assertThat(records.get(0), Matchers.instanceOf(DnsRawRecord.class)); @@ -1129,7 +1143,7 @@ public void aAndAAAAQueryShouldTryFirstDnsServerBeforeSecond() throws IOExceptio new TestRecursiveCacheDnsQueryLifecycleObserverFactory(); DnsNameResolverBuilder builder = new DnsNameResolverBuilder(group.next()) - .resolvedAddressTypes(ResolvedAddressTypes.IPV6_PREFERRED) + .resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY) .dnsQueryLifecycleObserverFactory(lifecycleObserverFactory) .channelType(NioDatagramChannel.class) .optResourceEnabled(false) @@ -1142,18 +1156,12 @@ public void aAndAAAAQueryShouldTryFirstDnsServerBeforeSecond() throws IOExceptio TestDnsQueryLifecycleObserver observer = lifecycleObserverFactory.observers.poll(); assertNotNull(observer); - assertEquals(2, lifecycleObserverFactory.observers.size()); + assertEquals(1, lifecycleObserverFactory.observers.size()); assertEquals(2, observer.events.size()); QueryWrittenEvent writtenEvent = (QueryWrittenEvent) observer.events.poll(); assertEquals(dnsServer1.localAddress(), writtenEvent.dnsServerAddress); QueryFailedEvent failedEvent = (QueryFailedEvent) observer.events.poll(); - observer = lifecycleObserverFactory.observers.poll(); - assertEquals(2, observer.events.size()); - writtenEvent = (QueryWrittenEvent) observer.events.poll(); - assertEquals(dnsServer1.localAddress(), writtenEvent.dnsServerAddress); - failedEvent = (QueryFailedEvent) observer.events.poll(); - observer = lifecycleObserverFactory.observers.poll(); assertEquals(2, observer.events.size()); writtenEvent = (QueryWrittenEvent) observer.events.poll(); @@ -1197,11 +1205,11 @@ private static void testResolvesPreferredWhenNonPreferredFirst0(ResolvedAddressT final String ipv6Address = "0:0:0:0:0:0:1:1"; final String ipv4Address = "1.1.1.1"; if (types == ResolvedAddressTypes.IPV4_PREFERRED) { - records.add(newAddressRecord(name, RecordType.AAAA, ipv6Address)); - records.add(newAddressRecord(name, RecordType.A, ipv4Address)); + records.add(Collections.singleton(TestDnsServer.newAddressRecord(name, RecordType.AAAA, ipv6Address))); + records.add(Collections.singleton(TestDnsServer.newAddressRecord(name, RecordType.A, ipv4Address))); } else { - records.add(newAddressRecord(name, RecordType.A, ipv4Address)); - records.add(newAddressRecord(name, RecordType.AAAA, ipv6Address)); + records.add(Collections.singleton(TestDnsServer.newAddressRecord(name, RecordType.A, ipv4Address))); + records.add(Collections.singleton(TestDnsServer.newAddressRecord(name, RecordType.AAAA, ipv6Address))); } final Iterator> recordsIterator = records.iterator(); RecordStore arbitrarilyOrderedStore = new RecordStore() { @@ -1229,16 +1237,6 @@ public Set getRecords(QuestionRecord questionRecord) { } } - private static Set newAddressRecord(String name, RecordType type, String address) { - ResourceRecordModifier rm = new ResourceRecordModifier(); - rm.setDnsClass(RecordClass.IN); - rm.setDnsName(name); - rm.setDnsTtl(100); - rm.setDnsType(type); - rm.put(DnsAttribute.IP_ADDRESS, address); - return Collections.singleton(rm.getEntry()); - } - private static void testRecursiveResolveCache(boolean cache) throws Exception { final String hostname = "some.record.netty.io"; @@ -1252,20 +1250,23 @@ private static void testRecursiveResolveCache(boolean cache) dnsServerAuthority.localAddress().getAddress().getHostAddress()); dnsServer.start(); - TestDnsCache nsCache = new TestDnsCache(cache ? new DefaultDnsCache() : NoopDnsCache.INSTANCE); + TestAuthoritativeDnsServerCache nsCache = new TestAuthoritativeDnsServerCache( + cache ? new DefaultAuthoritativeDnsServerCache() : NoopAuthoritativeDnsServerCache.INSTANCE); TestRecursiveCacheDnsQueryLifecycleObserverFactory lifecycleObserverFactory = new TestRecursiveCacheDnsQueryLifecycleObserverFactory(); EventLoopGroup group = new NioEventLoopGroup(1); - DnsNameResolver resolver = new DnsNameResolver( + final DnsNameResolver resolver = new DnsNameResolver( group.next(), new ReflectiveChannelFactory(NioDatagramChannel.class), NoopDnsCache.INSTANCE, nsCache, lifecycleObserverFactory, 3000, ResolvedAddressTypes.IPV4_ONLY, true, 10, true, 4096, false, HostsFileEntriesResolver.DEFAULT, new SingletonDnsServerAddressStreamProvider(dnsServer.localAddress()), DnsNameResolver.DEFAULT_SEARCH_DOMAINS, 0, true) { @Override - int dnsRedirectPort(InetAddress server) { - return server.equals(dnsServerAuthority.localAddress().getAddress()) ? - dnsServerAuthority.localAddress().getPort() : DNS_PORT; + InetSocketAddress newRedirectServerAddress(InetAddress server) { + if (server.equals(dnsServerAuthority.localAddress().getAddress())) { + return new InetSocketAddress(server, dnsServerAuthority.localAddress().getPort()); + } + return super.newRedirectServerAddress(server); } }; @@ -1292,12 +1293,16 @@ int dnsRedirectPort(InetAddress server) { QuerySucceededEvent succeededEvent = (QuerySucceededEvent) observer.events.poll(); if (cache) { - assertNull(nsCache.cache.get("io.", null)); - assertNull(nsCache.cache.get("netty.io.", null)); - List entries = nsCache.cache.get("record.netty.io.", null); - assertEquals(1, entries.size()); + assertNull(nsCache.cache.get("io.")); + assertNull(nsCache.cache.get("netty.io.")); + DnsServerAddressStream entries = nsCache.cache.get("record.netty.io."); + + // First address should be resolved (as we received a matching additional record), second is unresolved. + assertEquals(2, entries.size()); + assertFalse(entries.next().isUnresolved()); + assertTrue(entries.next().isUnresolved()); - assertNull(nsCache.cache.get(hostname, null)); + assertNull(nsCache.cache.get(hostname)); // Test again via cache. resolver.resolveAll(hostname).syncUninterruptibly(); @@ -1336,6 +1341,431 @@ int dnsRedirectPort(InetAddress server) { } } + @Test + public void testFollowNsRedirectsNoopCaches() throws Exception { + testFollowNsRedirects(NoopDnsCache.INSTANCE, NoopAuthoritativeDnsServerCache.INSTANCE, false); + } + + @Test + public void testFollowNsRedirectsNoopDnsCache() throws Exception { + testFollowNsRedirects(NoopDnsCache.INSTANCE, new DefaultAuthoritativeDnsServerCache(), false); + } + + @Test + public void testFollowNsRedirectsNoopAuthoritativeDnsServerCache() throws Exception { + testFollowNsRedirects(new DefaultDnsCache(), NoopAuthoritativeDnsServerCache.INSTANCE, false); + } + + @Test + public void testFollowNsRedirectsDefaultCaches() throws Exception { + testFollowNsRedirects(new DefaultDnsCache(), new DefaultAuthoritativeDnsServerCache(), false); + } + + @Test + public void testFollowNsRedirectAndTrySecondNsOnTimeout() throws Exception { + testFollowNsRedirects(NoopDnsCache.INSTANCE, NoopAuthoritativeDnsServerCache.INSTANCE, true); + } + + @Test + public void testFollowNsRedirectAndTrySecondNsOnTimeoutDefaultCaches() throws Exception { + testFollowNsRedirects(new DefaultDnsCache(), new DefaultAuthoritativeDnsServerCache(), true); + } + + private void testFollowNsRedirects(DnsCache cache, AuthoritativeDnsServerCache authoritativeDnsServerCache, + final boolean invalidNsFirst) throws Exception { + final String domain = "netty.io"; + final String ns1Name = "ns1." + domain; + final String ns2Name = "ns2." + domain; + final InetAddress expected = InetAddress.getByAddress("some.record." + domain, new byte[] { 10, 10, 10, 10 }); + + // This is used to simulate a query timeout... + final DatagramSocket socket = new DatagramSocket(new InetSocketAddress(0)); + + final TestDnsServer dnsServerAuthority = new TestDnsServer(new RecordStore() { + @Override + public Set getRecords(QuestionRecord question) { + if (question.getDomainName().equals(expected.getHostName())) { + return Collections.singleton(TestDnsServer.newARecord( + expected.getHostName(), expected.getHostAddress())); + } + return Collections.emptySet(); + } + }); + dnsServerAuthority.start(); + + TestDnsServer redirectServer = new TestDnsServer(new HashSet( + Arrays.asList(expected.getHostName(), ns1Name, ns2Name))) { + @Override + protected DnsMessage filterMessage(DnsMessage message) { + for (QuestionRecord record: message.getQuestionRecords()) { + if (record.getDomainName().equals(expected.getHostName())) { + message.getAdditionalRecords().clear(); + message.getAnswerRecords().clear(); + if (invalidNsFirst) { + message.getAuthorityRecords().add(TestDnsServer.newNsRecord(domain, ns2Name)); + message.getAuthorityRecords().add(TestDnsServer.newNsRecord(domain, ns1Name)); + } else { + message.getAuthorityRecords().add(TestDnsServer.newNsRecord(domain, ns1Name)); + message.getAuthorityRecords().add(TestDnsServer.newNsRecord(domain, ns2Name)); + } + return message; + } + } + return message; + } + }; + redirectServer.start(); + EventLoopGroup group = new NioEventLoopGroup(1); + final DnsNameResolver resolver = new DnsNameResolver( + group.next(), new ReflectiveChannelFactory(NioDatagramChannel.class), + cache, authoritativeDnsServerCache, NoopDnsQueryLifecycleObserverFactory.INSTANCE, 2000, + ResolvedAddressTypes.IPV4_ONLY, true, 10, true, 4096, + false, HostsFileEntriesResolver.DEFAULT, + new SingletonDnsServerAddressStreamProvider(redirectServer.localAddress()), + DnsNameResolver.DEFAULT_SEARCH_DOMAINS, 0, true) { + + @Override + InetSocketAddress newRedirectServerAddress(InetAddress server) { + try { + if (server.getHostName().startsWith(ns1Name)) { + return new InetSocketAddress(InetAddress.getByAddress(ns1Name, + dnsServerAuthority.localAddress().getAddress().getAddress()), + dnsServerAuthority.localAddress().getPort()); + } + if (server.getHostName().startsWith(ns2Name)) { + return new InetSocketAddress(InetAddress.getByAddress(ns2Name, + NetUtil.LOCALHOST.getAddress()), socket.getLocalPort()); + } + } catch (UnknownHostException e) { + throw new IllegalStateException(e); + } + return super.newRedirectServerAddress(server); + } + }; + + try { + List resolved = resolver.resolveAll(expected.getHostName()).syncUninterruptibly().getNow(); + assertEquals(1, resolved.size()); + assertEquals(expected, resolved.get(0)); + + List resolved2 = resolver.resolveAll(expected.getHostName()).syncUninterruptibly().getNow(); + assertEquals(1, resolved2.size()); + assertEquals(expected, resolved2.get(0)); + + if (authoritativeDnsServerCache != NoopAuthoritativeDnsServerCache.INSTANCE) { + DnsServerAddressStream cached = authoritativeDnsServerCache.get(domain + '.'); + assertEquals(2, cached.size()); + InetSocketAddress ns1Address = InetSocketAddress.createUnresolved( + ns1Name + '.', DefaultDnsServerAddressStreamProvider.DNS_PORT); + InetSocketAddress ns2Address = InetSocketAddress.createUnresolved( + ns2Name + '.', DefaultDnsServerAddressStreamProvider.DNS_PORT); + + if (invalidNsFirst) { + assertEquals(ns2Address, cached.next()); + assertEquals(ns1Address, cached.next()); + } else { + assertEquals(ns1Address, cached.next()); + assertEquals(ns2Address, cached.next()); + } + } + if (cache != NoopDnsCache.INSTANCE) { + List ns1Cached = cache.get(ns1Name + '.', null); + assertEquals(1, ns1Cached.size()); + DnsCacheEntry nsEntry = ns1Cached.get(0); + assertNotNull(nsEntry.address()); + assertNull(nsEntry.cause()); + + List ns2Cached = cache.get(ns2Name + '.', null); + if (invalidNsFirst) { + assertEquals(1, ns2Cached.size()); + DnsCacheEntry ns2Entry = ns2Cached.get(0); + assertNotNull(ns2Entry.address()); + assertNull(ns2Entry.cause()); + } else { + // We should not even have tried to resolve the DNS name so this should be null. + assertNull(ns2Cached); + } + + List expectedCached = cache.get(expected.getHostName(), null); + assertEquals(1, expectedCached.size()); + DnsCacheEntry expectedEntry = expectedCached.get(0); + assertEquals(expected, expectedEntry.address()); + assertNull(expectedEntry.cause()); + } + } finally { + resolver.close(); + group.shutdownGracefully(0, 0, TimeUnit.SECONDS); + redirectServer.stop(); + dnsServerAuthority.stop(); + socket.close(); + } + } + + @Test + public void testMultipleAdditionalRecordsForSameNSRecord() throws Exception { + testMultipleAdditionalRecordsForSameNSRecord(false); + } + + @Test + public void testMultipleAdditionalRecordsForSameNSRecordReordered() throws Exception { + testMultipleAdditionalRecordsForSameNSRecord(true); + } + + private static void testMultipleAdditionalRecordsForSameNSRecord(final boolean reversed) throws Exception { + final String domain = "netty.io"; + final String hostname = "test.netty.io"; + final String ns1Name = "ns1." + domain; + final InetSocketAddress ns1Address = new InetSocketAddress( + InetAddress.getByAddress(ns1Name, new byte[] { 10, 0, 0, 1 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + final InetSocketAddress ns2Address = new InetSocketAddress( + InetAddress.getByAddress(ns1Name, new byte[] { 10, 0, 0, 2 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + final InetSocketAddress ns3Address = new InetSocketAddress( + InetAddress.getByAddress(ns1Name, new byte[] { 10, 0, 0, 3 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + final InetSocketAddress ns4Address = new InetSocketAddress( + InetAddress.getByAddress(ns1Name, new byte[] { 10, 0, 0, 4 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + + TestDnsServer redirectServer = new TestDnsServer(new HashSet(Arrays.asList(hostname, ns1Name))) { + @Override + protected DnsMessage filterMessage(DnsMessage message) { + for (QuestionRecord record: message.getQuestionRecords()) { + if (record.getDomainName().equals(hostname)) { + message.getAdditionalRecords().clear(); + message.getAnswerRecords().clear(); + message.getAuthorityRecords().add(TestDnsServer.newNsRecord(domain, ns1Name)); + message.getAdditionalRecords().add(newARecord(ns1Address)); + message.getAdditionalRecords().add(newARecord(ns2Address)); + message.getAdditionalRecords().add(newARecord(ns3Address)); + message.getAdditionalRecords().add(newARecord(ns4Address)); + return message; + } + } + return message; + } + + private ResourceRecord newARecord(InetSocketAddress address) { + return TestDnsServer.newARecord(address.getHostName(), address.getAddress().getHostAddress()); + } + }; + redirectServer.start(); + EventLoopGroup group = new NioEventLoopGroup(1); + + final List cached = new CopyOnWriteArrayList(); + final AuthoritativeDnsServerCache authoritativeDnsServerCache = new AuthoritativeDnsServerCache() { + @Override + public DnsServerAddressStream get(String hostname) { + return null; + } + + @Override + public void cache(String hostname, InetSocketAddress address, long originalTtl, EventLoop loop) { + cached.add(address); + } + + @Override + public void clear() { + // NOOP + } + + @Override + public boolean clear(String hostname) { + return false; + } + }; + + final AtomicReference redirectedRef = new AtomicReference(); + final DnsNameResolver resolver = new DnsNameResolver( + group.next(), new ReflectiveChannelFactory(NioDatagramChannel.class), + NoopDnsCache.INSTANCE, authoritativeDnsServerCache, + NoopDnsQueryLifecycleObserverFactory.INSTANCE, 2000, ResolvedAddressTypes.IPV4_ONLY, + true, 10, true, 4096, + false, HostsFileEntriesResolver.DEFAULT, + new SingletonDnsServerAddressStreamProvider(redirectServer.localAddress()), + DnsNameResolver.DEFAULT_SEARCH_DOMAINS, 0, true) { + + @Override + protected DnsServerAddressStream newRedirectDnsServerStream( + String hostname, List nameservers) { + if (reversed) { + Collections.reverse(nameservers); + } + DnsServerAddressStream stream = new SequentialDnsServerAddressStream(nameservers, 0); + redirectedRef.set(stream); + return stream; + } + }; + + try { + Throwable cause = resolver.resolveAll(hostname).await().cause(); + assertTrue(cause instanceof UnknownHostException); + DnsServerAddressStream redirected = redirectedRef.get(); + assertNotNull(redirected); + assertEquals(4, redirected.size()); + assertEquals(4, cached.size()); + + if (reversed) { + assertEquals(ns4Address, redirected.next()); + assertEquals(ns3Address, redirected.next()); + assertEquals(ns2Address, redirected.next()); + assertEquals(ns1Address, redirected.next()); + } else { + assertEquals(ns1Address, redirected.next()); + assertEquals(ns2Address, redirected.next()); + assertEquals(ns3Address, redirected.next()); + assertEquals(ns4Address, redirected.next()); + } + + // We should always have the same order in the cache. + assertEquals(ns1Address, cached.get(0)); + assertEquals(ns2Address, cached.get(1)); + assertEquals(ns3Address, cached.get(2)); + assertEquals(ns4Address, cached.get(3)); + } finally { + resolver.close(); + group.shutdownGracefully(0, 0, TimeUnit.SECONDS); + redirectServer.stop(); + } + } + + @Test + public void testNSRecordsFromCache() throws Exception { + final String domain = "netty.io"; + final String hostname = "test.netty.io"; + final String ns0Name = "ns0." + domain + '.'; + final String ns1Name = "ns1." + domain + '.'; + final String ns2Name = "ns2." + domain + '.'; + + final InetSocketAddress ns0Address = new InetSocketAddress( + InetAddress.getByAddress(ns0Name, new byte[] { 10, 1, 0, 1 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + final InetSocketAddress ns1Address = new InetSocketAddress( + InetAddress.getByAddress(ns1Name, new byte[] { 10, 0, 0, 1 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + final InetSocketAddress ns2Address = new InetSocketAddress( + InetAddress.getByAddress(ns1Name, new byte[] { 10, 0, 0, 2 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + final InetSocketAddress ns3Address = new InetSocketAddress( + InetAddress.getByAddress(ns1Name, new byte[] { 10, 0, 0, 3 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + final InetSocketAddress ns4Address = new InetSocketAddress( + InetAddress.getByAddress(ns1Name, new byte[] { 10, 0, 0, 4 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + final InetSocketAddress ns5Address = new InetSocketAddress( + InetAddress.getByAddress(ns2Name, new byte[] { 10, 0, 0, 5 }), + DefaultDnsServerAddressStreamProvider.DNS_PORT); + TestDnsServer redirectServer = new TestDnsServer(new HashSet(Arrays.asList(hostname, ns1Name))) { + @Override + protected DnsMessage filterMessage(DnsMessage message) { + for (QuestionRecord record: message.getQuestionRecords()) { + if (record.getDomainName().equals(hostname)) { + message.getAdditionalRecords().clear(); + message.getAnswerRecords().clear(); + message.getAuthorityRecords().add(TestDnsServer.newNsRecord(domain, ns0Name)); + message.getAuthorityRecords().add(TestDnsServer.newNsRecord(domain, ns1Name)); + message.getAuthorityRecords().add(TestDnsServer.newNsRecord(domain, ns2Name)); + + message.getAdditionalRecords().add(newARecord(ns0Address)); + message.getAdditionalRecords().add(newARecord(ns5Address)); + + return message; + } + } + return message; + } + + private ResourceRecord newARecord(InetSocketAddress address) { + return TestDnsServer.newARecord(address.getHostName(), address.getAddress().getHostAddress()); + } + }; + redirectServer.start(); + EventLoopGroup group = new NioEventLoopGroup(1); + + final List cached = new CopyOnWriteArrayList(); + final AuthoritativeDnsServerCache authoritativeDnsServerCache = new AuthoritativeDnsServerCache() { + @Override + public DnsServerAddressStream get(String hostname) { + return null; + } + + @Override + public void cache(String hostname, InetSocketAddress address, long originalTtl, EventLoop loop) { + cached.add(address); + } + + @Override + public void clear() { + // NOOP + } + + @Override + public boolean clear(String hostname) { + return false; + } + }; + + EventLoop loop = group.next(); + DefaultDnsCache cache = new DefaultDnsCache(); + cache.cache(ns1Name, null, ns1Address.getAddress(), 10000, loop); + cache.cache(ns1Name, null, ns2Address.getAddress(), 10000, loop); + cache.cache(ns1Name, null, ns3Address.getAddress(), 10000, loop); + cache.cache(ns1Name, null, ns4Address.getAddress(), 10000, loop); + + final AtomicReference redirectedRef = new AtomicReference(); + final DnsNameResolver resolver = new DnsNameResolver( + loop, new ReflectiveChannelFactory(NioDatagramChannel.class), + cache, authoritativeDnsServerCache, + NoopDnsQueryLifecycleObserverFactory.INSTANCE, 2000, ResolvedAddressTypes.IPV4_ONLY, + true, 10, true, 4096, + false, HostsFileEntriesResolver.DEFAULT, + new SingletonDnsServerAddressStreamProvider(redirectServer.localAddress()), + DnsNameResolver.DEFAULT_SEARCH_DOMAINS, 0, true) { + + @Override + protected DnsServerAddressStream newRedirectDnsServerStream( + String hostname, List nameservers) { + DnsServerAddressStream stream = new SequentialDnsServerAddressStream(nameservers, 0); + redirectedRef.set(stream); + return stream; + } + }; + + try { + Throwable cause = resolver.resolveAll(hostname).await().cause(); + assertTrue(cause instanceof UnknownHostException); + DnsServerAddressStream redirected = redirectedRef.get(); + assertNotNull(redirected); + assertEquals(6, redirected.size()); + assertEquals(3, cached.size()); + + // The redirected addresses should have been retrieven from the DnsCache if not resolved, so these are + // fully resolved. + assertEquals(ns0Address, redirected.next()); + assertEquals(ns1Address, redirected.next()); + assertEquals(ns2Address, redirected.next()); + assertEquals(ns3Address, redirected.next()); + assertEquals(ns4Address, redirected.next()); + assertEquals(ns5Address, redirected.next()); + + // As this address was supplied as ADDITIONAL we should put it resolved into the cache. + assertEquals(ns0Address, cached.get(0)); + assertEquals(ns5Address, cached.get(1)); + + // We should have put the unresolved address in the AuthoritativeDnsServerCache (but only 1 time) + assertEquals(unresolved(ns1Address), cached.get(2)); + } finally { + resolver.close(); + group.shutdownGracefully(0, 0, TimeUnit.SECONDS); + redirectServer.stop(); + } + } + + private static InetSocketAddress unresolved(InetSocketAddress address) { + return InetSocketAddress.createUnresolved(address.getHostString(), address.getPort()); + } + private static void resolve(DnsNameResolver resolver, Map> futures, String hostname) { futures.put(hostname, resolver.resolve(hostname)); } @@ -1480,10 +1910,43 @@ public void querySucceed() { } } + private static final class TestAuthoritativeDnsServerCache implements AuthoritativeDnsServerCache { + final AuthoritativeDnsServerCache cache; + final Map cacheHits = new HashMap(); + + TestAuthoritativeDnsServerCache(AuthoritativeDnsServerCache cache) { + this.cache = cache; + } + + @Override + public void clear() { + cache.clear(); + } + + @Override + public boolean clear(String hostname) { + return cache.clear(hostname); + } + + @Override + public DnsServerAddressStream get(String hostname) { + DnsServerAddressStream cached = cache.get(hostname); + if (cached != null) { + cacheHits.put(hostname, cached.duplicate()); + } + return cached; + } + + @Override + public void cache(String hostname, InetSocketAddress address, long originalTtl, EventLoop loop) { + cache.cache(hostname, address, originalTtl, loop); + } + } + private static final class TestDnsCache implements DnsCache { - private final DnsCache cache; - final Map> cacheHits = new HashMap>(); + final DnsCache cache; + final Map> cacheHits = + new HashMap>(); TestDnsCache(DnsCache cache) { this.cache = cache; @@ -1501,20 +1964,19 @@ public boolean clear(String hostname) { @Override public List get(String hostname, DnsRecord[] additionals) { - List cacheEntries = cache.get(hostname, additionals); - cacheHits.put(hostname, cacheEntries); - return cacheEntries; + List cached = cache.get(hostname, additionals); + cacheHits.put(hostname, cached); + return cached; } @Override - public DnsCacheEntry cache( - String hostname, DnsRecord[] additionals, InetAddress address, long originalTtl, EventLoop loop) { + public DnsCacheEntry cache(String hostname, DnsRecord[] additionals, InetAddress address, + long originalTtl, EventLoop loop) { return cache.cache(hostname, additionals, address, originalTtl, loop); } @Override - public DnsCacheEntry cache( - String hostname, DnsRecord[] additionals, Throwable cause, EventLoop loop) { + public DnsCacheEntry cache(String hostname, DnsRecord[] additionals, Throwable cause, EventLoop loop) { return cache.cache(hostname, additionals, cause, loop); } } @@ -1534,6 +1996,8 @@ private static class RedirectingTestDnsServer extends TestDnsServer { protected DnsMessage filterMessage(DnsMessage message) { // Clear the answers as we want to add our own stuff to test dns redirects. message.getAnswerRecords().clear(); + message.getAuthorityRecords().clear(); + message.getAdditionalRecords().clear(); String name = domain; for (int i = 0 ;; i++) { @@ -1545,30 +2009,13 @@ protected DnsMessage filterMessage(DnsMessage message) { String dnsName = "dns" + idx + '.' + domain; message.getAuthorityRecords().add(newNsRecord(name, dnsName)); message.getAdditionalRecords().add(newARecord(dnsName, i == 0 ? dnsAddress : "1.2.3." + idx)); + + // Add an unresolved NS record (with no additionals as well) + message.getAuthorityRecords().add(newNsRecord(name, "unresolved." + dnsName)); } return message; } - - private static ResourceRecord newARecord(String dnsname, String ipAddress) { - ResourceRecordModifier rm = new ResourceRecordModifier(); - rm.setDnsClass(RecordClass.IN); - rm.setDnsName(dnsname); - rm.setDnsTtl(100); - rm.setDnsType(RecordType.A); - rm.put(DnsAttribute.IP_ADDRESS, ipAddress); - return rm.getEntry(); - } - - private static ResourceRecord newNsRecord(String dnsname, String domainName) { - ResourceRecordModifier rm = new ResourceRecordModifier(); - rm.setDnsClass(RecordClass.IN); - rm.setDnsName(dnsname); - rm.setDnsTtl(100); - rm.setDnsType(RecordType.NS); - rm.put(DnsAttribute.DOMAIN_NAME, domainName); - return rm.getEntry(); - } } @Test(timeout = 3000) @@ -1682,4 +2129,342 @@ public Set getRecords(QuestionRecord question) { } } } + + @Test + public void testFollowCNAMELoop() throws IOException { + expectedException.expect(UnknownHostException.class); + TestDnsServer dnsServer2 = new TestDnsServer(new RecordStore() { + + @Override + public Set getRecords(QuestionRecord question) { + Set records = new LinkedHashSet(4); + + records.add(new TestDnsServer.TestResourceRecord("x." + question.getDomainName(), + RecordType.A, Collections.singletonMap( + DnsAttribute.IP_ADDRESS.toLowerCase(), "10.0.0.99"))); + records.add(new TestDnsServer.TestResourceRecord( + "cname2.netty.io", RecordType.CNAME, + Collections.singletonMap( + DnsAttribute.DOMAIN_NAME.toLowerCase(), "cname.netty.io"))); + records.add(new TestDnsServer.TestResourceRecord( + "cname.netty.io", RecordType.CNAME, + Collections.singletonMap( + DnsAttribute.DOMAIN_NAME.toLowerCase(), "cname2.netty.io"))); + records.add(new TestDnsServer.TestResourceRecord( + question.getDomainName(), RecordType.CNAME, + Collections.singletonMap( + DnsAttribute.DOMAIN_NAME.toLowerCase(), "cname.netty.io"))); + return records; + } + }); + dnsServer2.start(); + DnsNameResolver resolver = null; + try { + DnsNameResolverBuilder builder = newResolver() + .recursionDesired(false) + .resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY) + .maxQueriesPerResolve(16) + .nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer2.localAddress())); + + resolver = builder.build(); + resolver.resolveAll("somehost.netty.io").syncUninterruptibly().getNow(); + } finally { + dnsServer2.stop(); + if (resolver != null) { + resolver.close(); + } + } + } + + @Test + public void testSearchDomainQueryFailureForSingleAddressTypeCompletes() { + expectedException.expect(UnknownHostException.class); + testSearchDomainQueryFailureCompletes(ResolvedAddressTypes.IPV4_ONLY); + } + + @Test + public void testSearchDomainQueryFailureForMultipleAddressTypeCompletes() { + expectedException.expect(UnknownHostException.class); + testSearchDomainQueryFailureCompletes(ResolvedAddressTypes.IPV4_PREFERRED); + } + + private void testSearchDomainQueryFailureCompletes(ResolvedAddressTypes types) { + DnsNameResolver resolver = newResolver() + .resolvedAddressTypes(types) + .ndots(1) + .searchDomains(singletonList(".")).build(); + try { + resolver.resolve("invalid.com").syncUninterruptibly(); + } finally { + resolver.close(); + } + } + + @Test(timeout = 2000L) + public void testCachesClearedOnClose() throws Exception { + final CountDownLatch resolveLatch = new CountDownLatch(1); + final CountDownLatch authoritativeLatch = new CountDownLatch(1); + + DnsNameResolver resolver = newResolver().resolveCache(new DnsCache() { + @Override + public void clear() { + resolveLatch.countDown(); + } + + @Override + public boolean clear(String hostname) { + return false; + } + + @Override + public List get(String hostname, DnsRecord[] additionals) { + return null; + } + + @Override + public DnsCacheEntry cache( + String hostname, DnsRecord[] additionals, InetAddress address, long originalTtl, EventLoop loop) { + return null; + } + + @Override + public DnsCacheEntry cache( + String hostname, DnsRecord[] additionals, Throwable cause, EventLoop loop) { + return null; + } + }).authoritativeDnsServerCache(new DnsCache() { + @Override + public void clear() { + authoritativeLatch.countDown(); + } + + @Override + public boolean clear(String hostname) { + return false; + } + + @Override + public List get(String hostname, DnsRecord[] additionals) { + return null; + } + + @Override + public DnsCacheEntry cache( + String hostname, DnsRecord[] additionals, InetAddress address, long originalTtl, EventLoop loop) { + return null; + } + + @Override + public DnsCacheEntry cache(String hostname, DnsRecord[] additionals, Throwable cause, EventLoop loop) { + return null; + } + }).build(); + + resolver.close(); + resolveLatch.await(); + authoritativeLatch.await(); + } + + @Test + public void testResolveACachedWithDot() { + final DnsCache cache = new DefaultDnsCache(); + DnsNameResolver resolver = newResolver(ResolvedAddressTypes.IPV4_ONLY) + .resolveCache(cache).build(); + + try { + String domain = DOMAINS.iterator().next(); + String domainWithDot = domain + '.'; + + resolver.resolve(domain).syncUninterruptibly(); + List cached = cache.get(domain, null); + List cached2 = cache.get(domainWithDot, null); + + assertEquals(1, cached.size()); + assertSame(cached, cached2); + } finally { + resolver.close(); + } + } + + @Test + public void testResolveACachedWithDotSearchDomain() throws Exception { + final TestDnsCache cache = new TestDnsCache(new DefaultDnsCache()); + TestDnsServer server = new TestDnsServer(Collections.singleton("test.netty.io")); + server.start(); + DnsNameResolver resolver = newResolver(ResolvedAddressTypes.IPV4_ONLY) + .searchDomains(Collections.singletonList("netty.io")) + .nameServerProvider(new SingletonDnsServerAddressStreamProvider(server.localAddress())) + .resolveCache(cache).build(); + try { + resolver.resolve("test").syncUninterruptibly(); + + assertNull(cache.cacheHits.get("test.netty.io")); + + List cached = cache.cache.get("test.netty.io", null); + List cached2 = cache.cache.get("test.netty.io.", null); + assertEquals(1, cached.size()); + assertSame(cached, cached2); + + resolver.resolve("test").syncUninterruptibly(); + List entries = cache.cacheHits.get("test.netty.io"); + assertFalse(entries.isEmpty()); + } finally { + resolver.close(); + server.stop(); + } + } + + @Test + public void testChannelFactoryException() { + final IllegalStateException exception = new IllegalStateException(); + try { + newResolver().channelFactory(new ChannelFactory() { + @Override + public DatagramChannel newChannel() { + throw exception; + } + }).build(); + fail(); + } catch (Exception e) { + assertSame(exception, e); + } + } + + @Test + public void testCNameCached() throws Exception { + final Map cache = new ConcurrentHashMap(); + final AtomicInteger cnameQueries = new AtomicInteger(); + final AtomicInteger aQueries = new AtomicInteger(); + + TestDnsServer dnsServer2 = new TestDnsServer(new RecordStore() { + + @Override + public Set getRecords(QuestionRecord question) { + if ("cname.netty.io".equals(question.getDomainName())) { + aQueries.incrementAndGet(); + + return Collections.singleton(new TestDnsServer.TestResourceRecord( + question.getDomainName(), RecordType.A, + Collections.singletonMap( + DnsAttribute.IP_ADDRESS.toLowerCase(), "10.0.0.99"))); + } + if ("x.netty.io".equals(question.getDomainName())) { + cnameQueries.incrementAndGet(); + + return Collections.singleton(new TestDnsServer.TestResourceRecord( + question.getDomainName(), RecordType.CNAME, + Collections.singletonMap( + DnsAttribute.DOMAIN_NAME.toLowerCase(), "cname.netty.io"))); + } + if ("y.netty.io".equals(question.getDomainName())) { + cnameQueries.incrementAndGet(); + + return Collections.singleton(new TestDnsServer.TestResourceRecord( + question.getDomainName(), RecordType.CNAME, + Collections.singletonMap( + DnsAttribute.DOMAIN_NAME.toLowerCase(), "x.netty.io"))); + } + return Collections.emptySet(); + } + }); + dnsServer2.start(); + DnsNameResolver resolver = null; + try { + DnsNameResolverBuilder builder = newResolver() + .recursionDesired(true) + .resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY) + .maxQueriesPerResolve(16) + .nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer2.localAddress())) + .resolveCache(NoopDnsCache.INSTANCE) + .cnameCache(new DnsCnameCache() { + @Override + public String get(String hostname) { + assertTrue(hostname, hostname.endsWith(".")); + return cache.get(hostname); + } + + @Override + public void cache(String hostname, String cname, long originalTtl, EventLoop loop) { + assertTrue(hostname, hostname.endsWith(".")); + cache.put(hostname, cname); + } + + @Override + public void clear() { + // NOOP + } + + @Override + public boolean clear(String hostname) { + return false; + } + }); + resolver = builder.build(); + List resolvedAddresses = + resolver.resolveAll("x.netty.io").syncUninterruptibly().getNow(); + assertEquals(1, resolvedAddresses.size()); + assertTrue(resolvedAddresses.contains(InetAddress.getByAddress(new byte[] { 10, 0, 0, 99 }))); + + assertEquals("cname.netty.io.", cache.get("x.netty.io.")); + assertEquals(1, cnameQueries.get()); + assertEquals(1, aQueries.get()); + + resolvedAddresses = + resolver.resolveAll("x.netty.io").syncUninterruptibly().getNow(); + assertEquals(1, resolvedAddresses.size()); + assertTrue(resolvedAddresses.contains(InetAddress.getByAddress(new byte[] { 10, 0, 0, 99 }))); + + // Should not have queried for the CNAME again. + assertEquals(1, cnameQueries.get()); + assertEquals(2, aQueries.get()); + + resolvedAddresses = + resolver.resolveAll("y.netty.io").syncUninterruptibly().getNow(); + assertEquals(1, resolvedAddresses.size()); + assertTrue(resolvedAddresses.contains(InetAddress.getByAddress(new byte[] { 10, 0, 0, 99 }))); + + assertEquals("x.netty.io.", cache.get("y.netty.io.")); + + // Will only query for one CNAME + assertEquals(2, cnameQueries.get()); + assertEquals(3, aQueries.get()); + + resolvedAddresses = + resolver.resolveAll("y.netty.io").syncUninterruptibly().getNow(); + assertEquals(1, resolvedAddresses.size()); + assertTrue(resolvedAddresses.contains(InetAddress.getByAddress(new byte[] { 10, 0, 0, 99 }))); + + // Should not have queried for the CNAME again. + assertEquals(2, cnameQueries.get()); + assertEquals(4, aQueries.get()); + } finally { + dnsServer2.stop(); + if (resolver != null) { + resolver.close(); + } + } + } + + @Test + public void testInstanceWithNullPreferredAddressType() { + new DnsNameResolver( + group.next(), // eventLoop + new ReflectiveChannelFactory(NioDatagramChannel.class), // channelFactory + NoopDnsCache.INSTANCE, // resolveCache + NoopAuthoritativeDnsServerCache.INSTANCE, // authoritativeDnsServerCache + NoopDnsQueryLifecycleObserverFactory.INSTANCE, // dnsQueryLifecycleObserverFactory + 100, // queryTimeoutMillis + null, // resolvedAddressTypes, see https://github.com/netty/netty/pull/8445 + true, // recursionDesired + 1, // maxQueriesPerResolve + false, // traceEnabled + 4096, // maxPayloadSize + true, // optResourceEnabled + HostsFileEntriesResolver.DEFAULT, // hostsFileEntriesResolver + DnsServerAddressStreamProviders.platformDefault(), // dnsServerAddressStreamProvider + null, // searchDomains + 1, // ndots + true // decodeIdn + ).close(); + } } diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/NameServerComparatorTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/NameServerComparatorTest.java new file mode 100644 index 000000000000..e29932a18bba --- /dev/null +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/NameServerComparatorTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import org.junit.BeforeClass; +import org.junit.Test; + +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.*; + + +public class NameServerComparatorTest { + + private static InetSocketAddress IPV4ADDRESS1; + private static InetSocketAddress IPV4ADDRESS2; + private static InetSocketAddress IPV4ADDRESS3; + + private static InetSocketAddress IPV6ADDRESS1; + private static InetSocketAddress IPV6ADDRESS2; + + private static InetSocketAddress UNRESOLVED1; + private static InetSocketAddress UNRESOLVED2; + private static InetSocketAddress UNRESOLVED3; + + @BeforeClass + public static void before() throws UnknownHostException { + IPV4ADDRESS1 = new InetSocketAddress(InetAddress.getByAddress("ns1", new byte[] { 10, 0, 0, 1 }), 53); + IPV4ADDRESS2 = new InetSocketAddress(InetAddress.getByAddress("ns2", new byte[] { 10, 0, 0, 2 }), 53); + IPV4ADDRESS3 = new InetSocketAddress(InetAddress.getByAddress("ns3", new byte[] { 10, 0, 0, 3 }), 53); + + IPV6ADDRESS1 = new InetSocketAddress(InetAddress.getByAddress( + "ns1", new byte[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }), 53); + IPV6ADDRESS2 = new InetSocketAddress(InetAddress.getByAddress( + "ns2", new byte[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2 }), 53); + + UNRESOLVED1 = InetSocketAddress.createUnresolved("ns3", 53); + UNRESOLVED2 = InetSocketAddress.createUnresolved("ns4", 53); + UNRESOLVED3 = InetSocketAddress.createUnresolved("ns5", 53); + } + + @Test + public void testCompareResolvedOnly() { + NameServerComparator comparator = new NameServerComparator(Inet4Address.class); + int x = comparator.compare(IPV4ADDRESS1, IPV6ADDRESS1); + int y = comparator.compare(IPV6ADDRESS1, IPV4ADDRESS1); + + assertEquals(-1, x); + assertEquals(x, -y); + + assertEquals(0, comparator.compare(IPV4ADDRESS1, IPV4ADDRESS1)); + assertEquals(0, comparator.compare(IPV6ADDRESS1, IPV6ADDRESS1)); + } + + @Test + public void testCompareUnresolvedSimple() { + NameServerComparator comparator = new NameServerComparator(Inet4Address.class); + int x = comparator.compare(IPV4ADDRESS1, UNRESOLVED1); + int y = comparator.compare(UNRESOLVED1, IPV4ADDRESS1); + + assertEquals(-1, x); + assertEquals(x, -y); + assertEquals(0, comparator.compare(IPV4ADDRESS1, IPV4ADDRESS1)); + assertEquals(0, comparator.compare(UNRESOLVED1, UNRESOLVED1)); + } + + @Test + public void testCompareUnresolvedOnly() { + NameServerComparator comparator = new NameServerComparator(Inet4Address.class); + int x = comparator.compare(UNRESOLVED1, UNRESOLVED2); + int y = comparator.compare(UNRESOLVED2, UNRESOLVED1); + + assertEquals(0, x); + assertEquals(x, -y); + + assertEquals(0, comparator.compare(UNRESOLVED1, UNRESOLVED1)); + assertEquals(0, comparator.compare(UNRESOLVED2, UNRESOLVED2)); + } + + @Test + public void testSortAlreadySortedPreferred() { + List expected = Arrays.asList(IPV4ADDRESS1, IPV4ADDRESS2, IPV4ADDRESS3); + List addresses = new ArrayList(expected); + NameServerComparator comparator = new NameServerComparator(Inet4Address.class); + + Collections.sort(addresses, comparator); + + assertEquals(expected, addresses); + } + + @Test + public void testSortAlreadySortedNotPreferred() { + List expected = Arrays.asList(IPV4ADDRESS1, IPV4ADDRESS2, IPV4ADDRESS3); + List addresses = new ArrayList(expected); + NameServerComparator comparator = new NameServerComparator(Inet6Address.class); + + Collections.sort(addresses, comparator); + + assertEquals(expected, addresses); + } + + @Test + public void testSortAlreadySortedUnresolved() { + List expected = Arrays.asList(UNRESOLVED1, UNRESOLVED2, UNRESOLVED3); + List addresses = new ArrayList(expected); + NameServerComparator comparator = new NameServerComparator(Inet6Address.class); + + Collections.sort(addresses, comparator); + + assertEquals(expected, addresses); + } + + @Test + public void testSortAlreadySortedMixed() { + List expected = Arrays.asList( + IPV4ADDRESS1, IPV4ADDRESS2, IPV6ADDRESS1, IPV6ADDRESS2, UNRESOLVED1, UNRESOLVED2); + + List addresses = new ArrayList(expected); + NameServerComparator comparator = new NameServerComparator(Inet4Address.class); + + Collections.sort(addresses, comparator); + + assertEquals(expected, addresses); + } + + @Test + public void testSort1() { + List expected = Arrays.asList( + IPV4ADDRESS1, IPV4ADDRESS2, IPV6ADDRESS1, IPV6ADDRESS2, UNRESOLVED1, UNRESOLVED2); + List addresses = new ArrayList( + Arrays.asList(IPV6ADDRESS1, IPV4ADDRESS1, IPV6ADDRESS2, UNRESOLVED1, UNRESOLVED2, IPV4ADDRESS2)); + NameServerComparator comparator = new NameServerComparator(Inet4Address.class); + + Collections.sort(addresses, comparator); + + assertEquals(expected, addresses); + } + + @Test + public void testSort2() { + List expected = Arrays.asList( + IPV4ADDRESS1, IPV4ADDRESS2, IPV6ADDRESS1, IPV6ADDRESS2, UNRESOLVED1, UNRESOLVED2); + List addresses = new ArrayList( + Arrays.asList(IPV4ADDRESS1, IPV6ADDRESS1, IPV6ADDRESS2, UNRESOLVED1, IPV4ADDRESS2, UNRESOLVED2)); + NameServerComparator comparator = new NameServerComparator(Inet4Address.class); + + Collections.sort(addresses, comparator); + + assertEquals(expected, addresses); + } +} diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/TestDnsServer.java b/resolver-dns/src/test/java/io/netty/resolver/dns/TestDnsServer.java index dd351ae242e7..229ea98a8806 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/TestDnsServer.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/TestDnsServer.java @@ -26,6 +26,7 @@ import org.apache.directory.server.dns.messages.RecordType; import org.apache.directory.server.dns.messages.ResourceRecord; import org.apache.directory.server.dns.messages.ResourceRecordImpl; +import org.apache.directory.server.dns.messages.ResourceRecordModifier; import org.apache.directory.server.dns.protocol.DnsProtocolHandler; import org.apache.directory.server.dns.protocol.DnsUdpDecoder; import org.apache.directory.server.dns.protocol.DnsUdpEncoder; @@ -67,7 +68,7 @@ class TestDnsServer extends DnsServer { BYTES.put("0:1:1:1:1:1:1:1", new byte[]{0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}); BYTES.put("1:1:1:1:1:1:1:1", new byte[]{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}); - IPV6_ADDRESSES = BYTES.keySet().toArray(new String[BYTES.size()]); + IPV6_ADDRESSES = BYTES.keySet().toArray(new String[0]); } private final RecordStore store; @@ -111,6 +112,30 @@ protected DnsMessage filterMessage(DnsMessage message) { return message; } + protected static ResourceRecord newARecord(String name, String ipAddress) { + return newAddressRecord(name, RecordType.A, ipAddress); + } + + protected static ResourceRecord newNsRecord(String dnsname, String domainName) { + ResourceRecordModifier rm = new ResourceRecordModifier(); + rm.setDnsClass(RecordClass.IN); + rm.setDnsName(dnsname); + rm.setDnsTtl(100); + rm.setDnsType(RecordType.NS); + rm.put(DnsAttribute.DOMAIN_NAME, domainName); + return rm.getEntry(); + } + + protected static ResourceRecord newAddressRecord(String name, RecordType type, String address) { + ResourceRecordModifier rm = new ResourceRecordModifier(); + rm.setDnsClass(RecordClass.IN); + rm.setDnsName(name); + rm.setDnsTtl(100); + rm.setDnsType(type); + rm.put(DnsAttribute.IP_ADDRESS, address); + return rm.getEntry(); + } + /** * {@link ProtocolCodecFactory} which allows to test AAAA resolution. */ diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/UnixResolverDnsServerAddressStreamProviderTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/UnixResolverDnsServerAddressStreamProviderTest.java index 38c6c919c512..33996679e631 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/UnixResolverDnsServerAddressStreamProviderTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/UnixResolverDnsServerAddressStreamProviderTest.java @@ -25,6 +25,9 @@ import java.io.IOException; import java.io.OutputStream; import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import static io.netty.resolver.dns.UnixResolverDnsServerAddressStreamProvider.DEFAULT_NDOTS; import static io.netty.resolver.dns.UnixResolverDnsServerAddressStreamProvider.parseEtcResolverFirstNdots; @@ -111,6 +114,56 @@ public void emptyEtcResolverDirectoryDoesNotThrow() throws IOException { assertHostNameEquals("127.0.0.2", stream.next()); } + @Test + public void searchDomainsWithOnlyDomain() throws IOException { + File f = buildFile("domain linecorp.local\n" + + "nameserver 127.0.0.2\n"); + List domains = UnixResolverDnsServerAddressStreamProvider.parseEtcResolverSearchDomains(f); + assertEquals(Collections.singletonList("linecorp.local"), domains); + } + + @Test + public void searchDomainsWithOnlySearch() throws IOException { + File f = buildFile("search linecorp.local\n" + + "nameserver 127.0.0.2\n"); + List domains = UnixResolverDnsServerAddressStreamProvider.parseEtcResolverSearchDomains(f); + assertEquals(Collections.singletonList("linecorp.local"), domains); + } + + @Test + public void searchDomainsWithMultipleSearch() throws IOException { + File f = buildFile("search linecorp.local\n" + + "search squarecorp.local\n" + + "nameserver 127.0.0.2\n"); + List domains = UnixResolverDnsServerAddressStreamProvider.parseEtcResolverSearchDomains(f); + assertEquals(Arrays.asList("linecorp.local", "squarecorp.local"), domains); + } + + @Test + public void searchDomainsWithMultipleSearchSeperatedByWhitespace() throws IOException { + File f = buildFile("search linecorp.local squarecorp.local\n" + + "nameserver 127.0.0.2\n"); + List domains = UnixResolverDnsServerAddressStreamProvider.parseEtcResolverSearchDomains(f); + assertEquals(Arrays.asList("linecorp.local", "squarecorp.local"), domains); + } + + @Test + public void searchDomainsWithMultipleSearchSeperatedByTab() throws IOException { + File f = buildFile("search linecorp.local\tsquarecorp.local\n" + + "nameserver 127.0.0.2\n"); + List domains = UnixResolverDnsServerAddressStreamProvider.parseEtcResolverSearchDomains(f); + assertEquals(Arrays.asList("linecorp.local", "squarecorp.local"), domains); + } + + @Test + public void searchDomainsPrecedence() throws IOException { + File f = buildFile("domain linecorp.local\n" + + "search squarecorp.local\n" + + "nameserver 127.0.0.2\n"); + List domains = UnixResolverDnsServerAddressStreamProvider.parseEtcResolverSearchDomains(f); + assertEquals(Collections.singletonList("squarecorp.local"), domains); + } + private File buildFile(String contents) throws IOException { File f = folder.newFile(); OutputStream out = new FileOutputStream(f); diff --git a/resolver/pom.xml b/resolver/pom.xml index b48df3cad1c5..9918805ef9f4 100644 --- a/resolver/pom.xml +++ b/resolver/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-resolver diff --git a/resolver/src/main/java/io/netty/resolver/AddressResolverGroup.java b/resolver/src/main/java/io/netty/resolver/AddressResolverGroup.java index bd1a34664e2a..e9bb9dfd6430 100644 --- a/resolver/src/main/java/io/netty/resolver/AddressResolverGroup.java +++ b/resolver/src/main/java/io/netty/resolver/AddressResolverGroup.java @@ -102,7 +102,7 @@ public void operationComplete(Future future) throws Exception { public void close() { final AddressResolver[] rArray; synchronized (resolvers) { - rArray = (AddressResolver[]) resolvers.values().toArray(new AddressResolver[resolvers.size()]); + rArray = (AddressResolver[]) resolvers.values().toArray(new AddressResolver[0]); resolvers.clear(); } diff --git a/resolver/src/main/java/io/netty/resolver/DefaultHostsFileEntriesResolver.java b/resolver/src/main/java/io/netty/resolver/DefaultHostsFileEntriesResolver.java index 7598a29200c6..9051262eb1c4 100644 --- a/resolver/src/main/java/io/netty/resolver/DefaultHostsFileEntriesResolver.java +++ b/resolver/src/main/java/io/netty/resolver/DefaultHostsFileEntriesResolver.java @@ -15,11 +15,14 @@ */ package io.netty.resolver; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.UnstableApi; import java.net.Inet4Address; import java.net.Inet6Address; import java.net.InetAddress; +import java.nio.charset.Charset; import java.util.Locale; import java.util.Map; @@ -33,7 +36,7 @@ public final class DefaultHostsFileEntriesResolver implements HostsFileEntriesRe private final Map inet6Entries; public DefaultHostsFileEntriesResolver() { - this(HostsFileParser.parseSilently()); + this(parseEntries()); } // for testing purpose only @@ -65,4 +68,14 @@ public InetAddress address(String inetHost, ResolvedAddressTypes resolvedAddress String normalize(String inetHost) { return inetHost.toLowerCase(Locale.ENGLISH); } + + private static HostsFileEntries parseEntries() { + if (PlatformDependent.isWindows()) { + // Ony windows there seems to be no standard for the encoding used for the hosts file, so let us + // try multiple until we either were able to parse it or there is none left and so we return an + // empty intstance. + return HostsFileParser.parseSilently(Charset.defaultCharset(), CharsetUtil.UTF_16, CharsetUtil.UTF_8); + } + return HostsFileParser.parseSilently(); + } } diff --git a/resolver/src/main/java/io/netty/resolver/HostsFileParser.java b/resolver/src/main/java/io/netty/resolver/HostsFileParser.java index 8ee1d195689f..16eaba6167f9 100644 --- a/resolver/src/main/java/io/netty/resolver/HostsFileParser.java +++ b/resolver/src/main/java/io/netty/resolver/HostsFileParser.java @@ -23,12 +23,14 @@ import java.io.BufferedReader; import java.io.File; -import java.io.FileReader; +import java.io.FileInputStream; import java.io.IOException; +import java.io.InputStreamReader; import java.io.Reader; import java.net.Inet4Address; import java.net.Inet6Address; import java.net.InetAddress; +import java.nio.charset.Charset; import java.util.ArrayList; import java.util.List; import java.util.Locale; @@ -66,22 +68,35 @@ private static File locateHostsFile() { } /** - * Parse hosts file at standard OS location. + * Parse hosts file at standard OS location using the systems default {@link Charset} for decoding. * * @return a {@link HostsFileEntries} */ public static HostsFileEntries parseSilently() { + return parseSilently(Charset.defaultCharset()); + } + + /** + * Parse hosts file at standard OS location using the given {@link Charset}s one after each other until + * we were able to parse something or none is left. + * + * @param charsets the {@link Charset}s to try as file encodings when parsing. + * @return a {@link HostsFileEntries} + */ + public static HostsFileEntries parseSilently(Charset... charsets) { File hostsFile = locateHostsFile(); try { - return parse(hostsFile); + return parse(hostsFile, charsets); } catch (IOException e) { - logger.warn("Failed to load and parse hosts file at " + hostsFile.getPath(), e); + if (logger.isWarnEnabled()) { + logger.warn("Failed to load and parse hosts file at " + hostsFile.getPath(), e); + } return HostsFileEntries.EMPTY; } } /** - * Parse hosts file at standard OS location. + * Parse hosts file at standard OS location using the system default {@link Charset} for decoding. * * @return a {@link HostsFileEntries} * @throws IOException file could not be read @@ -91,19 +106,37 @@ public static HostsFileEntries parse() throws IOException { } /** - * Parse a hosts file. + * Parse a hosts file using the system default {@link Charset} for decoding. * * @param file the file to be parsed * @return a {@link HostsFileEntries} * @throws IOException file could not be read */ public static HostsFileEntries parse(File file) throws IOException { + return parse(file, Charset.defaultCharset()); + } + + /** + * Parse a hosts file. + * + * @param file the file to be parsed + * @param charsets the {@link Charset}s to try as file encodings when parsing. + * @return a {@link HostsFileEntries} + * @throws IOException file could not be read + */ + public static HostsFileEntries parse(File file, Charset... charsets) throws IOException { checkNotNull(file, "file"); + checkNotNull(charsets, "charsets"); if (file.exists() && file.isFile()) { - return parse(new BufferedReader(new FileReader(file))); - } else { - return HostsFileEntries.EMPTY; + for (Charset charset: charsets) { + HostsFileEntries entries = parse(new BufferedReader(new InputStreamReader( + new FileInputStream(file), charset))); + if (entries != HostsFileEntries.EMPTY) { + return entries; + } + } } + return HostsFileEntries.EMPTY; } /** diff --git a/resolver/src/test/java/io/netty/resolver/HostsFileParserTest.java b/resolver/src/test/java/io/netty/resolver/HostsFileParserTest.java index 6b908f574205..ee07c592cae4 100644 --- a/resolver/src/test/java/io/netty/resolver/HostsFileParserTest.java +++ b/resolver/src/test/java/io/netty/resolver/HostsFileParserTest.java @@ -15,6 +15,9 @@ */ package io.netty.resolver; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ResourcesUtil; +import org.junit.Assume; import org.junit.Test; import java.io.BufferedReader; @@ -22,6 +25,8 @@ import java.io.StringReader; import java.net.Inet4Address; import java.net.Inet6Address; +import java.nio.charset.Charset; +import java.nio.charset.UnsupportedCharsetException; import java.util.Map; import static org.junit.Assert.*; @@ -60,4 +65,41 @@ public void testParse() throws IOException { assertEquals("192.168.0.5", inet4Entries.get("host7").getHostAddress()); assertEquals("0:0:0:0:0:0:0:1", inet6Entries.get("host1").getHostAddress()); } + + @Test + public void testParseUnicode() throws IOException { + final Charset unicodeCharset; + try { + unicodeCharset = Charset.forName("unicode"); + } catch (UnsupportedCharsetException e) { + Assume.assumeNoException(e); + return; + } + testParseFile(HostsFileParser.parse( + ResourcesUtil.getFile(getClass(), "hosts-unicode"), unicodeCharset)); + } + + @Test + public void testParseMultipleCharsets() throws IOException { + final Charset unicodeCharset; + try { + unicodeCharset = Charset.forName("unicode"); + } catch (UnsupportedCharsetException e) { + Assume.assumeNoException(e); + return; + } + testParseFile(HostsFileParser.parse(ResourcesUtil.getFile(getClass(), "hosts-unicode"), + CharsetUtil.UTF_8, CharsetUtil.ISO_8859_1, unicodeCharset)); + } + + private static void testParseFile(HostsFileEntries entries) throws IOException { + Map inet4Entries = entries.inet4Entries(); + Map inet6Entries = entries.inet6Entries(); + + assertEquals("Expected 2 IPv4 entries", 2, inet4Entries.size()); + assertEquals("Expected 1 IPv6 entries", 1, inet6Entries.size()); + assertEquals("127.0.0.1", inet4Entries.get("localhost").getHostAddress()); + assertEquals("255.255.255.255", inet4Entries.get("broadcasthost").getHostAddress()); + assertEquals("0:0:0:0:0:0:0:1", inet6Entries.get("localhost").getHostAddress()); + } } diff --git a/resolver/src/test/resources/io/netty/resolver/hosts-unicode b/resolver/src/test/resources/io/netty/resolver/hosts-unicode new file mode 100644 index 000000000000..68750bfd947a Binary files /dev/null and b/resolver/src/test/resources/io/netty/resolver/hosts-unicode differ diff --git a/tarball/pom.xml b/tarball/pom.xml index 485c2dfa422f..b3f8c9f80988 100644 --- a/tarball/pom.xml +++ b/tarball/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-tarball diff --git a/testsuite-autobahn/pom.xml b/testsuite-autobahn/pom.xml index bae2156de9e2..65d776c766e9 100644 --- a/testsuite-autobahn/pom.xml +++ b/testsuite-autobahn/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-testsuite-autobahn @@ -28,7 +28,26 @@ Netty/Testsuite/Autobahn + + true + + + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-codec-http @@ -55,7 +74,7 @@ me.normanmaurer.maven.autobahntestsuite autobahntestsuite-maven-plugin - 0.1.4 + 0.1.5 io.netty.testsuite.autobahn.AutobahnServer @@ -73,6 +92,13 @@ + + + org.python + jython-standalone + 2.7.1 + + diff --git a/testsuite-autobahn/src/main/java/io/netty/testsuite/autobahn/AutobahnServerHandler.java b/testsuite-autobahn/src/main/java/io/netty/testsuite/autobahn/AutobahnServerHandler.java index 75f40621e4bd..933c135887b0 100644 --- a/testsuite-autobahn/src/main/java/io/netty/testsuite/autobahn/AutobahnServerHandler.java +++ b/testsuite-autobahn/src/main/java/io/netty/testsuite/autobahn/AutobahnServerHandler.java @@ -78,7 +78,7 @@ private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req) } // Allow only GET methods. - if (req.method() != GET) { + if (!GET.equals(req.method())) { sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN)); return; } diff --git a/testsuite-http2/pom.xml b/testsuite-http2/pom.xml index f55c0678baf3..d4809c7779b2 100644 --- a/testsuite-http2/pom.xml +++ b/testsuite-http2/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-testsuite-http2 @@ -28,7 +28,31 @@ Netty/Testsuite/Http2 + + true + + + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + + + ${project.groupId} + netty-handler + ${project.version} + ${project.groupId} netty-codec-http @@ -61,17 +85,15 @@ com.github.madgnome h2spec-maven-plugin - 0.3 + 0.6 io.netty.testsuite.http2.Http2Server - 3.8 - Sends a GOAWAY frame 4.2 - Sends a dynamic table size update at the end of header block 5.1 - idle: Sends a DATA frame - 5.1 - closed: Sends a DATA frame + 5.1 - half closed (remote): Sends a HEADERS frame 5.1 - closed: Sends a HEADERS frame 5.1.1 - Sends stream identifier that is numerically smaller than previous - 7 - Sends a GOAWAY frame with unknown error code 8.1.2.2 - Sends a HEADERS frame that contains the connection-specific header field 8.1.2.2 - Sends a HEADERS frame that contains the TE header field with any value other than "trailers" 8.1.2.3 - Sends a HEADERS frame with empty ":path" pseudo-header field diff --git a/testsuite-http2/src/main/java/io/netty/testsuite/http2/Http2ServerInitializer.java b/testsuite-http2/src/main/java/io/netty/testsuite/http2/Http2ServerInitializer.java index 7b67fc9438cb..0af251cddd73 100644 --- a/testsuite-http2/src/main/java/io/netty/testsuite/http2/Http2ServerInitializer.java +++ b/testsuite-http2/src/main/java/io/netty/testsuite/http2/Http2ServerInitializer.java @@ -16,6 +16,8 @@ package io.netty.testsuite.http2; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; @@ -59,9 +61,7 @@ public UpgradeCodec newUpgradeCodec(CharSequence protocol) { } Http2ServerInitializer(int maxHttpContentLength) { - if (maxHttpContentLength < 0) { - throw new IllegalArgumentException("maxHttpContentLength (expected >= 0): " + maxHttpContentLength); - } + checkPositiveOrZero(maxHttpContentLength, "maxHttpContentLength"); this.maxHttpContentLength = maxHttpContentLength; } diff --git a/testsuite-osgi/pom.xml b/testsuite-osgi/pom.xml index 6c79b35d4150..47c564c859aa 100644 --- a/testsuite-osgi/pom.xml +++ b/testsuite-osgi/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-testsuite-osgi @@ -29,8 +29,9 @@ Netty/Testsuite/OSGI - 4.9.1 + 4.13.0 --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/jdk.internal.loader=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.security=ALL-UNNAMED + true diff --git a/testsuite-shading/pom.xml b/testsuite-shading/pom.xml new file mode 100644 index 000000000000..e461c7499473 --- /dev/null +++ b/testsuite-shading/pom.xml @@ -0,0 +1,388 @@ + + + + + 4.0.0 + + io.netty + netty-parent + 4.1.34.3.dse + + + netty-testsuite-shading + jar + + Netty/Testsuite/Shading + + + ${project.build.directory}/src + ${project.build.directory}/versions + ${project.build.directory}/classes-shaded + ${classesShadedDir}/META-INF/native + shaded + shaded2 + + ${project.artifactId}-${project.version}.jar + io.netty. + true + + + + + + kr.motd.maven + os-maven-plugin + 1.6.0 + + + + + + maven-deploy-plugin + + true + + + + + + + junit + junit + + + + + windows + + + windows + + + + + ${project.groupId} + netty-common + ${project.version} + compile + + + ${project.groupId} + netty-handler + ${project.version} + compile + + + ${project.groupId} + ${tcnative.artifactId} + ${tcnative.version} + ${tcnative.classifier} + compile + + + + + mac + + + mac + + + + netty_transport_native_kqueue_${os.detected.arch}.jnilib + + + + ${project.groupId} + netty-common + ${project.version} + compile + + + ${project.groupId} + netty-transport-native-kqueue + ${project.version} + ${jni.classifier} + compile + + + + ${project.groupId} + netty-handler + ${project.version} + compile + + + ${project.groupId} + ${tcnative.artifactId} + ${tcnative.version} + ${tcnative.classifier} + compile + + + + + + + maven-shade-plugin + + + shade + package + + shade + + + + + ${project.groupId} + + + + + ${shadedPackagePrefix} + ${shadingPrefix}.${shadedPackagePrefix} + + + + + + shade-1 + package + + shade + + + + + ${project.groupId} + + + + + ${shadedPackagePrefix} + ${shadingPrefix2}.${shadedPackagePrefix} + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + + + unpack-jar-features + package + + run + + + + + + + + + + + + + + + + + + + + + + + + + + + + + org.apache.maven.plugins + maven-failsafe-plugin + + + ${shadingPrefix} + ${shadingPrefix2} + + + + + package + + integration-test + + + + + + + + + linux + + + linux + + + + netty_transport_native_epoll_${os.detected.arch}.so + + + + ${project.groupId} + netty-common + ${project.version} + compile + + + ${project.groupId} + netty-transport-native-epoll + ${project.version} + ${jni.classifier} + compile + + + + ${project.groupId} + netty-handler + ${project.version} + compile + + + ${project.groupId} + ${tcnative.artifactId} + ${tcnative.version} + ${tcnative.classifier} + compile + + + + + + + maven-shade-plugin + + + shade + package + + shade + + + + + ${project.groupId} + + + + + ${shadedPackagePrefix} + ${shadingPrefix}.${shadedPackagePrefix} + + + + + + shade-1 + package + + shade + + + + + ${project.groupId} + + + + + ${shadedPackagePrefix} + ${shadingPrefix2}.${shadedPackagePrefix} + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + + + unpack-jar-features + package + + run + + + + + + + + + + + + + + + + + + + + + + + + + + + + + org.apache.maven.plugins + maven-failsafe-plugin + + + ${shadingPrefix} + ${shadingPrefix2} + + + + + package + + integration-test + + + + + + + + + + diff --git a/testsuite-shading/src/test/java/io/netty/testsuite/shading/ShadingIT.java b/testsuite-shading/src/test/java/io/netty/testsuite/shading/ShadingIT.java new file mode 100644 index 000000000000..51cef5f56836 --- /dev/null +++ b/testsuite-shading/src/test/java/io/netty/testsuite/shading/ShadingIT.java @@ -0,0 +1,55 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.shading; + +import io.netty.util.internal.PlatformDependent; +import org.junit.Test; +import org.junit.Assume; + +import java.lang.reflect.Method; + +public class ShadingIT { + + private static final String SHADING_PREFIX = System.getProperty("shadingPrefix2"); + private static final String SHADING_PREFIX2 = System.getProperty("shadingPrefix"); + + @Test + public void testShadingNativeTransport() throws Exception { + // Skip on windows. + Assume.assumeFalse(PlatformDependent.isWindows()); + + String className = PlatformDependent.isOsx() ? + "io.netty.channel.kqueue.KQueue" : "io.netty.channel.epoll.Epoll"; + testShading0(SHADING_PREFIX, className); + testShading0(SHADING_PREFIX2, className); + } + + @Test + public void testShadingTcnative() throws Exception { + // Skip on windows. + Assume.assumeFalse(PlatformDependent.isWindows()); + + String className = "io.netty.handler.ssl.OpenSsl"; + testShading0(SHADING_PREFIX, className); + testShading0(SHADING_PREFIX2, className); + } + + private static void testShading0(String shadingPrefix, String classname) throws Exception { + final Class clazz = Class.forName(shadingPrefix + '.' + classname); + Method method = clazz.getMethod("ensureAvailability"); + method.invoke(null); + } +} diff --git a/testsuite/pom.xml b/testsuite/pom.xml index 9b9bf225a949..319cce6fbd64 100644 --- a/testsuite/pom.xml +++ b/testsuite/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-testsuite @@ -29,6 +29,21 @@ Netty/Testsuite + + ${project.groupId} + netty-common + ${project.version} + + + ${project.groupId} + netty-buffer + ${project.version} + + + ${project.groupId} + netty-transport + ${project.version} + ${project.groupId} netty-transport-sctp @@ -89,6 +104,7 @@ --add-exports java.base/sun.security.x509=ALL-UNNAMED + true diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java index b49f6780ef52..37adc20e3837 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java @@ -158,7 +158,7 @@ public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exceptio }); final CountDownLatch latch = new CountDownLatch(count); - sc = setupServerChannel(sb, bytes, latch); + sc = setupServerChannel(sb, bytes, latch, false); if (bindClient) { cc = cb.bind(newSocketAddress()).sync().channel(); } else { @@ -209,10 +209,21 @@ private void testSimpleSendWithConnect(Bootstrap sb, Bootstrap cb, ByteBuf buf, private void testSimpleSendWithConnect0(Bootstrap sb, Bootstrap cb, ByteBuf buf, final byte[] bytes, int count, WrapType wrapType) throws Throwable { - cb.handler(new SimpleChannelInboundHandler() { + final CountDownLatch clientLatch = new CountDownLatch(count); + + cb.handler(new SimpleChannelInboundHandler() { @Override - public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exception { - // Nothing will be sent. + public void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) throws Exception { + ByteBuf buf = msg.content(); + assertEquals(bytes.length, buf.readableBytes()); + for (int i = 0; i < bytes.length; i++) { + assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); + } + + // Test that the channel's localAddress is equal to the message's recipient + assertEquals(ctx.channel().localAddress(), msg.recipient()); + + clientLatch.countDown(); } }); @@ -220,7 +231,7 @@ public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exceptio DatagramChannel cc = null; try { final CountDownLatch latch = new CountDownLatch(count); - sc = setupServerChannel(sb, bytes, latch); + sc = setupServerChannel(sb, bytes, latch, true); cc = (DatagramChannel) cb.connect(sc.localAddress()).sync().channel(); for (int i = 0; i < count; i++) { @@ -243,7 +254,7 @@ public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exceptio } cc.flush(); assertTrue(latch.await(10, TimeUnit.SECONDS)); - + assertTrue(clientLatch.await(10, TimeUnit.SECONDS)); assertTrue(cc.isConnected()); // Test what happens when we call disconnect() @@ -264,7 +275,7 @@ public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exceptio } @SuppressWarnings("deprecation") - private Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final CountDownLatch latch) + private Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final CountDownLatch latch, final boolean echo) throws Throwable { sb.handler(new ChannelInitializer() { @Override @@ -274,13 +285,16 @@ protected void initChannel(Channel ch) throws Exception { public void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) throws Exception { ByteBuf buf = msg.content(); assertEquals(bytes.length, buf.readableBytes()); - for (byte b : bytes) { - assertEquals(b, buf.readByte()); + for (int i = 0; i < bytes.length; i++) { + assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); } // Test that the channel's localAddress is equal to the message's recipient assertEquals(ctx.channel().localAddress(), msg.recipient()); + if (echo) { + ctx.writeAndFlush(new DatagramPacket(buf.retainedDuplicate(), msg.sender())); + } latch.countDown(); } }); diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketDataReadInitialStateTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketDataReadInitialStateTest.java new file mode 100644 index 000000000000..f4d218d957f9 --- /dev/null +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketDataReadInitialStateTest.java @@ -0,0 +1,184 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import org.junit.Test; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static io.netty.channel.ChannelOption.AUTO_READ; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class SocketDataReadInitialStateTest extends AbstractSocketTest { + @Test(timeout = 10000) + public void testAutoReadOffNoDataReadUntilReadCalled() throws Throwable { + run(); + } + + public void testAutoReadOffNoDataReadUntilReadCalled(ServerBootstrap sb, Bootstrap cb) throws Throwable { + Channel serverChannel = null; + Channel clientChannel = null; + final int sleepMs = 100; + try { + sb.option(AUTO_READ, false); + sb.childOption(AUTO_READ, false); + cb.option(AUTO_READ, false); + final CountDownLatch serverReadyLatch = new CountDownLatch(1); + final CountDownLatch acceptorReadLatch = new CountDownLatch(1); + final CountDownLatch serverReadLatch = new CountDownLatch(1); + final CountDownLatch clientReadLatch = new CountDownLatch(1); + final AtomicReference serverConnectedChannelRef = new AtomicReference(); + + sb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + acceptorReadLatch.countDown(); + ctx.fireChannelRead(msg); + } + }); + } + }); + + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + serverConnectedChannelRef.set(ch); + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { + ctx.writeAndFlush(msg.retainedDuplicate()); + serverReadLatch.countDown(); + } + }); + serverReadyLatch.countDown(); + } + }); + + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) { + clientReadLatch.countDown(); + } + }); + } + }); + + serverChannel = sb.bind().sync().channel(); + clientChannel = cb.connect(serverChannel.localAddress()).sync().channel(); + clientChannel.writeAndFlush(clientChannel.alloc().buffer().writeZero(1)).syncUninterruptibly(); + + // The acceptor shouldn't read any data until we call read() below, but give it some time to see if it will. + Thread.sleep(sleepMs); + assertEquals(1, acceptorReadLatch.getCount()); + serverChannel.read(); + serverReadyLatch.await(); + + Channel serverConnectedChannel = serverConnectedChannelRef.get(); + assertNotNull(serverConnectedChannel); + + // Allow some amount of time for the server peer to receive the message (which isn't expected to happen + // until we call read() below). + Thread.sleep(sleepMs); + assertEquals(1, serverReadLatch.getCount()); + serverConnectedChannel.read(); + serverReadLatch.await(); + + // Allow some amount of time for the client to read the echo. + Thread.sleep(sleepMs); + assertEquals(1, clientReadLatch.getCount()); + clientChannel.read(); + clientReadLatch.await(); + } finally { + if (serverChannel != null) { + serverChannel.close().sync(); + } + if (clientChannel != null) { + clientChannel.close().sync(); + } + } + } + + @Test(timeout = 10000) + public void testAutoReadOnDataReadImmediately() throws Throwable { + run(); + } + + public void testAutoReadOnDataReadImmediately(ServerBootstrap sb, Bootstrap cb) throws Throwable { + Channel serverChannel = null; + Channel clientChannel = null; + try { + sb.option(AUTO_READ, true); + sb.childOption(AUTO_READ, true); + cb.option(AUTO_READ, true); + final CountDownLatch serverReadLatch = new CountDownLatch(1); + final CountDownLatch clientReadLatch = new CountDownLatch(1); + + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { + ctx.writeAndFlush(msg.retainedDuplicate()); + serverReadLatch.countDown(); + } + }); + } + }); + + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) { + clientReadLatch.countDown(); + } + }); + } + }); + + serverChannel = sb.bind().sync().channel(); + clientChannel = cb.connect(serverChannel.localAddress()).sync().channel(); + clientChannel.writeAndFlush(clientChannel.alloc().buffer().writeZero(1)).syncUninterruptibly(); + serverReadLatch.await(); + clientReadLatch.await(); + } finally { + if (serverChannel != null) { + serverChannel.close().sync(); + } + if (clientChannel != null) { + clientChannel.close().sync(); + } + } + } +} diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketFileRegionTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketFileRegionTest.java index 53deb6c6743a..881acd1bcff3 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketFileRegionTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketFileRegionTest.java @@ -22,11 +22,13 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelOption; import io.netty.channel.DefaultFileRegion; import io.netty.channel.FileRegion; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.util.internal.PlatformDependent; +import org.hamcrest.CoreMatchers; import org.junit.Test; import java.io.File; @@ -73,6 +75,11 @@ public void testFileRegionVoidPromiseNotAutoRead() throws Throwable { run(); } + @Test + public void testFileRegionCountLargerThenFile() throws Throwable { + run(); + } + public void testFileRegion(ServerBootstrap sb, Bootstrap cb) throws Throwable { testFileRegion0(sb, cb, false, true, true); } @@ -93,6 +100,34 @@ public void testFileRegionVoidPromiseNotAutoRead(ServerBootstrap sb, Bootstrap c testFileRegion0(sb, cb, true, false, true); } + public void testFileRegionCountLargerThenFile(ServerBootstrap sb, Bootstrap cb) throws Throwable { + File file = File.createTempFile("netty-", ".tmp"); + file.deleteOnExit(); + + final FileOutputStream out = new FileOutputStream(file); + out.write(data); + out.close(); + + sb.childHandler(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { + // Just drop the message. + } + }); + cb.handler(new ChannelInboundHandlerAdapter()); + + Channel sc = sb.bind().sync().channel(); + Channel cc = cb.connect(sc.localAddress()).sync().channel(); + + // Request file region which is bigger then the underlying file. + FileRegion region = new DefaultFileRegion( + new FileInputStream(file).getChannel(), 0, data.length + 1024); + + assertThat(cc.writeAndFlush(region).await().cause(), CoreMatchers.instanceOf(IOException.class)); + cc.close().sync(); + sc.close().sync(); + } + private static void testFileRegion0( ServerBootstrap sb, Bootstrap cb, boolean voidPromise, final boolean autoRead, boolean defaultFileRegion) throws Throwable { diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketHalfClosedTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketHalfClosedTest.java index 3f99cd5b2c60..70bf82fc412b 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketHalfClosedTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketHalfClosedTest.java @@ -34,6 +34,8 @@ import io.netty.channel.socket.ChannelOutputShutdownEvent; import io.netty.channel.socket.DuplexChannel; import io.netty.util.UncheckedBooleanSupplier; +import io.netty.util.internal.PlatformDependent; +import org.junit.Assume; import org.junit.Test; import java.util.concurrent.CountDownLatch; @@ -229,6 +231,8 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @Test public void testAutoCloseFalseDoesShutdownOutput() throws Throwable { + // This test only works on Linux / BSD / MacOS as we assume some semantics that are not true for Windows. + Assume.assumeFalse(PlatformDependent.isWindows()); run(); } diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketRstTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketRstTest.java index 88cef1236a67..4cde8b023bb3 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketRstTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketRstTest.java @@ -22,8 +22,6 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.util.internal.PlatformDependent; import org.junit.Test; import java.io.IOException; @@ -91,13 +89,6 @@ public void channelInactive(ChannelHandlerContext ctx) { // Verify the client received a RST. Throwable cause = throwableRef.get(); - if (PlatformDependent.javaVersion() >= 11 && sb.config().group() instanceof NioEventLoopGroup) { - // In Java11 calling SocketChannel.close() will also call shutdown(..,SHUT_WR) before actual closing the - // fd which means we may not see the ECONNRESET at all :( - if (cause == null) { - return; - } - } assertTrue("actual [type, message]: [" + cause.getClass() + ", " + cause.getMessage() + "]", cause instanceof IOException); diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java index 855f1914096e..4574ca70951e 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslClientRenegotiateTest.java @@ -18,6 +18,7 @@ import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandlerContext; @@ -46,6 +47,9 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLHandshakeException; @@ -73,7 +77,7 @@ public class SocketSslClientRenegotiateTest extends AbstractSocketTest { KEY_FILE = ssc.privateKey(); } - @Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}") + @Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}") public static Collection data() throws Exception { List serverContexts = new ArrayList(); List clientContexts = new ArrayList(); @@ -91,7 +95,8 @@ public static Collection data() throws Exception { for (SslContext sc: serverContexts) { for (SslContext cc: clientContexts) { for (int i = 0; i < 32; i++) { - params.add(new Object[] { sc, cc}); + params.add(new Object[] { sc, cc, true}); + params.add(new Object[] { sc, cc, false}); } } } @@ -101,6 +106,7 @@ public static Collection data() throws Exception { private final SslContext serverCtx; private final SslContext clientCtx; + private final boolean delegate; private final AtomicReference clientException = new AtomicReference(); private final AtomicReference serverException = new AtomicReference(); @@ -116,9 +122,10 @@ public static Collection data() throws Exception { private final TestHandler serverHandler = new TestHandler(serverException); public SocketSslClientRenegotiateTest( - SslContext serverCtx, SslContext clientCtx) { + SslContext serverCtx, SslContext clientCtx, boolean delegate) { this.serverCtx = serverCtx; this.clientCtx = clientCtx; + this.delegate = delegate; } @Test(timeout = 30000) @@ -129,55 +136,74 @@ public void testSslRenegotiationRejected() throws Throwable { run(); } + private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) { + if (executor == null) { + return sslCtx.newHandler(allocator); + } else { + return sslCtx.newHandler(allocator, executor); + } + } + public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb) throws Throwable { reset(); - sb.childHandler(new ChannelInitializer() { - @Override - @SuppressWarnings("deprecation") - public void initChannel(Channel sch) throws Exception { - serverChannel = sch; - serverSslHandler = serverCtx.newHandler(sch.alloc()); + final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null; - sch.pipeline().addLast("ssl", serverSslHandler); - sch.pipeline().addLast("handler", serverHandler); + try { + sb.childHandler(new ChannelInitializer() { + @Override + @SuppressWarnings("deprecation") + public void initChannel(Channel sch) throws Exception { + serverChannel = sch; + serverSslHandler = newSslHandler(serverCtx, sch.alloc(), executorService); + // As we test renegotiation we should use a protocol that support it. + serverSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"}); + sch.pipeline().addLast("ssl", serverSslHandler); + sch.pipeline().addLast("handler", serverHandler); + } + }); + + cb.handler(new ChannelInitializer() { + @Override + @SuppressWarnings("deprecation") + public void initChannel(Channel sch) throws Exception { + clientChannel = sch; + clientSslHandler = newSslHandler(clientCtx, sch.alloc(), executorService); + // As we test renegotiation we should use a protocol that support it. + clientSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"}); + sch.pipeline().addLast("ssl", clientSslHandler); + sch.pipeline().addLast("handler", clientHandler); + } + }); + + Channel sc = sb.bind().sync().channel(); + cb.connect(sc.localAddress()).sync(); + + Future clientHandshakeFuture = clientSslHandler.handshakeFuture(); + clientHandshakeFuture.sync(); + + String renegotiation = clientSslHandler.engine().getEnabledCipherSuites()[0]; + // Use the first previous enabled ciphersuite and try to renegotiate. + clientSslHandler.engine().setEnabledCipherSuites(new String[]{renegotiation}); + clientSslHandler.renegotiate().await(); + serverChannel.close().awaitUninterruptibly(); + clientChannel.close().awaitUninterruptibly(); + sc.close().awaitUninterruptibly(); + try { + if (serverException.get() != null) { + throw serverException.get(); + } + fail(); + } catch (DecoderException e) { + assertTrue(e.getCause() instanceof SSLHandshakeException); } - }); - - cb.handler(new ChannelInitializer() { - @Override - @SuppressWarnings("deprecation") - public void initChannel(Channel sch) throws Exception { - clientChannel = sch; - clientSslHandler = clientCtx.newHandler(sch.alloc()); - - sch.pipeline().addLast("ssl", clientSslHandler); - sch.pipeline().addLast("handler", clientHandler); + if (clientException.get() != null) { + throw clientException.get(); } - }); - - Channel sc = sb.bind().sync().channel(); - cb.connect(sc.localAddress()).sync(); - - Future clientHandshakeFuture = clientSslHandler.handshakeFuture(); - clientHandshakeFuture.sync(); - - String renegotiation = clientSslHandler.engine().getSupportedCipherSuites()[0]; - clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation }); - clientSslHandler.renegotiate().await(); - serverChannel.close().awaitUninterruptibly(); - clientChannel.close().awaitUninterruptibly(); - sc.close().awaitUninterruptibly(); - try { - if (serverException.get() != null) { - throw serverException.get(); + } finally { + if (executorService != null) { + executorService.shutdown(); } - fail(); - } catch (DecoderException e) { - assertTrue(e.getCause() instanceof SSLHandshakeException); - } - if (clientException.get() != null) { - throw clientException.get(); } } diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java index 7c94b5a6710c..4cdeae98beb8 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java @@ -123,17 +123,33 @@ public String toString() { "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}") public static Collection data() throws Exception { List serverContexts = new ArrayList(); - serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build()); + serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE) + .sslProvider(SslProvider.JDK) + // As we test renegotiation we should use a protocol that support it. + .protocols("TLSv1.2") + .build()); List clientContexts = new ArrayList(); - clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build()); + clientContexts.add(SslContextBuilder.forClient() + .sslProvider(SslProvider.JDK) + .trustManager(CERT_FILE) + // As we test renegotiation we should use a protocol that support it. + .protocols("TLSv1.2") + .build()); boolean hasOpenSsl = OpenSsl.isAvailable(); if (hasOpenSsl) { serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE) - .sslProvider(SslProvider.OPENSSL).build()); - clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL) - .trustManager(CERT_FILE).build()); + .sslProvider(SslProvider.OPENSSL) + // As we test renegotiation we should use a protocol that support it. + .protocols("TLSv1.2") + .build()); + clientContexts.add(SslContextBuilder.forClient() + .sslProvider(SslProvider.OPENSSL) + .trustManager(CERT_FILE) + // As we test renegotiation we should use a protocol that support it. + .protocols("TLSv1.2") + .build()); } else { logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause()); } diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java index d7db90be8b8a..b3f0e46f45d5 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java @@ -18,6 +18,7 @@ import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; @@ -48,6 +49,9 @@ import java.util.Collection; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertEquals; @@ -74,7 +78,7 @@ public class SocketSslGreetingTest extends AbstractSocketTest { KEY_FILE = ssc.privateKey(); } - @Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}") + @Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}") public static Collection data() throws Exception { List serverContexts = new ArrayList(); serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build()); @@ -95,7 +99,8 @@ public static Collection data() throws Exception { List params = new ArrayList(); for (SslContext sc: serverContexts) { for (SslContext cc: clientContexts) { - params.add(new Object[] { sc, cc }); + params.add(new Object[] { sc, cc, true }); + params.add(new Object[] { sc, cc, false }); } } return params; @@ -103,10 +108,20 @@ public static Collection data() throws Exception { private final SslContext serverCtx; private final SslContext clientCtx; + private final boolean delegate; - public SocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx) { + public SocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx, boolean delegate) { this.serverCtx = serverCtx; this.clientCtx = clientCtx; + this.delegate = delegate; + } + + private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) { + if (executor == null) { + return sslCtx.newHandler(allocator); + } else { + return sslCtx.newHandler(allocator, executor); + } } // Test for https://github.com/netty/netty/pull/2437 @@ -119,46 +134,53 @@ public void testSslGreeting(ServerBootstrap sb, Bootstrap cb) throws Throwable { final ServerHandler sh = new ServerHandler(); final ClientHandler ch = new ClientHandler(); - sb.childHandler(new ChannelInitializer() { - @Override - public void initChannel(Channel sch) throws Exception { - ChannelPipeline p = sch.pipeline(); - p.addLast(serverCtx.newHandler(sch.alloc())); - p.addLast(new LoggingHandler(LOG_LEVEL)); - p.addLast(sh); - } - }); - - cb.handler(new ChannelInitializer() { - @Override - public void initChannel(Channel sch) throws Exception { - ChannelPipeline p = sch.pipeline(); - p.addLast(clientCtx.newHandler(sch.alloc())); - p.addLast(new LoggingHandler(LOG_LEVEL)); - p.addLast(ch); - } - }); + final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null; + try { + sb.childHandler(new ChannelInitializer() { + @Override + public void initChannel(Channel sch) throws Exception { + ChannelPipeline p = sch.pipeline(); + p.addLast(newSslHandler(serverCtx, sch.alloc(), executorService)); + p.addLast(new LoggingHandler(LOG_LEVEL)); + p.addLast(sh); + } + }); + + cb.handler(new ChannelInitializer() { + @Override + public void initChannel(Channel sch) throws Exception { + ChannelPipeline p = sch.pipeline(); + p.addLast(newSslHandler(clientCtx, sch.alloc(), executorService)); + p.addLast(new LoggingHandler(LOG_LEVEL)); + p.addLast(ch); + } + }); - Channel sc = sb.bind().sync().channel(); - Channel cc = cb.connect(sc.localAddress()).sync().channel(); + Channel sc = sb.bind().sync().channel(); + Channel cc = cb.connect(sc.localAddress()).sync().channel(); - ch.latch.await(); + ch.latch.await(); - sh.channel.close().awaitUninterruptibly(); - cc.close().awaitUninterruptibly(); - sc.close().awaitUninterruptibly(); + sh.channel.close().awaitUninterruptibly(); + cc.close().awaitUninterruptibly(); + sc.close().awaitUninterruptibly(); - if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) { - throw sh.exception.get(); - } - if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) { - throw ch.exception.get(); - } - if (sh.exception.get() != null) { - throw sh.exception.get(); - } - if (ch.exception.get() != null) { - throw ch.exception.get(); + if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) { + throw sh.exception.get(); + } + if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) { + throw ch.exception.get(); + } + if (sh.exception.get() != null) { + throw sh.exception.get(); + } + if (ch.exception.get() != null) { + throw ch.exception.get(); + } + } finally { + if (executorService != null) { + executorService.shutdown(); + } } } diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslSessionReuseTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslSessionReuseTest.java index 4071c8461049..5d0fd0a5e428 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslSessionReuseTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslSessionReuseTest.java @@ -98,7 +98,7 @@ public void testSslSessionReuse() throws Throwable { public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb) throws Throwable { final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true); final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true); - final String[] protocols = new String[]{ "TLSv1", "TLSv1.1", "TLSv1.2" }; + final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" }; sb.childHandler(new ChannelInitializer() { @Override diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/udt/UDTClientServerConnectionTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/udt/UDTClientServerConnectionTest.java index c8786434aeb0..c6af2124e81f 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/udt/UDTClientServerConnectionTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/udt/UDTClientServerConnectionTest.java @@ -35,6 +35,8 @@ import io.netty.util.NetUtil; import io.netty.util.concurrent.DefaultThreadFactory; import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.internal.PlatformDependent; +import org.junit.Assume; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -338,6 +340,8 @@ public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception */ @Test public void connection() throws Exception { + Assume.assumeFalse("Not supported on J9 JVM", PlatformDependent.isJ9Jvm()); + log.info("Starting server."); // Using LOCALHOST4 as UDT transport does not support IPV6 :( final Server server = new Server(new InetSocketAddress(NetUtil.LOCALHOST4, 0)); diff --git a/transport-native-epoll/pom.xml b/transport-native-epoll/pom.xml index 6022fa2e5218..32f2771168c3 100644 --- a/transport-native-epoll/pom.xml +++ b/transport-native-epoll/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-transport-native-epoll @@ -161,7 +161,6 @@ generate build - compile @@ -265,7 +264,7 @@ ${linux.sendmmsg.support}${glibc.sendmmsg.support} .*IO_NETTY_SENDMSSG_NOT_FOUND.* - CFLAGS=-O3 -DIO_NETTY_SENDMMSG_NOT_FOUND -Werror -fno-omit-frame-pointer -Wunused-variable -I${unix.common.include.unpacked.dir} + CFLAGS=-O3 -DIO_NETTY_SENDMMSG_NOT_FOUND -Werror -fno-omit-frame-pointer -Wunused-variable -fvisibility=hidden -I${unix.common.include.unpacked.dir} false @@ -281,7 +280,7 @@ ${jni.compiler.args.cflags} ^((?!CFLAGS=).)*$ - CFLAGS=-O3 -Werror -fno-omit-frame-pointer -Wunused-variable -I${unix.common.include.unpacked.dir} + CFLAGS=-O3 -Werror -fno-omit-frame-pointer -Wunused-variable -fvisibility=hidden -I${unix.common.include.unpacked.dir} false @@ -319,12 +318,12 @@ io.netty - netty-transport-native-unix-common + netty-transport ${project.version} io.netty - netty-transport + netty-transport-native-unix-common ${project.version} diff --git a/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c b/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c index d6e3ead94fee..05d889f175c0 100644 --- a/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c +++ b/transport-native-epoll/src/main/c/netty_epoll_linuxsocket.c @@ -50,6 +50,11 @@ #define TCP_NOTSENT_LOWAT 25 #endif +// SO_BUSY_POLL is defined in linux 3.11. We define this here so older kernels can compile. +#ifndef SO_BUSY_POLL +#define SO_BUSY_POLL 46 +#endif + static jclass peerCredentialsClass = NULL; static jmethodID peerCredentialsMethodId = NULL; @@ -111,6 +116,10 @@ static void netty_epoll_linuxsocket_setIpRecvOrigDestAddr(JNIEnv* env, jclass cl netty_unix_socket_setOption(env, fd, IPPROTO_IP, IP_RECVORIGDSTADDR, &optval, sizeof(optval)); } +static void netty_epoll_linuxsocket_setSoBusyPoll(JNIEnv* env, jclass clazz, jint fd, jint optval) { + netty_unix_socket_setOption(env, fd, SOL_SOCKET, SO_BUSY_POLL, &optval, sizeof(optval)); +} + static void netty_epoll_linuxsocket_setTcpMd5Sig(JNIEnv* env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jbyteArray key) { struct sockaddr_storage addr; socklen_t addrSize; @@ -256,6 +265,14 @@ static jint netty_epoll_linuxsocket_isTcpCork(JNIEnv* env, jclass clazz, jint fd return optval; } +static jint netty_epoll_linuxsocket_getSoBusyPoll(JNIEnv* env, jclass clazz, jint fd) { + int optval; + if (netty_unix_socket_getOption(env, fd, SOL_SOCKET, SO_BUSY_POLL, &optval, sizeof(optval)) == -1) { + return -1; + } + return optval; +} + static jint netty_epoll_linuxsocket_getTcpDeferAccept(JNIEnv* env, jclass clazz, jint fd) { int optval; if (netty_unix_socket_getOption(env, fd, IPPROTO_TCP, TCP_DEFER_ACCEPT, &optval, sizeof(optval)) == -1) { @@ -341,10 +358,12 @@ static jlong netty_epoll_linuxsocket_sendFile(JNIEnv* env, jclass clazz, jint fd // JNI Method Registration Table Begin static const JNINativeMethod fixed_method_table[] = { { "setTcpCork", "(II)V", (void *) netty_epoll_linuxsocket_setTcpCork }, + { "setSoBusyPoll", "(II)V", (void *) netty_epoll_linuxsocket_setSoBusyPoll }, { "setTcpQuickAck", "(II)V", (void *) netty_epoll_linuxsocket_setTcpQuickAck }, { "setTcpDeferAccept", "(II)V", (void *) netty_epoll_linuxsocket_setTcpDeferAccept }, { "setTcpNotSentLowAt", "(II)V", (void *) netty_epoll_linuxsocket_setTcpNotSentLowAt }, { "isTcpCork", "(I)I", (void *) netty_epoll_linuxsocket_isTcpCork }, + { "getSoBusyPoll", "(I)I", (void *) netty_epoll_linuxsocket_getSoBusyPoll }, { "getTcpDeferAccept", "(I)I", (void *) netty_epoll_linuxsocket_getTcpDeferAccept }, { "getTcpNotSentLowAt", "(I)I", (void *) netty_epoll_linuxsocket_getTcpNotSentLowAt }, { "isTcpQuickAck", "(I)I", (void *) netty_epoll_linuxsocket_isTcpQuickAck }, diff --git a/transport-native-epoll/src/main/c/netty_epoll_native.c b/transport-native-epoll/src/main/c/netty_epoll_native.c index 90e6f189c341..eb9a2eeac87d 100644 --- a/transport-native-epoll/src/main/c/netty_epoll_native.c +++ b/transport-native-epoll/src/main/c/netty_epoll_native.c @@ -41,6 +41,7 @@ #include #include "netty_epoll_linuxsocket.h" +#include "netty_unix_buffer.h" #include "netty_unix_errors.h" #include "netty_unix_filedescriptor.h" #include "netty_unix_jni.h" @@ -67,11 +68,11 @@ struct mmsghdr { #endif // Those are initialized in the init(...) method and cached for performance reasons -jfieldID packetAddrFieldId = NULL; -jfieldID packetScopeIdFieldId = NULL; -jfieldID packetPortFieldId = NULL; -jfieldID packetMemoryAddressFieldId = NULL; -jfieldID packetCountFieldId = NULL; +static jfieldID packetAddrFieldId = NULL; +static jfieldID packetScopeIdFieldId = NULL; +static jfieldID packetPortFieldId = NULL; +static jfieldID packetMemoryAddressFieldId = NULL; +static jfieldID packetCountFieldId = NULL; // util methods static int getSysctlValue(const char * property, int* returnValue) { @@ -188,13 +189,17 @@ static jint netty_epoll_native_epollWait0(JNIEnv* env, jclass clazz, jint efd, j } } while((err = errno) == EINTR); } else { - struct itimerspec ts; - memset(&ts.it_interval, 0, sizeof(struct timespec)); - ts.it_value.tv_sec = tvSec; - ts.it_value.tv_nsec = tvNsec; - if (timerfd_settime(timerFd, 0, &ts, NULL) < 0) { - netty_unix_errors_throwChannelExceptionErrorNo(env, "timerfd_settime() failed: ", errno); - return -1; + // only reschedule the timer if there is a newer event. + // -1 is a special value used by EpollEventLoop. + if (tvSec != ((jint) -1) && tvNsec != ((jint) -1)) { + struct itimerspec ts; + memset(&ts.it_interval, 0, sizeof(struct timespec)); + ts.it_value.tv_sec = tvSec; + ts.it_value.tv_nsec = tvNsec; + if (timerfd_settime(timerFd, 0, &ts, NULL) < 0) { + netty_unix_errors_throwChannelExceptionErrorNo(env, "timerfd_settime() failed: ", errno); + return -1; + } } do { result = epoll_wait(efd, ev, len, -1); @@ -215,6 +220,33 @@ static jint netty_epoll_native_epollWait0(JNIEnv* env, jclass clazz, jint efd, j return -err; } +static inline void cpu_relax() { +#if defined(__x86_64__) + asm volatile("pause\n": : :"memory"); +#endif +} + +static jint netty_epoll_native_epollBusyWait0(JNIEnv* env, jclass clazz, jint efd, jlong address, jint len) { + struct epoll_event *ev = (struct epoll_event*) (intptr_t) address; + int result, err; + + // Zeros = poll (aka return immediately). + do { + result = epoll_wait(efd, ev, len, 0); + if (result == 0) { + // Since we're always polling epoll_wait with no timeout, + // signal CPU that we're in a busy loop + cpu_relax(); + } + + if (result >= 0) { + return result; + } + } while((err = errno) == EINTR); + + return -err; +} + static jint netty_epoll_native_epollCtlAdd0(JNIEnv* env, jclass clazz, jint efd, jint fd, jint flags) { int res = epollCtl(env, efd, EPOLL_CTL_ADD, fd, flags); if (res < 0) { @@ -624,6 +656,7 @@ static const JNINativeMethod fixed_method_table[] = { { "timerFdRead", "(I)V", (void *) netty_epoll_native_timerFdRead }, { "epollCreate", "()I", (void *) netty_epoll_native_epollCreate }, { "epollWait0", "(IJIIII)I", (void *) netty_epoll_native_epollWait0 }, + { "epollBusyWait0", "(IJI)I", (void *) netty_epoll_native_epollBusyWait0 }, { "epollCtlAdd0", "(III)I", (void *) netty_epoll_native_epollCtlAdd0 }, { "epollCtlMod0", "(III)I", (void *) netty_epoll_native_epollCtlMod0 }, { "epollCtlDel0", "(II)I", (void *) netty_epoll_native_epollCtlDel0 }, @@ -665,13 +698,20 @@ static void freeDynamicMethodsTable(JNINativeMethod* dynamicMethods) { // JNI Method Registration Table End static jint netty_epoll_native_JNI_OnLoad(JNIEnv* env, const char* packagePrefix) { + int limitsOnLoadCalled = 0; + int errorsOnLoadCalled = 0; + int filedescriptorOnLoadCalled = 0; + int socketOnLoadCalled = 0; + int bufferOnLoadCalled = 0; + int linuxsocketOnLoadCalled = 0; + // We must register the statically referenced methods first! if (netty_unix_util_register_natives(env, packagePrefix, "io/netty/channel/epoll/NativeStaticallyReferencedJniMethods", statically_referenced_fixed_method_table, statically_referenced_fixed_method_table_size) != 0) { - return JNI_ERR; + goto error; } // Register the methods which are not referenced by static member variables JNINativeMethod* dynamicMethods = createDynamicMethodsTable(packagePrefix); @@ -681,26 +721,40 @@ static jint netty_epoll_native_JNI_OnLoad(JNIEnv* env, const char* packagePrefix dynamicMethods, dynamicMethodsTableSize()) != 0) { freeDynamicMethodsTable(dynamicMethods); - return JNI_ERR; + goto error; } freeDynamicMethodsTable(dynamicMethods); dynamicMethods = NULL; // Load all c modules that we depend upon if (netty_unix_limits_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; } + limitsOnLoadCalled = 1; + if (netty_unix_errors_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; } + errorsOnLoadCalled = 1; + if (netty_unix_filedescriptor_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; } + filedescriptorOnLoadCalled = 1; + if (netty_unix_socket_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; + } + socketOnLoadCalled = 1; + + if (netty_unix_buffer_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { + goto error; } + bufferOnLoadCalled = 1; + if (netty_epoll_linuxsocket_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; } + linuxsocketOnLoadCalled = 1; // Initialize this module char* nettyClassName = netty_unix_util_prepend(packagePrefix, "io/netty/channel/epoll/NativeDatagramPacketArray$NativeDatagramPacket"); @@ -709,37 +763,64 @@ static jint netty_epoll_native_JNI_OnLoad(JNIEnv* env, const char* packagePrefix nettyClassName = NULL; if (nativeDatagramPacketCls == NULL) { // pending exception... - return JNI_ERR; + goto error; } packetAddrFieldId = (*env)->GetFieldID(env, nativeDatagramPacketCls, "addr", "[B"); if (packetAddrFieldId == NULL) { netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.addr"); - return JNI_ERR; + goto error; } packetScopeIdFieldId = (*env)->GetFieldID(env, nativeDatagramPacketCls, "scopeId", "I"); if (packetScopeIdFieldId == NULL) { netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.scopeId"); - return JNI_ERR; + goto error; } packetPortFieldId = (*env)->GetFieldID(env, nativeDatagramPacketCls, "port", "I"); if (packetPortFieldId == NULL) { netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.port"); - return JNI_ERR; + goto error; } packetMemoryAddressFieldId = (*env)->GetFieldID(env, nativeDatagramPacketCls, "memoryAddress", "J"); if (packetMemoryAddressFieldId == NULL) { netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.memoryAddress"); - return JNI_ERR; + goto error; } packetCountFieldId = (*env)->GetFieldID(env, nativeDatagramPacketCls, "count", "I"); if (packetCountFieldId == NULL) { netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.count"); - return JNI_ERR; + goto error; } return NETTY_JNI_VERSION; + +error: + if (limitsOnLoadCalled == 1) { + netty_unix_limits_JNI_OnUnLoad(env); + } + if (errorsOnLoadCalled == 1) { + netty_unix_errors_JNI_OnUnLoad(env); + } + if (filedescriptorOnLoadCalled == 1) { + netty_unix_filedescriptor_JNI_OnUnLoad(env); + } + if (socketOnLoadCalled == 1) { + netty_unix_socket_JNI_OnUnLoad(env); + } + if (bufferOnLoadCalled == 1) { + netty_unix_buffer_JNI_OnUnLoad(env); + } + if (linuxsocketOnLoadCalled == 1) { + netty_epoll_linuxsocket_JNI_OnUnLoad(env); + } + packetAddrFieldId = NULL; + packetScopeIdFieldId = NULL; + packetPortFieldId = NULL; + packetMemoryAddressFieldId = NULL; + packetCountFieldId = NULL; + + return JNI_ERR; } static void netty_epoll_native_JNI_OnUnLoad(JNIEnv* env) { @@ -747,11 +828,18 @@ static void netty_epoll_native_JNI_OnUnLoad(JNIEnv* env) { netty_unix_errors_JNI_OnUnLoad(env); netty_unix_filedescriptor_JNI_OnUnLoad(env); netty_unix_socket_JNI_OnUnLoad(env); + netty_unix_buffer_JNI_OnUnLoad(env); netty_epoll_linuxsocket_JNI_OnUnLoad(env); + + packetAddrFieldId = NULL; + packetScopeIdFieldId = NULL; + packetPortFieldId = NULL; + packetMemoryAddressFieldId = NULL; + packetCountFieldId = NULL; } // Invoked by the JVM when statically linked -jint JNI_OnLoad_netty_transport_native_epoll(JavaVM* vm, void* reserved) { +static jint JNI_OnLoad_netty_transport_native_epoll0(JavaVM* vm, void* reserved) { JNIEnv* env; if ((*vm)->GetEnv(vm, (void**) &env, NETTY_JNI_VERSION) != JNI_OK) { return JNI_ERR; @@ -782,14 +870,7 @@ jint JNI_OnLoad_netty_transport_native_epoll(JavaVM* vm, void* reserved) { return ret; } -#ifndef NETTY_BUILD_STATIC -JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void* reserved) { - return JNI_OnLoad_netty_transport_native_epoll(vm, reserved); -} -#endif /* NETTY_BUILD_STATIC */ - -// Invoked by the JVM when statically linked -void JNI_OnUnload_netty_transport_native_epoll(JavaVM* vm, void* reserved) { +static void JNI_OnUnload_netty_transport_native_epoll0(JavaVM* vm, void* reserved) { JNIEnv* env; if ((*vm)->GetEnv(vm, (void**) &env, NETTY_JNI_VERSION) != JNI_OK) { // Something is wrong but nothing we can do about this :( @@ -798,8 +879,25 @@ void JNI_OnUnload_netty_transport_native_epoll(JavaVM* vm, void* reserved) { netty_epoll_native_JNI_OnUnLoad(env); } +// We build with -fvisibility=hidden so ensure we mark everything that needs to be visible with JNIEXPORT +// http://mail.openjdk.java.net/pipermail/core-libs-dev/2013-February/014549.html + +// Invoked by the JVM when statically linked +JNIEXPORT jint JNI_OnLoad_netty_transport_native_epoll(JavaVM* vm, void* reserved) { + return JNI_OnLoad_netty_transport_native_epoll0(vm, reserved); +} + +// Invoked by the JVM when statically linked +JNIEXPORT void JNI_OnUnload_netty_transport_native_epoll(JavaVM* vm, void* reserved) { + JNI_OnUnload_netty_transport_native_epoll0(vm, reserved); +} + #ifndef NETTY_BUILD_STATIC +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void* reserved) { + return JNI_OnLoad_netty_transport_native_epoll0(vm, reserved); +} + JNIEXPORT void JNI_OnUnload(JavaVM* vm, void* reserved) { - JNI_OnUnload_netty_transport_native_epoll(vm, reserved); + JNI_OnUnload_netty_transport_native_epoll0(vm, reserved); } #endif /* NETTY_BUILD_STATIC */ diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollChannel.java index c1af9975cd29..41c7d2cda30b 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollChannel.java @@ -60,7 +60,6 @@ abstract class AbstractEpollChannel extends AbstractChannel implements UnixChann private static final ClosedChannelException DO_CLOSE_CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace( new ClosedChannelException(), AbstractEpollChannel.class, "doClose()"); private static final ChannelMetadata METADATA = new ChannelMetadata(false); - private final int readFlag; final LinuxSocket socket; /** * The future of the current connection attempt. If not null, subsequent @@ -79,15 +78,13 @@ abstract class AbstractEpollChannel extends AbstractChannel implements UnixChann protected volatile boolean active; - AbstractEpollChannel(LinuxSocket fd, int flag) { - this(null, fd, flag, false); + AbstractEpollChannel(LinuxSocket fd) { + this(null, fd, false); } - AbstractEpollChannel(Channel parent, LinuxSocket fd, int flag, boolean active) { + AbstractEpollChannel(Channel parent, LinuxSocket fd, boolean active) { super(parent); socket = checkNotNull(fd, "fd"); - readFlag = flag; - flags |= flag; this.active = active; if (active) { // Directly cache the remote and local addresses @@ -97,11 +94,9 @@ abstract class AbstractEpollChannel extends AbstractChannel implements UnixChann } } - AbstractEpollChannel(Channel parent, LinuxSocket fd, int flag, SocketAddress remote) { + AbstractEpollChannel(Channel parent, LinuxSocket fd, SocketAddress remote) { super(parent); socket = checkNotNull(fd, "fd"); - readFlag = flag; - flags |= flag; active = true; // Directly cache the remote and local addresses // See https://github.com/netty/netty/issues/2359 @@ -228,7 +223,7 @@ protected final void doBeginRead() throws Exception { // We must set the read flag here as it is possible the user didn't read in the last read loop, the // executeEpollInReadyRunnable could read nothing, and if the user doesn't explicitly call read they will // never get data after this. - setFlag(readFlag); + setFlag(Native.EPOLLIN); // If EPOLL ET mode is enabled and auto read was toggled off on the last read loop then we may not be notified // again if we didn't consume all the data. So we force a read operation here if there maybe more data. @@ -268,7 +263,7 @@ public void run() { } else { // The EventLoop is not registered atm so just update the flags so the correct value // will be used once the channel is registered - flags &= ~readFlag; + flags &= ~Native.EPOLLIN; } } @@ -393,19 +388,14 @@ public void run() { */ abstract void epollInReady(); - final void epollInBefore() { maybeMoreDataToRead = false; } + final void epollInBefore() { + maybeMoreDataToRead = false; + } final void epollInFinally(ChannelConfig config) { - maybeMoreDataToRead = allocHandle.isEdgeTriggered() && allocHandle.maybeMoreDataToRead(); - // Check if there is a readPending which was not processed yet. - // This could be for two reasons: - // * The user called Channel.read() or ChannelHandlerContext.read() in channelRead(...) method - // * The user called Channel.read() or ChannelHandlerContext.read() in channelReadComplete(...) method - // - // See https://github.com/netty/netty/issues/2254 - if (!readPending && !config.isAutoRead()) { - clearEpollIn(); - } else if (readPending && maybeMoreDataToRead) { + maybeMoreDataToRead = allocHandle.maybeMoreDataToRead(); + + if (allocHandle.isReceivedRdHup() || (readPending && maybeMoreDataToRead)) { // trigger a read again as there may be something left to read and because of epoll ET we // will not get notified again until we read everything from the socket // @@ -414,6 +404,14 @@ final void epollInFinally(ChannelConfig config) { // to false before every read operation to prevent re-entry into epollInReady() we will not read from // the underlying OS again unless the user happens to call read again. executeEpollInReadyRunnable(config); + } else if (!readPending && !config.isAutoRead()) { + // Check if there is a readPending which was not processed yet. + // This could be for two reasons: + // * The user called Channel.read() or ChannelHandlerContext.read() in channelRead(...) method + // * The user called Channel.read() or ChannelHandlerContext.read() in channelReadComplete(...) method + // + // See https://github.com/netty/netty/issues/2254 + clearEpollIn(); } } @@ -534,7 +532,7 @@ protected final void clearEpollIn0() { assert eventLoop().inEventLoop(); try { readPending = false; - clearFlag(readFlag); + clearFlag(Native.EPOLLIN); } catch (IOException e) { // When this happens there is something completely wrong with either the filedescriptor or epoll, // so fire the exception through the pipeline and close the Channel. diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java index ebda9ebfe783..b7ccf37b2e36 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java @@ -39,7 +39,7 @@ protected AbstractEpollServerChannel(int fd) { } AbstractEpollServerChannel(LinuxSocket fd, boolean active) { - super(null, fd, Native.EPOLLIN, active); + super(null, fd, active); } @Override diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java index 7bb9e595bfd9..e2f2c88cb72e 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java @@ -54,6 +54,7 @@ import static io.netty.channel.internal.ChannelUtils.WRITE_STATUS_SNDBUF_FULL; import static io.netty.channel.unix.FileDescriptor.pipe; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel implements DuplexChannel { private static final ChannelMetadata METADATA = new ChannelMetadata(false, 16); @@ -99,19 +100,19 @@ protected AbstractEpollStreamChannel(int fd) { } AbstractEpollStreamChannel(Channel parent, LinuxSocket fd) { - super(parent, fd, Native.EPOLLIN, true); + super(parent, fd, true); // Add EPOLLRDHUP so we are notified once the remote peer close the connection. flags |= Native.EPOLLRDHUP; } AbstractEpollStreamChannel(Channel parent, LinuxSocket fd, SocketAddress remote) { - super(parent, fd, Native.EPOLLIN, remote); + super(parent, fd, remote); // Add EPOLLRDHUP so we are notified once the remote peer close the connection. flags |= Native.EPOLLRDHUP; } protected AbstractEpollStreamChannel(LinuxSocket fd, boolean active) { - super(null, fd, Native.EPOLLIN, active); + super(null, fd, active); // Add EPOLLRDHUP so we are notified once the remote peer close the connection. flags |= Native.EPOLLRDHUP; } @@ -163,9 +164,7 @@ public final ChannelFuture spliceTo(final AbstractEpollStreamChannel ch, final i if (ch.eventLoop() != eventLoop()) { throw new IllegalArgumentException("EventLoops are not the same."); } - if (len < 0) { - throw new IllegalArgumentException("len: " + len + " (expected: >= 0)"); - } + checkPositiveOrZero(len, "len"); if (ch.config().getEpollMode() != EpollMode.LEVEL_TRIGGERED || config().getEpollMode() != EpollMode.LEVEL_TRIGGERED) { throw new IllegalStateException("spliceTo() supported only when using " + EpollMode.LEVEL_TRIGGERED); @@ -214,12 +213,8 @@ public final ChannelFuture spliceTo(final FileDescriptor ch, final int offset, f */ public final ChannelFuture spliceTo(final FileDescriptor ch, final int offset, final int len, final ChannelPromise promise) { - if (len < 0) { - throw new IllegalArgumentException("len: " + len + " (expected: >= 0)"); - } - if (offset < 0) { - throw new IllegalArgumentException("offset must be >= 0 but was " + offset); - } + checkPositiveOrZero(len, "len"); + checkPositiveOrZero(offset, "offser"); if (config().getEpollMode() != EpollMode.LEVEL_TRIGGERED) { throw new IllegalStateException("spliceTo() supported only when using " + EpollMode.LEVEL_TRIGGERED); } @@ -372,13 +367,13 @@ private int writeBytesMultiple( * */ private int writeDefaultFileRegion(ChannelOutboundBuffer in, DefaultFileRegion region) throws Exception { + final long offset = region.transferred(); final long regionCount = region.count(); - if (region.transferred() >= regionCount) { + if (offset >= regionCount) { in.remove(); return 0; } - final long offset = region.transferred(); final long flushedAmount = socket.sendFile(region, region.position(), offset, regionCount - offset); if (flushedAmount > 0) { in.progress(flushedAmount); @@ -386,6 +381,8 @@ private int writeDefaultFileRegion(ChannelOutboundBuffer in, DefaultFileRegion r in.remove(); } return 1; + } else if (flushedAmount == 0) { + validateFileRegion(region, offset); } return WRITE_STATUS_SNDBUF_FULL; } @@ -513,22 +510,13 @@ protected int doWriteSingle(ChannelOutboundBuffer in) throws Exception { */ private int doWriteMultiple(ChannelOutboundBuffer in) throws Exception { final long maxBytesPerGatheringWrite = config().getMaxBytesPerGatheringWrite(); - if (PlatformDependent.hasUnsafe()) { - IovArray array = ((EpollEventLoop) eventLoop()).cleanArray(); - array.maxBytes(maxBytesPerGatheringWrite); - in.forEachFlushedMessage(array); - - if (array.count() >= 1) { - // TODO: Handle the case where cnt == 1 specially. - return writeBytesMultiple(in, array); - } - } else { - ByteBuffer[] buffers = in.nioBuffers(); - int cnt = in.nioBufferCount(); - if (cnt >= 1) { - // TODO: Handle the case where cnt == 1 specially. - return writeBytesMultiple(in, buffers, cnt, in.nioBufferSize(), maxBytesPerGatheringWrite); - } + IovArray array = ((EpollEventLoop) eventLoop()).cleanIovArray(); + array.maxBytes(maxBytesPerGatheringWrite); + in.forEachFlushedMessage(array); + + if (array.count() >= 1) { + // TODO: Handle the case where cnt == 1 specially. + return writeBytesMultiple(in, array); } // cnt == 0, which means the outbound buffer contained empty buffers only. in.removeBytes(0); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Epoll.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Epoll.java index b1e166ba737a..e4ecf42707b5 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Epoll.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Epoll.java @@ -16,7 +16,6 @@ package io.netty.channel.epoll; import io.netty.channel.unix.FileDescriptor; -import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.SystemPropertyUtil; /** @@ -58,15 +57,7 @@ public final class Epoll { } } - if (cause != null) { - UNAVAILABILITY_CAUSE = cause; - } else { - UNAVAILABILITY_CAUSE = PlatformDependent.hasUnsafe() - ? null - : new IllegalStateException( - "sun.misc.Unsafe not available", - PlatformDependent.getUnsafeUnavailabilityCause()); - } + UNAVAILABILITY_CAUSE = cause; } /** diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelConfig.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelConfig.java index c0cceb255afc..2d2610c0e27b 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelConfig.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelConfig.java @@ -29,12 +29,10 @@ import static io.netty.channel.unix.Limits.SSIZE_MAX; public class EpollChannelConfig extends DefaultChannelConfig { - final AbstractEpollChannel channel; private volatile long maxBytesPerGatheringWrite = SSIZE_MAX; EpollChannelConfig(AbstractEpollChannel channel) { super(channel); - this.channel = channel; } @Override @@ -136,7 +134,7 @@ public EpollChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator * {@link EpollMode#LEVEL_TRIGGERED}. */ public EpollMode getEpollMode() { - return channel.isFlagSet(Native.EPOLLET) + return ((AbstractEpollChannel) channel).isFlagSet(Native.EPOLLET) ? EpollMode.EDGE_TRIGGERED : EpollMode.LEVEL_TRIGGERED; } @@ -156,11 +154,11 @@ public EpollChannelConfig setEpollMode(EpollMode mode) { switch (mode) { case EDGE_TRIGGERED: checkChannelNotRegistered(); - channel.setFlag(Native.EPOLLET); + ((AbstractEpollChannel) channel).setFlag(Native.EPOLLET); break; case LEVEL_TRIGGERED: checkChannelNotRegistered(); - channel.clearFlag(Native.EPOLLET); + ((AbstractEpollChannel) channel).clearFlag(Native.EPOLLET); break; default: throw new Error(); @@ -179,7 +177,7 @@ private void checkChannelNotRegistered() { @Override protected final void autoReadCleared() { - channel.clearEpollIn(); + ((AbstractEpollChannel) channel).clearEpollIn(); } final void setMaxBytesPerGatheringWrite(long maxBytesPerGatheringWrite) { diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelOption.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelOption.java index d03e3e723678..1f5127c5dc3e 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelOption.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelOption.java @@ -38,6 +38,7 @@ public final class EpollChannelOption extends UnixChannelOption { public static final ChannelOption TCP_DEFER_ACCEPT = ChannelOption.valueOf(EpollChannelOption.class, "TCP_DEFER_ACCEPT"); public static final ChannelOption TCP_QUICKACK = valueOf(EpollChannelOption.class, "TCP_QUICKACK"); + public static final ChannelOption SO_BUSY_POLL = valueOf(EpollChannelOption.class, "SO_BUSY_POLL"); public static final ChannelOption EPOLL_MODE = ChannelOption.valueOf(EpollChannelOption.class, "EPOLL_MODE"); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java index a19191d5967e..7679404fdd78 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java @@ -28,6 +28,7 @@ import io.netty.channel.socket.DatagramChannelConfig; import io.netty.channel.socket.DatagramPacket; import io.netty.channel.unix.DatagramSocketAddress; +import io.netty.channel.unix.Errors; import io.netty.channel.unix.IovArray; import io.netty.channel.unix.UnixChannelUtil; import io.netty.util.internal.StringUtil; @@ -36,6 +37,7 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.NetworkInterface; +import java.net.PortUnreachableException; import java.net.SocketAddress; import java.net.SocketException; import java.nio.ByteBuffer; @@ -59,7 +61,7 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements private volatile boolean connected; public EpollDatagramChannel() { - super(newSocketDgram(), Native.EPOLLIN); + super(newSocketDgram()); config = new EpollDatagramChannelConfig(this); } @@ -68,7 +70,7 @@ public EpollDatagramChannel(int fd) { } EpollDatagramChannel(LinuxSocket fd) { - super(null, fd, Native.EPOLLIN, true); + super(null, fd, true); config = new EpollDatagramChannelConfig(this); } @@ -270,7 +272,8 @@ protected void doWrite(ChannelOutboundBuffer in) throws Exception { try { // Check if sendmmsg(...) is supported which is only the case for GLIBC 2.14+ if (Native.IS_SUPPORTING_SENDMMSG && in.size() > 1) { - NativeDatagramPacketArray array = NativeDatagramPacketArray.getInstance(in); + NativeDatagramPacketArray array = ((EpollEventLoop) eventLoop()).cleanDatagramPacketArray(); + in.forEachFlushedMessage(array); int cnt = array.count(); if (cnt >= 1) { @@ -347,7 +350,7 @@ private boolean doWriteMessage(Object msg) throws Exception { remoteAddress.getAddress(), remoteAddress.getPort()); } } else if (data.nioBufferCount() > 1) { - IovArray array = ((EpollEventLoop) eventLoop()).cleanArray(); + IovArray array = ((EpollEventLoop) eventLoop()).cleanIovArray(); array.add(data); int cnt = array.count(); assert cnt != 0; @@ -448,46 +451,72 @@ void epollInReady() { Throwable exception = null; try { - ByteBuf data = null; + ByteBuf byteBuf = null; try { + boolean connected = isConnected(); do { - data = allocHandle.allocate(allocator); - allocHandle.attemptedBytesRead(data.writableBytes()); - final DatagramSocketAddress remoteAddress; - if (data.hasMemoryAddress()) { - // has a memory address so use optimized call - remoteAddress = socket.recvFromAddress(data.memoryAddress(), data.writerIndex(), - data.capacity()); + byteBuf = allocHandle.allocate(allocator); + allocHandle.attemptedBytesRead(byteBuf.writableBytes()); + + final DatagramPacket packet; + if (connected) { + try { + allocHandle.lastBytesRead(doReadBytes(byteBuf)); + } catch (Errors.NativeIoException e) { + // We need to correctly translate connect errors to match NIO behaviour. + if (e.expectedErr() == Errors.ERROR_ECONNREFUSED_NEGATIVE) { + PortUnreachableException error = new PortUnreachableException(e.getMessage()); + error.initCause(e); + throw error; + } + throw e; + } + if (allocHandle.lastBytesRead() <= 0) { + // nothing was read, release the buffer. + byteBuf.release(); + byteBuf = null; + break; + } + packet = new DatagramPacket( + byteBuf, (InetSocketAddress) localAddress(), (InetSocketAddress) remoteAddress()); } else { - ByteBuffer nioData = data.internalNioBuffer(data.writerIndex(), data.writableBytes()); - remoteAddress = socket.recvFrom(nioData, nioData.position(), nioData.limit()); - } + final DatagramSocketAddress remoteAddress; + if (byteBuf.hasMemoryAddress()) { + // has a memory address so use optimized call + remoteAddress = socket.recvFromAddress(byteBuf.memoryAddress(), byteBuf.writerIndex(), + byteBuf.capacity()); + } else { + ByteBuffer nioData = byteBuf.internalNioBuffer( + byteBuf.writerIndex(), byteBuf.writableBytes()); + remoteAddress = socket.recvFrom(nioData, nioData.position(), nioData.limit()); + } - if (remoteAddress == null) { - allocHandle.lastBytesRead(-1); - data.release(); - data = null; - break; - } + if (remoteAddress == null) { + allocHandle.lastBytesRead(-1); + byteBuf.release(); + byteBuf = null; + break; + } + InetSocketAddress localAddress = remoteAddress.localAddress(); + if (localAddress == null) { + localAddress = (InetSocketAddress) localAddress(); + } + allocHandle.lastBytesRead(remoteAddress.receivedAmount()); + byteBuf.writerIndex(byteBuf.writerIndex() + allocHandle.lastBytesRead()); - InetSocketAddress localAddress = remoteAddress.localAddress(); - if (localAddress == null) { - localAddress = (InetSocketAddress) localAddress(); + packet = new DatagramPacket(byteBuf, localAddress, remoteAddress); } allocHandle.incMessagesRead(1); - allocHandle.lastBytesRead(remoteAddress.receivedAmount()); - data.writerIndex(data.writerIndex() + allocHandle.lastBytesRead()); readPending = false; - pipeline.fireChannelRead( - new DatagramPacket(data, localAddress, remoteAddress)); + pipeline.fireChannelRead(packet); - data = null; + byteBuf = null; } while (allocHandle.continueReading()); } catch (Throwable t) { - if (data != null) { - data.release(); + if (byteBuf != null) { + byteBuf.release(); } exception = t; } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannelConfig.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannelConfig.java index fbc44c1bcc16..778b555fa195 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannelConfig.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannelConfig.java @@ -31,12 +31,10 @@ public final class EpollDatagramChannelConfig extends EpollChannelConfig implements DatagramChannelConfig { private static final RecvByteBufAllocator DEFAULT_RCVBUF_ALLOCATOR = new FixedRecvByteBufAllocator(2048); - private final EpollDatagramChannel datagramChannel; private boolean activeOnOpen; EpollDatagramChannelConfig(EpollDatagramChannel channel) { super(channel); - datagramChannel = channel; setRecvByteBufAllocator(DEFAULT_RCVBUF_ALLOCATOR); } @@ -49,7 +47,7 @@ public Map, Object> getOptions() { ChannelOption.SO_REUSEADDR, ChannelOption.IP_MULTICAST_LOOP_DISABLED, ChannelOption.IP_MULTICAST_ADDR, ChannelOption.IP_MULTICAST_IF, ChannelOption.IP_MULTICAST_TTL, ChannelOption.IP_TOS, ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION, - EpollChannelOption.SO_REUSEPORT, EpollChannelOption.IP_TRANSPARENT, + EpollChannelOption.SO_REUSEPORT, EpollChannelOption.IP_FREEBIND, EpollChannelOption.IP_TRANSPARENT, EpollChannelOption.IP_RECVORIGDSTADDR); } @@ -92,6 +90,9 @@ public T getOption(ChannelOption option) { if (option == EpollChannelOption.IP_TRANSPARENT) { return (T) Boolean.valueOf(isIpTransparent()); } + if (option == EpollChannelOption.IP_FREEBIND) { + return (T) Boolean.valueOf(isFreeBind()); + } if (option == EpollChannelOption.IP_RECVORIGDSTADDR) { return (T) Boolean.valueOf(isIpRecvOrigDestAddr()); } @@ -125,6 +126,8 @@ public boolean setOption(ChannelOption option, T value) { setActiveOnOpen((Boolean) value); } else if (option == EpollChannelOption.SO_REUSEPORT) { setReusePort((Boolean) value); + } else if (option == EpollChannelOption.IP_FREEBIND) { + setFreeBind((Boolean) value); } else if (option == EpollChannelOption.IP_TRANSPARENT) { setIpTransparent((Boolean) value); } else if (option == EpollChannelOption.IP_RECVORIGDSTADDR) { @@ -219,7 +222,7 @@ public EpollDatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) @Override public int getSendBufferSize() { try { - return datagramChannel.socket.getSendBufferSize(); + return ((EpollDatagramChannel) channel).socket.getSendBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -228,7 +231,7 @@ public int getSendBufferSize() { @Override public EpollDatagramChannelConfig setSendBufferSize(int sendBufferSize) { try { - datagramChannel.socket.setSendBufferSize(sendBufferSize); + ((EpollDatagramChannel) channel).socket.setSendBufferSize(sendBufferSize); return this; } catch (IOException e) { throw new ChannelException(e); @@ -238,7 +241,7 @@ public EpollDatagramChannelConfig setSendBufferSize(int sendBufferSize) { @Override public int getReceiveBufferSize() { try { - return datagramChannel.socket.getReceiveBufferSize(); + return ((EpollDatagramChannel) channel).socket.getReceiveBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -247,7 +250,7 @@ public int getReceiveBufferSize() { @Override public EpollDatagramChannelConfig setReceiveBufferSize(int receiveBufferSize) { try { - datagramChannel.socket.setReceiveBufferSize(receiveBufferSize); + ((EpollDatagramChannel) channel).socket.setReceiveBufferSize(receiveBufferSize); return this; } catch (IOException e) { throw new ChannelException(e); @@ -257,7 +260,7 @@ public EpollDatagramChannelConfig setReceiveBufferSize(int receiveBufferSize) { @Override public int getTrafficClass() { try { - return datagramChannel.socket.getTrafficClass(); + return ((EpollDatagramChannel) channel).socket.getTrafficClass(); } catch (IOException e) { throw new ChannelException(e); } @@ -266,7 +269,7 @@ public int getTrafficClass() { @Override public EpollDatagramChannelConfig setTrafficClass(int trafficClass) { try { - datagramChannel.socket.setTrafficClass(trafficClass); + ((EpollDatagramChannel) channel).socket.setTrafficClass(trafficClass); return this; } catch (IOException e) { throw new ChannelException(e); @@ -276,7 +279,7 @@ public EpollDatagramChannelConfig setTrafficClass(int trafficClass) { @Override public boolean isReuseAddress() { try { - return datagramChannel.socket.isReuseAddress(); + return ((EpollDatagramChannel) channel).socket.isReuseAddress(); } catch (IOException e) { throw new ChannelException(e); } @@ -285,7 +288,7 @@ public boolean isReuseAddress() { @Override public EpollDatagramChannelConfig setReuseAddress(boolean reuseAddress) { try { - datagramChannel.socket.setReuseAddress(reuseAddress); + ((EpollDatagramChannel) channel).socket.setReuseAddress(reuseAddress); return this; } catch (IOException e) { throw new ChannelException(e); @@ -295,7 +298,7 @@ public EpollDatagramChannelConfig setReuseAddress(boolean reuseAddress) { @Override public boolean isBroadcast() { try { - return datagramChannel.socket.isBroadcast(); + return ((EpollDatagramChannel) channel).socket.isBroadcast(); } catch (IOException e) { throw new ChannelException(e); } @@ -304,7 +307,7 @@ public boolean isBroadcast() { @Override public EpollDatagramChannelConfig setBroadcast(boolean broadcast) { try { - datagramChannel.socket.setBroadcast(broadcast); + ((EpollDatagramChannel) channel).socket.setBroadcast(broadcast); return this; } catch (IOException e) { throw new ChannelException(e); @@ -362,7 +365,7 @@ public EpollDatagramChannelConfig setEpollMode(EpollMode mode) { */ public boolean isReusePort() { try { - return datagramChannel.socket.isReusePort(); + return ((EpollDatagramChannel) channel).socket.isReusePort(); } catch (IOException e) { throw new ChannelException(e); } @@ -377,7 +380,7 @@ public boolean isReusePort() { */ public EpollDatagramChannelConfig setReusePort(boolean reusePort) { try { - datagramChannel.socket.setReusePort(reusePort); + ((EpollDatagramChannel) channel).socket.setReusePort(reusePort); return this; } catch (IOException e) { throw new ChannelException(e); @@ -390,7 +393,7 @@ public EpollDatagramChannelConfig setReusePort(boolean reusePort) { */ public boolean isIpTransparent() { try { - return datagramChannel.socket.isIpTransparent(); + return ((EpollDatagramChannel) channel).socket.isIpTransparent(); } catch (IOException e) { throw new ChannelException(e); } @@ -402,7 +405,32 @@ public boolean isIpTransparent() { */ public EpollDatagramChannelConfig setIpTransparent(boolean ipTransparent) { try { - datagramChannel.socket.setIpTransparent(ipTransparent); + ((EpollDatagramChannel) channel).socket.setIpTransparent(ipTransparent); + return this; + } catch (IOException e) { + throw new ChannelException(e); + } + } + + /** + * Returns {@code true} if IP_FREEBIND is enabled, + * {@code false} otherwise. + */ + public boolean isFreeBind() { + try { + return ((EpollDatagramChannel) channel).socket.isIpFreeBind(); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + /** + * If {@code true} is used IP_FREEBIND is enabled, + * {@code false} for disable it. Default is disabled. + */ + public EpollDatagramChannelConfig setFreeBind(boolean freeBind) { + try { + ((EpollDatagramChannel) channel).socket.setIpFreeBind(freeBind); return this; } catch (IOException e) { throw new ChannelException(e); @@ -415,7 +443,7 @@ public EpollDatagramChannelConfig setIpTransparent(boolean ipTransparent) { */ public boolean isIpRecvOrigDestAddr() { try { - return datagramChannel.socket.isIpRecvOrigDestAddr(); + return ((EpollDatagramChannel) channel).socket.isIpRecvOrigDestAddr(); } catch (IOException e) { throw new ChannelException(e); } @@ -427,7 +455,7 @@ public boolean isIpRecvOrigDestAddr() { */ public EpollDatagramChannelConfig setIpRecvOrigDestAddr(boolean ipTransparent) { try { - datagramChannel.socket.setIpRecvOrigDestAddr(ipTransparent); + ((EpollDatagramChannel) channel).socket.setIpRecvOrigDestAddr(ipTransparent); return this; } catch (IOException e) { throw new ChannelException(e); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollEventArray.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollEventArray.java index b91e1c0f183c..ea03fb0f2cd5 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollEventArray.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollEventArray.java @@ -15,8 +15,11 @@ */ package io.netty.channel.epoll; +import io.netty.channel.unix.Buffer; import io.netty.util.internal.PlatformDependent; +import java.nio.ByteBuffer; + /** * This is an internal datastructure which can be directly passed to epoll_wait to reduce the overhead. * @@ -41,6 +44,7 @@ public final class EpollEventArray { // The offsiet of the data union in the epoll_event struct private static final int EPOLL_DATA_OFFSET = Native.offsetofEpollData(); + private ByteBuffer memory; private long memoryAddress; private int length; @@ -49,11 +53,8 @@ public final class EpollEventArray { throw new IllegalArgumentException("length must be >= 1 but was " + length); } this.length = length; - memoryAddress = allocate(length); - } - - private static long allocate(int length) { - return PlatformDependent.allocateMemory(length * EPOLL_EVENT_SIZE); + memory = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(length)); + memoryAddress = Buffer.memoryAddress(memory); } /** @@ -77,28 +78,43 @@ public int length() { public void increase() { // double the size length <<= 1; - free(); - memoryAddress = allocate(length); + // There is no need to preserve what was in the memory before. + ByteBuffer buffer = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(length)); + Buffer.free(memory); + memory = buffer; + memoryAddress = Buffer.memoryAddress(buffer); } /** * Free this {@link EpollEventArray}. Any usage after calling this method may segfault the JVM! */ void free() { - PlatformDependent.freeMemory(memoryAddress); + Buffer.free(memory); + memoryAddress = 0; } /** * Return the events for the {@code epoll_event} on this index. */ int events(int index) { - return PlatformDependent.getInt(memoryAddress + index * EPOLL_EVENT_SIZE); + return getInt(index, 0); } /** * Return the file descriptor for the {@code epoll_event} on this index. */ int fd(int index) { - return PlatformDependent.getInt(memoryAddress + index * EPOLL_EVENT_SIZE + EPOLL_DATA_OFFSET); + return getInt(index, EPOLL_DATA_OFFSET); + } + + private int getInt(int index, int offset) { + if (PlatformDependent.hasUnsafe()) { + return PlatformDependent.getInt(memoryAddress + index * EPOLL_EVENT_SIZE + offset); + } + return memory.getInt(index * EPOLL_EVENT_SIZE + offset); + } + + private static int calculateBufferCapacity(int capacity) { + return capacity * EPOLL_EVENT_SIZE; } } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollEventLoop.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollEventLoop.java index 986f59a1d8ff..6ac2ce53667b 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollEventLoop.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollEventLoop.java @@ -32,13 +32,9 @@ import io.netty.util.internal.logging.InternalLoggerFactory; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; import java.util.Iterator; import java.util.Queue; -import java.util.concurrent.Callable; import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import static java.lang.Math.min; @@ -59,6 +55,8 @@ public class EpollEventLoop extends SingleThreadEventLoop { Epoll.ensureAvailability(); } + // Pick a number that no task could have previously used. + private long prevDeadlineNanos = nanoTime() - 1; protected final FileDescriptor epollFd; protected final FileDescriptor eventFd; private final FileDescriptor timerFd; @@ -66,7 +64,11 @@ public class EpollEventLoop extends SingleThreadEventLoop { private final IntObjectMap channels = new IntObjectHashMap(4096); protected final boolean allowGrowing; protected final EpollEventArray events; - private final IovArray iovArray = new IovArray(); + + // These are initialized on first use + private IovArray iovArray; + private NativeDatagramPacketArray datagramPacketArray; + protected final SelectStrategy selectStrategy; protected final IntSupplier selectNowSupplier = new IntSupplier() { @Override @@ -74,17 +76,12 @@ public int get() throws Exception { return epollWaitNow(); } }; - private final Callable pendingTasksCallable = new Callable() { - @Override - public Integer call() throws Exception { - return EpollEventLoop.super.pendingTasks(); - } - }; + @SuppressWarnings("unused") // AtomicIntegerFieldUpdater protected volatile int wakenUp; private volatile int ioRatio = 50; // See http://man7.org/linux/man-pages/man2/timerfd_create.2.html. - static final long MAX_SCHEDULED_DAYS = TimeUnit.SECONDS.toDays(999999999); + private static final long MAX_SCHEDULED_TIMERFD_NS = 999999999; protected EpollEventLoop(EventLoopGroup parent, Executor executor, int maxEvents, SelectStrategy strategy, RejectedExecutionHandler rejectedExecutionHandler) { @@ -180,11 +177,27 @@ protected EpollEventLoop(EventLoopGroup parent, Executor executor, int maxEvents /** * Return a cleared {@link IovArray} that can be used for writes in this {@link EventLoop}. */ - IovArray cleanArray() { - iovArray.clear(); + IovArray cleanIovArray() { + if (iovArray == null) { + iovArray = new IovArray(); + } else { + iovArray.clear(); + } return iovArray; } + /** + * Return a cleared {@link NativeDatagramPacketArray} that can be used for writes in this {@link EventLoop}. + */ + NativeDatagramPacketArray cleanDatagramPacketArray() { + if (datagramPacketArray == null) { + datagramPacketArray = new NativeDatagramPacketArray(); + } else { + datagramPacketArray.clear(); + } + return datagramPacketArray; + } + public FileDescriptor epollFd() { return epollFd; } @@ -258,17 +271,6 @@ protected Queue newTaskQueue(int maxPendingTasks) { : PlatformDependent.newMpscQueue(maxPendingTasks); } - @Override - public int pendingTasks() { - // As we use a MpscQueue we need to ensure pendingTasks() is only executed from within the EventLoop as - // otherwise we may see unexpected behavior (as size() is only allowed to be called by a single consumer). - // See https://github.com/netty/netty/issues/5297 - if (inEventLoop()) { - return super.pendingTasks(); - } else { - return submit(pendingTasksCallable).syncUninterruptibly().getNow(); - } - } /** * Returns the percentage of the desired amount of time spent for I/O in the event loop. */ @@ -296,16 +298,29 @@ protected int epollWait(boolean oldWakeup) throws IOException { return epollWaitNow(); } - long totalDelay = delayNanos(System.nanoTime()); - int delaySeconds = (int) min(totalDelay / 1000000000L, Integer.MAX_VALUE); - return Native.epollWait(epollFd, events, timerFd, delaySeconds, - (int) min(totalDelay - delaySeconds * 1000000000L, Integer.MAX_VALUE)); + int delaySeconds; + int delayNanos; + long curDeadlineNanos = deadlineNanos(); + if (curDeadlineNanos == prevDeadlineNanos) { + delaySeconds = -1; + delayNanos = -1; + } else { + long totalDelay = delayNanos(System.nanoTime()); + prevDeadlineNanos = curDeadlineNanos; + delaySeconds = (int) min(totalDelay / 1000000000L, Integer.MAX_VALUE); + delayNanos = (int) min(totalDelay - delaySeconds * 1000000000L, MAX_SCHEDULED_TIMERFD_NS); + } + return Native.epollWait(epollFd, events, timerFd, delaySeconds, delayNanos); } private int epollWaitNow() throws IOException { return Native.epollWait(epollFd, events, timerFd, 0, 0); } + private int epollBusyWait() throws IOException { + return Native.epollBusyWait(epollFd, events); + } + @Override protected void run() { for (;;) { @@ -314,6 +329,11 @@ protected void run() { switch (strategy) { case SelectStrategy.CONTINUE: continue; + + case SelectStrategy.BUSY_WAIT: + strategy = epollBusyWait(); + break; + case SelectStrategy.SELECT: strategy = epollWait(WAKEN_UP_UPDATER.getAndSet(this, 0) == 1); @@ -396,7 +416,10 @@ protected void run() { } } - private static void handleLoopException(Throwable t) { + /** + * Visible only for testing! + */ + void handleLoopException(Throwable t) { logger.warn("Unexpected exception in the selector loop.", t); // Prevent possible consecutive immediate failures that lead to @@ -416,13 +439,9 @@ protected void closeAll() { } // Using the intermediate collection to prevent ConcurrentModificationException. // In the `close()` method, the channel is deleted from `channels` map. - Collection array = new ArrayList(channels.size()); + AbstractEpollChannel[] localChannels = channels.values().toArray(new AbstractEpollChannel[0]); - for (AbstractEpollChannel channel: channels.values()) { - array.add(channel); - } - - for (AbstractEpollChannel ch: array) { + for (AbstractEpollChannel ch : localChannels) { ch.unsafe().close(ch.unsafe().voidPromise()); } } @@ -515,7 +534,14 @@ protected void cleanup() { } } finally { // release native memory - iovArray.release(); + if (iovArray != null) { + iovArray.release(); + iovArray = null; + } + if (datagramPacketArray != null) { + datagramPacketArray.release(); + datagramPacketArray = null; + } events.free(); if (aioContext != null) { aioContext.destroy(); @@ -538,12 +564,4 @@ public void run() { submit(log); } } - - @Override - protected void validateScheduled(long amount, TimeUnit unit) { - long days = unit.toDays(amount); - if (days > MAX_SCHEDULED_DAYS) { - throw new IllegalArgumentException("days: " + days + " (expected: < " + MAX_SCHEDULED_DAYS + ')'); - } - } } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollRecvByteAllocatorHandle.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollRecvByteAllocatorHandle.java index e4aeb3bbda27..35ad15cab8d6 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollRecvByteAllocatorHandle.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollRecvByteAllocatorHandle.java @@ -19,10 +19,13 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelConfig; import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.unix.PreferredDirectByteBufAllocator; import io.netty.util.UncheckedBooleanSupplier; import io.netty.util.internal.ObjectUtil; class EpollRecvByteAllocatorHandle implements RecvByteBufAllocator.ExtendedHandle { + private final PreferredDirectByteBufAllocator preferredDirectByteBufAllocator = + new PreferredDirectByteBufAllocator(); private final RecvByteBufAllocator.ExtendedHandle delegate; private final UncheckedBooleanSupplier defaultMaybeMoreDataSupplier = new UncheckedBooleanSupplier() { @Override @@ -34,7 +37,7 @@ public boolean get() { private boolean receivedRdHup; EpollRecvByteAllocatorHandle(RecvByteBufAllocator.ExtendedHandle handle) { - this.delegate = ObjectUtil.checkNotNull(handle, "handle"); + delegate = ObjectUtil.checkNotNull(handle, "handle"); } final void receivedRdHup() { @@ -52,10 +55,11 @@ boolean maybeMoreDataToRead() { * respect auto read we supporting reading to stop if auto read is off. It is expected that the * {@link #EpollSocketChannel} implementations will track if we are in edgeTriggered mode and all data was not * read, and will force a EPOLLIN ready event. + * + * It is assumed RDHUP is handled externally by checking {@link #isReceivedRdHup()}. */ return (isEdgeTriggered && lastBytesRead() > 0) || - (!isEdgeTriggered && lastBytesRead() == attemptedBytesRead()) || - receivedRdHup; + (!isEdgeTriggered && lastBytesRead() == attemptedBytesRead()); } final void edgeTriggered(boolean edgeTriggered) { @@ -68,7 +72,9 @@ final boolean isEdgeTriggered() { @Override public final ByteBuf allocate(ByteBufAllocator alloc) { - return delegate.allocate(alloc); + // We need to ensure we always allocate a direct ByteBuf as we can only use a direct buffer to read via JNI. + preferredDirectByteBufAllocator.updateAllocator(alloc); + return delegate.allocate(preferredDirectByteBufAllocator); } @Override diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollRecvByteAllocatorStreamingHandle.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollRecvByteAllocatorStreamingHandle.java index f6ba5f58b54d..071acd578c8d 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollRecvByteAllocatorStreamingHandle.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollRecvByteAllocatorStreamingHandle.java @@ -18,7 +18,7 @@ import io.netty.channel.RecvByteBufAllocator; final class EpollRecvByteAllocatorStreamingHandle extends EpollRecvByteAllocatorHandle { - public EpollRecvByteAllocatorStreamingHandle(RecvByteBufAllocator.ExtendedHandle handle) { + EpollRecvByteAllocatorStreamingHandle(RecvByteBufAllocator.ExtendedHandle handle) { super(handle); } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerChannelConfig.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerChannelConfig.java index 5d6394b1428c..aa3409cf84c0 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerChannelConfig.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerChannelConfig.java @@ -30,15 +30,14 @@ import static io.netty.channel.ChannelOption.SO_BACKLOG; import static io.netty.channel.ChannelOption.SO_RCVBUF; import static io.netty.channel.ChannelOption.SO_REUSEADDR; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; public class EpollServerChannelConfig extends EpollChannelConfig implements ServerSocketChannelConfig { - protected final AbstractEpollChannel channel; private volatile int backlog = NetUtil.SOMAXCONN; private volatile int pendingFastOpenRequestsThreshold; EpollServerChannelConfig(AbstractEpollChannel channel) { super(channel); - this.channel = channel; } @Override @@ -85,7 +84,7 @@ public boolean setOption(ChannelOption option, T value) { public boolean isReuseAddress() { try { - return channel.socket.isReuseAddress(); + return ((AbstractEpollChannel) channel).socket.isReuseAddress(); } catch (IOException e) { throw new ChannelException(e); } @@ -93,7 +92,7 @@ public boolean isReuseAddress() { public EpollServerChannelConfig setReuseAddress(boolean reuseAddress) { try { - channel.socket.setReuseAddress(reuseAddress); + ((AbstractEpollChannel) channel).socket.setReuseAddress(reuseAddress); return this; } catch (IOException e) { throw new ChannelException(e); @@ -102,7 +101,7 @@ public EpollServerChannelConfig setReuseAddress(boolean reuseAddress) { public int getReceiveBufferSize() { try { - return channel.socket.getReceiveBufferSize(); + return ((AbstractEpollChannel) channel).socket.getReceiveBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -110,7 +109,7 @@ public int getReceiveBufferSize() { public EpollServerChannelConfig setReceiveBufferSize(int receiveBufferSize) { try { - channel.socket.setReceiveBufferSize(receiveBufferSize); + ((AbstractEpollChannel) channel).socket.setReceiveBufferSize(receiveBufferSize); return this; } catch (IOException e) { throw new ChannelException(e); @@ -122,9 +121,7 @@ public int getBacklog() { } public EpollServerChannelConfig setBacklog(int backlog) { - if (backlog < 0) { - throw new IllegalArgumentException("backlog: " + backlog); - } + checkPositiveOrZero(backlog, "backlog"); this.backlog = backlog; return this; } @@ -148,9 +145,7 @@ public int getTcpFastopen() { * @see RFC 7413 TCP FastOpen */ public EpollServerChannelConfig setTcpFastopen(int pendingFastOpenRequestsThreshold) { - if (this.pendingFastOpenRequestsThreshold < 0) { - throw new IllegalArgumentException("pendingFastOpenRequestsThreshold: " + pendingFastOpenRequestsThreshold); - } + checkPositiveOrZero(this.pendingFastOpenRequestsThreshold, "pendingFastOpenRequestsThreshold"); this.pendingFastOpenRequestsThreshold = pendingFastOpenRequestsThreshold; return this; } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerSocketChannelConfig.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerSocketChannelConfig.java index dfccb199c983..91861f0c137e 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerSocketChannelConfig.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerSocketChannelConfig.java @@ -191,7 +191,7 @@ public EpollServerSocketChannelConfig setTcpMd5Sig(Map keys */ public boolean isReusePort() { try { - return channel.socket.isReusePort(); + return ((EpollServerSocketChannel) channel).socket.isReusePort(); } catch (IOException e) { throw new ChannelException(e); } @@ -206,7 +206,7 @@ public boolean isReusePort() { */ public EpollServerSocketChannelConfig setReusePort(boolean reusePort) { try { - channel.socket.setReusePort(reusePort); + ((EpollServerSocketChannel) channel).socket.setReusePort(reusePort); return this; } catch (IOException e) { throw new ChannelException(e); @@ -219,7 +219,7 @@ public EpollServerSocketChannelConfig setReusePort(boolean reusePort) { */ public boolean isFreeBind() { try { - return channel.socket.isIpFreeBind(); + return ((EpollServerSocketChannel) channel).socket.isIpFreeBind(); } catch (IOException e) { throw new ChannelException(e); } @@ -231,7 +231,7 @@ public boolean isFreeBind() { */ public EpollServerSocketChannelConfig setFreeBind(boolean freeBind) { try { - channel.socket.setIpFreeBind(freeBind); + ((EpollServerSocketChannel) channel).socket.setIpFreeBind(freeBind); return this; } catch (IOException e) { throw new ChannelException(e); @@ -244,7 +244,7 @@ public EpollServerSocketChannelConfig setFreeBind(boolean freeBind) { */ public boolean isIpTransparent() { try { - return channel.socket.isIpTransparent(); + return ((EpollServerSocketChannel) channel).socket.isIpTransparent(); } catch (IOException e) { throw new ChannelException(e); } @@ -256,7 +256,7 @@ public boolean isIpTransparent() { */ public EpollServerSocketChannelConfig setIpTransparent(boolean transparent) { try { - channel.socket.setIpTransparent(transparent); + ((EpollServerSocketChannel) channel).socket.setIpTransparent(transparent); return this; } catch (IOException e) { throw new ChannelException(e); @@ -268,7 +268,7 @@ public EpollServerSocketChannelConfig setIpTransparent(boolean transparent) { */ public EpollServerSocketChannelConfig setTcpDeferAccept(int deferAccept) { try { - channel.socket.setTcpDeferAccept(deferAccept); + ((EpollServerSocketChannel) channel).socket.setTcpDeferAccept(deferAccept); return this; } catch (IOException e) { throw new ChannelException(e); @@ -280,7 +280,7 @@ public EpollServerSocketChannelConfig setTcpDeferAccept(int deferAccept) { */ public int getTcpDeferAccept() { try { - return channel.socket.getTcpDeferAccept(); + return ((EpollServerSocketChannel) channel).socket.getTcpDeferAccept(); } catch (IOException e) { throw new ChannelException(e); } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannelConfig.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannelConfig.java index 67468910b28e..f3d04dd3f955 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannelConfig.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannelConfig.java @@ -38,7 +38,6 @@ import static io.netty.channel.ChannelOption.TCP_NODELAY; public final class EpollSocketChannelConfig extends EpollChannelConfig implements SocketChannelConfig { - private final EpollSocketChannel channel; private volatile boolean allowHalfClosure; /** @@ -47,7 +46,6 @@ public final class EpollSocketChannelConfig extends EpollChannelConfig implement EpollSocketChannelConfig(EpollSocketChannel channel) { super(channel); - this.channel = channel; if (PlatformDependent.canEnableTcpNoDelayByDefault()) { setTcpNoDelay(true); } @@ -62,7 +60,7 @@ public Map, Object> getOptions() { ALLOW_HALF_CLOSURE, EpollChannelOption.TCP_CORK, EpollChannelOption.TCP_NOTSENT_LOWAT, EpollChannelOption.TCP_KEEPCNT, EpollChannelOption.TCP_KEEPIDLE, EpollChannelOption.TCP_KEEPINTVL, EpollChannelOption.TCP_MD5SIG, EpollChannelOption.TCP_QUICKACK, EpollChannelOption.IP_TRANSPARENT, - EpollChannelOption.TCP_FASTOPEN_CONNECT); + EpollChannelOption.TCP_FASTOPEN_CONNECT, EpollChannelOption.SO_BUSY_POLL); } @SuppressWarnings("unchecked") @@ -119,6 +117,9 @@ public T getOption(ChannelOption option) { if (option == EpollChannelOption.TCP_FASTOPEN_CONNECT) { return (T) Boolean.valueOf(isTcpFastOpenConnect()); } + if (option == EpollChannelOption.SO_BUSY_POLL) { + return (T) Integer.valueOf(getSoBusyPoll()); + } return super.getOption(option); } @@ -164,6 +165,8 @@ public boolean setOption(ChannelOption option, T value) { setTcpQuickAck((Boolean) value); } else if (option == EpollChannelOption.TCP_FASTOPEN_CONNECT) { setTcpFastOpenConnect((Boolean) value); + } else if (option == EpollChannelOption.SO_BUSY_POLL) { + setSoBusyPoll((Integer) value); } else { return super.setOption(option, value); } @@ -174,7 +177,7 @@ public boolean setOption(ChannelOption option, T value) { @Override public int getReceiveBufferSize() { try { - return channel.socket.getReceiveBufferSize(); + return ((EpollSocketChannel) channel).socket.getReceiveBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -183,7 +186,7 @@ public int getReceiveBufferSize() { @Override public int getSendBufferSize() { try { - return channel.socket.getSendBufferSize(); + return ((EpollSocketChannel) channel).socket.getSendBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -192,7 +195,7 @@ public int getSendBufferSize() { @Override public int getSoLinger() { try { - return channel.socket.getSoLinger(); + return ((EpollSocketChannel) channel).socket.getSoLinger(); } catch (IOException e) { throw new ChannelException(e); } @@ -201,7 +204,7 @@ public int getSoLinger() { @Override public int getTrafficClass() { try { - return channel.socket.getTrafficClass(); + return ((EpollSocketChannel) channel).socket.getTrafficClass(); } catch (IOException e) { throw new ChannelException(e); } @@ -210,7 +213,7 @@ public int getTrafficClass() { @Override public boolean isKeepAlive() { try { - return channel.socket.isKeepAlive(); + return ((EpollSocketChannel) channel).socket.isKeepAlive(); } catch (IOException e) { throw new ChannelException(e); } @@ -219,7 +222,7 @@ public boolean isKeepAlive() { @Override public boolean isReuseAddress() { try { - return channel.socket.isReuseAddress(); + return ((EpollSocketChannel) channel).socket.isReuseAddress(); } catch (IOException e) { throw new ChannelException(e); } @@ -228,7 +231,7 @@ public boolean isReuseAddress() { @Override public boolean isTcpNoDelay() { try { - return channel.socket.isTcpNoDelay(); + return ((EpollSocketChannel) channel).socket.isTcpNoDelay(); } catch (IOException e) { throw new ChannelException(e); } @@ -239,7 +242,18 @@ public boolean isTcpNoDelay() { */ public boolean isTcpCork() { try { - return channel.socket.isTcpCork(); + return ((EpollSocketChannel) channel).socket.isTcpCork(); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + /** + * Get the {@code SO_BUSY_POLL} option on the socket. See {@code man 7 tcp} for more details. + */ + public int getSoBusyPoll() { + try { + return ((EpollSocketChannel) channel).socket.getSoBusyPoll(); } catch (IOException e) { throw new ChannelException(e); } @@ -251,7 +265,7 @@ public boolean isTcpCork() { */ public long getTcpNotSentLowAt() { try { - return channel.socket.getTcpNotSentLowAt(); + return ((EpollSocketChannel) channel).socket.getTcpNotSentLowAt(); } catch (IOException e) { throw new ChannelException(e); } @@ -262,7 +276,7 @@ public long getTcpNotSentLowAt() { */ public int getTcpKeepIdle() { try { - return channel.socket.getTcpKeepIdle(); + return ((EpollSocketChannel) channel).socket.getTcpKeepIdle(); } catch (IOException e) { throw new ChannelException(e); } @@ -273,7 +287,7 @@ public int getTcpKeepIdle() { */ public int getTcpKeepIntvl() { try { - return channel.socket.getTcpKeepIntvl(); + return ((EpollSocketChannel) channel).socket.getTcpKeepIntvl(); } catch (IOException e) { throw new ChannelException(e); } @@ -284,7 +298,7 @@ public int getTcpKeepIntvl() { */ public int getTcpKeepCnt() { try { - return channel.socket.getTcpKeepCnt(); + return ((EpollSocketChannel) channel).socket.getTcpKeepCnt(); } catch (IOException e) { throw new ChannelException(e); } @@ -295,7 +309,7 @@ public int getTcpKeepCnt() { */ public int getTcpUserTimeout() { try { - return channel.socket.getTcpUserTimeout(); + return ((EpollSocketChannel) channel).socket.getTcpUserTimeout(); } catch (IOException e) { throw new ChannelException(e); } @@ -304,7 +318,7 @@ public int getTcpUserTimeout() { @Override public EpollSocketChannelConfig setKeepAlive(boolean keepAlive) { try { - channel.socket.setKeepAlive(keepAlive); + ((EpollSocketChannel) channel).socket.setKeepAlive(keepAlive); return this; } catch (IOException e) { throw new ChannelException(e); @@ -320,7 +334,7 @@ public EpollSocketChannelConfig setPerformancePreferences( @Override public EpollSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) { try { - channel.socket.setReceiveBufferSize(receiveBufferSize); + ((EpollSocketChannel) channel).socket.setReceiveBufferSize(receiveBufferSize); return this; } catch (IOException e) { throw new ChannelException(e); @@ -330,7 +344,7 @@ public EpollSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) { @Override public EpollSocketChannelConfig setReuseAddress(boolean reuseAddress) { try { - channel.socket.setReuseAddress(reuseAddress); + ((EpollSocketChannel) channel).socket.setReuseAddress(reuseAddress); return this; } catch (IOException e) { throw new ChannelException(e); @@ -340,7 +354,7 @@ public EpollSocketChannelConfig setReuseAddress(boolean reuseAddress) { @Override public EpollSocketChannelConfig setSendBufferSize(int sendBufferSize) { try { - channel.socket.setSendBufferSize(sendBufferSize); + ((EpollSocketChannel) channel).socket.setSendBufferSize(sendBufferSize); calculateMaxBytesPerGatheringWrite(); return this; } catch (IOException e) { @@ -351,7 +365,7 @@ public EpollSocketChannelConfig setSendBufferSize(int sendBufferSize) { @Override public EpollSocketChannelConfig setSoLinger(int soLinger) { try { - channel.socket.setSoLinger(soLinger); + ((EpollSocketChannel) channel).socket.setSoLinger(soLinger); return this; } catch (IOException e) { throw new ChannelException(e); @@ -361,7 +375,7 @@ public EpollSocketChannelConfig setSoLinger(int soLinger) { @Override public EpollSocketChannelConfig setTcpNoDelay(boolean tcpNoDelay) { try { - channel.socket.setTcpNoDelay(tcpNoDelay); + ((EpollSocketChannel) channel).socket.setTcpNoDelay(tcpNoDelay); return this; } catch (IOException e) { throw new ChannelException(e); @@ -373,7 +387,19 @@ public EpollSocketChannelConfig setTcpNoDelay(boolean tcpNoDelay) { */ public EpollSocketChannelConfig setTcpCork(boolean tcpCork) { try { - channel.socket.setTcpCork(tcpCork); + ((EpollSocketChannel) channel).socket.setTcpCork(tcpCork); + return this; + } catch (IOException e) { + throw new ChannelException(e); + } + } + + /** + * Set the {@code SO_BUSY_POLL} option on the socket. See {@code man 7 tcp} for more details. + */ + public EpollSocketChannelConfig setSoBusyPoll(int loopMicros) { + try { + ((EpollSocketChannel) channel).socket.setSoBusyPoll(loopMicros); return this; } catch (IOException e) { throw new ChannelException(e); @@ -386,7 +412,7 @@ public EpollSocketChannelConfig setTcpCork(boolean tcpCork) { */ public EpollSocketChannelConfig setTcpNotSentLowAt(long tcpNotSentLowAt) { try { - channel.socket.setTcpNotSentLowAt(tcpNotSentLowAt); + ((EpollSocketChannel) channel).socket.setTcpNotSentLowAt(tcpNotSentLowAt); return this; } catch (IOException e) { throw new ChannelException(e); @@ -396,7 +422,7 @@ public EpollSocketChannelConfig setTcpNotSentLowAt(long tcpNotSentLowAt) { @Override public EpollSocketChannelConfig setTrafficClass(int trafficClass) { try { - channel.socket.setTrafficClass(trafficClass); + ((EpollSocketChannel) channel).socket.setTrafficClass(trafficClass); return this; } catch (IOException e) { throw new ChannelException(e); @@ -408,7 +434,7 @@ public EpollSocketChannelConfig setTrafficClass(int trafficClass) { */ public EpollSocketChannelConfig setTcpKeepIdle(int seconds) { try { - channel.socket.setTcpKeepIdle(seconds); + ((EpollSocketChannel) channel).socket.setTcpKeepIdle(seconds); return this; } catch (IOException e) { throw new ChannelException(e); @@ -420,7 +446,7 @@ public EpollSocketChannelConfig setTcpKeepIdle(int seconds) { */ public EpollSocketChannelConfig setTcpKeepIntvl(int seconds) { try { - channel.socket.setTcpKeepIntvl(seconds); + ((EpollSocketChannel) channel).socket.setTcpKeepIntvl(seconds); return this; } catch (IOException e) { throw new ChannelException(e); @@ -440,7 +466,7 @@ public EpollSocketChannelConfig setTcpKeepCntl(int probes) { */ public EpollSocketChannelConfig setTcpKeepCnt(int probes) { try { - channel.socket.setTcpKeepCnt(probes); + ((EpollSocketChannel) channel).socket.setTcpKeepCnt(probes); return this; } catch (IOException e) { throw new ChannelException(e); @@ -452,7 +478,7 @@ public EpollSocketChannelConfig setTcpKeepCnt(int probes) { */ public EpollSocketChannelConfig setTcpUserTimeout(int milliseconds) { try { - channel.socket.setTcpUserTimeout(milliseconds); + ((EpollSocketChannel) channel).socket.setTcpUserTimeout(milliseconds); return this; } catch (IOException e) { throw new ChannelException(e); @@ -465,7 +491,7 @@ public EpollSocketChannelConfig setTcpUserTimeout(int milliseconds) { */ public boolean isIpTransparent() { try { - return channel.socket.isIpTransparent(); + return ((EpollSocketChannel) channel).socket.isIpTransparent(); } catch (IOException e) { throw new ChannelException(e); } @@ -477,7 +503,7 @@ public boolean isIpTransparent() { */ public EpollSocketChannelConfig setIpTransparent(boolean transparent) { try { - channel.socket.setIpTransparent(transparent); + ((EpollSocketChannel) channel).socket.setIpTransparent(transparent); return this; } catch (IOException e) { throw new ChannelException(e); @@ -491,7 +517,7 @@ public EpollSocketChannelConfig setIpTransparent(boolean transparent) { */ public EpollSocketChannelConfig setTcpMd5Sig(Map keys) { try { - channel.setTcpMd5Sig(keys); + ((EpollSocketChannel) channel).setTcpMd5Sig(keys); return this; } catch (IOException e) { throw new ChannelException(e); @@ -504,7 +530,7 @@ public EpollSocketChannelConfig setTcpMd5Sig(Map keys) { */ public EpollSocketChannelConfig setTcpQuickAck(boolean quickAck) { try { - channel.socket.setTcpQuickAck(quickAck); + ((EpollSocketChannel) channel).socket.setTcpQuickAck(quickAck); return this; } catch (IOException e) { throw new ChannelException(e); @@ -517,7 +543,7 @@ public EpollSocketChannelConfig setTcpQuickAck(boolean quickAck) { */ public boolean isTcpQuickAck() { try { - return channel.socket.isTcpQuickAck(); + return ((EpollSocketChannel) channel).socket.isTcpQuickAck(); } catch (IOException e) { throw new ChannelException(e); } @@ -531,7 +557,7 @@ public boolean isTcpQuickAck() { */ public EpollSocketChannelConfig setTcpFastOpenConnect(boolean fastOpenConnect) { try { - channel.socket.setTcpFastOpenConnect(fastOpenConnect); + ((EpollSocketChannel) channel).socket.setTcpFastOpenConnect(fastOpenConnect); return this; } catch (IOException e) { throw new ChannelException(e); @@ -543,7 +569,7 @@ public EpollSocketChannelConfig setTcpFastOpenConnect(boolean fastOpenConnect) { */ public boolean isTcpFastOpenConnect() { try { - return channel.socket.isTcpFastOpenConnect(); + return ((EpollSocketChannel) channel).socket.isTcpFastOpenConnect(); } catch (IOException e) { throw new ChannelException(e); } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java index f0a5304a74cf..a20dc4aa06a5 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java @@ -40,7 +40,7 @@ final class LinuxSocket extends Socket { private static final ClosedChannelException SENDFILE_CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace( new ClosedChannelException(), Native.class, "sendfile(...)"); - public LinuxSocket(int fd) { + LinuxSocket(int fd) { super(fd); } @@ -56,6 +56,10 @@ void setTcpCork(boolean tcpCork) throws IOException { setTcpCork(intValue(), tcpCork ? 1 : 0); } + void setSoBusyPoll(int loopMicros) throws IOException { + setSoBusyPoll(intValue(), loopMicros); + } + void setTcpNotSentLowAt(long tcpNotSentLowAt) throws IOException { if (tcpNotSentLowAt < 0 || tcpNotSentLowAt > MAX_UINT32_T) { throw new IllegalArgumentException("tcpNotSentLowAt must be a uint32_t"); @@ -116,6 +120,10 @@ boolean isTcpCork() throws IOException { return isTcpCork(intValue()) != 0; } + int getSoBusyPoll() throws IOException { + return getSoBusyPoll(intValue()); + } + int getTcpDeferAccept() throws IOException { return getTcpDeferAccept(intValue()); } @@ -190,6 +198,7 @@ private static native long sendFile(int socketFd, DefaultFileRegion src, long ba private static native int getTcpDeferAccept(int fd) throws IOException; private static native int isTcpQuickAck(int fd) throws IOException; private static native int isTcpCork(int fd) throws IOException; + private static native int getSoBusyPoll(int fd) throws IOException; private static native int getTcpNotSentLowAt(int fd) throws IOException; private static native int getTcpKeepIdle(int fd) throws IOException; private static native int getTcpKeepIntvl(int fd) throws IOException; @@ -205,6 +214,7 @@ private static native long sendFile(int socketFd, DefaultFileRegion src, long ba private static native void setTcpDeferAccept(int fd, int deferAccept) throws IOException; private static native void setTcpQuickAck(int fd, int quickAck) throws IOException; private static native void setTcpCork(int fd, int tcpCork) throws IOException; + private static native void setSoBusyPoll(int fd, int loopMicros) throws IOException; private static native void setTcpNotSentLowAt(int fd, int tcpNotSentLowAt) throws IOException; private static native void setTcpFastOpen(int fd, int tcpFastopenBacklog) throws IOException; private static native void setTcpFastOpenConnect(int fd, int tcpFastOpenConnect) throws IOException; diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java index b600ae5d97be..01575b294597 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java @@ -125,6 +125,21 @@ public static int epollWait(FileDescriptor epollFd, EpollEventArray events, File } private static native int epollWait0(int efd, long address, int len, int timerFd, int timeoutSec, int timeoutNs); + /** + * Non-blocking variant of + * {@link #epollWait(FileDescriptor, EpollEventArray, FileDescriptor, int, int)} + * that will also hint to processor we are in a busy-wait loop. + */ + public static int epollBusyWait(FileDescriptor epollFd, EpollEventArray events) throws IOException { + int ready = epollBusyWait0(epollFd.intValue(), events.memoryAddress(), events.length()); + if (ready < 0) { + throw newIOException("epoll_wait", ready); + } + return ready; + } + + private static native int epollBusyWait0(int efd, long address, int len); + public static void epollCtlAdd(int efd, final int fd, final int flags) throws IOException { int res = epollCtlAdd0(efd, fd, flags); if (res < 0) { diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java index aafa67ee66d9..6107af2c55f7 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java @@ -19,7 +19,6 @@ import io.netty.channel.ChannelOutboundBuffer; import io.netty.channel.socket.DatagramPacket; import io.netty.channel.unix.IovArray; -import io.netty.util.concurrent.FastThreadLocal; import java.net.Inet6Address; import java.net.InetAddress; import java.net.InetSocketAddress; @@ -32,28 +31,15 @@ */ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessageProcessor { - private static final FastThreadLocal ARRAY = - new FastThreadLocal() { - @Override - protected NativeDatagramPacketArray initialValue() throws Exception { - return new NativeDatagramPacketArray(); - } - - @Override - protected void onRemoval(NativeDatagramPacketArray value) throws Exception { - NativeDatagramPacket[] packetsArray = value.packets; - // Release all packets - for (NativeDatagramPacket datagramPacket : packetsArray) { - datagramPacket.release(); - } - } - }; - // Use UIO_MAX_IOV as this is the maximum number we can write with one sendmmsg(...) call. private final NativeDatagramPacket[] packets = new NativeDatagramPacket[UIO_MAX_IOV]; + + // We share one IovArray for all NativeDatagramPackets to reduce memory overhead. This will allow us to write + // up to IOV_MAX iovec across all messages in one sendmmsg(...) call. + private final IovArray iovArray = new IovArray(); private int count; - private NativeDatagramPacketArray() { + NativeDatagramPacketArray() { for (int i = 0; i < packets.length; i++) { packets[i] = new NativeDatagramPacket(); } @@ -65,6 +51,8 @@ private NativeDatagramPacketArray() { */ boolean add(DatagramPacket packet) { if (count == packets.length) { + // We already filled up to UIO_MAX_IOV messages. This is the max allowed per sendmmsg(...) call, we will + // try again later. return false; } ByteBuf content = packet.content(); @@ -74,16 +62,20 @@ boolean add(DatagramPacket packet) { } NativeDatagramPacket p = packets[count]; InetSocketAddress recipient = packet.recipient(); - if (!p.init(content, recipient)) { + + int offset = iovArray.count(); + if (!iovArray.add(content)) { + // Not enough space to hold the whole content, we will try again later. return false; } + p.init(iovArray.memoryAddress(offset), iovArray.count() - offset, recipient); count++; return true; } @Override - public boolean processMessage(Object msg) throws Exception { + public boolean processMessage(Object msg) { return msg instanceof DatagramPacket && add((DatagramPacket) msg); } @@ -101,15 +93,13 @@ NativeDatagramPacket[] packets() { return packets; } - /** - * Returns a {@link NativeDatagramPacketArray} which is filled with the flushed messages of - * {@link ChannelOutboundBuffer}. - */ - static NativeDatagramPacketArray getInstance(ChannelOutboundBuffer buffer) throws Exception { - NativeDatagramPacketArray array = ARRAY.get(); - array.count = 0; - buffer.forEachFlushedMessage(array); - return array; + void clear() { + this.count = 0; + this.iovArray.clear(); + } + + void release() { + iovArray.release(); } /** @@ -117,10 +107,6 @@ static NativeDatagramPacketArray getInstance(ChannelOutboundBuffer buffer) throw */ @SuppressWarnings("unused") static final class NativeDatagramPacket { - // Each NativeDatagramPackets holds a IovArray which is used for gathering writes. - // This is ok as NativeDatagramPacketArray is always obtained via a FastThreadLocal and - // so the memory needed is quite small anyway. - private final IovArray array = new IovArray(); // This is the actual struct iovec* private long memoryAddress; @@ -130,21 +116,9 @@ static final class NativeDatagramPacket { private int scopeId; private int port; - private void release() { - array.release(); - } - - /** - * Init this instance and return {@code true} if the init was successful. - */ - private boolean init(ByteBuf buf, InetSocketAddress recipient) { - array.clear(); - if (!array.add(buf)) { - return false; - } - // always start from offset 0 - memoryAddress = array.memoryAddress(0); - count = array.count(); + private void init(long memoryAddress, int count, InetSocketAddress recipient) { + this.memoryAddress = memoryAddress; + this.count = count; InetAddress address = recipient.getAddress(); if (address instanceof Inet6Address) { @@ -155,7 +129,6 @@ private boolean init(ByteBuf buf, InetSocketAddress recipient) { scopeId = 0; } port = recipient.getPort(); - return true; } } } diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramChannelConfigTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramChannelConfigTest.java new file mode 100644 index 000000000000..39cf7acd064b --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramChannelConfigTest.java @@ -0,0 +1,32 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +public class EpollDatagramChannelConfigTest { + + @Test + public void testIpFreeBind() throws Exception { + Epoll.ensureAvailability(); + EpollDatagramChannel channel = new EpollDatagramChannel(); + assertTrue(channel.config().setOption(EpollChannelOption.IP_FREEBIND, true)); + assertTrue(channel.config().getOption(EpollChannelOption.IP_FREEBIND)); + channel.fd().close(); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketDataReadInitialStateTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketDataReadInitialStateTest.java new file mode 100644 index 000000000000..9f848cd18d2d --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketDataReadInitialStateTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketDataReadInitialStateTest; + +import java.net.SocketAddress; +import java.util.List; + +public class EpollDomainSocketDataReadInitialStateTest extends SocketDataReadInitialStateTest { + @Override + protected SocketAddress newSocketAddress() { + return EpollSocketTestPermutation.newSocketAddress(); + } + + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.domainSocket(); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketSslClientRenegotiateTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketSslClientRenegotiateTest.java new file mode 100644 index 000000000000..3f493d08355b --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketSslClientRenegotiateTest.java @@ -0,0 +1,42 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.handler.ssl.SslContext; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketSslClientRenegotiateTest; + +import java.net.SocketAddress; +import java.util.List; + +public class EpollDomainSocketSslClientRenegotiateTest extends SocketSslClientRenegotiateTest { + + public EpollDomainSocketSslClientRenegotiateTest(SslContext serverCtx, SslContext clientCtx, boolean delegate) { + super(serverCtx, clientCtx, delegate); + } + + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.domainSocket(); + } + + @Override + protected SocketAddress newSocketAddress() { + return EpollSocketTestPermutation.newSocketAddress(); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketSslGreetingTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketSslGreetingTest.java index a1ed0c5ac2ea..4683a701eaae 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketSslGreetingTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketSslGreetingTest.java @@ -26,8 +26,8 @@ public class EpollDomainSocketSslGreetingTest extends SocketSslGreetingTest { - public EpollDomainSocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx) { - super(serverCtx, clientCtx); + public EpollDomainSocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx, boolean delegate) { + super(serverCtx, clientCtx, delegate); } @Override diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollETSocketDataReadInitialStateTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollETSocketDataReadInitialStateTest.java new file mode 100644 index 000000000000..9c27f2503074 --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollETSocketDataReadInitialStateTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketDataReadInitialStateTest; + +import java.util.List; + +public class EpollETSocketDataReadInitialStateTest extends SocketDataReadInitialStateTest { + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.socket(); + } + + @Override + protected void configure(ServerBootstrap bootstrap, Bootstrap bootstrap2, ByteBufAllocator allocator) { + super.configure(bootstrap, bootstrap2, allocator); + bootstrap.option(EpollChannelOption.EPOLL_MODE, EpollMode.EDGE_TRIGGERED) + .childOption(EpollChannelOption.EPOLL_MODE, EpollMode.EDGE_TRIGGERED); + bootstrap2.option(EpollChannelOption.EPOLL_MODE, EpollMode.EDGE_TRIGGERED); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollETSocketStringEchoBusyWaitTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollETSocketStringEchoBusyWaitTest.java new file mode 100644 index 000000000000..e599e3bef881 --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollETSocketStringEchoBusyWaitTest.java @@ -0,0 +1,31 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; + +public class EpollETSocketStringEchoBusyWaitTest extends EpollSocketStringEchoBusyWaitTest { + + @Override + protected void configure(ServerBootstrap bootstrap, Bootstrap bootstrap2, ByteBufAllocator allocator) { + super.configure(bootstrap, bootstrap2, allocator); + bootstrap.option(EpollChannelOption.EPOLL_MODE, EpollMode.EDGE_TRIGGERED) + .childOption(EpollChannelOption.EPOLL_MODE, EpollMode.EDGE_TRIGGERED); + bootstrap2.option(EpollChannelOption.EPOLL_MODE, EpollMode.EDGE_TRIGGERED); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollEventLoopTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollEventLoopTest.java index 0fe824b334d4..4e51114422bc 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollEventLoopTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollEventLoopTest.java @@ -15,53 +15,52 @@ */ package io.netty.channel.epoll; +import io.netty.channel.DefaultSelectStrategyFactory; import io.netty.channel.EventLoop; import io.netty.channel.EventLoopGroup; +import io.netty.util.concurrent.DefaultThreadFactory; import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.RejectedExecutionHandlers; +import io.netty.util.concurrent.ThreadPerTaskExecutor; import org.junit.Test; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; public class EpollEventLoopTest { - @Test(timeout = 5000L) - public void testScheduleBigDelayOverMax() { - EventLoopGroup group = new EpollEventLoopGroup(1); + @Test + public void testScheduleBigDelayNotOverflow() { + final AtomicReference capture = new AtomicReference(); + + final EventLoopGroup group = new EpollEventLoop(null, + new ThreadPerTaskExecutor(new DefaultThreadFactory(getClass())), 0, + DefaultSelectStrategyFactory.INSTANCE.newSelectStrategy(), RejectedExecutionHandlers.reject()) { + @Override + void handleLoopException(Throwable t) { + capture.set(t); + super.handleLoopException(t); + } + }; - final EventLoop el = group.next(); try { - el.schedule(new Runnable() { + final EventLoop eventLoop = group.next(); + Future future = eventLoop.schedule(new Runnable() { @Override public void run() { // NOOP } - }, Integer.MAX_VALUE, TimeUnit.DAYS); - fail(); - } catch (IllegalArgumentException expected) { - // expected - } - - group.shutdownGracefully(); - } + }, Long.MAX_VALUE, TimeUnit.MILLISECONDS); - @Test - public void testScheduleBigDelay() { - EventLoopGroup group = new EpollEventLoopGroup(1); - - final EventLoop el = group.next(); - Future future = el.schedule(new Runnable() { - @Override - public void run() { - // NOOP - } - }, EpollEventLoop.MAX_SCHEDULED_DAYS, TimeUnit.DAYS); - - assertFalse(future.awaitUninterruptibly(1000)); - assertTrue(future.cancel(true)); - group.shutdownGracefully(); + assertFalse(future.awaitUninterruptibly(1000)); + assertTrue(future.cancel(true)); + assertNull(capture.get()); + } finally { + group.shutdownGracefully(); + } } } diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollLTSocketDataReadInitialStateTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollLTSocketDataReadInitialStateTest.java new file mode 100644 index 000000000000..78614edd3586 --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollLTSocketDataReadInitialStateTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketDataReadInitialStateTest; + +import java.util.List; + +public class EpollLTSocketDataReadInitialStateTest extends SocketDataReadInitialStateTest { + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.socket(); + } + + @Override + protected void configure(ServerBootstrap bootstrap, Bootstrap bootstrap2, ByteBufAllocator allocator) { + super.configure(bootstrap, bootstrap2, allocator); + bootstrap.option(EpollChannelOption.EPOLL_MODE, EpollMode.LEVEL_TRIGGERED) + .childOption(EpollChannelOption.EPOLL_MODE, EpollMode.LEVEL_TRIGGERED); + bootstrap2.option(EpollChannelOption.EPOLL_MODE, EpollMode.LEVEL_TRIGGERED); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollLTSocketStringEchoBusyWaitTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollLTSocketStringEchoBusyWaitTest.java new file mode 100644 index 000000000000..d6098774c0ab --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollLTSocketStringEchoBusyWaitTest.java @@ -0,0 +1,31 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; + +public class EpollLTSocketStringEchoBusyWaitTest extends EpollSocketStringEchoBusyWaitTest { + + @Override + protected void configure(ServerBootstrap bootstrap, Bootstrap bootstrap2, ByteBufAllocator allocator) { + super.configure(bootstrap, bootstrap2, allocator); + bootstrap.option(EpollChannelOption.EPOLL_MODE, EpollMode.LEVEL_TRIGGERED) + .childOption(EpollChannelOption.EPOLL_MODE, EpollMode.LEVEL_TRIGGERED); + bootstrap2.option(EpollChannelOption.EPOLL_MODE, EpollMode.LEVEL_TRIGGERED); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslClientRenegotiateTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslClientRenegotiateTest.java new file mode 100644 index 000000000000..3f69196f12fe --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslClientRenegotiateTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.handler.ssl.SslContext; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketSslClientRenegotiateTest; + +import java.util.List; + +public class EpollSocketSslClientRenegotiateTest extends SocketSslClientRenegotiateTest { + + public EpollSocketSslClientRenegotiateTest(SslContext serverCtx, SslContext clientCtx, boolean delegate) { + super(serverCtx, clientCtx, delegate); + } + + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.socket(); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslGreetingTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslGreetingTest.java index 21d86b4fc7fb..34bf98a150f6 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslGreetingTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslGreetingTest.java @@ -25,8 +25,8 @@ public class EpollSocketSslGreetingTest extends SocketSslGreetingTest { - public EpollSocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx) { - super(serverCtx, clientCtx); + public EpollSocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx, boolean delegate) { + super(serverCtx, clientCtx, delegate); } @Override diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslSessionReuseTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslSessionReuseTest.java new file mode 100644 index 000000000000..2b782247c8fe --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslSessionReuseTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.handler.ssl.SslContext; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketSslSessionReuseTest; + +import java.util.List; + +public class EpollSocketSslSessionReuseTest extends SocketSslSessionReuseTest { + + public EpollSocketSslSessionReuseTest(SslContext serverCtx, SslContext clientCtx) { + super(serverCtx, clientCtx); + } + + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.socket(); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketStringEchoBusyWaitTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketStringEchoBusyWaitTest.java new file mode 100644 index 000000000000..f67b6da94fa0 --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketStringEchoBusyWaitTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SelectStrategy; +import io.netty.channel.SelectStrategyFactory; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.TestsuitePermutation.BootstrapComboFactory; +import io.netty.testsuite.transport.TestsuitePermutation.BootstrapFactory; +import io.netty.testsuite.transport.socket.SocketStringEchoTest; +import io.netty.util.IntSupplier; +import io.netty.util.concurrent.DefaultThreadFactory; + +public class EpollSocketStringEchoBusyWaitTest extends SocketStringEchoTest { + + private static EventLoopGroup EPOLL_LOOP; + + @BeforeClass + public static void setup() throws Exception { + EPOLL_LOOP = new EpollEventLoopGroup(2, new DefaultThreadFactory("testsuite-epoll-busy-wait", true), + new SelectStrategyFactory() { + @Override + public SelectStrategy newSelectStrategy() { + return new SelectStrategy() { + @Override + public int calculateStrategy(IntSupplier selectSupplier, boolean hasTasks) { + return SelectStrategy.BUSY_WAIT; + } + }; + } + }); + } + + @AfterClass + public static void teardown() throws Exception { + if (EPOLL_LOOP != null) { + EPOLL_LOOP.shutdownGracefully(); + } + } + + @Override + protected List> newFactories() { + List> list = + new ArrayList>(); + final BootstrapFactory sbf = serverSocket(); + final BootstrapFactory cbf = clientSocket(); + list.add(new BootstrapComboFactory() { + @Override + public ServerBootstrap newServerInstance() { + return sbf.newInstance(); + } + + @Override + public Bootstrap newClientInstance() { + return cbf.newInstance(); + } + }); + + return list; + } + + private static BootstrapFactory serverSocket() { + return new BootstrapFactory() { + @Override + public ServerBootstrap newInstance() { + return new ServerBootstrap().group(EPOLL_LOOP, EPOLL_LOOP).channel(EpollServerSocketChannel.class); + } + }; + } + + private static BootstrapFactory clientSocket() { + return new BootstrapFactory() { + @Override + public Bootstrap newInstance() { + return new Bootstrap().group(EPOLL_LOOP).channel(EpollSocketChannel.class); + } + }; + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSpliceTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSpliceTest.java index 8566a5f009ba..79962285b85a 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSpliceTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSpliceTest.java @@ -296,7 +296,7 @@ private static class SpliceHandler extends ChannelInboundHandlerAdapter { volatile ChannelFuture future; final AtomicReference exception = new AtomicReference(); - public SpliceHandler(File file) { + SpliceHandler(File file) { this.file = file; } diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java index a2d4fb2d407a..5a9cb19a1ff7 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollTest.java @@ -15,13 +15,58 @@ */ package io.netty.channel.epoll; -import org.junit.Assert; +import io.netty.channel.unix.FileDescriptor; import org.junit.Test; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + public class EpollTest { @Test public void testIsAvailable() { - Assert.assertTrue(Epoll.isAvailable()); + assertTrue(Epoll.isAvailable()); + } + + // Testcase for https://github.com/netty/netty/issues/8444 + @Test(timeout = 5000) + public void testEpollWaitWithTimeOutMinusOne() throws Exception { + final EpollEventArray eventArray = new EpollEventArray(8); + try { + final FileDescriptor epoll = Native.newEpollCreate(); + final FileDescriptor timerFd = Native.newTimerFd(); + final FileDescriptor eventfd = Native.newEventFd(); + Native.epollCtlAdd(epoll.intValue(), timerFd.intValue(), Native.EPOLLIN); + Native.epollCtlAdd(epoll.intValue(), eventfd.intValue(), Native.EPOLLIN); + + final AtomicReference ref = new AtomicReference(); + Thread t = new Thread(new Runnable() { + @Override + public void run() { + try { + assertEquals(1, Native.epollWait(epoll, eventArray, timerFd, -1, -1)); + // This should have been woken up because of eventfd_write. + assertEquals(eventfd.intValue(), eventArray.fd(0)); + } catch (Throwable cause) { + ref.set(cause); + } + } + }); + t.start(); + t.join(1000); + assertTrue(t.isAlive()); + Native.eventFdWrite(eventfd.intValue(), 1); + + t.join(); + assertNull(ref.get()); + epoll.close(); + timerFd.close(); + eventfd.close(); + } finally { + eventArray.free(); + } } } diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollWriteBeforeRegisteredTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollWriteBeforeRegisteredTest.java new file mode 100644 index 000000000000..b1943fd2f804 --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollWriteBeforeRegisteredTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.WriteBeforeRegisteredTest; + +import java.util.List; + +public class EpollWriteBeforeRegisteredTest extends WriteBeforeRegisteredTest { + + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.clientSocket(); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/LibAIOTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/LibAIOTest.java index 9c1575709f19..da1bb7957834 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/LibAIOTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/LibAIOTest.java @@ -531,6 +531,7 @@ static ByteBuffer allocateAlignedByteBuffer(int capacity, long align) { } } + @SuppressForbidden(reason = "libaio") static void freeAlignedByteBuffer(ByteBuffer buffer) { PlatformDependent.freeDirectNoCleaner((ByteBuffer) ((DirectBuffer) buffer).attachment()); } diff --git a/transport-native-kqueue/pom.xml b/transport-native-kqueue/pom.xml index f0bfc2e2da5a..fbb9284aa6bd 100644 --- a/transport-native-kqueue/pom.xml +++ b/transport-native-kqueue/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-transport-native-kqueue @@ -80,17 +80,16 @@ ${jni.compiler.args.ldflags} ${jni.compiler.args.cflags} - - MACOSX_DEPLOYMENT_TARGET=10.2 + MACOSX_DEPLOYMENT_TARGET=10.6 generate build - compile @@ -198,7 +197,6 @@ generate build - compile @@ -305,7 +303,6 @@ generate build - compile @@ -361,7 +358,7 @@ ${project.build.directory}/unix-common-lib ${unix.common.lib.dir}/META-INF/native/lib ${unix.common.lib.dir}/META-INF/native/include - CFLAGS=-O3 -Werror -fno-omit-frame-pointer -Wunused-variable -I${unix.common.include.unpacked.dir} + CFLAGS=-O3 -Werror -fno-omit-frame-pointer -Wunused-variable -fvisibility=hidden -I${unix.common.include.unpacked.dir} LDFLAGS=-z now -L${unix.common.lib.unpacked.dir} -Wl,--whole-archive -l${unix.common.lib.name} -Wl,--no-whole-archive true @@ -379,12 +376,12 @@ io.netty - netty-transport-native-unix-common + netty-transport ${project.version} io.netty - netty-transport + netty-transport-native-unix-common ${project.version} diff --git a/transport-native-kqueue/src/main/c/netty_kqueue_eventarray.c b/transport-native-kqueue/src/main/c/netty_kqueue_eventarray.c index ad7a20d6546a..dd38cf544c75 100644 --- a/transport-native-kqueue/src/main/c/netty_kqueue_eventarray.c +++ b/transport-native-kqueue/src/main/c/netty_kqueue_eventarray.c @@ -24,104 +24,26 @@ #include "netty_unix_jni.h" #include "netty_unix_util.h" -jfieldID kqueueJniPtrFieldId = NULL; - -static void netty_kqueue_eventarray_evSet(JNIEnv* env, jclass clzz, jlong keventAddress, jobject channel, jint ident, jshort filter, jshort flags, jint fflags) { - // Create a global pointer, cast it as a long, and retain it in java to re-use and free later. - jlong jniSelfPtr = (*env)->GetLongField(env, channel, kqueueJniPtrFieldId); - if (jniSelfPtr == 0) { - jniSelfPtr = (jlong) (*env)->NewGlobalRef(env, channel); - (*env)->SetLongField(env, channel, kqueueJniPtrFieldId, jniSelfPtr); - } else if ((flags & EV_DELETE) != 0) { - // If the event is deleted, make sure it no longer has a reference to the jniSelfPtr because it shouldn't be used after this point. - jniSelfPtr = 0; - } - EV_SET((struct kevent*) keventAddress, ident, filter, flags, fflags, 0, (jobject) jniSelfPtr); -} - -static jobject netty_kqueue_eventarray_getChannel(JNIEnv* env, jclass clazz, jlong keventAddress) { - struct kevent* event = (struct kevent*) keventAddress; - return event->udata == NULL ? NULL : (jobject) event->udata; -} - -static void netty_kqueue_eventarray_deleteGlobalRefs(JNIEnv* env, jclass clazz, jlong channelAddressStart, jlong channelAddressEnd) { - // Iterate over an array of longs, which are really pointers to the jobject NewGlobalRef created above in evSet - // and delete each one. The field has already been set to 0 in java. - jlong* itr = (jlong*) channelAddressStart; - const jlong* end = (jlong*) channelAddressEnd; - for (; itr != end; ++itr) { - (*env)->DeleteGlobalRef(env, (jobject) *itr); - } +static void netty_kqueue_eventarray_evSet(JNIEnv* env, jclass clzz, jlong keventAddress, jint ident, jshort filter, jshort flags, jint fflags) { + EV_SET((struct kevent*) keventAddress, ident, filter, flags, fflags, 0, NULL); } // JNI Method Registration Table Begin static const JNINativeMethod fixed_method_table[] = { - { "deleteGlobalRefs", "(JJ)V", (void *) netty_kqueue_eventarray_deleteGlobalRefs } - // "evSet" has a dynamic signature - // "getChannel" has a dynamic signature + { "evSet", "(JISSI)V", (void *) netty_kqueue_eventarray_evSet } }; static const jint fixed_method_table_size = sizeof(fixed_method_table) / sizeof(fixed_method_table[0]); -static jint dynamicMethodsTableSize() { - return fixed_method_table_size + 2; -} - -static JNINativeMethod* createDynamicMethodsTable(const char* packagePrefix) { - JNINativeMethod* dynamicMethods = malloc(sizeof(JNINativeMethod) * dynamicMethodsTableSize()); - memcpy(dynamicMethods, fixed_method_table, sizeof(fixed_method_table)); - char* dynamicTypeName = netty_unix_util_prepend(packagePrefix, "io/netty/channel/kqueue/AbstractKQueueChannel;ISSI)V"); - JNINativeMethod* dynamicMethod = &dynamicMethods[fixed_method_table_size]; - dynamicMethod->name = "evSet"; - dynamicMethod->signature = netty_unix_util_prepend("(JL", dynamicTypeName); - dynamicMethod->fnPtr = (void *) netty_kqueue_eventarray_evSet; - free(dynamicTypeName); - - ++dynamicMethod; - dynamicTypeName = netty_unix_util_prepend(packagePrefix, "io/netty/channel/kqueue/AbstractKQueueChannel;"); - dynamicMethod->name = "getChannel"; - dynamicMethod->signature = netty_unix_util_prepend("(J)L", dynamicTypeName); - dynamicMethod->fnPtr = (void *) netty_kqueue_eventarray_getChannel; - free(dynamicTypeName); - return dynamicMethods; -} - -static void freeDynamicMethodsTable(JNINativeMethod* dynamicMethods) { - jint fullMethodTableSize = dynamicMethodsTableSize(); - jint i = fixed_method_table_size; - for (; i < fullMethodTableSize; ++i) { - free(dynamicMethods[i].signature); - } - free(dynamicMethods); -} // JNI Method Registration Table End jint netty_kqueue_eventarray_JNI_OnLoad(JNIEnv* env, const char* packagePrefix) { - JNINativeMethod* dynamicMethods = createDynamicMethodsTable(packagePrefix); if (netty_unix_util_register_natives(env, packagePrefix, "io/netty/channel/kqueue/KQueueEventArray", - dynamicMethods, - dynamicMethodsTableSize()) != 0) { - freeDynamicMethodsTable(dynamicMethods); - return JNI_ERR; - } - freeDynamicMethodsTable(dynamicMethods); - dynamicMethods = NULL; - - char* nettyClassName = netty_unix_util_prepend(packagePrefix, "io/netty/channel/kqueue/AbstractKQueueChannel"); - jclass kqueueChannelCls = (*env)->FindClass(env, nettyClassName); - free(nettyClassName); - nettyClassName = NULL; - if (kqueueChannelCls == NULL) { + fixed_method_table, + fixed_method_table_size) != 0) { return JNI_ERR; } - - kqueueJniPtrFieldId = (*env)->GetFieldID(env, kqueueChannelCls, "jniSelfPtr", "J"); - if (kqueueJniPtrFieldId == NULL) { - netty_unix_errors_throwRuntimeException(env, "failed to get field ID: AbstractKQueueChannel.jniSelfPtr"); - return JNI_ERR; - } - return NETTY_JNI_VERSION; } diff --git a/transport-native-kqueue/src/main/c/netty_kqueue_native.c b/transport-native-kqueue/src/main/c/netty_kqueue_native.c index 94899e037e48..a2a5a874a23d 100644 --- a/transport-native-kqueue/src/main/c/netty_kqueue_native.c +++ b/transport-native-kqueue/src/main/c/netty_kqueue_native.c @@ -28,6 +28,7 @@ #include "netty_kqueue_bsdsocket.h" #include "netty_kqueue_eventarray.h" +#include "netty_unix_buffer.h" #include "netty_unix_errors.h" #include "netty_unix_filedescriptor.h" #include "netty_unix_jni.h" @@ -65,7 +66,7 @@ #endif /* NOTE_DISCONNECTED */ #endif /* __APPLE__ */ -clockid_t waitClockId = 0; // initialized by netty_unix_util_initialize_wait_clock +static clockid_t waitClockId = 0; // initialized by netty_unix_util_initialize_wait_clock static jint netty_kqueue_native_kqueueCreate(JNIEnv* env, jclass clazz) { jint kq = kqueue(); @@ -268,45 +269,93 @@ static const jint fixed_method_table_size = sizeof(fixed_method_table) / sizeof( // JNI Method Registration Table End static jint netty_kqueue_native_JNI_OnLoad(JNIEnv* env, const char* packagePrefix) { + int limitsOnLoadCalled = 0; + int errorsOnLoadCalled = 0; + int filedescriptorOnLoadCalled = 0; + int socketOnLoadCalled = 0; + int bufferOnLoadCalled = 0; + int bsdsocketOnLoadCalled = 0; + int eventarrayOnLoadCalled = 0; + // We must register the statically referenced methods first! if (netty_unix_util_register_natives(env, packagePrefix, "io/netty/channel/kqueue/KQueueStaticallyReferencedJniMethods", statically_referenced_fixed_method_table, statically_referenced_fixed_method_table_size) != 0) { - return JNI_ERR; + goto error; } // Register the methods which are not referenced by static member variables if (netty_unix_util_register_natives(env, packagePrefix, "io/netty/channel/kqueue/Native", fixed_method_table, fixed_method_table_size) != 0) { - return JNI_ERR; + goto error; } // Load all c modules that we depend upon if (netty_unix_limits_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; } + limitsOnLoadCalled = 1; + if (netty_unix_errors_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; } + errorsOnLoadCalled = 1; + if (netty_unix_filedescriptor_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; } + filedescriptorOnLoadCalled = 1; + if (netty_unix_socket_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; + } + socketOnLoadCalled = 1; + + if (netty_unix_buffer_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { + goto error; } + bufferOnLoadCalled = 1; + if (netty_kqueue_bsdsocket_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; } + bsdsocketOnLoadCalled = 1; + if (netty_kqueue_eventarray_JNI_OnLoad(env, packagePrefix) == JNI_ERR) { - return JNI_ERR; + goto error; } + eventarrayOnLoadCalled = 1; + // Initialize this module if (!netty_unix_util_initialize_wait_clock(&waitClockId)) { - fprintf(stderr, "FATAL: could not find a clock for clock_gettime!\n"); - return JNI_ERR; + fprintf(stderr, "FATAL: could not find a clock for clock_gettime!\n"); + goto error; } return NETTY_JNI_VERSION; +error: + if (limitsOnLoadCalled == 1) { + netty_unix_limits_JNI_OnUnLoad(env); + } + if (errorsOnLoadCalled == 1) { + netty_unix_errors_JNI_OnUnLoad(env); + } + if (filedescriptorOnLoadCalled == 1) { + netty_unix_filedescriptor_JNI_OnUnLoad(env); + } + if (socketOnLoadCalled == 1) { + netty_unix_socket_JNI_OnUnLoad(env); + } + if (bufferOnLoadCalled == 1) { + netty_unix_buffer_JNI_OnUnLoad(env); + } + if (bsdsocketOnLoadCalled == 1) { + netty_kqueue_bsdsocket_JNI_OnUnLoad(env); + } + if (eventarrayOnLoadCalled == 1) { + netty_kqueue_eventarray_JNI_OnUnLoad(env); + } + return JNI_ERR; } static void netty_kqueue_native_JNI_OnUnLoad(JNIEnv* env) { @@ -314,12 +363,12 @@ static void netty_kqueue_native_JNI_OnUnLoad(JNIEnv* env) { netty_unix_errors_JNI_OnUnLoad(env); netty_unix_filedescriptor_JNI_OnUnLoad(env); netty_unix_socket_JNI_OnUnLoad(env); + netty_unix_buffer_JNI_OnUnLoad(env); netty_kqueue_bsdsocket_JNI_OnUnLoad(env); netty_kqueue_eventarray_JNI_OnUnLoad(env); } -// Invoked by the JVM when statically linked -jint JNI_OnLoad_netty_transport_native_kqueue(JavaVM* vm, void* reserved) { +static jint JNI_OnLoad_netty_transport_native_kqueue0(JavaVM* vm, void* reserved) { JNIEnv* env; if ((*vm)->GetEnv(vm, (void**) &env, NETTY_JNI_VERSION) != JNI_OK) { return JNI_ERR; @@ -351,14 +400,7 @@ jint JNI_OnLoad_netty_transport_native_kqueue(JavaVM* vm, void* reserved) { return ret; } -#ifndef NETTY_BUILD_STATIC -JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void* reserved) { - return JNI_OnLoad_netty_transport_native_kqueue(vm, reserved); -} -#endif /* NETTY_BUILD_STATIC */ - -// Invoked by the JVM when statically linked -void JNI_OnUnload_netty_transport_native_kqueue(JavaVM* vm, void* reserved) { +static void JNI_OnUnload_netty_transport_native_kqueue0(JavaVM* vm, void* reserved) { JNIEnv* env; if ((*vm)->GetEnv(vm, (void**) &env, NETTY_JNI_VERSION) != JNI_OK) { // Something is wrong but nothing we can do about this :( @@ -367,8 +409,25 @@ void JNI_OnUnload_netty_transport_native_kqueue(JavaVM* vm, void* reserved) { netty_kqueue_native_JNI_OnUnLoad(env); } +// We build with -fvisibility=hidden so ensure we mark everything that needs to be visible with JNIEXPORT +// http://mail.openjdk.java.net/pipermail/core-libs-dev/2013-February/014549.html + +// Invoked by the JVM when statically linked +JNIEXPORT jint JNI_OnLoad_netty_transport_native_kqueue(JavaVM* vm, void* reserved) { + return JNI_OnLoad_netty_transport_native_kqueue0(vm, reserved); +} + +// Invoked by the JVM when statically linked +JNIEXPORT void JNI_OnUnload_netty_transport_native_kqueue(JavaVM* vm, void* reserved) { + JNI_OnUnload_netty_transport_native_kqueue0(vm, reserved); +} + #ifndef NETTY_BUILD_STATIC +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void* reserved) { + return JNI_OnLoad_netty_transport_native_kqueue0(vm, reserved); +} + JNIEXPORT void JNI_OnUnload(JavaVM* vm, void* reserved) { - return JNI_OnUnload_netty_transport_native_kqueue(vm, reserved); + return JNI_OnUnload_netty_transport_native_kqueue0(vm, reserved); } #endif /* NETTY_BUILD_STATIC */ diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueChannel.java index 0a6eb59f7a56..ff5285794fa4 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueChannel.java @@ -65,19 +65,10 @@ abstract class AbstractKQueueChannel extends AbstractChannel implements UnixChan private SocketAddress requestedRemoteAddress; final BsdSocket socket; - private boolean readFilterEnabled = true; + private boolean readFilterEnabled; private boolean writeFilterEnabled; boolean readReadyRunnablePending; boolean inputClosedSeenErrorOnRead; - /** - * This member variable means we don't have to have a map in {@link KQueueEventLoop} which associates the FDs - * from kqueue to instances of this class. This field will be initialized by JNI when modifying kqueue events. - * If there is no global reference when JNI gets a kqueue evSet call (aka this field is 0) then a global reference - * will be created and the address will be saved in this member variable. Then when we process a kevent in Java - * we can ask JNI to give us the {@link AbstractKQueueChannel} that corresponds to that event. - */ - long jniSelfPtr; - protected volatile boolean active; private volatile SocketAddress local; private volatile SocketAddress remote; @@ -133,35 +124,7 @@ protected void doClose() throws Exception { // Even if we allow half closed sockets we should give up on reading. Otherwise we may allow a read attempt on a // socket which has not even been connected yet. This has been observed to block during unit tests. inputClosedSeenErrorOnRead = true; - try { - if (isRegistered()) { - // The FD will be closed, which should take care of deleting any associated events from kqueue, but - // since we rely upon jniSelfRef to be consistent we make sure that we clear this reference out for - // all events which are pending in kqueue to avoid referencing a deleted pointer at a later time. - - // Need to check if we are on the EventLoop as doClose() may be triggered by the GlobalEventExecutor - // if SO_LINGER is used. - // - // See https://github.com/netty/netty/issues/7159 - EventLoop loop = eventLoop(); - if (loop.inEventLoop()) { - doDeregister(); - } else { - loop.execute(new Runnable() { - @Override - public void run() { - try { - doDeregister(); - } catch (Throwable cause) { - pipeline().fireExceptionCaught(cause); - } - } - }); - } - } - } finally { - socket.close(); - } + socket.close(); } @Override @@ -187,9 +150,6 @@ protected void doDeregister() throws Exception { evSet0(Native.EVFILT_SOCK, Native.EV_DELETE, 0); ((KQueueEventLoop) eventLoop()).remove(this); - - // Set the filters back to the initial state in case this channel is registered with another event loop. - readFilterEnabled = true; } @Override @@ -216,6 +176,9 @@ protected void doRegister() throws Exception { // make sure the readReadyRunnablePending variable is reset so we will be able to execute the Runnable on the // new EventLoop. readReadyRunnablePending = false; + + ((KQueueEventLoop) eventLoop()).add(this); + // Add the write event first so we get notified of connection refused on the client side! if (writeFilterEnabled) { evSet0(Native.EVFILT_WRITE, Native.EV_ADD_CLEAR_ENABLE); @@ -401,19 +364,14 @@ final void readReady(long numberBytesPending) { abstract void readReady(KQueueRecvByteAllocatorHandle allocHandle); - final void readReadyBefore() { maybeMoreDataToRead = false; } + final void readReadyBefore() { + maybeMoreDataToRead = false; + } final void readReadyFinally(ChannelConfig config) { maybeMoreDataToRead = allocHandle.maybeMoreDataToRead(); - // Check if there is a readPending which was not processed yet. - // This could be for two reasons: - // * The user called Channel.read() or ChannelHandlerContext.read() in channelRead(...) method - // * The user called Channel.read() or ChannelHandlerContext.read() in channelReadComplete(...) method - // - // See https://github.com/netty/netty/issues/2254 - if (!readPending && !config.isAutoRead()) { - clearReadFilter0(); - } else if (readPending && maybeMoreDataToRead) { + + if (allocHandle.isReadEOF() || (readPending && maybeMoreDataToRead)) { // trigger a read again as there may be something left to read and because of ET we // will not get notified again until we read everything from the socket // @@ -422,6 +380,14 @@ final void readReadyFinally(ChannelConfig config) { // to false before every read operation to prevent re-entry into readReady() we will not read from // the underlying OS again unless the user happens to call read again. executeReadReadyRunnable(config); + } else if (!readPending && !config.isAutoRead()) { + // Check if there is a readPending which was not processed yet. + // This could be for two reasons: + // * The user called Channel.read() or ChannelHandlerContext.read() in channelRead(...) method + // * The user called Channel.read() or ChannelHandlerContext.read() in channelReadComplete(...) method + // + // See https://github.com/netty/netty/issues/2254 + clearReadFilter0(); } } diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java index 46f4fc72ee2c..4b7f4e06e5b2 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java @@ -33,7 +33,6 @@ import io.netty.channel.unix.IovArray; import io.netty.channel.unix.SocketWritableByteChannel; import io.netty.channel.unix.UnixChannelUtil; -import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; @@ -211,12 +210,13 @@ private int writeBytesMultiple( */ private int writeDefaultFileRegion(ChannelOutboundBuffer in, DefaultFileRegion region) throws Exception { final long regionCount = region.count(); - if (region.transferred() >= regionCount) { + final long offset = region.transferred(); + + if (offset >= regionCount) { in.remove(); return 0; } - final long offset = region.transferred(); final long flushedAmount = socket.sendFile(region, region.position(), offset, regionCount - offset); if (flushedAmount > 0) { in.progress(flushedAmount); @@ -224,6 +224,8 @@ private int writeDefaultFileRegion(ChannelOutboundBuffer in, DefaultFileRegion r in.remove(); } return 1; + } else if (flushedAmount == 0) { + validateFileRegion(region, offset); } return WRITE_STATUS_SNDBUF_FULL; } @@ -345,22 +347,13 @@ protected int doWriteSingle(ChannelOutboundBuffer in) throws Exception { */ private int doWriteMultiple(ChannelOutboundBuffer in) throws Exception { final long maxBytesPerGatheringWrite = config().getMaxBytesPerGatheringWrite(); - if (PlatformDependent.hasUnsafe()) { - IovArray array = ((KQueueEventLoop) eventLoop()).cleanArray(); - array.maxBytes(maxBytesPerGatheringWrite); - in.forEachFlushedMessage(array); - - if (array.count() >= 1) { - // TODO: Handle the case where cnt == 1 specially. - return writeBytesMultiple(in, array); - } - } else { - ByteBuffer[] buffers = in.nioBuffers(); - int cnt = in.nioBufferCount(); - if (cnt >= 1) { - // TODO: Handle the case where cnt == 1 specially. - return writeBytesMultiple(in, buffers, cnt, in.nioBufferSize(), maxBytesPerGatheringWrite); - } + IovArray array = ((KQueueEventLoop) eventLoop()).cleanArray(); + array.maxBytes(maxBytesPerGatheringWrite); + in.forEachFlushedMessage(array); + + if (array.count() >= 1) { + // TODO: Handle the case where cnt == 1 specially. + return writeBytesMultiple(in, array); } // cnt == 0, which means the outbound buffer contained empty buffers only. in.removeBytes(0); diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueue.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueue.java index b5ec5a51bb92..b5a772fab799 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueue.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueue.java @@ -16,7 +16,6 @@ package io.netty.channel.kqueue; import io.netty.channel.unix.FileDescriptor; -import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.SystemPropertyUtil; import io.netty.util.internal.UnstableApi; @@ -48,15 +47,7 @@ public final class KQueue { } } - if (cause != null) { - UNAVAILABILITY_CAUSE = cause; - } else { - UNAVAILABILITY_CAUSE = PlatformDependent.hasUnsafe() - ? null - : new IllegalStateException( - "sun.misc.Unsafe not available", - PlatformDependent.getUnsafeUnavailabilityCause()); - } + UNAVAILABILITY_CAUSE = cause; } /** diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueChannelConfig.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueChannelConfig.java index 878663c5e746..56748e730cd7 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueChannelConfig.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueChannelConfig.java @@ -31,13 +31,11 @@ @UnstableApi public class KQueueChannelConfig extends DefaultChannelConfig { - final AbstractKQueueChannel channel; private volatile boolean transportProvidesGuess; private volatile long maxBytesPerGatheringWrite = SSIZE_MAX; KQueueChannelConfig(AbstractKQueueChannel channel) { super(channel); - this.channel = channel; } @Override @@ -69,7 +67,7 @@ public boolean setOption(ChannelOption option, T value) { } /** - * If this is {@code true} then the {@link RecvByteBufAllocator.Handle#guess()} will be overriden to always attempt + * If this is {@code true} then the {@link RecvByteBufAllocator.Handle#guess()} will be overridden to always attempt * to read as many bytes as kqueue says are available. */ public KQueueChannelConfig setRcvAllocTransportProvidesGuess(boolean transportProvidesGuess) { @@ -78,7 +76,7 @@ public KQueueChannelConfig setRcvAllocTransportProvidesGuess(boolean transportPr } /** - * If this is {@code true} then the {@link RecvByteBufAllocator.Handle#guess()} will be overriden to always attempt + * If this is {@code true} then the {@link RecvByteBufAllocator.Handle#guess()} will be overridden to always attempt * to read as many bytes as kqueue says are available. */ public boolean getRcvAllocTransportProvidesGuess() { @@ -154,7 +152,7 @@ public KQueueChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimato @Override protected final void autoReadCleared() { - channel.clearReadFilter(); + ((AbstractKQueueChannel) channel).clearReadFilter(); } final void setMaxBytesPerGatheringWrite(long maxBytesPerGatheringWrite) { diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueChannelOption.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueChannelOption.java index be5dc4e887aa..e70924789e19 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueChannelOption.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueChannelOption.java @@ -27,7 +27,7 @@ public final class KQueueChannelOption extends UnixChannelOption { public static final ChannelOption SO_ACCEPTFILTER = valueOf(KQueueChannelOption.class, "SO_ACCEPTFILTER"); /** - * If this is {@code true} then the {@link RecvByteBufAllocator.Handle#guess()} will be overriden to always attempt + * If this is {@code true} then the {@link RecvByteBufAllocator.Handle#guess()} will be overridden to always attempt * to read as many bytes as kqueue says are available. */ public static final ChannelOption RCV_ALLOC_TRANSPORT_PROVIDES_GUESS = diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java index 887951a28f77..b2925c58c2c4 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java @@ -28,6 +28,7 @@ import io.netty.channel.socket.DatagramChannelConfig; import io.netty.channel.socket.DatagramPacket; import io.netty.channel.unix.DatagramSocketAddress; +import io.netty.channel.unix.Errors; import io.netty.channel.unix.IovArray; import io.netty.channel.unix.UnixChannelUtil; import io.netty.util.internal.StringUtil; @@ -37,6 +38,7 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.NetworkInterface; +import java.net.PortUnreachableException; import java.net.SocketAddress; import java.net.SocketException; import java.nio.ByteBuffer; @@ -420,41 +422,72 @@ void readReady(KQueueRecvByteAllocatorHandle allocHandle) { Throwable exception = null; try { - ByteBuf data = null; + ByteBuf byteBuf = null; try { + boolean connected = isConnected(); do { - data = allocHandle.allocate(allocator); - allocHandle.attemptedBytesRead(data.writableBytes()); - final DatagramSocketAddress remoteAddress; - if (data.hasMemoryAddress()) { - // has a memory address so use optimized call - remoteAddress = socket.recvFromAddress(data.memoryAddress(), data.writerIndex(), - data.capacity()); + byteBuf = allocHandle.allocate(allocator); + allocHandle.attemptedBytesRead(byteBuf.writableBytes()); + + final DatagramPacket packet; + if (connected) { + try { + allocHandle.lastBytesRead(doReadBytes(byteBuf)); + } catch (Errors.NativeIoException e) { + // We need to correctly translate connect errors to match NIO behaviour. + if (e.expectedErr() == Errors.ERROR_ECONNREFUSED_NEGATIVE) { + PortUnreachableException error = new PortUnreachableException(e.getMessage()); + error.initCause(e); + throw error; + } + throw e; + } + if (allocHandle.lastBytesRead() <= 0) { + // nothing was read, release the buffer. + byteBuf.release(); + byteBuf = null; + break; + } + packet = new DatagramPacket(byteBuf, + (InetSocketAddress) localAddress(), (InetSocketAddress) remoteAddress()); } else { - ByteBuffer nioData = data.internalNioBuffer(data.writerIndex(), data.writableBytes()); - remoteAddress = socket.recvFrom(nioData, nioData.position(), nioData.limit()); - } - - if (remoteAddress == null) { - allocHandle.lastBytesRead(-1); - data.release(); - data = null; - break; + final DatagramSocketAddress remoteAddress; + if (byteBuf.hasMemoryAddress()) { + // has a memory address so use optimized call + remoteAddress = socket.recvFromAddress(byteBuf.memoryAddress(), byteBuf.writerIndex(), + byteBuf.capacity()); + } else { + ByteBuffer nioData = byteBuf.internalNioBuffer( + byteBuf.writerIndex(), byteBuf.writableBytes()); + remoteAddress = socket.recvFrom(nioData, nioData.position(), nioData.limit()); + } + + if (remoteAddress == null) { + allocHandle.lastBytesRead(-1); + byteBuf.release(); + byteBuf = null; + break; + } + InetSocketAddress localAddress = remoteAddress.localAddress(); + if (localAddress == null) { + localAddress = (InetSocketAddress) localAddress(); + } + allocHandle.lastBytesRead(remoteAddress.receivedAmount()); + byteBuf.writerIndex(byteBuf.writerIndex() + allocHandle.lastBytesRead()); + + packet = new DatagramPacket(byteBuf, localAddress, remoteAddress); } allocHandle.incMessagesRead(1); - allocHandle.lastBytesRead(remoteAddress.receivedAmount()); - data.writerIndex(data.writerIndex() + allocHandle.lastBytesRead()); readPending = false; - pipeline.fireChannelRead( - new DatagramPacket(data, (InetSocketAddress) localAddress(), remoteAddress)); + pipeline.fireChannelRead(packet); - data = null; + byteBuf = null; } while (allocHandle.continueReading()); } catch (Throwable t) { - if (data != null) { - data.release(); + if (byteBuf != null) { + byteBuf.release(); } exception = t; } diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannelConfig.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannelConfig.java index c64417485b96..478d5544d164 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannelConfig.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannelConfig.java @@ -45,12 +45,10 @@ @UnstableApi public final class KQueueDatagramChannelConfig extends KQueueChannelConfig implements DatagramChannelConfig { private static final RecvByteBufAllocator DEFAULT_RCVBUF_ALLOCATOR = new FixedRecvByteBufAllocator(2048); - private final KQueueDatagramChannel datagramChannel; private boolean activeOnOpen; KQueueDatagramChannelConfig(KQueueDatagramChannel channel) { super(channel); - this.datagramChannel = channel; setRecvByteBufAllocator(DEFAULT_RCVBUF_ALLOCATOR); } @@ -153,7 +151,7 @@ boolean getActiveOnOpen() { */ public boolean isReusePort() { try { - return datagramChannel.socket.isReusePort(); + return ((KQueueDatagramChannel) channel).socket.isReusePort(); } catch (IOException e) { throw new ChannelException(e); } @@ -168,7 +166,7 @@ public boolean isReusePort() { */ public KQueueDatagramChannelConfig setReusePort(boolean reusePort) { try { - datagramChannel.socket.setReusePort(reusePort); + ((KQueueDatagramChannel) channel).socket.setReusePort(reusePort); return this; } catch (IOException e) { throw new ChannelException(e); @@ -253,7 +251,7 @@ public KQueueDatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) @Override public int getSendBufferSize() { try { - return datagramChannel.socket.getSendBufferSize(); + return ((KQueueDatagramChannel) channel).socket.getSendBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -262,7 +260,7 @@ public int getSendBufferSize() { @Override public KQueueDatagramChannelConfig setSendBufferSize(int sendBufferSize) { try { - datagramChannel.socket.setSendBufferSize(sendBufferSize); + ((KQueueDatagramChannel) channel).socket.setSendBufferSize(sendBufferSize); return this; } catch (IOException e) { throw new ChannelException(e); @@ -272,7 +270,7 @@ public KQueueDatagramChannelConfig setSendBufferSize(int sendBufferSize) { @Override public int getReceiveBufferSize() { try { - return datagramChannel.socket.getReceiveBufferSize(); + return ((KQueueDatagramChannel) channel).socket.getReceiveBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -281,7 +279,7 @@ public int getReceiveBufferSize() { @Override public KQueueDatagramChannelConfig setReceiveBufferSize(int receiveBufferSize) { try { - datagramChannel.socket.setReceiveBufferSize(receiveBufferSize); + ((KQueueDatagramChannel) channel).socket.setReceiveBufferSize(receiveBufferSize); return this; } catch (IOException e) { throw new ChannelException(e); @@ -291,7 +289,7 @@ public KQueueDatagramChannelConfig setReceiveBufferSize(int receiveBufferSize) { @Override public int getTrafficClass() { try { - return datagramChannel.socket.getTrafficClass(); + return ((KQueueDatagramChannel) channel).socket.getTrafficClass(); } catch (IOException e) { throw new ChannelException(e); } @@ -300,7 +298,7 @@ public int getTrafficClass() { @Override public KQueueDatagramChannelConfig setTrafficClass(int trafficClass) { try { - datagramChannel.socket.setTrafficClass(trafficClass); + ((KQueueDatagramChannel) channel).socket.setTrafficClass(trafficClass); return this; } catch (IOException e) { throw new ChannelException(e); @@ -310,7 +308,7 @@ public KQueueDatagramChannelConfig setTrafficClass(int trafficClass) { @Override public boolean isReuseAddress() { try { - return datagramChannel.socket.isReuseAddress(); + return ((KQueueDatagramChannel) channel).socket.isReuseAddress(); } catch (IOException e) { throw new ChannelException(e); } @@ -319,7 +317,7 @@ public boolean isReuseAddress() { @Override public KQueueDatagramChannelConfig setReuseAddress(boolean reuseAddress) { try { - datagramChannel.socket.setReuseAddress(reuseAddress); + ((KQueueDatagramChannel) channel).socket.setReuseAddress(reuseAddress); return this; } catch (IOException e) { throw new ChannelException(e); @@ -329,7 +327,7 @@ public KQueueDatagramChannelConfig setReuseAddress(boolean reuseAddress) { @Override public boolean isBroadcast() { try { - return datagramChannel.socket.isBroadcast(); + return ((KQueueDatagramChannel) channel).socket.isBroadcast(); } catch (IOException e) { throw new ChannelException(e); } @@ -338,7 +336,7 @@ public boolean isBroadcast() { @Override public KQueueDatagramChannelConfig setBroadcast(boolean broadcast) { try { - datagramChannel.socket.setBroadcast(broadcast); + ((KQueueDatagramChannel) channel).socket.setBroadcast(broadcast); return this; } catch (IOException e) { throw new ChannelException(e); diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java index 636d55a81445..43b5e6266309 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventArray.java @@ -15,8 +15,11 @@ */ package io.netty.channel.kqueue; +import io.netty.channel.unix.Buffer; import io.netty.util.internal.PlatformDependent; +import java.nio.ByteBuffer; + /** * Represents an array of kevent structures, backed by offheap memory. * @@ -37,6 +40,7 @@ final class KQueueEventArray { private static final int KQUEUE_FLAGS_OFFSET = Native.offsetofKEventFlags(); private static final int KQUEUE_DATA_OFFSET = Native.offsetofKeventData(); + private ByteBuffer memory; private long memoryAddress; private int size; private int capacity; @@ -45,7 +49,8 @@ final class KQueueEventArray { if (capacity < 1) { throw new IllegalArgumentException("capacity must be >= 1 but was " + capacity); } - memoryAddress = PlatformDependent.allocateMemory(capacity * KQUEUE_EVENT_SIZE); + memory = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(capacity)); + memoryAddress = Buffer.memoryAddress(memory); this.capacity = capacity; } @@ -73,11 +78,11 @@ void clear() { } void evSet(AbstractKQueueChannel ch, short filter, short flags, int fflags) { - checkSize(); - evSet(getKEventOffset(size++), ch, ch.socket.intValue(), filter, flags, fflags); + reallocIfNeeded(); + evSet(getKEventOffset(size++) + memoryAddress, ch.socket.intValue(), filter, flags, fflags); } - private void checkSize() { + private void reallocIfNeeded() { if (size == capacity) { realloc(true); } @@ -89,15 +94,25 @@ private void checkSize() { void realloc(boolean throwIfFail) { // Double the capacity while it is "sufficiently small", and otherwise increase by 50%. int newLength = capacity <= 65536 ? capacity << 1 : capacity + capacity >> 1; - long newMemoryAddress = PlatformDependent.reallocateMemory(memoryAddress, newLength * KQUEUE_EVENT_SIZE); - if (newMemoryAddress != 0) { - memoryAddress = newMemoryAddress; - capacity = newLength; - return; - } - if (throwIfFail) { - throw new OutOfMemoryError("unable to allocate " + newLength + " new bytes! Existing capacity is: " - + capacity); + + try { + ByteBuffer buffer = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(newLength)); + // Copy over the old content of the memory and reset the position as we always act on the buffer as if + // the position was never increased. + memory.position(0).limit(size); + buffer.put(memory); + buffer.position(0); + + Buffer.free(memory); + memory = buffer; + memoryAddress = Buffer.memoryAddress(buffer); + } catch (OutOfMemoryError e) { + if (throwIfFail) { + OutOfMemoryError error = new OutOfMemoryError( + "unable to allocate " + newLength + " new bytes! Existing capacity is: " + capacity); + error.initCause(e); + throw error; + } } } @@ -105,40 +120,54 @@ void realloc(boolean throwIfFail) { * Free this {@link KQueueEventArray}. Any usage after calling this method may segfault the JVM! */ void free() { - PlatformDependent.freeMemory(memoryAddress); + Buffer.free(memory); memoryAddress = size = capacity = 0; } - long getKEventOffset(int index) { - return memoryAddress + index * KQUEUE_EVENT_SIZE; + private static int getKEventOffset(int index) { + return index * KQUEUE_EVENT_SIZE; + } + + private long getKEventOffsetAddress(int index) { + return getKEventOffset(index) + memoryAddress; + } + + private short getShort(int index, int offset) { + if (PlatformDependent.hasUnsafe()) { + return PlatformDependent.getShort(getKEventOffsetAddress(index) + offset); + } + return memory.getShort(getKEventOffset(index) + offset); } short flags(int index) { - return PlatformDependent.getShort(getKEventOffset(index) + KQUEUE_FLAGS_OFFSET); + return getShort(index, KQUEUE_FLAGS_OFFSET); } short filter(int index) { - return PlatformDependent.getShort(getKEventOffset(index) + KQUEUE_FILTER_OFFSET); + return getShort(index, KQUEUE_FILTER_OFFSET); } short fflags(int index) { - return PlatformDependent.getShort(getKEventOffset(index) + KQUEUE_FFLAGS_OFFSET); + return getShort(index, KQUEUE_FFLAGS_OFFSET); } int fd(int index) { - return PlatformDependent.getInt(getKEventOffset(index) + KQUEUE_IDENT_OFFSET); + if (PlatformDependent.hasUnsafe()) { + return PlatformDependent.getInt(getKEventOffsetAddress(index) + KQUEUE_IDENT_OFFSET); + } + return memory.getInt(getKEventOffset(index) + KQUEUE_IDENT_OFFSET); } long data(int index) { - return PlatformDependent.getLong(getKEventOffset(index) + KQUEUE_DATA_OFFSET); + if (PlatformDependent.hasUnsafe()) { + return PlatformDependent.getLong(getKEventOffsetAddress(index) + KQUEUE_DATA_OFFSET); + } + return memory.getLong(getKEventOffset(index) + KQUEUE_DATA_OFFSET); } - AbstractKQueueChannel channel(int index) { - return getChannel(getKEventOffset(index)); + private static int calculateBufferCapacity(int capacity) { + return capacity * KQUEUE_EVENT_SIZE; } - private static native void evSet(long keventAddress, AbstractKQueueChannel ch, - int ident, short filter, short flags, int fflags); - private static native AbstractKQueueChannel getChannel(long keventAddress); - static native void deleteGlobalRefs(long channelAddressStart, long channelAddressEnd); + private static native void evSet(long keventAddress, int ident, short filter, short flags, int fflags); } diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventLoop.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventLoop.java index 3badcea96c75..25c50a4f05ba 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventLoop.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueEventLoop.java @@ -23,6 +23,8 @@ import io.netty.channel.unix.FileDescriptor; import io.netty.channel.unix.IovArray; import io.netty.util.IntSupplier; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; import io.netty.util.concurrent.RejectedExecutionHandler; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.PlatformDependent; @@ -31,12 +33,9 @@ import java.io.IOException; import java.util.Queue; -import java.util.concurrent.Callable; import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import static io.netty.channel.kqueue.KQueueEventArray.deleteGlobalRefs; import static java.lang.Math.min; /** @@ -54,7 +53,6 @@ final class KQueueEventLoop extends SingleThreadEventLoop { KQueue.ensureAvailability(); } - private final NativeLongArray jniChannelPointers; private final boolean allowGrowing; private final FileDescriptor kqueueFd; private final KQueueEventArray changeList; @@ -67,18 +65,11 @@ public int get() throws Exception { return kqueueWaitNow(); } }; - private final Callable pendingTasksCallable = new Callable() { - @Override - public Integer call() throws Exception { - return KQueueEventLoop.super.pendingTasks(); - } - }; + private final IntObjectMap channels = new IntObjectHashMap(4096); private volatile int wakenUp; private volatile int ioRatio = 50; - static final long MAX_SCHEDULED_DAYS = 365 * 3; - KQueueEventLoop(EventLoopGroup parent, Executor executor, int maxEvents, SelectStrategy strategy, RejectedExecutionHandler rejectedExecutionHandler) { super(parent, executor, false, DEFAULT_MAX_PENDING_TASKS, rejectedExecutionHandler); @@ -92,7 +83,6 @@ public Integer call() throws Exception { } changeList = new KQueueEventArray(maxEvents); eventList = new KQueueEventArray(maxEvents); - jniChannelPointers = new NativeLongArray(4096); int result = Native.keventAddUserEvent(kqueueFd.intValue(), KQUEUE_WAKE_UP_IDENT); if (result < 0) { cleanup(); @@ -100,18 +90,18 @@ public Integer call() throws Exception { } } + void add(AbstractKQueueChannel ch) { + assert inEventLoop(); + channels.put(ch.fd().intValue(), ch); + } + void evSet(AbstractKQueueChannel ch, short filter, short flags, int fflags) { changeList.evSet(ch, filter, flags, fflags); } - void remove(AbstractKQueueChannel ch) throws IOException { + void remove(AbstractKQueueChannel ch) { assert inEventLoop(); - if (ch.jniSelfPtr == 0) { - return; - } - - jniChannelPointers.add(ch.jniSelfPtr); - ch.jniSelfPtr = 0; + channels.remove(ch.fd().intValue()); } /** @@ -154,32 +144,25 @@ private int kqueueWaitNow() throws IOException { } private int kqueueWait(int timeoutSec, int timeoutNs) throws IOException { - deleteJniChannelPointers(); int numEvents = Native.keventWait(kqueueFd.intValue(), changeList, eventList, timeoutSec, timeoutNs); changeList.clear(); return numEvents; } - private void deleteJniChannelPointers() { - if (!jniChannelPointers.isEmpty()) { - deleteGlobalRefs(jniChannelPointers.memoryAddress(), jniChannelPointers.memoryAddressEnd()); - jniChannelPointers.clear(); - } - } - private void processReady(int ready) { for (int i = 0; i < ready; ++i) { final short filter = eventList.filter(i); final short flags = eventList.flags(i); + final int fd = eventList.fd(i); if (filter == Native.EVFILT_USER || (flags & Native.EV_ERROR) != 0) { // EV_ERROR is returned if the FD is closed synchronously (which removes from kqueue) and then // we later attempt to delete the filters from kqueue. assert filter != Native.EVFILT_USER || - (filter == Native.EVFILT_USER && eventList.fd(i) == KQUEUE_WAKE_UP_IDENT); + (filter == Native.EVFILT_USER && fd == KQUEUE_WAKE_UP_IDENT); continue; } - AbstractKQueueChannel channel = eventList.channel(i); + AbstractKQueueChannel channel = channels.get(fd); if (channel == null) { // This may happen if the channel has already been closed, and it will be removed from kqueue anyways. // We also handle EV_ERROR above to skip this even early if it is a result of a referencing a closed and @@ -217,6 +200,10 @@ protected void run() { switch (strategy) { case SelectStrategy.CONTINUE: continue; + + case SelectStrategy.BUSY_WAIT: + // fall-through to SELECT since the busy-wait is not supported with kqueue + case SelectStrategy.SELECT: strategy = kqueueWait(WAKEN_UP_UPDATER.getAndSet(this, 0) == 1); @@ -304,14 +291,6 @@ protected Queue newTaskQueue(int maxPendingTasks) { : PlatformDependent.newMpscQueue(maxPendingTasks); } - @Override - public int pendingTasks() { - // As we use a MpscQueue we need to ensure pendingTasks() is only executed from within the EventLoop as - // otherwise we may see unexpected behavior (as size() is only allowed to be called by a single consumer). - // See https://github.com/netty/netty/issues/5297 - return inEventLoop() ? super.pendingTasks() : submit(pendingTasksCallable).syncUninterruptibly().getNow(); - } - /** * Returns the percentage of the desired amount of time spent for I/O in the event loop. */ @@ -340,12 +319,6 @@ protected void cleanup() { } } finally { // Cleanup all native memory! - - // The JNI channel pointers should already be deleted because we should wait on kevent before this method, - // but lets just be sure we cleanup native memory. - deleteJniChannelPointers(); - jniChannelPointers.free(); - changeList.free(); eventList.free(); } @@ -370,12 +343,4 @@ private static void handleLoopException(Throwable t) { // Ignore. } } - - @Override - protected void validateScheduled(long amount, TimeUnit unit) { - long days = unit.toDays(amount); - if (days > MAX_SCHEDULED_DAYS) { - throw new IllegalArgumentException("days: " + days + " (expected: < " + MAX_SCHEDULED_DAYS + ')'); - } - } } diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueRecvByteAllocatorHandle.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueRecvByteAllocatorHandle.java index 6ff8d959cb1f..101f4be8fc6c 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueRecvByteAllocatorHandle.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueRecvByteAllocatorHandle.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelConfig; import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.unix.PreferredDirectByteBufAllocator; import io.netty.util.UncheckedBooleanSupplier; import io.netty.util.internal.ObjectUtil; @@ -26,7 +27,10 @@ import static java.lang.Math.min; final class KQueueRecvByteAllocatorHandle implements RecvByteBufAllocator.ExtendedHandle { + private final PreferredDirectByteBufAllocator preferredDirectByteBufAllocator = + new PreferredDirectByteBufAllocator(); private final RecvByteBufAllocator.ExtendedHandle delegate; + private final UncheckedBooleanSupplier defaultMaybeMoreDataSupplier = new UncheckedBooleanSupplier() { @Override public boolean get() { @@ -38,7 +42,7 @@ public boolean get() { private long numberBytesPending; KQueueRecvByteAllocatorHandle(RecvByteBufAllocator.ExtendedHandle handle) { - this.delegate = ObjectUtil.checkNotNull(handle, "handle"); + delegate = ObjectUtil.checkNotNull(handle, "handle"); } @Override @@ -59,7 +63,10 @@ public void incMessagesRead(int numMessages) { @Override public ByteBuf allocate(ByteBufAllocator alloc) { - return overrideGuess ? alloc.ioBuffer(guess0()) : delegate.allocate(alloc); + // We need to ensure we always allocate a direct ByteBuf as we can only use a direct buffer to read via JNI. + preferredDirectByteBufAllocator.updateAllocator(alloc); + return overrideGuess ? preferredDirectByteBufAllocator.ioBuffer(guess0()) : + delegate.allocate(preferredDirectByteBufAllocator); } @Override @@ -108,6 +115,10 @@ void readEOF() { readEOF = true; } + boolean isReadEOF() { + return readEOF; + } + void numberBytesPending(long numberBytesPending) { this.numberBytesPending = numberBytesPending; } @@ -121,9 +132,9 @@ boolean maybeMoreDataToRead() { * channel. It is expected that the {@link #KQueueSocketChannel} implementations will track if all data was not * read, and will force a EVFILT_READ ready event. * - * If EOF has been read we must read until we get an error. + * It is assumed EOF is handled externally by checking {@link #isReadEOF()}. */ - return numberBytesPending != 0 || readEOF; + return numberBytesPending != 0; } private int guess0() { diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueServerChannelConfig.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueServerChannelConfig.java index 7f878dc25779..cf103f9a3ae5 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueServerChannelConfig.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueServerChannelConfig.java @@ -31,15 +31,14 @@ import static io.netty.channel.ChannelOption.SO_BACKLOG; import static io.netty.channel.ChannelOption.SO_RCVBUF; import static io.netty.channel.ChannelOption.SO_REUSEADDR; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; @UnstableApi public class KQueueServerChannelConfig extends KQueueChannelConfig implements ServerSocketChannelConfig { - protected final AbstractKQueueChannel channel; private volatile int backlog = NetUtil.SOMAXCONN; KQueueServerChannelConfig(AbstractKQueueChannel channel) { super(channel); - this.channel = channel; } @Override @@ -81,7 +80,7 @@ public boolean setOption(ChannelOption option, T value) { public boolean isReuseAddress() { try { - return channel.socket.isReuseAddress(); + return ((AbstractKQueueChannel) channel).socket.isReuseAddress(); } catch (IOException e) { throw new ChannelException(e); } @@ -89,7 +88,7 @@ public boolean isReuseAddress() { public KQueueServerChannelConfig setReuseAddress(boolean reuseAddress) { try { - channel.socket.setReuseAddress(reuseAddress); + ((AbstractKQueueChannel) channel).socket.setReuseAddress(reuseAddress); return this; } catch (IOException e) { throw new ChannelException(e); @@ -98,7 +97,7 @@ public KQueueServerChannelConfig setReuseAddress(boolean reuseAddress) { public int getReceiveBufferSize() { try { - return channel.socket.getReceiveBufferSize(); + return ((AbstractKQueueChannel) channel).socket.getReceiveBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -106,7 +105,7 @@ public int getReceiveBufferSize() { public KQueueServerChannelConfig setReceiveBufferSize(int receiveBufferSize) { try { - channel.socket.setReceiveBufferSize(receiveBufferSize); + ((AbstractKQueueChannel) channel).socket.setReceiveBufferSize(receiveBufferSize); return this; } catch (IOException e) { throw new ChannelException(e); @@ -118,9 +117,7 @@ public int getBacklog() { } public KQueueServerChannelConfig setBacklog(int backlog) { - if (backlog < 0) { - throw new IllegalArgumentException("backlog: " + backlog); - } + checkPositiveOrZero(backlog, "backlog"); this.backlog = backlog; return this; } diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueServerSocketChannelConfig.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueServerSocketChannelConfig.java index dce3e6e71b85..a743e039de6f 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueServerSocketChannelConfig.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueServerSocketChannelConfig.java @@ -75,7 +75,7 @@ public boolean setOption(ChannelOption option, T value) { public KQueueServerSocketChannelConfig setReusePort(boolean reusePort) { try { - channel.socket.setReusePort(reusePort); + ((KQueueServerSocketChannel) channel).socket.setReusePort(reusePort); return this; } catch (IOException e) { throw new ChannelException(e); @@ -84,7 +84,7 @@ public KQueueServerSocketChannelConfig setReusePort(boolean reusePort) { public boolean isReusePort() { try { - return channel.socket.isReusePort(); + return ((KQueueServerSocketChannel) channel).socket.isReusePort(); } catch (IOException e) { throw new ChannelException(e); } @@ -92,7 +92,7 @@ public boolean isReusePort() { public KQueueServerSocketChannelConfig setAcceptFilter(AcceptFilter acceptFilter) { try { - channel.socket.setAcceptFilter(acceptFilter); + ((KQueueServerSocketChannel) channel).socket.setAcceptFilter(acceptFilter); return this; } catch (IOException e) { throw new ChannelException(e); @@ -101,7 +101,7 @@ public KQueueServerSocketChannelConfig setAcceptFilter(AcceptFilter acceptFilter public AcceptFilter getAcceptFilter() { try { - return channel.socket.getAcceptFilter(); + return ((KQueueServerSocketChannel) channel).socket.getAcceptFilter(); } catch (IOException e) { throw new ChannelException(e); } diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueSocketChannelConfig.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueSocketChannelConfig.java index 8662e55c7b98..b5c718b9113a 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueSocketChannelConfig.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueSocketChannelConfig.java @@ -41,12 +41,10 @@ @UnstableApi public final class KQueueSocketChannelConfig extends KQueueChannelConfig implements SocketChannelConfig { - private final KQueueSocketChannel channel; private volatile boolean allowHalfClosure; KQueueSocketChannelConfig(KQueueSocketChannel channel) { super(channel); - this.channel = channel; if (PlatformDependent.canEnableTcpNoDelayByDefault()) { setTcpNoDelay(true); } @@ -131,7 +129,7 @@ public boolean setOption(ChannelOption option, T value) { @Override public int getReceiveBufferSize() { try { - return channel.socket.getReceiveBufferSize(); + return ((KQueueSocketChannel) channel).socket.getReceiveBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -140,7 +138,7 @@ public int getReceiveBufferSize() { @Override public int getSendBufferSize() { try { - return channel.socket.getSendBufferSize(); + return ((KQueueSocketChannel) channel).socket.getSendBufferSize(); } catch (IOException e) { throw new ChannelException(e); } @@ -149,7 +147,7 @@ public int getSendBufferSize() { @Override public int getSoLinger() { try { - return channel.socket.getSoLinger(); + return ((KQueueSocketChannel) channel).socket.getSoLinger(); } catch (IOException e) { throw new ChannelException(e); } @@ -158,7 +156,7 @@ public int getSoLinger() { @Override public int getTrafficClass() { try { - return channel.socket.getTrafficClass(); + return ((KQueueSocketChannel) channel).socket.getTrafficClass(); } catch (IOException e) { throw new ChannelException(e); } @@ -167,7 +165,7 @@ public int getTrafficClass() { @Override public boolean isKeepAlive() { try { - return channel.socket.isKeepAlive(); + return ((KQueueSocketChannel) channel).socket.isKeepAlive(); } catch (IOException e) { throw new ChannelException(e); } @@ -176,7 +174,7 @@ public boolean isKeepAlive() { @Override public boolean isReuseAddress() { try { - return channel.socket.isReuseAddress(); + return ((KQueueSocketChannel) channel).socket.isReuseAddress(); } catch (IOException e) { throw new ChannelException(e); } @@ -185,7 +183,7 @@ public boolean isReuseAddress() { @Override public boolean isTcpNoDelay() { try { - return channel.socket.isTcpNoDelay(); + return ((KQueueSocketChannel) channel).socket.isTcpNoDelay(); } catch (IOException e) { throw new ChannelException(e); } @@ -193,7 +191,7 @@ public boolean isTcpNoDelay() { public int getSndLowAt() { try { - return channel.socket.getSndLowAt(); + return ((KQueueSocketChannel) channel).socket.getSndLowAt(); } catch (IOException e) { throw new ChannelException(e); } @@ -201,7 +199,7 @@ public int getSndLowAt() { public void setSndLowAt(int sndLowAt) { try { - channel.socket.setSndLowAt(sndLowAt); + ((KQueueSocketChannel) channel).socket.setSndLowAt(sndLowAt); } catch (IOException e) { throw new ChannelException(e); } @@ -209,7 +207,7 @@ public void setSndLowAt(int sndLowAt) { public boolean isTcpNoPush() { try { - return channel.socket.isTcpNoPush(); + return ((KQueueSocketChannel) channel).socket.isTcpNoPush(); } catch (IOException e) { throw new ChannelException(e); } @@ -217,7 +215,7 @@ public boolean isTcpNoPush() { public void setTcpNoPush(boolean tcpNoPush) { try { - channel.socket.setTcpNoPush(tcpNoPush); + ((KQueueSocketChannel) channel).socket.setTcpNoPush(tcpNoPush); } catch (IOException e) { throw new ChannelException(e); } @@ -226,7 +224,7 @@ public void setTcpNoPush(boolean tcpNoPush) { @Override public KQueueSocketChannelConfig setKeepAlive(boolean keepAlive) { try { - channel.socket.setKeepAlive(keepAlive); + ((KQueueSocketChannel) channel).socket.setKeepAlive(keepAlive); return this; } catch (IOException e) { throw new ChannelException(e); @@ -236,7 +234,7 @@ public KQueueSocketChannelConfig setKeepAlive(boolean keepAlive) { @Override public KQueueSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) { try { - channel.socket.setReceiveBufferSize(receiveBufferSize); + ((KQueueSocketChannel) channel).socket.setReceiveBufferSize(receiveBufferSize); return this; } catch (IOException e) { throw new ChannelException(e); @@ -246,7 +244,7 @@ public KQueueSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) { @Override public KQueueSocketChannelConfig setReuseAddress(boolean reuseAddress) { try { - channel.socket.setReuseAddress(reuseAddress); + ((KQueueSocketChannel) channel).socket.setReuseAddress(reuseAddress); return this; } catch (IOException e) { throw new ChannelException(e); @@ -256,7 +254,7 @@ public KQueueSocketChannelConfig setReuseAddress(boolean reuseAddress) { @Override public KQueueSocketChannelConfig setSendBufferSize(int sendBufferSize) { try { - channel.socket.setSendBufferSize(sendBufferSize); + ((KQueueSocketChannel) channel).socket.setSendBufferSize(sendBufferSize); calculateMaxBytesPerGatheringWrite(); return this; } catch (IOException e) { @@ -267,7 +265,7 @@ public KQueueSocketChannelConfig setSendBufferSize(int sendBufferSize) { @Override public KQueueSocketChannelConfig setSoLinger(int soLinger) { try { - channel.socket.setSoLinger(soLinger); + ((KQueueSocketChannel) channel).socket.setSoLinger(soLinger); return this; } catch (IOException e) { throw new ChannelException(e); @@ -277,7 +275,7 @@ public KQueueSocketChannelConfig setSoLinger(int soLinger) { @Override public KQueueSocketChannelConfig setTcpNoDelay(boolean tcpNoDelay) { try { - channel.socket.setTcpNoDelay(tcpNoDelay); + ((KQueueSocketChannel) channel).socket.setTcpNoDelay(tcpNoDelay); return this; } catch (IOException e) { throw new ChannelException(e); @@ -287,7 +285,7 @@ public KQueueSocketChannelConfig setTcpNoDelay(boolean tcpNoDelay) { @Override public KQueueSocketChannelConfig setTrafficClass(int trafficClass) { try { - channel.socket.setTrafficClass(trafficClass); + ((KQueueSocketChannel) channel).socket.setTrafficClass(trafficClass); return this; } catch (IOException e) { throw new ChannelException(e); diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/Native.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/Native.java index d0862748a106..675432a25d96 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/Native.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/Native.java @@ -44,7 +44,7 @@ import static io.netty.channel.unix.Errors.newIOException; /** - * Navite helper methods + * Native helper methods *

    Internal usage only! */ final class Native { diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java index d9f00094f16a..7f83738e7084 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/NativeLongArray.java @@ -15,11 +15,15 @@ */ package io.netty.channel.kqueue; +import io.netty.channel.unix.Buffer; import io.netty.util.internal.PlatformDependent; +import java.nio.ByteBuffer; + import static io.netty.channel.unix.Limits.SIZEOF_JLONG; final class NativeLongArray { + private ByteBuffer memory; private long memoryAddress; private int capacity; private int size; @@ -28,13 +32,27 @@ final class NativeLongArray { if (capacity < 1) { throw new IllegalArgumentException("capacity must be >= 1 but was " + capacity); } - memoryAddress = PlatformDependent.allocateMemory(capacity * SIZEOF_JLONG); + memory = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(capacity)); + memoryAddress = Buffer.memoryAddress(memory); this.capacity = capacity; } + private static int idx(int index) { + return index * SIZEOF_JLONG; + } + + private static int calculateBufferCapacity(int capacity) { + return capacity * SIZEOF_JLONG; + } + void add(long value) { - checkSize(); - PlatformDependent.putLong(memoryOffset(size++), value); + reallocIfNeeded(); + if (PlatformDependent.hasUnsafe()) { + PlatformDependent.putLong(memoryOffset(size), value); + } else { + memory.putLong(idx(size), value); + } + ++size; } void clear() { @@ -46,7 +64,7 @@ boolean isEmpty() { } void free() { - PlatformDependent.freeMemory(memoryAddress); + Buffer.free(memory); memoryAddress = 0; } @@ -59,25 +77,25 @@ long memoryAddressEnd() { } private long memoryOffset(int index) { - return memoryAddress + index * SIZEOF_JLONG; + return memoryAddress + idx(index); } - private void checkSize() { + private void reallocIfNeeded() { if (size == capacity) { - realloc(); - } - } + // Double the capacity while it is "sufficiently small", and otherwise increase by 50%. + int newLength = capacity <= 65536 ? capacity << 1 : capacity + capacity >> 1; + ByteBuffer buffer = Buffer.allocateDirectWithNativeOrder(calculateBufferCapacity(newLength)); + // Copy over the old content of the memory and reset the position as we always act on the buffer as if + // the position was never increased. + memory.position(0).limit(size); + buffer.put(memory); + buffer.position(0); - private void realloc() { - // Double the capacity while it is "sufficiently small", and otherwise increase by 50%. - int newLength = capacity <= 65536 ? capacity << 1 : capacity + capacity >> 1; - long newMemoryAddress = PlatformDependent.reallocateMemory(memoryAddress, newLength * SIZEOF_JLONG); - if (newMemoryAddress == 0) { - throw new OutOfMemoryError("unable to allocate " + newLength + " new bytes! Existing capacity is: " - + capacity); + Buffer.free(memory); + memory = buffer; + memoryAddress = Buffer.memoryAddress(buffer); + capacity = newLength; } - memoryAddress = newMemoryAddress; - capacity = newLength; } @Override diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketDataReadInitialStateTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketDataReadInitialStateTest.java new file mode 100644 index 000000000000..10e538b8e42d --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketDataReadInitialStateTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketDataReadInitialStateTest; + +import java.net.SocketAddress; +import java.util.List; + +public class KQueueDomainSocketDataReadInitialStateTest extends SocketDataReadInitialStateTest { + @Override + protected SocketAddress newSocketAddress() { + return KQueueSocketTestPermutation.newSocketAddress(); + } + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.domainSocket(); + } +} diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketSslClientRenegotiateTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketSslClientRenegotiateTest.java new file mode 100644 index 000000000000..a719b7c73ef7 --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketSslClientRenegotiateTest.java @@ -0,0 +1,42 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.handler.ssl.SslContext; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketSslClientRenegotiateTest; + +import java.net.SocketAddress; +import java.util.List; + +public class KQueueDomainSocketSslClientRenegotiateTest extends SocketSslClientRenegotiateTest { + + public KQueueDomainSocketSslClientRenegotiateTest(SslContext serverCtx, SslContext clientCtx, boolean delegate) { + super(serverCtx, clientCtx, delegate); + } + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.domainSocket(); + } + + @Override + protected SocketAddress newSocketAddress() { + return KQueueSocketTestPermutation.newSocketAddress(); + } +} diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketSslGreetingTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketSslGreetingTest.java index 492021a52ba1..8cbaa00efa29 100644 --- a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketSslGreetingTest.java +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketSslGreetingTest.java @@ -26,8 +26,8 @@ public class KQueueDomainSocketSslGreetingTest extends SocketSslGreetingTest { - public KQueueDomainSocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx) { - super(serverCtx, clientCtx); + public KQueueDomainSocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx, boolean delegate) { + super(serverCtx, clientCtx, delegate); } @Override diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueETSocketDataReadInitialStateTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueETSocketDataReadInitialStateTest.java new file mode 100644 index 000000000000..3055b1dafe68 --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueETSocketDataReadInitialStateTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketDataReadInitialStateTest; + +import java.util.List; + +public class KQueueETSocketDataReadInitialStateTest extends SocketDataReadInitialStateTest { + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.socket(); + } +} diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueEventLoopTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueEventLoopTest.java index c7ad56e37ab1..0d441559994a 100644 --- a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueEventLoopTest.java +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueEventLoopTest.java @@ -24,32 +24,11 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; public class KQueueEventLoopTest { - @Test(timeout = 5000L) - public void testScheduleBigDelayOverMax() { - EventLoopGroup group = new KQueueEventLoopGroup(1); - - final EventLoop el = group.next(); - try { - el.schedule(new Runnable() { - @Override - public void run() { - // NOOP - } - }, Integer.MAX_VALUE, TimeUnit.DAYS); - fail(); - } catch (IllegalArgumentException expected) { - // expected - } - - group.shutdownGracefully(); - } - @Test - public void testScheduleBigDelay() { + public void testScheduleBigDelayNotOverflow() { EventLoopGroup group = new KQueueEventLoopGroup(1); final EventLoop el = group.next(); @@ -58,7 +37,7 @@ public void testScheduleBigDelay() { public void run() { // NOOP } - }, KQueueEventLoop.MAX_SCHEDULED_DAYS, TimeUnit.DAYS); + }, Long.MAX_VALUE, TimeUnit.MILLISECONDS); assertFalse(future.awaitUninterruptibly(1000)); assertTrue(future.cancel(true)); diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslClientRenegotiateTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslClientRenegotiateTest.java new file mode 100644 index 000000000000..a3ba23818198 --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslClientRenegotiateTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.handler.ssl.SslContext; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketSslClientRenegotiateTest; + +import java.util.List; + +public class KQueueSocketSslClientRenegotiateTest extends SocketSslClientRenegotiateTest { + + public KQueueSocketSslClientRenegotiateTest(SslContext serverCtx, SslContext clientCtx, boolean delegate) { + super(serverCtx, clientCtx, delegate); + } + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.socket(); + } +} diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslGreetingTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslGreetingTest.java index 9242fc388f78..6eecc35a5196 100644 --- a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslGreetingTest.java +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslGreetingTest.java @@ -25,8 +25,8 @@ public class KQueueSocketSslGreetingTest extends SocketSslGreetingTest { - public KQueueSocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx) { - super(serverCtx, clientCtx); + public KQueueSocketSslGreetingTest(SslContext serverCtx, SslContext clientCtx, boolean delegate) { + super(serverCtx, clientCtx, delegate); } @Override diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslSessionReuseTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslSessionReuseTest.java new file mode 100644 index 000000000000..5508dcb0a29a --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketSslSessionReuseTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.handler.ssl.SslContext; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketSslSessionReuseTest; + +import java.util.List; + +public class KQueueSocketSslSessionReuseTest extends SocketSslSessionReuseTest { + + public KQueueSocketSslSessionReuseTest(SslContext serverCtx, SslContext clientCtx) { + super(serverCtx, clientCtx); + } + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.socket(); + } +} diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KqueueWriteBeforeRegisteredTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KqueueWriteBeforeRegisteredTest.java new file mode 100644 index 000000000000..b8d447565ce4 --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KqueueWriteBeforeRegisteredTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.WriteBeforeRegisteredTest; + +import java.util.List; + +public class KqueueWriteBeforeRegisteredTest extends WriteBeforeRegisteredTest { + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.clientSocket(); + } +} diff --git a/transport-native-unix-common-tests/pom.xml b/transport-native-unix-common-tests/pom.xml index 17933640fb02..62dbb1eda592 100644 --- a/transport-native-unix-common-tests/pom.xml +++ b/transport-native-unix-common-tests/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-transport-native-unix-common-tests @@ -30,6 +30,11 @@ + + io.netty + netty-transport + ${project.version} + io.netty netty-transport-native-unix-common diff --git a/transport-native-unix-common-tests/src/main/java/io/netty/channel/unix/tests/DetectPeerCloseWithoutReadTest.java b/transport-native-unix-common-tests/src/main/java/io/netty/channel/unix/tests/DetectPeerCloseWithoutReadTest.java index ef5482d11b01..a8b231f5e14c 100644 --- a/transport-native-unix-common-tests/src/main/java/io/netty/channel/unix/tests/DetectPeerCloseWithoutReadTest.java +++ b/transport-native-unix-common-tests/src/main/java/io/netty/channel/unix/tests/DetectPeerCloseWithoutReadTest.java @@ -25,6 +25,7 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; +import io.netty.channel.FixedRecvByteBufAllocator; import io.netty.channel.ServerChannel; import io.netty.channel.SimpleChannelInboundHandler; import org.junit.Test; @@ -41,7 +42,17 @@ public abstract class DetectPeerCloseWithoutReadTest { protected abstract Class clientChannel(); @Test(timeout = 10000) - public void clientCloseWithoutServerReadIsDetected() throws InterruptedException { + public void clientCloseWithoutServerReadIsDetectedNoExtraReadRequested() throws InterruptedException { + clientCloseWithoutServerReadIsDetected0(false); + } + + @Test(timeout = 10000) + public void clientCloseWithoutServerReadIsDetectedExtraReadRequested() throws InterruptedException { + clientCloseWithoutServerReadIsDetected0(true); + } + + private void clientCloseWithoutServerReadIsDetected0(final boolean extraReadRequested) + throws InterruptedException { EventLoopGroup serverGroup = null; EventLoopGroup clientGroup = null; Channel serverChannel = null; @@ -54,11 +65,15 @@ public void clientCloseWithoutServerReadIsDetected() throws InterruptedException ServerBootstrap sb = new ServerBootstrap(); sb.group(serverGroup); sb.channel(serverChannel()); + // Ensure we read only one message per read() call and that we need multiple read() + // calls to consume everything. sb.childOption(ChannelOption.AUTO_READ, false); + sb.childOption(ChannelOption.MAX_MESSAGES_PER_READ, 1); + sb.childOption(ChannelOption.RCVBUF_ALLOCATOR, new FixedRecvByteBufAllocator(expectedBytes / 10)); sb.childHandler(new ChannelInitializer() { @Override - protected void initChannel(Channel ch) throws Exception { - ch.pipeline().addLast(new TestHandler(bytesRead, latch)); + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new TestHandler(bytesRead, extraReadRequested, latch)); } }); @@ -89,7 +104,16 @@ protected void initChannel(Channel ch) throws Exception { } @Test(timeout = 10000) - public void serverCloseWithoutClientReadIsDetected() throws InterruptedException { + public void serverCloseWithoutClientReadIsDetectedNoExtraReadRequested() throws InterruptedException { + serverCloseWithoutClientReadIsDetected0(false); + } + + @Test(timeout = 10000) + public void serverCloseWithoutClientReadIsDetectedExtraReadRequested() throws InterruptedException { + serverCloseWithoutClientReadIsDetected0(true); + } + + private void serverCloseWithoutClientReadIsDetected0(final boolean extraReadRequested) throws InterruptedException { EventLoopGroup serverGroup = null; EventLoopGroup clientGroup = null; Channel serverChannel = null; @@ -105,10 +129,10 @@ public void serverCloseWithoutClientReadIsDetected() throws InterruptedException sb.channel(serverChannel()); sb.childHandler(new ChannelInitializer() { @Override - protected void initChannel(Channel ch) throws Exception { + protected void initChannel(Channel ch) { ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { + public void channelActive(ChannelHandlerContext ctx) { ByteBuf buf = ctx.alloc().buffer(expectedBytes); buf.writerIndex(buf.writerIndex() + expectedBytes); ctx.writeAndFlush(buf).addListener(ChannelFutureListener.CLOSE); @@ -123,11 +147,15 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { Bootstrap cb = new Bootstrap(); cb.group(serverGroup); cb.channel(clientChannel()); + // Ensure we read only one message per read() call and that we need multiple read() + // calls to consume everything. cb.option(ChannelOption.AUTO_READ, false); + cb.option(ChannelOption.MAX_MESSAGES_PER_READ, 1); + cb.option(ChannelOption.RCVBUF_ALLOCATOR, new FixedRecvByteBufAllocator(expectedBytes / 10)); cb.handler(new ChannelInitializer() { @Override protected void initChannel(Channel ch) throws Exception { - ch.pipeline().addLast(new TestHandler(bytesRead, latch)); + ch.pipeline().addLast(new TestHandler(bytesRead, extraReadRequested, latch)); } }); clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel(); @@ -152,22 +180,27 @@ protected void initChannel(Channel ch) throws Exception { private static final class TestHandler extends SimpleChannelInboundHandler { private final AtomicInteger bytesRead; + private final boolean extraReadRequested; private final CountDownLatch latch; - TestHandler(AtomicInteger bytesRead, CountDownLatch latch) { + TestHandler(AtomicInteger bytesRead, boolean extraReadRequested, CountDownLatch latch) { this.bytesRead = bytesRead; + this.extraReadRequested = extraReadRequested; this.latch = latch; } @Override - protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { bytesRead.addAndGet(msg.readableBytes()); - // Because autoread is off, we call read to consume all data until we detect the close. - ctx.read(); + + if (extraReadRequested) { + // Because autoread is off, we call read to consume all data until we detect the close. + ctx.read(); + } } @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { + public void channelInactive(ChannelHandlerContext ctx) { latch.countDown(); ctx.fireChannelInactive(); } diff --git a/transport-native-unix-common/pom.xml b/transport-native-unix-common/pom.xml index 62d0b9e8ddb0..10b6d5bc6bd0 100644 --- a/transport-native-unix-common/pom.xml +++ b/transport-native-unix-common/pom.xml @@ -19,7 +19,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-transport-native-unix-common @@ -101,7 +101,7 @@ - + + + ${project.groupId} + netty-common + ${project.version} + ${project.groupId} netty-buffer ${project.version} - ${project.groupId} netty-transport @@ -78,7 +82,7 @@ org.apache.maven.plugins maven-surefire-plugin - always + false diff --git a/transport-udt/src/test/java/io/netty/test/udt/util/CaliperMeasure.java b/transport-udt/src/test/java/io/netty/test/udt/util/CaliperMeasure.java index eb1a489f16a8..68aa2a4d5028 100644 --- a/transport-udt/src/test/java/io/netty/test/udt/util/CaliperMeasure.java +++ b/transport-udt/src/test/java/io/netty/test/udt/util/CaliperMeasure.java @@ -172,7 +172,7 @@ public Map variables() { } private static MeasurementSet measurementSet(final Map map) { - final Measurement[] array = map.values().toArray(new Measurement[map.size()]); + final Measurement[] array = map.values().toArray(new Measurement[0]); return new MeasurementSet(array); } diff --git a/transport/pom.xml b/transport/pom.xml index 67a96a38e504..ff87fbec631d 100644 --- a/transport/pom.xml +++ b/transport/pom.xml @@ -20,7 +20,7 @@ io.netty netty-parent - 4.1.25.5.dse + 4.1.34.3.dse netty-transport @@ -33,6 +33,11 @@ + + ${project.groupId} + netty-common + ${project.version} + ${project.groupId} netty-buffer diff --git a/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java b/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java index fb05c8b263e9..1030c320411e 100644 --- a/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java +++ b/transport/src/main/java/io/netty/bootstrap/AbstractBootstrap.java @@ -126,7 +126,7 @@ public B channelFactory(ChannelFactory channelFactory) { * {@link io.netty.channel.ChannelFactory} which is used to create {@link Channel} instances from * when calling {@link #bind()}. This method is usually only used if {@link #channel(Class)} * is not working for you because of some more complex needs. If your {@link Channel} implementation - * has a no-args constructor, its highly recommend to just use {@link #channel(Class)} for + * has a no-args constructor, its highly recommend to just use {@link #channel(Class)} to * simplify your code. */ @SuppressWarnings({ "unchecked", "deprecation" }) diff --git a/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java b/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java index 2a4c04e74162..310e9fb9fe7d 100644 --- a/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java +++ b/transport/src/main/java/io/netty/bootstrap/ServerBootstrap.java @@ -160,10 +160,10 @@ void init(Channel channel) throws Exception { final Entry, Object>[] currentChildOptions; final Entry, Object>[] currentChildAttrs; synchronized (childOptions) { - currentChildOptions = childOptions.entrySet().toArray(newOptionArray(childOptions.size())); + currentChildOptions = childOptions.entrySet().toArray(newOptionArray(0)); } synchronized (childAttrs) { - currentChildAttrs = childAttrs.entrySet().toArray(newAttrArray(childAttrs.size())); + currentChildAttrs = childAttrs.entrySet().toArray(newAttrArray(0)); } p.addLast(new ChannelInitializer() { diff --git a/transport/src/main/java/io/netty/channel/AbstractChannel.java b/transport/src/main/java/io/netty/channel/AbstractChannel.java index 75ad995cd496..d60b0117c765 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannel.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannel.java @@ -47,14 +47,14 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractChannel.class); - private static final ClosedChannelException FLUSH0_CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace( - new ClosedChannelException(), AbstractUnsafe.class, "flush0()"); private static final ClosedChannelException ENSURE_OPEN_CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace( - new ClosedChannelException(), AbstractUnsafe.class, "ensureOpen(...)"); + new ExtendedClosedChannelException(null), AbstractUnsafe.class, "ensureOpen(...)"); private static final ClosedChannelException CLOSE_CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace( new ClosedChannelException(), AbstractUnsafe.class, "close(...)"); private static final ClosedChannelException WRITE_CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace( - new ClosedChannelException(), AbstractUnsafe.class, "write(...)"); + new ExtendedClosedChannelException(null), AbstractUnsafe.class, "write(...)"); + private static final ClosedChannelException FLUSH0_CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace( + new ExtendedClosedChannelException(null), AbstractUnsafe.class, "flush0()"); private static final NotYetConnectedException FLUSH0_NOT_YET_CONNECTED_EXCEPTION = ThrowableUtil.unknownStackTrace( new NotYetConnectedException(), AbstractUnsafe.class, "flush0()"); @@ -70,6 +70,7 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha protected volatile EventLoop eventLoop; private volatile boolean registered; private boolean closeInitiated; + private Throwable initialCloseCause; /** Cache for the string representation of this channel */ private boolean strValActive; @@ -187,6 +188,8 @@ public SocketAddress localAddress() { if (localAddress == null) { try { this.localAddress = localAddress = unsafe().localAddress(); + } catch (Error e) { + throw e; } catch (Throwable t) { // Sometimes fails on a closed socket in Windows. return null; @@ -209,6 +212,8 @@ public SocketAddress remoteAddress() { if (remoteAddress == null) { try { this.remoteAddress = remoteAddress = unsafe().remoteAddress(); + } catch (Error e) { + throw e; } catch (Throwable t) { // Sometimes fails on a closed socket in Windows. return null; @@ -883,7 +888,7 @@ public final void write(Object msg, ChannelPromise promise) { // need to fail the future right away. If it is not null the handling of the rest // will be done in flush0() // See https://github.com/netty/netty/issues/2362 - safeSetFailure(promise, WRITE_CLOSED_CHANNEL_EXCEPTION); + safeSetFailure(promise, newWriteException(initialCloseCause)); // release message now to prevent resource-leak ReferenceCountUtil.release(msg); return; @@ -939,7 +944,7 @@ protected void flush0() { outboundBuffer.failFlushed(FLUSH0_NOT_YET_CONNECTED_EXCEPTION, true); } else { // Do not trigger channelWritabilityChanged because the channel is closed already. - outboundBuffer.failFlushed(FLUSH0_CLOSED_CHANNEL_EXCEPTION, false); + outboundBuffer.failFlushed(newFlush0Exception(initialCloseCause), false); } } finally { inFlush0 = false; @@ -959,12 +964,14 @@ protected void flush0() { * This is needed as otherwise {@link #isActive()} , {@link #isOpen()} and {@link #isWritable()} * may still return {@code true} even if the channel should be closed as result of the exception. */ - close(voidPromise(), t, FLUSH0_CLOSED_CHANNEL_EXCEPTION, false); + initialCloseCause = t; + close(voidPromise(), t, newFlush0Exception(t), false); } else { try { shutdownOutput(voidPromise(), t); } catch (Throwable t2) { - close(voidPromise(), t2, FLUSH0_CLOSED_CHANNEL_EXCEPTION, false); + initialCloseCause = t; + close(voidPromise(), t2, newFlush0Exception(t), false); } } } finally { @@ -972,6 +979,30 @@ protected void flush0() { } } + private ClosedChannelException newWriteException(Throwable cause) { + if (cause == null) { + return WRITE_CLOSED_CHANNEL_EXCEPTION; + } + return ThrowableUtil.unknownStackTrace( + new ExtendedClosedChannelException(cause), AbstractUnsafe.class, "write(...)"); + } + + private ClosedChannelException newFlush0Exception(Throwable cause) { + if (cause == null) { + return FLUSH0_CLOSED_CHANNEL_EXCEPTION; + } + return ThrowableUtil.unknownStackTrace( + new ExtendedClosedChannelException(cause), AbstractUnsafe.class, "flush0()"); + } + + private ClosedChannelException newEnsureOpenException(Throwable cause) { + if (cause == null) { + return ENSURE_OPEN_CLOSED_CHANNEL_EXCEPTION; + } + return ThrowableUtil.unknownStackTrace( + new ExtendedClosedChannelException(cause), AbstractUnsafe.class, "ensureOpen(...)"); + } + @Override public final ChannelPromise voidPromise() { assertEventLoop(); @@ -984,7 +1015,7 @@ protected final boolean ensureOpen(ChannelPromise promise) { return true; } - safeSetFailure(promise, ENSURE_OPEN_CLOSED_CHANNEL_EXCEPTION); + safeSetFailure(promise, newEnsureOpenException(initialCloseCause)); return false; } @@ -1135,6 +1166,10 @@ protected Object filterOutboundMessage(Object msg) throws Exception { return msg; } + protected void validateFileRegion(DefaultFileRegion region, long position) throws IOException { + DefaultFileRegion.validate(region, position); + } + static final class CloseFuture extends DefaultChannelPromise { CloseFuture(AbstractChannel ch) { diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index 3b155a9e2348..f24771a50eac 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -76,10 +76,7 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap // Lazily instantiated tasks used to trigger events to a handler with different executor. // There is no need to make this volatile as at worse it will just create a few more instances then needed. - private Runnable invokeChannelReadCompleteTask; - private Runnable invokeReadTask; - private Runnable invokeChannelWritableStateChangedTask; - private Runnable invokeFlushTask; + private Tasks invokeTasks; private volatile int handlerState = INIT; @@ -379,16 +376,11 @@ static void invokeChannelReadComplete(final AbstractChannelHandlerContext next) if (executor.inEventLoop()) { next.invokeChannelReadComplete(); } else { - Runnable task = next.invokeChannelReadCompleteTask; - if (task == null) { - next.invokeChannelReadCompleteTask = task = new Runnable() { - @Override - public void run() { - next.invokeChannelReadComplete(); - } - }; + Tasks tasks = next.invokeTasks; + if (tasks == null) { + next.invokeTasks = tasks = new Tasks(next); } - executor.execute(task); + executor.execute(tasks.invokeChannelReadCompleteTask); } } @@ -415,16 +407,11 @@ static void invokeChannelWritabilityChanged(final AbstractChannelHandlerContext if (executor.inEventLoop()) { next.invokeChannelWritabilityChanged(); } else { - Runnable task = next.invokeChannelWritableStateChangedTask; - if (task == null) { - next.invokeChannelWritableStateChangedTask = task = new Runnable() { - @Override - public void run() { - next.invokeChannelWritabilityChanged(); - } - }; + Tasks tasks = next.invokeTasks; + if (tasks == null) { + next.invokeTasks = tasks = new Tasks(next); } - executor.execute(task); + executor.execute(tasks.invokeChannelWritableStateChangedTask); } } @@ -672,16 +659,11 @@ public ChannelHandlerContext read() { if (executor.inEventLoop()) { next.invokeRead(); } else { - Runnable task = next.invokeReadTask; - if (task == null) { - next.invokeReadTask = task = new Runnable() { - @Override - public void run() { - next.invokeRead(); - } - }; + Tasks tasks = next.invokeTasks; + if (tasks == null) { + next.invokeTasks = tasks = new Tasks(next); } - executor.execute(task); + executor.execute(tasks.invokeReadTask); } return this; @@ -706,20 +688,6 @@ public ChannelFuture write(Object msg) { @Override public ChannelFuture write(final Object msg, final ChannelPromise promise) { - if (msg == null) { - throw new NullPointerException("msg"); - } - - try { - if (isNotValidPromise(promise, true)) { - ReferenceCountUtil.release(msg); - // cancelled - return promise; - } - } catch (RuntimeException e) { - ReferenceCountUtil.release(msg); - throw e; - } write(msg, false, promise); return promise; @@ -748,16 +716,11 @@ public ChannelHandlerContext flush() { if (executor.inEventLoop()) { next.invokeFlush(); } else { - Runnable task = next.invokeFlushTask; - if (task == null) { - next.invokeFlushTask = task = new Runnable() { - @Override - public void run() { - next.invokeFlush(); - } - }; + Tasks tasks = next.invokeTasks; + if (tasks == null) { + next.invokeTasks = tasks = new Tasks(next); } - safeExecute(executor, task, channel().voidPromise(), null); + safeExecute(executor, tasks.invokeFlushTask, channel().voidPromise(), null); } return this; @@ -781,18 +744,7 @@ private void invokeFlush0() { @Override public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { - if (msg == null) { - throw new NullPointerException("msg"); - } - - if (isNotValidPromise(promise, true)) { - ReferenceCountUtil.release(msg); - // cancelled - return promise; - } - write(msg, true, promise); - return promise; } @@ -806,6 +758,18 @@ private void invokeWriteAndFlush(Object msg, ChannelPromise promise) { } private void write(Object msg, boolean flush, ChannelPromise promise) { + ObjectUtil.checkNotNull(msg, "msg"); + try { + if (isNotValidPromise(promise, true)) { + ReferenceCountUtil.release(msg); + // cancelled + return; + } + } catch (RuntimeException e) { + ReferenceCountUtil.release(msg); + throw e; + } + AbstractChannelHandlerContext next = findContextOutbound(); final Object m = pipeline.touch(msg, next); EventExecutor executor = next.executor(); @@ -816,13 +780,19 @@ private void write(Object msg, boolean flush, ChannelPromise promise) { next.invokeWrite(m, promise); } } else { - AbstractWriteTask task; + final AbstractWriteTask task; if (flush) { task = WriteAndFlushTask.newInstance(next, m, promise); } else { task = WriteTask.newInstance(next, m, promise); } - safeExecute(executor, task, promise, m); + if (!safeExecute(executor, task, promise, m)) { + // We failed to submit the AbstractWriteTask. We need to cancel it so we decrement the pending bytes + // and put it back in the Recycler for re-use later. + // + // See https://github.com/netty/netty/issues/8343. + task.cancel(); + } } } @@ -956,14 +926,17 @@ final void setRemoved() { handlerState = REMOVE_COMPLETE; } - final void setAddComplete() { + final boolean setAddComplete() { for (;;) { int oldState = handlerState; + if (oldState == REMOVE_COMPLETE) { + return false; + } // Ensure we never update when the handlerState is REMOVE_COMPLETE already. // oldState is usually ADD_PENDING but can also be REMOVE_COMPLETE when an EventExecutor is used that is not // exposing ordering guarantees. - if (oldState == REMOVE_COMPLETE || HANDLER_STATE_UPDATER.compareAndSet(this, oldState, ADD_COMPLETE)) { - return; + if (HANDLER_STATE_UPDATER.compareAndSet(this, oldState, ADD_COMPLETE)) { + return true; } } } @@ -973,6 +946,26 @@ final void setAddPending() { assert updated; // This should always be true as it MUST be called before setAddComplete() or setRemoved(). } + final void callHandlerAdded() throws Exception { + // We must call setAddComplete before calling handlerAdded. Otherwise if the handlerAdded method generates + // any pipeline events ctx.handler() will miss them because the state will not allow it. + if (setAddComplete()) { + handler().handlerAdded(this); + } + } + + final void callHandlerRemoved() throws Exception { + try { + // Only call handlerRemoved(...) if we called handlerAdded(...) before. + if (handlerState == ADD_COMPLETE) { + handler().handlerRemoved(this); + } + } finally { + // Mark the handler as removed in any case. + setRemoved(); + } + } + /** * Makes best possible effort to detect if {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} was called * yet. If not return {@code false} and if called or could not detect return {@code true}. @@ -1002,9 +995,10 @@ public boolean hasAttr(AttributeKey key) { return channel().hasAttr(key); } - private static void safeExecute(EventExecutor executor, Runnable runnable, ChannelPromise promise, Object msg) { + private static boolean safeExecute(EventExecutor executor, Runnable runnable, ChannelPromise promise, Object msg) { try { executor.execute(runnable); + return true; } catch (Throwable cause) { try { promise.setFailure(cause); @@ -1013,6 +1007,7 @@ private static void safeExecute(EventExecutor executor, Runnable runnable, Chann ReferenceCountUtil.release(msg); } } + return false; } } @@ -1063,20 +1058,35 @@ protected static void init(AbstractWriteTask task, AbstractChannelHandlerContext @Override public final void run() { try { - // Check for null as it may be set to null if the channel is closed already - if (ESTIMATE_TASK_SIZE_ON_SUBMIT) { - ctx.pipeline.decrementPendingOutboundBytes(size); - } + decrementPendingOutboundBytes(); write(ctx, msg, promise); } finally { - // Set to null so the GC can collect them directly - ctx = null; - msg = null; - promise = null; - handle.recycle(this); + recycle(); + } + } + + void cancel() { + try { + decrementPendingOutboundBytes(); + } finally { + recycle(); } } + private void decrementPendingOutboundBytes() { + if (ESTIMATE_TASK_SIZE_ON_SUBMIT) { + ctx.pipeline.decrementPendingOutboundBytes(size); + } + } + + private void recycle() { + // Set to null so the GC can collect them directly + ctx = null; + msg = null; + promise = null; + handle.recycle(this); + } + protected void write(AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) { ctx.invokeWrite(msg, promise); } @@ -1091,7 +1101,7 @@ protected WriteTask newObject(Handle handle) { } }; - private static WriteTask newInstance( + static WriteTask newInstance( AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) { WriteTask task = RECYCLER.get(); init(task, ctx, msg, promise); @@ -1112,7 +1122,7 @@ protected WriteAndFlushTask newObject(Handle handle) { } }; - private static WriteAndFlushTask newInstance( + static WriteAndFlushTask newInstance( AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) { WriteAndFlushTask task = RECYCLER.get(); init(task, ctx, msg, promise); @@ -1129,4 +1139,36 @@ public void write(AbstractChannelHandlerContext ctx, Object msg, ChannelPromise ctx.invokeFlush(); } } + + private static final class Tasks { + private final AbstractChannelHandlerContext next; + private final Runnable invokeChannelReadCompleteTask = new Runnable() { + @Override + public void run() { + next.invokeChannelReadComplete(); + } + }; + private final Runnable invokeReadTask = new Runnable() { + @Override + public void run() { + next.invokeRead(); + } + }; + private final Runnable invokeChannelWritableStateChangedTask = new Runnable() { + @Override + public void run() { + next.invokeChannelWritabilityChanged(); + } + }; + private final Runnable invokeFlushTask = new Runnable() { + @Override + public void run() { + next.invokeFlush(); + } + }; + + Tasks(AbstractChannelHandlerContext next) { + this.next = next; + } + } } diff --git a/transport/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java b/transport/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java index 552793ee92e2..80bf9098d3fb 100644 --- a/transport/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java +++ b/transport/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java @@ -40,7 +40,7 @@ public abstract class AbstractCoalescingBufferQueue { * * @param channel the {@link Channel} which will have the {@link Channel#isWritable()} reflect the amount of queued * buffers or {@code null} if there is no writability state updated. - * @param initSize theinitial size of the underlying queue. + * @param initSize the initial size of the underlying queue. */ protected AbstractCoalescingBufferQueue(Channel channel, int initSize) { bufAndListenerPairs = new ArrayDeque(initSize); diff --git a/transport/src/main/java/io/netty/channel/AdaptiveRecvByteBufAllocator.java b/transport/src/main/java/io/netty/channel/AdaptiveRecvByteBufAllocator.java index a2db615f6bb4..a4ab9fe5faae 100644 --- a/transport/src/main/java/io/netty/channel/AdaptiveRecvByteBufAllocator.java +++ b/transport/src/main/java/io/netty/channel/AdaptiveRecvByteBufAllocator.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; +import static io.netty.util.internal.ObjectUtil.checkPositive; import static java.lang.Math.max; import static java.lang.Math.min; @@ -95,7 +96,7 @@ private final class HandleImpl extends MaxMessageHandle { private int nextReceiveBufferSize; private boolean decreaseNow; - public HandleImpl(int minIndex, int maxIndex, int initial) { + HandleImpl(int minIndex, int maxIndex, int initial) { this.minIndex = minIndex; this.maxIndex = maxIndex; @@ -163,9 +164,7 @@ public AdaptiveRecvByteBufAllocator() { * @param maximum the inclusive upper bound of the expected buffer size */ public AdaptiveRecvByteBufAllocator(int minimum, int initial, int maximum) { - if (minimum <= 0) { - throw new IllegalArgumentException("minimum: " + minimum); - } + checkPositive(minimum, "minimum"); if (initial < minimum) { throw new IllegalArgumentException("initial: " + initial); } diff --git a/transport/src/main/java/io/netty/channel/ChannelConfig.java b/transport/src/main/java/io/netty/channel/ChannelConfig.java index 9e77da7a34f8..9550acd01e28 100644 --- a/transport/src/main/java/io/netty/channel/ChannelConfig.java +++ b/transport/src/main/java/io/netty/channel/ChannelConfig.java @@ -121,7 +121,8 @@ public interface ChannelConfig { ChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); /** - * @deprecated Use {@link MaxMessagesRecvByteBufAllocator} + * @deprecated Use {@link MaxMessagesRecvByteBufAllocator} and + * {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead()}. *

    * Returns the maximum number of messages to read per read loop. * a {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object) channelRead()} event. @@ -131,7 +132,8 @@ public interface ChannelConfig { int getMaxMessagesPerRead(); /** - * @deprecated Use {@link MaxMessagesRecvByteBufAllocator} + * @deprecated Use {@link MaxMessagesRecvByteBufAllocator} and + * {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead(int)}. *

    * Sets the maximum number of messages to read per read loop. * If this value is greater than 1, an event loop might attempt to read multiple times to procure multiple messages. @@ -195,21 +197,15 @@ public interface ChannelConfig { ChannelConfig setAutoRead(boolean autoRead); /** - * @deprecated Auto close will be removed in a future release. - * * Returns {@code true} if and only if the {@link Channel} will be closed automatically and immediately on - * write failure. The default is {@code false}. + * write failure. The default is {@code true}. */ - @Deprecated boolean isAutoClose(); /** - * @deprecated Auto close will be removed in a future release. - * * Sets whether the {@link Channel} should be closed automatically and immediately on write failure. - * The default is {@code false}. + * The default is {@code true}. */ - @Deprecated ChannelConfig setAutoClose(boolean autoClose); /** diff --git a/transport/src/main/java/io/netty/channel/ChannelDuplexHandler.java b/transport/src/main/java/io/netty/channel/ChannelDuplexHandler.java index 30faa3da846b..07c6484e5020 100644 --- a/transport/src/main/java/io/netty/channel/ChannelDuplexHandler.java +++ b/transport/src/main/java/io/netty/channel/ChannelDuplexHandler.java @@ -74,7 +74,7 @@ public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exce } /** - * Calls {@link ChannelHandlerContext#close(ChannelPromise)} to forward + * Calls {@link ChannelHandlerContext#deregister(ChannelPromise)} to forward * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. * * Sub-classes may override this method to change behavior. diff --git a/transport/src/main/java/io/netty/channel/ChannelFlushPromiseNotifier.java b/transport/src/main/java/io/netty/channel/ChannelFlushPromiseNotifier.java index 26594a38df5a..ff0b9c031b87 100644 --- a/transport/src/main/java/io/netty/channel/ChannelFlushPromiseNotifier.java +++ b/transport/src/main/java/io/netty/channel/ChannelFlushPromiseNotifier.java @@ -15,6 +15,8 @@ */ package io.netty.channel; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import java.util.ArrayDeque; import java.util.Queue; @@ -64,9 +66,7 @@ public ChannelFlushPromiseNotifier add(ChannelPromise promise, long pendingDataS if (promise == null) { throw new NullPointerException("promise"); } - if (pendingDataSize < 0) { - throw new IllegalArgumentException("pendingDataSize must be >= 0 but was " + pendingDataSize); - } + checkPositiveOrZero(pendingDataSize, "pendingDataSize"); long checkpoint = writeCounter + pendingDataSize; if (promise instanceof FlushCheckpoint) { FlushCheckpoint cp = (FlushCheckpoint) promise; @@ -81,9 +81,7 @@ public ChannelFlushPromiseNotifier add(ChannelPromise promise, long pendingDataS * Increase the current write counter by the given delta */ public ChannelFlushPromiseNotifier increaseWriteCounter(long delta) { - if (delta < 0) { - throw new IllegalArgumentException("delta must be >= 0 but was " + delta); - } + checkPositiveOrZero(delta, "delta"); writeCounter += delta; return this; } diff --git a/transport/src/main/java/io/netty/channel/ChannelHandler.java b/transport/src/main/java/io/netty/channel/ChannelHandler.java index 898f7523bcf2..f9401080eb48 100644 --- a/transport/src/main/java/io/netty/channel/ChannelHandler.java +++ b/transport/src/main/java/io/netty/channel/ChannelHandler.java @@ -73,13 +73,12 @@ * * {@code @Override} * public void channelRead0({@link ChannelHandlerContext} ctx, Message message) { - * {@link Channel} ch = e.getChannel(); * if (message instanceof LoginMessage) { * authenticate((LoginMessage) message); * loggedIn = true; * } else (message instanceof GetDataMessage) { * if (loggedIn) { - * ch.write(fetchSecret((GetDataMessage) message)); + * ctx.writeAndFlush(fetchSecret((GetDataMessage) message)); * } else { * fail(); * } @@ -123,13 +122,12 @@ * {@code @Override} * public void channelRead({@link ChannelHandlerContext} ctx, Message message) { * {@link Attribute}<{@link Boolean}> attr = ctx.attr(auth); - * {@link Channel} ch = ctx.channel(); * if (message instanceof LoginMessage) { * authenticate((LoginMessage) o); * attr.set(true); * } else (message instanceof GetDataMessage) { * if (Boolean.TRUE.equals(attr.get())) { - * ch.write(fetchSecret((GetDataMessage) o)); + * ctx.writeAndFlush(fetchSecret((GetDataMessage) o)); * } else { * fail(); * } diff --git a/transport/src/main/java/io/netty/channel/ChannelHandlerAdapter.java b/transport/src/main/java/io/netty/channel/ChannelHandlerAdapter.java index aadc691d669a..ee380b584874 100644 --- a/transport/src/main/java/io/netty/channel/ChannelHandlerAdapter.java +++ b/transport/src/main/java/io/netty/channel/ChannelHandlerAdapter.java @@ -81,8 +81,11 @@ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { * to the next {@link ChannelHandler} in the {@link ChannelPipeline}. * * Sub-classes may override this method to change behavior. + * + * @deprecated is part of {@link ChannelInboundHandler} */ @Override + @Deprecated public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { ctx.fireExceptionCaught(cause); } diff --git a/transport/src/main/java/io/netty/channel/ChannelInitializer.java b/transport/src/main/java/io/netty/channel/ChannelInitializer.java index 1d8578c3dd2b..9aa4eaa57010 100644 --- a/transport/src/main/java/io/netty/channel/ChannelInitializer.java +++ b/transport/src/main/java/io/netty/channel/ChannelInitializer.java @@ -18,11 +18,12 @@ import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelHandler.Sharable; -import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; -import java.util.concurrent.ConcurrentMap; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; /** * A special {@link ChannelInboundHandler} which offers an easy way to initialize a {@link Channel} once it was @@ -53,9 +54,10 @@ public abstract class ChannelInitializer extends ChannelInboundHandlerAdapter { private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelInitializer.class); - // We use a ConcurrentMap as a ChannelInitializer is usually shared between all Channels in a Bootstrap / + // We use a Set as a ChannelInitializer is usually shared between all Channels in a Bootstrap / // ServerBootstrap. This way we can reduce the memory usage compared to use Attributes. - private final ConcurrentMap initMap = PlatformDependent.newConcurrentHashMap(); + private final Set initMap = Collections.newSetFromMap( + new ConcurrentHashMap()); /** * This method will be called once the {@link Channel} was registered. After the method returns this instance @@ -77,6 +79,9 @@ public final void channelRegistered(ChannelHandlerContext ctx) throws Exception // we called initChannel(...) so we need to call now pipeline.fireChannelRegistered() to ensure we not // miss an event. ctx.pipeline().fireChannelRegistered(); + + // We are done with init the Channel, removing all the state for the Channel now. + removeState(ctx); } else { // Called initChannel(...) before which is the expected behavior, so just forward the event. ctx.fireChannelRegistered(); @@ -88,7 +93,9 @@ public final void channelRegistered(ChannelHandlerContext ctx) throws Exception */ @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - logger.warn("Failed to initialize a channel. Closing: " + ctx.channel(), cause); + if (logger.isWarnEnabled()) { + logger.warn("Failed to initialize a channel. Closing: " + ctx.channel(), cause); + } ctx.close(); } @@ -102,13 +109,22 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { // The good thing about calling initChannel(...) in handlerAdded(...) is that there will be no ordering // surprises if a ChannelInitializer will add another ChannelInitializer. This is as all handlers // will be added in the expected order. - initChannel(ctx); + if (initChannel(ctx)) { + + // We are done with init the Channel, removing the initializer now. + removeState(ctx); + } } } + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + initMap.remove(ctx); + } + @SuppressWarnings("unchecked") private boolean initChannel(ChannelHandlerContext ctx) throws Exception { - if (initMap.putIfAbsent(ctx, Boolean.TRUE) == null) { // Guard against re-entrance. + if (initMap.add(ctx)) { // Guard against re-entrance. try { initChannel((C) ctx.channel()); } catch (Throwable cause) { @@ -116,21 +132,29 @@ private boolean initChannel(ChannelHandlerContext ctx) throws Exception { // We do so to prevent multiple calls to initChannel(...). exceptionCaught(ctx, cause); } finally { - remove(ctx); + ChannelPipeline pipeline = ctx.pipeline(); + if (pipeline.context(this) != null) { + pipeline.remove(this); + } } return true; } return false; } - private void remove(ChannelHandlerContext ctx) { - try { - ChannelPipeline pipeline = ctx.pipeline(); - if (pipeline.context(this) != null) { - pipeline.remove(this); - } - } finally { + private void removeState(final ChannelHandlerContext ctx) { + // The removal may happen in an async fashion if the EventExecutor we use does something funky. + if (ctx.isRemoved()) { initMap.remove(ctx); + } else { + // The context is not removed yet which is most likely the case because a custom EventExecutor is used. + // Let's schedule it on the EventExecutor to give it some more time to be completed in case it is offloaded. + ctx.executor().execute(new Runnable() { + @Override + public void run() { + initMap.remove(ctx); + } + }); } } } diff --git a/transport/src/main/java/io/netty/channel/ChannelMetadata.java b/transport/src/main/java/io/netty/channel/ChannelMetadata.java index c77f53091668..e3b116b9883a 100644 --- a/transport/src/main/java/io/netty/channel/ChannelMetadata.java +++ b/transport/src/main/java/io/netty/channel/ChannelMetadata.java @@ -15,6 +15,8 @@ */ package io.netty.channel; +import static io.netty.util.internal.ObjectUtil.checkPositive; + import java.net.SocketAddress; /** @@ -46,10 +48,7 @@ public ChannelMetadata(boolean hasDisconnect) { * set for {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead()}. Must be {@code > 0}. */ public ChannelMetadata(boolean hasDisconnect, int defaultMaxMessagesPerRead) { - if (defaultMaxMessagesPerRead <= 0) { - throw new IllegalArgumentException("defaultMaxMessagesPerRead: " + defaultMaxMessagesPerRead + - " (expected > 0)"); - } + checkPositive(defaultMaxMessagesPerRead, "defaultMaxMessagesPerRead"); this.hasDisconnect = hasDisconnect; this.defaultMaxMessagesPerRead = defaultMaxMessagesPerRead; } diff --git a/transport/src/main/java/io/netty/channel/ChannelOption.java b/transport/src/main/java/io/netty/channel/ChannelOption.java index 4cb909a1db1a..97bf31545c4d 100644 --- a/transport/src/main/java/io/netty/channel/ChannelOption.java +++ b/transport/src/main/java/io/netty/channel/ChannelOption.java @@ -78,6 +78,7 @@ public static ChannelOption newInstance(String name) { public static final ChannelOption CONNECT_TIMEOUT_MILLIS = valueOf("CONNECT_TIMEOUT_MILLIS"); /** * @deprecated Use {@link MaxMessagesRecvByteBufAllocator} + * and {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead(int)}. */ @Deprecated public static final ChannelOption MAX_MESSAGES_PER_READ = valueOf("MAX_MESSAGES_PER_READ"); @@ -99,12 +100,9 @@ public static ChannelOption newInstance(String name) { public static final ChannelOption AUTO_READ = valueOf("AUTO_READ"); /** - * @deprecated Auto close will be removed in a future release. - * * If {@code true} then the {@link Channel} is closed automatically and immediately on write failure. * The default value is {@code true}. */ - @Deprecated public static final ChannelOption AUTO_CLOSE = valueOf("AUTO_CLOSE"); public static final ChannelOption SO_BROADCAST = valueOf("SO_BROADCAST"); diff --git a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java index 155d73f35eb3..d3a934a82973 100644 --- a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java +++ b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java @@ -434,21 +434,9 @@ public ByteBuffer[] nioBuffers(int maxCount, long maxBytes) { } nioBuffers[nioBufferCount++] = nioBuf; } else { - ByteBuffer[] nioBufs = entry.bufs; - if (nioBufs == null) { - // cached ByteBuffers as they may be expensive to create in terms - // of Object allocation - entry.bufs = nioBufs = buf.nioBuffers(); - } - for (int i = 0; i < nioBufs.length && nioBufferCount < maxCount; ++i) { - ByteBuffer nioBuf = nioBufs[i]; - if (nioBuf == null) { - break; - } else if (!nioBuf.hasRemaining()) { - continue; - } - nioBuffers[nioBufferCount++] = nioBuf; - } + // The code exists in an extra method to ensure the method is not too big to inline as this + // branch is not very likely to get hit very frequently. + nioBufferCount = nioBuffers(entry, buf, nioBuffers, nioBufferCount, maxCount); } if (nioBufferCount == maxCount) { break; @@ -463,6 +451,25 @@ public ByteBuffer[] nioBuffers(int maxCount, long maxBytes) { return nioBuffers; } + private static int nioBuffers(Entry entry, ByteBuf buf, ByteBuffer[] nioBuffers, int nioBufferCount, int maxCount) { + ByteBuffer[] nioBufs = entry.bufs; + if (nioBufs == null) { + // cached ByteBuffers as they may be expensive to create in terms + // of Object allocation + entry.bufs = nioBufs = buf.nioBuffers(); + } + for (int i = 0; i < nioBufs.length && nioBufferCount < maxCount; ++i) { + ByteBuffer nioBuf = nioBufs[i]; + if (nioBuf == null) { + break; + } else if (!nioBuf.hasRemaining()) { + continue; + } + nioBuffers[nioBufferCount++] = nioBuf; + } + return nioBufferCount; + } + private static ByteBuffer[] expandNioBufferArray(ByteBuffer[] array, int neededSpace, int size) { int newCapacity = array.length; do { diff --git a/transport/src/main/java/io/netty/channel/ChannelPipeline.java b/transport/src/main/java/io/netty/channel/ChannelPipeline.java index 9652637278e0..c415d8536d5e 100644 --- a/transport/src/main/java/io/netty/channel/ChannelPipeline.java +++ b/transport/src/main/java/io/netty/channel/ChannelPipeline.java @@ -597,7 +597,7 @@ T replace(Class oldHandlerType, String newName, @Override ChannelPipeline fireChannelRegistered(); - @Override + @Override ChannelPipeline fireChannelUnregistered(); @Override diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelConfig.java b/transport/src/main/java/io/netty/channel/DefaultChannelConfig.java index 4118708637ce..405678a26987 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelConfig.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelConfig.java @@ -36,6 +36,8 @@ import static io.netty.channel.ChannelOption.WRITE_BUFFER_WATER_MARK; import static io.netty.channel.ChannelOption.WRITE_SPIN_COUNT; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; /** * The default {@link ChannelConfig} implementation. @@ -209,10 +211,7 @@ public int getConnectTimeoutMillis() { @Override public ChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { - if (connectTimeoutMillis < 0) { - throw new IllegalArgumentException(String.format( - "connectTimeoutMillis: %d (expected: >= 0)", connectTimeoutMillis)); - } + checkPositiveOrZero(connectTimeoutMillis, "connectTimeoutMillis"); this.connectTimeoutMillis = connectTimeoutMillis; return this; } @@ -261,10 +260,7 @@ public int getWriteSpinCount() { @Override public ChannelConfig setWriteSpinCount(int writeSpinCount) { - if (writeSpinCount <= 0) { - throw new IllegalArgumentException( - "writeSpinCount must be a positive integer."); - } + checkPositive(writeSpinCount, "writeSpinCount"); // Integer.MAX_VALUE is used as a special value in the channel implementations to indicate the channel cannot // accept any more data, and results in the writeOp being set on the selector (or execute a runnable which tries // to flush later because the writeSpinCount quantum has been exhausted). This strategy prevents additional @@ -357,10 +353,7 @@ public int getWriteBufferHighWaterMark() { @Override public ChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) { - if (writeBufferHighWaterMark < 0) { - throw new IllegalArgumentException( - "writeBufferHighWaterMark must be >= 0"); - } + checkPositiveOrZero(writeBufferHighWaterMark, "writeBufferHighWaterMark"); for (;;) { WriteBufferWaterMark waterMark = writeBufferWaterMark; if (writeBufferHighWaterMark < waterMark.low()) { @@ -383,10 +376,7 @@ public int getWriteBufferLowWaterMark() { @Override public ChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) { - if (writeBufferLowWaterMark < 0) { - throw new IllegalArgumentException( - "writeBufferLowWaterMark must be >= 0"); - } + checkPositiveOrZero(writeBufferLowWaterMark, "writeBufferLowWaterMark"); for (;;) { WriteBufferWaterMark waterMark = writeBufferWaterMark; if (writeBufferLowWaterMark > waterMark.high()) { diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java index 0d3307a5351a..2b307cc911f4 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java @@ -53,7 +53,7 @@ public class DefaultChannelPipeline implements ChannelPipeline { private static final FastThreadLocal, String>> nameCaches = new FastThreadLocal, String>>() { @Override - protected Map, String> initialValue() throws Exception { + protected Map, String> initialValue() { return new WeakHashMap, String>(); } }; @@ -163,7 +163,7 @@ public final ChannelPipeline addFirst(EventExecutorGroup group, String name, Cha addFirst0(newCtx); - // If the registered is false it means that the channel was not registered on an eventloop yet. + // If the registered is false it means that the channel was not registered on an eventLoop yet. // In this case we add the context to the pipeline and add a task that will call // ChannelHandler.handlerAdded(...) once the channel is registered. if (!registered) { @@ -174,13 +174,7 @@ public final ChannelPipeline addFirst(EventExecutorGroup group, String name, Cha EventExecutor executor = newCtx.executor(); if (!executor.inEventLoop()) { - newCtx.setAddPending(); - executor.execute(new Runnable() { - @Override - public void run() { - callHandlerAdded0(newCtx); - } - }); + callHandlerAddedInEventLoop(newCtx, executor); return this; } } @@ -211,7 +205,7 @@ public final ChannelPipeline addLast(EventExecutorGroup group, String name, Chan addLast0(newCtx); - // If the registered is false it means that the channel was not registered on an eventloop yet. + // If the registered is false it means that the channel was not registered on an eventLoop yet. // In this case we add the context to the pipeline and add a task that will call // ChannelHandler.handlerAdded(...) once the channel is registered. if (!registered) { @@ -222,13 +216,7 @@ public final ChannelPipeline addLast(EventExecutorGroup group, String name, Chan EventExecutor executor = newCtx.executor(); if (!executor.inEventLoop()) { - newCtx.setAddPending(); - executor.execute(new Runnable() { - @Override - public void run() { - callHandlerAdded0(newCtx); - } - }); + callHandlerAddedInEventLoop(newCtx, executor); return this; } } @@ -263,7 +251,7 @@ public final ChannelPipeline addBefore( addBefore0(ctx, newCtx); - // If the registered is false it means that the channel was not registered on an eventloop yet. + // If the registered is false it means that the channel was not registered on an eventLoop yet. // In this case we add the context to the pipeline and add a task that will call // ChannelHandler.handlerAdded(...) once the channel is registered. if (!registered) { @@ -274,13 +262,7 @@ public final ChannelPipeline addBefore( EventExecutor executor = newCtx.executor(); if (!executor.inEventLoop()) { - newCtx.setAddPending(); - executor.execute(new Runnable() { - @Override - public void run() { - callHandlerAdded0(newCtx); - } - }); + callHandlerAddedInEventLoop(newCtx, executor); return this; } } @@ -323,7 +305,7 @@ public final ChannelPipeline addAfter( addAfter0(ctx, newCtx); - // If the registered is false it means that the channel was not registered on an eventloop yet. + // If the registered is false it means that the channel was not registered on an eventLoop yet. // In this case we remove the context from the pipeline and add a task that will call // ChannelHandler.handlerRemoved(...) once the channel is registered. if (!registered) { @@ -333,13 +315,7 @@ public final ChannelPipeline addAfter( } EventExecutor executor = newCtx.executor(); if (!executor.inEventLoop()) { - newCtx.setAddPending(); - executor.execute(new Runnable() { - @Override - public void run() { - callHandlerAdded0(newCtx); - } - }); + callHandlerAddedInEventLoop(newCtx, executor); return this; } } @@ -631,19 +607,12 @@ private static void checkMultiplicity(ChannelHandler handler) { private void callHandlerAdded0(final AbstractChannelHandlerContext ctx) { try { - // We must call setAddComplete before calling handlerAdded. Otherwise if the handlerAdded method generates - // any pipeline events ctx.handler() will miss them because the state will not allow it. - ctx.setAddComplete(); - ctx.handler().handlerAdded(ctx); + ctx.callHandlerAdded(); } catch (Throwable t) { boolean removed = false; try { remove0(ctx); - try { - ctx.handler().handlerRemoved(ctx); - } finally { - ctx.setRemoved(); - } + ctx.callHandlerRemoved(); removed = true; } catch (Throwable t2) { if (logger.isWarnEnabled()) { @@ -666,11 +635,7 @@ private void callHandlerAdded0(final AbstractChannelHandlerContext ctx) { private void callHandlerRemoved0(final AbstractChannelHandlerContext ctx) { // Notify the complete removal. try { - try { - ctx.handler().handlerRemoved(ctx); - } finally { - ctx.setRemoved(); - } + ctx.callHandlerRemoved(); } catch (Throwable t) { fireExceptionCaught(new ChannelPipelineException( ctx.handler().getClass().getName() + ".handlerRemoved() has thrown an exception.", t)); @@ -1179,6 +1144,16 @@ private void callHandlerCallbackLater(AbstractChannelHandlerContext ctx, boolean } } + private void callHandlerAddedInEventLoop(final AbstractChannelHandlerContext newCtx, EventExecutor executor) { + newCtx.setAddPending(); + executor.execute(new Runnable() { + @Override + public void run() { + callHandlerAdded0(newCtx); + } + }); + } + /** * Called once a {@link Throwable} hit the end of the {@link ChannelPipeline} without been handled by the user * in {@link ChannelHandler#exceptionCaught(ChannelHandlerContext, Throwable)}. @@ -1278,49 +1253,49 @@ public ChannelHandler handler() { } @Override - public void channelRegistered(ChannelHandlerContext ctx) throws Exception { } + public void channelRegistered(ChannelHandlerContext ctx) { } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { } + public void channelUnregistered(ChannelHandlerContext ctx) { } @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { + public void channelActive(ChannelHandlerContext ctx) { onUnhandledInboundChannelActive(); } @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { + public void channelInactive(ChannelHandlerContext ctx) { onUnhandledInboundChannelInactive(); } @Override - public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + public void channelWritabilityChanged(ChannelHandlerContext ctx) { onUnhandledChannelWritabilityChanged(); } @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { } + public void handlerAdded(ChannelHandlerContext ctx) { } @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { } + public void handlerRemoved(ChannelHandlerContext ctx) { } @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { onUnhandledInboundUserEventTriggered(evt); } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { onUnhandledInboundException(cause); } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object msg) { onUnhandledInboundMessage(msg); } @Override - public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + public void channelReadComplete(ChannelHandlerContext ctx) { onUnhandledInboundChannelReadComplete(); } } @@ -1331,7 +1306,7 @@ final class HeadContext extends AbstractChannelHandlerContext private final Unsafe unsafe; HeadContext(DefaultChannelPipeline pipeline) { - super(pipeline, null, HEAD_NAME, false, true); + super(pipeline, null, HEAD_NAME, true, true); unsafe = pipeline.channel().unsafe(); setAddComplete(); } @@ -1342,19 +1317,18 @@ public ChannelHandler handler() { } @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + public void handlerAdded(ChannelHandlerContext ctx) { // NOOP } @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + public void handlerRemoved(ChannelHandlerContext ctx) { // NOOP } @Override public void bind( - ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) - throws Exception { + ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { unsafe.bind(localAddress, promise); } @@ -1362,22 +1336,22 @@ public void bind( public void connect( ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, - ChannelPromise promise) throws Exception { + ChannelPromise promise) { unsafe.connect(remoteAddress, localAddress, promise); } @Override - public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { unsafe.disconnect(promise); } @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { unsafe.close(promise); } @Override - public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { unsafe.deregister(promise); } @@ -1387,28 +1361,28 @@ public void read(ChannelHandlerContext ctx) { } @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { unsafe.write(msg, promise); } @Override - public void flush(ChannelHandlerContext ctx) throws Exception { + public void flush(ChannelHandlerContext ctx) { unsafe.flush(); } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { ctx.fireExceptionCaught(cause); } @Override - public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + public void channelRegistered(ChannelHandlerContext ctx) { invokeHandlerAddedIfNeeded(); ctx.fireChannelRegistered(); } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + public void channelUnregistered(ChannelHandlerContext ctx) { ctx.fireChannelUnregistered(); // Remove all handlers sequentially if channel is closed and unregistered. @@ -1418,24 +1392,24 @@ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { + public void channelActive(ChannelHandlerContext ctx) { ctx.fireChannelActive(); readIfIsAutoRead(); } @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { + public void channelInactive(ChannelHandlerContext ctx) { ctx.fireChannelInactive(); } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object msg) { ctx.fireChannelRead(msg); } @Override - public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + public void channelReadComplete(ChannelHandlerContext ctx) { ctx.fireChannelReadComplete(); readIfIsAutoRead(); @@ -1448,12 +1422,12 @@ private void readIfIsAutoRead() { } @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { ctx.fireUserEventTriggered(evt); } @Override - public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + public void channelWritabilityChanged(ChannelHandlerContext ctx) { ctx.fireChannelWritabilityChanged(); } } diff --git a/transport/src/main/java/io/netty/channel/DefaultFileRegion.java b/transport/src/main/java/io/netty/channel/DefaultFileRegion.java index 2ccb48586026..2f6bea95a1cf 100644 --- a/transport/src/main/java/io/netty/channel/DefaultFileRegion.java +++ b/transport/src/main/java/io/netty/channel/DefaultFileRegion.java @@ -15,6 +15,8 @@ */ package io.netty.channel; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.util.AbstractReferenceCounted; import io.netty.util.IllegalReferenceCountException; import io.netty.util.internal.logging.InternalLogger; @@ -44,7 +46,7 @@ public class DefaultFileRegion extends AbstractReferenceCounted implements FileR /** * Create a new instance * - * @param file the {@link FileChannel} which should be transfered + * @param file the {@link FileChannel} which should be transferred * @param position the position from which the transfer should start * @param count the number of bytes to transfer */ @@ -52,12 +54,8 @@ public DefaultFileRegion(FileChannel file, long position, long count) { if (file == null) { throw new NullPointerException("file"); } - if (position < 0) { - throw new IllegalArgumentException("position must be >= 0 but was " + position); - } - if (count < 0) { - throw new IllegalArgumentException("count must be >= 0 but was " + count); - } + checkPositiveOrZero(position, "position"); + checkPositiveOrZero(count, "count"); this.file = file; this.position = position; this.count = count; @@ -68,7 +66,7 @@ public DefaultFileRegion(FileChannel file, long position, long count) { * Create a new instance using the given {@link File}. The {@link File} will be opened lazily or * explicitly via {@link #open()}. * - * @param f the {@link File} which should be transfered + * @param f the {@link File} which should be transferred * @param position the position from which the transfer should start * @param count the number of bytes to transfer */ @@ -76,12 +74,8 @@ public DefaultFileRegion(File f, long position, long count) { if (f == null) { throw new NullPointerException("f"); } - if (position < 0) { - throw new IllegalArgumentException("position must be >= 0 but was " + position); - } - if (count < 0) { - throw new IllegalArgumentException("count must be >= 0 but was " + count); - } + checkPositiveOrZero(position, "position"); + checkPositiveOrZero(count, "count"); this.position = position; this.count = count; this.f = f; @@ -145,6 +139,12 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep long written = file.transferTo(this.position + position, count, target); if (written > 0) { transferred += written; + } else if (written == 0) { + // If the amount of written data is 0 we need to check if the requested count is bigger then the + // actual file itself as it may have been truncated on disk. + // + // See https://github.com/netty/netty/issues/8868 + validate(this, position); } return written; } @@ -188,4 +188,16 @@ public FileRegion touch() { public FileRegion touch(Object hint) { return this; } + + static void validate(DefaultFileRegion region, long position) throws IOException { + // If the amount of written data is 0 we need to check if the requested count is bigger then the + // actual file itself as it may have been truncated on disk. + // + // See https://github.com/netty/netty/issues/8868 + long size = region.file.size(); + long count = region.count - position; + if (region.position + count + position > size) { + throw new IOException("Underlying file size " + size + " smaller then requested count " + region.count); + } + } } diff --git a/transport/src/main/java/io/netty/channel/DefaultMaxBytesRecvByteBufAllocator.java b/transport/src/main/java/io/netty/channel/DefaultMaxBytesRecvByteBufAllocator.java index 3732bf0db161..ae5e465b32c4 100644 --- a/transport/src/main/java/io/netty/channel/DefaultMaxBytesRecvByteBufAllocator.java +++ b/transport/src/main/java/io/netty/channel/DefaultMaxBytesRecvByteBufAllocator.java @@ -15,6 +15,8 @@ */ package io.netty.channel; +import static io.netty.util.internal.ObjectUtil.checkPositive; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.UncheckedBooleanSupplier; @@ -128,9 +130,7 @@ public int maxBytesPerRead() { @Override public DefaultMaxBytesRecvByteBufAllocator maxBytesPerRead(int maxBytesPerRead) { - if (maxBytesPerRead <= 0) { - throw new IllegalArgumentException("maxBytesPerRead: " + maxBytesPerRead + " (expected: > 0)"); - } + checkPositive(maxBytesPerRead, "maxBytesPerRead"); // There is a dependency between this.maxBytesPerRead and this.maxBytesPerIndividualRead (a < b). // Write operations must be synchronized, but independent read operations can just be volatile. synchronized (this) { @@ -153,10 +153,7 @@ public int maxBytesPerIndividualRead() { @Override public DefaultMaxBytesRecvByteBufAllocator maxBytesPerIndividualRead(int maxBytesPerIndividualRead) { - if (maxBytesPerIndividualRead <= 0) { - throw new IllegalArgumentException( - "maxBytesPerIndividualRead: " + maxBytesPerIndividualRead + " (expected: > 0)"); - } + checkPositive(maxBytesPerIndividualRead, "maxBytesPerIndividualRead"); // There is a dependency between this.maxBytesPerRead and this.maxBytesPerIndividualRead (a < b). // Write operations must be synchronized, but independent read operations can just be volatile. synchronized (this) { @@ -178,13 +175,8 @@ public synchronized Entry maxBytesPerReadPair() { } private static void checkMaxBytesPerReadPair(int maxBytesPerRead, int maxBytesPerIndividualRead) { - if (maxBytesPerRead <= 0) { - throw new IllegalArgumentException("maxBytesPerRead: " + maxBytesPerRead + " (expected: > 0)"); - } - if (maxBytesPerIndividualRead <= 0) { - throw new IllegalArgumentException( - "maxBytesPerIndividualRead: " + maxBytesPerIndividualRead + " (expected: > 0)"); - } + checkPositive(maxBytesPerRead, "maxBytesPerRead"); + checkPositive(maxBytesPerIndividualRead, "maxBytesPerIndividualRead"); if (maxBytesPerRead < maxBytesPerIndividualRead) { throw new IllegalArgumentException( "maxBytesPerRead cannot be less than " + diff --git a/transport/src/main/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocator.java b/transport/src/main/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocator.java index 1381696be240..23baba9ec2da 100644 --- a/transport/src/main/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocator.java +++ b/transport/src/main/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocator.java @@ -15,6 +15,8 @@ */ package io.netty.channel; +import static io.netty.util.internal.ObjectUtil.checkPositive; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.UncheckedBooleanSupplier; @@ -42,9 +44,7 @@ public int maxMessagesPerRead() { @Override public MaxMessagesRecvByteBufAllocator maxMessagesPerRead(int maxMessagesPerRead) { - if (maxMessagesPerRead <= 0) { - throw new IllegalArgumentException("maxMessagesPerRead: " + maxMessagesPerRead + " (expected: > 0)"); - } + checkPositive(maxMessagesPerRead, "maxMessagesPerRead"); this.maxMessagesPerRead = maxMessagesPerRead; return this; } diff --git a/transport/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java b/transport/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java index 1459743259c8..cedbb2ad4228 100644 --- a/transport/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java +++ b/transport/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java @@ -15,6 +15,8 @@ */ package io.netty.channel; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufHolder; @@ -59,9 +61,7 @@ public int size(Object msg) { * @param unknownSize The size which is returned for unknown messages. */ public DefaultMessageSizeEstimator(int unknownSize) { - if (unknownSize < 0) { - throw new IllegalArgumentException("unknownSize: " + unknownSize + " (expected: >= 0)"); - } + checkPositiveOrZero(unknownSize, "unknownSize"); handle = new HandleImpl(unknownSize); } diff --git a/transport/src/main/java/io/netty/channel/ExtendedClosedChannelException.java b/transport/src/main/java/io/netty/channel/ExtendedClosedChannelException.java new file mode 100644 index 000000000000..3b908cd1930a --- /dev/null +++ b/transport/src/main/java/io/netty/channel/ExtendedClosedChannelException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.nio.channels.ClosedChannelException; + +final class ExtendedClosedChannelException extends ClosedChannelException { + + ExtendedClosedChannelException(Throwable cause) { + if (cause != null) { + initCause(cause); + } + } + + @Override + public Throwable fillInStackTrace() { + return this; + } +} diff --git a/transport/src/main/java/io/netty/channel/FileRegion.java b/transport/src/main/java/io/netty/channel/FileRegion.java index 48fde291317c..8c7128587642 100644 --- a/transport/src/main/java/io/netty/channel/FileRegion.java +++ b/transport/src/main/java/io/netty/channel/FileRegion.java @@ -58,7 +58,7 @@ public interface FileRegion extends ReferenceCounted { long position(); /** - * Returns the bytes which was transfered already. + * Returns the bytes which was transferred already. * * @deprecated Use {@link #transferred()} instead. */ @@ -66,7 +66,7 @@ public interface FileRegion extends ReferenceCounted { long transfered(); /** - * Returns the bytes which was transfered already. + * Returns the bytes which was transferred already. */ long transferred(); diff --git a/transport/src/main/java/io/netty/channel/FixedRecvByteBufAllocator.java b/transport/src/main/java/io/netty/channel/FixedRecvByteBufAllocator.java index 8dab77dad77f..8bdd43df81bf 100644 --- a/transport/src/main/java/io/netty/channel/FixedRecvByteBufAllocator.java +++ b/transport/src/main/java/io/netty/channel/FixedRecvByteBufAllocator.java @@ -15,6 +15,8 @@ */ package io.netty.channel; +import static io.netty.util.internal.ObjectUtil.checkPositive; + /** * The {@link RecvByteBufAllocator} that always yields the same buffer * size prediction. This predictor ignores the feed back from the I/O thread. @@ -26,7 +28,7 @@ public class FixedRecvByteBufAllocator extends DefaultMaxMessagesRecvByteBufAllo private final class HandleImpl extends MaxMessageHandle { private final int bufferSize; - public HandleImpl(int bufferSize) { + HandleImpl(int bufferSize) { this.bufferSize = bufferSize; } @@ -41,10 +43,7 @@ public int guess() { * the specified buffer size. */ public FixedRecvByteBufAllocator(int bufferSize) { - if (bufferSize <= 0) { - throw new IllegalArgumentException( - "bufferSize must greater than 0: " + bufferSize); - } + checkPositive(bufferSize, "bufferSize"); this.bufferSize = bufferSize; } diff --git a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java index 16ae47bd24a8..bbc661bf6c80 100644 --- a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java +++ b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java @@ -130,7 +130,7 @@ public ChannelFuture removeAndWriteAll() { } ChannelPromise p = ctx.newPromise(); - PromiseCombiner combiner = new PromiseCombiner(); + PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); try { // It is possible for some of the written promises to trigger more writes. The new writes // will "revive" the queue, so we need to write them up until the queue is empty. diff --git a/transport/src/main/java/io/netty/channel/RecvByteBufAllocator.java b/transport/src/main/java/io/netty/channel/RecvByteBufAllocator.java index 5a34ad27f3da..39aab5c7717a 100644 --- a/transport/src/main/java/io/netty/channel/RecvByteBufAllocator.java +++ b/transport/src/main/java/io/netty/channel/RecvByteBufAllocator.java @@ -96,7 +96,7 @@ interface Handle { int attemptedBytesRead(); /** - * Determine if the current read loop should should continue. + * Determine if the current read loop should continue. * @return {@code true} if the read loop should continue reading. {@code false} if the read loop is complete. */ boolean continueReading(); diff --git a/transport/src/main/java/io/netty/channel/ReflectiveChannelFactory.java b/transport/src/main/java/io/netty/channel/ReflectiveChannelFactory.java index 502d15f17327..677c1c071d21 100644 --- a/transport/src/main/java/io/netty/channel/ReflectiveChannelFactory.java +++ b/transport/src/main/java/io/netty/channel/ReflectiveChannelFactory.java @@ -16,33 +16,40 @@ package io.netty.channel; +import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.StringUtil; +import java.lang.reflect.Constructor; + /** * A {@link ChannelFactory} that instantiates a new {@link Channel} by invoking its default constructor reflectively. */ public class ReflectiveChannelFactory implements ChannelFactory { - private final Class clazz; + private final Constructor constructor; public ReflectiveChannelFactory(Class clazz) { - if (clazz == null) { - throw new NullPointerException("clazz"); + ObjectUtil.checkNotNull(clazz, "clazz"); + try { + this.constructor = clazz.getConstructor(); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException("Class " + StringUtil.simpleClassName(clazz) + + " does not have a public non-arg constructor", e); } - this.clazz = clazz; } @Override public T newChannel() { try { - return clazz.getConstructor().newInstance(); + return constructor.newInstance(); } catch (Throwable t) { - throw new ChannelException("Unable to create Channel from class " + clazz, t); + throw new ChannelException("Unable to create Channel from class " + constructor.getDeclaringClass(), t); } } @Override public String toString() { - return StringUtil.simpleClassName(clazz) + ".class"; + return StringUtil.simpleClassName(ReflectiveChannelFactory.class) + + '(' + StringUtil.simpleClassName(constructor.getDeclaringClass()) + ".class)"; } } diff --git a/transport/src/main/java/io/netty/channel/SelectStrategy.java b/transport/src/main/java/io/netty/channel/SelectStrategy.java index 0c8aca5789a7..447fb7f3f100 100644 --- a/transport/src/main/java/io/netty/channel/SelectStrategy.java +++ b/transport/src/main/java/io/netty/channel/SelectStrategy.java @@ -33,6 +33,10 @@ public interface SelectStrategy { * Indicates the IO loop should be retried, no blocking select to follow directly. */ int CONTINUE = -2; + /** + * Indicates the IO loop to poll for new events without blocking. + */ + int BUSY_WAIT = -3; /** * The {@link SelectStrategy} can be used to steer the outcome of a potential select diff --git a/transport/src/main/java/io/netty/channel/SimpleUserEventChannelHandler.java b/transport/src/main/java/io/netty/channel/SimpleUserEventChannelHandler.java new file mode 100644 index 000000000000..c976c2d26c25 --- /dev/null +++ b/transport/src/main/java/io/netty/channel/SimpleUserEventChannelHandler.java @@ -0,0 +1,120 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.TypeParameterMatcher; + +/** + * {@link ChannelInboundHandlerAdapter} which allows to conveniently only handle a specific type of user events. + * + * For example, here is an implementation which only handle {@link String} user events. + * + *

    + *     public class StringEventHandler extends
    + *             {@link SimpleUserEventChannelHandler}<{@link String}> {
    + *
    + *         {@code @Override}
    + *         protected void eventReceived({@link ChannelHandlerContext} ctx, {@link String} evt)
    + *                 throws {@link Exception} {
    + *             System.out.println(evt);
    + *         }
    + *     }
    + * 
    + * + * Be aware that depending of the constructor parameters it will release all handled events by passing them to + * {@link ReferenceCountUtil#release(Object)}. In this case you may need to use + * {@link ReferenceCountUtil#retain(Object)} if you pass the object to the next handler in the {@link ChannelPipeline}. + */ +public abstract class SimpleUserEventChannelHandler extends ChannelInboundHandlerAdapter { + + private final TypeParameterMatcher matcher; + private final boolean autoRelease; + + /** + * see {@link #SimpleUserEventChannelHandler(boolean)} with {@code true} as boolean parameter. + */ + protected SimpleUserEventChannelHandler() { + this(true); + } + + /** + * Create a new instance which will try to detect the types to match out of the type parameter of the class. + * + * @param autoRelease {@code true} if handled events should be released automatically by passing them to + * {@link ReferenceCountUtil#release(Object)}. + */ + protected SimpleUserEventChannelHandler(boolean autoRelease) { + matcher = TypeParameterMatcher.find(this, SimpleUserEventChannelHandler.class, "I"); + this.autoRelease = autoRelease; + } + + /** + * see {@link #SimpleUserEventChannelHandler(Class, boolean)} with {@code true} as boolean value. + */ + protected SimpleUserEventChannelHandler(Class eventType) { + this(eventType, true); + } + + /** + * Create a new instance + * + * @param eventType The type of events to match + * @param autoRelease {@code true} if handled events should be released automatically by passing them to + * {@link ReferenceCountUtil#release(Object)}. + */ + protected SimpleUserEventChannelHandler(Class eventType, boolean autoRelease) { + matcher = TypeParameterMatcher.get(eventType); + this.autoRelease = autoRelease; + } + + /** + * Returns {@code true} if the given user event should be handled. If {@code false} it will be passed to the next + * {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + */ + protected boolean acceptEvent(Object evt) throws Exception { + return matcher.match(evt); + } + + @Override + public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + boolean release = true; + try { + if (acceptEvent(evt)) { + @SuppressWarnings("unchecked") + I ievt = (I) evt; + eventReceived(ctx, ievt); + } else { + release = false; + ctx.fireUserEventTriggered(evt); + } + } finally { + if (autoRelease && release) { + ReferenceCountUtil.release(evt); + } + } + } + + /** + * Is called for each user event triggered of type {@link I}. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link SimpleUserEventChannelHandler} belongs to + * @param evt the user event to handle + * + * @throws Exception is thrown if an error occurred + */ + protected abstract void eventReceived(ChannelHandlerContext ctx, I evt) throws Exception; +} diff --git a/transport/src/main/java/io/netty/channel/ThreadPerChannelEventLoop.java b/transport/src/main/java/io/netty/channel/ThreadPerChannelEventLoop.java index b10d79de7eb1..1d4b95817c0b 100644 --- a/transport/src/main/java/io/netty/channel/ThreadPerChannelEventLoop.java +++ b/transport/src/main/java/io/netty/channel/ThreadPerChannelEventLoop.java @@ -19,7 +19,9 @@ * {@link SingleThreadEventLoop} which is used to handle OIO {@link Channel}'s. So in general there will be * one {@link ThreadPerChannelEventLoop} per {@link Channel}. * + * @deprecated this will be remove in the next-major release. */ +@Deprecated public class ThreadPerChannelEventLoop extends SingleThreadEventLoop { private final ThreadPerChannelEventLoopGroup parent; diff --git a/transport/src/main/java/io/netty/channel/ThreadPerChannelEventLoopGroup.java b/transport/src/main/java/io/netty/channel/ThreadPerChannelEventLoopGroup.java index f254b8de92e1..7ee89d0211d5 100644 --- a/transport/src/main/java/io/netty/channel/ThreadPerChannelEventLoopGroup.java +++ b/transport/src/main/java/io/netty/channel/ThreadPerChannelEventLoopGroup.java @@ -42,7 +42,10 @@ /** * An {@link EventLoopGroup} that creates one {@link EventLoop} per {@link Channel}. + * + * @deprecated this will be remove in the next-major release. */ +@Deprecated public class ThreadPerChannelEventLoopGroup extends AbstractEventExecutorGroup implements EventLoopGroup { private final Object[] childArgs; diff --git a/transport/src/main/java/io/netty/channel/WriteBufferWaterMark.java b/transport/src/main/java/io/netty/channel/WriteBufferWaterMark.java index ee3d4666a153..3deb74f6bd53 100644 --- a/transport/src/main/java/io/netty/channel/WriteBufferWaterMark.java +++ b/transport/src/main/java/io/netty/channel/WriteBufferWaterMark.java @@ -15,6 +15,8 @@ */ package io.netty.channel; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + /** * WriteBufferWaterMark is used to set low water mark and high water mark for the write buffer. *

    @@ -54,9 +56,7 @@ public WriteBufferWaterMark(int low, int high) { */ WriteBufferWaterMark(int low, int high, boolean validate) { if (validate) { - if (low < 0) { - throw new IllegalArgumentException("write buffer's low water mark must be >= 0"); - } + checkPositiveOrZero(low, "low"); if (high < low) { throw new IllegalArgumentException( "write buffer's high water mark cannot be less than " + diff --git a/transport/src/main/java/io/netty/channel/nio/NioEventLoop.java b/transport/src/main/java/io/netty/channel/nio/NioEventLoop.java index 2b9316719734..8ba5366155d0 100644 --- a/transport/src/main/java/io/netty/channel/nio/NioEventLoop.java +++ b/transport/src/main/java/io/netty/channel/nio/NioEventLoop.java @@ -33,8 +33,9 @@ import java.lang.reflect.Field; import java.nio.channels.CancelledKeyException; import java.nio.channels.SelectableChannel; -import java.nio.channels.SelectionKey; import java.nio.channels.Selector; +import java.nio.channels.SelectionKey; + import java.nio.channels.spi.SelectorProvider; import java.security.AccessController; import java.security.PrivilegedAction; @@ -43,7 +44,6 @@ import java.util.Iterator; import java.util.Queue; import java.util.Set; -import java.util.concurrent.Callable; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -59,7 +59,7 @@ public class NioEventLoop extends SingleThreadEventLoop { private static final int CLEANUP_INTERVAL = 256; // XXX Hard-coded value, but won't need customization. - private static final boolean DISABLE_KEYSET_OPTIMIZATION = + private static final boolean DISABLE_KEY_SET_OPTIMIZATION = SystemPropertyUtil.getBoolean("io.netty.noKeySetOptimization", false); private static final int MIN_PREMATURE_SELECTOR_RETURNS = 3; @@ -71,12 +71,6 @@ public int get() throws Exception { return selectNow(); } }; - private final Callable pendingTasksCallable = new Callable() { - @Override - public Integer call() throws Exception { - return NioEventLoop.super.pendingTasks(); - } - }; // Workaround for JDK NIO bug. // @@ -85,8 +79,8 @@ public Integer call() throws Exception { // - https://github.com/netty/netty/issues/203 static { final String key = "sun.nio.ch.bugLevel"; - final String buglevel = SystemPropertyUtil.get(key); - if (buglevel == null) { + final String bugLevel = SystemPropertyUtil.get(key); + if (bugLevel == null) { try { AccessController.doPrivileged(new PrivilegedAction() { @Override @@ -108,13 +102,11 @@ public Void run() { SELECTOR_AUTO_REBUILD_THRESHOLD = selectorAutoRebuildThreshold; if (logger.isDebugEnabled()) { - logger.debug("-Dio.netty.noKeySetOptimization: {}", DISABLE_KEYSET_OPTIMIZATION); + logger.debug("-Dio.netty.noKeySetOptimization: {}", DISABLE_KEY_SET_OPTIMIZATION); logger.debug("-Dio.netty.selectorAutoRebuildThreshold: {}", SELECTOR_AUTO_REBUILD_THRESHOLD); } } - static final long MAX_SCHEDULED_DAYS = 365 * 3; - /** * The NIO {@link Selector}. */ @@ -177,12 +169,10 @@ private SelectorTuple openSelector() { throw new ChannelException("failed to open a new selector", e); } - if (DISABLE_KEYSET_OPTIMIZATION) { + if (DISABLE_KEY_SET_OPTIMIZATION) { return new SelectorTuple(unwrappedSelector); } - final SelectedSelectionKeySet selectedKeySet = new SelectedSelectionKeySet(); - Object maybeSelectorImplClass = AccessController.doPrivileged(new PrivilegedAction() { @Override public Object run() { @@ -198,8 +188,8 @@ public Object run() { }); if (!(maybeSelectorImplClass instanceof Class) || - // ensure the current selector implementation is what we can instrument. - !((Class) maybeSelectorImplClass).isAssignableFrom(unwrappedSelector.getClass())) { + // ensure the current selector implementation is what we can instrument. + !((Class) maybeSelectorImplClass).isAssignableFrom(unwrappedSelector.getClass())) { if (maybeSelectorImplClass instanceof Throwable) { Throwable t = (Throwable) maybeSelectorImplClass; logger.trace("failed to instrument a special java.util.Set into: {}", unwrappedSelector, t); @@ -208,6 +198,7 @@ public Object run() { } final Class selectorImplClass = (Class) maybeSelectorImplClass; + final SelectedSelectionKeySet selectedKeySet = new SelectedSelectionKeySet(); Object maybeException = AccessController.doPrivileged(new PrivilegedAction() { @Override @@ -216,6 +207,23 @@ public Object run() { Field selectedKeysField = selectorImplClass.getDeclaredField("selectedKeys"); Field publicSelectedKeysField = selectorImplClass.getDeclaredField("publicSelectedKeys"); + if (PlatformDependent.javaVersion() >= 9 && PlatformDependent.hasUnsafe()) { + // Let us try to use sun.misc.Unsafe to replace the SelectionKeySet. + // This allows us to also do this in Java9+ without any extra flags. + long selectedKeysFieldOffset = PlatformDependent.objectFieldOffset(selectedKeysField); + long publicSelectedKeysFieldOffset = + PlatformDependent.objectFieldOffset(publicSelectedKeysField); + + if (selectedKeysFieldOffset != -1 && publicSelectedKeysFieldOffset != -1) { + PlatformDependent.putObject( + unwrappedSelector, selectedKeysFieldOffset, selectedKeySet); + PlatformDependent.putObject( + unwrappedSelector, publicSelectedKeysFieldOffset, selectedKeySet); + return null; + } + // We could not retrieve the offset, lets try reflection as last-resort. + } + Throwable cause = ReflectionUtil.trySetAccessible(selectedKeysField, true); if (cause != null) { return cause; @@ -262,18 +270,6 @@ protected Queue newTaskQueue(int maxPendingTasks) { : PlatformDependent.newMpscQueue(maxPendingTasks); } - @Override - public int pendingTasks() { - // As we use a MpscQueue we need to ensure pendingTasks() is only executed from within the EventLoop as - // otherwise we may see unexpected behavior (as size() is only allowed to be called by a single consumer). - // See https://github.com/netty/netty/issues/5297 - if (inEventLoop()) { - return super.pendingTasks(); - } else { - return submit(pendingTasksCallable).syncUninterruptibly().getNow(); - } - } - /** * Registers an arbitrary {@link SelectableChannel}, not necessarily created by Netty, to the {@link Selector} * of this event loop. Once the specified {@link SelectableChannel} is registered, the specified {@code task} will @@ -298,8 +294,28 @@ public void register(final SelectableChannel ch, final int interestOps, final Ni throw new IllegalStateException("event loop shut down"); } + if (inEventLoop()) { + register0(ch, interestOps, task); + } else { + try { + // Offload to the EventLoop as otherwise java.nio.channels.spi.AbstractSelectableChannel.register + // may block for a long time while trying to obtain an internal lock that may be hold while selecting. + submit(new Runnable() { + @Override + public void run() { + register0(ch, interestOps, task); + } + }).sync(); + } catch (InterruptedException ignore) { + // Even if interrupted we did schedule it so just mark the Thread as interrupted. + Thread.currentThread().interrupt(); + } + } + } + + private void register0(SelectableChannel ch, int interestOps, NioTask task) { try { - ch.register(selector, interestOps, task); + ch.register(unwrappedSelector, interestOps, task); } catch (Exception e) { throw new EventLoopException("failed to register a channel", e); } @@ -397,16 +413,23 @@ private void rebuildSelector0() { } } - logger.info("Migrated " + nChannels + " channel(s) to the new Selector."); + if (logger.isInfoEnabled()) { + logger.info("Migrated " + nChannels + " channel(s) to the new Selector."); + } } @Override protected void run() { for (;;) { try { - switch (selectStrategy.calculateStrategy(selectNowSupplier, hasTasks())) { + try { + switch (selectStrategy.calculateStrategy(selectNowSupplier, hasTasks())) { case SelectStrategy.CONTINUE: continue; + + case SelectStrategy.BUSY_WAIT: + // fall-through to SELECT since the busy-wait is not supported with NIO + case SelectStrategy.SELECT: select(wakenUp.getAndSet(false)); @@ -443,6 +466,13 @@ protected void run() { } // fall through default: + } + } catch (IOException e) { + // If we receive an IOException here its because the Selector is messed up. Let's rebuild + // the selector and retry. https://github.com/netty/netty/issues/8566 + rebuildSelector0(); + handleLoopException(e); + continue; } cancelledKeys = 0; @@ -784,17 +814,9 @@ private void select(boolean oldWakenUp) throws IOException { selectCnt = 1; } else if (SELECTOR_AUTO_REBUILD_THRESHOLD > 0 && selectCnt >= SELECTOR_AUTO_REBUILD_THRESHOLD) { - // The selector returned prematurely many times in a row. - // Rebuild the selector to work around the problem. - logger.warn( - "Selector.select() returned prematurely {} times in a row; rebuilding Selector {}.", - selectCnt, selector); - - rebuildSelector(); - selector = this.selector; - - // Select again to populate selectedKeys. - selector.selectNow(); + // The code exists in an extra method to ensure the method is not too big to inline as this + // branch is not very likely to get hit very frequently. + selector = selectRebuildSelector(selectCnt); selectCnt = 1; break; } @@ -817,6 +839,21 @@ private void select(boolean oldWakenUp) throws IOException { } } + private Selector selectRebuildSelector(int selectCnt) throws IOException { + // The selector returned prematurely many times in a row. + // Rebuild the selector to work around the problem. + logger.warn( + "Selector.select() returned prematurely {} times in a row; rebuilding Selector {}.", + selectCnt, selector); + + rebuildSelector(); + Selector selector = this.selector; + + // Select again to populate selectedKeys. + selector.selectNow(); + return selector; + } + private void selectAgain() { needsToSelectAgain = false; try { @@ -825,12 +862,4 @@ private void selectAgain() { logger.warn("Failed to update SelectionKeys.", t); } } - - @Override - protected void validateScheduled(long amount, TimeUnit unit) { - long days = unit.toDays(amount); - if (days > MAX_SCHEDULED_DAYS) { - throw new IllegalArgumentException("days: " + days + " (expected: < " + MAX_SCHEDULED_DAYS + ')'); - } - } } diff --git a/transport/src/main/java/io/netty/channel/nio/SelectedSelectionKeySet.java b/transport/src/main/java/io/netty/channel/nio/SelectedSelectionKeySet.java index 534e27f6c778..8be2672a5a2c 100644 --- a/transport/src/main/java/io/netty/channel/nio/SelectedSelectionKeySet.java +++ b/transport/src/main/java/io/netty/channel/nio/SelectedSelectionKeySet.java @@ -19,6 +19,7 @@ import java.util.AbstractSet; import java.util.Arrays; import java.util.Iterator; +import java.util.NoSuchElementException; final class SelectedSelectionKeySet extends AbstractSet { @@ -43,11 +44,6 @@ public boolean add(SelectionKey o) { return true; } - @Override - public int size() { - return size; - } - @Override public boolean remove(Object o) { return false; @@ -58,9 +54,34 @@ public boolean contains(Object o) { return false; } + @Override + public int size() { + return size; + } + @Override public Iterator iterator() { - throw new UnsupportedOperationException(); + return new Iterator() { + private int idx; + + @Override + public boolean hasNext() { + return idx < size; + } + + @Override + public SelectionKey next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return keys[idx++]; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; } void reset() { diff --git a/transport/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java b/transport/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java index 6eda848f82ef..54ea0deadd8b 100644 --- a/transport/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java +++ b/transport/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java @@ -34,6 +34,8 @@ /** * Abstract base class for OIO which reads and writes bytes from/to a Socket + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ public abstract class AbstractOioByteChannel extends AbstractOioChannel { diff --git a/transport/src/main/java/io/netty/channel/oio/AbstractOioChannel.java b/transport/src/main/java/io/netty/channel/oio/AbstractOioChannel.java index 7aa312d8e90b..b046d00a1c02 100644 --- a/transport/src/main/java/io/netty/channel/oio/AbstractOioChannel.java +++ b/transport/src/main/java/io/netty/channel/oio/AbstractOioChannel.java @@ -25,7 +25,10 @@ /** * Abstract base class for {@link Channel} implementations that use Old-Blocking-IO + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public abstract class AbstractOioChannel extends AbstractChannel { protected static final int SO_TIMEOUT = 1000; diff --git a/transport/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java b/transport/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java index 0543b83c13ad..721631159ca2 100644 --- a/transport/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java +++ b/transport/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java @@ -26,7 +26,10 @@ /** * Abstract base class for OIO which reads and writes objects from/to a Socket + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public abstract class AbstractOioMessageChannel extends AbstractOioChannel { private final List readBuf = new ArrayList(); diff --git a/transport/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java b/transport/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java index 6d8863b9d356..d352823fdab9 100644 --- a/transport/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java +++ b/transport/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java @@ -31,7 +31,10 @@ /** * Abstract base class for OIO Channels that are based on streams. + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public abstract class OioByteStreamChannel extends AbstractOioByteChannel { private static final InputStream CLOSED_IN = new InputStream() { diff --git a/transport/src/main/java/io/netty/channel/oio/OioEventLoopGroup.java b/transport/src/main/java/io/netty/channel/oio/OioEventLoopGroup.java index 684af4098396..91a2c4b3ffbf 100644 --- a/transport/src/main/java/io/netty/channel/oio/OioEventLoopGroup.java +++ b/transport/src/main/java/io/netty/channel/oio/OioEventLoopGroup.java @@ -30,7 +30,10 @@ /** * {@link EventLoopGroup} which is used to handle OIO {@link Channel}'s. Each {@link Channel} will be handled by its * own {@link EventLoop} to not block others. + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public class OioEventLoopGroup extends ThreadPerChannelEventLoopGroup { /** diff --git a/transport/src/main/java/io/netty/channel/oio/package-info.java b/transport/src/main/java/io/netty/channel/oio/package-info.java index 948e8492ad76..af5b19a1b387 100644 --- a/transport/src/main/java/io/netty/channel/oio/package-info.java +++ b/transport/src/main/java/io/netty/channel/oio/package-info.java @@ -17,5 +17,7 @@ /** * Old blocking I/O based channel API implementation - recommended for * a small number of connections (< 1000). + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ package io.netty.channel.oio; diff --git a/transport/src/main/java/io/netty/channel/pool/FixedChannelPool.java b/transport/src/main/java/io/netty/channel/pool/FixedChannelPool.java index 1c0ebcc5343b..26ded639ef91 100644 --- a/transport/src/main/java/io/netty/channel/pool/FixedChannelPool.java +++ b/transport/src/main/java/io/netty/channel/pool/FixedChannelPool.java @@ -20,6 +20,7 @@ import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.GlobalEventExecutor; import io.netty.util.concurrent.Promise; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.ThrowableUtil; @@ -27,6 +28,7 @@ import java.nio.channels.ClosedChannelException; import java.util.ArrayDeque; import java.util.Queue; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -69,7 +71,7 @@ public enum AcquireTimeoutAction { private final Queue pendingAcquireQueue = new ArrayDeque(); private final int maxConnections; private final int maxPendingAcquires; - private int acquiredChannelCount; + private final AtomicInteger acquiredChannelCount = new AtomicInteger(); private int pendingAcquireCount; private boolean closed; @@ -228,6 +230,11 @@ public void onTimeout(AcquireTask task) { this.maxPendingAcquires = maxPendingAcquires; } + /** Returns the number of acquired channels that this pool thinks it has. */ + public int acquiredChannelCount() { + return acquiredChannelCount.get(); + } + @Override public Future acquire(final Promise promise) { try { @@ -254,8 +261,8 @@ private void acquire0(final Promise promise) { promise.setFailure(POOL_CLOSED_ON_ACQUIRE_EXCEPTION); return; } - if (acquiredChannelCount < maxConnections) { - assert acquiredChannelCount >= 0; + if (acquiredChannelCount.get() < maxConnections) { + assert acquiredChannelCount.get() >= 0; // We need to create a new promise as we need to ensure the AcquireListener runs in the correct // EventLoop @@ -318,10 +325,9 @@ public void operationComplete(Future future) throws Exception { } private void decrementAndRunTaskQueue() { - --acquiredChannelCount; - // We should never have a negative value. - assert acquiredChannelCount >= 0; + int currentCount = acquiredChannelCount.decrementAndGet(); + assert currentCount >= 0; // Run the pending acquire tasks before notify the original promise so if the user would // try to acquire again from the ChannelFutureListener and the pendingAcquireCount is >= @@ -331,7 +337,7 @@ private void decrementAndRunTaskQueue() { } private void runTaskQueue() { - while (acquiredChannelCount < maxConnections) { + while (acquiredChannelCount.get() < maxConnections) { AcquireTask task = pendingAcquireQueue.poll(); if (task == null) { break; @@ -351,7 +357,7 @@ private void runTaskQueue() { // We should never have a negative value. assert pendingAcquireCount >= 0; - assert acquiredChannelCount >= 0; + assert acquiredChannelCount.get() >= 0; } // AcquireTask extends AcquireListener to reduce object creations and so GC pressure @@ -360,7 +366,7 @@ private final class AcquireTask extends AcquireListener { final long expireNanoTime = System.nanoTime() + acquireTimeoutNanos; ScheduledFuture timeoutFuture; - public AcquireTask(Promise promise) { + AcquireTask(Promise promise) { super(promise); // We need to create a new promise as we need to ensure the AcquireListener runs in the correct // EventLoop. @@ -430,34 +436,50 @@ public void acquired() { if (acquired) { return; } - acquiredChannelCount++; + acquiredChannelCount.incrementAndGet(); acquired = true; } } @Override public void close() { - executor.execute(new Runnable() { - @Override - public void run() { - if (!closed) { - closed = true; - for (;;) { - AcquireTask task = pendingAcquireQueue.poll(); - if (task == null) { - break; - } - ScheduledFuture f = task.timeoutFuture; - if (f != null) { - f.cancel(false); - } - task.promise.setFailure(new ClosedChannelException()); - } - acquiredChannelCount = 0; - pendingAcquireCount = 0; - FixedChannelPool.super.close(); + if (executor.inEventLoop()) { + close0(); + } else { + executor.submit(new Runnable() { + @Override + public void run() { + close0(); } + }).awaitUninterruptibly(); + } + } + + private void close0() { + if (!closed) { + closed = true; + for (;;) { + AcquireTask task = pendingAcquireQueue.poll(); + if (task == null) { + break; + } + ScheduledFuture f = task.timeoutFuture; + if (f != null) { + f.cancel(false); + } + task.promise.setFailure(new ClosedChannelException()); } - }); + acquiredChannelCount.set(0); + pendingAcquireCount = 0; + + // Ensure we dispatch this on another Thread as close0 will be called from the EventExecutor and we need + // to ensure we will not block in a EventExecutor. + GlobalEventExecutor.INSTANCE.execute(new Runnable() { + @Override + public void run() { + FixedChannelPool.super.close(); + } + }); + } } } diff --git a/transport/src/main/java/io/netty/channel/pool/SimpleChannelPool.java b/transport/src/main/java/io/netty/channel/pool/SimpleChannelPool.java index d8dd0eae9409..6fcfd4443fac 100644 --- a/transport/src/main/java/io/netty/channel/pool/SimpleChannelPool.java +++ b/transport/src/main/java/io/netty/channel/pool/SimpleChannelPool.java @@ -394,7 +394,8 @@ public void close() { if (channel == null) { break; } - channel.close(); + // Just ignore any errors that are reported back from close(). + channel.close().awaitUninterruptibly(); } } } diff --git a/transport/src/main/java/io/netty/channel/socket/DefaultServerSocketChannelConfig.java b/transport/src/main/java/io/netty/channel/socket/DefaultServerSocketChannelConfig.java index 57ac53dafd25..10c7bcb700ff 100644 --- a/transport/src/main/java/io/netty/channel/socket/DefaultServerSocketChannelConfig.java +++ b/transport/src/main/java/io/netty/channel/socket/DefaultServerSocketChannelConfig.java @@ -31,6 +31,7 @@ import static io.netty.channel.ChannelOption.SO_BACKLOG; import static io.netty.channel.ChannelOption.SO_RCVBUF; import static io.netty.channel.ChannelOption.SO_REUSEADDR; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; /** * The default {@link ServerSocketChannelConfig} implementation. @@ -141,9 +142,7 @@ public int getBacklog() { @Override public ServerSocketChannelConfig setBacklog(int backlog) { - if (backlog < 0) { - throw new IllegalArgumentException("backlog: " + backlog); - } + checkPositiveOrZero(backlog, "backlog"); this.backlog = backlog; return this; } diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioChannelOption.java b/transport/src/main/java/io/netty/channel/socket/nio/NioChannelOption.java new file mode 100644 index 000000000000..3f9550f101c7 --- /dev/null +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioChannelOption.java @@ -0,0 +1,118 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; + +import java.io.IOException; +import java.nio.channels.Channel; +import java.nio.channels.ServerSocketChannel; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +/** + * Provides {@link ChannelOption} over a given {@link java.net.SocketOption} which is then passed through the underlying + * {@link java.nio.channels.NetworkChannel}. + */ +public final class NioChannelOption extends ChannelOption { + + private final java.net.SocketOption option; + + @SuppressWarnings("deprecation") + private NioChannelOption(java.net.SocketOption option) { + super(option.name()); + this.option = option; + } + + /** + * Returns a {@link ChannelOption} for the given {@link java.net.SocketOption}. + */ + public static ChannelOption of(java.net.SocketOption option) { + return new NioChannelOption(option); + } + + // It's important to not use java.nio.channels.NetworkChannel as otherwise the classes that sometimes call this + // method may not be used on Java 6, as method linking can happen eagerly even if this method was not actually + // called at runtime. + // + // See https://github.com/netty/netty/issues/8166 + + // Internal helper methods to remove code duplication between Nio*Channel implementations. + static boolean setOption(Channel jdkChannel, NioChannelOption option, T value) { + java.nio.channels.NetworkChannel channel = (java.nio.channels.NetworkChannel) jdkChannel; + if (!channel.supportedOptions().contains(option.option)) { + return false; + } + if (channel instanceof ServerSocketChannel && option.option == java.net.StandardSocketOptions.IP_TOS) { + // Skip IP_TOS as a workaround for a JDK bug: + // See http://mail.openjdk.java.net/pipermail/nio-dev/2018-August/005365.html + return false; + } + try { + channel.setOption(option.option, value); + return true; + } catch (IOException e) { + throw new ChannelException(e); + } + } + + static T getOption(Channel jdkChannel, NioChannelOption option) { + java.nio.channels.NetworkChannel channel = (java.nio.channels.NetworkChannel) jdkChannel; + + if (!channel.supportedOptions().contains(option.option)) { + return null; + } + if (channel instanceof ServerSocketChannel && option.option == java.net.StandardSocketOptions.IP_TOS) { + // Skip IP_TOS as a workaround for a JDK bug: + // See http://mail.openjdk.java.net/pipermail/nio-dev/2018-August/005365.html + return null; + } + try { + return channel.getOption(option.option); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @SuppressWarnings("unchecked") + static ChannelOption[] getOptions(Channel jdkChannel) { + java.nio.channels.NetworkChannel channel = (java.nio.channels.NetworkChannel) jdkChannel; + Set> supportedOpts = channel.supportedOptions(); + + if (channel instanceof ServerSocketChannel) { + List> extraOpts = new ArrayList>(supportedOpts.size()); + for (java.net.SocketOption opt : supportedOpts) { + if (opt == java.net.StandardSocketOptions.IP_TOS) { + // Skip IP_TOS as a workaround for a JDK bug: + // See http://mail.openjdk.java.net/pipermail/nio-dev/2018-August/005365.html + continue; + } + extraOpts.add(new NioChannelOption(opt)); + } + return extraOpts.toArray(new ChannelOption[0]); + } else { + ChannelOption[] extraOpts = new ChannelOption[supportedOpts.size()]; + + int i = 0; + for (java.net.SocketOption opt : supportedOpts) { + extraOpts[i++] = new NioChannelOption(opt); + } + return extraOpts; + } + } +} diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannelConfig.java b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannelConfig.java index 92cf677f7cb7..7e109af556e6 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannelConfig.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannelConfig.java @@ -16,6 +16,7 @@ package io.netty.channel.socket.nio; import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; import io.netty.channel.socket.DatagramChannelConfig; import io.netty.channel.socket.DefaultDatagramChannelConfig; import io.netty.util.internal.PlatformDependent; @@ -27,6 +28,7 @@ import java.net.SocketException; import java.nio.channels.DatagramChannel; import java.util.Enumeration; +import java.util.Map; /** * The default {@link NioDatagramChannelConfig} implementation. @@ -205,4 +207,29 @@ private void setOption0(Object option, Object value) { } } } + + @Override + public boolean setOption(ChannelOption option, T value) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.setOption(javaChannel, (NioChannelOption) option, value); + } + return super.setOption(option, value); + } + + @Override + public T getOption(ChannelOption option) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.getOption(javaChannel, (NioChannelOption) option); + } + return super.getOption(option); + } + + @SuppressWarnings("unchecked") + @Override + public Map, Object> getOptions() { + if (PlatformDependent.javaVersion() >= 7) { + return getOptions(super.getOptions(), NioChannelOption.getOptions(javaChannel)); + } + return super.getOptions(); + } } diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioServerSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioServerSocketChannel.java index f2d6f16f6567..128e531e56b5 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioServerSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioServerSocketChannel.java @@ -17,6 +17,7 @@ import io.netty.channel.ChannelException; import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOutboundBuffer; import io.netty.util.internal.SocketUtils; import io.netty.channel.nio.AbstractNioMessageChannel; @@ -35,6 +36,7 @@ import java.nio.channels.SocketChannel; import java.nio.channels.spi.SelectorProvider; import java.util.List; +import java.util.Map; /** * A {@link io.netty.channel.socket.ServerSocketChannel} implementation which uses @@ -199,6 +201,35 @@ private NioServerSocketChannelConfig(NioServerSocketChannel channel, ServerSocke protected void autoReadCleared() { clearReadPending(); } + + @Override + public boolean setOption(ChannelOption option, T value) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.setOption(jdkChannel(), (NioChannelOption) option, value); + } + return super.setOption(option, value); + } + + @Override + public T getOption(ChannelOption option) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.getOption(jdkChannel(), (NioChannelOption) option); + } + return super.getOption(option); + } + + @SuppressWarnings("unchecked") + @Override + public Map, Object> getOptions() { + if (PlatformDependent.javaVersion() >= 7) { + return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel())); + } + return super.getOptions(); + } + + private ServerSocketChannel jdkChannel() { + return ((NioServerSocketChannel) channel).javaChannel(); + } } // Override just to to be able to call directly via unit tests. diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java index e5010ac3b572..74431791ce40 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java @@ -20,6 +20,7 @@ import io.netty.channel.ChannelException; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOutboundBuffer; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; @@ -44,6 +45,7 @@ import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; import java.nio.channels.spi.SelectorProvider; +import java.util.Map; import java.util.concurrent.Executor; import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD; @@ -461,7 +463,6 @@ protected Executor prepareToClose() { private final class NioSocketChannelConfig extends DefaultSocketChannelConfig { private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE; - private NioSocketChannelConfig(NioSocketChannel channel, Socket javaSocket) { super(channel, javaSocket); calculateMaxBytesPerGatheringWrite(); @@ -479,6 +480,31 @@ public NioSocketChannelConfig setSendBufferSize(int sendBufferSize) { return this; } + @Override + public boolean setOption(ChannelOption option, T value) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.setOption(jdkChannel(), (NioChannelOption) option, value); + } + return super.setOption(option, value); + } + + @Override + public T getOption(ChannelOption option) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.getOption(jdkChannel(), (NioChannelOption) option); + } + return super.getOption(option); + } + + @SuppressWarnings("unchecked") + @Override + public Map, Object> getOptions() { + if (PlatformDependent.javaVersion() >= 7) { + return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel())); + } + return super.getOptions(); + } + void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) { this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite; } @@ -494,5 +520,9 @@ private void calculateMaxBytesPerGatheringWrite() { setMaxBytesPerGatheringWrite(getSendBufferSize() << 1); } } + + private SocketChannel jdkChannel() { + return ((NioSocketChannel) channel).javaChannel(); + } } } diff --git a/transport/src/main/java/io/netty/channel/socket/oio/DefaultOioServerSocketChannelConfig.java b/transport/src/main/java/io/netty/channel/socket/oio/DefaultOioServerSocketChannelConfig.java index 05d9bb30a457..9bb9056f50a6 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/DefaultOioServerSocketChannelConfig.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/DefaultOioServerSocketChannelConfig.java @@ -33,7 +33,10 @@ /** * Default {@link OioServerSocketChannelConfig} implementation + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public class DefaultOioServerSocketChannelConfig extends DefaultServerSocketChannelConfig implements OioServerSocketChannelConfig { diff --git a/transport/src/main/java/io/netty/channel/socket/oio/DefaultOioSocketChannelConfig.java b/transport/src/main/java/io/netty/channel/socket/oio/DefaultOioSocketChannelConfig.java index 3aa72fb48321..5261ebb82a52 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/DefaultOioSocketChannelConfig.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/DefaultOioSocketChannelConfig.java @@ -33,7 +33,10 @@ /** * Default {@link OioSocketChannelConfig} implementation + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public class DefaultOioSocketChannelConfig extends DefaultSocketChannelConfig implements OioSocketChannelConfig { @Deprecated public DefaultOioSocketChannelConfig(SocketChannel channel, Socket javaSocket) { diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java b/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java index 3e588d5a369d..abe9a431d573 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java @@ -16,6 +16,7 @@ package io.netty.channel.socket.oio; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; import io.netty.channel.AddressedEnvelope; import io.netty.channel.Channel; import io.netty.channel.ChannelException; @@ -53,7 +54,9 @@ * * @see AddressedEnvelope * @see DatagramPacket + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public class OioDatagramChannel extends AbstractOioMessageChannel implements DatagramChannel { @@ -120,7 +123,7 @@ public ChannelMetadata metadata() { /** * {@inheritDoc} * - * This can be safetly cast to {@link OioDatagramChannelConfig}. + * This can be safely cast to {@link OioDatagramChannelConfig}. */ @Override // TODO: Change return type to OioDatagramChannelConfig in next major release @@ -276,9 +279,7 @@ protected void doWrite(ChannelOutboundBuffer in) throws Exception { if (data.hasArray()) { tmpPacket.setData(data.array(), data.arrayOffset() + data.readerIndex(), length); } else { - byte[] tmp = new byte[length]; - data.getBytes(data.readerIndex(), tmp); - tmpPacket.setData(tmp); + tmpPacket.setData(ByteBufUtil.getBytes(data, data.readerIndex(), length)); } socket.send(tmpPacket); in.remove(); diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannelConfig.java b/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannelConfig.java index 0ebea3dbfb82..5da805e819ba 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannelConfig.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannelConfig.java @@ -24,6 +24,10 @@ import java.net.InetAddress; import java.net.NetworkInterface; +/** + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated public interface OioDatagramChannelConfig extends DatagramChannelConfig { /** * Sets the maximal time a operation on the underlying socket may block. diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannel.java index 336fcf1a6e73..bf91829b2574 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannel.java @@ -38,7 +38,10 @@ * {@link ServerSocketChannel} which accepts new connections and create the {@link OioSocketChannel}'s for them. * * This implementation use Old-Blocking-IO. + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public class OioServerSocketChannel extends AbstractOioMessageChannel implements ServerSocketChannel { diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannelConfig.java b/transport/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannelConfig.java index 51bf6e01293b..caf3a5ebabb1 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannelConfig.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannelConfig.java @@ -39,7 +39,10 @@ * {@link ChannelOption#SO_TIMEOUT}{@link #setSoTimeout(int)} * * + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public interface OioServerSocketChannelConfig extends ServerSocketChannelConfig { /** diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java index 0c5580e93c29..935c316a88a8 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java @@ -39,7 +39,10 @@ /** * A {@link SocketChannel} which is using Old-Blocking-IO + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public class OioSocketChannel extends OioByteStreamChannel implements SocketChannel { private static final InternalLogger logger = InternalLoggerFactory.getInstance(OioSocketChannel.class); diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannelConfig.java b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannelConfig.java index 5b56bf179b4b..6b4574519030 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannelConfig.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannelConfig.java @@ -39,7 +39,10 @@ * {@link ChannelOption#SO_TIMEOUT}{@link #setSoTimeout(int)} * * + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ +@Deprecated public interface OioSocketChannelConfig extends SocketChannelConfig { /** diff --git a/transport/src/main/java/io/netty/channel/socket/oio/package-info.java b/transport/src/main/java/io/netty/channel/socket/oio/package-info.java index e73c4affb601..786ca08556ec 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/package-info.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/package-info.java @@ -17,5 +17,7 @@ /** * Old blocking I/O based socket channel API implementation - recommended for * a small number of connections (< 1000). + * + * @deprecated use NIO / EPOLL / KQUEUE transport. */ package io.netty.channel.socket.oio; diff --git a/transport/src/test/java/io/netty/channel/AbstractChannelTest.java b/transport/src/test/java/io/netty/channel/AbstractChannelTest.java index afbe27c04bed..9d5110ea8d8e 100644 --- a/transport/src/test/java/io/netty/channel/AbstractChannelTest.java +++ b/transport/src/test/java/io/netty/channel/AbstractChannelTest.java @@ -15,8 +15,12 @@ */ package io.netty.channel; +import java.io.IOException; +import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import io.netty.util.NetUtil; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -82,6 +86,69 @@ public void ensureDefaultChannelId() { assertTrue(channelId instanceof DefaultChannelId); } + @Test + public void testClosedChannelExceptionCarryIOException() throws Exception { + final IOException ioException = new IOException(); + final Channel channel = new TestChannel() { + private boolean open = true; + private boolean active; + + @Override + protected AbstractUnsafe newUnsafe() { + return new AbstractUnsafe() { + @Override + public void connect( + SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + active = true; + promise.setSuccess(); + } + }; + } + + @Override + protected void doClose() { + active = false; + open = false; + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + throw ioException; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public boolean isActive() { + return active; + } + }; + + EventLoop loop = new DefaultEventLoop(); + try { + registerChannel(loop, channel); + channel.connect(new InetSocketAddress(NetUtil.LOCALHOST, 8888)).sync(); + assertSame(ioException, channel.writeAndFlush("").await().cause()); + + assertClosedChannelException(channel.writeAndFlush(""), ioException); + assertClosedChannelException(channel.write(""), ioException); + assertClosedChannelException(channel.bind(new InetSocketAddress(NetUtil.LOCALHOST, 8888)), ioException); + } finally { + channel.close(); + loop.shutdownGracefully(); + } + } + + private static void assertClosedChannelException(ChannelFuture future, IOException expected) + throws InterruptedException { + Throwable cause = future.await().cause(); + assertTrue(cause instanceof ClosedChannelException); + assertSame(expected, cause.getCause()); + } + private static void registerChannel(EventLoop eventLoop, Channel channel) throws Exception { DefaultChannelPromise future = new DefaultChannelPromise(channel); channel.unsafe().register(eventLoop, future); @@ -90,19 +157,16 @@ private static void registerChannel(EventLoop eventLoop, Channel channel) throws private static class TestChannel extends AbstractChannel { private static final ChannelMetadata TEST_METADATA = new ChannelMetadata(false); - private class TestUnsafe extends AbstractUnsafe { - @Override - public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { } - } + private final ChannelConfig config = new DefaultChannelConfig(this); - public TestChannel() { + TestChannel() { super(null); } @Override public ChannelConfig config() { - return new DefaultChannelConfig(this); + return config; } @Override @@ -122,7 +186,12 @@ public ChannelMetadata metadata() { @Override protected AbstractUnsafe newUnsafe() { - return new TestUnsafe(); + return new AbstractUnsafe() { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + promise.setFailure(new UnsupportedOperationException()); + } + }; } @Override @@ -141,16 +210,16 @@ protected SocketAddress remoteAddress0() { } @Override - protected void doBind(SocketAddress localAddress) throws Exception { } + protected void doBind(SocketAddress localAddress) { } @Override - protected void doDisconnect() throws Exception { } + protected void doDisconnect() { } @Override - protected void doClose() throws Exception { } + protected void doClose() { } @Override - protected void doBeginRead() throws Exception { } + protected void doBeginRead() { } @Override protected void doWrite(ChannelOutboundBuffer in) throws Exception { } diff --git a/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java b/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java index 5ab48867bc0b..bd9415b17b1e 100644 --- a/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java +++ b/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java @@ -21,19 +21,26 @@ import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalServerChannel; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; import org.junit.After; import org.junit.Before; import org.junit.Test; import java.util.Iterator; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertSame; public class ChannelInitializerTest { @@ -63,6 +70,47 @@ public void tearDown() { group.shutdownGracefully(0, TIMEOUT_MILLIS, TimeUnit.MILLISECONDS).syncUninterruptibly(); } + @Test + public void testInitChannelThrowsRegisterFirst() { + testInitChannelThrows(true); + } + + @Test + public void testInitChannelThrowsRegisterAfter() { + testInitChannelThrows(false); + } + + private void testInitChannelThrows(boolean registerFirst) { + final Exception exception = new Exception(); + final AtomicReference causeRef = new AtomicReference(); + + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + if (registerFirst) { + group.register(pipeline.channel()).syncUninterruptibly(); + } + pipeline.addFirst(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + throw exception; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + causeRef.set(cause); + super.exceptionCaught(ctx, cause); + } + }); + + if (!registerFirst) { + group.register(pipeline.channel()).syncUninterruptibly(); + } + pipeline.channel().close().syncUninterruptibly(); + pipeline.channel().closeFuture().syncUninterruptibly(); + + assertSame(exception, causeRef.get()); + } + @Test public void testChannelInitializerInInitializerCorrectOrdering() { final ChannelInboundHandlerAdapter handler1 = new ChannelInboundHandlerAdapter(); @@ -207,6 +255,135 @@ private void testChannelRegisteredEventPropagation(ChannelInitializer errorRef = new AtomicReference(); + LocalAddress addr = new LocalAddress("test"); + + final EventExecutor executor = new DefaultEventLoop() { + private final ScheduledExecutorService execService = Executors.newSingleThreadScheduledExecutor(); + + @Override + public void shutdown() { + execService.shutdown(); + } + + @Override + public boolean inEventLoop(Thread thread) { + // Always return false which will ensure we always call execute(...) + return false; + } + + @Override + public boolean isShuttingDown() { + return false; + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + throw new IllegalStateException(); + } + + @Override + public Future terminationFuture() { + throw new IllegalStateException(); + } + + @Override + public boolean isShutdown() { + return execService.isShutdown(); + } + + @Override + public boolean isTerminated() { + return execService.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return execService.awaitTermination(timeout, unit); + } + + @Override + public void execute(Runnable command) { + execService.execute(command); + } + }; + + final CountDownLatch latch = new CountDownLatch(1); + ServerBootstrap serverBootstrap = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(group) + .localAddress(addr) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) { + ch.pipeline().addLast(executor, new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + invokeCount.incrementAndGet(); + ChannelHandlerContext ctx = ch.pipeline().context(this); + assertNotNull(ctx); + ch.pipeline().addAfter(ctx.executor(), + ctx.name(), null, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + // just drop on the floor. + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + latch.countDown(); + } + }); + completeCount.incrementAndGet(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (cause instanceof AssertionError) { + errorRef.set(cause); + } + } + }); + } + }); + + Channel server = serverBootstrap.bind().sync().channel(); + + Bootstrap clientBootstrap = new Bootstrap() + .channel(LocalChannel.class) + .group(group) + .remoteAddress(addr) + .handler(new ChannelInboundHandlerAdapter()); + + Channel client = clientBootstrap.connect().sync().channel(); + client.writeAndFlush("Hello World").sync(); + + client.close().sync(); + server.close().sync(); + + client.closeFuture().sync(); + server.closeFuture().sync(); + + // Wait until the handler is removed from the pipeline and so no more events are handled by it. + latch.await(); + + assertEquals(1, invokeCount.get()); + assertEquals(invokeCount.get(), completeCount.get()); + + Throwable cause = errorRef.get(); + if (cause != null) { + throw cause; + } + + executor.shutdown(); + assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS)); + } + private static void closeChannel(Channel c) { if (c != null) { c.close().syncUninterruptibly(); diff --git a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java index 470249091680..c58569c43a31 100644 --- a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java +++ b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java @@ -19,10 +19,16 @@ import io.netty.buffer.CompositeByteBuf; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.RejectedExecutionHandlers; +import io.netty.util.concurrent.SingleThreadEventExecutor; import org.junit.Test; import java.net.SocketAddress; import java.nio.ByteBuffer; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.RejectedExecutionException; import static io.netty.buffer.Unpooled.*; import static org.hamcrest.Matchers.*; @@ -355,6 +361,103 @@ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exceptio safeClose(ch); } + @Test(timeout = 5000) + public void testWriteTaskRejected() throws Exception { + final SingleThreadEventExecutor executor = new SingleThreadEventExecutor( + null, new DefaultThreadFactory("executorPool"), + true, 1, RejectedExecutionHandlers.reject()) { + @Override + protected void run() { + do { + Runnable task = takeTask(); + if (task != null) { + task.run(); + updateLastExecutionTime(); + } + } while (!confirmShutdown()); + } + + @Override + protected Queue newTaskQueue(int maxPendingTasks) { + return super.newTaskQueue(1); + } + }; + final CountDownLatch handlerAddedLatch = new CountDownLatch(1); + final CountDownLatch handlerRemovedLatch = new CountDownLatch(1); + EmbeddedChannel ch = new EmbeddedChannel(); + ch.pipeline().addLast(executor, "handler", new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + promise.setFailure(new AssertionError("Should not be called")); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + handlerAddedLatch.countDown(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + handlerRemovedLatch.countDown(); + } + }); + + // Lets wait until we are sure the handler was added. + handlerAddedLatch.await(); + + final CountDownLatch executeLatch = new CountDownLatch(1); + final CountDownLatch runLatch = new CountDownLatch(1); + executor.execute(new Runnable() { + @Override + public void run() { + try { + runLatch.countDown(); + executeLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); + + runLatch.await(); + + executor.execute(new Runnable() { + @Override + public void run() { + // Will not be executed but ensure the pending count is 1. + } + }); + + assertEquals(1, executor.pendingTasks()); + assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes()); + + ByteBuf buffer = buffer(128).writeZero(128); + ChannelFuture future = ch.write(buffer); + ch.runPendingTasks(); + + assertTrue(future.cause() instanceof RejectedExecutionException); + assertEquals(0, buffer.refCnt()); + + // In case of rejected task we should not have anything pending. + assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes()); + executeLatch.countDown(); + + while (executor.pendingTasks() != 0) { + // Wait until there is no more pending task left. + Thread.sleep(10); + } + + ch.pipeline().remove("handler"); + + // Ensure we do not try to shutdown the executor before we handled everything for the Channel. Otherwise + // the Executor may reject when the Channel tries to add a task to it. + handlerRemovedLatch.await(); + + safeClose(ch); + + executor.shutdownGracefully(); + } + private static void safeClose(EmbeddedChannel ch) { ch.finish(); for (;;) { diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTailTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTailTest.java index 697db84af506..7eb624b6bc36 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTailTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTailTest.java @@ -237,7 +237,7 @@ protected void onUnhandledInboundWritabilityChanged() { private static class MyChannelFactory implements ChannelFactory { private final MyChannel channel; - public MyChannelFactory(MyChannel channel) { + MyChannelFactory(MyChannel channel) { this.channel = channel; } @@ -365,7 +365,7 @@ public void connect(SocketAddress remoteAddress, SocketAddress localAddress, Cha private class MyChannelPipeline extends DefaultChannelPipeline { - public MyChannelPipeline(Channel channel) { + MyChannelPipeline(Channel channel) { super(channel); } diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 7cd51b2999f5..65209d50cfea 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -24,6 +24,7 @@ import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalEventLoopGroup; import io.netty.channel.local.LocalServerChannel; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.oio.OioEventLoopGroup; @@ -1156,6 +1157,117 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { } } + // Test for https://github.com/netty/netty/issues/8676. + @Test + public void testHandlerRemovedOnlyCalledWhenHandlerAddedCalled() throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(1); + try { + final AtomicReference errorRef = new AtomicReference(); + + // As this only happens via a race we will verify 500 times. This was good enough to have it failed most of + // the time. + for (int i = 0; i < 500; i++) { + + ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()).sync(); + + final CountDownLatch latch = new CountDownLatch(1); + + pipeline.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + // Block just for a bit so we have a chance to trigger the race mentioned in the issue. + latch.await(50, TimeUnit.MILLISECONDS); + } + }); + + // Close the pipeline which will call destroy0(). This will remove each handler in the pipeline and + // should call handlerRemoved(...) if and only if handlerAdded(...) was called for the handler before. + pipeline.close(); + + pipeline.addLast(new ChannelInboundHandlerAdapter() { + private boolean handerAddedCalled; + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + handerAddedCalled = true; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + if (!handerAddedCalled) { + errorRef.set(new AssertionError( + "handlerRemoved(...) called without handlerAdded(...) before")); + } + } + }); + + latch.countDown(); + + pipeline.channel().closeFuture().syncUninterruptibly(); + + // Schedule something on the EventLoop to ensure all other scheduled tasks had a chance to complete. + pipeline.channel().eventLoop().submit(new Runnable() { + @Override + public void run() { + // NOOP + } + }).syncUninterruptibly(); + Error error = errorRef.get(); + if (error != null) { + throw error; + } + } + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testWriteThrowsReleaseMessage() { + testWriteThrowsReleaseMessage0(false); + } + + @Test + public void testWriteAndFlushThrowsReleaseMessage() { + testWriteThrowsReleaseMessage0(true); + } + + private void testWriteThrowsReleaseMessage0(boolean flush) { + ReferenceCounted referenceCounted = new AbstractReferenceCounted() { + @Override + protected void deallocate() { + // NOOP + } + + @Override + public ReferenceCounted touch(Object hint) { + return this; + } + }; + assertEquals(1, referenceCounted.refCnt()); + + Channel channel = new LocalChannel(); + Channel channel2 = new LocalChannel(); + group.register(channel).syncUninterruptibly(); + group.register(channel2).syncUninterruptibly(); + + try { + if (flush) { + channel.writeAndFlush(referenceCounted, channel2.newPromise()); + } else { + channel.write(referenceCounted, channel2.newPromise()); + } + fail(); + } catch (IllegalArgumentException expected) { + // expected + } + assertEquals(0, referenceCounted.refCnt()); + + channel.close().syncUninterruptibly(); + channel2.close().syncUninterruptibly(); + } + @Test(timeout = 5000) public void handlerAddedStateUpdatedBeforeHandlerAddedDoneForceEventLoop() throws InterruptedException { handlerAddedStateUpdatedBeforeHandlerAddedDone(true); diff --git a/transport/src/test/java/io/netty/channel/DefaultFileRegionTest.java b/transport/src/test/java/io/netty/channel/DefaultFileRegionTest.java new file mode 100644 index 000000000000..e416bccba594 --- /dev/null +++ b/transport/src/test/java/io/netty/channel/DefaultFileRegionTest.java @@ -0,0 +1,120 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.internal.PlatformDependent; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class DefaultFileRegionTest { + + private static final byte[] data = new byte[1048576 * 10]; + + static { + PlatformDependent.threadLocalRandom().nextBytes(data); + } + + private static File newFile() throws IOException { + File file = File.createTempFile("netty-", ".tmp"); + file.deleteOnExit(); + + final FileOutputStream out = new FileOutputStream(file); + out.write(data); + out.close(); + return file; + } + + @Test + public void testCreateFromFile() throws IOException { + File file = newFile(); + try { + testFileRegion(new DefaultFileRegion(file, 0, data.length)); + } finally { + file.delete(); + } + } + + @Test + public void testCreateFromFileChannel() throws IOException { + File file = newFile(); + RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r"); + try { + testFileRegion(new DefaultFileRegion(randomAccessFile.getChannel(), 0, data.length)); + } finally { + randomAccessFile.close(); + file.delete(); + } + } + + private static void testFileRegion(FileRegion region) throws IOException { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + WritableByteChannel channel = Channels.newChannel(outputStream); + + try { + assertEquals(data.length, region.count()); + assertEquals(0, region.transferred()); + assertEquals(data.length, region.transferTo(channel, 0)); + assertEquals(data.length, region.count()); + assertEquals(data.length, region.transferred()); + assertArrayEquals(data, outputStream.toByteArray()); + } finally { + channel.close(); + } + } + + @Test + public void testTruncated() throws IOException { + File file = newFile(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + WritableByteChannel channel = Channels.newChannel(outputStream); + RandomAccessFile randomAccessFile = new RandomAccessFile(file, "rw"); + + try { + FileRegion region = new DefaultFileRegion(randomAccessFile.getChannel(), 0, data.length); + + randomAccessFile.getChannel().truncate(data.length - 1024); + + assertEquals(data.length, region.count()); + assertEquals(0, region.transferred()); + + assertEquals(data.length - 1024, region.transferTo(channel, 0)); + assertEquals(data.length, region.count()); + assertEquals(data.length - 1024, region.transferred()); + try { + region.transferTo(channel, data.length - 1024); + fail(); + } catch (IOException expected) { + // expected + } + } finally { + channel.close(); + + randomAccessFile.close(); + file.delete(); + } + } +} diff --git a/transport/src/test/java/io/netty/channel/SimpleUserEventChannelHandlerTest.java b/transport/src/test/java/io/netty/channel/SimpleUserEventChannelHandlerTest.java new file mode 100644 index 000000000000..c0801c0be39f --- /dev/null +++ b/transport/src/test/java/io/netty/channel/SimpleUserEventChannelHandlerTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.DefaultByteBufHolder; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.*; + +public class SimpleUserEventChannelHandlerTest { + + private FooEventCatcher fooEventCatcher; + private AllEventCatcher allEventCatcher; + private EmbeddedChannel channel; + + @Before + public void setUp() { + fooEventCatcher = new FooEventCatcher(); + allEventCatcher = new AllEventCatcher(); + channel = new EmbeddedChannel(fooEventCatcher, allEventCatcher); + } + + @Test + public void testTypeMatch() { + FooEvent fooEvent = new FooEvent(); + channel.pipeline().fireUserEventTriggered(fooEvent); + assertEquals(1, fooEventCatcher.caughtEvents.size()); + assertEquals(0, allEventCatcher.caughtEvents.size()); + assertEquals(0, fooEvent.refCnt()); + assertFalse(channel.finish()); + } + + @Test + public void testTypeMismatch() { + BarEvent barEvent = new BarEvent(); + channel.pipeline().fireUserEventTriggered(barEvent); + assertEquals(0, fooEventCatcher.caughtEvents.size()); + assertEquals(1, allEventCatcher.caughtEvents.size()); + assertTrue(barEvent.release()); + assertFalse(channel.finish()); + } + + static final class FooEvent extends DefaultByteBufHolder { + FooEvent() { + super(Unpooled.buffer()); + } + } + + static final class BarEvent extends DefaultByteBufHolder { + BarEvent() { + super(Unpooled.buffer()); + } + } + + static final class FooEventCatcher extends SimpleUserEventChannelHandler { + + public List caughtEvents; + + FooEventCatcher() { + caughtEvents = new ArrayList(); + } + + @Override + protected void eventReceived(ChannelHandlerContext ctx, FooEvent evt) { + caughtEvents.add(evt); + } + } + + static final class AllEventCatcher extends ChannelInboundHandlerAdapter { + + public List caughtEvents; + + AllEventCatcher() { + caughtEvents = new ArrayList(); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + caughtEvents.add(evt); + } + } +} diff --git a/transport/src/test/java/io/netty/channel/nio/NioEventLoopTest.java b/transport/src/test/java/io/netty/channel/nio/NioEventLoopTest.java index 887dce421c99..8b176bc71c02 100644 --- a/transport/src/test/java/io/netty/channel/nio/NioEventLoopTest.java +++ b/transport/src/test/java/io/netty/channel/nio/NioEventLoopTest.java @@ -19,14 +19,26 @@ import io.netty.channel.Channel; import io.netty.channel.EventLoop; import io.netty.channel.EventLoopGroup; +import io.netty.channel.SelectStrategy; +import io.netty.channel.SelectStrategyFactory; import io.netty.channel.socket.ServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.util.IntSupplier; +import io.netty.util.concurrent.DefaultThreadFactory; import io.netty.util.concurrent.Future; +import org.hamcrest.core.IsInstanceOf; import org.junit.Test; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.channels.SelectionKey; import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.*; @@ -73,27 +85,8 @@ public void run() { } } - @Test(timeout = 5000L) - public void testScheduleBigDelayOverMax() { - EventLoopGroup group = new NioEventLoopGroup(1); - final EventLoop el = group.next(); - try { - el.schedule(new Runnable() { - @Override - public void run() { - // NOOP - } - }, Integer.MAX_VALUE, TimeUnit.DAYS); - fail(); - } catch (IllegalArgumentException expected) { - // expected - } - - group.shutdownGracefully(); - } - @Test - public void testScheduleBigDelay() { + public void testScheduleBigDelayNotOverflow() { EventLoopGroup group = new NioEventLoopGroup(1); final EventLoop el = group.next(); @@ -102,7 +95,7 @@ public void testScheduleBigDelay() { public void run() { // NOOP } - }, NioEventLoop.MAX_SCHEDULED_DAYS, TimeUnit.DAYS); + }, Long.MAX_VALUE, TimeUnit.MILLISECONDS); assertFalse(future.awaitUninterruptibly(1000)); assertTrue(future.cancel(true)); @@ -151,4 +144,118 @@ public void run() { group.shutdownGracefully(); } } + + @Test(timeout = 3000) + public void testSelectableChannel() throws Exception { + NioEventLoopGroup group = new NioEventLoopGroup(1); + NioEventLoop loop = (NioEventLoop) group.next(); + + try { + Channel channel = new NioServerSocketChannel(); + loop.register(channel).syncUninterruptibly(); + channel.bind(new InetSocketAddress(0)).syncUninterruptibly(); + + SocketChannel selectableChannel = SocketChannel.open(); + selectableChannel.configureBlocking(false); + selectableChannel.connect(channel.localAddress()); + + final CountDownLatch latch = new CountDownLatch(1); + + loop.register(selectableChannel, SelectionKey.OP_CONNECT, new NioTask() { + @Override + public void channelReady(SocketChannel ch, SelectionKey key) { + latch.countDown(); + } + + @Override + public void channelUnregistered(SocketChannel ch, Throwable cause) { + } + }); + + latch.await(); + + selectableChannel.close(); + channel.close().syncUninterruptibly(); + } finally { + group.shutdownGracefully(); + } + } + + @SuppressWarnings("deprecation") + @Test + public void testTaskRemovalOnShutdownThrowsNoUnsupportedOperationException() throws Exception { + final AtomicReference error = new AtomicReference(); + final Runnable task = new Runnable() { + @Override + public void run() { + // NOOP + } + }; + // Just run often enough to trigger it normally. + for (int i = 0; i < 1000; i++) { + NioEventLoopGroup group = new NioEventLoopGroup(1); + final NioEventLoop loop = (NioEventLoop) group.next(); + + Thread t = new Thread(new Runnable() { + @Override + public void run() { + try { + for (;;) { + loop.execute(task); + } + } catch (Throwable cause) { + error.set(cause); + } + } + }); + t.start(); + group.shutdownNow(); + t.join(); + group.terminationFuture().syncUninterruptibly(); + assertThat(error.get(), IsInstanceOf.instanceOf(RejectedExecutionException.class)); + error.set(null); + } + } + + @Test + public void testRebuildSelectorOnIOException() { + SelectStrategyFactory selectStrategyFactory = new SelectStrategyFactory() { + @Override + public SelectStrategy newSelectStrategy() { + return new SelectStrategy() { + + private boolean thrown; + + @Override + public int calculateStrategy(IntSupplier selectSupplier, boolean hasTasks) throws Exception { + if (!thrown) { + thrown = true; + throw new IOException(); + } + return -1; + } + }; + } + }; + + EventLoopGroup group = new NioEventLoopGroup(1, new DefaultThreadFactory("ioPool"), + SelectorProvider.provider(), selectStrategyFactory); + final NioEventLoop loop = (NioEventLoop) group.next(); + try { + Channel channel = new NioServerSocketChannel(); + Selector selector = loop.unwrappedSelector(); + + loop.register(channel).syncUninterruptibly(); + + Selector newSelector = ((NioEventLoop) channel.eventLoop()).unwrappedSelector(); + assertTrue(newSelector.isOpen()); + assertNotSame(selector, newSelector); + assertFalse(selector.isOpen()); + + channel.close().syncUninterruptibly(); + } finally { + group.shutdownGracefully(); + } + } + } diff --git a/transport/src/test/java/io/netty/channel/nio/SelectedSelectionKeySetTest.java b/transport/src/test/java/io/netty/channel/nio/SelectedSelectionKeySetTest.java index 5c32001b8ed7..88bb7b0a29a9 100644 --- a/transport/src/test/java/io/netty/channel/nio/SelectedSelectionKeySetTest.java +++ b/transport/src/test/java/io/netty/channel/nio/SelectedSelectionKeySetTest.java @@ -21,12 +21,10 @@ import org.mockito.MockitoAnnotations; import java.nio.channels.SelectionKey; +import java.util.Iterator; +import java.util.NoSuchElementException; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public class SelectedSelectionKeySetTest { @Mock @@ -34,6 +32,9 @@ public class SelectedSelectionKeySetTest { @Mock private SelectionKey mockKey2; + @Mock + private SelectionKey mockKey3; + @Before public void setup() { MockitoAnnotations.initMocks(this); @@ -63,4 +64,49 @@ public void resetSet() { assertEquals(0, set.size()); assertTrue(set.isEmpty()); } + + @Test + public void iterator() { + SelectedSelectionKeySet set = new SelectedSelectionKeySet(); + assertTrue(set.add(mockKey)); + assertTrue(set.add(mockKey2)); + Iterator keys = set.iterator(); + assertTrue(keys.hasNext()); + assertSame(mockKey, keys.next()); + assertTrue(keys.hasNext()); + assertSame(mockKey2, keys.next()); + assertFalse(keys.hasNext()); + + try { + keys.next(); + fail(); + } catch (NoSuchElementException expected) { + // expected + } + + try { + keys.remove(); + fail(); + } catch (UnsupportedOperationException expected) { + // expected + } + } + + @Test + public void contains() { + SelectedSelectionKeySet set = new SelectedSelectionKeySet(); + assertTrue(set.add(mockKey)); + assertTrue(set.add(mockKey2)); + assertFalse(set.contains(mockKey)); + assertFalse(set.contains(mockKey2)); + assertFalse(set.contains(mockKey3)); + } + + @Test + public void remove() { + SelectedSelectionKeySet set = new SelectedSelectionKeySet(); + assertTrue(set.add(mockKey)); + assertFalse(set.remove(mockKey)); + assertFalse(set.remove(mockKey2)); + } } diff --git a/transport/src/test/java/io/netty/channel/socket/nio/AbstractNioChannelTest.java b/transport/src/test/java/io/netty/channel/socket/nio/AbstractNioChannelTest.java new file mode 100644 index 000000000000..5652a4ebac3a --- /dev/null +++ b/transport/src/test/java/io/netty/channel/socket/nio/AbstractNioChannelTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.channel.ChannelOption; +import io.netty.channel.nio.AbstractNioChannel; +import org.junit.Test; + +import java.io.IOException; +import java.net.SocketOption; +import java.net.StandardSocketOptions; +import java.nio.channels.NetworkChannel; + +import static org.junit.Assert.*; + +public abstract class AbstractNioChannelTest { + + protected abstract T newNioChannel(); + + protected abstract NetworkChannel jdkChannel(T channel); + + protected abstract SocketOption newInvalidOption(); + + @Test + public void testNioChannelOption() throws IOException { + T channel = newNioChannel(); + try { + NetworkChannel jdkChannel = jdkChannel(channel); + ChannelOption option = NioChannelOption.of(StandardSocketOptions.SO_REUSEADDR); + boolean value1 = jdkChannel.getOption(StandardSocketOptions.SO_REUSEADDR); + boolean value2 = channel.config().getOption(option); + + assertEquals(value1, value2); + + channel.config().setOption(option, !value2); + boolean value3 = jdkChannel.getOption(StandardSocketOptions.SO_REUSEADDR); + boolean value4 = channel.config().getOption(option); + assertEquals(value3, value4); + assertNotEquals(value1, value4); + } finally { + channel.unsafe().closeForcibly(); + } + } + + @Test + public void testInvalidNioChannelOption() { + T channel = newNioChannel(); + try { + ChannelOption option = NioChannelOption.of(newInvalidOption()); + assertFalse(channel.config().setOption(option, null)); + assertNull(channel.config().getOption(option)); + } finally { + channel.unsafe().closeForcibly(); + } + } + + @Test + public void testGetOptions() { + T channel = newNioChannel(); + try { + channel.config().getOptions(); + } finally { + channel.unsafe().closeForcibly(); + } + } +} diff --git a/transport/src/test/java/io/netty/channel/nio/NioDatagramChannelTest.java b/transport/src/test/java/io/netty/channel/socket/nio/NioDatagramChannelTest.java similarity index 78% rename from transport/src/test/java/io/netty/channel/nio/NioDatagramChannelTest.java rename to transport/src/test/java/io/netty/channel/socket/nio/NioDatagramChannelTest.java index 343bd1920e98..d0e235efcbfb 100644 --- a/transport/src/test/java/io/netty/channel/nio/NioDatagramChannelTest.java +++ b/transport/src/test/java/io/netty/channel/socket/nio/NioDatagramChannelTest.java @@ -13,24 +13,27 @@ * License for the specific language governing permissions and limitations * under the License. */ -package io.netty.channel.nio; +package io.netty.channel.socket.nio; import io.netty.bootstrap.Bootstrap; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelOption; import io.netty.channel.group.DefaultChannelGroup; +import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.DatagramChannel; -import io.netty.channel.socket.nio.NioDatagramChannel; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.GlobalEventExecutor; import org.junit.Assert; import org.junit.Test; import java.net.InetSocketAddress; +import java.net.SocketOption; +import java.net.StandardSocketOptions; +import java.nio.channels.NetworkChannel; -public class NioDatagramChannelTest { +public class NioDatagramChannelTest extends AbstractNioChannelTest { /** * Test try to reproduce issue #1335 @@ -61,4 +64,19 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { group.shutdownGracefully().sync(); } } + + @Override + protected NioDatagramChannel newNioChannel() { + return new NioDatagramChannel(); + } + + @Override + protected NetworkChannel jdkChannel(NioDatagramChannel channel) { + return channel.javaChannel(); + } + + @Override + protected SocketOption newInvalidOption() { + return StandardSocketOptions.TCP_NODELAY; + } } diff --git a/transport/src/test/java/io/netty/channel/socket/nio/NioServerSocketChannelTest.java b/transport/src/test/java/io/netty/channel/socket/nio/NioServerSocketChannelTest.java index 80779902f76b..ec502239771f 100644 --- a/transport/src/test/java/io/netty/channel/socket/nio/NioServerSocketChannelTest.java +++ b/transport/src/test/java/io/netty/channel/socket/nio/NioServerSocketChannelTest.java @@ -22,9 +22,12 @@ import java.io.IOException; import java.net.InetSocketAddress; +import java.net.SocketOption; +import java.net.StandardSocketOptions; +import java.nio.channels.NetworkChannel; import java.nio.channels.ServerSocketChannel; -public class NioServerSocketChannelTest { +public class NioServerSocketChannelTest extends AbstractNioChannelTest { @Test public void testCloseOnError() throws Exception { @@ -41,4 +44,19 @@ public void testCloseOnError() throws Exception { group.shutdownGracefully(); } } + + @Override + protected NioServerSocketChannel newNioChannel() { + return new NioServerSocketChannel(); + } + + @Override + protected NetworkChannel jdkChannel(NioServerSocketChannel channel) { + return channel.javaChannel(); + } + + @Override + protected SocketOption newInvalidOption() { + return StandardSocketOptions.IP_MULTICAST_IF; + } } diff --git a/transport/src/test/java/io/netty/channel/nio/NioSocketChannelTest.java b/transport/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java similarity index 94% rename from transport/src/test/java/io/netty/channel/nio/NioSocketChannelTest.java rename to transport/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java index 69b601a7578d..4819c44b04a7 100644 --- a/transport/src/test/java/io/netty/channel/nio/NioSocketChannelTest.java +++ b/transport/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java @@ -13,7 +13,7 @@ * License for the specific language governing permissions and limitations * under the License. */ -package io.netty.channel.nio; +package io.netty.channel.socket.nio; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; @@ -30,9 +30,8 @@ import io.netty.channel.EventLoop; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; -import io.netty.channel.socket.nio.NioServerSocketChannel; -import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.util.CharsetUtil; import io.netty.util.NetUtil; import io.netty.util.internal.PlatformDependent; @@ -46,7 +45,10 @@ import java.net.ServerSocket; import java.net.Socket; import java.net.SocketAddress; +import java.net.SocketOption; +import java.net.StandardSocketOptions; import java.nio.channels.ClosedChannelException; +import java.nio.channels.NetworkChannel; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; @@ -55,7 +57,7 @@ import static org.junit.Assert.*; -public class NioSocketChannelTest { +public class NioSocketChannelTest extends AbstractNioChannelTest { /** * Reproduces the issue #1600 @@ -278,4 +280,19 @@ public void testShutdownOutputAndClose() throws IOException { group.shutdownGracefully(); } } + + @Override + protected NioSocketChannel newNioChannel() { + return new NioSocketChannel(); + } + + @Override + protected NetworkChannel jdkChannel(NioSocketChannel channel) { + return channel.javaChannel(); + } + + @Override + protected SocketOption newInvalidOption() { + return StandardSocketOptions.IP_MULTICAST_IF; + } } diff --git a/transport/test.log b/transport/test.log new file mode 100644 index 000000000000..e69de29bb2d1