From 7850e2f0f595bd2f08ec1fcb5513beeb3de27b0a Mon Sep 17 00:00:00 2001 From: Hannah Law Date: Tue, 7 Nov 2023 15:52:59 +0000 Subject: [PATCH] Refactor and add number of threads selector. --- src/main/java/attack/AttackHandler.java | 78 ++++++++++++++++++- .../java/connection/ConnectionFactory.java | 10 +-- .../java/connection/WebSocketConnection.java | 10 +-- src/main/java/logger/Logger.java | 2 +- .../java/queue/SendMessageQueueConsumer.java | 16 ++-- .../queue/TableBlockingQueueConsumer.java | 10 +-- .../queue/TableBlockingQueueProducer.java | 1 - src/main/java/ui/WebSocketFrame.java | 19 +++-- .../java/ui/attack/WebSocketAttackPanel.java | 48 ++++-------- .../java/ui/editor/WebSocketEditorPanel.java | 19 +++-- 10 files changed, 138 insertions(+), 75 deletions(-) diff --git a/src/main/java/attack/AttackHandler.java b/src/main/java/attack/AttackHandler.java index b92f532..507d2af 100644 --- a/src/main/java/attack/AttackHandler.java +++ b/src/main/java/attack/AttackHandler.java @@ -8,33 +8,54 @@ import data.ConnectionMessage; import data.WebSocketConnectionMessage; import logger.Logger; +import logger.LoggerLevel; import org.python.util.PythonInterpreter; +import queue.SendMessageQueueConsumer; +import queue.TableBlockingQueueConsumer; import queue.TableBlockingQueueProducer; +import ui.attack.table.WebSocketMessageTableModel; import java.time.LocalDateTime; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; public class AttackHandler { + private final Logger logger; + private final BlockingQueue sendMessageQueue; + private final BlockingQueue tableBlockingQueue; + private final WebSocketMessageTableModel webSocketMessageTableModel; + private final WebSocketMessage baseWebSocketMessage; + private final AtomicBoolean isAttackRunning; private final PythonInterpreter interpreter; + private ExecutorService sendMessageExecutorService; + private ExecutorService tableExecutorService; public AttackHandler( Logger logger, WebSockets webSockets, - AtomicBoolean isProcessing, BlockingQueue sendMessageQueue, BlockingQueue tableBlockingQueue, - WebSocketMessage baseWebSocketMessage + WebSocketMessageTableModel webSocketMessageTableModel, + WebSocketMessage baseWebSocketMessage, + AtomicBoolean isAttackRunning ) { + this.logger = logger; + this.sendMessageQueue = sendMessageQueue; + this.tableBlockingQueue = tableBlockingQueue; + this.webSocketMessageTableModel = webSocketMessageTableModel; + this.baseWebSocketMessage = baseWebSocketMessage; + this.isAttackRunning = isAttackRunning; interpreter = new PythonInterpreter(); interpreter.setOut(logger.outputStream()); interpreter.setErr(logger.errorStream()); interpreter.set("base_websocket", baseWebSocketMessage); - interpreter.set("websocket_connection", new ConnectionFactory(logger, webSockets, isProcessing, sendMessageQueue)); + interpreter.set("websocket_connection", new ConnectionFactory(logger, webSockets, sendMessageQueue, isAttackRunning)); interpreter.set("results_table", new TableBlockingQueueProducer(logger, tableBlockingQueue)); } @@ -58,6 +79,57 @@ public void executeCallback(WebSocketConnectionMessage webSocketConnectionMessag interpreter.exec(String.format("%s(%s)", callbackMethod, messageParameterName)); } + public BlockingQueue getSendMessageQueue() + { + return sendMessageQueue; + } + + public BlockingQueue getTableBlockingQueue() + { + return tableBlockingQueue; + } + + public WebSocketMessageTableModel getWebSocketMessageTableModel() + { + return webSocketMessageTableModel; + } + + public WebSocketMessage getBaseWebSocketMessage() + { + return baseWebSocketMessage; + } + + public AtomicBoolean getIsAttackRunning() + { + return isAttackRunning; + } + + public void startConsumers(int numberOfSendThreads) + { + + sendMessageExecutorService = Executors.newFixedThreadPool(numberOfSendThreads); + sendMessageExecutorService.execute(new SendMessageQueueConsumer(logger, this, isAttackRunning)); + + logger.logOutput(LoggerLevel.DEBUG, "Number of threads attack started with: " + numberOfSendThreads); + + tableExecutorService = Executors.newSingleThreadExecutor(); + tableExecutorService.execute(new TableBlockingQueueConsumer(logger, webSocketMessageTableModel, tableBlockingQueue, isAttackRunning)); + + logger.logOutput(LoggerLevel.DEBUG, "Table thread started."); + } + + public void shutdownConsumers() + { + sendMessageExecutorService.shutdownNow(); + logger.logOutput(LoggerLevel.DEBUG, "sendMessageExecutorService shutdown? " + sendMessageExecutorService.isShutdown()); + + tableExecutorService.shutdownNow(); + logger.logOutput(LoggerLevel.DEBUG, "tableExecutorService shutdown? " + tableExecutorService.isShutdown()); + + sendMessageQueue.clear(); + tableBlockingQueue.clear(); + } + private static class DecoratedConnectionMessage implements ConnectionMessage { private final WebSocketConnectionMessage webSocketConnectionMessage; diff --git a/src/main/java/connection/ConnectionFactory.java b/src/main/java/connection/ConnectionFactory.java index e0b6b31..c1cd152 100644 --- a/src/main/java/connection/ConnectionFactory.java +++ b/src/main/java/connection/ConnectionFactory.java @@ -12,24 +12,24 @@ public class ConnectionFactory { private final Logger logger; private final WebSockets webSockets; - private final AtomicBoolean isProcessing; private final BlockingQueue sendMessageQueue; + private final AtomicBoolean isAttackRunning; public ConnectionFactory( Logger logger, WebSockets webSockets, - AtomicBoolean isProcessing, - BlockingQueue sendMessageQueue + BlockingQueue sendMessageQueue, + AtomicBoolean isAttackRunning ) { this.logger = logger; this.webSockets = webSockets; - this.isProcessing = isProcessing; this.sendMessageQueue = sendMessageQueue; + this.isAttackRunning = isAttackRunning; } public Connection create(WebSocketMessage baseWebSocketMessage) { - return new WebSocketConnection(logger, webSockets, isProcessing, baseWebSocketMessage, sendMessageQueue); + return new WebSocketConnection(logger, webSockets, sendMessageQueue, baseWebSocketMessage, isAttackRunning); } } diff --git a/src/main/java/connection/WebSocketConnection.java b/src/main/java/connection/WebSocketConnection.java index ad97ab3..b1f9ad9 100644 --- a/src/main/java/connection/WebSocketConnection.java +++ b/src/main/java/connection/WebSocketConnection.java @@ -18,22 +18,22 @@ public class WebSocketConnection implements Connection { private final Logger logger; private final WebSockets webSockets; - private final AtomicBoolean isProcessing; private final BlockingQueue sendMessageQueue; + private final AtomicBoolean isAttackRunning; private final ExtensionWebSocket extensionWebSocket; WebSocketConnection( Logger logger, WebSockets webSockets, - AtomicBoolean isProcessing, + BlockingQueue sendMessageQueue, WebSocketMessage baseWebSocketMessage, - BlockingQueue sendMessageQueue + AtomicBoolean isAttackRunning ) { this.logger = logger; this.webSockets = webSockets; - this.isProcessing = isProcessing; this.sendMessageQueue = sendMessageQueue; + this.isAttackRunning = isAttackRunning; extensionWebSocket = createExtensionWebSocket(baseWebSocketMessage); } @@ -41,7 +41,7 @@ public class WebSocketConnection implements Connection @Override public void queue(String payload) { - if (isProcessing.get()) + if (isAttackRunning.get()) {try { sendMessageQueue.put(new WebSocketConnectionMessage(payload, Direction.CLIENT_TO_SERVER, LocalDateTime.now(), null, this)); diff --git a/src/main/java/logger/Logger.java b/src/main/java/logger/Logger.java index 8b7e65d..5469ccf 100644 --- a/src/main/java/logger/Logger.java +++ b/src/main/java/logger/Logger.java @@ -17,7 +17,7 @@ public Logger(Logging logging) { this.logging = logging; - debugLogLevel = false; + debugLogLevel = true; errorLogLevel = true; } diff --git a/src/main/java/queue/SendMessageQueueConsumer.java b/src/main/java/queue/SendMessageQueueConsumer.java index 8263c68..ebeb313 100644 --- a/src/main/java/queue/SendMessageQueueConsumer.java +++ b/src/main/java/queue/SendMessageQueueConsumer.java @@ -12,28 +12,28 @@ public class SendMessageQueueConsumer implements Runnable { private final Logger logger; - private final AtomicBoolean isProcessing; - private final BlockingQueue sendMessageQueue; private final AttackHandler attackHandler; + private final AtomicBoolean isAttackRunning; + private final BlockingQueue sendMessageQueue; public SendMessageQueueConsumer( Logger logger, - AtomicBoolean isProcessing, - BlockingQueue sendMessageQueue, - AttackHandler attackHandler + AttackHandler attackHandler, + AtomicBoolean isAttackRunning ) { this.logger = logger; - this.isProcessing = isProcessing; - this.sendMessageQueue = sendMessageQueue; this.attackHandler = attackHandler; + this.isAttackRunning = isAttackRunning; + + sendMessageQueue = attackHandler.getSendMessageQueue(); } @Override public void run() { - while (isProcessing.get()) + while (isAttackRunning.get()) { try { diff --git a/src/main/java/queue/TableBlockingQueueConsumer.java b/src/main/java/queue/TableBlockingQueueConsumer.java index 49ec26e..c4b3971 100644 --- a/src/main/java/queue/TableBlockingQueueConsumer.java +++ b/src/main/java/queue/TableBlockingQueueConsumer.java @@ -14,25 +14,25 @@ public class TableBlockingQueueConsumer implements Runnable private final Logger logger; private final BlockingQueue queue; private final WebSocketMessageTableModel tableModel; - private final AtomicBoolean isRunning; + private final AtomicBoolean isAttackRunning; public TableBlockingQueueConsumer( Logger logger, - BlockingQueue queue, WebSocketMessageTableModel tableModel, - AtomicBoolean isRunning + BlockingQueue queue, + AtomicBoolean isAttackRunning ) { this.logger = logger; this.queue = queue; this.tableModel = tableModel; - this.isRunning = isRunning; + this.isAttackRunning = isAttackRunning; } @Override public void run() { - while (isRunning.get()) + while (isAttackRunning.get()) { try { diff --git a/src/main/java/queue/TableBlockingQueueProducer.java b/src/main/java/queue/TableBlockingQueueProducer.java index 607f4fd..f356f0d 100644 --- a/src/main/java/queue/TableBlockingQueueProducer.java +++ b/src/main/java/queue/TableBlockingQueueProducer.java @@ -17,7 +17,6 @@ public TableBlockingQueueProducer( BlockingQueue tableBlockingQueue ) { - this.logger = logger; this.tableBlockingQueue = tableBlockingQueue; } diff --git a/src/main/java/ui/WebSocketFrame.java b/src/main/java/ui/WebSocketFrame.java index 350a362..1089cff 100644 --- a/src/main/java/ui/WebSocketFrame.java +++ b/src/main/java/ui/WebSocketFrame.java @@ -10,6 +10,7 @@ import data.WebSocketConnectionMessage; import logger.Logger; import ui.attack.WebSocketAttackPanel; +import ui.attack.table.WebSocketMessageTableModel; import ui.editor.WebSocketEditorPanel; import javax.swing.*; @@ -27,8 +28,8 @@ public class WebSocketFrame extends JFrame private final Persistence persistence; private final WebSockets webSockets; private final WebSocketMessage webSocketMessage; - private final AtomicBoolean isProcessing; - private final AtomicBoolean isRunning; + private final AtomicBoolean isAttackRunning; + private AttackHandler attackHandler; public WebSocketFrame( Logger logger, @@ -44,8 +45,7 @@ public WebSocketFrame( this.webSockets = webSockets; this.webSocketMessage = webSocketMessage; - isProcessing = new AtomicBoolean(true); - isRunning = new AtomicBoolean(true); + isAttackRunning = new AtomicBoolean(true); initComponents(); @@ -54,7 +54,8 @@ public WebSocketFrame( @Override public void windowClosed(WindowEvent e) { - isRunning.set(false); + isAttackRunning.set(false); + attackHandler.shutdownConsumers(); } }); } @@ -72,10 +73,12 @@ private void initComponents() BlockingQueue sendMessageQueue = new LinkedBlockingQueue<>(); BlockingQueue tableBlockingQueue = new LinkedBlockingQueue<>(); - AttackHandler attackHandler = new AttackHandler(logger, webSockets, isProcessing, sendMessageQueue, tableBlockingQueue, webSocketMessage); + WebSocketMessageTableModel webSocketMessageTableModel = new WebSocketMessageTableModel(); - cardDeck.add(new WebSocketEditorPanel(logger, userInterface, persistence, cardLayout, cardDeck, attackHandler, webSocketMessage), "editorPanel"); - cardDeck.add(new WebSocketAttackPanel(logger, userInterface, cardLayout, cardDeck, attackHandler, sendMessageQueue, tableBlockingQueue, isProcessing, isRunning), "attackPanel"); + attackHandler = new AttackHandler(logger, webSockets, sendMessageQueue, tableBlockingQueue, webSocketMessageTableModel, webSocketMessage, isAttackRunning); + + cardDeck.add(new WebSocketEditorPanel(logger, userInterface, persistence, cardLayout, cardDeck, attackHandler), "editorPanel"); + cardDeck.add(new WebSocketAttackPanel(userInterface, cardLayout, cardDeck, attackHandler), "attackPanel"); this.getContentPane().add(cardDeck); this.pack(); diff --git a/src/main/java/ui/attack/WebSocketAttackPanel.java b/src/main/java/ui/attack/WebSocketAttackPanel.java index bee2a56..15f7b41 100644 --- a/src/main/java/ui/attack/WebSocketAttackPanel.java +++ b/src/main/java/ui/attack/WebSocketAttackPanel.java @@ -4,19 +4,10 @@ import burp.api.montoya.ui.UserInterface; import burp.api.montoya.ui.editor.EditorOptions; import burp.api.montoya.ui.editor.WebSocketMessageEditor; -import data.ConnectionMessage; -import data.WebSocketConnectionMessage; -import logger.Logger; -import queue.SendMessageQueueConsumer; -import queue.TableBlockingQueueConsumer; import ui.attack.table.WebSocketMessageTable; -import ui.attack.table.WebSocketMessageTableModel; import javax.swing.*; import java.awt.*; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; public class WebSocketAttackPanel extends JPanel @@ -24,19 +15,14 @@ public class WebSocketAttackPanel extends JPanel private final UserInterface userInterface; private final CardLayout cardLayout; private final JPanel cardDeck; - private final AtomicBoolean isProcessing; - private WebSocketMessageTableModel messageTableModel; + private final AttackHandler attackHandler; + private final AtomicBoolean isAttackRunning; public WebSocketAttackPanel( - Logger logger, UserInterface userInterface, CardLayout cardLayout, JPanel cardDeck, - AttackHandler attackHandler, - BlockingQueue sendMessageQueue, - BlockingQueue tableBlockingQueue, - AtomicBoolean isProcessing, - AtomicBoolean isRunning + AttackHandler attackHandler ) { super(new BorderLayout()); @@ -44,15 +30,11 @@ public WebSocketAttackPanel( this.userInterface = userInterface; this.cardLayout = cardLayout; this.cardDeck = cardDeck; - this.isProcessing = isProcessing; + this.attackHandler = attackHandler; - initComponents(); - - ExecutorService sendMessageExecutorService = Executors.newSingleThreadExecutor(); - sendMessageExecutorService.execute(new SendMessageQueueConsumer(logger, isProcessing, sendMessageQueue, attackHandler)); + isAttackRunning = attackHandler.getIsAttackRunning(); - ExecutorService executorService = Executors.newSingleThreadExecutor(); - executorService.execute(new TableBlockingQueueConsumer(logger, tableBlockingQueue, messageTableModel, isRunning)); + initComponents(); } private void initComponents() @@ -69,9 +51,7 @@ private Component getWebSocketMessageDisplay() private Component getWebSocketMessageTable(WebSocketMessageEditor webSocketMessageEditor) { - messageTableModel = new WebSocketMessageTableModel(); - - return new WebSocketMessageTable(messageTableModel, webSocketMessageEditor); + return new WebSocketMessageTable(attackHandler.getWebSocketMessageTableModel(), webSocketMessageEditor); } private WebSocketMessageEditor getWebSocketMessageEditor() @@ -83,19 +63,23 @@ private Component getHaltConfigureButton() { JButton haltConfigureButton = new JButton("Halt"); haltConfigureButton.addActionListener(l -> { - if (isProcessing.get()) + if (isAttackRunning.get()) { - isProcessing.set(false); + isAttackRunning.set(false); + + attackHandler.shutdownConsumers(); + haltConfigureButton.setText("Configure"); } else { - cardLayout.show(cardDeck, "editorPanel"); - messageTableModel.clear(); + attackHandler.getWebSocketMessageTableModel().clear(); haltConfigureButton.setText("Halt"); - isProcessing.set(true); + isAttackRunning.set(true); + + cardLayout.show(cardDeck, "editorPanel"); } }); diff --git a/src/main/java/ui/editor/WebSocketEditorPanel.java b/src/main/java/ui/editor/WebSocketEditorPanel.java index 8c336fc..4d0cd49 100644 --- a/src/main/java/ui/editor/WebSocketEditorPanel.java +++ b/src/main/java/ui/editor/WebSocketEditorPanel.java @@ -5,7 +5,6 @@ import burp.api.montoya.persistence.Persistence; import burp.api.montoya.ui.Theme; import burp.api.montoya.ui.UserInterface; -import burp.api.montoya.ui.contextmenu.WebSocketMessage; import burp.api.montoya.ui.editor.WebSocketMessageEditor; import logger.Logger; import logger.LoggerLevel; @@ -36,9 +35,9 @@ public class WebSocketEditorPanel extends JPanel private final CardLayout cardLayout; private final JPanel cardDeck; private final AttackHandler attackHandler; - private final WebSocketMessage webSocketMessage; private JComboBox scriptComboBox; private WebSocketMessageEditor webSocketsMessageEditor; + private JSpinner numberOfThreadsSpinner; public WebSocketEditorPanel( Logger logger, @@ -46,8 +45,7 @@ public WebSocketEditorPanel( Persistence persistence, CardLayout cardLayout, JPanel cardDeck, - AttackHandler attackHandler, - WebSocketMessage webSocketMessage + AttackHandler attackHandler ) { this.logger = logger; @@ -56,7 +54,6 @@ public WebSocketEditorPanel( this.cardLayout = cardLayout; this.cardDeck = cardDeck; this.attackHandler = attackHandler; - this.webSocketMessage = webSocketMessage; this.setLayout(new BorderLayout()); @@ -74,7 +71,7 @@ private void initComponents() private Component getWebSocketMessageEditor() { webSocketsMessageEditor = userInterface.createWebSocketMessageEditor(); - webSocketsMessageEditor.setContents(webSocketMessage.payload()); + webSocketsMessageEditor.setContents(attackHandler.getBaseWebSocketMessage().payload()); return webSocketsMessageEditor.uiComponent(); } @@ -141,12 +138,18 @@ private Component getButtonPanel() JPanel buttonPanel = new JPanel(); scriptComboBox = getScriptComboBox(); - buttonPanel.add(scriptComboBox); JButton selectScriptsDirectoryButton = getScriptsDirectoryButton(); buttonPanel.add(selectScriptsDirectoryButton); + JLabel threadsLabel = new JLabel("Number of threads:"); + buttonPanel.add(threadsLabel); + + SpinnerNumberModel spinnerModel = new SpinnerNumberModel(1, 0, 50, 1); + numberOfThreadsSpinner = new JSpinner(spinnerModel); + buttonPanel.add(numberOfThreadsSpinner); + return buttonPanel; } @@ -291,6 +294,8 @@ private JButton getAttackButton(RSyntaxTextArea rSyntaxTextArea) } }).start(); + attackHandler.startConsumers((int) numberOfThreadsSpinner.getValue()); + SwingUtilities.invokeLater(() -> cardLayout.show(cardDeck, "attackPanel")); });