Skip to content

Commit

Permalink
fix(s3stream): fix potential incomplete future in StreamReader (#826)
Browse files Browse the repository at this point in the history
Signed-off-by: Shichao Nie <[email protected]>
  • Loading branch information
SCNieh authored Dec 12, 2023
1 parent 8870c10 commit be55d79
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 33 deletions.
71 changes: 50 additions & 21 deletions s3stream/src/main/java/com/automq/stream/s3/cache/StreamReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -111,9 +113,15 @@ public CompletableFuture<List<StreamRecordBatch>> syncReadAhead(long streamId, l
DefaultS3BlockCache.ReadAheadTaskKey readAheadTaskKey = new DefaultS3BlockCache.ReadAheadTaskKey(streamId, startOffset);
// put a placeholder task at start offset to prevent next cache miss request spawn duplicated read ahead task
inflightReadAheadTaskMap.putIfAbsent(readAheadTaskKey, new CompletableFuture<>());
return getDataBlockIndices(streamId, endOffset, context).thenComposeAsync(v ->
handleSyncReadAhead(streamId, startOffset, endOffset, maxBytes, agent, uuid, timer, context),
streamReaderExecutor);
context.taskKeySet.add(readAheadTaskKey);
return getDataBlockIndices(streamId, endOffset, context)
.thenComposeAsync(v ->
handleSyncReadAhead(streamId, startOffset, endOffset, maxBytes, agent, uuid, timer, context), streamReaderExecutor)
.whenComplete((nil, ex) -> {
for (DefaultS3BlockCache.ReadAheadTaskKey key : context.taskKeySet) {
completeInflightTask(key, ex);
}
});
}

CompletableFuture<List<StreamRecordBatch>> handleSyncReadAhead(long streamId, long startOffset, long endOffset,
Expand All @@ -132,7 +140,7 @@ CompletableFuture<List<StreamRecordBatch>> handleSyncReadAhead(long streamId, lo
List<String> sortedDataBlockKeyList = new ArrayList<>();

// collect all data blocks to read from S3
List<Pair<ObjectReader, StreamDataBlock>> streamDataBlocksToRead = collectStreamDataBlocksToRead(streamId, context);
List<Pair<ObjectReader, StreamDataBlock>> streamDataBlocksToRead = collectStreamDataBlocksToRead(context);

// reserve all data blocks to read
List<DataBlockReadAccumulator.ReserveResult> reserveResults = dataBlockReadAccumulator.reserveDataBlock(streamDataBlocksToRead);
Expand Down Expand Up @@ -180,9 +188,12 @@ CompletableFuture<List<StreamRecordBatch>> handleSyncReadAhead(long streamId, lo
streamId, startOffset, endOffset, streamDataBlock, ex);
}
completeInflightTask(taskKey, ex);
context.taskKeySet.remove(taskKey);
if (readIndex == 0) {
// in case of first data block and startOffset is not aligned with start of data block
completeInflightTask(new DefaultS3BlockCache.ReadAheadTaskKey(streamId, startOffset), ex);
DefaultS3BlockCache.ReadAheadTaskKey key = new DefaultS3BlockCache.ReadAheadTaskKey(streamId, startOffset);
completeInflightTask(key, ex);
context.taskKeySet.remove(key);
}
}));
if (reserveResult.reserveSize() > 0) {
Expand Down Expand Up @@ -249,9 +260,14 @@ public void asyncReadAhead(long streamId, long startOffset, long endOffset, int
DefaultS3BlockCache.ReadAheadTaskKey readAheadTaskKey = new DefaultS3BlockCache.ReadAheadTaskKey(streamId, startOffset);
// put a placeholder task at start offset to prevent next cache miss request spawn duplicated read ahead task
inflightReadAheadTaskMap.putIfAbsent(readAheadTaskKey, new CompletableFuture<>());
getDataBlockIndices(streamId, endOffset, context).thenAcceptAsync(v ->
handleAsyncReadAhead(streamId, startOffset, endOffset, maxBytes, agent, timer, context),
streamReaderExecutor);
getDataBlockIndices(streamId, endOffset, context)
.thenAcceptAsync(v ->
handleAsyncReadAhead(streamId, startOffset, endOffset, maxBytes, agent, timer, context), streamReaderExecutor)
.whenComplete((nil, ex) -> {
for (DefaultS3BlockCache.ReadAheadTaskKey key : context.taskKeySet) {
completeInflightTask(key, ex);
}
});
}

CompletableFuture<Void> handleAsyncReadAhead(long streamId, long startOffset, long endOffset, int maxBytes, ReadAheadAgent agent,
Expand All @@ -266,7 +282,7 @@ CompletableFuture<Void> handleAsyncReadAhead(long streamId, long startOffset, lo

List<CompletableFuture<Void>> cfList = new ArrayList<>();
// collect all data blocks to read from S3
List<Pair<ObjectReader, StreamDataBlock>> streamDataBlocksToRead = collectStreamDataBlocksToRead(streamId, context);
List<Pair<ObjectReader, StreamDataBlock>> streamDataBlocksToRead = collectStreamDataBlocksToRead(context);

// concurrently read all data blocks
for (int i = 0; i < streamDataBlocksToRead.size(); i++) {
Expand All @@ -282,7 +298,7 @@ CompletableFuture<Void> handleAsyncReadAhead(long streamId, long startOffset, lo
DefaultS3BlockCache.ReadAheadTaskKey taskKey = new DefaultS3BlockCache.ReadAheadTaskKey(streamId, streamDataBlock.getStartOffset());
DataBlockReadAccumulator.ReserveResult reserveResult = dataBlockReadAccumulator.reserveDataBlock(List.of(pair)).get(0);
int readIndex = i;
cfList.add(reserveResult.cf().thenAcceptAsync(dataBlock -> {
CompletableFuture<Void> cf = reserveResult.cf().thenAcceptAsync(dataBlock -> {
if (dataBlock.records().isEmpty()) {
return;
}
Expand All @@ -302,11 +318,15 @@ CompletableFuture<Void> handleAsyncReadAhead(long streamId, long startOffset, lo
}
inflightReadThrottle.release(uuid);
completeInflightTask(taskKey, ex);
context.taskKeySet.remove(taskKey);
if (readIndex == 0) {
// in case of first data block and startOffset is not aligned with start of data block
completeInflightTask(new DefaultS3BlockCache.ReadAheadTaskKey(streamId, startOffset), ex);
DefaultS3BlockCache.ReadAheadTaskKey key = new DefaultS3BlockCache.ReadAheadTaskKey(streamId, startOffset);
completeInflightTask(key, ex);
context.taskKeySet.remove(key);
}
}));
});
cfList.add(cf);

if (LOGGER.isDebugEnabled()) {
LOGGER.debug("[S3BlockCache] async ra acquire size: {}, uuid={}, stream={}, {}-{}, {}",
Expand All @@ -316,7 +336,10 @@ CompletableFuture<Void> handleAsyncReadAhead(long streamId, long startOffset, lo
inflightReadThrottle.acquire(uuid, reserveResult.reserveSize()).thenAcceptAsync(nil -> {
// read data block
dataBlockReadAccumulator.readDataBlock(objectReader, streamDataBlock.dataBlockIndex());
}, streamReaderExecutor);
}, streamReaderExecutor).exceptionally(ex -> {
cf.completeExceptionally(ex);
return null;
});
}
}
return CompletableFuture.allOf(cfList.toArray(CompletableFuture[]::new)).whenComplete((ret, ex) -> {
Expand Down Expand Up @@ -376,6 +399,12 @@ CompletableFuture<Void> getDataBlockIndices(long streamId, long endOffset, ReadC
return CompletableFuture.completedFuture(null);
}

for (StreamDataBlock streamDataBlock : streamDataBlocks) {
DefaultS3BlockCache.ReadAheadTaskKey taskKey = new DefaultS3BlockCache.ReadAheadTaskKey(streamId, streamDataBlock.getStartOffset());
inflightReadAheadTaskMap.putIfAbsent(taskKey, new CompletableFuture<>());
context.taskKeySet.add(taskKey);
}

S3ObjectMetadata objectMetadata = context.objects.get(context.objectIndex);
long lastOffset = streamDataBlocks.get(streamDataBlocks.size() - 1).getEndOffset();
context.lastOffset = Math.max(lastOffset, context.lastOffset);
Expand All @@ -392,25 +421,23 @@ CompletableFuture<Void> getDataBlockIndices(long streamId, long endOffset, ReadC
}

private void completeInflightTask(DefaultS3BlockCache.ReadAheadTaskKey key, Throwable ex) {
CompletableFuture<Void> inflightReadAheadTask = inflightReadAheadTaskMap.remove(key);
if (inflightReadAheadTask != null) {
inflightReadAheadTaskMap.computeIfPresent(key, (k, v) -> {
if (ex != null) {
inflightReadAheadTask.completeExceptionally(ex);
v.completeExceptionally(ex);
} else {
inflightReadAheadTask.complete(null);
v.complete(null);
}
}
return null;
});
}

private List<Pair<ObjectReader, StreamDataBlock>> collectStreamDataBlocksToRead(long streamId, ReadContext context) {
private List<Pair<ObjectReader, StreamDataBlock>> collectStreamDataBlocksToRead(ReadContext context) {
List<Pair<ObjectReader, StreamDataBlock>> result = new ArrayList<>();
for (Pair<Long, List<StreamDataBlock>> entry : context.streamDataBlocksPair) {
long objectId = entry.getKey();
ObjectReader objectReader = context.objectReaderMap.get(objectId);
for (StreamDataBlock streamDataBlock : entry.getValue()) {
result.add(Pair.of(objectReader, streamDataBlock));
DefaultS3BlockCache.ReadAheadTaskKey taskKey = new DefaultS3BlockCache.ReadAheadTaskKey(streamId, streamDataBlock.getStartOffset());
inflightReadAheadTaskMap.putIfAbsent(taskKey, new CompletableFuture<>());
}
}
return result;
Expand All @@ -431,6 +458,7 @@ static class ReadContext {
List<S3ObjectMetadata> objects;
List<Pair<Long, List<StreamDataBlock>>> streamDataBlocksPair;
Map<Long, ObjectReader> objectReaderMap;
Set<DefaultS3BlockCache.ReadAheadTaskKey> taskKeySet;
int objectIndex;
long nextStartOffset;
int nextMaxBytes;
Expand All @@ -442,6 +470,7 @@ public ReadContext(long startOffset, int maxBytes) {
this.objectIndex = 0;
this.streamDataBlocksPair = new LinkedList<>();
this.objectReaderMap = new HashMap<>();
this.taskKeySet = new HashSet<>();
this.nextStartOffset = startOffset;
this.nextMaxBytes = maxBytes;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package com.automq.stream.s3;

import com.automq.stream.s3.cache.CacheAccessType;
import com.automq.stream.s3.cache.DefaultS3BlockCache;
import com.automq.stream.s3.cache.ReadDataBlock;
import com.automq.stream.s3.model.StreamRecordBatch;
Expand All @@ -26,6 +27,7 @@
import com.automq.stream.s3.metadata.ObjectUtils;
import com.automq.stream.s3.metadata.S3ObjectMetadata;
import com.automq.stream.s3.metadata.S3ObjectType;
import com.automq.stream.utils.Threads;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -128,8 +130,12 @@ public void testRead_readAhead() throws ExecutionException, InterruptedException
verify(s3Operator, timeout(1000).times(2)).rangeRead(eq(ObjectUtils.genKey(0, 1)), ArgumentMatchers.anyLong(), ArgumentMatchers.anyLong(), ArgumentMatchers.any());
verify(objectManager, timeout(1000).times(1)).getObjects(eq(233L), eq(30L), eq(-1L), eq(2));

Threads.sleep(1000);

// expect read ahead already cached the records
List<StreamRecordBatch> records = s3BlockCache.read(233L, 20L, 30L, 10000).get().getRecords();
ReadDataBlock ret = s3BlockCache.read(233L, 20L, 30L, 10000).get();
assertEquals(CacheAccessType.BLOCK_CACHE_HIT, ret.getCacheAccessType());
List<StreamRecordBatch> records = ret.getRecords();
assertEquals(1, records.size());
assertEquals(20L, records.get(0).getBaseOffset());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicInteger;

import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;

public class StreamReaderTest {

Expand Down Expand Up @@ -94,21 +97,25 @@ public void testSyncReadAheadInflight() {
ObjectManager objectManager = Mockito.mock(ObjectManager.class);
BlockCache blockCache = Mockito.mock(BlockCache.class);
Map<ReadAheadTaskKey, CompletableFuture<Void>> inflightReadAheadTasks = new HashMap<>();
StreamReader streamReader = new StreamReader(s3Operator, objectManager, blockCache, cache, accumulator, inflightReadAheadTasks, new InflightReadThrottle());
StreamReader streamReader = Mockito.spy(new StreamReader(s3Operator, objectManager, blockCache, cache, accumulator, inflightReadAheadTasks, new InflightReadThrottle()));

long streamId = 233L;
long startOffset = 70;
StreamReader.ReadContext context = new StreamReader.ReadContext(startOffset, 256);
ObjectReader.DataBlockIndex index1 = new ObjectReader.DataBlockIndex(0, 0, 256, 128);
context.streamDataBlocksPair = List.of(
new ImmutablePair<>(1L, List.of(
new StreamDataBlock(233L, 64, 128, 1, index1))));
long endOffset = 1024;
int maxBytes = 64;
long objectId = 1;
S3ObjectMetadata metadata = new S3ObjectMetadata(objectId, -1, S3ObjectType.STREAM);
doAnswer(invocation -> CompletableFuture.completedFuture(List.of(metadata)))
.when(objectManager).getObjects(eq(streamId), eq(startOffset), anyLong(), anyInt());

ObjectReader reader = Mockito.mock(ObjectReader.class);
Mockito.when(reader.read(index1)).thenReturn(new CompletableFuture<>());
context.objectReaderMap = new HashMap<>(Map.of(1L, reader));
inflightReadAheadTasks.put(new ReadAheadTaskKey(233L, startOffset), new CompletableFuture<>());
streamReader.handleSyncReadAhead(233L, startOffset,
999, 64, Mockito.mock(ReadAheadAgent.class), UUID.randomUUID(), new TimerUtil(), context);
ObjectReader.DataBlockIndex index1 = new ObjectReader.DataBlockIndex(0, 0, 256, 128);
doReturn(reader).when(streamReader).getObjectReader(metadata);
doAnswer(invocation -> CompletableFuture.completedFuture(new ObjectReader.FindIndexResult(true, -1, -1,
List.of(new StreamDataBlock(streamId, 64, 128, objectId, index1))))).when(reader).find(eq(streamId), eq(startOffset), anyLong(), eq(maxBytes));
doReturn(new CompletableFuture<>()).when(reader).read(index1);

streamReader.syncReadAhead(streamId, startOffset, endOffset, maxBytes, Mockito.mock(ReadAheadAgent.class), UUID.randomUUID());
Threads.sleep(1000);
Assertions.assertEquals(2, inflightReadAheadTasks.size());
Assertions.assertTrue(inflightReadAheadTasks.containsKey(new ReadAheadTaskKey(233L, startOffset)));
Expand Down

0 comments on commit be55d79

Please sign in to comment.