Skip to content

Commit

Permalink
Compatibility with lookup dict rewrite of mal toolbox
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikolaos Kakouros committed Jan 16, 2025
1 parent d34c367 commit ed2fc8d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
30 changes: 15 additions & 15 deletions malsim/sims/mal_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,19 @@ def create_blank_observation(self, default_obs_state=-1):
observation = {
# If no observability set for node, assume observable.
"is_observable": [step.extras.get('observable', 1)
for step in self.attack_graph.nodes],
for step in self.attack_graph.nodes.values()],
# Same goes for actionable.
"is_actionable": [step.extras.get('actionable', 1)
for step in self.attack_graph.nodes],
for step in self.attack_graph.nodes.values()],
"observed_state": num_steps * [default_obs_state],
"remaining_ttc": num_steps * [0],
"asset_type": [self._asset_type_to_index[step.asset.type]
for step in self.attack_graph.nodes],
for step in self.attack_graph.nodes.values()],
"asset_id": [step.asset.id
for step in self.attack_graph.nodes],
for step in self.attack_graph.nodes.values()],
"step_name": [self._step_name_to_index[
str(step.asset.type + ":" + step.name)]
for step in self.attack_graph.nodes],
for step in self.attack_graph.nodes.values()],
}

logger.debug(
Expand All @@ -127,13 +127,13 @@ def create_blank_observation(self, default_obs_state=-1):
# Add attack graph edges to observation
observation["attack_graph_edges"] = [
[self._id_to_index[attack_step.id], self._id_to_index[child.id]]
for attack_step in self.attack_graph.nodes
for attack_step in self.attack_graph.nodes.values()
for child in attack_step.children
]

# Add reverse attack graph edges for defense steps (required by some
# defender agent logic)
for attack_step in self.attack_graph.nodes:
for attack_step in self.attack_graph.nodes.values():
if attack_step.type == "defense":
for child in attack_step.children:
observation["attack_graph_edges"].append(
Expand Down Expand Up @@ -520,9 +520,9 @@ def _create_mapping_tables(self):
logger.debug("Creating and listing mapping tables.")

# Lookup lists index to attribute
self._index_to_id = [n.id for n in self.attack_graph.nodes]
self._index_to_id = [n.id for n in self.attack_graph.nodes.values()]
self._index_to_full_name = [n.full_name
for n in self.attack_graph.nodes]
for n in self.attack_graph.nodes.values()]
self._index_to_asset_type = [n.name for n in self.lang_graph.assets.values()]
self._index_to_step_name = [
n.full_name
Expand Down Expand Up @@ -575,7 +575,7 @@ def index_to_node(self, index: int) -> AttackGraphNode:
)

node_id = self._index_to_id[index]
node = self.attack_graph.get_node_by_id(node_id)
node = self.attack_graph.nodes[node_id]
if not node:
raise LookupError(
f'Index {index} (id: {node_id}), does not map to a node'
Expand Down Expand Up @@ -654,7 +654,7 @@ def _initialize_agents(self) -> dict[str, list[int]]:

# Initial actions for defender are all pre-enabled defenses
initial_actions[agent] = [self._id_to_index[node.id]
for node in self.attack_graph.nodes
for node in self.attack_graph.nodes.values()
if node.is_enabled_defense()]

else:
Expand Down Expand Up @@ -838,9 +838,9 @@ def _defender_step(
) -> tuple[list[int], list[AttackGraphNode]]:

actions = []
defense_step_node = self.attack_graph.get_node_by_id(
defense_step_node = self.attack_graph.nodes[
self._index_to_id[defense_step_index]
)
]
logger.info(
'Defender agent "%s" stepping through "%s"(%d).',
agent,
Expand Down Expand Up @@ -920,7 +920,7 @@ def _observe_attacker(
continue

node_id = self._index_to_id[step_index]
node = self.attack_graph.get_node_by_id(node_id)
node = self.attack_graph.nodes[node_id]
if node.type in ('or', 'and'):
# Attack step activated, set to 1 (enabled)
obs_state[step_index] = 1
Expand Down Expand Up @@ -977,7 +977,7 @@ def _reward_agents(self, performed_actions):
continue

node_id = self._index_to_id[action]
node = self.attack_graph.get_node_by_id(node_id)
node = self.attack_graph.nodes[node_id]
node_reward = node.extras.get('reward', 0)

if agent_type == "attacker":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_example_scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_bfs_vs_bfs_state_and_reward():
assert defender_actions == [68, 249, 324, 325, 349, 350, 396, 397, 421, 422, 423, 457, 0, 31, 88, 113, 144, 181, 212, 252, 276, 326, 327, 351, 352, 374]

for step_index in attacker_actions:
node = sim.attack_graph.get_node_by_id(sim._index_to_id[step_index])
node = sim.attack_graph.nodes[sim._index_to_id[step_index]]
if node.is_compromised():
assert obs[defender_agent_id]['observed_state'][step_index]

Expand Down
10 changes: 5 additions & 5 deletions tests/test_mal_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_malsimulator_create_blank_observation_actionability_given(

for index, actionable in enumerate(blank_observation['is_actionable']):
node_id = sim._index_to_id[index]
node = sim.attack_graph.get_node_by_id(node_id)
node = sim.attack_graph.nodes[node_id]

# Below are the rules from the traininglang observability scenario
# made into if statements
Expand Down Expand Up @@ -760,7 +760,7 @@ def verify_attacker_obs_state(
"""Make sure obs state looks as expected"""
for index, state in enumerate(obs_state):
node_id = sim._index_to_id[index]
node = sim.attack_graph.get_node_by_id(node_id)
node = sim.attack_graph.nodes[node_id]
if state == 1:
assert node_id in expected_reached
elif state == 0:
Expand Down Expand Up @@ -959,7 +959,7 @@ def test_simulator_default_settings_defender_observation():
# Verify that all states in obs match the state of the attack graph
for index, state in enumerate(defender_observation):
step_id = sim._index_to_id[index]
node = sim.attack_graph.get_node_by_id(step_id)
node = sim.attack_graph.nodes[step_id]
if state == 1:
assert node.is_compromised()
else:
Expand All @@ -976,7 +976,7 @@ def test_simulator_default_settings_defender_observation():
# Verify that all states in obs match the state of the attack graph
for index, state in enumerate(defender_observation):
step_id = sim._index_to_id[index]
node = sim.attack_graph.get_node_by_id(step_id)
node = sim.attack_graph.nodes[step_id]
if state == 1:
assert node.is_compromised() or node.is_enabled_defense()
else:
Expand Down Expand Up @@ -1023,7 +1023,7 @@ def test_simulator_settings_defender_observation():
# is the latest performed step (User:3:compromise)
for index, state in enumerate(defender_observation):
step_id = sim._index_to_id[index]
node = sim.attack_graph.get_node_by_id(step_id)
node = sim.attack_graph.nodes[step_id]
if node == user_3_compromise:
assert state == 1 # Last performed step known active state
else:
Expand Down

0 comments on commit ed2fc8d

Please sign in to comment.