diff --git a/java/client/src/main/java/glide/connectors/handlers/ChannelHandler.java b/java/client/src/main/java/glide/connectors/handlers/ChannelHandler.java index 2874fc13b8..32637219d1 100644 --- a/java/client/src/main/java/glide/connectors/handlers/ChannelHandler.java +++ b/java/client/src/main/java/glide/connectors/handlers/ChannelHandler.java @@ -6,8 +6,11 @@ import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.unix.DomainSocketAddress; import java.util.concurrent.CompletableFuture; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; import redis_request.RedisRequestOuterClass.RedisRequest; import response.ResponseOuterClass.Response; @@ -17,8 +20,6 @@ */ public class ChannelHandler { - private static final String THREAD_POOL_NAME = "glide-channel"; - protected final Channel channel; protected final CallbackDispatcher callbackDispatcher; @@ -41,6 +42,8 @@ public ChannelHandler( .channel(threadPoolResource.getDomainSocketChannelClass()) .handler(new ProtobufSocketChannelInitializer(callbackDispatcher)) .connect(new DomainSocketAddress(socketPath)) + // TODO .addListener(new NettyFutureErrorHandler()) + // we need to use connection promise here for that ^ .sync() .channel(); this.callbackDispatcher = callbackDispatcher; @@ -58,9 +61,11 @@ public CompletableFuture write(RedisRequest.Builder request, boolean f request.setCallbackIdx(commandId.getKey()); if (flush) { - channel.writeAndFlush(request.build()); + channel + .writeAndFlush(request.build()) + .addListener(new NettyFutureErrorHandler(commandId.getValue())); } else { - channel.write(request.build()); + channel.write(request.build()).addListener(new NettyFutureErrorHandler(commandId.getValue())); } return commandId.getValue(); } @@ -73,7 +78,7 @@ public CompletableFuture write(RedisRequest.Builder request, boolean f */ public CompletableFuture connect(ConnectionRequest request) { var future = callbackDispatcher.registerConnection(); - channel.writeAndFlush(request); + channel.writeAndFlush(request).addListener(new NettyFutureErrorHandler(future)); return future; } @@ -82,4 +87,25 @@ public ChannelFuture close() { callbackDispatcher.shutdownGracefully(); return channel.close(); } + + /** + * Propagate an error from Netty's {@link ChannelFuture} and complete the {@link + * CompletableFuture} promise. + */ + @RequiredArgsConstructor + private static class NettyFutureErrorHandler implements ChannelFutureListener { + + private final CompletableFuture promise; + + @Override + public void operationComplete(@NonNull ChannelFuture channelFuture) throws Exception { + if (channelFuture.isCancelled()) { + promise.cancel(false); + } + var cause = channelFuture.cause(); + if (cause != null) { + promise.completeExceptionally(cause); + } + } + } } diff --git a/java/client/src/main/java/glide/connectors/handlers/ReadHandler.java b/java/client/src/main/java/glide/connectors/handlers/ReadHandler.java index 44e8d75a1c..29b7f4c01b 100644 --- a/java/client/src/main/java/glide/connectors/handlers/ReadHandler.java +++ b/java/client/src/main/java/glide/connectors/handlers/ReadHandler.java @@ -29,8 +29,10 @@ public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) /** Handles uncaught exceptions from {@link #channelRead(ChannelHandlerContext, Object)}. */ @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + // TODO: log thru logger System.out.printf("=== exceptionCaught %s %s %n", ctx, cause); - cause.printStackTrace(System.err); - super.exceptionCaught(ctx, cause); + + callbackDispatcher.distributeClosingException( + "An unhandled error while reading from UDS channel: " + cause); } } diff --git a/java/client/src/test/java/glide/connection/ConnectionWithGlideMockTests.java b/java/client/src/test/java/glide/connection/ConnectionWithGlideMockTests.java new file mode 100644 index 0000000000..52be39bacc --- /dev/null +++ b/java/client/src/test/java/glide/connection/ConnectionWithGlideMockTests.java @@ -0,0 +1,196 @@ +/** Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.connection; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import connection_request.ConnectionRequestOuterClass.ConnectionRequest; +import connection_request.ConnectionRequestOuterClass.NodeAddress; +import glide.api.RedisClient; +import glide.api.models.exceptions.ClosingException; +import glide.connectors.handlers.CallbackDispatcher; +import glide.connectors.handlers.ChannelHandler; +import glide.connectors.resources.Platform; +import glide.managers.CommandManager; +import glide.managers.ConnectionManager; +import glide.utils.RustCoreLibMockTestBase; +import glide.utils.RustCoreMock; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeoutException; +import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import redis_request.RedisRequestOuterClass.RedisRequest; +import response.ResponseOuterClass.Response; + +public class ConnectionWithGlideMockTests extends RustCoreLibMockTestBase { + + private ChannelHandler channelHandler = null; + + @BeforeEach + @SneakyThrows + public void createTestClient() { + channelHandler = + new ChannelHandler( + new CallbackDispatcher(), socketPath, Platform.getThreadPoolResourceSupplier().get()); + } + + @AfterEach + public void closeTestClient() { + channelHandler.close(); + } + + private Future testConnection() { + return channelHandler.connect(createConnectionRequest()); + } + + private static ConnectionRequest createConnectionRequest() { + return ConnectionRequest.newBuilder() + .addAddresses(NodeAddress.newBuilder().setHost("dummyhost").setPort(42).build()) + .build(); + } + + @BeforeAll + public static void init() { + startRustCoreLibMock(null); + } + + @Test + @SneakyThrows + // as of #710 https://github.com/aws/babushka/pull/710 - connection response is empty + public void can_connect_with_empty_response() { + RustCoreMock.updateGlideMock( + new RustCoreMock.GlideMockProtobuf() { + @Override + public Response connection(ConnectionRequest request) { + return Response.newBuilder().build(); + } + + @Override + public Response.Builder redisRequest(RedisRequest request) { + return null; + } + }); + + var connectionResponse = testConnection().get(); + assertAll( + () -> assertFalse(connectionResponse.hasClosingError()), + () -> assertFalse(connectionResponse.hasRequestError()), + () -> assertFalse(connectionResponse.hasRespPointer())); + } + + @Test + @SneakyThrows + public void can_connect_with_ok_response() { + RustCoreMock.updateGlideMock( + new RustCoreMock.GlideMockProtobuf() { + @Override + public Response connection(ConnectionRequest request) { + return OK().build(); + } + + @Override + public Response.Builder redisRequest(RedisRequest request) { + return null; + } + }); + + var connectionResponse = testConnection().get(); + assertAll( + () -> assertTrue(connectionResponse.hasConstantResponse()), + () -> assertFalse(connectionResponse.hasClosingError()), + () -> assertFalse(connectionResponse.hasRequestError()), + () -> assertFalse(connectionResponse.hasRespPointer())); + } + + @Test + public void cant_connect_when_no_response() { + RustCoreMock.updateGlideMock( + new RustCoreMock.GlideMockProtobuf() { + @Override + public Response connection(ConnectionRequest request) { + return null; + } + + @Override + public Response.Builder redisRequest(RedisRequest request) { + return null; + } + }); + + assertThrows(TimeoutException.class, () -> testConnection().get(1, SECONDS)); + } + + @Test + @SneakyThrows + public void cant_connect_when_negative_response() { + RustCoreMock.updateGlideMock( + new RustCoreMock.GlideMockProtobuf() { + @Override + public Response connection(ConnectionRequest request) { + return Response.newBuilder().setClosingError("You shall not pass!").build(); + } + + @Override + public Response.Builder redisRequest(RedisRequest request) { + return null; + } + }); + + var exception = assertThrows(ExecutionException.class, () -> testConnection().get(1, SECONDS)); + assertAll( + () -> assertTrue(exception.getCause() instanceof ClosingException), + () -> assertEquals("You shall not pass!", exception.getCause().getMessage())); + } + + @Test + @SneakyThrows + public void rethrow_error_on_read_when_malformed_packet_received() { + RustCoreMock.updateGlideMock(request -> new byte[] {-1}); + + var exception = assertThrows(ExecutionException.class, () -> testConnection().get(1, SECONDS)); + assertAll( + () -> assertTrue(exception.getCause() instanceof ClosingException), + () -> + assertTrue( + exception + .getCause() + .getMessage() + .contains("An unhandled error while reading from UDS channel"))); + } + + @Test + @SneakyThrows + public void rethrow_error_if_UDS_channel_closed() { + var client = new TestClient(channelHandler); + stopRustCoreLibMock(); + try { + var exception = + assertThrows( + ExecutionException.class, () -> client.customCommand(new String[0]).get(1, SECONDS)); + assertTrue(exception.getCause() instanceof RuntimeException); + + // Not a public class, can't import + assertEquals( + "io.netty.channel.StacklessClosedChannelException", + exception.getCause().getCause().getClass().getName()); + } finally { + // restart mock to let other tests pass if this one failed + startRustCoreLibMock(null); + } + } + + private static class TestClient extends RedisClient { + + public TestClient(ChannelHandler channelHandler) { + super(new ConnectionManager(channelHandler), new CommandManager(channelHandler)); + } + } +} diff --git a/java/client/src/test/java/glide/utils/RustCoreLibMockTestBase.java b/java/client/src/test/java/glide/utils/RustCoreLibMockTestBase.java new file mode 100644 index 0000000000..ecf59e4a17 --- /dev/null +++ b/java/client/src/test/java/glide/utils/RustCoreLibMockTestBase.java @@ -0,0 +1,49 @@ +/** Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.utils; + +import glide.connectors.handlers.ChannelHandler; +import glide.ffi.resolvers.SocketListenerResolver; +import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +public class RustCoreLibMockTestBase { + + /** + * Pass this socket path to {@link ChannelHandler} or mock {@link + * SocketListenerResolver#getSocket()} to return it. + */ + protected static String socketPath = null; + + @SneakyThrows + public static void startRustCoreLibMock(RustCoreMock.GlideMock rustCoreLibMock) { + assert socketPath == null + : "Previous `RustCoreMock` wasn't stopped. Ensure that your test class inherits" + + " `RustCoreLibMockTestBase`."; + + socketPath = RustCoreMock.start(rustCoreLibMock); + } + + @BeforeEach + public void preTestCheck() { + assert socketPath != null + : "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class" + + " inherited from `RustCoreLibMockTestBase`."; + } + + @AfterEach + public void afterTestCheck() { + assert !RustCoreMock.failed() : "Error occurred in `RustCoreMock`"; + } + + @AfterAll + @SneakyThrows + public static void stopRustCoreLibMock() { + assert socketPath != null + : "You missed to call `startRustCoreLibMock` in a `@AfterAll` method of your test class" + + " inherited from `RustCoreLibMockTestBase`."; + RustCoreMock.stop(); + socketPath = null; + } +} diff --git a/java/client/src/test/java/glide/utils/RustCoreMock.java b/java/client/src/test/java/glide/utils/RustCoreMock.java new file mode 100644 index 0000000000..8ef787948e --- /dev/null +++ b/java/client/src/test/java/glide/utils/RustCoreMock.java @@ -0,0 +1,189 @@ +/** Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.utils; + +import connection_request.ConnectionRequestOuterClass.ConnectionRequest; +import glide.connectors.resources.Platform; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.EpollServerDomainSocketChannel; +import io.netty.channel.kqueue.KQueueServerDomainSocketChannel; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.channel.unix.DomainSocketChannel; +import io.netty.handler.codec.protobuf.ProtobufEncoder; +import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder; +import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender; +import java.nio.file.Files; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import redis_request.RedisRequestOuterClass.RedisRequest; +import response.ResponseOuterClass.ConstantResponse; +import response.ResponseOuterClass.Response; + +public class RustCoreMock { + + @FunctionalInterface + public interface GlideMock { + default boolean isRaw() { + return true; + } + + byte[] handle(byte[] request); + } + + public abstract static class GlideMockProtobuf implements GlideMock { + @Override + public boolean isRaw() { + return false; + } + + @Override + public byte[] handle(byte[] request) { + return new byte[0]; + } + + /** Return `null` to do not reply. */ + public abstract Response connection(ConnectionRequest request); + + /** Return `null` to do not reply. */ + public abstract Response.Builder redisRequest(RedisRequest request); + + public Response redisRequestWithCallbackId(RedisRequest request) { + var responseDraft = redisRequest(request); + return responseDraft == null + ? null + : responseDraft.setCallbackIdx(request.getCallbackIdx()).build(); + } + + public static Response.Builder OK() { + return Response.newBuilder().setConstantResponse(ConstantResponse.OK); + } + } + + public abstract static class GlideMockConnectAll extends GlideMockProtobuf { + @Override + public Response connection(ConnectionRequest request) { + return Response.newBuilder().build(); + } + } + + /** Thread pool supplied to Netty to perform all async IO. */ + private final EventLoopGroup group; + + private final Channel channel; + + private final String socketPath; + + private static RustCoreMock instance; + + private GlideMock messageProcessor; + + /** Update {@link GlideMock} into a running {@link RustCoreMock}. */ + public static void updateGlideMock(GlideMock newMock) { + instance.messageProcessor = newMock; + } + + private final AtomicBoolean failed = new AtomicBoolean(false); + + /** Get and clear failure status. */ + public static boolean failed() { + return instance.failed.compareAndSet(true, false); + } + + @SneakyThrows + private RustCoreMock() { + var threadPoolResource = Platform.getThreadPoolResourceSupplier().get(); + socketPath = Files.createTempFile("GlideCoreMock", null).toString(); + group = threadPoolResource.getEventLoopGroup(); + channel = + new ServerBootstrap() + .group(group) + .channel( + Platform.getCapabilities().isEPollAvailable() + ? EpollServerDomainSocketChannel.class + : KQueueServerDomainSocketChannel.class) + .childHandler( + new ChannelInitializer() { + + @Override + protected void initChannel(DomainSocketChannel ch) throws Exception { + ch.pipeline() + // https://netty.io/4.1/api/io/netty/handler/codec/protobuf/ProtobufEncoder.html + .addLast("frameDecoder", new ProtobufVarint32FrameDecoder()) + .addLast("frameEncoder", new ProtobufVarint32LengthFieldPrepender()) + .addLast("protobufEncoder", new ProtobufEncoder()) + .addLast(new UdsServer(ch)); + } + }) + .bind(new DomainSocketAddress(socketPath)) + .syncUninterruptibly() + .channel(); + } + + public static String start(GlideMock messageProcessor) { + if (instance != null) { + stop(); + } + instance = new RustCoreMock(); + instance.messageProcessor = messageProcessor; + return instance.socketPath; + } + + @SneakyThrows + public static void stop() { + if (instance != null) { + instance.channel.close().syncUninterruptibly(); + instance.group.shutdownGracefully().get(5, TimeUnit.SECONDS); + instance = null; + } + } + + @RequiredArgsConstructor + private class UdsServer extends ChannelInboundHandlerAdapter { + + private final Channel ch; + + // This works with only one connected client. + // TODO Rework with `channelActive` override. + private final AtomicBoolean anybodyConnected = new AtomicBoolean(false); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + var buf = (ByteBuf) msg; + var bytes = new byte[buf.readableBytes()]; + buf.readBytes(bytes); + buf.release(); + if (messageProcessor.isRaw()) { + ch.writeAndFlush(Unpooled.copiedBuffer(messageProcessor.handle(bytes))); + return; + } + var handler = (GlideMockProtobuf) messageProcessor; + Response response = null; + if (!anybodyConnected.get()) { + var connection = ConnectionRequest.parseFrom(bytes); + response = handler.connection(connection); + anybodyConnected.setPlain(true); + } else { + var request = RedisRequest.parseFrom(bytes); + response = handler.redisRequestWithCallbackId(request); + } + if (response != null) { + ctx.writeAndFlush(response); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + cause.printStackTrace(); + ctx.close(); + failed.setPlain(true); + } + } +}