diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java index fb49281ed1..6c22f82953 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java @@ -1,5 +1,5 @@ /* - * Copyright © 2018-2023 Apple Inc. and the ServiceTalk project authors + * Copyright © 2018-2019, 2021-2022 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import io.servicetalk.concurrent.api.Publisher; import io.servicetalk.concurrent.api.Single; import io.servicetalk.concurrent.api.TerminalSignalConsumer; +import io.servicetalk.http.api.FilterableStreamingHttpConnection; import io.servicetalk.http.api.HttpConnectionContext; import io.servicetalk.http.api.HttpEventKey; import io.servicetalk.http.api.HttpExecutionContext; @@ -38,7 +39,6 @@ import io.servicetalk.transport.netty.internal.FlushStrategy; import io.servicetalk.transport.netty.internal.NettyConnectionContext; -import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,7 +66,7 @@ import static java.util.Objects.requireNonNull; abstract class AbstractStreamingHttpConnection - implements NettyFilterableStreamingHttpConnection { + implements FilterableStreamingHttpConnection { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractStreamingHttpConnection.class); static final IgnoreConsumedEvent ZERO_MAX_CONCURRENCY_EVENT = new IgnoreConsumedEvent<>(0); @@ -168,7 +168,7 @@ public void cancel() { } @Override - public final Single request(final StreamingHttpRequest request) { + public Single request(final StreamingHttpRequest request) { return defer(() -> { Publisher flatRequest; // See https://tools.ietf.org/html/rfc7230#section-3.3.3 @@ -260,11 +260,6 @@ public final StreamingHttpResponseFactory httpResponseFactory() { return reqRespFactory; } - @Override - public Channel nettyChannel() { - return connection.nettyChannel(); - } - @Override public final Completable onClose() { return connectionContext.onClose(); diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyFilterableStreamingHttpConnection.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyFilterableStreamingHttpConnection.java deleted file mode 100644 index 192e9fa431..0000000000 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyFilterableStreamingHttpConnection.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright © 2023 Apple Inc. and the ServiceTalk project authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.servicetalk.http.netty; - -import io.servicetalk.http.api.FilterableStreamingHttpConnection; - -import io.netty.channel.Channel; - -/** - * {@link FilterableStreamingHttpConnection} that also gives access to Netty {@link Channel}. - */ -interface NettyFilterableStreamingHttpConnection extends FilterableStreamingHttpConnection { - - /** - * Return the Netty {@link Channel} backing this connection. - * - * @return the Netty {@link Channel} backing this connection. - */ - Channel nettyChannel(); -} diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java index f3951d1e76..b9b4d20edf 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java @@ -47,8 +47,7 @@ final class PipelinedLBHttpConnectionFactory extends AbstractLB Single newFilterableConnection(final ResolvedAddress resolvedAddress, final TransportObserver observer) { assert config.h1Config() != null; - return buildStreaming(executionContext, resolvedAddress, config.tcpConfig(), config.h1Config(), - config.hasProxy(), observer) + return buildStreaming(executionContext, resolvedAddress, config, observer) .map(conn -> new PipelinedStreamingHttpConnection(conn, config.h1Config(), reqRespFactoryFunc.apply(HTTP_1_1), config.allowDropTrailersReadFromTransport())); } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectChannelSingle.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectChannelSingle.java new file mode 100644 index 0000000000..c921e8ee04 --- /dev/null +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectChannelSingle.java @@ -0,0 +1,206 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.http.netty; + +import io.servicetalk.concurrent.SingleSource; +import io.servicetalk.concurrent.api.Single; +import io.servicetalk.http.api.HttpHeadersFactory; +import io.servicetalk.http.api.HttpRequestMetaData; +import io.servicetalk.http.api.HttpResponseMetaData; +import io.servicetalk.http.netty.ProxyConnectException.RetryableProxyConnectException; +import io.servicetalk.transport.api.ConnectionObserver; +import io.servicetalk.transport.api.ConnectionObserver.ProxyConnectObserver; +import io.servicetalk.transport.netty.internal.ChannelInitializer; +import io.servicetalk.transport.netty.internal.CloseHandler.InboundDataEndEvent; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import static io.servicetalk.http.api.HttpHeaderNames.HOST; +import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; +import static io.servicetalk.http.api.HttpRequestMetaDataFactory.newRequestMetaData; +import static io.servicetalk.http.api.HttpRequestMethod.CONNECT; +import static io.servicetalk.http.api.HttpResponseStatus.StatusClass.SUCCESSFUL_2XX; +import static io.servicetalk.transport.netty.internal.ChannelCloseUtils.assignConnectionError; + +/** + * A {@link Single} that adds a {@link ChannelHandler} into {@link ChannelPipeline} to perform HTTP/1.1 CONNECT + * exchange. + */ +final class ProxyConnectChannelSingle extends ChannelInitSingle { + + private final ConnectionObserver observer; + private final HttpHeadersFactory headersFactory; + private final String connectAddress; + + ProxyConnectChannelSingle(final Channel channel, + final ChannelInitializer channelInitializer, + final ConnectionObserver observer, + final HttpHeadersFactory headersFactory, + final String connectAddress) { + super(channel, channelInitializer); + this.observer = observer; + this.headersFactory = headersFactory; + this.connectAddress = connectAddress; + assert !channel.config().isAutoRead(); + } + + @Override + protected ChannelHandler newChannelHandler(final Subscriber subscriber) { + return new ProxyConnectHandler(observer, headersFactory, connectAddress, subscriber); + } + + private static final class ProxyConnectHandler extends ChannelDuplexHandler { + + private static final Logger LOGGER = LoggerFactory.getLogger(ProxyConnectHandler.class); + + private final ConnectionObserver observer; + private final HttpHeadersFactory headersFactory; + private final String connectAddress; + @Nullable + private Subscriber subscriber; + @Nullable + private ProxyConnectObserver connectObserver; + @Nullable + private HttpResponseMetaData response; + + private ProxyConnectHandler(final ConnectionObserver observer, + final HttpHeadersFactory headersFactory, + final String connectAddress, + final Subscriber subscriber) { + this.observer = observer; + this.headersFactory = headersFactory; + this.connectAddress = connectAddress; + this.subscriber = subscriber; + } + + @Override + public void handlerAdded(final ChannelHandlerContext ctx) { + if (ctx.channel().isActive()) { + sendConnectRequest(ctx); + } + } + + @Override + public void channelActive(final ChannelHandlerContext ctx) { + sendConnectRequest(ctx); + ctx.fireChannelActive(); + } + + private void sendConnectRequest(final ChannelHandlerContext ctx) { + final HttpRequestMetaData request = newRequestMetaData(HTTP_1_1, CONNECT, connectAddress, + headersFactory.newHeaders()).addHeader(HOST, connectAddress); + connectObserver = observer.onProxyConnect(request); + ctx.writeAndFlush(request).addListener(f -> { + if (f.isSuccess()) { + ctx.read(); + } else { + failSubscriber(ctx, new RetryableProxyConnectException( + "Failed to write CONNECT request", f.cause())); + } + }); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + if (msg instanceof HttpResponseMetaData) { + if (response != null) { + failSubscriber(ctx, new RetryableProxyConnectException( + "Received two responses for a single CONNECT request")); + return; + } + response = (HttpResponseMetaData) msg; + if (response.status().statusClass() != SUCCESSFUL_2XX) { + failSubscriber(ctx, new ProxyResponseException("Non-successful response '" + response.status() + + "' from proxy on CONNECT " + connectAddress, response.status())); + } + // We do not complete subscriber here because we need to wait for the HttpResponseDecoder state machine + // to complete. Completion will be signalled by InboundDataEndEvent. Any other messages before that are + // unexpected, see https://datatracker.ietf.org/doc/html/rfc9110#section-9.3.6 + // It also helps to make sure we do not propagate InboundDataEndEvent after the next handlers are added + // to the pipeline, potentially causing changes in their state machine. + } else { + failSubscriber(ctx, new RetryableProxyConnectException( + "Received unexpected message in the pipeline of type: " + msg.getClass().getName())); + } + } + + @Override + public void channelReadComplete(final ChannelHandlerContext ctx) throws Exception { + if (subscriber != null) { + ctx.read(); // Keep requesting until finished + } + ctx.fireChannelReadComplete(); + } + + @Override + public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { + if (evt != InboundDataEndEvent.INSTANCE || subscriber == null) { + ctx.fireUserEventTriggered(evt); + return; + } + assert response != null; + assert connectObserver != null; + connectObserver.proxyConnectComplete(response); + ctx.pipeline().remove(this); + final Channel channel = ctx.channel(); + LOGGER.debug("{} Received successful response from proxy on CONNECT {}", channel, connectAddress); + final Subscriber subscriberCopy = subscriber; + subscriber = null; + subscriberCopy.onSuccess(channel); + } + + @Override + public void channelInactive(final ChannelHandlerContext ctx) { + if (subscriber != null) { + failSubscriber(ctx, new RetryableProxyConnectException( + "Connection closed before proxy CONNECT finished")); + return; + } + ctx.fireChannelInactive(); + } + + @Override + public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) throws Exception { + if (subscriber != null) { + failSubscriber(ctx, new ProxyConnectException( + "Unexpected exception before proxy CONNECT finished", cause)); + return; + } + ctx.fireExceptionCaught(cause); + } + + private void failSubscriber(final ChannelHandlerContext ctx, final Throwable cause) { + assignConnectionError(ctx.channel(), cause); + if (subscriber != null) { + if (connectObserver != null) { + connectObserver.proxyConnectFailed(cause); + } + final SingleSource.Subscriber subscriberCopy = subscriber; + subscriber = null; + subscriberCopy.onError(cause); + } + ctx.close(); + } + } +} diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectException.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectException.java new file mode 100644 index 0000000000..262f74fed5 --- /dev/null +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectException.java @@ -0,0 +1,51 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.http.netty; + +import io.servicetalk.transport.api.RetryableException; + +import java.io.IOException; + +/** + * An exception while processing + * HTTP/1.1 CONNECT request. + */ +public class ProxyConnectException extends IOException { + + private static final long serialVersionUID = 4453075928788773272L; + + ProxyConnectException(final String message) { + super(message); + } + + ProxyConnectException(final String message, final Throwable cause) { + super(message, cause); + } + + static final class RetryableProxyConnectException extends ProxyConnectException + implements RetryableException { + + private static final long serialVersionUID = 5118637083568536242L; + + RetryableProxyConnectException(final String message) { + super(message); + } + + RetryableProxyConnectException(final String message, final Throwable cause) { + super(message, cause); + } + } +} diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java index 5e0f4924dc..94787f308c 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java @@ -16,39 +16,48 @@ package io.servicetalk.http.netty; import io.servicetalk.client.api.ConnectionFactoryFilter; -import io.servicetalk.concurrent.api.Completable; import io.servicetalk.concurrent.api.Single; import io.servicetalk.http.api.FilterableStreamingHttpConnection; import io.servicetalk.http.api.HttpExecutionContext; +import io.servicetalk.http.api.HttpExecutionStrategy; import io.servicetalk.http.api.StreamingHttpConnectionFilterFactory; -import io.servicetalk.http.api.StreamingHttpRequest; import io.servicetalk.http.api.StreamingHttpRequestResponseFactory; -import io.servicetalk.http.api.StreamingHttpResponse; import io.servicetalk.http.netty.AlpnChannelSingle.NoopChannelInitializer; +import io.servicetalk.http.netty.ProxyConnectException.RetryableProxyConnectException; import io.servicetalk.tcp.netty.internal.ReadOnlyTcpClientConfig; +import io.servicetalk.tcp.netty.internal.TcpClientChannelInitializer; +import io.servicetalk.tcp.netty.internal.TcpConnector; +import io.servicetalk.transport.api.ConnectionObserver; import io.servicetalk.transport.api.ExecutionStrategy; import io.servicetalk.transport.api.TransportObserver; +import io.servicetalk.transport.netty.internal.ChannelCloseUtils; +import io.servicetalk.transport.netty.internal.CloseHandler; import io.servicetalk.transport.netty.internal.DefaultNettyConnection; import io.servicetalk.transport.netty.internal.DeferSslHandler; -import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopConnectionObserver; import io.servicetalk.transport.netty.internal.StacklessClosedChannelException; import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelPipeline; +import java.nio.channels.ClosedChannelException; import javax.annotation.Nullable; +import javax.net.ssl.SSLException; -import static io.servicetalk.http.api.HttpContextKeys.HTTP_EXECUTION_STRATEGY_KEY; -import static io.servicetalk.http.api.HttpExecutionStrategies.offloadNone; -import static io.servicetalk.http.api.HttpHeaderNames.HOST; +import static io.netty.channel.ChannelOption.ALLOW_HALF_CLOSURE; +import static io.servicetalk.buffer.netty.BufferUtils.getByteBufAllocator; +import static io.servicetalk.http.api.HttpExecutionStrategies.customStrategyBuilder; import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_2_0; -import static io.servicetalk.http.api.HttpResponseStatus.StatusClass.SUCCESSFUL_2XX; import static io.servicetalk.http.netty.AlpnLBHttpConnectionFactory.unknownAlpnProtocol; +import static io.servicetalk.http.netty.HeaderUtils.OBJ_EXPECT_CONTINUE; +import static io.servicetalk.http.netty.HttpDebugUtils.showPipeline; +import static io.servicetalk.http.netty.HttpExecutionContextUtils.channelExecutionContext; import static io.servicetalk.http.netty.HttpProtocolConfigs.h1Default; -import static io.servicetalk.http.netty.StreamingConnectionFactory.buildStreaming; -import static io.servicetalk.utils.internal.ThrowableUtils.addSuppressed; +import static io.servicetalk.transport.netty.internal.CloseHandler.forPipelinedRequestResponse; +import static java.lang.Boolean.FALSE; +import static java.lang.Boolean.TRUE; /** * {@link AbstractLBHttpConnectionFactory} implementation that handles HTTP/1.1 CONNECT when a client is configured to @@ -59,6 +68,8 @@ final class ProxyConnectLBHttpConnectionFactory extends AbstractLBHttpConnectionFactory { + private static final HttpExecutionStrategy OFFLOAD_SEND_STRATEGY = customStrategyBuilder().offloadSend().build(); + private final String connectAddress; ProxyConnectLBHttpConnectionFactory( @@ -70,137 +81,118 @@ final class ProxyConnectLBHttpConnectionFactory final ProtocolBinding protocolBinding) { super(config, executionContext, version -> reqRespFactory, connectStrategy, connectionFactoryFilter, connectionFilterFunction, protocolBinding); + assert config.hasProxy() : "Unexpected hasProxy flag"; assert config.tcpConfig().sslContext() != null : "Proxy CONNECT works only for TLS connections"; assert config.connectAddress() != null : "Address (authority) for CONNECT request is required"; - connectAddress = config.connectAddress().toString(); + this.connectAddress = config.connectAddress().toString(); } @Override Single newFilterableConnection(final ResolvedAddress resolvedAddress, final TransportObserver observer) { final H1ProtocolConfig h1Config = config.h1Config() != null ? config.h1Config() : h1Default(); - return buildStreaming(executionContext, resolvedAddress, config.tcpConfig(), h1Config, config.hasProxy(), - observer) - // Always create PipelinedStreamingHttpConnection because: - // 1. buildStreaming creates a CloseHandler for pipelined request-response - // 2. in case ALPN negotiates HTTP/1.x we won't need to change the connection - .map(c -> new PipelinedStreamingHttpConnection(c, h1Config, - reqRespFactoryFunc.apply(HTTP_1_1), config.allowDropTrailersReadFromTransport())) - .flatMap(this::processConnect); + // Because we initialize HTTP/1.1 connection, autoRead must be set to false. + return TcpConnector.connect(null, resolvedAddress, config.tcpConfig(), false, executionContext, + (channel, connectionObserver) -> createConnection(channel, connectionObserver, h1Config), observer); } - // Visible for testing - Single processConnect(final NettyFilterableStreamingHttpConnection c) { - try { - // Send CONNECT request: https://datatracker.ietf.org/doc/html/rfc9110#section-9.3.6 - // Host header value must be equal to CONNECT request target, see - // https://github.com/haproxy/haproxy/issues/1159 - // https://datatracker.ietf.org/doc/html/rfc7230#section-5.4: - // If the target URI includes an authority component, then a client MUST send a field-value - // for Host that is identical to that authority component - final StreamingHttpRequest request = c.connect(connectAddress).setHeader(HOST, connectAddress); - // No need to offload because there is no user code involved - request.context().put(HTTP_EXECUTION_STRATEGY_KEY, offloadNone()); - return c.request(request) - .flatMap(response -> { - // Successful response to CONNECT never has a message body, and we are not interested in payload - // body for any non-200 status code. Drain it asap to free connection and RS resources before - // starting TLS handshake or propagating an error. We do this after verifying the status to - // preserve ProxyResponseException even if draining fails with an exception. - if (response.status().statusClass() != SUCCESSFUL_2XX) { - return drainPropagateError(response, new ProxyResponseException(c + - " Non-successful response from proxy CONNECT " + connectAddress, response.status())) - .shareContextOnSubscribe(); - } - return response.messageBody().ignoreElements() - .concat(handshake(c)) - .shareContextOnSubscribe(); - }) - // Close recently created connection in case of any error while it connects to the proxy: - .onErrorResume(t -> closePropagateError(c, t)); - // We do not apply shareContextOnSubscribe() here to isolate a context for `CONNECT` request. - } catch (Throwable t) { - return closePropagateError(c, t); - } + private Single createConnection( + final Channel channel, final ConnectionObserver observer, final H1ProtocolConfig h1Config) { + final ChannelConfig channelConfig = channel.config(); + final CloseHandler closeHandler = forPipelinedRequestResponse(true, channelConfig); + // Disable half-closure to simplify ProxyConnectHandler implementation + channelConfig.setOption(ALLOW_HALF_CLOSURE, FALSE); + return new ProxyConnectChannelSingle(channel, + new TcpClientChannelInitializer(config.tcpConfig(), observer, config.hasProxy()) + .andThen(new HttpClientChannelInitializer( + getByteBufAllocator(executionContext.bufferAllocator()), h1Config, closeHandler)), + observer, h1Config.headersFactory(), connectAddress) + .flatMap(ch -> handshake(ch, observer)) + .flatMap(protocol -> finishConnectionInitialization(protocol, channel, closeHandler, observer)) + .onErrorMap(cause -> handleException(cause, channel)); } - private Single handshake( - final NettyFilterableStreamingHttpConnection connection) { - return Single.defer(() -> { - final Channel channel = connection.nettyChannel(); - assert channel.eventLoop().inEventLoop(); - - final Single result; - final DeferSslHandler deferSslHandler = channel.pipeline().get(DeferSslHandler.class); - if (deferSslHandler == null) { - if (!channel.isActive()) { - result = Single.failed(StacklessClosedChannelException.newInstance(connection + - " Connection is closed, either received a 'Connection: closed' header or" + - " closed by the proxy. Investigate logs on a proxy side to identify the cause.", - ProxyConnectLBHttpConnectionFactory.class, "handshake")); - } else { - result = Single.failed(new IllegalStateException(connection + - " Unexpected connection state: failed to find a handler of type " + - DeferSslHandler.class + " in the channel pipeline.")); - } + private static Single handshake(final Channel channel, final ConnectionObserver connectionObserver) { + assert channel.eventLoop().inEventLoop(); + + final Single result; + final DeferSslHandler deferSslHandler = channel.pipeline().get(DeferSslHandler.class); + if (deferSslHandler == null) { + if (!channel.isActive()) { + result = Single.failed(new ProxyConnectException(channel + + " Connection is closed, either received a 'Connection: closed' header or closed by the proxy", + StacklessClosedChannelException.newInstance( + ProxyConnectLBHttpConnectionFactory.class, "handshake"))); } else { - result = new AlpnChannelSingle(channel, NoopChannelInitializer.INSTANCE, __ -> deferSslHandler.ready()); - } - return result.shareContextOnSubscribe(); - }).flatMap(protocol -> { - final Single result; - switch (protocol) { - case AlpnIds.HTTP_1_1: - // Nothing to do, HTTP/1.1 pipeline is already initialized - result = Single.succeeded(connection); - break; - case AlpnIds.HTTP_2: - final Channel channel = connection.nettyChannel(); - assert channel.eventLoop().inEventLoop(); - removeH1Handlers(channel); - result = initializeH2Connection(channel); - break; - default: - result = unknownAlpnProtocol(protocol); - break; + result = Single.failed(new ProxyConnectException(channel + + " Unexpected connection state: failed to find a handler of type " + DeferSslHandler.class + + " in the channel pipeline.")); } - return result.shareContextOnSubscribe(); - }); + } else { + result = new AlpnChannelSingle(channel, NoopChannelInitializer.INSTANCE, ctx -> deferSslHandler.ready()); + } + return result.shareContextOnSubscribe(); + } + + private Single finishConnectionInitialization( + final String protocol, final Channel channel, final CloseHandler closeHandler, + final ConnectionObserver connectionObserver) { + assert channel.eventLoop().inEventLoop(); + + final ReadOnlyTcpClientConfig tcpConfig = config.tcpConfig(); + final Single result; + switch (protocol) { + case AlpnIds.HTTP_1_1: + assert config.h1Config() != null; + // Re-enable half-closure to let CloseHandler work with DuplexChannel + channel.config().setOption(ALLOW_HALF_CLOSURE, TRUE); + result = showPipeline(DefaultNettyConnection.initChannel(channel, + channelExecutionContext(channel, executionContext), closeHandler, + tcpConfig.flushStrategy(), tcpConfig.idleTimeoutMs(), tcpConfig.sslConfig(), + NoopChannelInitializer.INSTANCE, HTTP_1_1, connectionObserver, true, OBJ_EXPECT_CONTINUE), + HTTP_1_1, channel) + .map(conn -> new PipelinedStreamingHttpConnection(conn, config.h1Config(), + reqRespFactoryFunc.apply(HTTP_1_1), config.allowDropTrailersReadFromTransport())); + break; + case AlpnIds.HTTP_2: + removeH1Handlers(channel); + final H2ProtocolConfig h2Config = config.h2Config(); + assert h2Config != null; + result = H2ClientParentConnectionContext.initChannel(channel, executionContext, h2Config, + reqRespFactoryFunc.apply(HTTP_2_0), tcpConfig.flushStrategy(), tcpConfig.idleTimeoutMs(), + tcpConfig.sslConfig(), new H2ClientParentChannelInitializer(h2Config), connectionObserver, + config.allowDropTrailersReadFromTransport()); + break; + default: + result = unknownAlpnProtocol(protocol); + break; + } + return result.shareContextOnSubscribe(); } private static void removeH1Handlers(final Channel channel) { final ChannelPipeline pipeline = channel.pipeline(); - pipeline.remove(DefaultNettyConnection.handlerClass()); for (Class handlerClass : HttpClientChannelInitializer.handlers()) { pipeline.remove(handlerClass); } } - private Single initializeH2Connection(final Channel channel) { - final H2ProtocolConfig h2Config = config.h2Config(); - assert h2Config != null; - final ReadOnlyTcpClientConfig tcpConfig = config.tcpConfig(); - return H2ClientParentConnectionContext.initChannel(channel, executionContext, h2Config, - reqRespFactoryFunc.apply(HTTP_2_0), tcpConfig.flushStrategy(), tcpConfig.idleTimeoutMs(), - tcpConfig.sslConfig(), new H2ClientParentChannelInitializer(h2Config), - // FIXME: propagate real observer - NoopConnectionObserver.INSTANCE, config.allowDropTrailersReadFromTransport()); - } - - private static Single drainPropagateError( - final StreamingHttpResponse response, final Throwable error) { - return safeCompletePropagateError(response.messageBody().ignoreElements(), error); - } - - private static Single closePropagateError( - final FilterableStreamingHttpConnection connection, final Throwable error) { - return safeCompletePropagateError(connection.closeAsync(), error); - } - - private static Single safeCompletePropagateError( - final Completable completable, final Throwable error) { - return completable - .onErrorResume(completableError -> Completable.failed(addSuppressed(error, completableError))) - .concat(Single.failed(error)); + private static Throwable handleException(final Throwable cause, final Channel channel) { + // If the Channel is still active, close it in case of any error to free resources. + if (channel.isActive()) { + ChannelCloseUtils.close(channel, cause); + } + if (cause instanceof SSLException) { + return cause; + } + if (cause instanceof ClosedChannelException) { + return new RetryableProxyConnectException(channel + " Connection is closed, either received a " + + "'Connection: closed' header or closed by the proxy", cause); + } + if (!(cause instanceof ProxyConnectException || cause instanceof ProxyResponseException)) { + return new RetryableProxyConnectException(channel + + " Unexpected exception during an attempt to connect to a proxy", cause); + } + return cause; } } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyResponseException.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyResponseException.java index 4c2c1c894a..03fd1f6727 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyResponseException.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyResponseException.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019 Apple Inc. and the ServiceTalk project authors + * Copyright © 2019, 2023 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,9 +41,4 @@ public final class ProxyResponseException extends IOException implements Retryab public HttpResponseStatus status() { return status; } - - @Override - public String toString() { - return super.toString() + ": " + status; - } } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/StreamingConnectionFactory.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/StreamingConnectionFactory.java index b4dd8df929..8cfe85ff80 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/StreamingConnectionFactory.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/StreamingConnectionFactory.java @@ -1,5 +1,5 @@ /* - * Copyright © 2018-2021 Apple Inc. and the ServiceTalk project authors + * Copyright © 2018-2023 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,13 +50,16 @@ private StreamingConnectionFactory() { static Single> buildStreaming( final HttpExecutionContext executionContext, final ResolvedAddress resolvedAddress, - final ReadOnlyTcpClientConfig originalTcpConfig, final H1ProtocolConfig h1Config, boolean hasProxy, - final TransportObserver observer) { - final ReadOnlyTcpClientConfig tcpConfig = withSslConfigPeerHost(resolvedAddress, originalTcpConfig); + final ReadOnlyHttpClientConfig roConfig, final TransportObserver observer) { + final ReadOnlyTcpClientConfig tcpConfig = withSslConfigPeerHost(resolvedAddress, roConfig.tcpConfig()); + final H1ProtocolConfig h1Config = roConfig.h1Config(); + assert h1Config != null; + assert !roConfig.hasProxy() || tcpConfig.sslContext() == null : + "This factory can be used only for non-proxied connections or for non-secure proxies"; // We disable auto read so we can handle stuff in the ConnectionFilter before we accept any content. return TcpConnector.connect(null, resolvedAddress, tcpConfig, false, executionContext, (channel, connectionObserver) -> createConnection(channel, executionContext, h1Config, tcpConfig, - new TcpClientChannelInitializer(tcpConfig, connectionObserver, hasProxy), + new TcpClientChannelInitializer(tcpConfig, connectionObserver, false), connectionObserver), observer); } diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/GracefulConnectionClosureHandlingTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/GracefulConnectionClosureHandlingTest.java index 2c995eedc6..528be8d6d1 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/GracefulConnectionClosureHandlingTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/GracefulConnectionClosureHandlingTest.java @@ -178,7 +178,6 @@ private void setUp(HttpProtocol protocol, boolean secure, boolean initiateClosur } if (viaProxy) { assumeTrue(secure, "Proxy tunnel works only with secure connections"); - assumeTrue(protocol != HTTP_2, "Proxy is not supported with HTTP/2"); } HttpServerBuilder serverBuilder = forAddress(useUds ? newSocketAddress() : localAddress(0)) @@ -292,9 +291,8 @@ private static Collection data() { for (boolean initiateClosureFromClient : TRUE_FALSE) { for (boolean useUds : TRUE_FALSE) { for (boolean viaProxy : TRUE_FALSE) { - if (viaProxy && (useUds || protocol == HTTP_2 || !secure)) { + if (viaProxy && (useUds || !secure)) { // UDS cannot be used via proxy - // Proxy is not supported with HTTP/2 // Proxy tunnel works only with secure connections continue; } diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpTransportObserverAsyncContextTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpTransportObserverAsyncContextTest.java index 41e6a0c1a8..66a027da32 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpTransportObserverAsyncContextTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpTransportObserverAsyncContextTest.java @@ -31,6 +31,7 @@ import io.servicetalk.transport.api.ConnectionObserver.WriteObserver; import io.servicetalk.transport.api.TransportObserver; import io.servicetalk.transport.netty.internal.NoopTransportObserver; +import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopProxyConnectObserver; import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopSecurityHandshakeObserver; import org.junit.jupiter.api.BeforeEach; @@ -198,6 +199,12 @@ public void onTransportHandshakeComplete() { // AsyncContext is unknown at this point because this event is triggered by network } + @Override + public ProxyConnectObserver onProxyConnect(final Object connectMsg) { + // AsyncContext is unknown at this point because this event is triggered by network + return NoopProxyConnectObserver.INSTANCE; + } + @Override public SecurityHandshakeObserver onSecurityHandshake() { // AsyncContext is unknown at this point because this event is triggered by network diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java index 2503cf8d41..ebf7e12b19 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java @@ -18,53 +18,77 @@ import io.servicetalk.client.api.ConnectionFactory; import io.servicetalk.client.api.ConnectionFactoryFilter; import io.servicetalk.client.api.DelegatingConnectionFactory; +import io.servicetalk.client.api.TransportObserverConnectionFactoryFilter; import io.servicetalk.concurrent.api.Single; import io.servicetalk.context.api.ContextMap; import io.servicetalk.http.api.BlockingHttpClient; import io.servicetalk.http.api.FilterableStreamingHttpConnection; +import io.servicetalk.http.api.HttpConnectionContext; import io.servicetalk.http.api.HttpProtocolVersion; import io.servicetalk.http.api.HttpRequest; import io.servicetalk.http.api.HttpResponse; import io.servicetalk.http.api.ReservedBlockingHttpConnection; import io.servicetalk.test.resources.DefaultTestCerts; import io.servicetalk.transport.api.ClientSslConfigBuilder; +import io.servicetalk.transport.api.ConnectionObserver; +import io.servicetalk.transport.api.ConnectionObserver.ProxyConnectObserver; +import io.servicetalk.transport.api.ConnectionObserver.SecurityHandshakeObserver; import io.servicetalk.transport.api.HostAndPort; +import io.servicetalk.transport.api.RetryableException; import io.servicetalk.transport.api.ServerContext; import io.servicetalk.transport.api.ServerSslConfigBuilder; import io.servicetalk.transport.api.TransportObserver; import io.servicetalk.transport.netty.internal.ExecutionContextExtension; +import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopDataObserver; +import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopMultiplexedObserver; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.InOrder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.io.OutputStream; import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; import java.util.List; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; +import javax.net.ssl.SSLHandshakeException; import static io.servicetalk.concurrent.api.Single.succeeded; import static io.servicetalk.http.api.HttpContextKeys.HTTP_TARGET_ADDRESS_BEHIND_PROXY; +import static io.servicetalk.http.api.HttpHeaderNames.CONNECTION; +import static io.servicetalk.http.api.HttpHeaderNames.CONTENT_LENGTH; import static io.servicetalk.http.api.HttpHeaderNames.HOST; +import static io.servicetalk.http.api.HttpHeaderValues.CLOSE; import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR; import static io.servicetalk.http.api.HttpResponseStatus.OK; import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED; import static io.servicetalk.http.api.HttpSerializers.textSerializerUtf8; -import static io.servicetalk.http.netty.HttpProtocol.HTTP_1; -import static io.servicetalk.http.netty.HttpProtocol.HTTP_2; import static io.servicetalk.http.netty.HttpProtocol.toConfigs; import static io.servicetalk.test.resources.DefaultTestCerts.serverPemHostname; import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; import static java.nio.charset.StandardCharsets.US_ASCII; -import static java.util.Arrays.asList; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; class HttpsProxyTest { @@ -83,6 +107,12 @@ class HttpsProxyTest { private final ProxyTunnel proxyTunnel = new ProxyTunnel(); private final AtomicReference targetAddress = new AtomicReference<>(); + private TransportObserver transportObserver; + private ConnectionObserver connectionObserver; + private ProxyConnectObserver proxyConnectObserver; + private SecurityHandshakeObserver securityHandshakeObserver; + private InOrder order; + @Nullable private HostAndPort proxyAddress; @Nullable @@ -92,14 +122,15 @@ class HttpsProxyTest { @Nullable private BlockingHttpClient client; - private static List> protocols() { - return asList(asList(HTTP_1), asList(HTTP_2), asList(HTTP_2, HTTP_1), asList(HTTP_1, HTTP_2)); + void setUp(List protocols) throws Exception { + setUp(protocols, false); } - void setUp(List protocols) throws Exception { + void setUp(List protocols, boolean failHandshake) throws Exception { + initMocks(); proxyAddress = proxyTunnel.startProxy(); startServer(protocols); - createClient(protocols); + createClient(protocols, failHandshake); } @AfterEach @@ -119,7 +150,21 @@ static void safeClose(@Nullable AutoCloseable closeable) { } } - void startServer(List protocols) throws Exception { + private void initMocks() { + transportObserver = mock(TransportObserver.class, "transportObserver"); + connectionObserver = mock(ConnectionObserver.class, "connectionObserver"); + proxyConnectObserver = mock(ProxyConnectObserver.class, "proxyConnectObserver"); + securityHandshakeObserver = mock(SecurityHandshakeObserver.class, "securityHandshakeObserver"); + when(transportObserver.onNewConnection(any(), any())).thenReturn(connectionObserver); + when(connectionObserver.onProxyConnect(any())).thenReturn(proxyConnectObserver); + lenient().when(connectionObserver.onSecurityHandshake()).thenReturn(securityHandshakeObserver); + lenient().when(connectionObserver.connectionEstablished(any())).thenReturn(NoopDataObserver.INSTANCE); + lenient().when(connectionObserver.multiplexedConnectionEstablished(any())) + .thenReturn(NoopMultiplexedObserver.INSTANCE); + order = inOrder(transportObserver, connectionObserver, proxyConnectObserver, securityHandshakeObserver); + } + + private void startServer(List protocols) throws Exception { serverContext = BuilderUtils.newServerBuilder(SERVER_CTX) .sslConfig(new ServerSslConfigBuilder(DefaultTestCerts::loadServerPem, DefaultTestCerts::loadServerKey).build()) @@ -129,19 +174,20 @@ void startServer(List protocols) throws Exception { serverAddress = serverHostAndPort(serverContext); } - void createClient(List protocols) { + private void createClient(List protocols, boolean failHandshake) { assert serverContext != null && proxyAddress != null; client = BuilderUtils.newClientBuilder(serverContext, CLIENT_CTX) .proxyAddress(proxyAddress) .sslConfig(new ClientSslConfigBuilder(DefaultTestCerts::loadServerCAPem) - .peerHost(serverPemHostname()).build()) + .peerHost(failHandshake ? "unknown" : serverPemHostname()).build()) .protocols(toConfigs(protocols)) + .appendConnectionFactoryFilter(new TransportObserverConnectionFactoryFilter<>(transportObserver)) .appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, true)) .buildBlocking(); } @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") - @MethodSource("protocols") + @MethodSource("io.servicetalk.http.netty.HttpProtocol#allCombinations") void testClientRequest(List protocols) throws Exception { setUp(protocols); assert client != null; @@ -149,23 +195,24 @@ void testClientRequest(List protocols) throws Exception { } @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") - @MethodSource("protocols") + @MethodSource("io.servicetalk.http.netty.HttpProtocol#allCombinations") void testConnectionRequest(List protocols) throws Exception { setUp(protocols); assert client != null; assert proxyAddress != null; HttpProtocolVersion expectedVersion = protocols.get(0).version; try (ReservedBlockingHttpConnection connection = client.reserveConnection(client.get("/"))) { - assertThat(connection.connectionContext().protocol(), is(expectedVersion)); - assertThat(connection.connectionContext().sslConfig(), is(notNullValue())); - assertThat(connection.connectionContext().sslSession(), is(notNullValue())); - assertThat(((InetSocketAddress) connection.connectionContext().remoteAddress()).getPort(), - is(proxyAddress.port())); + HttpConnectionContext ctx = connection.connectionContext(); + assertThat(ctx.protocol(), is(expectedVersion)); + assertThat(ctx.sslConfig(), is(notNullValue())); + assertThat(ctx.sslSession(), is(notNullValue())); + assertThat(serverHostAndPort(ctx.remoteAddress()).port(), is(proxyAddress.port())); HttpRequest request = connection.get("/path"); assertThat(request.version(), is(expectedVersion)); assertResponse(connection.request(request), expectedVersion); } + order.verify(connectionObserver).connectionClosed(); } private void assertResponse(HttpResponse httpResponse, HttpProtocolVersion expectedVersion) { @@ -173,11 +220,21 @@ private void assertResponse(HttpResponse httpResponse, HttpProtocolVersion expec assertThat(httpResponse.version(), is(expectedVersion)); assertThat(proxyTunnel.connectCount(), is(1)); assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress)); - assertThat(targetAddress.get(), is(equalTo(serverAddress.toString()))); + assertTargetAddress(); + + verifyProxyConnectComplete(); + order.verify(connectionObserver).onSecurityHandshake(); + order.verify(securityHandshakeObserver).handshakeComplete(any()); + if (expectedVersion.major() > 1) { + order.verify(connectionObserver).multiplexedConnectionEstablished(any()); + } else { + order.verify(connectionObserver).connectionEstablished(any()); + } + verifyNoMoreInteractions(transportObserver, proxyConnectObserver, securityHandshakeObserver); } @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") - @MethodSource("protocols") + @MethodSource("io.servicetalk.http.netty.HttpProtocol#allCombinations") void testProxyAuthRequired(List protocols) throws Exception { setUp(protocols); proxyTunnel.basicAuthToken(AUTH_TOKEN); @@ -185,21 +242,132 @@ void testProxyAuthRequired(List protocols) throws Exception { ProxyResponseException e = assertThrows(ProxyResponseException.class, () -> client.request(client.get("/path"))); assertThat(e.status(), is(PROXY_AUTHENTICATION_REQUIRED)); - assertThat(targetAddress.get(), is(equalTo(serverAddress.toString()))); + assertTargetAddress(); + verifyProxyConnectFailed(e); } @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") - @MethodSource("protocols") + @MethodSource("io.servicetalk.http.netty.HttpProtocol#allCombinations") void testBadProxyResponse(List protocols) throws Exception { setUp(protocols); proxyTunnel.badResponseProxy(); assert client != null; ProxyResponseException e = assertThrows(ProxyResponseException.class, () -> client.request(client.get("/path"))); + assertThat(e, is(instanceOf(RetryableException.class))); assertThat(e.status(), is(INTERNAL_SERVER_ERROR)); + assertTargetAddress(); + verifyProxyConnectFailed(e); + } + + @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") + @MethodSource("io.servicetalk.http.netty.HttpProtocol#allCombinations") + void testProxyClosesConnection(List protocols) throws Exception { + setUp(protocols); + proxyTunnel.proxyRequestHandler((socket, host, port, protocol) -> { + socket.close(); + }); + assert client != null; + ProxyConnectException e = assertThrows(ProxyConnectException.class, + () -> client.request(client.get("/path"))); + assertThat(e, is(instanceOf(RetryableException.class))); + assertThat(e.getCause(), is(anyOf(nullValue(), instanceOf(ClosedChannelException.class)))); + assertTargetAddress(); + verifyProxyConnectFailed(e); + } + + @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") + @MethodSource("io.servicetalk.http.netty.HttpProtocol#allCombinations") + void testProxyRespondsWithConnectionCloseHeader(List protocols) throws Exception { + setUp(protocols); + proxyTunnel.proxyRequestHandler((socket, host, port, protocol) -> { + final OutputStream os = socket.getOutputStream(); + os.write((protocol + ' ' + OK + "\r\n" + + CONNECTION + ": " + CLOSE + "\r\n" + + "\r\n").getBytes(UTF_8)); + os.flush(); + }); + assert client != null; + ProxyConnectException e = assertThrows(ProxyConnectException.class, + () -> client.request(client.get("/path"))); + assertThat(e, is(instanceOf(RetryableException.class))); + assertThat(e.getCause(), is(instanceOf(ClosedChannelException.class))); + assertTargetAddress(); + verifyProxyConnectComplete(); + } + + @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") + @MethodSource("io.servicetalk.http.netty.HttpProtocol#allCombinations") + void testProxyRespondsAndClosesConnection(List protocols) throws Exception { + setUp(protocols); + proxyTunnel.proxyRequestHandler((socket, host, port, protocol) -> { + final OutputStream os = socket.getOutputStream(); + os.write((protocol + ' ' + OK + "\r\n\r\n").getBytes(UTF_8)); + os.flush(); + socket.close(); + }); + assert client != null; + ProxyConnectException e = assertThrows(ProxyConnectException.class, + () -> client.request(client.get("/path"))); + assertThat(e, is(instanceOf(RetryableException.class))); + assertThat(e.getCause(), is(instanceOf(IOException.class))); + assertTargetAddress(); + verifyProxyConnectComplete(); + } + + @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") + @MethodSource("io.servicetalk.http.netty.HttpProtocol#allCombinations") + void testProxyRespondsWithPayloadBody(List protocols) throws Exception { + setUp(protocols); + String content = "content"; + proxyTunnel.proxyRequestHandler((socket, host, port, protocol) -> { + final OutputStream os = socket.getOutputStream(); + os.write((protocol + ' ' + OK + "\r\n" + + CONTENT_LENGTH + ": " + content.length() + "\r\n" + + "\r\n" + content).getBytes(UTF_8)); + os.flush(); + }); + assert client != null; + ProxyConnectException e = assertThrows(ProxyConnectException.class, + () -> client.request(client.get("/path"))); + assertThat(e, is(instanceOf(RetryableException.class))); + assertTargetAddress(); + verifyProxyConnectComplete(); + } + + @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") + @MethodSource("io.servicetalk.http.netty.HttpProtocol#allCombinations") + void testHandshakeFailed(List protocols) throws Exception { + setUp(protocols, true); + assert client != null; + SSLHandshakeException e = assertThrows(SSLHandshakeException.class, + () -> client.request(client.get("/path"))); + assertThat(e, is(not(instanceOf(RetryableException.class)))); + assertTargetAddress(); + + verifyProxyConnectComplete(); + order.verify(connectionObserver).onSecurityHandshake(); + order.verify(securityHandshakeObserver).handshakeFailed(e); + } + + private void assertTargetAddress() { assertThat(targetAddress.get(), is(equalTo(serverAddress.toString()))); } + private void verifyProxyConnectFailed(Throwable cause) { + order.verify(transportObserver).onNewConnection(any(), any()); + order.verify(connectionObserver).onTransportHandshakeComplete(); + order.verify(connectionObserver).onProxyConnect(any()); + order.verify(proxyConnectObserver).proxyConnectFailed(cause); + } + + private void verifyProxyConnectComplete() { + order.verify(transportObserver).onNewConnection(any(), any()); + order.verify(connectionObserver).onTransportHandshakeComplete(); + order.verify(connectionObserver).onProxyConnect(any()); + order.verify(proxyConnectObserver).proxyConnectComplete(any()); + } + static final class TargetAddressCheckConnectionFactoryFilter implements ConnectionFactoryFilter { diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java deleted file mode 100644 index 1a9f324586..0000000000 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java +++ /dev/null @@ -1,366 +0,0 @@ -/* - * Copyright © 2020-2023 Apple Inc. and the ServiceTalk project authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.servicetalk.http.netty; - -import io.servicetalk.client.api.ConnectionFactoryFilter; -import io.servicetalk.concurrent.Cancellable; -import io.servicetalk.concurrent.PublisherSource; -import io.servicetalk.concurrent.api.TestCompletable; -import io.servicetalk.concurrent.api.TestPublisher; -import io.servicetalk.concurrent.test.internal.TestSingleSubscriber; -import io.servicetalk.http.api.DefaultHttpHeadersFactory; -import io.servicetalk.http.api.DefaultStreamingHttpRequestResponseFactory; -import io.servicetalk.http.api.FilterableStreamingHttpConnection; -import io.servicetalk.http.api.HttpConnectionContext; -import io.servicetalk.http.api.HttpExecutionContext; -import io.servicetalk.http.api.HttpExecutionStrategy; -import io.servicetalk.http.api.StreamingHttpRequest; -import io.servicetalk.http.api.StreamingHttpRequestResponseFactory; -import io.servicetalk.http.api.StreamingHttpResponse; -import io.servicetalk.http.netty.AbstractLBHttpConnectionFactory.ProtocolBinding; -import io.servicetalk.transport.api.ClientSslConfig; -import io.servicetalk.transport.api.ClientSslConfigBuilder; -import io.servicetalk.transport.api.ConnectExecutionStrategy; -import io.servicetalk.transport.netty.internal.DeferSslHandler; - -import io.netty.channel.Channel; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandler; -import io.netty.channel.ChannelPipeline; -import io.netty.channel.EventLoop; -import io.netty.handler.ssl.SslHandler; -import io.netty.handler.ssl.SslHandshakeCompletionEvent; -import io.netty.util.Attribute; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.ArgumentCaptor; -import org.mockito.stubbing.Answer; - -import java.nio.channels.ClosedChannelException; -import java.util.concurrent.atomic.AtomicReference; - -import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR; -import static io.servicetalk.concurrent.Cancellable.IGNORE_CANCEL; -import static io.servicetalk.concurrent.api.Single.failed; -import static io.servicetalk.concurrent.api.Single.succeeded; -import static io.servicetalk.concurrent.api.SourceAdapters.toSource; -import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; -import static io.servicetalk.http.api.HttpContextKeys.HTTP_EXECUTION_STRATEGY_KEY; -import static io.servicetalk.http.api.HttpExecutionStrategies.offloadNone; -import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; -import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR; -import static io.servicetalk.http.api.HttpResponseStatus.OK; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; -import static org.hamcrest.Matchers.sameInstance; -import static org.mockito.ArgumentCaptor.forClass; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -class ProxyConnectLBHttpConnectionFactoryTest { - - private static final ClientSslConfig DEFAULT_SSL_CONFIG = new ClientSslConfigBuilder().build(); - private static final StreamingHttpRequestResponseFactory REQ_RES_FACTORY = - new DefaultStreamingHttpRequestResponseFactory(DEFAULT_ALLOCATOR, DefaultHttpHeadersFactory.INSTANCE, - HTTP_1_1); - private static final String CONNECT_ADDRESS = "foo.bar"; - - private final NettyFilterableStreamingHttpConnection connection; - private final TestCompletable connectionClose; - private final TestPublisher messageBody; - private final TestSingleSubscriber subscriber; - private final ProxyConnectLBHttpConnectionFactory connectionFactory; - private final ChannelHandlerContext ctx; - - ProxyConnectLBHttpConnectionFactoryTest() { - HttpExecutionContext executionContext = new HttpExecutionContextBuilder().build(); - HttpConnectionContext connectionContext = mock(HttpConnectionContext.class); - when(connectionContext.executionContext()).thenReturn(executionContext); - connection = mock(NettyFilterableStreamingHttpConnection.class); - when(connection.connectionContext()).thenReturn(connectionContext); - connectionClose = new TestCompletable.Builder().build(subscriber -> { - subscriber.onSubscribe(IGNORE_CANCEL); - subscriber.onComplete(); - return subscriber; - }); - when(connection.closeAsync()).thenReturn(connectionClose); - - messageBody = new TestPublisher.Builder<>().build(subscriber -> { - subscriber.onSubscribe(new PublisherSource.Subscription() { - @Override - public void request(final long n) { - subscriber.onComplete(); - } - - @Override - public void cancel() { - // noop - } - }); - return subscriber; - }); - - subscriber = new TestSingleSubscriber<>(); - - HttpClientConfig config = new HttpClientConfig(); - config.connectAddress(CONNECT_ADDRESS); - config.tcpConfig().sslConfig(DEFAULT_SSL_CONFIG); - connectionFactory = new ProxyConnectLBHttpConnectionFactory<>(config.asReadOnly(), - executionContext, null, REQ_RES_FACTORY, ConnectExecutionStrategy.offloadNone(), - ConnectionFactoryFilter.identity(), mock(ProtocolBinding.class)); - - ctx = mock(ChannelHandlerContext.class); - } - - private ChannelPipeline configurePipeline() { - return configurePipeline(new AtomicReference<>()); - } - - private ChannelPipeline configurePipeline(AtomicReference handlerCaptor) { - ChannelPipeline pipeline = mock(ChannelPipeline.class); - configureConnectionNettyChannel(pipeline); - when(pipeline.addLast(any())).then((Answer) invocation -> { - ChannelInboundHandler handshakeAwait = invocation.getArgument(0); - handlerCaptor.set(handshakeAwait); - handshakeAwait.handlerAdded(ctx); - return pipeline; - }); - return pipeline; - } - - private void configureDeferSslHandler(ChannelPipeline pipeline, - AtomicReference handlerCaptor, - SslHandshakeCompletionEvent event) { - DeferSslHandler deferSslHandler = mock(DeferSslHandler.class); - when(pipeline.get(DeferSslHandler.class)).thenReturn(deferSslHandler); - doAnswer(invocation -> { - final ChannelInboundHandler handler = handlerCaptor.get(); - handler.userEventTriggered(ctx, event); - if (!event.isSuccess()) { - handler.exceptionCaught(ctx, event.cause()); - } - return null; - }).when(deferSslHandler).ready(); - when(pipeline.get(SslHandler.class)).thenReturn(mock(SslHandler.class)); - } - - @SuppressWarnings("unchecked") - private void configureConnectionNettyChannel(final ChannelPipeline pipeline) { - Channel channel = mock(Channel.class); - EventLoop eventLoop = mock(EventLoop.class); - when(eventLoop.inEventLoop()).thenReturn(true); - when(channel.eventLoop()).thenReturn(eventLoop); - when(channel.pipeline()).thenReturn(pipeline); - when(channel.attr(any())).thenReturn(mock(Attribute.class)); - when(pipeline.channel()).thenReturn(channel); - when(connection.nettyChannel()).thenReturn(channel); - when(ctx.channel()).thenReturn(channel); - when(ctx.pipeline()).thenReturn(pipeline); - } - - private void configureRequestSend() { - StreamingHttpResponse response = mock(StreamingHttpResponse.class); - when(response.status()).thenReturn(OK); - when(response.messageBody()).thenReturn(messageBody); - when(connection.request(any())).thenReturn(succeeded(response)); - } - - private void configureConnectRequest() { - when(connection.connect(any())).thenReturn(REQ_RES_FACTORY.connect(CONNECT_ADDRESS)); - } - - private void subscribeToProxyConnectionFactory() { - toSource(connectionFactory.processConnect(connection)).subscribe(subscriber); - } - - @Test - void newConnectRequestThrows() { - when(connection.connect(any())).thenThrow(DELIBERATE_EXCEPTION); - subscribeToProxyConnectionFactory(); - - assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); - verify(connection).connect(any()); - verify(connection, never()).request(any()); - assertConnectionClosed(); - } - - @Test - void connectRequestFails() { - when(connection.request(any())).thenReturn(failed(DELIBERATE_EXCEPTION)); - - configureConnectRequest(); - subscribeToProxyConnectionFactory(); - - Throwable error = subscriber.awaitOnError(); - assertThat("Unexpected error: " + error, error, is(DELIBERATE_EXCEPTION)); - assertConnectPayloadConsumed(false); - assertConnectionClosed(); - } - - @Test - void nonSuccessfulResponseCode() { - StreamingHttpResponse response = mock(StreamingHttpResponse.class); - when(response.status()).thenReturn(INTERNAL_SERVER_ERROR); - when(response.messageBody()).thenReturn(messageBody); - when(connection.request(any())).thenReturn(succeeded(response)); - - configureConnectRequest(); - subscribeToProxyConnectionFactory(); - - Throwable error = subscriber.awaitOnError(); - assertThat(error, instanceOf(ProxyResponseException.class)); - assertThat(((ProxyResponseException) error).status(), is(INTERNAL_SERVER_ERROR)); - assertConnectPayloadConsumed(true); - assertConnectionClosed(); - } - - @ParameterizedTest(name = "{displayName} [{index}] ttl={0}") - @ValueSource(booleans = {true, false}) - void noDeferSslHandler(boolean channelActive) { - ChannelPipeline pipeline = configurePipeline(); - // Do not configureDeferSslHandler(pipeline); - Channel channel = pipeline.channel(); - when(channel.isActive()).thenReturn(channelActive); - configureRequestSend(); - configureConnectRequest(); - subscribeToProxyConnectionFactory(); - - Throwable error = subscriber.awaitOnError(); - assertThat(error, is(notNullValue())); - if (channelActive) { - assertThat(error, instanceOf(IllegalStateException.class)); - assertThat(error.getMessage(), containsString(DeferSslHandler.class.getSimpleName())); - } else { - assertThat(error, instanceOf(ClosedChannelException.class)); - } - assertConnectPayloadConsumed(true); - assertConnectionClosed(); - } - - @Test - void getDeferSslHandlerThrows() { - ChannelPipeline pipeline = configurePipeline(); - when(pipeline.get(DeferSslHandler.class)).thenThrow(DELIBERATE_EXCEPTION); - - configureRequestSend(); - configureConnectRequest(); - subscribeToProxyConnectionFactory(); - - assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); - assertConnectPayloadConsumed(true); - assertConnectionClosed(); - } - - @Test - void deferSslHandlerReadyThrows() { - ChannelPipeline pipeline = configurePipeline(); - DeferSslHandler deferSslHandler = mock(DeferSslHandler.class); - when(pipeline.get(DeferSslHandler.class)).thenReturn(deferSslHandler); - doThrow(DELIBERATE_EXCEPTION).when(deferSslHandler).ready(); - - configureRequestSend(); - configureConnectRequest(); - subscribeToProxyConnectionFactory(); - - assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); - assertConnectPayloadConsumed(true); - assertConnectionClosed(); - } - - @Test - void sslHandshakeFailure() { - AtomicReference handlerCaptor = new AtomicReference<>(); - ChannelPipeline pipeline = configurePipeline(handlerCaptor); - configureDeferSslHandler(pipeline, handlerCaptor, new SslHandshakeCompletionEvent(DELIBERATE_EXCEPTION)); - - configureRequestSend(); - configureConnectRequest(); - subscribeToProxyConnectionFactory(); - - assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); - assertConnectPayloadConsumed(true); - assertConnectionClosed(); - } - - @Test - @Disabled("https://github.com/apple/servicetalk/issues/1010") - void cancelledBeforeSslHandshakeCompletionEvent() { - ChannelPipeline pipeline = configurePipeline(); - DeferSslHandler handler = mock(DeferSslHandler.class); - when(pipeline.get(DeferSslHandler.class)).thenReturn(handler); - doNothing().when(handler).ready(); // Do not generate any SslHandshakeCompletionEvent - - configureRequestSend(); - configureConnectRequest(); - subscribeToProxyConnectionFactory(); - - Cancellable cancellable = subscriber.awaitSubscription(); - assertThat(subscriber.pollTerminal(10, MILLISECONDS), is(nullValue())); - assertThat(connectionClose.isSubscribed(), is(false)); - cancellable.cancel(); - assertConnectPayloadConsumed(true); - assertConnectionClosed(); - } - - @Test - void successfulConnect() { - AtomicReference handlerCaptor = new AtomicReference<>(); - ChannelPipeline pipeline = configurePipeline(handlerCaptor); - configureDeferSslHandler(pipeline, handlerCaptor, SslHandshakeCompletionEvent.SUCCESS); - configureRequestSend(); - configureConnectRequest(); - subscribeToProxyConnectionFactory(); - - assertThat(subscriber.awaitOnSuccess(), is(sameInstance(this.connection))); - StreamingHttpRequest request = assertConnectPayloadConsumed(true); - assertExecutionStrategy(request, offloadNone()); - assertConnectionClosed(false); - } - - private StreamingHttpRequest assertConnectPayloadConsumed(boolean expected) { - ArgumentCaptor requestCaptor = forClass(StreamingHttpRequest.class); - verify(connection).connect(any()); - verify(connection).request(requestCaptor.capture()); - assertThat("CONNECT response payload body was " + (expected ? "was" : "unnecessarily") + " consumed", - messageBody.isSubscribed(), is(expected)); - return requestCaptor.getValue(); - } - - private static void assertExecutionStrategy(StreamingHttpRequest request, HttpExecutionStrategy expectedStrategy) { - assertThat(request.context().get(HTTP_EXECUTION_STRATEGY_KEY), is(expectedStrategy)); - } - - private void assertConnectionClosed() { - assertConnectionClosed(true); - } - - private void assertConnectionClosed(boolean closed) { - assertThat("Closure of the connection was not triggered", connectionClose.isSubscribed(), is(closed)); - } -} diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/SecurityHandshakeObserverTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/SecurityHandshakeObserverTest.java index 5376999809..272134cc75 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/SecurityHandshakeObserverTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/SecurityHandshakeObserverTest.java @@ -17,13 +17,11 @@ import io.servicetalk.client.api.TransportObserverConnectionFactoryFilter; import io.servicetalk.http.api.BlockingHttpClient; -import io.servicetalk.http.api.SingleAddressHttpClientBuilder; import io.servicetalk.test.resources.DefaultTestCerts; import io.servicetalk.transport.api.ClientSslConfigBuilder; import io.servicetalk.transport.api.ConnectionInfo; import io.servicetalk.transport.api.ConnectionObserver; import io.servicetalk.transport.api.ConnectionObserver.SecurityHandshakeObserver; -import io.servicetalk.transport.api.HostAndPort; import io.servicetalk.transport.api.ServerContext; import io.servicetalk.transport.api.ServerSslConfigBuilder; import io.servicetalk.transport.api.TransportObserver; @@ -31,27 +29,22 @@ import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopDataObserver; import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopMultiplexedObserver; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; -import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; import java.util.List; import java.util.concurrent.CountDownLatch; -import java.util.function.UnaryOperator; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLSession; import static io.servicetalk.http.api.HttpResponseStatus.OK; -import static io.servicetalk.http.netty.HttpProtocol.HTTP_1; import static io.servicetalk.http.netty.HttpProtocol.toConfigs; import static io.servicetalk.http.netty.TestServiceStreaming.SVC_ECHO; import static io.servicetalk.test.resources.DefaultTestCerts.serverPemHostname; -import static java.util.Collections.singletonList; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.hasItemInArray; @@ -151,23 +144,7 @@ void verifyHandshakeFailed(List protocols) throws Exception { verifyHandshakeObserved(protocols, true); } - @Test - void withProxyTunnel() throws Exception { - try (ProxyTunnel proxyTunnel = new ProxyTunnel()) { - HostAndPort proxyAddress = proxyTunnel.startProxy(); - verifyHandshakeObserved(singletonList(HTTP_1), false, true, - builder -> builder.proxyAddress(proxyAddress)); - } - } - private void verifyHandshakeObserved(List protocols, boolean failHandshake) throws Exception { - verifyHandshakeObserved(protocols, failHandshake, false, UnaryOperator.identity()); - } - - private void verifyHandshakeObserved(List protocols, boolean failHandshake, boolean hasProxy, - UnaryOperator> clientBuilderFunction) - throws Exception { - try (ServerContext serverContext = BuilderUtils.newServerBuilder(SERVER_CTX) .protocols(toConfigs(protocols)) .sslConfig(new ServerSslConfigBuilder( @@ -175,8 +152,7 @@ private void verifyHandshakeObserved(List protocols, boolean failH .transportObserver(serverTransportObserver) .listenStreamingAndAwait(new TestServiceStreaming()); - BlockingHttpClient client = clientBuilderFunction.apply( - BuilderUtils.newClientBuilder(serverContext, CLIENT_CTX)) + BlockingHttpClient client = BuilderUtils.newClientBuilder(serverContext, CLIENT_CTX) .protocols(toConfigs(protocols)) .sslConfig(new ClientSslConfigBuilder(DefaultTestCerts::loadServerCAPem) .peerHost(failHandshake ? "unknown" : serverPemHostname()).build()) @@ -196,19 +172,16 @@ private void verifyHandshakeObserved(List protocols, boolean failH bothClosed.await(); HttpProtocol expectedProtocol = protocols.get(0); verifyObservers(clientOrder, clientTransportObserver, clientConnectionObserver, - clientSecurityHandshakeObserver, expectedProtocol, failHandshake, hasProxy); + clientSecurityHandshakeObserver, expectedProtocol, failHandshake); verifyObservers(serverOrder, serverTransportObserver, serverConnectionObserver, - serverSecurityHandshakeObserver, expectedProtocol, failHandshake, false); + serverSecurityHandshakeObserver, expectedProtocol, failHandshake); } private static void verifyObservers(InOrder order, TransportObserver transportObserver, ConnectionObserver connectionObserver, SecurityHandshakeObserver securityHandshakeObserver, - HttpProtocol expectedProtocol, boolean failHandshake, boolean hasProxy) { + HttpProtocol expectedProtocol, boolean failHandshake) { order.verify(transportObserver).onNewConnection(any(), any()); order.verify(connectionObserver).onTransportHandshakeComplete(); - if (hasProxy) { - order.verify(connectionObserver).connectionEstablished(any()); - } order.verify(connectionObserver).onSecurityHandshake(); if (failHandshake) { ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); @@ -223,12 +196,10 @@ private static void verifyObservers(InOrder order, TransportObserver transportOb order.verify(connectionObserver).connectionClosed(exception); } else { order.verify(securityHandshakeObserver).handshakeComplete(any(SSLSession.class)); - if (!hasProxy) { - if (expectedProtocol.version.major() > 1) { - order.verify(connectionObserver).multiplexedConnectionEstablished(any()); - } else { - order.verify(connectionObserver).connectionEstablished(any()); - } + if (expectedProtocol.version.major() > 1) { + order.verify(connectionObserver).multiplexedConnectionEstablished(any()); + } else { + order.verify(connectionObserver).connectionEstablished(any()); } } verifyNoMoreInteractions(transportObserver, securityHandshakeObserver); diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/SslCertificateCompressionTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/SslCertificateCompressionTest.java index 3fab7bbc8b..df9ac60bb9 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/SslCertificateCompressionTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/SslCertificateCompressionTest.java @@ -177,6 +177,11 @@ public void onFlush() { public void onTransportHandshakeComplete() { } + @Override + public ProxyConnectObserver onProxyConnect(final Object connectMsg) { + return NoopTransportObserver.NoopProxyConnectObserver.INSTANCE; + } + @Override public DataObserver connectionEstablished(final ConnectionInfo info) { return NoopTransportObserver.NoopDataObserver.INSTANCE; @@ -187,6 +192,10 @@ public MultiplexedObserver multiplexedConnectionEstablished(final ConnectionInfo return NoopTransportObserver.NoopMultiplexedObserver.INSTANCE; } + @Override + public void connectionWritabilityChanged(final boolean isWritable) { + } + @Override public void connectionClosed(final Throwable error) { } diff --git a/servicetalk-http-netty/src/testFixtures/java/io/servicetalk/http/netty/ProxyTunnel.java b/servicetalk-http-netty/src/testFixtures/java/io/servicetalk/http/netty/ProxyTunnel.java index 32b75fbf86..d1ecd21c91 100644 --- a/servicetalk-http-netty/src/testFixtures/java/io/servicetalk/http/netty/ProxyTunnel.java +++ b/servicetalk-http-netty/src/testFixtures/java/io/servicetalk/http/netty/ProxyTunnel.java @@ -26,7 +26,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.util.Locale; @@ -43,6 +42,7 @@ import static io.servicetalk.http.api.HttpResponseStatus.BAD_REQUEST; import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR; import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED; +import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; import static java.net.InetAddress.getLoopbackAddress; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.concurrent.Executors.newCachedThreadPool; @@ -85,7 +85,6 @@ public void close() throws Exception { */ public HostAndPort startProxy() throws IOException { serverSocket = new ServerSocket(0, 50, getLoopbackAddress()); - final InetSocketAddress serverSocketAddress = (InetSocketAddress) serverSocket.getLocalSocketAddress(); executor.submit(() -> { while (!executor.isShutdown()) { final Socket socket = serverSocket.accept(); @@ -112,7 +111,7 @@ public HostAndPort startProxy() throws IOException { } final String authToken = this.authToken; if (authToken != null && !("basic " + authToken).equals(headers.proxyAuthorization)) { - proxyAuthRequired(socket); + proxyAuthRequired(socket, protocol); return; } handler.handle(socket, host, port, protocol); @@ -130,7 +129,7 @@ public HostAndPort startProxy() throws IOException { return null; }); - return HostAndPort.of(serverSocketAddress.getAddress().getHostAddress(), serverSocketAddress.getPort()); + return serverHostAndPort(serverSocket.getLocalSocketAddress()); } private static void badRequest(final Socket socket, final String cause) throws IOException { @@ -141,9 +140,9 @@ private static void badRequest(final Socket socket, final String cause) throws I os.flush(); } - private static void proxyAuthRequired(final Socket socket) throws IOException { + private static void proxyAuthRequired(final Socket socket, final String protocol) throws IOException { final OutputStream os = socket.getOutputStream(); - os.write((HTTP_1_1 + " " + PROXY_AUTHENTICATION_REQUIRED + "\r\n" + + os.write((protocol + ' ' + PROXY_AUTHENTICATION_REQUIRED + "\r\n" + PROXY_AUTHENTICATE + ": Basic realm=\"simple\"" + "\r\n" + "\r\n").getBytes(UTF_8)); os.flush(); @@ -160,6 +159,15 @@ public void badResponseProxy() { }; } + /** + * Override the default handler to the passed {@link ProxyRequestHandler}. + * + * @param handler {@link ProxyRequestHandler} to use + */ + public void proxyRequestHandler(final ProxyRequestHandler handler) { + this.handler = handler; + } + /** * Sets a required {@link HttpHeaderNames#PROXY_AUTHORIZATION} header value for "Basic" scheme to validate before * accepting a {@code CONNECT} request. @@ -257,8 +265,21 @@ private static void copyStream(final OutputStream out, final InputStream cin) th // Don't close either Stream! We close the socket outside the scope of this method (in a specific sequence). } + /** + * A handler that processes a parsed CONNECT request. + */ @FunctionalInterface - private interface ProxyRequestHandler { + public interface ProxyRequestHandler { + + /** + * Handle the parsed CONNECT request. + * + * @param socket {@link Socket} from a client to a proxy + * @param host Host to connect to + * @param port Port to connect to + * @param protocol String representation of a protocol used for incoming CONNECT request + * @throws IOException if any exception happens while working with I/O + */ void handle(Socket socket, String host, int port, String protocol) throws IOException; } diff --git a/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/BiTransportObserver.java b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/BiTransportObserver.java index 939d24ed98..bd2ad62a9a 100644 --- a/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/BiTransportObserver.java +++ b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/BiTransportObserver.java @@ -17,6 +17,7 @@ import io.servicetalk.transport.api.ConnectionObserver.DataObserver; import io.servicetalk.transport.api.ConnectionObserver.MultiplexedObserver; +import io.servicetalk.transport.api.ConnectionObserver.ProxyConnectObserver; import io.servicetalk.transport.api.ConnectionObserver.ReadObserver; import io.servicetalk.transport.api.ConnectionObserver.SecurityHandshakeObserver; import io.servicetalk.transport.api.ConnectionObserver.StreamObserver; @@ -77,6 +78,11 @@ public void onTransportHandshakeComplete() { second.onTransportHandshakeComplete(); } + @Override + public ProxyConnectObserver onProxyConnect(final Object connectMsg) { + return new BiProxyConnectObserver(first.onProxyConnect(connectMsg), second.onProxyConnect(connectMsg)); + } + @Override public SecurityHandshakeObserver onSecurityHandshake() { return new BiSecurityHandshakeObserver(first.onSecurityHandshake(), second.onSecurityHandshake()); @@ -112,6 +118,29 @@ public void connectionClosed() { } } + private static final class BiProxyConnectObserver implements ProxyConnectObserver { + + private final ProxyConnectObserver first; + private final ProxyConnectObserver second; + + private BiProxyConnectObserver(final ProxyConnectObserver first, final ProxyConnectObserver second) { + this.first = first; + this.second = second; + } + + @Override + public void proxyConnectFailed(final Throwable cause) { + first.proxyConnectFailed(cause); + second.proxyConnectFailed(cause); + } + + @Override + public void proxyConnectComplete(final Object responseMsg) { + first.proxyConnectComplete(responseMsg); + second.proxyConnectComplete(responseMsg); + } + } + private static final class BiSecurityHandshakeObserver implements SecurityHandshakeObserver { private final SecurityHandshakeObserver first; diff --git a/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/CatchAllTransportObserver.java b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/CatchAllTransportObserver.java index 267a4eb9f2..aa235348fd 100644 --- a/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/CatchAllTransportObserver.java +++ b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/CatchAllTransportObserver.java @@ -17,6 +17,7 @@ import io.servicetalk.transport.api.ConnectionObserver.DataObserver; import io.servicetalk.transport.api.ConnectionObserver.MultiplexedObserver; +import io.servicetalk.transport.api.ConnectionObserver.ProxyConnectObserver; import io.servicetalk.transport.api.ConnectionObserver.ReadObserver; import io.servicetalk.transport.api.ConnectionObserver.SecurityHandshakeObserver; import io.servicetalk.transport.api.ConnectionObserver.StreamObserver; @@ -24,6 +25,7 @@ import io.servicetalk.transport.api.NoopTransportObserver.NoopConnectionObserver; import io.servicetalk.transport.api.NoopTransportObserver.NoopDataObserver; import io.servicetalk.transport.api.NoopTransportObserver.NoopMultiplexedObserver; +import io.servicetalk.transport.api.NoopTransportObserver.NoopProxyConnectObserver; import io.servicetalk.transport.api.NoopTransportObserver.NoopReadObserver; import io.servicetalk.transport.api.NoopTransportObserver.NoopSecurityHandshakeObserver; import io.servicetalk.transport.api.NoopTransportObserver.NoopStreamObserver; @@ -87,6 +89,12 @@ public void onTransportHandshakeComplete() { safeReport(observer::onTransportHandshakeComplete, observer, "flush"); } + @Override + public ProxyConnectObserver onProxyConnect(final Object connectMsg) { + return safeReport(() -> observer.onProxyConnect(connectMsg), observer, "proxy connect", + CatchAllProxyConnectObserver::new, NoopProxyConnectObserver.INSTANCE); + } + @Override public SecurityHandshakeObserver onSecurityHandshake() { return safeReport(observer::onSecurityHandshake, observer, "security handshake", @@ -123,6 +131,25 @@ public void connectionClosed() { } } + private static final class CatchAllProxyConnectObserver implements ProxyConnectObserver { + + private final ProxyConnectObserver observer; + + private CatchAllProxyConnectObserver(final ProxyConnectObserver observer) { + this.observer = observer; + } + + @Override + public void proxyConnectFailed(final Throwable cause) { + safeReport(() -> observer.proxyConnectFailed(cause), observer, "proxy connect failed", cause); + } + + @Override + public void proxyConnectComplete(final Object responseMsg) { + safeReport(() -> observer.proxyConnectComplete(responseMsg), observer, "proxy connect complete"); + } + } + private static final class CatchAllSecurityHandshakeObserver implements SecurityHandshakeObserver { private final SecurityHandshakeObserver observer; diff --git a/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/ConnectionObserver.java b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/ConnectionObserver.java index 6348139344..56d7bd2827 100644 --- a/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/ConnectionObserver.java +++ b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/ConnectionObserver.java @@ -15,6 +15,8 @@ */ package io.servicetalk.transport.api; +import io.servicetalk.transport.api.NoopTransportObserver.NoopProxyConnectObserver; + import javax.annotation.Nullable; import javax.net.ssl.SSLSession; @@ -53,6 +55,18 @@ public interface ConnectionObserver { */ void onTransportHandshakeComplete(); + /** + * Callback when a proxy connect is initiated. + *

+ * For a typical connection, this callback is invoked after {@link #onTransportHandshakeComplete()}. + * + * @param connectMsg a message sent to a proxy in request to establish a connection to the target server + * @return a new {@link ProxyConnectObserver} that provides visibility into proxy connect events. + */ + default ProxyConnectObserver onProxyConnect(Object connectMsg) { // FIXME: 0.43 - consider removing default impl + return NoopProxyConnectObserver.INSTANCE; + } + /** * Callback when a security handshake is initiated. *

@@ -64,8 +78,8 @@ public interface ConnectionObserver { * Open is available and configured, it may not actually happen if the * Fast Open Cookie is {@code null} or * rejected by the server. - *

  • For a proxy connections, the handshake may happen after the - * {@link #connectionEstablished(ConnectionInfo)}.
  • + *
  • For a proxy connections, the handshake may happen after an observer returned by + * {@link #onProxyConnect(Object)} completes successfully.
  • * * * @return a new {@link SecurityHandshakeObserver} that provides visibility into security handshake events @@ -111,6 +125,31 @@ default void connectionWritabilityChanged(boolean isWritable) { // FIXME: 0.43 - */ void connectionClosed(); + /** + * An observer interface that provides visibility into proxy connect events for establishing a tunnel. + *

    + * Either {@link #proxyConnectComplete(Object)} or {@link #proxyConnectFailed(Throwable)} will be invoked to signal + * successful or failed connection via proxy tunnel. + */ + interface ProxyConnectObserver { + + /** + * Callback when the proxy connect attempt is failed. + * + * @param cause the cause of proxy connect failure + */ + void proxyConnectFailed(Throwable cause); + + /** + * Callback when the proxy connect attempt is complete successfully. + * + * @param responseMsg an object that represents a response message. The actual message type depends upon proxy + * protocol implementation (e.g. HTTP Tunnel, + * SOCKS, etc.) + */ + void proxyConnectComplete(Object responseMsg); + } + /** * An observer interface that provides visibility into security handshake events. *

    diff --git a/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/NoopTransportObserver.java b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/NoopTransportObserver.java index 3a244c2369..fce93727e0 100644 --- a/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/NoopTransportObserver.java +++ b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/NoopTransportObserver.java @@ -17,6 +17,7 @@ import io.servicetalk.transport.api.ConnectionObserver.DataObserver; import io.servicetalk.transport.api.ConnectionObserver.MultiplexedObserver; +import io.servicetalk.transport.api.ConnectionObserver.ProxyConnectObserver; import io.servicetalk.transport.api.ConnectionObserver.ReadObserver; import io.servicetalk.transport.api.ConnectionObserver.SecurityHandshakeObserver; import io.servicetalk.transport.api.ConnectionObserver.StreamObserver; @@ -62,6 +63,11 @@ public void onFlush() { public void onTransportHandshakeComplete() { } + @Override + public ProxyConnectObserver onProxyConnect(final Object connectMsg) { + return NoopProxyConnectObserver.INSTANCE; + } + @Override public SecurityHandshakeObserver onSecurityHandshake() { return NoopSecurityHandshakeObserver.INSTANCE; @@ -90,6 +96,23 @@ public void connectionClosed() { } } + static final class NoopProxyConnectObserver implements ProxyConnectObserver { + + static final ProxyConnectObserver INSTANCE = new NoopProxyConnectObserver(); + + private NoopProxyConnectObserver() { + // Singleton + } + + @Override + public void proxyConnectFailed(final Throwable cause) { + } + + @Override + public void proxyConnectComplete(final Object responseMsg) { + } + } + static final class NoopSecurityHandshakeObserver implements SecurityHandshakeObserver { static final SecurityHandshakeObserver INSTANCE = new NoopSecurityHandshakeObserver(); diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CloseHandler.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CloseHandler.java index c1cd824226..e283130d99 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CloseHandler.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CloseHandler.java @@ -376,8 +376,11 @@ private OutboundDataEndEvent() { /** * {@link NettyUserEvent} to indicate the end of inbound data was observed at the transport. */ - static final class InboundDataEndEvent extends NettyUserEvent { - static final InboundDataEndEvent INSTANCE = new InboundDataEndEvent(); + public static final class InboundDataEndEvent extends NettyUserEvent { + /** + * {@link NettyUserEvent} instance to indicate an inbound end of data. + */ + public static final InboundDataEndEvent INSTANCE = new InboundDataEndEvent(); /** * {@link NettyUserEvent} to indicate the end of inbound data was observed at the transport. diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java index bc6b0418f6..2099d5ea52 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java @@ -52,7 +52,6 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; @@ -525,18 +524,6 @@ protected void handleSubscribe( }; } - /** - * Return {@link Class} of the {@link ChannelHandler} in case there is a need to remove the handler from the - * {@link ChannelPipeline}. - * - * @return {@link Class} of the {@link ChannelHandler} in case there is a need to remove the handler from the - * {@link ChannelPipeline}. - */ - // FIXME: remove this method in a follow-up RP, after refactoring for ProxyConnectObserver is complete. - public static Class handlerClass() { - return NettyToStChannelHandler.class; - } - private static boolean shouldWaitForSslHandshake(@Nullable final SSLSession sslSession, @Nullable final SslConfig sslConfig, final ChannelPipeline pipeline) { diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NoopTransportObserver.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NoopTransportObserver.java index 6ec5e8ed60..cb6b4828e0 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NoopTransportObserver.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NoopTransportObserver.java @@ -19,6 +19,7 @@ import io.servicetalk.transport.api.ConnectionObserver; import io.servicetalk.transport.api.ConnectionObserver.DataObserver; import io.servicetalk.transport.api.ConnectionObserver.MultiplexedObserver; +import io.servicetalk.transport.api.ConnectionObserver.ProxyConnectObserver; import io.servicetalk.transport.api.ConnectionObserver.ReadObserver; import io.servicetalk.transport.api.ConnectionObserver.SecurityHandshakeObserver; import io.servicetalk.transport.api.ConnectionObserver.StreamObserver; @@ -71,6 +72,11 @@ public void onFlush() { public void onTransportHandshakeComplete() { } + @Override + public ProxyConnectObserver onProxyConnect(final Object connectMsg) { + return NoopProxyConnectObserver.INSTANCE; + } + @Override public SecurityHandshakeObserver onSecurityHandshake() { return NoopSecurityHandshakeObserver.INSTANCE; @@ -99,6 +105,26 @@ public void connectionClosed() { } } + /** + * Noop version of {@link ProxyConnectObserver}. + */ + public static final class NoopProxyConnectObserver implements ProxyConnectObserver { + + public static final ProxyConnectObserver INSTANCE = new NoopProxyConnectObserver(); + + private NoopProxyConnectObserver() { + // Singleton + } + + @Override + public void proxyConnectFailed(final Throwable cause) { + } + + @Override + public void proxyConnectComplete(final Object responseMsg) { + } + } + /** * Noop version of {@link SecurityHandshakeObserver}. */ diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/StacklessClosedChannelException.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/StacklessClosedChannelException.java index 99b3c73c61..3242897b3d 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/StacklessClosedChannelException.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/StacklessClosedChannelException.java @@ -18,7 +18,6 @@ import io.servicetalk.concurrent.internal.ThrowableUtils; import java.nio.channels.ClosedChannelException; -import javax.annotation.Nullable; /** * {@link ClosedChannelException} that will not not fill in the stacktrace but use a cheaper way of producing @@ -27,22 +26,7 @@ public final class StacklessClosedChannelException extends ClosedChannelException { private static final long serialVersionUID = -5021225720136487769L; - @Nullable - private final String message; - - private StacklessClosedChannelException() { - this(null); - } - - private StacklessClosedChannelException(@Nullable final String message) { - this.message = message; - } - - @Nullable - @Override - public String getMessage() { - return message; - } + private StacklessClosedChannelException() { } @Override public Throwable fillInStackTrace() { @@ -57,21 +41,7 @@ public Throwable fillInStackTrace() { * @param method The method from which it will be thrown. * @return a new instance. */ - public static StacklessClosedChannelException newInstance(final Class clazz, final String method) { + public static StacklessClosedChannelException newInstance(Class clazz, String method) { return ThrowableUtils.unknownStackTrace(new StacklessClosedChannelException(), clazz, method); } - - /** - * Creates a new {@link StacklessClosedChannelException} instance. - * - * @param message The description message for more information. - * @param clazz The class in which this {@link StacklessClosedChannelException} will be used. - * @param method The method from which it will be thrown. - * @return a new instance. - */ - public static StacklessClosedChannelException newInstance(final String message, - final Class clazz, - final String method) { - return ThrowableUtils.unknownStackTrace(new StacklessClosedChannelException(message), clazz, method); - } }