Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix racing condition in BaseQueryRewriteContext #17124

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG-3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix compression support for h2c protocol ([#4944](https://github.com/opensearch-project/OpenSearch/pull/4944))
- Don't over-allocate in HeapBufferedAsyncEntityConsumer in order to consume the response ([#9993](https://github.com/opensearch-project/OpenSearch/pull/9993))
- Fix swapped field formats in nodes API where `total_indexing_buffer_in_bytes` and `total_indexing_buffer` values were reversed ([#17070](https://github.com/opensearch-project/OpenSearch/pull/17070))

- Fix racing condition in BaseQueryRewriteContext([#17124](https://github.com/opensearch-project/OpenSearch/pull/17124))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.LongSupplier;

Expand All @@ -32,7 +33,7 @@ public class BaseQueryRewriteContext implements QueryRewriteContext {
private final NamedWriteableRegistry writeableRegistry;
protected final Client client;
protected final LongSupplier nowInMillis;
private final List<BiConsumer<Client, ActionListener<?>>> asyncActions = new ArrayList<>();
private final AtomicReference<List<BiConsumer<Client, ActionListener<?>>>> asyncActionsRef = new AtomicReference<>(new ArrayList<>());
private final boolean validate;

public BaseQueryRewriteContext(
Expand Down Expand Up @@ -90,21 +91,27 @@ public QueryShardContext convertToShardContext() {
* from an index.
*/
public void registerAsyncAction(BiConsumer<Client, ActionListener<?>> asyncAction) {
asyncActions.add(asyncAction);
asyncActionsRef.updateAndGet(list -> {
List<BiConsumer<Client, ActionListener<?>>> newList = new ArrayList<>(list);
newList.add(asyncAction);
return newList;
});
Comment on lines +94 to +98
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to copy the list here; this gives us no value over a CopyOnWriteArrayList which we wouldn't need the atomic reference for. Can't we simply just update the list in-place but wrapped in the atomic reference for thread safety? (CC @msfroh to confirm this suggestion):

Suggested change
asyncActionsRef.updateAndGet(list -> {
List<BiConsumer<Client, ActionListener<?>>> newList = new ArrayList<>(list);
newList.add(asyncAction);
return newList;
});
asyncActionsRef.updateAndGet(list -> {
list.add(asyncAction);
return list;
});

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, if not returning new list, I am having test failure, so it still cause racing condition during registration.

BaseQueryRewriteContextTests > testRacingConditionFixed FAILED
    java.lang.AssertionError: expected:<5000> but was:<4874>
        at org.junit.Assert.fail(Assert.java:89)
        at org.junit.Assert.failNotEquals(Assert.java:835)
        at org.junit.Assert.assertEquals(Assert.java:647)
        at org.junit.Assert.assertEquals(Assert.java:633)
        at org.opensearch.index.query.BaseQueryRewriteContextTests.testRacingConditionFixed(BaseQueryRewriteContextTests.java:167)

BaseQueryRewriteContextTests > testConcurrentRegistrationAndExecution FAILED
    java.lang.AssertionError: expected:<10000> but was:<7802>
        at org.junit.Assert.fail(Assert.java:89)
        at org.junit.Assert.failNotEquals(Assert.java:835)
        at org.junit.Assert.assertEquals(Assert.java:647)
        at org.junit.Assert.assertEquals(Assert.java:633)
        at org.opensearch.index.query.BaseQueryRewriteContextTests.testConcurrentRegistrationAndExecution(BaseQueryRewriteContextTests.java:108)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... if we need to copy the whole list anyway, I think I like @dbwiddis's suggestion of using CopyOnWriteArrayList over my AtomicReference suggestion.

}

/**
* Returns <code>true</code> if there are any registered async actions.
*/
public boolean hasAsyncActions() {
return asyncActions.isEmpty() == false;
return asyncActionsRef.get().isEmpty() == false;
}

/**
* Executes all registered async actions and notifies the listener once it's done. The value that is passed to the listener is always
* <code>null</code>. The list of registered actions is cleared once this method returns.
*/
public void executeAsyncActions(ActionListener listener) {
// get asyncActions before execute
List<BiConsumer<Client, ActionListener<?>>> asyncActions = asyncActionsRef.getAndSet(new ArrayList<>());
if (asyncActions.isEmpty()) {
listener.onResponse(null);
return;
Expand All @@ -126,11 +133,13 @@ public void onFailure(Exception e) {
}
}
};
// make a copy to prevent concurrent modification exception
List<BiConsumer<Client, ActionListener<?>>> biConsumers = new ArrayList<>(asyncActions);
asyncActions.clear();
for (BiConsumer<Client, ActionListener<?>> action : biConsumers) {
action.accept(client, internalListener);

for (BiConsumer<Client, ActionListener<?>> action : asyncActions) {
if (action != null) {
action.accept(client, internalListener);
} else {
countDown.countDown();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query;

import org.opensearch.client.Client;
import org.opensearch.common.util.concurrent.CountDown;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.junit.Before;
import org.junit.Test;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;

/**
* Unit tests for the BaseQueryRewriteContext class to verify the fix for racing conditions
* in async action registration and execution.
*/
public class BaseQueryRewriteContextTests {
private BaseQueryRewriteContext context;
private Client mockClient;

@Before
public void setUp() {
mockClient = mock(Client.class);
context = new BaseQueryRewriteContext(
mock(NamedXContentRegistry.class),
mock(NamedWriteableRegistry.class),
mockClient,
() -> System.currentTimeMillis()
);
}

/**
* Tests concurrent registration and execution of async actions.
*
* This test simulates a scenario where multiple threads are simultaneously
* registering a large number of async actions, followed by a single execution
* of all registered actions. It verifies that:
* 1. All registered actions are executed correctly.
* 2. The total number of executed actions matches the expected count.
* 3. There are no remaining async actions after execution.
* 4. No exceptions occur during the process, indicating thread-safety.
*
* @throws InterruptedException if the test is interrupted while waiting for threads to complete
*/
@Test
public void testConcurrentRegistrationAndExecution() throws InterruptedException {
int numThreads = 10;
int actionsPerThread = 1000;
mingshl marked this conversation as resolved.
Show resolved Hide resolved
ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
CountDown startCountDown = new CountDown(1);
CountDown endCountDown = new CountDown(numThreads);
AtomicInteger totalExecutedActions = new AtomicInteger(0);

for (int i = 0; i < numThreads; i++) {
executorService.submit(() -> {
while (startCountDown.isCountedDown() == false) {
Thread.yield();
}
for (int j = 0; j < actionsPerThread; j++) {
context.registerAsyncAction((client, listener) -> {
totalExecutedActions.incrementAndGet();
listener.onResponse(null);
});
}
endCountDown.countDown();
});
}

startCountDown.countDown();
while (endCountDown.isCountedDown() == false) {
Thread.yield();
}

CountDown executionCountDown = new CountDown(1);
context.executeAsyncActions(new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
executionCountDown.countDown();
}

@Override
public void onFailure(Exception e) {
fail("Execution failed: " + e.getMessage());
}
});

while (executionCountDown.isCountedDown() == false) {
Thread.yield();
}
ensureAllActionsExecuted();
assertEquals(numThreads * actionsPerThread, totalExecutedActions.get());
assertFalse(context.hasAsyncActions());

executorService.shutdown();
assertTrue(executorService.awaitTermination(10, TimeUnit.SECONDS));
}

/**
* Tests the fix for the racing condition by simulating concurrent registration and execution.
*
* This test creates a scenario where multiple threads are simultaneously:
* 1. Registering async actions
* 2. Periodically executing the registered actions
*
* It verifies that:
* 1. No exceptions occur during the process, indicating thread-safety.
* 2. All actions are eventually executed, leaving no remaining async actions.
* 3. The fix prevents any race conditions that could occur in this situation.
*
* @throws InterruptedException if the test is interrupted while waiting for threads to complete
*/
@Test
public void testRacingConditionFixed() throws InterruptedException {
int numThreads = 5;
int actionsPerThread = 1000;
ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
CountDown startCountDown = new CountDown(1);
CountDown endCountDown = new CountDown(numThreads);
AtomicInteger totalExecutedActions = new AtomicInteger(0);

for (int i = 0; i < numThreads; i++) {
executorService.submit(() -> {
while (startCountDown.isCountedDown() == false) {
Thread.yield();
}
for (int j = 0; j < actionsPerThread; j++) {
context.registerAsyncAction((client, listener) -> {
totalExecutedActions.incrementAndGet();
listener.onResponse(null);
});
if (j % 100 == 0) {
context.executeAsyncActions(ActionListener.wrap(v -> {}, e -> fail("Execution failed: " + e.getMessage())));
}
}
endCountDown.countDown();
});
}

startCountDown.countDown();
while (endCountDown.isCountedDown() == false) {
Thread.yield();
}

// Final execution to ensure all remaining actions are processed
context.executeAsyncActions(ActionListener.wrap(v -> {}, e -> fail("Final execution failed: " + e.getMessage())));

executorService.shutdown();
assertTrue(executorService.awaitTermination(30, TimeUnit.SECONDS)); // Increased timeout
ensureAllActionsExecuted();
assertEquals(numThreads * actionsPerThread, totalExecutedActions.get());
assertFalse(context.hasAsyncActions());
}

private void ensureAllActionsExecuted() {
int maxAttempts = 10;
for (int i = 0; i < maxAttempts && context.hasAsyncActions(); i++) {
context.executeAsyncActions(ActionListener.wrap(v -> {}, e -> fail("Execution failed: " + e.getMessage())));
try {
Thread.sleep(100); // Give some time for actions to complete
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
}
Loading