Skip to content

Commit

Permalink
Add detectors in agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikolaos Kakouros committed Jan 29, 2025
1 parent 0a5bf0b commit 92b1ff7
Showing 1 changed file with 72 additions and 15 deletions.
87 changes: 72 additions & 15 deletions malsim/agents/searchers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import pprint
import re

from typing import Any, Optional
Expand All @@ -21,15 +22,24 @@ class BreadthFirstAttacker:
_insert_head: Optional[int] = 0
name = ' '.join(re.findall(r'[A-Z][^A-Z]*', __qualname__))

default_settings = {
'randomize': False,
'seed': None,
'wait_factor': 0,
}

def __init__(self, agent_config: dict) -> None:
self.targets: list[int] = []
self.current_target: Optional[int] = None
self.rng = None

if agent_config.get('randomize', False):
self.rng = np.random.default_rng(
agent_config.get('seed', np.random.SeedSequence().entropy)
)
self.logs: list[dict] = []

self.attack_graph = agent_config.pop('attack_graph')
self.settings = self.default_settings | agent_config

self.rng = np.random.default_rng(
self.settings['seed'] or np.random.SeedSequence()
)

def compute_action_from_dict(
self, observation: dict[str, Any], mask: tuple
Expand All @@ -39,7 +49,7 @@ def compute_action_from_dict(

new_targets = [step for step in attack_surface if step not in self.targets]

if self.rng:
if self.settings['randomize']:
self.rng.shuffle(new_targets)

for c in new_targets:
Expand All @@ -48,23 +58,70 @@ def compute_action_from_dict(
else:
self.targets.insert(self._insert_head, c)

# If self.current_target is not yet compromised, e.g. due to TTCs, keep
# using that as the target, else choose a new target.
if self.current_target not in attack_surface:
self.current_target, done = self._select_next_target()
act, self.current_target = self._select_next_target()

if done:
if act:
self._collect_logs(observation)
else:
logger.debug(
'%s agent does not have any valid targets it will terminate', self.name
)

return (int(not done), self.current_target)
return (int(act), self.current_target)

def _select_next_target(self) -> tuple[bool, Optional[int]]:
if self.current_target in self.targets:
# If self.current_target is not yet compromised, e.g. due to TTCs,
# keep using that as the target.
self.targets.remove(self.current_target)
self.targets.append(self.current_target)

act = np.random.choice(
[True, False],
p=[1 - self.settings['wait_factor'], self.settings['wait_factor']],
)

def _select_next_target(self) -> tuple[Optional[int], bool]:
try:
return self.targets.pop(), False
return act, self.targets.pop()
except IndexError:
return None, True
return False, None

def _collect_logs(self, observation):
for _, detector in self.attack_graph.nodes[
self.current_target
].detectors.items():
attack_step = self.attack_graph.nodes[self.current_target]
log = {
'timestamp': observation['timestamp'],
'_detector': detector.name,
'asset': str(attack_step.asset.name),
'attack_step': attack_step.name,
'agent': self.__class__.__name__,
#'context': {},
}

for label, lgasset in detector.context.items():
*_, asset = (
step.asset
for step in self.attack_graph.attackers[0].reached_attack_steps
if step.asset.type
in [subasset.name for subasset in lgasset.sub_assets]
)

log[label] = str(asset.name)

self.logs.append(log)

logger.info('Detector triggered on %s', attack_step.full_name)
logger.info(pprint.pformat(log))

def terminate(self):
self._write_logs()

def _write_logs(self):
with open('logs.json', 'w') as f:
json.dump(self.logs, f, indent=2)
self.logs = []


class DepthFirstAttacker(BreadthFirstAttacker):
Expand Down

0 comments on commit 92b1ff7

Please sign in to comment.