diff --git a/vertx-core/src/main/java/io/vertx/core/http/impl/WebSocketGroup.java b/vertx-core/src/main/java/io/vertx/core/http/impl/WebSocketGroup.java index adf351476b2..a37dc7d6a6b 100644 --- a/vertx-core/src/main/java/io/vertx/core/http/impl/WebSocketGroup.java +++ b/vertx-core/src/main/java/io/vertx/core/http/impl/WebSocketGroup.java @@ -15,6 +15,7 @@ import io.vertx.core.http.WebSocketClientOptions; import io.vertx.core.http.WebSocketConnectOptions; import io.vertx.core.internal.ContextInternal; +import io.vertx.core.internal.PromiseInternal; import io.vertx.core.internal.resource.ManagedResource; import io.vertx.core.spi.metrics.ClientMetrics; import io.vertx.core.spi.metrics.PoolMetrics; @@ -47,7 +48,6 @@ private static class Waiter { private final HttpChannelConnector connector; private final Deque waiters; private int inflightConnections; - private final ClientMetrics clientMetrics; private final PoolMetrics poolMetrics; @@ -72,19 +72,7 @@ public Future requestConnection(ContextInternal ctx, WebSocketConnect return fut; } - private void onEvict() { - decRefCount(); - Waiter h; - synchronized (WebSocketGroup.this) { - if (--inflightConnections > maxPoolSize || waiters.isEmpty()) { - return; - } - h = waiters.poll(); - } - tryConnect(h.context, h.connectOptions).onComplete(h.promise); - } - - private Future tryConnect(ContextInternal ctx, WebSocketConnectOptions connectOptions) { + private void connect(ContextInternal ctx, WebSocketConnectOptions connectOptions, Promise promise) { ContextInternal eventLoopContext; if (ctx.isEventLoopContext()) { eventLoopContext = ctx; @@ -92,50 +80,80 @@ private Future tryConnect(ContextInternal ctx, WebSocketConnectOption eventLoopContext = ctx.owner().createEventLoopContext(ctx.nettyEventLoop(), ctx.workerPool(), ctx.classLoader()); } Future fut = connector.httpConnect(eventLoopContext); - return fut.compose(c -> { - if (!incRefCount()) { - c.close(); - return Future.failedFuture(new VertxException("Connection closed", true)); - } - long timeout = Math.max(connectOptions.getTimeout(), 0L); - if (connectOptions.getIdleTimeout() >= 0L) { - timeout = connectOptions.getIdleTimeout(); - } - Http1xClientConnection ci = (Http1xClientConnection) c; - Promise promise = ctx.promise(); - ci.toWebSocket( - ctx, - connectOptions.getURI(), - connectOptions.getHeaders(), - connectOptions.getAllowOriginHeader(), - options, - connectOptions.getVersion(), - connectOptions.getSubProtocols(), - timeout, - connectOptions.isRegisterWriteHandlers(), - options.getMaxFrameSize(), - promise); - return promise.future().andThen(ar -> { - if (ar.succeeded()) { - WebSocketImpl wsi = (WebSocketImpl) ar.result(); - wsi.evictionHandler(v -> onEvict()); - } else { - onEvict(); + fut.onComplete(ar -> { + if (ar.succeeded()) { + HttpClientConnectionInternal c = ar.result(); + if (!incRefCount()) { + c.close(); + promise.fail(new VertxException("Connection closed", true)); + return; } - }); + long timeout = Math.max(connectOptions.getTimeout(), 0L); + if (connectOptions.getIdleTimeout() >= 0L) { + timeout = connectOptions.getIdleTimeout(); + } + Http1xClientConnection ci = (Http1xClientConnection) c; + ci.toWebSocket( + ctx, + connectOptions.getURI(), + connectOptions.getHeaders(), + connectOptions.getAllowOriginHeader(), + options, + connectOptions.getVersion(), + connectOptions.getSubProtocols(), + timeout, + connectOptions.isRegisterWriteHandlers(), + options.getMaxFrameSize(), + promise); + } else { + promise.fail(ar.cause()); + } }); } - protected Future requestConnection2(ContextInternal ctx, WebSocketConnectOptions connectOptions, long timeout) { + private void release() { + Waiter waiter; + synchronized (WebSocketGroup.this) { + if (--inflightConnections > maxPoolSize || waiters.isEmpty()) { + return; + } + waiter = waiters.poll(); + } + connect(waiter.context, waiter.connectOptions, waiter.promise); + } + + private Future tryAcquire(ContextInternal ctx, WebSocketConnectOptions options) { synchronized (this) { if (inflightConnections >= maxPoolSize) { - Waiter waiter = new Waiter(ctx, connectOptions); + Waiter waiter = new Waiter(ctx, options); waiters.add(waiter); return waiter.promise.future(); } inflightConnections++; } - return tryConnect(ctx, connectOptions); + return null; + } + + protected Future requestConnection2(ContextInternal ctx, WebSocketConnectOptions connectOptions, long timeout) { + Future res = tryAcquire(ctx, connectOptions); + if (res == null) { + PromiseInternal promise = ctx.promise(); + connect(ctx, connectOptions, promise); + res = promise.future(); + } + res.andThen(ar -> { + if (ar.succeeded()) { + WebSocketImpl wsi = (WebSocketImpl) ar.result(); + wsi.evictionHandler(v -> { + decRefCount(); + release(); + }); + } else { + decRefCount(); + release(); + } + }); + return res; } @Override diff --git a/vertx-core/src/test/java/io/vertx/tests/http/WebSocketTest.java b/vertx-core/src/test/java/io/vertx/tests/http/WebSocketTest.java index 97fb3757470..db08798b02b 100644 --- a/vertx-core/src/test/java/io/vertx/tests/http/WebSocketTest.java +++ b/vertx-core/src/test/java/io/vertx/tests/http/WebSocketTest.java @@ -46,6 +46,7 @@ import io.vertx.core.net.NetSocket; import io.vertx.core.net.SocketAddress; import io.vertx.test.core.CheckingSender; +import io.vertx.test.core.Repeat; import io.vertx.test.core.TestUtils; import io.vertx.test.core.VertxTestBase; import io.vertx.test.http.HttpTestBase; @@ -95,7 +96,6 @@ import static io.vertx.test.http.HttpTestBase.DEFAULT_HTTP_HOST; import static io.vertx.test.http.HttpTestBase.DEFAULT_HTTP_HOST_AND_PORT; import static io.vertx.test.http.HttpTestBase.DEFAULT_HTTP_PORT; -import static org.junit.Assume.assumeTrue; /** * @author Tim Fox @@ -3987,4 +3987,46 @@ public void testCustomResponseHeadersBeforeUpgrade() throws InterruptedException })); await(); } + + @Test + public void testPoolShouldNotStarveOnConnectError() throws Exception { + + server = vertx.createHttpServer(); + + CountDownLatch shutdownLatch = new CountDownLatch(1); + AtomicInteger accepted = new AtomicInteger(); + server.webSocketHandler(ws -> { + ws.shutdownHandler(v -> shutdownLatch.countDown()); + assertTrue(accepted.getAndIncrement() == 0); + }); + + server.listen(DEFAULT_HTTP_PORT, DEFAULT_HTTP_HOST).toCompletionStage().toCompletableFuture().get(); + + int maxConnections = 5; + + client = vertx.createWebSocketClient(new WebSocketClientOptions() + .setMaxConnections(maxConnections) + .setConnectTimeout(4000)); + + Future wsFut = client.connect(DEFAULT_HTTP_PORT, DEFAULT_HTTP_HOST, "/").andThen(onSuccess(v -> { + })); + + // Finish handshake + wsFut.toCompletionStage().toCompletableFuture().get(10, TimeUnit.SECONDS); + + // This test requires a server socket to respond for the first connection + // Subsequent connections need to fail (connect error) + server.shutdown(30, TimeUnit.SECONDS); + awaitLatch(shutdownLatch); + + int num = maxConnections + 10; + CountDownLatch latch = new CountDownLatch(num); + for (int i = 0;i < num;i++) { + client.connect(DEFAULT_HTTP_PORT, DEFAULT_HTTP_HOST, "/").onComplete(ar -> { + latch.countDown(); + }); + } + + awaitLatch(latch, 10, TimeUnit.SECONDS); + } }