diff --git a/src/main/java/gov/usgs/earthquake/aws/AwsProductReceiver.java b/src/main/java/gov/usgs/earthquake/aws/AwsProductReceiver.java index b0c7f4a4..f056c593 100644 --- a/src/main/java/gov/usgs/earthquake/aws/AwsProductReceiver.java +++ b/src/main/java/gov/usgs/earthquake/aws/AwsProductReceiver.java @@ -61,6 +61,12 @@ public class AwsProductReceiver extends DefaultNotificationReceiver implements W protected Long lastBroadcastId = null; protected boolean processBroadcast = false; + /** Used to coordinate sending products_created_after message. */ + protected boolean sendProductsCreatedAfterFlag = false; + protected boolean sendProductsCreatedAfterRunning = false; + protected Object sendProductsCreatedAfterSync = new Object(); + protected Thread sendProductsCreatedAfterThread; + @Override public void configure(Config config) throws Exception { @@ -260,6 +266,81 @@ protected void onProductsCreatedAfter(final JsonObject json) throws Exception { } } + /** + * Request background thread to send "products_created_after" message. + * + * @throws IOException + */ + protected void sendProductsCreatedAfter() throws IOException { + synchronized(sendProductsCreatedAfterSync) { + // set flag that we want to send products created after + sendProductsCreatedAfterFlag = true; + // wake up background thread that sends message + sendProductsCreatedAfterSync.notifyAll(); + } + } + + /** + * Start background thread that sends "products_created_after" messages. + * + * @return started thread. + */ + protected void startSendProductsCreatedAfterThread() { + if (sendProductsCreatedAfterThread != null) { + throw new IllegalStateException("sendProductsCreatedThread already exists"); + } + sendProductsCreatedAfterFlag = false; + sendProductsCreatedAfterRunning = true; + sendProductsCreatedAfterThread = new Thread(() -> { + while (sendProductsCreatedAfterRunning) { + try { + synchronized (sendProductsCreatedAfterSync) { + if (!sendProductsCreatedAfterFlag) { + // wait until ready to send + sendProductsCreatedAfterSync.wait(); + continue; + } + } + + // ready to send, try to keep queue size manageable + throttleQueues(); + + synchronized (sendProductsCreatedAfterSync) { + // now send actual products created after message + try { + _sendProductsCreatedAfter(); + // message sent, reset flag + sendProductsCreatedAfterFlag = false; + } catch (IOException e) { + LOGGER.log( + Level.WARNING, + "[" + getName() + "] Exception sending products_created_after", + e); + } + } + } catch (InterruptedException ie) { + // interrupted usually means shutting down thread + } + } + }); + sendProductsCreatedAfterThread.start(); + } + + protected void stopProductsCreatedAfterThread() { + try { + sendProductsCreatedAfterRunning = false; + sendProductsCreatedAfterThread.interrupt(); + sendProductsCreatedAfterThread.join(); + } catch (Exception e) { + LOGGER.log( + Level.WARNING, + "[" + getName() + "] exception stopping sendProductsCreatedAfterThread", + e); + } finally { + sendProductsCreatedAfterThread = null; + } + } + /** * Send an "action"="products_created_after" request, which is part of the * catch up process. @@ -268,7 +349,7 @@ protected void onProductsCreatedAfter(final JsonObject json) throws Exception { * then one "action"="products_created_after" message to indicate the request * is complete. */ - protected void sendProductsCreatedAfter() throws IOException { + protected void _sendProductsCreatedAfter() throws IOException { // set default for created after if (this.createdAfter == null) { this.createdAfter = Instant.now().minusSeconds(7 * 86400); @@ -323,6 +404,9 @@ public void startup() throws Exception{ createdAfter = Instant.parse(json.getString(CREATED_AFTER_PROPERTY)); } + // start background thread for products_create_after messages + startSendProductsCreatedAfterThread(); + //open websocket client = new WebSocketClient(uri, this, attempts, timeout, true); } @@ -333,8 +417,11 @@ public void startup() throws Exception{ */ @Override public void shutdown() throws Exception{ + stopProductsCreatedAfterThread(); //close socket - client.shutdown(); + try { + client.shutdown(); + } catch (Exception e) {} super.shutdown(); } diff --git a/src/test/java/gov/usgs/earthquake/aws/AwsProductReceiverTest.java b/src/test/java/gov/usgs/earthquake/aws/AwsProductReceiverTest.java new file mode 100644 index 00000000..084d6d1b --- /dev/null +++ b/src/test/java/gov/usgs/earthquake/aws/AwsProductReceiverTest.java @@ -0,0 +1,194 @@ +package gov.usgs.earthquake.aws; + +import java.time.Instant; + +import javax.json.Json; +import javax.json.JsonObject; + +import org.junit.Assert; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import gov.usgs.earthquake.distribution.Notification; +import gov.usgs.earthquake.product.Product; +import gov.usgs.earthquake.product.ProductId; +import gov.usgs.earthquake.product.io.JsonProduct; + +public class AwsProductReceiverTest { + + TestAwsProductReceiver receiver; + + @BeforeEach + public void before() throws Exception { + receiver = new TestAwsProductReceiver(); + receiver.startSendProductsCreatedAfterThread(); + } + + @AfterEach + public void after() throws Exception { + receiver.stopProductsCreatedAfterThread(); + receiver = null; + } + + @Test + public void testSwitchToBroadcast() throws Exception { + TestSession testSession = new TestSession(); + // connect + receiver.onOpen(testSession); + + // receive broadcast + Instant created = Instant.now(); + receiver.onMessage(getNotification("broadcast", 10, created).toString()); + Assert.assertFalse("not in broadcast mode yet", receiver.isProcessBroadcast()); + Assert.assertNull("didn't process broadcast", receiver.lastJsonNotification); + Assert.assertEquals("saved broadcast id", + Long.valueOf(10L), receiver.getLastBroadcastId()); + + // receive response to products_created_after + receiver.onMessage(getNotification("product", 10, created).toString()); + Assert.assertNotNull( + "processed product during catch up", + receiver.lastJsonNotification); + + // receive end of products_created_after response + receiver.onMessage(getProductsCreatedAfter(created, 1).toString()); + Assert.assertTrue("switched to broadcast mode", receiver.isProcessBroadcast()); + } + + + @Test + public void testBroadcastOutOfOrder() throws Exception { + TestSession testSession = new TestSession(); + // connect + receiver.onOpen(testSession); + // enable broadcast mode + receiver.setProcessBroadcast(true); + receiver.setLastBroadcastId(10L); + + // receive broadcast in order + receiver.onMessage(getNotification("broadcast", 11, Instant.now()).toString()); + Assert.assertTrue("still in broadcast mode", receiver.isProcessBroadcast()); + Assert.assertNotNull("processed broadcast", receiver.lastJsonNotification); + Assert.assertEquals("saved broadcast id", + Long.valueOf(11L), receiver.getLastBroadcastId()); + + // clear any previous products created after message + testSession.testBasicRemote.lastSendText = null; + receiver.lastJsonNotification = null; + + // receive broadcast out of order + receiver.onMessage(getNotification("broadcast", 13, Instant.now()).toString()); + Assert.assertFalse("no longer in broadcast mode", receiver.isProcessBroadcast()); + Assert.assertNull("did not broadcast", receiver.lastJsonNotification); + Assert.assertEquals("still saved broadcast id", + Long.valueOf(13L), receiver.getLastBroadcastId()); + String sent = testSession.waitForBasicSendText(100L); + Assert.assertTrue( + "sent products_created_after", + sent.contains("\"action\":\"products_created_after\"")); + } + + @Test + public void testStartCatchUpWhenConnected() throws Exception { + // set flags being tested to opposite values + receiver.setProcessBroadcast(true); + TestSession testSession = new TestSession(); + + // call onOpen to simulate connection + receiver.onOpen(testSession); + + // processBroadcast disabled and created after request sent + Assert.assertFalse("not in process broadcast mode", receiver.isProcessBroadcast()); + + String sent = testSession.waitForBasicSendText(1000L); + Assert.assertTrue( + "sent products_created_after", + sent.contains("\"action\":\"products_created_after\"")); + } + + @Test + public void testThrottleQueue() throws Exception { + TestSession testSession = new TestSession(); + TestListenerNotifier testNotifier = new TestListenerNotifier(receiver); + receiver.setNotifier(testNotifier); + // simulate queue that needs to be throttled + testNotifier.queueSize = 5001; + testNotifier.setThrottleStartThreshold(5000); + testNotifier.setThrottleStopThreshold(2500); + testNotifier.setThrottleWaitInterval(100L); + + // call onOpen to simulate connection + receiver.onOpen(testSession); + Assert.assertNull( + "throttling should prevent message", + testSession.waitForBasicSendText(100L)); + + // now simulate queue that is done throttling + testNotifier.queueSize = 2499; + String sent = testSession.waitForBasicSendText(500L); + Assert.assertTrue( + "sent products_created_after", + sent.contains("\"action\":\"products_created_after\"")); + } + + static JsonObject getNotification(final String action, final long id, final Instant created) throws Exception { + Product product = new Product(new ProductId("source", "type", "code")); + return Json.createObjectBuilder() + .add("action", action) + .add("notification", + Json.createObjectBuilder() + .add("id", id) + .add("created", created.toString()) + .add("product", new JsonProduct().getJsonObject(product))) + .build(); + } + + static JsonObject getProductsCreatedAfter(final Instant createdAfter, final int count) { + return Json.createObjectBuilder() + .add("action", "products_created_after") + .add("created_after", createdAfter.toString()) + .add("count", count) + .build(); + } + + /** + * Stub socket connections to test message handling behavior. + */ + static class TestAwsProductReceiver extends AwsProductReceiver { + public JsonNotification lastJsonNotification; + public boolean onJsonNotificationCalled = false; + + @Override + protected void onJsonNotification(JsonNotification notification) throws Exception { + onJsonNotificationCalled = true; + lastJsonNotification = notification; + super.onJsonNotification(notification); + } + + @Override + public void receiveNotification(Notification notification) throws Exception { + // skip actual processing + } + + @Override + public void writeTrackingData() { + // skip tracking + } + + // getter/setter to control state for testing + + public Instant getCreatedAfter() { return this.createdAfter; } + public void setCreatedAfter(final Instant c) { this.createdAfter = c; } + + public JsonNotification getLastBroadcast() { return this.lastBroadcast; } + public void setLastBroadcast(final JsonNotification j) { this.lastBroadcast = j; } + + public Long getLastBroadcastId() { return this.lastBroadcastId; } + public void setLastBroadcastId(final Long l) { this.lastBroadcastId = l; } + + public boolean isProcessBroadcast() { return this.processBroadcast; } + public void setProcessBroadcast(final boolean b) { this.processBroadcast = b; } + } + +} diff --git a/src/test/java/gov/usgs/earthquake/aws/TestBasicRemote.java b/src/test/java/gov/usgs/earthquake/aws/TestBasicRemote.java new file mode 100644 index 00000000..cb51ba27 --- /dev/null +++ b/src/test/java/gov/usgs/earthquake/aws/TestBasicRemote.java @@ -0,0 +1,97 @@ +package gov.usgs.earthquake.aws; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.nio.ByteBuffer; +import java.util.logging.Logger; + +import javax.websocket.EncodeException; +import javax.websocket.RemoteEndpoint.Basic; + +/** + * TestBasicRemote captures data sent via sendText calls. + */ +public class TestBasicRemote implements Basic { + + private static final Logger LOGGER = Logger.getLogger(TestBasicRemote.class.getName()); + + public String lastSendText = null; + private Object sendTextSync = new Object(); + + @Override + public void sendText(String text) throws IOException { + LOGGER.info("sendText called with " + text); + synchronized(sendTextSync) { + lastSendText = text; + sendTextSync.notifyAll(); + } + } + + public String waitForSendText(final long timeoutMillis) throws InterruptedException { + synchronized(sendTextSync) { + if (lastSendText == null) { + sendTextSync.wait(timeoutMillis); + } + return lastSendText; + } + }; + + // other methods in interface are not used at this time + + @Override + public void setBatchingAllowed(boolean allowed) throws IOException { + throw new RuntimeException("Not implemented"); + } + + @Override + public boolean getBatchingAllowed() { + throw new RuntimeException("Not implemented"); + } + + @Override + public void flushBatch() throws IOException { + throw new RuntimeException("Not implemented"); + } + + @Override + public void sendPing(ByteBuffer applicationData) throws IOException, IllegalArgumentException { + throw new RuntimeException("Not implemented"); + } + + @Override + public void sendPong(ByteBuffer applicationData) throws IOException, IllegalArgumentException { + throw new RuntimeException("Not implemented"); + } + + @Override + public void sendBinary(ByteBuffer data) throws IOException { + throw new RuntimeException("Not implemented"); + } + + @Override + public void sendText(String partialMessage, boolean isLast) throws IOException { + throw new RuntimeException("Not implemented"); + } + + @Override + public void sendBinary(ByteBuffer partialByte, boolean isLast) throws IOException { + throw new RuntimeException("Not implemented"); + } + + @Override + public OutputStream getSendStream() throws IOException { + throw new RuntimeException("Not implemented"); + } + + @Override + public Writer getSendWriter() throws IOException { + throw new RuntimeException("Not implemented"); + } + + @Override + public void sendObject(Object data) throws IOException, EncodeException { + throw new RuntimeException("Not implemented"); + } + +} diff --git a/src/test/java/gov/usgs/earthquake/aws/TestListenerNotifier.java b/src/test/java/gov/usgs/earthquake/aws/TestListenerNotifier.java new file mode 100644 index 00000000..a428912c --- /dev/null +++ b/src/test/java/gov/usgs/earthquake/aws/TestListenerNotifier.java @@ -0,0 +1,23 @@ +package gov.usgs.earthquake.aws; + +import gov.usgs.earthquake.distribution.DefaultNotificationReceiver; +import gov.usgs.earthquake.distribution.ExecutorListenerNotifier; + +public class TestListenerNotifier extends ExecutorListenerNotifier { + + // override queueSize for testing + public Integer queueSize = null; + + public TestListenerNotifier(DefaultNotificationReceiver receiver) { + super(receiver); + } + + @Override + public Integer getMaxQueueSize() { + if (queueSize != null) { + return queueSize; + } + return super.getMaxQueueSize(); + } + +} diff --git a/src/test/java/gov/usgs/earthquake/aws/TestSession.java b/src/test/java/gov/usgs/earthquake/aws/TestSession.java new file mode 100644 index 00000000..1121e65c --- /dev/null +++ b/src/test/java/gov/usgs/earthquake/aws/TestSession.java @@ -0,0 +1,179 @@ +package gov.usgs.earthquake.aws; + +import java.io.IOException; +import java.net.URI; +import java.security.Principal; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.websocket.CloseReason; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.MessageHandler.Partial; +import javax.websocket.MessageHandler.Whole; +import javax.websocket.RemoteEndpoint.Async; +import javax.websocket.RemoteEndpoint.Basic; +import javax.websocket.Session; +import javax.websocket.WebSocketContainer; + + +/** + * Test session provides testing basic remote to capture sendText calls. + */ +public class TestSession implements Session { + + TestBasicRemote testBasicRemote = new TestBasicRemote(); + + @Override + public Basic getBasicRemote() { + return testBasicRemote; + } + + @Override + public String getId() { + return "test session"; + } + + public String waitForBasicSendText(final long timeoutMillis) throws InterruptedException { + return testBasicRemote.waitForSendText(timeoutMillis); + } + + // other interface methods not implemented at this time. + + @Override + public WebSocketContainer getContainer() { + throw new RuntimeException("Not implemented"); + } + + @Override + public void addMessageHandler(MessageHandler handler) throws IllegalStateException { + throw new RuntimeException("Not implemented"); + } + + @Override + public void addMessageHandler(Class clazz, Whole handler) { + throw new RuntimeException("Not implemented"); + } + + @Override + public void addMessageHandler(Class clazz, Partial handler) { + throw new RuntimeException("Not implemented"); + } + + @Override + public Set getMessageHandlers() { + throw new RuntimeException("Not implemented"); + } + + @Override + public void removeMessageHandler(MessageHandler handler) { + throw new RuntimeException("Not implemented"); + } + + @Override + public String getProtocolVersion() { + throw new RuntimeException("Not implemented"); + } + + @Override + public String getNegotiatedSubprotocol() { + throw new RuntimeException("Not implemented"); + } + + @Override + public List getNegotiatedExtensions() { + throw new RuntimeException("Not implemented"); + } + + @Override + public boolean isSecure() { + throw new RuntimeException("Not implemented"); + } + + @Override + public boolean isOpen() { + throw new RuntimeException("Not implemented"); + } + + @Override + public long getMaxIdleTimeout() { + throw new RuntimeException("Not implemented"); + } + + @Override + public void setMaxIdleTimeout(long milliseconds) { + throw new RuntimeException("Not implemented"); + } + + @Override + public void setMaxBinaryMessageBufferSize(int length) { + throw new RuntimeException("Not implemented"); + } + + @Override + public int getMaxBinaryMessageBufferSize() { + throw new RuntimeException("Not implemented"); + } + + @Override + public void setMaxTextMessageBufferSize(int length) { + throw new RuntimeException("Not implemented"); + } + + @Override + public int getMaxTextMessageBufferSize() { + throw new RuntimeException("Not implemented"); + } + + @Override + public Async getAsyncRemote() { + throw new RuntimeException("Not implemented"); + } + + @Override + public void close() throws IOException { + throw new RuntimeException("Not implemented"); + } + + @Override + public void close(CloseReason closeReason) throws IOException { + throw new RuntimeException("Not implemented"); + } + + @Override + public URI getRequestURI() { + throw new RuntimeException("Not implemented"); + } + + @Override + public Map> getRequestParameterMap() { + throw new RuntimeException("Not implemented"); + } + + @Override + public String getQueryString() { + throw new RuntimeException("Not implemented"); + } + + @Override + public Map getPathParameters() { + throw new RuntimeException("Not implemented"); + } + + @Override + public Map getUserProperties() { + throw new RuntimeException("Not implemented"); + } + + @Override + public Principal getUserPrincipal() { + throw new RuntimeException("Not implemented"); + } + + @Override + public Set getOpenSessions() { + throw new RuntimeException("Not implemented"); + } + +}