Skip to content

Commit

Permalink
core: Server side cancellations should promptly inform server (#2963)
Browse files Browse the repository at this point in the history
When a cancellation happens, the ServerCall and Context get notified. Rather than serializing on the normal work queue (which may be doing user computation), we should execute the notification immediately, thereby allowing the user computation to see the cancellation.
  • Loading branch information
zpencer authored May 30, 2017
1 parent 37e2131 commit 8a1217d
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 21 deletions.
12 changes: 12 additions & 0 deletions core/src/main/java/io/grpc/internal/ServerCallImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Attributes;
import io.grpc.Codec;
import io.grpc.Compressor;
Expand Down Expand Up @@ -221,6 +222,17 @@ public ServerStreamListenerImpl(
this.call = checkNotNull(call, "call");
this.listener = checkNotNull(listener, "listener must not be null");
this.context = checkNotNull(context, "context");
// Wire ourselves up so that if the context is cancelled, our flag call.cancelled also
// reflects the new state. Use a DirectExecutor so that it happens in the same thread
// as the caller of {@link Context#cancel}.
this.context.addListener(
new Context.CancellationListener() {
@Override
public void cancelled(Context context) {
ServerStreamListenerImpl.this.call.cancelled = true;
}
},
MoreExecutors.directExecutor());
}

@SuppressWarnings("Finally") // The code avoids suppressing the exception thrown from try
Expand Down
40 changes: 33 additions & 7 deletions core/src/main/java/io/grpc/internal/ServerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ public void streamCreated(
}

final JumpToApplicationThreadServerStreamListener jumpListener
= new JumpToApplicationThreadServerStreamListener(wrappedExecutor, stream, context);
= new JumpToApplicationThreadServerStreamListener(
wrappedExecutor, executor, stream, context);
stream.setListener(jumpListener);
// Run in wrappedExecutor so jumpListener.setListener() is called before any callbacks
// are delivered, including any errors. Callbacks can still be triggered, but they will be
Expand Down Expand Up @@ -512,18 +513,23 @@ public void onReady() {}
@VisibleForTesting
static class JumpToApplicationThreadServerStreamListener implements ServerStreamListener {
private final Executor callExecutor;
private final Executor cancelExecutor;
private final Context.CancellableContext context;
private final ServerStream stream;
// Only accessed from callExecutor.
private ServerStreamListener listener;

public JumpToApplicationThreadServerStreamListener(Executor executor,
ServerStream stream, Context.CancellableContext context) {
Executor cancelExecutor, ServerStream stream, Context.CancellableContext context) {
this.callExecutor = executor;
this.cancelExecutor = cancelExecutor;
this.stream = stream;
this.context = context;
}

/**
* This call MUST be serialized on callExecutor to avoid races.
*/
private ServerStreamListener getListener() {
if (listener == null) {
throw new IllegalStateException("listener unset");
Expand Down Expand Up @@ -584,16 +590,20 @@ public void runInContext() {

@Override
public void closed(final Status status) {
// For cancellations, promptly inform any users of the context that their work should be
// aborted. Otherwise, we can wait until pending work is done.
if (!status.isOk()) {
// The callExecutor might be busy doing user work. To avoid waiting, use an executor that
// is not serializing.
cancelExecutor.execute(new ContextCloser(context, status.getCause()));
}
callExecutor.execute(new ContextRunnable(context) {
@Override
public void runInContext() {
try {
getListener().closed(status);
} finally {
// Regardless of the status code we cancel the context so that listeners
// are aware that the call is done.
if (status.isOk()) {
context.cancel(status.getCause());
}
getListener().closed(status);
}
});
}
Expand All @@ -616,4 +626,20 @@ public void runInContext() {
});
}
}

@VisibleForTesting
static class ContextCloser implements Runnable {
private final Context.CancellableContext context;
private final Throwable cause;

ContextCloser(Context.CancellableContext context, Throwable cause) {
this.context = context;
this.cause = cause;
}

@Override
public void run() {
context.cancel(cause);
}
}
}
36 changes: 33 additions & 3 deletions core/src/test/java/io/grpc/internal/FakeClock.java
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,38 @@ public Ticker getTicker() {
* @return the number of tasks run by this call
*/
public int runDueTasks() {
return runDueTasks(new TaskFilter() {
@Override
public boolean shouldRun(Runnable runnable) {
return true;
}
});
}

/**
* Run all due tasks that match the {@link TaskFilter}.
*
* @return the number of tasks run by this call
*/
public int runDueTasks(TaskFilter filter) {
int count = 0;
List<ScheduledTask> putBack = new ArrayList<ScheduledTask>();
while (true) {
ScheduledTask task = tasks.peek();
if (task == null || task.dueTimeNanos > currentTimeNanos) {
break;
}
if (tasks.remove(task)) {
task.command.run();
task.complete();
count++;
if (filter.shouldRun(task.command)) {
task.command.run();
task.complete();
count++;
} else {
putBack.add(task);
}
}
}
tasks.addAll(putBack);
return count;
}

Expand Down Expand Up @@ -288,4 +308,14 @@ public long currentTimeMillis() {
// Normally millis and nanos are of different epochs. Add an offset to simulate that.
return TimeUnit.NANOSECONDS.toMillis(currentTimeNanos + 123456789L);
}

/**
* A filter that allows us to have fine grained control over which tasks are run.
*/
public interface TaskFilter {
/**
* Inspect the Runnable and returns true if it should be run.
*/
boolean shouldRun(Runnable runnable);
}
}
32 changes: 32 additions & 0 deletions core/src/test/java/io/grpc/internal/FakeClockTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -171,6 +172,37 @@ public void testPendingAndDueTasks() {
assertEquals(0, fakeClock.getDueTasks().size());
}

@Test
public void testTaskFilter() {
FakeClock fakeClock = new FakeClock();
ScheduledExecutorService scheduledExecutorService = fakeClock.getScheduledExecutorService();
final AtomicBoolean selectedDone = new AtomicBoolean();
final AtomicBoolean ignoredDone = new AtomicBoolean();
final Runnable selectedRunnable = new Runnable() {
@Override
public void run() {
selectedDone.set(true);
}
};
Runnable ignoredRunnable = new Runnable() {
@Override
public void run() {
ignoredDone.set(true);
}
};
scheduledExecutorService.execute(selectedRunnable);
scheduledExecutorService.execute(ignoredRunnable);
assertEquals(2, fakeClock.numPendingTasks());
assertEquals(1, fakeClock.runDueTasks(new FakeClock.TaskFilter() {
@Override
public boolean shouldRun(Runnable runnable) {
return runnable == selectedRunnable;
}
}));
assertTrue(selectedDone.get());
assertFalse(ignoredDone.get());
}

private Runnable newRunnable() {
return new Runnable() {
@Override
Expand Down
94 changes: 83 additions & 11 deletions core/src/test/java/io/grpc/internal/ServerImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ public class ServerImplTest {
private static final Context.CancellableContext SERVER_CONTEXT =
Context.ROOT.withValue(SERVER_ONLY, "yes").withCancellation();
private static final ImmutableList<ServerTransportFilter> NO_FILTERS = ImmutableList.of();
private static final FakeClock.TaskFilter CONTEXT_CLOSER_TASK_FITLER =
new FakeClock.TaskFilter() {
@Override
public boolean shouldRun(Runnable runnable) {
return runnable instanceof ServerImpl.ContextCloser;
}
};

private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
private final DecompressorRegistry decompressorRegistry =
Expand Down Expand Up @@ -766,20 +773,23 @@ private void checkContext() {
assertTrue(onHalfCloseCalled.get());

streamListener.closed(Status.CANCELLED);
assertEquals(1, executor.runDueTasks(CONTEXT_CLOSER_TASK_FITLER));
assertEquals(1, executor.runDueTasks());
assertTrue(onCancelCalled.get());

// Close should never be called if asserts in listener pass.
verify(stream, times(0)).close(isA(Status.class), isNotNull(Metadata.class));
}

@Test
public void testClientCancelTriggersContextCancellation() throws Exception {
private ServerStreamListener testClientClose_setup(
final AtomicReference<ServerCall<String, Integer>> callReference,
final AtomicReference<Context> context,
final AtomicBoolean contextCancelled) throws Exception {
createAndStartServer(NO_FILTERS);
final AtomicBoolean contextCancelled = new AtomicBoolean(false);
callListener = new ServerCall.Listener<String>() {
@Override
public void onReady() {
context.set(Context.current());
Context.current().addListener(new Context.CancellationListener() {
@Override
public void cancelled(Context context) {
Expand All @@ -789,8 +799,6 @@ public void cancelled(Context context) {
}
};

final AtomicReference<ServerCall<String, Integer>> callReference
= new AtomicReference<ServerCall<String, Integer>>();
MethodDescriptor<String, Integer> method = MethodDescriptor.<String, Integer>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName("Waiter/serve")
Expand Down Expand Up @@ -822,9 +830,55 @@ public ServerCall.Listener<String> startCall(
assertNotNull(streamListener);

streamListener.onReady();
assertEquals(1, executor.runDueTasks());
return streamListener;
}

@Test
public void testClientClose_cancelTriggersImmediateCancellation() throws Exception {
AtomicBoolean contextCancelled = new AtomicBoolean(false);
AtomicReference<Context> context = new AtomicReference<Context>();
AtomicReference<ServerCall<String, Integer>> callReference
= new AtomicReference<ServerCall<String, Integer>>();

ServerStreamListener streamListener = testClientClose_setup(callReference,
context, contextCancelled);

// For close status being non OK:
// isCancelled is expected to be true immediately after calling closed(), without needing
// to wait for the main executor to run any tasks.
assertFalse(callReference.get().isCancelled());
assertFalse(context.get().isCancelled());
streamListener.closed(Status.CANCELLED);
assertEquals(1, executor.runDueTasks(CONTEXT_CLOSER_TASK_FITLER));
assertTrue(callReference.get().isCancelled());
assertTrue(context.get().isCancelled());

assertEquals(1, executor.runDueTasks());
assertTrue(contextCancelled.get());
}

@Test
public void testClientClose_OkTriggersDelayedCancellation() throws Exception {
AtomicBoolean contextCancelled = new AtomicBoolean(false);
AtomicReference<Context> context = new AtomicReference<Context>();
AtomicReference<ServerCall<String, Integer>> callReference
= new AtomicReference<ServerCall<String, Integer>>();

ServerStreamListener streamListener = testClientClose_setup(callReference,
context, contextCancelled);

// For close status OK:
// isCancelled is expected to be true after all pending work is done
assertFalse(callReference.get().isCancelled());
assertFalse(context.get().isCancelled());
streamListener.closed(Status.OK);
assertFalse(callReference.get().isCancelled());
assertFalse(context.get().isCancelled());

assertEquals(1, executor.runDueTasks());
assertTrue(callReference.get().isCancelled());
assertTrue(context.get().isCancelled());
assertTrue(contextCancelled.get());
}

Expand Down Expand Up @@ -903,7 +957,10 @@ public void handlerRegistryPriorities() throws Exception {
public void messageRead_errorCancelsCall() throws Exception {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
executor.getScheduledExecutorService(),
executor.getScheduledExecutorService(),
stream,
Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);

Expand All @@ -925,7 +982,10 @@ public void messageRead_errorCancelsCall() throws Exception {
public void messageRead_runtimeExceptionCancelsCall() throws Exception {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
executor.getScheduledExecutorService(),
executor.getScheduledExecutorService(),
stream,
Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);

Expand All @@ -947,7 +1007,10 @@ public void messageRead_runtimeExceptionCancelsCall() throws Exception {
public void halfClosed_errorCancelsCall() {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
executor.getScheduledExecutorService(),
executor.getScheduledExecutorService(),
stream,
Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);

Expand All @@ -968,7 +1031,10 @@ public void halfClosed_errorCancelsCall() {
public void halfClosed_runtimeExceptionCancelsCall() {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
executor.getScheduledExecutorService(),
executor.getScheduledExecutorService(),
stream,
Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);

Expand All @@ -989,7 +1055,10 @@ public void halfClosed_runtimeExceptionCancelsCall() {
public void onReady_errorCancelsCall() {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
executor.getScheduledExecutorService(),
executor.getScheduledExecutorService(),
stream,
Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);

Expand All @@ -1010,7 +1079,10 @@ public void onReady_errorCancelsCall() {
public void onReady_runtimeExceptionCancelsCall() {
JumpToApplicationThreadServerStreamListener listener
= new JumpToApplicationThreadServerStreamListener(
executor.getScheduledExecutorService(), stream, Context.ROOT.withCancellation());
executor.getScheduledExecutorService(),
executor.getScheduledExecutorService(),
stream,
Context.ROOT.withCancellation());
ServerStreamListener mockListener = mock(ServerStreamListener.class);
listener.setListener(mockListener);

Expand Down

0 comments on commit 8a1217d

Please sign in to comment.