Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Abstracting outbound side of transport (#13293) #13656

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x]
### Added
- Add useCompoundFile index setting ([#13478](https://github.com/opensearch-project/OpenSearch/pull/13478))
- Make outbound side of transport protocol dependent ([#13293](https://github.com/opensearch-project/OpenSearch/pull/13293))

### Dependencies
- Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

package org.opensearch.transport;

import org.opensearch.Version;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.BigArrays;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -57,7 +59,12 @@ public class InboundHandler {
private final Map<String, ProtocolMessageHandler> protocolMessageHandlers;

InboundHandler(
String nodeName,
Version version,
String[] features,
StatsTracker statsTracker,
ThreadPool threadPool,
BigArrays bigArrays,
OutboundHandler outboundHandler,
NamedWriteableRegistry namedWriteableRegistry,
TransportHandshaker handshaker,
Expand All @@ -70,7 +77,12 @@ public class InboundHandler {
this.protocolMessageHandlers = Map.of(
NativeInboundMessage.NATIVE_PROTOCOL,
new NativeMessageHandler(
nodeName,
version,
features,
statsTracker,
threadPool,
bigArrays,
outboundHandler,
namedWriteableRegistry,
handshaker,
Expand All @@ -83,6 +95,7 @@ public class InboundHandler {
}

void setMessageListener(TransportMessageListener listener) {
protocolMessageHandlers.values().forEach(handler -> handler.setMessageListener(listener));
if (messageListener == TransportMessageListener.NOOP_LISTENER) {
messageListener = listener;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.BytesRef;
import org.opensearch.Version;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.common.io.stream.ByteBufferStreamInput;
Expand All @@ -51,6 +52,7 @@
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeOutboundHandler;

import java.io.EOFException;
import java.io.IOException;
Expand All @@ -71,7 +73,7 @@ public class NativeMessageHandler implements ProtocolMessageHandler {
private static final Logger logger = LogManager.getLogger(NativeMessageHandler.class);

private final ThreadPool threadPool;
private final OutboundHandler outboundHandler;
private final NativeOutboundHandler outboundHandler;
private final NamedWriteableRegistry namedWriteableRegistry;
private final TransportHandshaker handshaker;
private final TransportKeepAlive keepAlive;
Expand All @@ -81,7 +83,12 @@ public class NativeMessageHandler implements ProtocolMessageHandler {
private final Tracer tracer;

NativeMessageHandler(
String nodeName,
Version version,
String[] features,
StatsTracker statsTracker,
ThreadPool threadPool,
BigArrays bigArrays,
OutboundHandler outboundHandler,
NamedWriteableRegistry namedWriteableRegistry,
TransportHandshaker handshaker,
Expand All @@ -91,7 +98,7 @@ public class NativeMessageHandler implements ProtocolMessageHandler {
TransportKeepAlive keepAlive
) {
this.threadPool = threadPool;
this.outboundHandler = outboundHandler;
this.outboundHandler = new NativeOutboundHandler(nodeName, version, features, statsTracker, threadPool, bigArrays, outboundHandler);
this.namedWriteableRegistry = namedWriteableRegistry;
this.handshaker = handshaker;
this.requestHandlers = requestHandlers;
Expand Down Expand Up @@ -491,4 +498,9 @@ public void onFailure(Exception e) {
}
}

@Override
public void setMessageListener(TransportMessageListener listener) {
outboundHandler.setMessageListener(listener);
}

}
171 changes: 13 additions & 158 deletions server/src/main/java/org/opensearch/transport/OutboundHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,164 +35,47 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.CheckedSupplier;
import org.opensearch.common.io.stream.ReleasableBytesStreamOutput;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.network.CloseableChannel;
import org.opensearch.common.transport.NetworkExceptionHelper;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.NotifyOnceListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.Set;

/**
* Outbound data handler
*
* @opensearch.internal
*/
final class OutboundHandler {
public final class OutboundHandler {

private static final Logger logger = LogManager.getLogger(OutboundHandler.class);

private final String nodeName;
private final Version version;
private final String[] features;
private final StatsTracker statsTracker;
private final ThreadPool threadPool;
private final BigArrays bigArrays;
private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;

OutboundHandler(
String nodeName,
Version version,
String[] features,
StatsTracker statsTracker,
ThreadPool threadPool,
BigArrays bigArrays
) {
this.nodeName = nodeName;
this.version = version;
this.features = features;
public OutboundHandler(StatsTracker statsTracker, ThreadPool threadPool) {
this.statsTracker = statsTracker;
this.threadPool = threadPool;
this.bigArrays = bigArrays;
}

void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener<Void> listener) {
SendContext sendContext = new SendContext(channel, () -> bytes, listener);
SendContext sendContext = new SendContext(statsTracker, channel, () -> bytes, listener);
try {
internalSend(channel, sendContext);
sendBytes(channel, sendContext);
} catch (IOException e) {
// This should not happen as the bytes are already serialized
throw new AssertionError(e);
}
}

/**
* Sends the request to the given channel. This method should be used to send {@link TransportRequest}
* objects back to the caller.
*/
void sendRequest(
final DiscoveryNode node,
final TcpChannel channel,
final long requestId,
final String action,
final TransportRequest request,
final TransportRequestOptions options,
final Version channelVersion,
final boolean compressRequest,
final boolean isHandshake
) throws IOException, TransportException {
Version version = Version.min(this.version, channelVersion);
OutboundMessage.Request message = new OutboundMessage.Request(
threadPool.getThreadContext(),
features,
request,
version,
action,
requestId,
isHandshake,
compressRequest
);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onRequestSent(node, requestId, action, request, options));
sendMessage(channel, message, listener);
}

/**
* Sends the response to the given channel. This method should be used to send {@link TransportResponse}
* objects back to the caller.
*
* @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses
*/
void sendResponse(
final Version nodeVersion,
final Set<String> features,
final TcpChannel channel,
final long requestId,
final String action,
final TransportResponse response,
final boolean compress,
final boolean isHandshake
) throws IOException {
Version version = Version.min(this.version, nodeVersion);
OutboundMessage.Response message = new OutboundMessage.Response(
threadPool.getThreadContext(),
features,
response,
version,
requestId,
isHandshake,
compress
);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
sendMessage(channel, message, listener);
}

/**
* Sends back an error response to the caller via the given channel
*/
void sendErrorResponse(
final Version nodeVersion,
final Set<String> features,
final TcpChannel channel,
final long requestId,
final String action,
final Exception error
) throws IOException {
Version version = Version.min(this.version, nodeVersion);
TransportAddress address = new TransportAddress(channel.getLocalAddress());
RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error);
OutboundMessage.Response message = new OutboundMessage.Response(
threadPool.getThreadContext(),
features,
tx,
version,
requestId,
false,
false
);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error));
sendMessage(channel, message, listener);
}

private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
internalSend(channel, sendContext);
}

private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException {
public void sendBytes(TcpChannel channel, SendContext sendContext) throws IOException {
channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
BytesReference reference = sendContext.get();
// stash thread context so that channel event loop is not polluted by thread context
Expand All @@ -205,59 +88,30 @@ private void internalSend(TcpChannel channel, SendContext sendContext) throws IO
}
}

void setMessageListener(TransportMessageListener listener) {
if (messageListener == TransportMessageListener.NOOP_LISTENER) {
messageListener = listener;
} else {
throw new IllegalStateException("Cannot set message listener twice");
}
}

/**
* Internal message serializer
*
* @opensearch.internal
*/
private static class MessageSerializer implements CheckedSupplier<BytesReference, IOException>, Releasable {

private final OutboundMessage message;
private final BigArrays bigArrays;
private volatile ReleasableBytesStreamOutput bytesStreamOutput;

private MessageSerializer(OutboundMessage message, BigArrays bigArrays) {
this.message = message;
this.bigArrays = bigArrays;
}

@Override
public BytesReference get() throws IOException {
bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays);
return message.serialize(bytesStreamOutput);
}

@Override
public void close() {
IOUtils.closeWhileHandlingException(bytesStreamOutput);
}
}

private class SendContext extends NotifyOnceListener<Void> implements CheckedSupplier<BytesReference, IOException> {

public static class SendContext extends NotifyOnceListener<Void> implements CheckedSupplier<BytesReference, IOException> {
private final StatsTracker statsTracker;
private final TcpChannel channel;
private final CheckedSupplier<BytesReference, IOException> messageSupplier;
private final ActionListener<Void> listener;
private final Releasable optionalReleasable;
private long messageSize = -1;

private SendContext(
SendContext(
StatsTracker statsTracker,
TcpChannel channel,
CheckedSupplier<BytesReference, IOException> messageSupplier,
ActionListener<Void> listener
) {
this(channel, messageSupplier, listener, null);
this(statsTracker, channel, messageSupplier, listener, null);
}

private SendContext(
public SendContext(
StatsTracker statsTracker,
TcpChannel channel,
CheckedSupplier<BytesReference, IOException> messageSupplier,
ActionListener<Void> listener,
Expand All @@ -267,6 +121,7 @@ private SendContext(
this.messageSupplier = messageSupplier;
this.listener = listener;
this.optionalReleasable = optionalReleasable;
this.statsTracker = statsTracker;
}

public BytesReference get() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,25 @@
*/
public interface ProtocolMessageHandler {

/**
* Handles the message received on the channel.
* @param channel the channel on which the message was received
* @param message the message received
* @param startTime the start time
* @param slowLogThresholdMs the threshold for slow logs
* @param messageListener the message listener
*/
public void messageReceived(
TcpChannel channel,
ProtocolInboundMessage message,
long startTime,
long slowLogThresholdMs,
TransportMessageListener messageListener
) throws IOException;

/**
* Sets the message listener to be used by the handler.
* @param listener the message listener
*/
public void setMessageListener(TransportMessageListener listener);
}
Loading
Loading