diff --git a/android/src/main/java/org/whispersystems/signalservice/api/util/RealtimeSleepTimer.java b/android/src/main/java/org/whispersystems/signalservice/api/util/RealtimeSleepTimer.java index 55c2eb7d85..c1c7e1bc90 100644 --- a/android/src/main/java/org/whispersystems/signalservice/api/util/RealtimeSleepTimer.java +++ b/android/src/main/java/org/whispersystems/signalservice/api/util/RealtimeSleepTimer.java @@ -24,6 +24,7 @@ public class RealtimeSleepTimer implements SleepTimer { private final AlarmReceiver alarmReceiver; private final Context context; + private boolean armed; public RealtimeSleepTimer(Context context) { this.context = context; @@ -31,24 +32,31 @@ public RealtimeSleepTimer(Context context) { } @Override - public void sleep(long millis) { - context.registerReceiver(alarmReceiver, - new IntentFilter(AlarmReceiver.WAKE_UP_THREAD_ACTION)); - - final long startTime = System.currentTimeMillis(); - alarmReceiver.setAlarm(millis); - - while (System.currentTimeMillis() - startTime < millis) { - try { - synchronized (this) { - wait(millis - System.currentTimeMillis() + startTime); - } - } catch (InterruptedException e) { - Log.w(TAG, e); - } + public void sleep(long millis) throws InterruptedException { + boolean arm; + + synchronized (this) { + if (!armed) { + armed = true; + arm = true; + } else { + arm = false; + } } - context.unregisterReceiver(alarmReceiver); + if (arm) { + context.registerReceiver(alarmReceiver, + new IntentFilter(AlarmReceiver.WAKE_UP_THREAD_ACTION)); + + alarmReceiver.setAlarm(millis); + synchronized (this) { + wait(millis); + } + } else { + synchronized (this) { + wait(); + } + } } private class AlarmReceiver extends BroadcastReceiver { @@ -81,9 +89,11 @@ public void onReceive(Context context, Intent intent) { Log.w(TAG, "Waking up."); synchronized (RealtimeSleepTimer.this) { + RealtimeSleepTimer.this.context.unregisterReceiver(this); + + RealtimeSleepTimer.this.armed = false; RealtimeSleepTimer.this.notifyAll(); } } } } - diff --git a/java/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessagePipe.java b/java/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessagePipe.java index 92daad698e..143ca24cfc 100644 --- a/java/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessagePipe.java +++ b/java/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessagePipe.java @@ -68,7 +68,7 @@ public class SignalServiceMessagePipe { * @throws TimeoutException */ public SignalServiceEnvelope read(long timeout, TimeUnit unit) - throws InvalidVersionException, IOException, TimeoutException + throws InvalidVersionException, IOException, TimeoutException, InterruptedException { return read(timeout, unit, new NullMessagePipeCallback()); } @@ -91,7 +91,7 @@ public SignalServiceEnvelope read(long timeout, TimeUnit unit) * @throws InvalidVersionException */ public SignalServiceEnvelope read(long timeout, TimeUnit unit, MessagePipeCallback callback) - throws TimeoutException, IOException, InvalidVersionException + throws TimeoutException, IOException, InvalidVersionException, InterruptedException { if (!credentialsProvider.isPresent()) { throw new IllegalArgumentException("You can't read messages if you haven't specified credentials"); diff --git a/java/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java b/java/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java index cc41cdcaf2..4950e13ae9 100644 --- a/java/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java +++ b/java/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java @@ -89,14 +89,14 @@ public class SignalServiceMessageSender { private static final String TAG = SignalServiceMessageSender.class.getSimpleName(); - private final PushServiceSocket socket; - private final SignalProtocolStore store; - private final SignalServiceAddress localAddress; - private final Optional eventListener; + private final PushServiceSocket socket; + private final SignalProtocolStore store; + private final SignalServiceAddress localAddress; + private final Optional eventListener; - private final AtomicReference> pipe; - private final AtomicReference> unidentifiedPipe; - private final AtomicBoolean isMultiDevice; + private final AtomicReference pipe; + private final AtomicReference unidentifiedPipe; + private final AtomicBoolean isMultiDevice; /** * Construct a SignalServiceMessageSender. @@ -113,8 +113,8 @@ public SignalServiceMessageSender(SignalServiceConfiguration urls, SignalProtocolStore store, String userAgent, boolean isMultiDevice, - Optional pipe, - Optional unidentifiedPipe, + AtomicReference pipe, + AtomicReference unidentifiedPipe, Optional eventListener) { this(urls, new StaticCredentialsProvider(user, password, null), store, userAgent, isMultiDevice, pipe, unidentifiedPipe, eventListener); @@ -125,15 +125,15 @@ public SignalServiceMessageSender(SignalServiceConfiguration urls, SignalProtocolStore store, String userAgent, boolean isMultiDevice, - Optional pipe, - Optional unidentifiedPipe, + AtomicReference pipe, + AtomicReference unidentifiedPipe, Optional eventListener) { this.socket = new PushServiceSocket(urls, credentialsProvider, userAgent); this.store = store; this.localAddress = new SignalServiceAddress(credentialsProvider.getUser()); - this.pipe = new AtomicReference<>(pipe); - this.unidentifiedPipe = new AtomicReference<>(unidentifiedPipe); + this.pipe = pipe; + this.unidentifiedPipe = unidentifiedPipe; this.isMultiDevice = new AtomicBoolean(isMultiDevice); this.eventListener = eventListener; } @@ -301,11 +301,6 @@ public void cancelInFlightRequests() { socket.cancelInFlightRequests(); } - public void setMessagePipe(SignalServiceMessagePipe pipe, SignalServiceMessagePipe unidentifiedPipe) { - this.pipe.set(Optional.fromNullable(pipe)); - this.unidentifiedPipe.set(Optional.fromNullable(unidentifiedPipe)); - } - public void setIsMultiDevice(boolean isMultiDevice) { this.isMultiDevice.set(isMultiDevice); } @@ -824,22 +819,22 @@ private SendMessageResult sendMessage(SignalServiceAddress recipient, for (int i=0;i<4;i++) { try { OutgoingPushMessageList messages = getEncryptedMessages(socket, recipient, unidentifiedAccess, timestamp, content, online); - Optional pipe = this.pipe.get(); - Optional unidentifiedPipe = this.unidentifiedPipe.get(); + SignalServiceMessagePipe pipe = this.pipe.get(); + SignalServiceMessagePipe unidentifiedPipe = this.unidentifiedPipe.get(); - if (pipe.isPresent() && !unidentifiedAccess.isPresent()) { + if (pipe != null && !unidentifiedAccess.isPresent()) { try { Log.w(TAG, "Transmitting over pipe..."); - SendMessageResponse response = pipe.get().send(messages, Optional.absent()); + SendMessageResponse response = pipe.send(messages, Optional.absent()); return SendMessageResult.success(recipient, false, response.getNeedsSync()); } catch (IOException e) { Log.w(TAG, e); Log.w(TAG, "Falling back to new connection..."); } - } else if (unidentifiedPipe.isPresent() && unidentifiedAccess.isPresent()) { + } else if (unidentifiedPipe != null && unidentifiedAccess.isPresent()) { try { Log.w(TAG, "Transmitting over unidentified pipe..."); - SendMessageResponse response = unidentifiedPipe.get().send(messages, unidentifiedAccess); + SendMessageResponse response = unidentifiedPipe.send(messages, unidentifiedAccess); return SendMessageResult.success(recipient, true, response.getNeedsSync()); } catch (IOException e) { Log.w(TAG, e); diff --git a/java/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java b/java/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java index 7075cf4305..69880ae18d 100644 --- a/java/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java +++ b/java/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java @@ -20,6 +20,7 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.Map; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -127,22 +128,20 @@ public synchronized void disconnect() { } if (keepAliveSender != null) { - keepAliveSender.shutdown(); + keepAliveSender.interrupt(); keepAliveSender = null; } } public synchronized WebSocketRequestMessage readRequest(long timeoutMillis) - throws TimeoutException, IOException + throws TimeoutException, IOException, InterruptedException { if (client == null) { throw new IOException("Connection closed!"); } - long startTime = System.currentTimeMillis(); - - while (client != null && incomingRequests.isEmpty() && elapsedTime(startTime) < timeoutMillis) { - Util.wait(this, Math.max(1, timeoutMillis - elapsedTime(startTime))); + if (client != null && incomingRequests.isEmpty()) { + wait(timeoutMillis); } if (incomingRequests.isEmpty() && client == null) throw new IOException("Connection closed!"); @@ -183,20 +182,16 @@ public synchronized void sendResponse(WebSocketResponseMessage response) throws } } - private synchronized void sendKeepAlive() throws IOException { + private synchronized Future> sendKeepAlive() throws IOException { if (keepAliveSender != null && client != null) { - byte[] message = WebSocketMessage.newBuilder() - .setType(WebSocketMessage.Type.REQUEST) - .setRequest(WebSocketRequestMessage.newBuilder() - .setId(System.currentTimeMillis()) - .setPath("/v1/keepalive") - .setVerb("GET") - .build()).build() - .toByteArray(); - - if (!client.send(ByteString.of(message))) { - throw new IOException("Write failed!"); - } + WebSocketRequestMessage request = WebSocketRequestMessage.newBuilder() + .setId(System.currentTimeMillis()) + .setPath("/v1/keepalive") + .setVerb("GET") + .build(); + return sendRequest(request); + } else { + return null; } } @@ -249,7 +244,7 @@ public synchronized void onClosed(WebSocket webSocket, int code, String reason) } if (keepAliveSender != null) { - keepAliveSender.shutdown(); + keepAliveSender.interrupt(); keepAliveSender = null; } @@ -312,23 +307,43 @@ private Pair createTlsSocketFactory(TrustSto private class KeepAliveSender extends Thread { - private AtomicBoolean stop = new AtomicBoolean(false); - public void run() { - while (!stop.get()) { + Future future = null; + boolean severed = false; + + while (!interrupted()) { try { sleepTimer.sleep(TimeUnit.SECONDS.toMillis(KEEPALIVE_TIMEOUT_SECONDS)); - Log.w(TAG, "Sending keep alive..."); - sendKeepAlive(); - } catch (Throwable e) { - Log.w(TAG, e); + if (future != null) { + try { + future.get(0L, TimeUnit.SECONDS); + } catch (ExecutionException | TimeoutException e){ + severed = true; + } + } + } catch (InterruptedException e) { + Log.d(TAG, "Keep alive sender interrupted; exiting loop."); + break; } - } - } - public void shutdown() { - stop.set(true); + if (severed) { + Log.d(TAG, "No response to previous keep-alive; forcing new connection."); + + disconnect(); + synchronized(WebSocketConnection.this) { + WebSocketConnection.this.notifyAll(); + } + } else { + Log.d(TAG, "Sending keep alive..."); + + try { + future = sendKeepAlive(); + } catch (IOException e) { + Log.d(TAG, "Failed to send keep alive: " + e.getMessage()); + } + } + } } }