Skip to content

Commit

Permalink
Rewrite agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikolaos Kakouros committed Jan 13, 2025
1 parent 40a4abc commit 601eb41
Showing 1 changed file with 50 additions and 108 deletions.
158 changes: 50 additions & 108 deletions malsim/agents/searchers.py
Original file line number Diff line number Diff line change
@@ -1,137 +1,79 @@
import logging
import re

from collections import deque
from typing import Any, Deque, Dict, List, Set, Union
from typing import Any, Optional

import numpy as np

logger = logging.getLogger(__name__)


def get_new_targets(
observation: dict, discovered_targets: Set[int], mask: tuple
) -> List[int]:
attack_surface = mask[1]
surface_indexes = list(np.flatnonzero(attack_surface))
new_targets = [idx for idx in surface_indexes if idx not in discovered_targets]
return new_targets, surface_indexes


class PassiveAttacker:
def compute_action_from_dict(self, observation, mask):
return (0, None)

class BreadthFirstAttacker:
def __init__(self, agent_config: dict) -> None:
self.targets: Deque[int] = deque([])
self.current_target: int = None
seed = (
agent_config["seed"]
if agent_config.get("seed", None)
else np.random.SeedSequence().entropy
)
self.rng = (
np.random.default_rng(seed)
if agent_config.get("randomize", False)
else None
)
pass

def compute_action_from_dict(self, observation: Dict[str, Any], mask: tuple):
new_targets, surface_indexes = get_new_targets(observation, self.targets, mask)

# Add new targets to the back of the queue
# if desired, shuffle the new targets to make the attacker more unpredictable
if self.rng:
self.rng.shuffle(new_targets)
for c in new_targets:
self.targets.appendleft(c)

self.current_target, done = self.select_next_target(
self.current_target, self.targets, surface_indexes
)

self.current_target = None if done else self.current_target
action = 0 if done else 1
if action == 0:
logger.debug(
"Attacker Breadth First agent does not have "
"any valid targets it will terminate"
)

return (action, self.current_target)
def compute_action_from_dict(self, observation, mask):
return (0, None)

@staticmethod
def select_next_target(
current_target: int,
targets: Union[List[int], Deque[int]],
attack_surface: Set[int],
) -> int:
# If the current target was not compromised, put it
# back, but on the bottom of the stack.
if current_target in attack_surface:
targets.appendleft(current_target)
current_target = targets.pop()

while current_target not in attack_surface:
if len(targets) == 0:
return None, True
class BreadthFirstAttacker:
_insert_head: Optional[int] = 0
name = ' '.join(re.findall(r'[A-Z][^A-Z]*', __qualname__))

current_target = targets.pop()
default_settings = {
'randomize': False,
'seed': None,
}

return current_target, False
def __init__(self, agent_config: dict) -> None:
self.targets: list[int] = []
self.current_target: Optional[int] = None

self.attack_graph = agent_config.pop('attack_graph')
self.settings = self.default_settings | agent_config

class DepthFirstAttacker:
def __init__(self, agent_config: dict) -> None:
self.current_target = -1
self.targets: List[int] = []
seed = (
agent_config["seed"]
if agent_config.get("seed", None)
else np.random.SeedSequence().entropy
)
self.rng = (
np.random.default_rng(seed)
if agent_config.get("randomize", False)
else None
self.rng = np.random.default_rng(
self.settings['seed'] or np.random.SeedSequence()
)

def compute_action_from_dict(self, observation: Dict[str, Any], mask: tuple):
new_targets, surface_indexes = get_new_targets(observation, self.targets, mask)
def compute_action_from_dict(
self, observation: dict[str, Any], mask: tuple
) -> tuple[int, Optional[int]]:
# mask[1] has 1s for actionable steps the agent has not compromised yet.
attack_surface = list(np.flatnonzero(mask[1]))

new_targets = [step for step in attack_surface if step not in self.targets]

# Add new targets to the top of the stack
if self.rng:
if self.settings['randomize']:
self.rng.shuffle(new_targets)
for c in new_targets:
self.targets.append(c)

self.current_target, done = self.select_next_target(
self.current_target, self.targets, surface_indexes
)
for c in new_targets:
if self._insert_head is None:
self.targets.append(c)
else:
self.targets.insert(self._insert_head, c)

self.current_target = None if done else self.current_target
action = 0 if done else 1
return (action, self.current_target)
self.current_target, done = self._select_next_target()

@staticmethod
def select_next_target(
current_target: int,
targets: Union[List[int], Deque[int]],
attack_surface: Set[int],
) -> int:
if current_target in attack_surface:
return current_target, False
if done:
logger.debug(
'%s agent does not have any valid targets it will terminate', self.name
)

while current_target not in attack_surface:
if len(targets) == 0:
return None, True
return (int(not done), self.current_target)

current_target = targets.pop()
def _select_next_target(self) -> tuple[Optional[int], bool]:
if self.current_target in self.targets:
# If self.current_target is not yet compromised, e.g. due to TTCs,
# keep using that as the target.
self.targets.remove(self.current_target)
self.targets.append(self.current_target)

return current_target, False
try:
return self.targets.pop(), False
except IndexError:
return None, True


AGENTS = {
BreadthFirstAttacker.__name__: BreadthFirstAttacker,
DepthFirstAttacker.__name__: DepthFirstAttacker,
}
class DepthFirstAttacker(BreadthFirstAttacker):
_insert_head = None

0 comments on commit 601eb41

Please sign in to comment.