Skip to content

Commit

Permalink
run ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
kasanari committed Mar 1, 2024
1 parent cc1cb58 commit 405df8b
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 123 deletions.
15 changes: 7 additions & 8 deletions malpzsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@

from malpzsim.wrappers.wrapper import LazyWrapper
from malpzsim.wrappers.gym_wrapper import AttackerEnv, DefenderEnv, register_envs

"""
MAL Petting Zoo Simulator
"""

__title__ = 'malpzsim'
__version__ = '0.0.6'
__authors__ = ['Andrei Buhaiu',
'Jakob Nyberg']
__license__ = 'Apache 2.0'
__docformat__ = 'restructuredtext en'

__all__ = ('LazyWrapper', 'AttackerEnv', 'DefenderEnv', 'register_envs')
__title__ = "malpzsim"
__version__ = "0.0.6"
__authors__ = ["Andrei Buhaiu", "Jakob Nyberg"]
__license__ = "Apache 2.0"
__docformat__ = "restructuredtext en"

__all__ = ("LazyWrapper", "AttackerEnv", "DefenderEnv", "register_envs")
14 changes: 8 additions & 6 deletions malpzsim/agents/keyboard_input.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import numpy as np
import logging

AGENT_ATTACKER = 'attacker'
AGENT_DEFENDER = 'defender'
AGENT_ATTACKER = "attacker"
AGENT_DEFENDER = "defender"

logger = logging.getLogger(__name__)

null_action = (0, None)

class KeyboardAgent():

class KeyboardAgent:
def __init__(self, vocab):
logger.debug('Create Keyboard agent.')
logger.debug("Create Keyboard agent.")
self.vocab = vocab

def compute_action_from_dict(self, obs: dict, mask: tuple) -> tuple:
Expand Down Expand Up @@ -55,7 +56,8 @@ def get_action_object(user_input: str) -> tuple:
print("Invalid action.")

node, a = get_action_object(user_input)
print(f"Selected action: {action_strings[node] if node is not None else 'wait'}")
print(
f"Selected action: {action_strings[node] if node is not None else 'wait'}"
)

return (a, available_actions[node] if a != 0 else -1)

59 changes: 42 additions & 17 deletions malpzsim/agents/searchers.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,39 @@
import logging
import copy

from collections import deque
from typing import Any, Deque, Dict, List, Set, Type, Union
from typing import Any, Deque, Dict, List, Set, Union

import numpy as np

logger = logging.getLogger(__name__)

def get_new_targets(observation: dict, discovered_targets: Set[int], mask: tuple) -> List[int]:

def get_new_targets(
observation: dict, discovered_targets: Set[int], mask: tuple
) -> List[int]:
attack_surface = mask[1]
surface_indexes = list(np.flatnonzero(attack_surface))
new_targets = [idx for idx in surface_indexes if idx not in discovered_targets]
return new_targets, surface_indexes

class BreadthFirstAttacker():

class BreadthFirstAttacker:
def __init__(self, agent_config: dict) -> None:
self.targets: Deque[int] = deque([])
self.current_target: int = None
seed = agent_config["seed"] if agent_config.get("seed", None) else np.random.SeedSequence().entropy
self.rng = np.random.default_rng(seed) if agent_config.get("randomize", False) else None
seed = (
agent_config["seed"]
if agent_config.get("seed", None)
else np.random.SeedSequence().entropy
)
self.rng = (
np.random.default_rng(seed)
if agent_config.get("randomize", False)
else None
)

def compute_action_from_dict(self, observation: Dict[str, Any], mask: tuple):
new_targets, surface_indexes = get_new_targets(observation,
self.targets,
mask)
new_targets, surface_indexes = get_new_targets(observation, self.targets, mask)

# Add new targets to the back of the queue
# if desired, shuffle the new targets to make the attacker more unpredictable
Expand All @@ -40,14 +49,18 @@ def compute_action_from_dict(self, observation: Dict[str, Any], mask: tuple):
self.current_target = None if done else self.current_target
action = 0 if done else 1
if action == 0:
logger.debug('Attacker Breadth First agent does not have '
'any valid targets it will terminate')
logger.debug(
"Attacker Breadth First agent does not have "
"any valid targets it will terminate"
)

return (action, self.current_target)

@staticmethod
def select_next_target(
current_target: int, targets: Union[List[int], Deque[int]], attack_surface: Set[int]
current_target: int,
targets: Union[List[int], Deque[int]],
attack_surface: Set[int],
) -> int:
# If the current target was not compromised, put it
# back, but on the bottom of the stack.
Expand All @@ -63,12 +76,21 @@ def select_next_target(

return current_target, False

class DepthFirstAttacker():

class DepthFirstAttacker:
def __init__(self, agent_config: dict) -> None:
self.current_target = -1
self.targets: List[int] = []
seed = agent_config["seed"] if agent_config.get("seed", None) else np.random.SeedSequence().entropy
self.rng = np.random.default_rng(seed) if agent_config.get("randomize", False) else None
seed = (
agent_config["seed"]
if agent_config.get("seed", None)
else np.random.SeedSequence().entropy
)
self.rng = (
np.random.default_rng(seed)
if agent_config.get("randomize", False)
else None
)

def compute_action_from_dict(self, observation: Dict[str, Any], mask: tuple):
new_targets, surface_indexes = get_new_targets(observation, self.targets, mask)
Expand All @@ -89,7 +111,9 @@ def compute_action_from_dict(self, observation: Dict[str, Any], mask: tuple):

@staticmethod
def select_next_target(
current_target: int, targets: Union[List[int], Deque[int]], attack_surface: Set[int]
current_target: int,
targets: Union[List[int], Deque[int]],
attack_surface: Set[int],
) -> int:
if current_target in attack_surface:
return current_target, False
Expand All @@ -102,7 +126,8 @@ def select_next_target(

return current_target, False


AGENTS = {
BreadthFirstAttacker.__name__: BreadthFirstAttacker,
DepthFirstAttacker.__name__: DepthFirstAttacker,
}
}
10 changes: 3 additions & 7 deletions malpzsim/sims/mal_petting_zoo_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import copy
import logging
import functools
from typing import List, Tuple, Optional
from typing import Optional
import numpy as np

import maltoolbox
from maltoolbox.model.model import Model
from maltoolbox.language.languagegraph import LanguageGraph
from maltoolbox.attackgraph.attackgraph import AttackGraph
from maltoolbox.attackgraph.attacker import Attacker
from maltoolbox.attackgraph.node import AttackGraphNode
import maltoolbox.attackgraph.analyzers.apriori as apriori
import maltoolbox.attackgraph.query as query
from maltoolbox.ingestors import neo4j
Expand Down Expand Up @@ -76,11 +74,10 @@ def __init__(
self.init(self.max_iter)

def create_blank_observation(self):
num_actions = 2
# For now, an `object` is an attack step
num_objects = len(self.attack_graph.nodes)
num_lang_asset_types = len(self.lang_graph.assets)
num_lang_attack_steps = len(self.lang_graph.attack_steps)
len(self.lang_graph.assets)
len(self.lang_graph.attack_steps)

observation = {
"is_observable": num_objects * [1],
Expand Down Expand Up @@ -160,7 +157,6 @@ def _format_info(self, info):

@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
num_actions = 2
# For now, an `object` is an attack step
num_objects = len(self.attack_graph.nodes)
num_lang_asset_types = len(self.lang_graph.assets)
Expand Down
2 changes: 1 addition & 1 deletion malpzsim/wrappers/gym_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, SupportsFloat

import gymnasium as gym
from gymnasium.core import RenderFrame
import gymnasium.utils.env_checker as env_checker

import numpy as np
Expand Down Expand Up @@ -113,6 +112,7 @@ def register_envs():
gym.register("MALDefenderEnv-v0", entry_point=DefenderEnv)
gym.register("MALAttackerEnv-v0", entry_point=AttackerEnv)


if __name__ == "__main__":
gym.register("MALDefenderEnv-v0", entry_point=DefenderEnv)
env = gym.make(
Expand Down
37 changes: 18 additions & 19 deletions malpzsim/wrappers/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class LazyWrapper(ParallelEnv):

def __init__(self, **kwargs):
lang_file = kwargs.pop("lang_file")
model_file = kwargs.pop("model_file")
Expand Down Expand Up @@ -41,24 +40,24 @@ def __init__(self, **kwargs):

# TODO - This is a temporary fix to set the rewards for the nodes in the attack graph

sim.attack_graph.get_node_by_id('Application:0:notPresent').reward = 2
sim.attack_graph.get_node_by_id('Application:0:supplyChainAuditing').reward = 7
sim.attack_graph.get_node_by_id('Application:1:notPresent').reward = 3
sim.attack_graph.get_node_by_id('Application:1:supplyChainAuditing').reward = 7
sim.attack_graph.get_node_by_id('SoftwareVulnerability:2:notPresent').reward = 4
sim.attack_graph.get_node_by_id('Data:3:notPresent').reward = 1
sim.attack_graph.get_node_by_id('Credentials:4:notPhishable').reward = 7
sim.attack_graph.get_node_by_id('Identity:5:notPresent').reward = 3.5
sim.attack_graph.get_node_by_id('ConnectionRule:6:restricted').reward = 4
sim.attack_graph.get_node_by_id('ConnectionRule:6:payloadInspection').reward = 3
sim.attack_graph.get_node_by_id('Application:7:notPresent').reward = 2
sim.attack_graph.get_node_by_id('Application:7:supplyChainAuditing').reward = 7

sim.attack_graph.get_node_by_id('Application:0:fullAccess').reward = 5
sim.attack_graph.get_node_by_id('Application:1:fullAccess').reward = 2
sim.attack_graph.get_node_by_id('Identity:5:assume').reward = 2
sim.attack_graph.get_node_by_id('Application:7:fullAccess').reward = 6
sim.attack_graph.get_node_by_id("Application:0:notPresent").reward = 2
sim.attack_graph.get_node_by_id("Application:0:supplyChainAuditing").reward = 7
sim.attack_graph.get_node_by_id("Application:1:notPresent").reward = 3
sim.attack_graph.get_node_by_id("Application:1:supplyChainAuditing").reward = 7
sim.attack_graph.get_node_by_id("SoftwareVulnerability:2:notPresent").reward = 4
sim.attack_graph.get_node_by_id("Data:3:notPresent").reward = 1
sim.attack_graph.get_node_by_id("Credentials:4:notPhishable").reward = 7
sim.attack_graph.get_node_by_id("Identity:5:notPresent").reward = 3.5
sim.attack_graph.get_node_by_id("ConnectionRule:6:restricted").reward = 4
sim.attack_graph.get_node_by_id("ConnectionRule:6:payloadInspection").reward = 3
sim.attack_graph.get_node_by_id("Application:7:notPresent").reward = 2
sim.attack_graph.get_node_by_id("Application:7:supplyChainAuditing").reward = 7

sim.attack_graph.get_node_by_id("Application:0:fullAccess").reward = 5
sim.attack_graph.get_node_by_id("Application:1:fullAccess").reward = 2
sim.attack_graph.get_node_by_id("Identity:5:assume").reward = 2
sim.attack_graph.get_node_by_id("Application:7:fullAccess").reward = 6

self.sim = sim

def step(
Expand Down
21 changes: 9 additions & 12 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,29 @@

null_action = (0, None)

lang_file = 'org.mal-lang.coreLang-1.0.0.mar'
lang_file = "org.mal-lang.coreLang-1.0.0.mar"
lang_spec = specification.load_language_specification_from_mar(lang_file)
specification.save_language_specification_to_json(lang_spec, 'lang_spec.json')
specification.save_language_specification_to_json(lang_spec, "lang_spec.json")
lang_classes_factory = classes_factory.LanguageClassesFactory(lang_spec)
lang_classes_factory.create_classes()

lang_graph = mallanguagegraph.LanguageGraph()
lang_graph.generate_graph(lang_spec)

model = malmodel.Model('Test Model', lang_spec, lang_classes_factory)
model.load_from_file('example_model.json')
model = malmodel.Model("Test Model", lang_spec, lang_classes_factory)
model.load_from_file("example_model.json")

attack_graph = malattackgraph.AttackGraph()
attack_graph.generate_graph(lang_spec, model)
attack_graph.attach_attackers(model)
attack_graph.save_to_file('tmp/attack_graph.json')
attack_graph.save_to_file("tmp/attack_graph.json")

env = MalPettingZooSimulator(lang_graph,
model,
attack_graph,
max_iter = 5)
env = MalPettingZooSimulator(lang_graph, model, attack_graph, max_iter=5)

env.register_attacker('attacker', 0)
env.register_defender('defender')
env.register_attacker("attacker", 0)
env.register_defender("defender")

logger.debug('Run Parrallel API test.')
logger.debug("Run Parrallel API test.")
parallel_api_test(env, num_cycles=50)

env.close()
Loading

0 comments on commit 405df8b

Please sign in to comment.