diff --git a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java index 7855468ee61..cc3ac9f516e 100644 --- a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java +++ b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java @@ -132,6 +132,7 @@ final class CachingRlsLbClient { @GuardedBy("lock") private final RefCountedChildPolicyWrapperFactory refCountedChildPolicyWrapperFactory; private final ChannelLogger logger; + private final ChildPolicyWrapper fallbackChildPolicyWrapper; static { MetricInstrumentRegistry metricInstrumentRegistry @@ -226,6 +227,13 @@ private CachingRlsLbClient(Builder builder) { lbPolicyConfig.getLoadBalancingPolicy(), childLbResolvedAddressFactory, childLbHelperProvider, new BackoffRefreshListener()); + // TODO(creamsoup) wait until lb is ready + String defaultTarget = lbPolicyConfig.getRouteLookupConfig().defaultTarget(); + if (defaultTarget != null && !defaultTarget.isEmpty()) { + fallbackChildPolicyWrapper = refCountedChildPolicyWrapperFactory.createOrGet(defaultTarget); + } else { + fallbackChildPolicyWrapper = null; + } gaugeRegistration = helper.getMetricRecorder() .registerBatchCallback(new BatchCallback() { @@ -1022,12 +1030,8 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { } } - private ChildPolicyWrapper fallbackChildPolicyWrapper; - /** Uses Subchannel connected to default target. */ private PickResult useFallback(PickSubchannelArgs args) { - // TODO(creamsoup) wait until lb is ready - startFallbackChildPolicy(); SubchannelPicker picker = fallbackChildPolicyWrapper.getPicker(); if (picker == null) { return PickResult.withNoResult(); @@ -1052,17 +1056,6 @@ private String determineMetricsPickResult(PickResult pickResult) { } } - private void startFallbackChildPolicy() { - String defaultTarget = lbPolicyConfig.getRouteLookupConfig().defaultTarget(); - synchronized (lock) { - if (fallbackChildPolicyWrapper != null) { - return; - } - logger.log(ChannelLogLevel.DEBUG, "starting fallback to {0}", defaultTarget); - fallbackChildPolicyWrapper = refCountedChildPolicyWrapperFactory.createOrGet(defaultTarget); - } - } - // GuardedBy CachingRlsLbClient.lock void close() { synchronized (lock) { // Lock is already held, but ErrorProne can't tell diff --git a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java index 7c5df2c96b3..4f086abc4a2 100644 --- a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java +++ b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java @@ -191,7 +191,9 @@ public void setUpMockMetricRecorder() { @After public void tearDown() throws Exception { - rlsLbClient.close(); + if (rlsLbClient != null) { + rlsLbClient.close(); + } assertWithMessage( "On client shut down, RlsLoadBalancer must shut down with all its child loadbalancers.") .that(lbProvider.loadBalancers).isEmpty(); @@ -372,12 +374,14 @@ public void get_updatesLbState() throws Exception { ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor stateCaptor = ArgumentCaptor.forClass(ConnectivityState.class); - inOrder.verify(helper, times(2)) + inOrder.verify(helper, times(3)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); assertThat(new HashSet<>(pickerCaptor.getAllValues())).hasSize(1); + // TRANSIENT_FAILURE is because the test setup pretends fallback is not available. assertThat(stateCaptor.getAllValues()) - .containsExactly(ConnectivityState.CONNECTING, ConnectivityState.READY); + .containsExactly(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING, + ConnectivityState.READY); Metadata headers = new Metadata(); PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); assertThat(pickResult.getStatus().isOk()).isTrue(); @@ -439,7 +443,7 @@ public void timeout_not_changing_picked_subchannel() throws Exception { ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor stateCaptor = ArgumentCaptor.forClass(ConnectivityState.class); - verify(helper, times(4)).updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + verify(helper, times(5)).updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); Metadata headers = new Metadata(); PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); @@ -509,7 +513,7 @@ public void get_withAdaptiveThrottler() throws Exception { ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor stateCaptor = ArgumentCaptor.forClass(ConnectivityState.class); - inOrder.verify(helper, times(2)) + inOrder.verify(helper, times(3)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); Metadata headers = new Metadata(); @@ -699,6 +703,7 @@ public void metricGauges() throws ExecutionException, InterruptedException, Time // Shutdown rlsLbClient.close(); + rlsLbClient = null; verify(mockGaugeRegistration).close(); } diff --git a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java index f3986cb89d5..354466f3caf 100644 --- a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java @@ -201,7 +201,13 @@ public void tearDown() { @Test public void lb_serverStatusCodeConversion() throws Exception { - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); @@ -236,7 +242,13 @@ public void lb_serverStatusCodeConversion() throws Exception { @Test public void lb_working_withDefaultTarget_rlsResponding() throws Exception { - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); @@ -257,7 +269,7 @@ public void lb_working_withDefaultTarget_rlsResponding() throws Exception { inOrder.verifyNoMoreInteractions(); assertThat(res.getStatus().isOk()).isTrue(); - assertThat(subchannels).hasSize(1); + assertThat(subchannels).hasSize(2); // includes fallback sub-channel FakeSubchannel searchSubchannel = subchannels.getLast(); assertThat(subchannelIsReady(searchSubchannel)).isFalse(); @@ -277,7 +289,7 @@ public void lb_working_withDefaultTarget_rlsResponding() throws Exception { // other rls picker itself is ready due to first channel. assertThat(res.getStatus().isOk()).isTrue(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); - assertThat(subchannels).hasSize(2); + assertThat(subchannels).hasSize(3); // includes fallback sub-channel FakeSubchannel rescueSubchannel = subchannels.getLast(); // search subchannel is down, rescue subchannel is connecting @@ -393,7 +405,13 @@ public void lb_working_withoutDefaultTarget_noRlsResponse() throws Exception { public void lb_working_withDefaultTarget_noRlsResponse() throws Exception { fakeThrottler.nextResult = true; - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); @@ -535,7 +553,13 @@ public void lb_working_withoutDefaultTarget() throws Exception { @Test public void lb_nameResolutionFailed() throws Exception { - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); @@ -545,7 +569,7 @@ public void lb_nameResolutionFailed() throws Exception { assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - assertThat(subchannels).hasSize(1); + assertThat(subchannels).hasSize(2); // includes fallback sub-channel FakeSubchannel searchSubchannel = subchannels.getLast(); searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));