diff --git a/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java b/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java index 5808f51f01efc..077e02988d8b3 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java @@ -19,6 +19,7 @@ import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.blobstore.transfer.stream.OffsetRangeInputStream; +import org.opensearch.common.blobstore.transfer.stream.RateLimitingOffsetRangeInputStream; import org.opensearch.common.blobstore.transfer.stream.ResettableCheckedInputStream; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.common.util.ByteUtils; @@ -27,6 +28,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.zip.CRC32; import com.jcraft.jzlib.JZlib; @@ -51,6 +53,7 @@ public class RemoteTransferContainer implements Closeable { private final long expectedChecksum; private final OffsetRangeInputStreamSupplier offsetRangeInputStreamSupplier; private final boolean isRemoteDataIntegritySupported; + private final AtomicBoolean closed = new AtomicBoolean(); private static final Logger log = LogManager.getLogger(RemoteTransferContainer.class); @@ -160,6 +163,10 @@ private LocalStreamSupplier getMultipartStreamSupplier( return () -> { try { OffsetRangeInputStream offsetRangeInputStream = offsetRangeInputStreamSupplier.get(size, position); + if (offsetRangeInputStream instanceof RateLimitingOffsetRangeInputStream) { + RateLimitingOffsetRangeInputStream rangeIndexInputStream = (RateLimitingOffsetRangeInputStream) offsetRangeInputStream; + rangeIndexInputStream.setClose(closed); + } InputStream inputStream = !isRemoteDataIntegrityCheckPossible() ? new ResettableCheckedInputStream(offsetRangeInputStream, fileName) : offsetRangeInputStream; @@ -226,27 +233,7 @@ private long getActualChecksum() { @Override public void close() throws IOException { - if (inputStreams.get() == null) { - log.warn("Input streams cannot be closed since they are not yet set for multi stream upload"); - return; - } - - boolean closeStreamException = false; - for (InputStream is : Objects.requireNonNull(inputStreams.get())) { - try { - if (is != null) { - is.close(); - } - } catch (IOException ex) { - closeStreamException = true; - // Attempting to close all streams first before throwing exception. - log.error("Multipart stream failed to close ", ex); - } - } - - if (closeStreamException) { - throw new IOException("Closure of some of the multi-part streams failed."); - } + closed.set(true); } /** diff --git a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeIndexInputStream.java b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeIndexInputStream.java index 7518f9ac569b9..07bd4be6ba622 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeIndexInputStream.java +++ b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeIndexInputStream.java @@ -8,10 +8,16 @@ package org.opensearch.common.blobstore.transfer.stream; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.store.IndexInput; +import org.opensearch.common.concurrent.RefCountedReleasable; import org.opensearch.common.lucene.store.InputStreamIndexInput; +import org.opensearch.common.util.concurrent.RunOnce; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; /** * OffsetRangeIndexInputStream extends InputStream to read from a specified offset using IndexInput @@ -19,9 +25,14 @@ * @opensearch.internal */ public class OffsetRangeIndexInputStream extends OffsetRangeInputStream { - + private static final Logger logger = LogManager.getLogger(OffsetRangeIndexInputStream.class); private final InputStreamIndexInput inputStreamIndexInput; private final IndexInput indexInput; + private AtomicBoolean closed; + + private final OffsetRangeRefCount offsetRangeRefCount; + + private final RunOnce closeOnce; /** * Construct a new OffsetRangeIndexInputStream object @@ -35,16 +46,50 @@ public OffsetRangeIndexInputStream(IndexInput indexInput, long size, long positi indexInput.seek(position); this.indexInput = indexInput; this.inputStreamIndexInput = new InputStreamIndexInput(indexInput, size); + ClosingStreams closingStreams = new ClosingStreams(); + closingStreams.indexInput = indexInput; + closingStreams.inputStreamIndexInput = inputStreamIndexInput; + offsetRangeRefCount = new OffsetRangeRefCount(closingStreams); + closeOnce = new RunOnce(offsetRangeRefCount::decRef); + } + + public void setClose(AtomicBoolean close) { + this.closed = close; } @Override public int read(byte[] b, int off, int len) throws IOException { - return inputStreamIndexInput.read(b, off, len); + ensureOpen(); + try (OffsetRangeRefCount ignored = getStreamReference()) { + return inputStreamIndexInput.read(b, off, len); + } + } + + private OffsetRangeRefCount getStreamReference() { + boolean successIncrement = offsetRangeRefCount.tryIncRef(); + if (successIncrement == false) { + throw alreadyClosed("OffsetRangeIndexInputStream is already unreferenced."); + } + return offsetRangeRefCount; + } + + private void ensureOpen() { + if (closed != null && closed.get() == true) { + logger.debug("Read on stream was attempted after during the close of overall file stream!"); + throw alreadyClosed("Already closed: "); + } + } + + AlreadyClosedException alreadyClosed(String msg) { + return new AlreadyClosedException(msg + this); } @Override public int read() throws IOException { - return inputStreamIndexInput.read(); + ensureOpen(); + try (OffsetRangeRefCount ignored = getStreamReference()) { + return inputStreamIndexInput.read(); + } } @Override @@ -67,9 +112,37 @@ public long getFilePointer() throws IOException { return indexInput.getFilePointer(); } + @Override + public String toString() { + return "OffsetRangeIndexInputStream{" + "indexInput=" + indexInput + ", closed=" + closed + '}'; + } + + private static class ClosingStreams { + private InputStreamIndexInput inputStreamIndexInput; + private IndexInput indexInput; + } + + private static class OffsetRangeRefCount extends RefCountedReleasable { + private static final Logger logger = LogManager.getLogger(OffsetRangeRefCount.class); + + public OffsetRangeRefCount(ClosingStreams ref) { + super("OffsetRangeRefCount", ref, () -> { + try { + ref.inputStreamIndexInput.close(); + } catch (IOException ex) { + logger.error("Failed to close indexStreamIndexInput", ex); + } + try { + ref.indexInput.close(); + } catch (IOException ex) { + logger.error("Failed to close indexInput", ex); + } + }); + } + } + @Override public void close() throws IOException { - inputStreamIndexInput.close(); - indexInput.close(); + closeOnce.run(); } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeInputStream.java b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeInputStream.java index e8b889db1f3b0..fd7721664d23a 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeInputStream.java +++ b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeInputStream.java @@ -10,6 +10,7 @@ import java.io.IOException; import java.io.InputStream; +import java.util.concurrent.atomic.AtomicBoolean; /** * OffsetRangeInputStream is an abstract class that extends from {@link InputStream} @@ -19,4 +20,8 @@ */ public abstract class OffsetRangeInputStream extends InputStream { public abstract long getFilePointer() throws IOException; + + public void setClose(AtomicBoolean close) { + // Nothing + } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/RateLimitingOffsetRangeInputStream.java b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/RateLimitingOffsetRangeInputStream.java index b455999bbed0c..1007151ce121c 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/RateLimitingOffsetRangeInputStream.java +++ b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/RateLimitingOffsetRangeInputStream.java @@ -12,6 +12,7 @@ import org.opensearch.common.StreamLimiter; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; /** @@ -40,6 +41,10 @@ public RateLimitingOffsetRangeInputStream( this.delegate = delegate; } + public void setClose(AtomicBoolean close) { + delegate.setClose(close); + } + @Override public int read() throws IOException { int b = delegate.read(); diff --git a/server/src/test/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainerTests.java index a33e5f453d1e1..8c8a54b564dc4 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainerTests.java @@ -8,21 +8,35 @@ package org.opensearch.common.blobstore.transfer; +import org.apache.lucene.store.AlreadyClosedException; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.RateLimiter; import org.opensearch.common.StreamContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.blobstore.transfer.stream.OffsetRangeFileInputStream; +import org.opensearch.common.blobstore.transfer.stream.OffsetRangeIndexInputStream; import org.opensearch.common.blobstore.transfer.stream.OffsetRangeInputStream; +import org.opensearch.common.blobstore.transfer.stream.RateLimitingOffsetRangeInputStream; import org.opensearch.common.blobstore.transfer.stream.ResettableCheckedInputStream; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.test.OpenSearchTestCase; import org.junit.Before; import java.io.IOException; +import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.Arrays; import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import org.mockito.Mockito; +import org.mockito.stubbing.Answer; public class RemoteTransferContainerTests extends OpenSearchTestCase { @@ -92,25 +106,37 @@ private void testSupplyStreamContext( int partCount = streamContext.getNumberOfParts(); assertEquals(expectedPartCount, partCount); Thread[] threads = new Thread[partCount]; + InputStream[] streams = new InputStream[partCount]; long totalContentLength = remoteTransferContainer.getContentLength(); assert partSize * (partCount - 1) + lastPartSize == totalContentLength : "part sizes and last part size don't add up to total content length"; logger.info("partSize: {}, lastPartSize: {}, partCount: {}", partSize, lastPartSize, streamContext.getNumberOfParts()); - for (int partIdx = 0; partIdx < partCount; partIdx++) { - int finalPartIdx = partIdx; - long expectedPartSize = (partIdx == partCount - 1) ? lastPartSize : partSize; - threads[partIdx] = new Thread(() -> { + try { + for (int partIdx = 0; partIdx < partCount; partIdx++) { + int finalPartIdx = partIdx; + long expectedPartSize = (partIdx == partCount - 1) ? lastPartSize : partSize; + threads[partIdx] = new Thread(() -> { + try { + InputStreamContainer inputStreamContainer = streamContext.provideStream(finalPartIdx); + streams[finalPartIdx] = inputStreamContainer.getInputStream(); + assertEquals(expectedPartSize, inputStreamContainer.getContentLength()); + } catch (IOException e) { + fail("IOException during stream creation"); + } + }); + threads[partIdx].start(); + } + for (int i = 0; i < partCount; i++) { + threads[i].join(); + } + } finally { + Arrays.stream(streams).forEach(stream -> { try { - InputStreamContainer inputStreamContainer = streamContext.provideStream(finalPartIdx); - assertEquals(expectedPartSize, inputStreamContainer.getContentLength()); + stream.close(); } catch (IOException e) { - fail("IOException during stream creation"); + throw new RuntimeException(e); } }); - threads[partIdx].start(); - } - for (int i = 0; i < partCount; i++) { - threads[i].join(); } } @@ -182,6 +208,7 @@ public OffsetRangeInputStream get(long size, long position) throws IOException { } private void testTypeOfProvidedStreams(boolean isRemoteDataIntegritySupported) throws IOException { + InputStream inputStream = null; try ( RemoteTransferContainer remoteTransferContainer = new RemoteTransferContainer( testFile.getFileName().toString(), @@ -201,12 +228,127 @@ public OffsetRangeInputStream get(long size, long position) throws IOException { ) { StreamContext streamContext = remoteTransferContainer.supplyStreamContext(16); InputStreamContainer inputStreamContainer = streamContext.provideStream(0); + inputStream = inputStreamContainer.getInputStream(); if (shouldOffsetInputStreamsBeChecked(isRemoteDataIntegritySupported)) { assertTrue(inputStreamContainer.getInputStream() instanceof ResettableCheckedInputStream); } else { assertTrue(inputStreamContainer.getInputStream() instanceof OffsetRangeInputStream); } assertThrows(RuntimeException.class, () -> remoteTransferContainer.supplyStreamContext(16)); + } finally { + if (inputStream != null) { + inputStream.close(); + } + } + } + + public void testStreamReadAccessAfterClose() throws IOException, InterruptedException { + Supplier rateLimiterSupplier = Mockito.mock(Supplier.class); + Mockito.when(rateLimiterSupplier.get()).thenReturn(null); + CountDownLatch readInvokedLatch = new CountDownLatch(1); + AtomicBoolean readAfterClose = new AtomicBoolean(); + CountDownLatch streamClosed = new CountDownLatch(1); + AtomicBoolean indexInputClosed = new AtomicBoolean(); + try ( + RemoteTransferContainer remoteTransferContainer = new RemoteTransferContainer( + testFile.getFileName().toString(), + testFile.getFileName().toString(), + TEST_FILE_SIZE_BYTES, + true, + WritePriority.NORMAL, + new RemoteTransferContainer.OffsetRangeInputStreamSupplier() { + @Override + public OffsetRangeInputStream get(long size, long position) throws IOException { + IndexInput indexInput = Mockito.mock(IndexInput.class); + Mockito.doAnswer(invocation -> { + indexInputClosed.set(true); + return null; + }).when(indexInput).close(); + Mockito.when(indexInput.getFilePointer()).thenAnswer((Answer) invocation -> { + if (readAfterClose.get() == false) { + return 0L; + } + readInvokedLatch.countDown(); + boolean closedSuccess = streamClosed.await(30, TimeUnit.SECONDS); + assertTrue(closedSuccess); + assertFalse(indexInputClosed.get()); + return 0L; + }); + + OffsetRangeIndexInputStream offsetRangeIndexInputStream = new OffsetRangeIndexInputStream( + indexInput, + size, + position + ); + return new RateLimitingOffsetRangeInputStream(offsetRangeIndexInputStream, rateLimiterSupplier, null); + } + }, + 0, + true + ) + ) { + StreamContext streamContext = remoteTransferContainer.supplyStreamContext(16); + InputStreamContainer inputStreamContainer = streamContext.provideStream(0); + assertTrue(inputStreamContainer.getInputStream() instanceof RateLimitingOffsetRangeInputStream); + CountDownLatch latch = new CountDownLatch(1); + new Thread(() -> { + try { + readAfterClose.set(true); + inputStreamContainer.getInputStream().readAllBytes(); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }).start(); + boolean successReadWait = readInvokedLatch.await(30, TimeUnit.SECONDS); + assertTrue(successReadWait); + // Closing stream here. Test Multiple invocations of close. Shouldn't throw any exception + inputStreamContainer.getInputStream().close(); + inputStreamContainer.getInputStream().close(); + streamClosed.countDown(); + boolean processed = latch.await(30, TimeUnit.SECONDS); + assertTrue(processed); + assertTrue(readAfterClose.get()); + assertTrue(indexInputClosed.get()); + + // Test Multiple invocations of close. Shouldn't throw any exception + inputStreamContainer.getInputStream().close(); + inputStreamContainer.getInputStream().close(); + + } + } + + public void testReadAccessWhenStreamClosed() throws IOException { + Supplier rateLimiterSupplier = Mockito.mock(Supplier.class); + Mockito.when(rateLimiterSupplier.get()).thenReturn(null); + try ( + RemoteTransferContainer remoteTransferContainer = new RemoteTransferContainer( + testFile.getFileName().toString(), + testFile.getFileName().toString(), + TEST_FILE_SIZE_BYTES, + true, + WritePriority.NORMAL, + new RemoteTransferContainer.OffsetRangeInputStreamSupplier() { + @Override + public OffsetRangeInputStream get(long size, long position) throws IOException { + IndexInput indexInput = Mockito.mock(IndexInput.class); + OffsetRangeIndexInputStream offsetRangeIndexInputStream = new OffsetRangeIndexInputStream( + indexInput, + size, + position + ); + return new RateLimitingOffsetRangeInputStream(offsetRangeIndexInputStream, rateLimiterSupplier, null); + } + }, + 0, + true + ) + ) { + StreamContext streamContext = remoteTransferContainer.supplyStreamContext(16); + InputStreamContainer inputStreamContainer = streamContext.provideStream(0); + inputStreamContainer.getInputStream().close(); + assertThrows(AlreadyClosedException.class, () -> inputStreamContainer.getInputStream().readAllBytes()); } }