Skip to content

Commit

Permalink
Simplify agents
Browse files Browse the repository at this point in the history
  • Loading branch information
nkakouros authored and Nikolaos Kakouros committed Jan 22, 2025
1 parent bf769ab commit 6817394
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 139 deletions.
18 changes: 14 additions & 4 deletions malsim/agents/decision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,25 @@ def get_next_action(
self,
agent: MalSimAgentStateView,
**kwargs
) -> Optional[AttackGraphNode]: ...
) -> Optional[AttackGraphNode]:
"""
Select next action the agent will work with.
Attributes:
agent: Current state of and other info about the agent from the simulator
Returns:
The selected action or None if there are no actions to select from.
"""
...

class PassiveAgent(DecisionAgent):
def __init__(self, info):
return
def __init__(self, *args, **kwargs):
...

def get_next_action(
self,
agent: MalSimAgentStateView,
**kwargs
) -> Optional[AttackGraphNode]:
return None
...
190 changes: 55 additions & 135 deletions malsim/agents/searchers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations
import logging
import re

from collections import deque
from typing import Deque, List, Set, Union, Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING

import numpy as np

from .decision_agent import DecisionAgent
Expand All @@ -13,161 +15,79 @@

logger = logging.getLogger(__name__)

def get_new_targets(
discovered_targets: set[int],
possible_actions: set[int]
) -> list[int]:
"""Return targets that are not already discovered"""
new_targets = [id for id in possible_actions
if id not in discovered_targets]
return new_targets

class BreadthFirstAttacker(DecisionAgent):
def __init__(self, agent_config: dict) -> None:
self.targets: Deque[int] = deque([])
self.current_target: int = None

seed = (
agent_config["seed"]
if agent_config.get("seed", None)
else np.random.SeedSequence().entropy
)
self.rng = (
np.random.default_rng(seed)
if agent_config.get("randomize", False)
else None
)

def get_next_action(
self, agent: MalSimAgentStateView, **kwargs
) -> Optional[AttackGraphNode]:

# Create a dict of possible actions
# mapping id to node
possible_actions = {
n.id: n for n in agent.action_surface
if not n.is_compromised()
}

# Get targets that are not discovered yet
new_targets = get_new_targets(
self.targets, possible_actions.keys()
)
"""A Breadth-First agent, with possible randomization at each level."""

# Add new targets to the back of the queue
# if desired, shuffle the new targets to
# make the attacker more unpredictable
if self.rng:
self.rng.shuffle(new_targets)
for c in new_targets:
self.targets.appendleft(c)

# Select next target
self.current_target = self.select_next_target(
self.current_target,
self.targets,
possible_actions.keys()
)

# Convert the current target id to AttackGraphNode
action_node = None
if self.current_target is not None:
action_node = possible_actions[self.current_target]
_extend_method = "extendleft"
# Controls where newly discovered steps will be appended to the list of
# available actions. Currently used to differentiate between BFS and DFS
# agents.

return action_node
name = ' '.join(re.findall(r'[A-Z][^A-Z]*', __qualname__))
# A human-friendly name for the agent.

@staticmethod
def select_next_target(
previous_target: int,
targets: Union[List[int], Deque[int]],
attack_surface: Set[int],
) -> Optional[int]:
"""Select a target from attack surface
by going through the target queue"""
default_settings = {
'randomize': False,
# Whether to randomize next target selection, still respecting the
# policy of the agent (e.g. BFS or DFS).
'seed': None,
# The random seed to initialize the randomness engine with.
}

next_target = None
if previous_target in attack_surface:
# If the current target was not compromised, put it
# back, but on the bottom of the stack.
targets.appendleft(previous_target)
next_target = targets.pop()

while next_target not in attack_surface:
if len(targets) == 0:
return None

next_target = targets.pop()
def __init__(self, agent_config: dict) -> None:
"""Initialize a BFS agent.
return next_target
Args:
agent_config: Dict with settings to override defaults
"""
self.targets: deque[AttackGraphNode] = deque()
self.current_target: Optional[AttackGraphNode] = None

self.settings = self.default_settings | agent_config

class DepthFirstAttacker(DecisionAgent):
def __init__(self, agent_config: dict) -> None:
self.current_target = -1
self.targets: List[int] = []
seed = (
agent_config["seed"]
if agent_config.get("seed", None)
else np.random.SeedSequence().entropy
)
self.rng = (
np.random.default_rng(seed)
if agent_config.get("randomize", False)
else None
self.rng = np.random.default_rng(
self.settings['seed'] or np.random.SeedSequence()
)

def get_next_action(
self, agent: MalSimAgentStateView, **kwargs
) -> Optional[AttackGraphNode]:
self._update_targets(agent.action_surface)
self._select_next_target()

# Create a dict of possible actions
# mapping id to node
possible_actions = {
n.id: n for n in agent.action_surface
if not n.is_compromised()
}
return self.current_target

# Get targets that are not discovered yet
new_targets = get_new_targets(
self.targets, possible_actions.keys()
)
def _update_targets(self, action_surface):
new_targets = [
step
for step in action_surface
if step not in self.targets and not step.is_compromised()
]

# Add new targets to the top of the stack
if self.rng:
if self.settings['randomize']:
self.rng.shuffle(new_targets)
for c in new_targets:
self.targets.append(c)

self.current_target = self.select_next_target(
self.current_target, self.targets, possible_actions.keys()
)

# Convert the current target id to AttackGraphNode
action_node = None
if self.current_target is not None:
action_node = possible_actions[self.current_target]

return action_node
if self.current_target in new_targets:
# If self.current_target is not yet compromised, e.g. due to TTCs,
# keep using that as the target.
new_targets.remove(self.current_target)
new_targets.append(self.current_target)

@staticmethod
def select_next_target(
previous_target: int,
targets: Union[List[int], Deque[int]],
attack_surface: Set[int],
) -> Optional[int]:
if previous_target in attack_surface:
return previous_target
# Enabled defenses may remove previously possible attack steps.
self.targets = deque(filter(lambda n: n.is_viable, self.targets))

next_target = None
while next_target not in attack_surface:
if len(targets) == 0:
return None
next_target = targets.pop()
getattr(self.targets, self._extend_method)(new_targets)

return next_target
def _select_next_target(self) -> None:
"""
Implement the actual next target selection logic.
"""
try:
self.current_target = self.targets.pop()
except IndexError:
self.current_target = None


AGENTS = {
BreadthFirstAttacker.__name__: BreadthFirstAttacker,
DepthFirstAttacker.__name__: DepthFirstAttacker,
}
class DepthFirstAttacker(BreadthFirstAttacker):
_extend_method = "extend"

0 comments on commit 6817394

Please sign in to comment.