From 6817394d4d9b675933432f8524401ff2f6535a0e Mon Sep 17 00:00:00 2001 From: Nikolaos Kakouros Date: Mon, 20 Jan 2025 18:32:27 +0000 Subject: [PATCH] Simplify agents --- malsim/agents/decision_agent.py | 18 ++- malsim/agents/searchers.py | 190 +++++++++----------------------- 2 files changed, 69 insertions(+), 139 deletions(-) diff --git a/malsim/agents/decision_agent.py b/malsim/agents/decision_agent.py index a03687c..09943bc 100644 --- a/malsim/agents/decision_agent.py +++ b/malsim/agents/decision_agent.py @@ -15,15 +15,25 @@ def get_next_action( self, agent: MalSimAgentStateView, **kwargs - ) -> Optional[AttackGraphNode]: ... + ) -> Optional[AttackGraphNode]: + """ + Select next action the agent will work with. + + Attributes: + agent: Current state of and other info about the agent from the simulator + + Returns: + The selected action or None if there are no actions to select from. + """ + ... class PassiveAgent(DecisionAgent): - def __init__(self, info): - return + def __init__(self, *args, **kwargs): + ... def get_next_action( self, agent: MalSimAgentStateView, **kwargs ) -> Optional[AttackGraphNode]: - return None + ... diff --git a/malsim/agents/searchers.py b/malsim/agents/searchers.py index c9976f7..84e8bc5 100644 --- a/malsim/agents/searchers.py +++ b/malsim/agents/searchers.py @@ -1,8 +1,10 @@ from __future__ import annotations import logging +import re from collections import deque -from typing import Deque, List, Set, Union, Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING + import numpy as np from .decision_agent import DecisionAgent @@ -13,161 +15,79 @@ logger = logging.getLogger(__name__) -def get_new_targets( - discovered_targets: set[int], - possible_actions: set[int] -) -> list[int]: - """Return targets that are not already discovered""" - new_targets = [id for id in possible_actions - if id not in discovered_targets] - return new_targets class BreadthFirstAttacker(DecisionAgent): - def __init__(self, agent_config: dict) -> None: - self.targets: Deque[int] = deque([]) - self.current_target: int = None - - seed = ( - agent_config["seed"] - if agent_config.get("seed", None) - else np.random.SeedSequence().entropy - ) - self.rng = ( - np.random.default_rng(seed) - if agent_config.get("randomize", False) - else None - ) - - def get_next_action( - self, agent: MalSimAgentStateView, **kwargs - ) -> Optional[AttackGraphNode]: - - # Create a dict of possible actions - # mapping id to node - possible_actions = { - n.id: n for n in agent.action_surface - if not n.is_compromised() - } - - # Get targets that are not discovered yet - new_targets = get_new_targets( - self.targets, possible_actions.keys() - ) + """A Breadth-First agent, with possible randomization at each level.""" - # Add new targets to the back of the queue - # if desired, shuffle the new targets to - # make the attacker more unpredictable - if self.rng: - self.rng.shuffle(new_targets) - for c in new_targets: - self.targets.appendleft(c) - - # Select next target - self.current_target = self.select_next_target( - self.current_target, - self.targets, - possible_actions.keys() - ) - - # Convert the current target id to AttackGraphNode - action_node = None - if self.current_target is not None: - action_node = possible_actions[self.current_target] + _extend_method = "extendleft" + # Controls where newly discovered steps will be appended to the list of + # available actions. Currently used to differentiate between BFS and DFS + # agents. - return action_node + name = ' '.join(re.findall(r'[A-Z][^A-Z]*', __qualname__)) + # A human-friendly name for the agent. - @staticmethod - def select_next_target( - previous_target: int, - targets: Union[List[int], Deque[int]], - attack_surface: Set[int], - ) -> Optional[int]: - """Select a target from attack surface - by going through the target queue""" + default_settings = { + 'randomize': False, + # Whether to randomize next target selection, still respecting the + # policy of the agent (e.g. BFS or DFS). + 'seed': None, + # The random seed to initialize the randomness engine with. + } - next_target = None - if previous_target in attack_surface: - # If the current target was not compromised, put it - # back, but on the bottom of the stack. - targets.appendleft(previous_target) - next_target = targets.pop() - - while next_target not in attack_surface: - if len(targets) == 0: - return None - - next_target = targets.pop() + def __init__(self, agent_config: dict) -> None: + """Initialize a BFS agent. - return next_target + Args: + agent_config: Dict with settings to override defaults + """ + self.targets: deque[AttackGraphNode] = deque() + self.current_target: Optional[AttackGraphNode] = None + self.settings = self.default_settings | agent_config -class DepthFirstAttacker(DecisionAgent): - def __init__(self, agent_config: dict) -> None: - self.current_target = -1 - self.targets: List[int] = [] - seed = ( - agent_config["seed"] - if agent_config.get("seed", None) - else np.random.SeedSequence().entropy - ) - self.rng = ( - np.random.default_rng(seed) - if agent_config.get("randomize", False) - else None + self.rng = np.random.default_rng( + self.settings['seed'] or np.random.SeedSequence() ) def get_next_action( self, agent: MalSimAgentStateView, **kwargs ) -> Optional[AttackGraphNode]: + self._update_targets(agent.action_surface) + self._select_next_target() - # Create a dict of possible actions - # mapping id to node - possible_actions = { - n.id: n for n in agent.action_surface - if not n.is_compromised() - } + return self.current_target - # Get targets that are not discovered yet - new_targets = get_new_targets( - self.targets, possible_actions.keys() - ) + def _update_targets(self, action_surface): + new_targets = [ + step + for step in action_surface + if step not in self.targets and not step.is_compromised() + ] - # Add new targets to the top of the stack - if self.rng: + if self.settings['randomize']: self.rng.shuffle(new_targets) - for c in new_targets: - self.targets.append(c) - - self.current_target = self.select_next_target( - self.current_target, self.targets, possible_actions.keys() - ) - - # Convert the current target id to AttackGraphNode - action_node = None - if self.current_target is not None: - action_node = possible_actions[self.current_target] - return action_node + if self.current_target in new_targets: + # If self.current_target is not yet compromised, e.g. due to TTCs, + # keep using that as the target. + new_targets.remove(self.current_target) + new_targets.append(self.current_target) - @staticmethod - def select_next_target( - previous_target: int, - targets: Union[List[int], Deque[int]], - attack_surface: Set[int], - ) -> Optional[int]: - if previous_target in attack_surface: - return previous_target + # Enabled defenses may remove previously possible attack steps. + self.targets = deque(filter(lambda n: n.is_viable, self.targets)) - next_target = None - while next_target not in attack_surface: - if len(targets) == 0: - return None - next_target = targets.pop() + getattr(self.targets, self._extend_method)(new_targets) - return next_target + def _select_next_target(self) -> None: + """ + Implement the actual next target selection logic. + """ + try: + self.current_target = self.targets.pop() + except IndexError: + self.current_target = None -AGENTS = { - BreadthFirstAttacker.__name__: BreadthFirstAttacker, - DepthFirstAttacker.__name__: DepthFirstAttacker, -} +class DepthFirstAttacker(BreadthFirstAttacker): + _extend_method = "extend"