Skip to content

Commit

Permalink
Allow customization of netty channel handles before and during decomp…
Browse files Browse the repository at this point in the history
…ression (#10261)
  • Loading branch information
cwperks authored Oct 5, 2023
1 parent 66aef13 commit dad525a
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add Doc Status Counter for Indexing Engine ([#4562](https://github.com/opensearch-project/OpenSearch/issues/4562))
- Add unreferenced file cleanup count to merge stats ([#10204](https://github.com/opensearch-project/OpenSearch/pull/10204))
- [Remote Store] Add support to restrict creation & deletion if system repository and mutation of immutable settings of system repository ([#9839](https://github.com/opensearch-project/OpenSearch/pull/9839))
- Improve compressed request handling ([#10261](https://github.com/opensearch-project/OpenSearch/pull/10261))

### Dependencies
- Bump `peter-evans/create-or-update-comment` from 2 to 3 ([#9575](https://github.com/opensearch-project/OpenSearch/pull/9575))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.http.netty4;

import org.opensearch.OpenSearchNetty4IntegTestCase;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.http.HttpServerTransport;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.OpenSearchIntegTestCase.ClusterScope;
import org.opensearch.test.OpenSearchIntegTestCase.Scope;
import org.opensearch.transport.Netty4BlockingPlugin;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import io.netty.buffer.ByteBufUtil;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http2.HttpConversionUtil;
import io.netty.util.ReferenceCounted;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static io.netty.handler.codec.http.HttpHeaderNames.HOST;

@ClusterScope(scope = Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1)
public class Netty4HeaderVerifierIT extends OpenSearchNetty4IntegTestCase {

@Override
protected boolean addMockHttpTransport() {
return false; // enable http
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(Netty4BlockingPlugin.class);
}

public void testThatNettyHttpServerRequestBlockedWithHeaderVerifier() throws Exception {
HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class);
TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses();
TransportAddress transportAddress = randomFrom(boundAddresses);

final FullHttpRequest blockedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
blockedRequest.headers().add("blockme", "Not Allowed");
blockedRequest.headers().add(HOST, "localhost");
blockedRequest.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http");

final List<FullHttpResponse> responses = new ArrayList<>();
try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http2()) {
try {
FullHttpResponse blockedResponse = nettyHttpClient.send(transportAddress.address(), blockedRequest);
responses.add(blockedResponse);
String blockedResponseContent = new String(ByteBufUtil.getBytes(blockedResponse.content()), StandardCharsets.UTF_8);
assertThat(blockedResponseContent, containsString("Hit header_verifier"));
assertThat(blockedResponse.status().code(), equalTo(401));
} finally {
responses.forEach(ReferenceCounted::release);
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.transport;

import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.http.HttpServerTransport;
import org.opensearch.http.netty4.Netty4HttpServerTransport;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;

import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.function.Supplier;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.ReferenceCountUtil;

public class Netty4BlockingPlugin extends Netty4ModulePlugin {

public class Netty4BlockingHttpServerTransport extends Netty4HttpServerTransport {

public Netty4BlockingHttpServerTransport(
Settings settings,
NetworkService networkService,
BigArrays bigArrays,
ThreadPool threadPool,
NamedXContentRegistry xContentRegistry,
Dispatcher dispatcher,
ClusterSettings clusterSettings,
SharedGroupFactory sharedGroupFactory,
Tracer tracer
) {
super(
settings,
networkService,
bigArrays,
threadPool,
xContentRegistry,
dispatcher,
clusterSettings,
sharedGroupFactory,
tracer
);
}

@Override
protected ChannelInboundHandlerAdapter createHeaderVerifier() {
return new ExampleBlockingNetty4HeaderVerifier();
}
}

@Override
public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
Settings settings,
ThreadPool threadPool,
BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler,
CircuitBreakerService circuitBreakerService,
NamedXContentRegistry xContentRegistry,
NetworkService networkService,
HttpServerTransport.Dispatcher dispatcher,
ClusterSettings clusterSettings,
Tracer tracer
) {
return Collections.singletonMap(
NETTY_HTTP_TRANSPORT_NAME,
() -> new Netty4BlockingHttpServerTransport(
settings,
networkService,
bigArrays,
threadPool,
xContentRegistry,
dispatcher,
clusterSettings,
getSharedGroupFactory(settings),
tracer
)
);
}

/** POC for how an external header verifier would be implemented */
public class ExampleBlockingNetty4HeaderVerifier extends SimpleChannelInboundHandler<DefaultHttpRequest> {

@Override
public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) throws Exception {
ReferenceCountUtil.retain(msg);
if (isBlocked(msg)) {
ByteBuf buf = Unpooled.copiedBuffer("Hit header_verifier".getBytes(StandardCharsets.UTF_8));
final FullHttpResponse response = new DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.UNAUTHORIZED, buf);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
ReferenceCountUtil.release(msg);
} else {
// Lets the request pass to the next channel handler
ctx.fireChannelRead(msg);
}
}

private boolean isBlocked(HttpRequest request) {
final boolean shouldBlock = request.headers().contains("blockme");

return shouldBlock;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ public ChannelHandler configureServerChannelHandler() {
return new HttpChannelHandler(this, handlingSettings);
}

protected static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("opensearch-http-channel");
public static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("opensearch-http-channel");
protected static final AttributeKey<Netty4HttpServerChannel> HTTP_SERVER_CHANNEL_KEY = AttributeKey.newInstance(
"opensearch-http-server-channel"
);
Expand Down Expand Up @@ -419,8 +419,8 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpMessage msg) throws E
// If this handler is hit then no upgrade has been attempted and the client is just talking HTTP
final ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), "handler", getRequestHandler());
pipeline.replace(this, "decoder_compress", new HttpContentDecompressor());

pipeline.replace(this, "header_verifier", transport.createHeaderVerifier());
pipeline.addAfter("header_verifier", "decoder_compress", transport.createDecompressor());
pipeline.addAfter("decoder_compress", "aggregator", aggregator);
if (handlingSettings.isCompression()) {
pipeline.addAfter(
Expand All @@ -446,7 +446,8 @@ protected void configureDefaultHttpPipeline(ChannelPipeline pipeline) {
);
decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
pipeline.addLast("decoder", decoder);
pipeline.addLast("decoder_compress", new HttpContentDecompressor());
pipeline.addLast("header_verifier", transport.createHeaderVerifier());
pipeline.addLast("decoder_compress", transport.createDecompressor());
pipeline.addLast("encoder", new HttpResponseEncoder());
final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength());
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);
Expand Down Expand Up @@ -487,13 +488,13 @@ protected void initChannel(Channel childChannel) throws Exception {

final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength());
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);

childChannel.pipeline()
.addLast(new LoggingHandler(LogLevel.DEBUG))
.addLast(new Http2StreamFrameToHttpObjectCodec(true))
.addLast("byte_buf_sizer", byteBufSizer)
.addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS))
.addLast("decoder_decompress", new HttpContentDecompressor());
.addLast("header_verifier", transport.createHeaderVerifier())
.addLast("decoder_decompress", transport.createDecompressor());

if (handlingSettings.isCompression()) {
childChannel.pipeline()
Expand Down Expand Up @@ -531,4 +532,21 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}
}

/**
* Extension point that allows a NetworkPlugin to extend the netty pipeline and inspect headers after request decoding
*/
protected ChannelInboundHandlerAdapter createHeaderVerifier() {
// pass-through
return new ChannelInboundHandlerAdapter();
}

/**
* Extension point that allows a NetworkPlugin to override the default netty HttpContentDecompressor and supply a custom decompressor.
*
* Used in instances to conditionally decompress depending on the outcome from header verification
*/
protected ChannelInboundHandlerAdapter createDecompressor() {
return new HttpContentDecompressor();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
);
}

private SharedGroupFactory getSharedGroupFactory(Settings settings) {
SharedGroupFactory getSharedGroupFactory(Settings settings) {
SharedGroupFactory groupFactory = this.groupFactory.get();
if (groupFactory != null) {
assert groupFactory.getSettings().equals(settings) : "Different settings than originally provided";
Expand Down
2 changes: 1 addition & 1 deletion server/src/main/java/org/opensearch/rest/RestHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ default List<ReplacedRoute> replacedRoutes() {
}

/**
* Controls whether requests handled by this class are allowed to to access system indices by default.
* Controls whether requests handled by this class are allowed to access system indices by default.
* @return {@code true} if requests handled by this class should be allowed to access system indices.
*/
default boolean allowSystemIndexAccessByDefault() {
Expand Down

0 comments on commit dad525a

Please sign in to comment.