Skip to content

Commit

Permalink
add tests for proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
ikhoon committed Sep 20, 2024
1 parent 4a808fc commit 979bcaa
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ SslContext getOrCreateSslContext(SocketAddress remoteAddress, SessionProtocol de
}

private SslContext newSslContext(SocketAddress remoteAddress, SessionProtocol desiredProtocol) {
assert desiredProtocol.isTls();

final String hostname;
if (remoteAddress instanceof InetSocketAddress) {
hostname = ((InetSocketAddress) remoteAddress).getHostString();
Expand Down Expand Up @@ -203,8 +201,11 @@ private ChannelInitializer<Channel> clientChannelInitializer(SessionProtocol p,
return new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
if (closeSslContext) {
ch.closeFuture().addListener(unused -> releaseSslContext(sslCtx));
}
ch.pipeline().addLast(new HttpClientPipelineConfigurator(
clientFactory, webSocket, p, sslCtx, closeSslContext ? sslContextFactory : null));
clientFactory, webSocket, p, sslCtx));
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.annotations.VisibleForTesting;

import com.linecorp.armeria.common.Flags;
import com.linecorp.armeria.common.HttpObject;
import com.linecorp.armeria.common.HttpRequest;
Expand All @@ -63,7 +61,6 @@
import com.linecorp.armeria.internal.common.ArmeriaHttpUtil;
import com.linecorp.armeria.internal.common.CancellationScheduler;
import com.linecorp.armeria.internal.common.ReadSuppressingHandler;
import com.linecorp.armeria.internal.common.SslContextFactory;
import com.linecorp.armeria.internal.common.TrafficLoggingHandler;
import com.linecorp.armeria.internal.common.util.ChannelUtil;

Expand Down Expand Up @@ -153,8 +150,6 @@ private enum HttpPreference {
private final boolean webSocket;
@Nullable
private final SslContext sslCtx;
@Nullable
private final SslContextFactory sslContextFactory;
private final HttpPreference httpPreference;
@Nullable
private SocketAddress remoteAddress;
Expand All @@ -164,7 +159,7 @@ private enum HttpPreference {

HttpClientPipelineConfigurator(HttpClientFactory clientFactory,
boolean webSocket, SessionProtocol sessionProtocol,
SslContext sslCtx, @Nullable SslContextFactory sslContextFactory) {
SslContext sslCtx) {
this.clientFactory = clientFactory;
this.webSocket = webSocket;

Expand All @@ -181,13 +176,10 @@ private enum HttpPreference {

if (sessionProtocol.isTls()) {
this.sslCtx = sslCtx;
this.sslContextFactory = sslContextFactory;
http1 = H1;
http2 = H2;
} else {
this.sslCtx = null;
assert sslContextFactory == null;
this.sslContextFactory = null;
http1 = H1C;
http2 = H2C;
}
Expand Down Expand Up @@ -227,15 +219,6 @@ public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, Sock
ctx.connect(remoteAddress, localAddress, connectionPromise);
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
super.channelInactive(ctx);
if (sslContextFactory != null) {
assert sslCtx != null;
sslContextFactory.release(sslCtx);
}
}

/**
* See <a href="https://http2.github.io/http2-spec/#discover-https">HTTP/2 specification</a>.
*/
Expand Down Expand Up @@ -819,12 +802,6 @@ private static HttpClientCodec newHttp1Codec(
return new HttpClientCodec(defaultMaxInitialLineLength, defaultMaxHeaderSize, defaultMaxChunkSize);
}

@VisibleForTesting
@Nullable
SslContextFactory sslContextFactory() {
return sslContextFactory;
}

/**
* Suppresses unnecessary read calls and deactivates the {@link HttpSession} associated with a channel when
* it is closed to ensure it isn't used anymore.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ public TlsKeyPair find(String hostname) {
return tlsKeyPair;
}

@Override
public List<X509Certificate> trustedCertificates() {
return trustedCertificates;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import com.linecorp.armeria.internal.common.SslContextFactory;
import com.linecorp.armeria.internal.testing.MockAddressResolverGroup;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServerTlsConfig;
import com.linecorp.armeria.testing.junit5.server.SelfSignedCertificateExtension;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

Expand Down Expand Up @@ -71,17 +72,23 @@ class TlsProviderCacheTest {
static final ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) {
sb.tls(serverFooCert.tlsKeyPair());
sb.tlsCustomizer(b -> {
b.clientAuth(ClientAuth.REQUIRE)
.trustManager(clientFooCert.certificate());
});
final TlsProvider tlsProvider =
TlsProvider.builder()
.setDefault(serverFooCert.tlsKeyPair())
.set("bar.com", serverBarCert.tlsKeyPair())
.trustedCertificates(clientFooCert.certificate(),
clientBarCert.certificate())
.build();
final ServerTlsConfig tlsConfig = ServerTlsConfig.builder()
.clientAuth(ClientAuth.REQUIRE)
.build();
sb.tlsProvider(tlsProvider, tlsConfig);

sb.virtualHost("bar.com")
.tls(serverBarCert.tlsKeyPair())
.tlsCustomizer(b -> {
b.clientAuth(ClientAuth.REQUIRE)
.trustManager(clientBarCert.certificate());
.service("/", (ctx, req) -> {
final CompletableFuture<HttpResponse> future =
startFuture.thenApply(unused -> HttpResponse.of("Hello, Bar!"));
return HttpResponse.of(future);
});

sb.service("/", (ctx, req) -> {
Expand All @@ -97,24 +104,6 @@ void setUp() {
startFuture = new CompletableFuture<>();
}

@Test
void test() {
final ClientFactory factory =
ClientFactory.builder()
.tlsCustomizer(b -> {
b.trustManager(serverFooCert.certificate(),
serverBarCert.certificate());
})
.build();
final BlockingWebClient client =
WebClient.builder("https://google.com")
.factory(factory)
.build()
.blocking();
final AggregatedHttpResponse res = client.get("/");
System.out.println(res);
}

@Test
void shouldCacheSslContext() {
// This test could be broken if multiple tests are running in parallel.
Expand Down Expand Up @@ -159,22 +148,33 @@ void shouldCacheSslContext() {
}
}

startFuture.complete(null);
for (AggregatedHttpResponse response : CompletableFutures.allAsList(responses).join()) {
assertThat(response.status()).isEqualTo(HttpStatus.OK);
assertThat(response.contentUtf8()).isEqualTo("Hello!");
}

await().untilAsserted(() -> {
assertThat(poolListener.opened()).isEqualTo(6);
});
assertThat(channels).hasSize(6);

final HttpClientFactory clientFactory = (HttpClientFactory) factory.unwrap();
final SslContextFactory sslContextFactory = clientFactory.sslContextFactory();
assertThat(sslContextFactory).isNotNull();
// Make sure the SslContext is reused after the connection is closed.
// Make sure the SslContext is reused
assertThat(sslContextFactory.numCachedContexts()).isEqualTo(2);

startFuture.complete(null);
final List<AggregatedHttpResponse> responses0 = CompletableFutures.allAsList(responses).join();
for (int i = 0; i < responses0.size(); i++) {
final AggregatedHttpResponse response = responses0.get(i);
assertThat(response.status()).isEqualTo(HttpStatus.OK);
if (i < 3) {
assertThat(response.contentUtf8()).isEqualTo("Hello!");
} else {
assertThat(response.contentUtf8()).isEqualTo("Hello, Bar!");
}
}

await().untilAsserted(() -> {
assertThat(poolListener.closed()).isEqualTo(6);
});
// Make sure a cached SslContext is released when all referenced channels are closed.
assertThat(sslContextFactory.numCachedContexts()).isEqualTo(0);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2024 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.client;

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.TlsProvider;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServerTlsConfig;
import com.linecorp.armeria.testing.junit5.server.SelfSignedCertificateExtension;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import io.netty.handler.ssl.ClientAuth;

class TlsProviderMTlsTest {
@Order(0)
@RegisterExtension
static SelfSignedCertificateExtension sscServer = new SelfSignedCertificateExtension();
@Order(0)
@RegisterExtension
static SelfSignedCertificateExtension sscClient = new SelfSignedCertificateExtension();

@Order(1)
@RegisterExtension
static ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) {
final TlsProvider tlsProvider = TlsProvider.builder()
.setDefault(sscServer.tlsKeyPair())
.trustedCertificates(sscClient.certificate())
.build();
final ServerTlsConfig tlsConfig = ServerTlsConfig.builder()
.clientAuth(ClientAuth.REQUIRE)
.build();
sb.tlsProvider(tlsProvider, tlsConfig);

sb.service("/", (ctx, req) -> {
return HttpResponse.of(HttpStatus.OK);
});
}
};

@Test
void testMTls() {
final TlsProvider tlsProvider = TlsProvider
.builder()
.setDefault(sscClient.tlsKeyPair())
.trustedCertificates(sscServer.certificate())
.build();
try (ClientFactory factory = ClientFactory
.builder()
.tlsProvider(tlsProvider)
.connectTimeoutMillis(Long.MAX_VALUE)
.build()) {
final BlockingWebClient client = WebClient.builder(server.httpsUri())
.factory(factory)
.build()
.blocking();
final AggregatedHttpResponse res = client.get("/");
assertThat(res.status()).isEqualTo(HttpStatus.OK);
}
}
}
Loading

0 comments on commit 979bcaa

Please sign in to comment.