Skip to content

Commit

Permalink
Fix waiter thread handle notification
Browse files Browse the repository at this point in the history
The threads that wait on a state of a task future are stored in a
collection. This collection got cleared when the task finished, and its
state was set to RESULT_READY. This operation was incorrect, as if some
threads wait on a notification about a new ancestor being added to a
task, then it will never get notified, as its thread handle got purged
from the collection.

This gets fixed by not clearing the waiting thread collection in case
the task finishes, so threads can get notified after finish.

The thread handles that are only interested in the RESULT_READY event
will still get removed not to keep them unnecessarily in the collection.

AddAncestorBlockingWaitTaskTest was extended to test this scenario. This
reliably happened when the str task finishes before the child starter
task (plus) can start it.

Related issue: #22
  • Loading branch information
Sipkab committed Jul 22, 2022
1 parent d4ccac6 commit 2b4a8cc
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 29 deletions.
59 changes: 48 additions & 11 deletions core/common/saker/build/task/TaskExecutionManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -2456,9 +2456,22 @@ public TaskDependencies getDependencies() {
}

private static class WaiterThreadHandle extends WeakReference<Thread> {
/**
* Initial state, the handle hasn't been used yet anywhere. It hasn't been added to futures for waiting.
*/
static final int STATE_INITIAL = 0;
/**
* The handle is waiting for notification, and has been added to the waiting thread collection of the relevant
* futures.
*/
static final int STATE_WAITING = 1;
/**
* The handle has been notified and should re-check the condition.
*/
static final int STATE_NOTIFIED = 2;
/**
* The handle is finished, no longer needs notification.
*/
static final int STATE_FINISHED = 3;

static final AtomicIntegerFieldUpdater<TaskExecutionManager.WaiterThreadHandle> AIFU_state = AtomicIntegerFieldUpdater
Expand All @@ -2472,6 +2485,17 @@ public WaiterThreadHandle(int triggerEvents) {
this.triggerEvents = triggerEvents;
}

/**
* Notifies and unparks the thread handle.
* <p>
* The method always unparks the associated thread, unless the state is {@link #STATE_FINISHED}.
* <p>
* The method sets the state to {@link #STATE_NOTIFIED}, and handles the waiting thread count adjustments.
*
* @param execmanager
* The execution manager.
* @return <code>false</code> if the thread handle is finished, and can be released.
*/
public boolean unparkNotify(TaskExecutionManager execmanager) {
while (true) {
int s = this.state;
Expand Down Expand Up @@ -3119,7 +3143,7 @@ private void setResultState(TaskExecutionManager execmanager, FutureState s, Fut
}
throw new AssertionError("Failed to set state for " + taskId + " (" + this.futureState + ") " + nstate);
}
unparkAllWaitingThreads(execmanager);
unparkWaitingThreadsForResult(execmanager);
}

protected TaskResultHolder<R> getWaitWithoutOutputChangeDetector(TaskExecutorContext<?> realcontext)
Expand Down Expand Up @@ -3937,7 +3961,8 @@ protected void deadlocked() {
|| s.state == STATE_INITIALIZING; s = this.futureState) {
if (ARFU_futureState.compareAndSet(this, s,
new DeadlockedFutureState<>(s.getFactory(), s.getInvocationConfiguration(), this.taskId))) {
for (WaiterThreadHandle t; (t = waitingThreads.poll()) != null;) {
ConcurrentLinkedQueue<WaiterThreadHandle> threadqueue = waitingThreads;
for (WaiterThreadHandle t; (t = threadqueue.poll()) != null;) {
LockSupport.unpark(t.get());
}
break;
Expand All @@ -3958,18 +3983,28 @@ protected void unparkWaitingThreads(TaskExecutionManager execmanager, int event)
}
}

protected void unparkAllWaitingThreads(TaskExecutionManager execmanager) {
ConcurrentLinkedQueue<WaiterThreadHandle> threadqueue = waitingThreads;
for (WaiterThreadHandle t; (t = threadqueue.poll()) != null;) {
t.unparkNotify(execmanager);
protected void unparkWaitingThreadsForResult(TaskExecutionManager execmanager) {
for (Iterator<WaiterThreadHandle> it = waitingThreads.iterator(); it.hasNext();) {
WaiterThreadHandle t = it.next();
int triggerevents = t.triggerEvents;
if ((triggerevents & STATE_RESULT_READY) == 0) {
//the thread is not interested in the RESULT_READY event
continue;
}
if (!t.unparkNotify(execmanager) || triggerevents == STATE_RESULT_READY) {
//thread handle finished
// OR
//only interested in the RESULT_READY event, so it can be removed
it.remove();
}
}
}

protected boolean unparkOneWaitingThread(TaskExecutionManager execmanager) {
for (Iterator<WaiterThreadHandle> it = waitingThreads.iterator(); it.hasNext();) {
WaiterThreadHandle t = it.next();
if (!t.unparkNotify(execmanager)) {
//thread handle finished
//thread handle finished, continue attempting to unpark the next one
it.remove();
} else {
return true;
Expand Down Expand Up @@ -6376,10 +6411,6 @@ private <R> void executeTaskRunning(TaskExecutionResult<?> previousExecutionResu
// throw exc;
// }
}
if (TestFlag.ENABLED) {
TestFlag.metric().taskFinished(taskid, factory, result, executionresult.getTaggedOutputs(),
executionresult.getMetaDatas());
}

executiondependencies.setSelfOutputChangeDetector(taskcontext.reportedOutputChangeDetector);

Expand All @@ -6393,6 +6424,12 @@ private <R> void executeTaskRunning(TaskExecutionResult<?> previousExecutionResu

future.finished(this, executionresult);

if (TestFlag.ENABLED) {
//call this after the future.finished() call
TestFlag.metric().taskFinished(taskid, factory, result, executionresult.getTaggedOutputs(),
executionresult.getMetaDatas());
}

if (hasabortedexception) {
taskRunningFailureExceptions.add(ImmutableUtils.makeImmutableMapEntry(taskid,
createFailException(taskid, taskrunningexception, abortexceptions)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
*/
package testing.saker.build.tests.tasks;

import java.io.IOException;

import saker.build.task.TaskContext;
import saker.build.task.TaskFactory;
import saker.build.task.TaskFuture;
import testing.saker.SakerTest;
import testing.saker.build.flag.TestFlag;
import testing.saker.build.tests.CollectingMetricEnvironmentTestCase;
import testing.saker.build.tests.CollectingTestMetric;
import testing.saker.build.tests.ExecutionOrderer;
import testing.saker.build.tests.tasks.factories.ChildTaskStarterTaskFactory;
import testing.saker.build.tests.tasks.factories.StringTaskFactory;
Expand Down Expand Up @@ -65,7 +69,7 @@ public class AddAncestorBlockingWaitTaskTest extends CollectingMetricEnvironment
/**
* The plus task has finished, and its result has been waited for by main.
*/
private static final String SECTION_PLUS_FINISHED = "plus_started";
private static final String SECTION_PLUS_FINISHED = "plus_finished";
/**
* The str task has ben waited for by waiter.
*/
Expand All @@ -78,6 +82,8 @@ public class AddAncestorBlockingWaitTaskTest extends CollectingMetricEnvironment
private static ExecutionOrderer orderer;
private static volatile boolean gotStrTaskResultByWaiter = false;

private static volatile boolean waitStrFinishInStarter = false;

private static class StarterTaskFactory extends SelfStatelessTaskFactory<Void> {
private static final long serialVersionUID = 1L;

Expand All @@ -86,8 +92,24 @@ public Void run(TaskContext taskcontext) throws Exception {
taskcontext.getTaskUtilities().startTaskFuture(strTaskId("waiter"), new WaiterTaskFactory());
taskcontext.getTaskUtilities().startTaskFuture(strTaskId("blocker"), new BlockerStarterTaskFactory());
orderer.enter(SECTION_PLUS_STARTER);
taskcontext.getTaskUtilities().runTaskResult(strTaskId("plus"),
new ChildTaskStarterTaskFactory().add(strTaskId("str"), new StringTaskFactory("str")));
ChildTaskStarterTaskFactory childstarter = new ChildTaskStarterTaskFactory() {
@Override
public Void run(TaskContext context) throws Exception {
if (waitStrFinishInStarter) {
System.out.println("Wait result of str task before starting it...");
while (!((CollectingTestMetric) TestFlag.metric()).getRunTaskIdResults()
.containsKey(strTaskId("str"))) {
Thread.sleep(100);
}

System.out.println("Got result of str through test metric.");
}
return super.run(context);
}
};
childstarter.add(strTaskId("str"), new StringTaskFactory("str"));

taskcontext.getTaskUtilities().runTaskResult(strTaskId("plus"), childstarter);
orderer.enter(SECTION_PLUS_FINISHED);
return null;
}
Expand Down Expand Up @@ -126,6 +148,29 @@ public Void run(TaskContext taskcontext) throws Exception {

@Override
protected void runTestImpl() throws Throwable {
for (int i = 0; i < 10; i++) {
waitStrFinishInStarter = false;
runMainTask();
cleanProject();
System.out.println();

System.out.println("Wait str:");
waitStrFinishInStarter = true;
runMainTask();
cleanProject();
System.out.println();
}
}

private void cleanProject() throws IOException {
if (project != null) {
project.clean();
} else {
files.clearDirectoryRecursively(PATH_BUILD_DIRECTORY);
}
}

private void runMainTask() throws Throwable, AssertionError {
gotStrTaskResultByWaiter = false;
ExecutionOrderer orderer = new ExecutionOrderer();
orderer.addSection(SECTION_WAITER_START);
Expand All @@ -139,6 +184,8 @@ protected void runTestImpl() throws Throwable {

AddAncestorBlockingWaitTaskTest.orderer = new ExecutionOrderer(orderer);
runTask("main", main);
assertEquals(getMetric().getRunTaskIdFactories().keySet(),
strTaskIdSetOf("main", "blocker", "str", "waiter", "plus"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import saker.build.task.identifier.TaskIdentifier;
import saker.build.thirdparty.saker.util.io.SerialUtils;

public class ChildTaskStarterTaskFactory implements TaskFactory<Void>, Externalizable {
public class ChildTaskStarterTaskFactory implements TaskFactory<Void>, Task<Void>, Externalizable {
private static final long serialVersionUID = 1L;

private Map<TaskIdentifier, TaskFactory<?>> namedChildTaskValues = new HashMap<>();
Expand Down Expand Up @@ -64,16 +64,15 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept

@Override
public Task<Void> createTask(ExecutionContext context) {
return new Task<Void>() {
@Override
public Void run(TaskContext context) {
for (Entry<? extends TaskIdentifier, ? extends TaskFactory<?>> entry : namedChildTaskValues
.entrySet()) {
context.getTaskUtilities().startTaskFuture(entry.getKey(), entry.getValue());
}
return null;
}
};
return this;
}

@Override
public Void run(TaskContext context) throws Exception {
for (Entry<? extends TaskIdentifier, ? extends TaskFactory<?>> entry : namedChildTaskValues.entrySet()) {
context.getTaskUtilities().startTaskFuture(entry.getKey(), entry.getValue());
}
return null;
}

@Override
Expand Down
24 changes: 20 additions & 4 deletions test/utils/src/testing/saker/build/tests/ExecutionOrderer.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package testing.saker.build.tests;

import java.time.Instant;
import java.time.format.DateTimeFormatter;
import java.util.LinkedList;
import java.util.Objects;

Expand All @@ -39,24 +41,34 @@ public void addSection(String id) {

public synchronized void enter(String id) throws InterruptedException {
try {
if (Thread.interrupted()) {
//check interruption before entering
//so if the thread is already interrupted when this method is called, then
//we throw an exception and dont consume a section (so errors are logged more appropriately.)
throw new InterruptedException(DateTimeFormatter.ISO_INSTANT.format(Instant.now())
+ " Interrupted while waiting for: " + id + " in " + order);
}
while (true) {
String first = order.peekFirst();
if (first == null) {
throw new IllegalArgumentException("No more sections.");
}
if (first.equals(id)) {
System.out.println("Reached: " + id);
System.out.println(
DateTimeFormatter.ISO_INSTANT.format(Instant.now()) + " ExecutionOrderer reached: " + id);
order.pollFirst();
this.notifyAll();
return;
}
if (!order.contains(id)) {
throw new IllegalArgumentException("No section found: " + id + " in " + order);
throw new IllegalArgumentException(DateTimeFormatter.ISO_INSTANT.format(Instant.now())
+ " No section found: " + id + " in " + order);
}
this.wait();
}
} catch (InterruptedException e) {
throw new InterruptedException("Interrupted while waiting for: " + id + " in " + order);
throw new InterruptedException(DateTimeFormatter.ISO_INSTANT.format(Instant.now())
+ " Interrupted while waiting for: " + id + " in " + order);
}
}

Expand All @@ -70,7 +82,11 @@ public boolean isAnySectionRemaining() {

@Override
public String toString() {
return "ExecutionOrderer[" + order + "]";
String orderstr;
synchronized (this) {
orderstr = order.toString();
}
return getClass().getSimpleName() + "[" + orderstr + "]";
}

}

0 comments on commit 2b4a8cc

Please sign in to comment.