diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationBaseIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationBaseIT.java index 52fe85b51cebd..8fef88cbe6820 100644 --- a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationBaseIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationBaseIT.java @@ -24,7 +24,9 @@ import org.opensearch.index.IndexService; import org.opensearch.index.SegmentReplicationPerGroupStats; import org.opensearch.index.SegmentReplicationShardStats; +import org.opensearch.index.engine.Engine; import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.shard.ShardId; import org.opensearch.index.store.Store; import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.indices.IndicesService; @@ -158,6 +160,7 @@ protected void verifyStoreContent() throws Exception { final String indexName = primaryRouting.getIndexName(); final List replicaRouting = shardRoutingTable.replicaShards(); final IndexShard primaryShard = getIndexShard(clusterState, primaryRouting, indexName); + final int primaryDocCount = getDocCountFromShard(primaryShard); final Map primarySegmentMetadata = primaryShard.getSegmentMetadataMap(); for (ShardRouting replica : replicaRouting) { IndexShard replicaShard = getIndexShard(clusterState, replica, indexName); @@ -165,6 +168,8 @@ protected void verifyStoreContent() throws Exception { primarySegmentMetadata, replicaShard.getSegmentMetadataMap() ); + final int replicaDocCount = getDocCountFromShard(replicaShard); + assertEquals("Doc counts should match", primaryDocCount, replicaDocCount); if (recoveryDiff.missing.isEmpty() == false || recoveryDiff.different.isEmpty() == false) { fail( "Expected no missing or different segments between primary and replica but diff was missing: " @@ -185,10 +190,30 @@ protected void verifyStoreContent() throws Exception { }, 1, TimeUnit.MINUTES); } + private int getDocCountFromShard(IndexShard shard) { + try (final Engine.Searcher searcher = shard.acquireSearcher("test")) { + return searcher.getDirectoryReader().numDocs(); + } + } + private IndexShard getIndexShard(ClusterState state, ShardRouting routing, String indexName) { - return getIndexShard(state.nodes().get(routing.currentNodeId()).getName(), indexName); + return getIndexShard(state.nodes().get(routing.currentNodeId()).getName(), routing.shardId(), indexName); + } + + /** + * Fetch IndexShard by shardId, multiple shards per node allowed. + */ + protected IndexShard getIndexShard(String node, ShardId shardId, String indexName) { + final Index index = resolveIndex(indexName); + IndicesService indicesService = internalCluster().getInstance(IndicesService.class, node); + IndexService indexService = indicesService.indexServiceSafe(index); + final Optional id = indexService.shardIds().stream().filter(sid -> sid == shardId.id()).findFirst(); + return indexService.getShard(id.get()); } + /** + * Fetch IndexShard, assumes only a single shard per node. + */ protected IndexShard getIndexShard(String node, String indexName) { final Index index = resolveIndex(indexName); IndicesService indicesService = internalCluster().getInstance(IndicesService.class, node); diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java new file mode 100644 index 0000000000000..9025c1cc79884 --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.replication; + +import org.junit.Before; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.settings.Settings; +import org.opensearch.indices.replication.common.ReplicationType; +import org.opensearch.test.OpenSearchIntegTestCase; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, minNumDataNodes = 2) +public class SegmentReplicationSuiteIT extends SegmentReplicationBaseIT { + + @Before + public void setup() { + internalCluster().startClusterManagerOnlyNode(); + createIndex(INDEX_NAME); + } + + @Override + public Settings indexSettings() { + final Settings.Builder builder = Settings.builder() + .put(super.indexSettings()) + // reset shard & replica count to random values set by OpenSearchIntegTestCase. + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards()) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numberOfReplicas()) + .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT); + + // TODO: Randomly enable remote store on these tests. + return builder.build(); + } + + public void testBasicReplication() throws Exception { + final int docCount = scaledRandomIntBetween(10, 200); + for (int i = 0; i < docCount; i++) { + client().prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource("field", "value" + i).execute().get(); + } + refresh(); + ensureGreen(INDEX_NAME); + verifyStoreContent(); + } + + public void testDropRandomNodeDuringReplication() throws Exception { + internalCluster().ensureAtLeastNumDataNodes(2); + internalCluster().startClusterManagerOnlyNodes(1); + + final int docCount = scaledRandomIntBetween(10, 200); + for (int i = 0; i < docCount; i++) { + client().prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource("field", "value" + i).execute().get(); + } + refresh(); + + internalCluster().restartRandomDataNode(); + + ensureYellow(INDEX_NAME); + client().prepareIndex(INDEX_NAME).setId(Integer.toString(docCount)).setSource("field", "value" + docCount).execute().get(); + internalCluster().startDataOnlyNode(); + client().admin().indices().delete(new DeleteIndexRequest(INDEX_NAME)).actionGet(); + } + + public void testDeleteIndexWhileReplicating() throws Exception { + internalCluster().startClusterManagerOnlyNode(); + final int docCount = scaledRandomIntBetween(10, 200); + for (int i = 0; i < docCount; i++) { + client().prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource("field", "value" + i).execute().get(); + } + refresh(INDEX_NAME); + client().admin().indices().delete(new DeleteIndexRequest(INDEX_NAME)).actionGet(); + } + + public void testFullRestartDuringReplication() throws Exception { + internalCluster().startNode(); + final int docCount = scaledRandomIntBetween(10, 200); + for (int i = 0; i < docCount; i++) { + client().prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource("field", "value" + i).execute().get(); + } + refresh(INDEX_NAME); + internalCluster().fullRestart(); + ensureGreen(INDEX_NAME); + } +} diff --git a/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java b/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java index 5e278f06cfb8f..0b343fb0b0871 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java +++ b/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java @@ -835,7 +835,7 @@ void finalizeRecovery(long targetLocalCheckpoint, long trimAboveSeqNo, ActionLis } else { // Force round of segment replication to update its checkpoint to primary's if (shard.indexSettings().isSegRepEnabled()) { - recoveryTarget.forceSegmentFileSync(); + cancellableThreads.execute(recoveryTarget::forceSegmentFileSync); } } stopWatch.stop(); diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java index 7a996ec7aedaa..226ccbaf01afa 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java @@ -45,8 +45,7 @@ public enum Stage { GET_CHECKPOINT_INFO((byte) 3), FILE_DIFF((byte) 4), GET_FILES((byte) 5), - FINALIZE_REPLICATION((byte) 6), - CANCELLED((byte) 7); + FINALIZE_REPLICATION((byte) 6); private static final Stage[] STAGES = new Stage[Stage.values().length]; @@ -245,14 +244,6 @@ public void setStage(Stage stage) { overallTimer.stop(); timingData.put("OVERALL", overallTimer.time()); break; - case CANCELLED: - if (this.stage == Stage.DONE) { - throw new IllegalStateException("can't move replication to Cancelled state from Done."); - } - this.stage = Stage.CANCELLED; - overallTimer.stop(); - timingData.put("OVERALL", overallTimer.time()); - break; default: throw new IllegalArgumentException("unknown SegmentReplicationState.Stage [" + stage + "]"); } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java index 9d724d6cc9dcf..f59a7c2368689 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java @@ -12,16 +12,15 @@ import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.IndexFormatTooNewException; import org.apache.lucene.index.IndexFormatTooOldException; -import org.apache.lucene.index.SegmentInfos; +import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.store.BufferedChecksumIndexInput; import org.apache.lucene.store.ByteBuffersDataInput; import org.apache.lucene.store.ByteBuffersIndexInput; import org.apache.lucene.store.ChecksumIndexInput; -import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchCorruptionException; import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; import org.opensearch.action.StepListener; -import org.opensearch.common.CheckedConsumer; import org.opensearch.common.UUIDs; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.lucene.Lucene; @@ -39,6 +38,8 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.List; +import java.util.Locale; /** * Represents the target of a replication event. @@ -102,17 +103,11 @@ public SegmentReplicationTarget retryCopy() { @Override public String description() { - return "Segment replication from " + source.toString(); + return String.format(Locale.ROOT, "Id:[%d] Shard:[%s] Source:[%s]", getId(), shardId(), source.getDescription()); } @Override public void notifyListener(ReplicationFailedException e, boolean sendShardFailure) { - // Cancellations still are passed to our SegmentReplicationListener as failures, if we have failed because of cancellation - // update the stage. - final Throwable cancelledException = ExceptionsHelper.unwrap(e, CancellableThreads.ExecutionCancelledException.class); - if (cancelledException != null) { - state.setStage(SegmentReplicationState.Stage.CANCELLED); - } listener.onFailure(state(), e, sendShardFailure); } @@ -141,141 +136,115 @@ public void writeFileChunk( /** * Start the Replication event. + * * @param listener {@link ActionListener} listener. */ public void startReplication(ActionListener listener) { cancellableThreads.setOnCancel((reason, beforeCancelEx) -> { - // This method only executes when cancellation is triggered by this node and caught by a call to checkForCancel, - // SegmentReplicationSource does not share CancellableThreads. - final CancellableThreads.ExecutionCancelledException executionCancelledException = - new CancellableThreads.ExecutionCancelledException("replication was canceled reason [" + reason + "]"); - notifyListener(new ReplicationFailedException("Segment replication failed", executionCancelledException), false); - throw executionCancelledException; + throw new CancellableThreads.ExecutionCancelledException("replication was canceled reason [" + reason + "]"); }); + // TODO: Remove this useless state. state.setStage(SegmentReplicationState.Stage.REPLICATING); final StepListener checkpointInfoListener = new StepListener<>(); final StepListener getFilesListener = new StepListener<>(); - final StepListener finalizeListener = new StepListener<>(); - cancellableThreads.checkForCancel(); - logger.trace("[shardId {}] Replica starting replication [id {}]", shardId().getId(), getId()); + logger.trace(new ParameterizedMessage("Starting Replication Target: {}", description())); // Get list of files to copy from this checkpoint. state.setStage(SegmentReplicationState.Stage.GET_CHECKPOINT_INFO); + cancellableThreads.checkForCancel(); source.getCheckpointMetadata(getId(), checkpoint, checkpointInfoListener); - checkpointInfoListener.whenComplete(checkpointInfo -> getFiles(checkpointInfo, getFilesListener), listener::onFailure); - getFilesListener.whenComplete( - response -> finalizeReplication(checkpointInfoListener.result(), finalizeListener), - listener::onFailure - ); - finalizeListener.whenComplete(r -> listener.onResponse(null), listener::onFailure); + checkpointInfoListener.whenComplete(checkpointInfo -> { + final List filesToFetch = getFiles(checkpointInfo); + state.setStage(SegmentReplicationState.Stage.GET_FILES); + cancellableThreads.checkForCancel(); + source.getSegmentFiles(getId(), checkpointInfo.getCheckpoint(), filesToFetch, indexShard, getFilesListener); + }, listener::onFailure); + + getFilesListener.whenComplete(response -> { + finalizeReplication(checkpointInfoListener.result()); + listener.onResponse(null); + }, listener::onFailure); } - private void getFiles(CheckpointInfoResponse checkpointInfo, StepListener getFilesListener) - throws IOException { + private List getFiles(CheckpointInfoResponse checkpointInfo) throws IOException { cancellableThreads.checkForCancel(); state.setStage(SegmentReplicationState.Stage.FILE_DIFF); final Store.RecoveryDiff diff = Store.segmentReplicationDiff(checkpointInfo.getMetadataMap(), indexShard.getSegmentMetadataMap()); - logger.trace("Replication diff for checkpoint {} {}", checkpointInfo.getCheckpoint(), diff); + logger.trace(() -> new ParameterizedMessage("Replication diff for checkpoint {} {}", checkpointInfo.getCheckpoint(), diff)); /* * Segments are immutable. So if the replica has any segments with the same name that differ from the one in the incoming * snapshot from source that means the local copy of the segment has been corrupted/changed in some way and we throw an * IllegalStateException to fail the shard */ if (diff.different.isEmpty() == false) { - IllegalStateException illegalStateException = new IllegalStateException( + throw new OpenSearchCorruptionException( new ParameterizedMessage( "Shard {} has local copies of segments that differ from the primary {}", indexShard.shardId(), diff.different ).getFormattedMessage() ); - ReplicationFailedException rfe = new ReplicationFailedException( - indexShard.shardId(), - "different segment files", - illegalStateException - ); - fail(rfe, true); - throw rfe; } for (StoreFileMetadata file : diff.missing) { state.getIndex().addFileDetail(file.name(), file.length(), false); } - // always send a req even if not fetching files so the primary can clear the copyState for this shard. - state.setStage(SegmentReplicationState.Stage.GET_FILES); - cancellableThreads.checkForCancel(); - source.getSegmentFiles(getId(), checkpointInfo.getCheckpoint(), diff.missing, indexShard, getFilesListener); + return diff.missing; } - private void finalizeReplication(CheckpointInfoResponse checkpointInfoResponse, ActionListener listener) { + private void finalizeReplication(CheckpointInfoResponse checkpointInfoResponse) throws OpenSearchCorruptionException { // TODO: Refactor the logic so that finalize doesn't have to be invoked for remote store as source if (source instanceof RemoteStoreReplicationSource) { - ActionListener.completeWith(listener, () -> { - state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); - return null; - }); + state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); return; } - ActionListener.completeWith(listener, () -> { - cancellableThreads.checkForCancel(); - state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); - Store store = null; + cancellableThreads.checkForCancel(); + state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); + Store store = null; + try { + store = store(); + store.incRef(); + store.buildInfosFromBytes( + multiFileWriter.getTempFileNames(), + checkpointInfoResponse.getInfosBytes(), + checkpointInfoResponse.getCheckpoint().getSegmentsGen(), + indexShard::finalizeReplication + ); + } catch (CorruptIndexException | IndexFormatTooNewException | IndexFormatTooOldException ex) { + // this is a fatal exception at this stage. + // this means we transferred files from the remote that have not be checksummed and they are + // broken. We have to clean up this shard entirely, remove all files and bubble it up to the + // source shard since this index might be broken there as well? The Source can handle this and checks + // its content on disk if possible. try { - store = store(); - store.incRef(); - CheckedConsumer finalizeReplication = indexShard::finalizeReplication; - store.buildInfosFromBytes( - multiFileWriter.getTempFileNames(), - checkpointInfoResponse.getInfosBytes(), - checkpointInfoResponse.getCheckpoint().getSegmentsGen(), - finalizeReplication - ); - } catch (CorruptIndexException | IndexFormatTooNewException | IndexFormatTooOldException ex) { - // this is a fatal exception at this stage. - // this means we transferred files from the remote that have not be checksummed and they are - // broken. We have to clean up this shard entirely, remove all files and bubble it up to the - // source shard since this index might be broken there as well? The Source can handle this and checks - // its content on disk if possible. try { - try { - store.removeCorruptionMarker(); - } finally { - Lucene.cleanLuceneIndex(store.directory()); // clean up and delete all files - } - } catch (Exception e) { - logger.debug("Failed to clean lucene index", e); - ex.addSuppressed(e); - } - ReplicationFailedException rfe = new ReplicationFailedException( - indexShard.shardId(), - "failed to clean after replication", - ex - ); - fail(rfe, true); - throw rfe; - } catch (OpenSearchException ex) { - /* - Ignore closed replication target as it can happen due to index shard closed event in a separate thread. - In such scenario, ignore the exception - */ - assert cancellableThreads.isCancelled() : "Replication target closed but segment replication not cancelled"; - logger.info("Replication target closed", ex); - } catch (Exception ex) { - ReplicationFailedException rfe = new ReplicationFailedException( - indexShard.shardId(), - "failed to clean after replication", - ex - ); - fail(rfe, true); - throw rfe; - } finally { - if (store != null) { - store.decRef(); + store.removeCorruptionMarker(); + } finally { + Lucene.cleanLuceneIndex(store.directory()); // clean up and delete all files } + } catch (Exception e) { + logger.debug("Failed to clean lucene index", e); + ex.addSuppressed(e); } - return null; - }); + throw new OpenSearchCorruptionException(ex); + } catch (AlreadyClosedException ex) { + // In this case the shard is closed at some point while updating the reader. + // This can happen when the engine is closed in a separate thread. + logger.warn("Shard is already closed, closing replication"); + } catch (OpenSearchException ex) { + /* + Ignore closed replication target as it can happen due to index shard closed event in a separate thread. + In such scenario, ignore the exception + */ + assert cancellableThreads.isCancelled() : "Replication target closed but segment replication not cancelled"; + } catch (Exception ex) { + throw new OpenSearchCorruptionException(ex); + } finally { + if (store != null) { + store.decRef(); + } + } } /** @@ -288,10 +257,15 @@ private ChecksumIndexInput toIndexInput(byte[] input) { ); } + /** + * Trigger a cancellation, this method will not close the target a subsequent call to #fail is required from target service. + */ @Override - protected void onCancel(String reason) { - cancellableThreads.cancel(reason); - source.cancel(); - multiFileWriter.close(); + public void cancel(String reason) { + if (finished.get() == false) { + logger.trace(new ParameterizedMessage("Cancelling replication for target {}", description())); + cancellableThreads.cancel(reason); + source.cancel(); + } } } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java index a7e0c0ec887ab..ac93412eef725 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java @@ -11,14 +11,15 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchCorruptionException; import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ChannelActionListener; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.CancellableThreads; +import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.index.shard.IndexEventListener; import org.opensearch.index.shard.IndexShard; @@ -43,10 +44,8 @@ import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportService; -import java.io.IOException; import java.util.Map; import java.util.Optional; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import static org.opensearch.indices.replication.SegmentReplicationSourceService.Actions.UPDATE_VISIBLE_CHECKPOINT; @@ -145,7 +144,7 @@ public SegmentReplicationTargetService( @Override public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) { if (indexShard != null && indexShard.indexSettings().isSegRepEnabled()) { - onGoingReplications.cancelForShard(shardId, "shard closed"); + onGoingReplications.requestCancel(indexShard.shardId(), "Shard closing"); latestReceivedCheckpoint.remove(shardId); } } @@ -167,7 +166,7 @@ public void afterIndexShardStarted(IndexShard indexShard) { @Override public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { if (oldRouting != null && indexShard.indexSettings().isSegRepEnabled() && oldRouting.primary() == false && newRouting.primary()) { - onGoingReplications.cancelForShard(indexShard.shardId(), "shard has been promoted to primary"); + onGoingReplications.requestCancel(indexShard.shardId(), "Shard has been promoted to primary"); latestReceivedCheckpoint.remove(indexShard.shardId()); } } @@ -224,11 +223,13 @@ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedChe if (ongoingReplicationTarget != null) { if (ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() < receivedCheckpoint.getPrimaryTerm()) { logger.trace( - "Cancelling ongoing replication from old primary with primary term {}", - ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() + () -> new ParameterizedMessage( + "Cancelling ongoing replication {} from old primary with primary term {}", + ongoingReplicationTarget.description(), + ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() + ) ); - onGoingReplications.cancel(ongoingReplicationTarget.getId(), "Cancelling stuck target after new primary"); - completedReplications.put(replicaShard.shardId(), ongoingReplicationTarget); + ongoingReplicationTarget.cancel("Cancelling stuck target after new primary"); } else { logger.trace( () -> new ParameterizedMessage( @@ -268,21 +269,20 @@ public void onReplicationFailure( ReplicationFailedException e, boolean sendShardFailure ) { - logger.trace( + logger.error( () -> new ParameterizedMessage( "[shardId {}] [replication id {}] Replication failed, timing data: {}", replicaShard.shardId().getId(), state.getReplicationId(), state.getTimingData() - ) + ), + e ); if (sendShardFailure == true) { - logger.error("replication failure", e); - replicaShard.failShard("replication failure", e); + failShard(e, replicaShard); } } }); - } } else { logger.trace( @@ -305,7 +305,14 @@ protected void updateVisibleCheckpoint(long replicationId, IndexShard replicaSha final TransportRequestOptions options = TransportRequestOptions.builder() .withTimeout(recoverySettings.internalActionTimeout()) .build(); - logger.debug("Updating replication checkpoint to {}", request.getCheckpoint()); + logger.trace( + () -> new ParameterizedMessage( + "Updating Primary shard that replica {}-{} is synced to checkpoint {}", + replicaShard.shardId(), + replicaShard.routingEntry().allocationId(), + request.getCheckpoint() + ) + ); RetryableTransportClient transportClient = new RetryableTransportClient( transportService, getPrimaryNode(primaryShard), @@ -315,19 +322,23 @@ protected void updateVisibleCheckpoint(long replicationId, IndexShard replicaSha final ActionListener listener = new ActionListener<>() { @Override public void onResponse(Void unused) { - logger.debug( - "Successfully updated replication checkpoint {} for replica {}", - replicaShard.shardId(), - request.getCheckpoint() + logger.trace( + () -> new ParameterizedMessage( + "Successfully updated replication checkpoint {} for replica {}", + replicaShard.shardId(), + request.getCheckpoint() + ) ); } @Override public void onFailure(Exception e) { logger.error( - "Failed to update visible checkpoint for replica {}, {}: {}", - replicaShard.shardId(), - request.getCheckpoint(), + () -> new ParameterizedMessage( + "Failed to update visible checkpoint for replica {}, {}:", + replicaShard.shardId(), + request.getCheckpoint() + ), e ); } @@ -350,6 +361,13 @@ private DiscoveryNode getPrimaryNode(ShardRouting primaryShard) { protected boolean processLatestReceivedCheckpoint(IndexShard replicaShard, Thread thread) { final ReplicationCheckpoint latestPublishedCheckpoint = latestReceivedCheckpoint.get(replicaShard.shardId()); if (latestPublishedCheckpoint != null && latestPublishedCheckpoint.isAheadOf(replicaShard.getLatestReplicationCheckpoint())) { + logger.trace( + () -> new ParameterizedMessage( + "Processing latest received checkpoint for shard {} {}", + replicaShard.shardId(), + latestPublishedCheckpoint + ) + ); Runnable runnable = () -> onNewCheckpoint(latestReceivedCheckpoint.get(replicaShard.shardId()), replicaShard); // Checks if we are using same thread and forks if necessary. if (thread == Thread.currentThread()) { @@ -381,7 +399,15 @@ public SegmentReplicationTarget startReplication(final IndexShard indexShard, fi // pkg-private for integration tests void startReplication(final SegmentReplicationTarget target) { - final long replicationId = onGoingReplications.start(target, recoverySettings.activityTimeout()); + final long replicationId; + try { + replicationId = onGoingReplications.startSafe(target, recoverySettings.activityTimeout()); + } catch (ReplicationFailedException e) { + // replication already running for shard. + target.fail(e, false); + return; + } + logger.trace(() -> new ParameterizedMessage("Added new replication to collection {}", target.description())); threadPool.generic().execute(new ReplicationRunner(replicationId)); } @@ -410,7 +436,7 @@ default void onFailure(ReplicationState state, ReplicationFailedException e, boo /** * Runnable implementation to trigger a replication event. */ - private class ReplicationRunner implements Runnable { + private class ReplicationRunner extends AbstractRunnable { final long replicationId; @@ -419,47 +445,49 @@ public ReplicationRunner(long replicationId) { } @Override - public void run() { + public void onFailure(Exception e) { + try (final ReplicationRef ref = onGoingReplications.get(replicationId)) { + logger.error(() -> new ParameterizedMessage("Error during segment replication, {}", ref.get().description()), e); + } + onGoingReplications.fail(replicationId, new ReplicationFailedException("Unexpected Error during replication", e), false); + } + + @Override + public void doRun() { start(replicationId); } } private void start(final long replicationId) { + final SegmentReplicationTarget target; try (ReplicationRef replicationRef = onGoingReplications.get(replicationId)) { // This check is for handling edge cases where the reference is removed before the ReplicationRunner is started by the // threadpool. if (replicationRef == null) { return; } - SegmentReplicationTarget target = onGoingReplications.getTarget(replicationId); - replicationRef.get().startReplication(new ActionListener<>() { - @Override - public void onResponse(Void o) { - onGoingReplications.markAsDone(replicationId); - if (target.state().getIndex().recoveredFileCount() != 0 && target.state().getIndex().recoveredBytes() != 0) { - completedReplications.put(target.shardId(), target); - } - + target = replicationRef.get(); + } + target.startReplication(new ActionListener<>() { + @Override + public void onResponse(Void o) { + logger.trace(() -> new ParameterizedMessage("Finished replicating {} marking as done.", target.description())); + onGoingReplications.markAsDone(replicationId); + if (target.state().getIndex().recoveredFileCount() != 0 && target.state().getIndex().recoveredBytes() != 0) { + completedReplications.put(target.shardId(), target); } + } - @Override - public void onFailure(Exception e) { - Throwable cause = ExceptionsHelper.unwrapCause(e); - if (cause instanceof CancellableThreads.ExecutionCancelledException) { - if (onGoingReplications.getTarget(replicationId) != null) { - IndexShard indexShard = onGoingReplications.getTarget(replicationId).indexShard(); - // if the target still exists in our collection, the primary initiated the cancellation, fail the replication - // but do not fail the shard. Cancellations initiated by this node from Index events will be removed with - // onGoingReplications.cancel and not appear in the collection when this listener resolves. - onGoingReplications.fail(replicationId, new ReplicationFailedException(indexShard, cause), false); - completedReplications.put(target.shardId(), target); - } - } else { - onGoingReplications.fail(replicationId, new ReplicationFailedException("Segment Replication failed", e), false); - } + @Override + public void onFailure(Exception e) { + logger.error(() -> new ParameterizedMessage("Exception replicating {} marking as failed.", target.description()), e); + if (e instanceof OpenSearchCorruptionException) { + onGoingReplications.fail(replicationId, new ReplicationFailedException("Store corruption during replication", e), true); + return; } - }); - } + onGoingReplications.fail(replicationId, new ReplicationFailedException("Segment Replication failed", e), false); + } + }); } private class FileChunkTransportRequestHandler implements TransportRequestHandler { @@ -484,27 +512,31 @@ public void messageReceived(final FileChunkRequest request, TransportChannel cha private class ForceSyncTransportRequestHandler implements TransportRequestHandler { @Override public void messageReceived(final ForceSyncRequest request, TransportChannel channel, Task task) throws Exception { - assert indicesService != null; - final IndexShard indexShard = indicesService.getShardOrNull(request.getShardId()); - // Proceed with round of segment replication only when it is allowed - if (indexShard == null || indexShard.getReplicationEngine().isEmpty()) { - logger.info("Ignore force segment replication sync as it is not allowed"); - channel.sendResponse(TransportResponse.Empty.INSTANCE); - return; - } + forceReplication(request, new ChannelActionListener<>(channel, Actions.FORCE_SYNC, request)); + } + } + + private void forceReplication(ForceSyncRequest request, ActionListener listener) { + final ShardId shardId = request.getShardId(); + assert indicesService != null; + final IndexShard indexShard = indicesService.getShardOrNull(shardId); + // Proceed with round of segment replication only when it is allowed + if (indexShard == null || indexShard.getReplicationEngine().isEmpty()) { + listener.onResponse(TransportResponse.Empty.INSTANCE); + } else { startReplication(indexShard, new SegmentReplicationTargetService.SegmentReplicationListener() { @Override public void onReplicationDone(SegmentReplicationState state) { - logger.trace( - () -> new ParameterizedMessage( - "[shardId {}] [replication id {}] Replication complete to {}, timing data: {}", - indexShard.shardId().getId(), - state.getReplicationId(), - indexShard.getLatestReplicationCheckpoint(), - state.getTimingData() - ) - ); try { + logger.trace( + () -> new ParameterizedMessage( + "[shardId {}] [replication id {}] Force replication Sync complete to {}, timing data: {}", + shardId, + state.getReplicationId(), + indexShard.getLatestReplicationCheckpoint(), + state.getTimingData() + ) + ); // Promote engine type for primary target if (indexShard.recoveryState().getPrimary() == true) { indexShard.resetToWriteableEngine(); @@ -512,33 +544,40 @@ public void onReplicationDone(SegmentReplicationState state) { // Update the replica's checkpoint on primary's replication tracker. updateVisibleCheckpoint(state.getReplicationId(), indexShard); } - channel.sendResponse(TransportResponse.Empty.INSTANCE); - } catch (InterruptedException | TimeoutException | IOException e) { - throw new RuntimeException(e); + listener.onResponse(TransportResponse.Empty.INSTANCE); + } catch (Exception e) { + logger.error("Error while marking replication completed", e); + listener.onFailure(e); } } @Override public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) { - logger.trace( + logger.error( () -> new ParameterizedMessage( "[shardId {}] [replication id {}] Replication failed, timing data: {}", indexShard.shardId().getId(), state.getReplicationId(), state.getTimingData() - ) + ), + e ); - if (sendShardFailure == true) { - indexShard.failShard("replication failure", e); - } - try { - channel.sendResponse(e); - } catch (IOException ex) { - throw new RuntimeException(ex); + if (sendShardFailure) { + failShard(e, indexShard); } + listener.onFailure(e); } }); } } + private void failShard(ReplicationFailedException e, IndexShard indexShard) { + try { + indexShard.failShard("unrecoverable replication failure", e); + } catch (Exception inner) { + logger.error("Error attempting to fail shard", inner); + e.addSuppressed(inner); + } + } + } diff --git a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java index e918ac0a79691..c65ef27969154 100644 --- a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java +++ b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java @@ -70,6 +70,26 @@ public ReplicationCollection(Logger logger, ThreadPool threadPool) { this.threadPool = threadPool; } + /** + * Starts a new target event for a given shard, fails the given target if this shard is already replicating. + * @param target ReplicationTarget to start + * @param activityTimeout timeout for entire replication event + * @return The replication id + */ + public long startSafe(T target, TimeValue activityTimeout) { + synchronized (onGoingTargetEvents) { + final boolean isPresent = onGoingTargetEvents.values() + .stream() + .map(ReplicationTarget::shardId) + .anyMatch(t -> t.equals(target.shardId())); + if (isPresent) { + throw new ReplicationFailedException("Shard " + target.shardId() + " is already replicating"); + } else { + return start(target, activityTimeout); + } + } + } + /** * Starts a new target event for the given shard, source node and state * @@ -234,6 +254,22 @@ public boolean cancelForShard(ShardId shardId, String reason) { return cancelled; } + /** + * Trigger cancel on the target but do not remove it from the collection. + * This is intended to be called to ensure replication events are removed from the collection + * only when the target has closed. + * + * @param shardId {@link ShardId} shard events to cancel + * @param reason {@link String} reason for cancellation + */ + public void requestCancel(ShardId shardId, String reason) { + for (T value : onGoingTargetEvents.values()) { + if (value.shardId().equals(shardId)) { + value.cancel(reason); + } + } + } + /** * Get target for shard * diff --git a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationTarget.java b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationTarget.java index 4d75ff4896706..344a4040be119 100644 --- a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationTarget.java +++ b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationTarget.java @@ -173,6 +173,7 @@ public void cancel(String reason) { public void fail(ReplicationFailedException e, boolean sendShardFailure) { if (finished.compareAndSet(false, true)) { try { + logger.debug("marking target " + description() + " as failed", e); notifyListener(e, sendShardFailure); } finally { try { diff --git a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java index b3876a8ea8fd0..1f5980ba9bfe0 100644 --- a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java +++ b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java @@ -10,6 +10,7 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.index.SegmentInfos; +import org.apache.lucene.store.AlreadyClosedException; import org.junit.Assert; import org.opensearch.ExceptionsHelper; import org.opensearch.action.ActionListener; @@ -86,11 +87,12 @@ import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.spy; public class SegmentReplicationIndexShardTests extends OpenSearchIndexLevelReplicationTestCase { @@ -1156,6 +1158,125 @@ public void getSegmentFiles( } } + public void testCloseShardDuringFinalize() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + final IndexShard replicaSpy = spy(replica); + + primary.refresh("Test"); + + doThrow(AlreadyClosedException.class).when(replicaSpy).finalizeReplication(any()); + + replicateSegments(primary, List.of(replicaSpy)); + } + } + + public void testCloseShardWhileGettingCheckpoint() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new TestReplicationSource() { + + ActionListener listener; + + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + // set the listener, we will only fail it once we cancel the source. + this.listener = listener; + // shard is closing while we are copying files. + targetService.beforeIndexShardClosed(replica.shardId, replica, Settings.EMPTY); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + IndexShard indexShard, + ActionListener listener + ) { + Assert.fail("Unreachable"); + } + + @Override + public void cancel() { + // simulate listener resolving, but only after we have issued a cancel from beforeIndexShardClosed . + final RuntimeException exception = new CancellableThreads.ExecutionCancelledException("retryable action was cancelled"); + listener.onFailure(exception); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + + public void testBeforeIndexShardClosedWhileCopyingFiles() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new TestReplicationSource() { + + ActionListener listener; + + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + resolveCheckpointInfoResponseListener(listener, primary); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + IndexShard indexShard, + ActionListener listener + ) { + // set the listener, we will only fail it once we cancel the source. + this.listener = listener; + // shard is closing while we are copying files. + targetService.beforeIndexShardClosed(replica.shardId, replica, Settings.EMPTY); + } + + @Override + public void cancel() { + // simulate listener resolving, but only after we have issued a cancel from beforeIndexShardClosed . + final RuntimeException exception = new CancellableThreads.ExecutionCancelledException("retryable action was cancelled"); + listener.onFailure(exception); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + public void testPrimaryCancelsExecution() throws Exception { try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { shards.startAll(); @@ -1245,7 +1366,6 @@ public void onReplicationDone(SegmentReplicationState state) { @Override public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) { assertFalse(sendShardFailure); - assertEquals(SegmentReplicationState.Stage.CANCELLED, state.getStage()); latch.countDown(); } } diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index c632f2843cba2..32203c28c8ed8 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -8,8 +8,8 @@ package org.opensearch.indices.replication; +import org.apache.lucene.store.AlreadyClosedException; import org.junit.Assert; -import org.mockito.Mockito; import org.opensearch.OpenSearchException; import org.opensearch.Version; import org.opensearch.action.ActionListener; @@ -26,6 +26,7 @@ import org.opensearch.index.replication.TestReplicationSource; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; +import org.opensearch.index.shard.ShardId; import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.indices.IndicesService; import org.opensearch.indices.recovery.ForceSyncRequest; @@ -49,6 +50,7 @@ import java.util.concurrent.TimeUnit; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; @@ -62,7 +64,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -import static org.opensearch.indices.replication.SegmentReplicationState.Stage.CANCELLED; public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { @@ -240,71 +241,81 @@ public void testAlreadyOnNewCheckpoint() { verify(spy, times(0)).startReplication(any(), any()); } - public void testShardAlreadyReplicating() throws InterruptedException { - // Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it. - SegmentReplicationTargetService serviceSpy = spy(sut); - final SegmentReplicationTarget target = new SegmentReplicationTarget( - replicaShard, - replicationSource, - mock(SegmentReplicationTargetService.SegmentReplicationListener.class) - ); - // Create a Mockito spy of target to stub response of few method calls. - final SegmentReplicationTarget targetSpy = Mockito.spy(target); - CountDownLatch latch = new CountDownLatch(1); - // Mocking response when startReplication is called on targetSpy we send a new checkpoint to serviceSpy and later reduce countdown - // of latch. - doAnswer(invocation -> { - final ActionListener listener = invocation.getArgument(0); - // a new checkpoint arrives before we've completed. - serviceSpy.onNewCheckpoint(aheadCheckpoint, replicaShard); - listener.onResponse(null); - latch.countDown(); - return null; - }).when(targetSpy).startReplication(any()); - doNothing().when(targetSpy).onDone(); - - // start replication of this shard the first time. - serviceSpy.startReplication(targetSpy); + public void testShardAlreadyReplicating() { + sut.startReplication(replicaShard, mock(SegmentReplicationTargetService.SegmentReplicationListener.class)); + sut.startReplication(replicaShard, new SegmentReplicationTargetService.SegmentReplicationListener() { + @Override + public void onReplicationDone(SegmentReplicationState state) { + Assert.fail("Should not succeed"); + } - // wait for the new checkpoint to arrive, before the listener completes. - latch.await(30, TimeUnit.SECONDS); - verify(targetSpy, times(0)).cancel(any()); - verify(serviceSpy, times(0)).startReplication(eq(replicaShard), any()); + @Override + public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) { + assertEquals("Shard " + replicaShard.shardId() + " is already replicating", e.getMessage()); + assertFalse(sendShardFailure); + } + }); } - public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws IOException, InterruptedException { + public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws InterruptedException { // Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it. SegmentReplicationTargetService serviceSpy = spy(sut); + doNothing().when(serviceSpy).updateVisibleCheckpoint(anyLong(), any()); + // skip post replication actions so we can assert execution counts. This will continue to process bc replica's pterm is not advanced + // post replication. + doReturn(true).when(serviceSpy).processLatestReceivedCheckpoint(any(), any()); // Create a Mockito spy of target to stub response of few method calls. - final SegmentReplicationTarget targetSpy = spy( - new SegmentReplicationTarget( - replicaShard, - replicationSource, - mock(SegmentReplicationTargetService.SegmentReplicationListener.class) - ) - ); CountDownLatch latch = new CountDownLatch(1); - // Mocking response when startReplication is called on targetSpy we send a new checkpoint to serviceSpy and later reduce countdown - // of latch. - doAnswer(invocation -> { - // short circuit loop on new checkpoint request - doReturn(null).when(serviceSpy).startReplication(eq(replicaShard), any()); - // a new checkpoint arrives before we've completed. - serviceSpy.onNewCheckpoint(newPrimaryCheckpoint, replicaShard); - try { - invocation.callRealMethod(); - } catch (CancellableThreads.ExecutionCancelledException e) { + SegmentReplicationSource source = new TestReplicationSource() { + + ActionListener listener; + + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + // set the listener, we will only fail it once we cancel the source. + this.listener = listener; latch.countDown(); + // do not resolve this listener yet, wait for cancel to hit. } - return null; - }).when(targetSpy).startReplication(any()); + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + IndexShard indexShard, + ActionListener listener + ) { + Assert.fail("Unreachable"); + } + + @Override + public void cancel() { + // simulate listener resolving, but only after we have issued a cancel from beforeIndexShardClosed . + final RuntimeException exception = new CancellableThreads.ExecutionCancelledException("retryable action was cancelled"); + listener.onFailure(exception); + } + }; + + final SegmentReplicationTarget targetSpy = spy( + new SegmentReplicationTarget(replicaShard, source, mock(SegmentReplicationTargetService.SegmentReplicationListener.class)) + ); // start replication. This adds the target to on-ongoing replication collection serviceSpy.startReplication(targetSpy); + + // wait until we get to getCheckpoint step. latch.await(); - // wait for the new checkpoint to arrive, before the listener completes. - assertEquals(CANCELLED, targetSpy.state().getStage()); + + // new checkpoint arrives with higher pterm. + serviceSpy.onNewCheckpoint(newPrimaryCheckpoint, replicaShard); + + // ensure the old target is cancelled. and new iteration kicks off. verify(targetSpy, times(1)).cancel("Cancelling stuck target after new primary"); verify(serviceSpy, times(1)).startReplication(eq(replicaShard), any()); } @@ -467,6 +478,7 @@ public void testForceSegmentSyncHandler() throws Exception { } public void testForceSegmentSyncHandlerWithFailure() throws Exception { + allowShardFailures(); IndexShard spyReplicaShard = spy(replicaShard); ForceSyncRequest forceSyncRequest = new ForceSyncRequest(1L, 1L, replicaShard.shardId()); when(indicesService.getShardOrNull(forceSyncRequest.getShardId())).thenReturn(spyReplicaShard); @@ -488,4 +500,57 @@ public void testForceSegmentSyncHandlerWithFailure() throws Exception { assertTrue(nestedException instanceof IOException); assertTrue(nestedException.getMessage().contains("dummy failure")); } + + public void testForceSync_ShardDoesNotExist() { + ForceSyncRequest forceSyncRequest = new ForceSyncRequest(1L, 1L, new ShardId("no", "", 0)); + when(indicesService.getShardOrNull(forceSyncRequest.getShardId())).thenReturn(null); + transportService.submitRequest( + localNode, + SegmentReplicationTargetService.Actions.FORCE_SYNC, + forceSyncRequest, + TransportRequestOptions.builder().withTimeout(TRANSPORT_TIMEOUT).build(), + EmptyTransportResponseHandler.INSTANCE_SAME + ).txGet(); + } + + public void testForceSegmentSyncHandlerWithFailure_AlreadyClosedException_swallowed() throws Exception { + IndexShard spyReplicaShard = spy(replicaShard); + ForceSyncRequest forceSyncRequest = new ForceSyncRequest(1L, 1L, replicaShard.shardId()); + when(indicesService.getShardOrNull(forceSyncRequest.getShardId())).thenReturn(spyReplicaShard); + + AlreadyClosedException exception = new AlreadyClosedException("shard closed"); + doThrow(exception).when(spyReplicaShard).finalizeReplication(any()); + + // prevent shard failure to avoid test setup assertion + doNothing().when(spyReplicaShard).failShard(eq("replication failure"), any()); + transportService.submitRequest( + localNode, + SegmentReplicationTargetService.Actions.FORCE_SYNC, + forceSyncRequest, + TransportRequestOptions.builder().withTimeout(TRANSPORT_TIMEOUT).build(), + EmptyTransportResponseHandler.INSTANCE_SAME + ).txGet(); + } + + public void testTargetCancelledBeforeStartInvoked() { + final SegmentReplicationTarget target = new SegmentReplicationTarget( + replicaShard, + mock(SegmentReplicationSource.class), + new SegmentReplicationTargetService.SegmentReplicationListener() { + @Override + public void onReplicationDone(SegmentReplicationState state) { + Assert.fail(); + } + + @Override + public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) { + // failures leave state object in last entered stage. + assertEquals(SegmentReplicationState.Stage.GET_CHECKPOINT_INFO, state.getStage()); + assertTrue(e.getCause() instanceof CancellableThreads.ExecutionCancelledException); + } + } + ); + target.cancel("test"); + sut.startReplication(target); + } } diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetTests.java index ac8904527f7fb..4fb1edb4e496e 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetTests.java @@ -27,6 +27,7 @@ import org.junit.Assert; import org.mockito.Mockito; import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchCorruptionException; import org.opensearch.action.ActionListener; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.Settings; @@ -373,8 +374,9 @@ public void onResponse(Void replicationResponse) { @Override public void onFailure(Exception e) { - assert (e instanceof ReplicationFailedException); - assert (e.getMessage().contains("different segment files")); + assertTrue(e instanceof OpenSearchCorruptionException); + assertTrue(e.getMessage().contains("has local copies of segments that differ from the primary")); + segrepTarget.fail(new ReplicationFailedException(e), false); } }); } diff --git a/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java b/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java index 75ac1075e8ee0..0bb1bc210d946 100644 --- a/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java +++ b/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java @@ -38,12 +38,15 @@ import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.ShardId; import org.opensearch.index.store.Store; +import org.opensearch.indices.replication.SegmentReplicationSource; +import org.opensearch.indices.replication.SegmentReplicationTarget; import org.opensearch.indices.replication.common.ReplicationCollection; import org.opensearch.indices.replication.common.ReplicationFailedException; import org.opensearch.indices.replication.common.ReplicationListener; import org.opensearch.indices.replication.common.ReplicationState; import org.opensearch.indices.recovery.RecoveryState; import org.opensearch.indices.recovery.RecoveryTarget; +import org.opensearch.indices.replication.common.ReplicationTarget; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -51,6 +54,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.lessThan; +import static org.mockito.Mockito.mock; public class ReplicationCollectionTests extends OpenSearchIndexLevelReplicationTestCase { static final ReplicationListener listener = new ReplicationListener() { @@ -108,7 +112,30 @@ public void onFailure(ReplicationState state, ReplicationFailedException e, bool } } - public void testMultiReplicationsForSingleShard() throws Exception { + public void testStartMultipleReplicationsForSingleShard() throws Exception { + try (ReplicationGroup shards = createGroup(0)) { + shards.startAll(); + final ReplicationCollection collection = new ReplicationCollection<>(logger, threadPool); + final IndexShard shard = shards.addReplica(); + shards.recoverReplica(shard); + final SegmentReplicationTarget target1 = new SegmentReplicationTarget( + shard, + mock(SegmentReplicationSource.class), + mock(ReplicationListener.class) + ); + final SegmentReplicationTarget target2 = new SegmentReplicationTarget( + shard, + mock(SegmentReplicationSource.class), + mock(ReplicationListener.class) + ); + collection.startSafe(target1, TimeValue.timeValueMinutes(30)); + assertThrows(ReplicationFailedException.class, () -> collection.startSafe(target2, TimeValue.timeValueMinutes(30))); + target1.decRef(); + target2.decRef(); + } + } + + public void testGetReplicationTargetMultiReplicationsForSingleShard() throws Exception { try (ReplicationGroup shards = createGroup(0)) { final ReplicationCollection collection = new ReplicationCollection<>(logger, threadPool); final IndexShard shard1 = shards.addReplica();