diff --git a/malpzsim/__init__.py b/malpzsim/__init__.py index 2d607701..e1213c36 100644 --- a/malpzsim/__init__.py +++ b/malpzsim/__init__.py @@ -15,7 +15,8 @@ # limitations under the License. # - +from malpzsim.wrappers.wrapper import LazyWrapper +from malpzsim.wrappers.gym_wrapper import AttackerEnv, DefenderEnv """ MAL Petting Zoo Simulator """ @@ -27,5 +28,5 @@ __license__ = 'Apache 2.0' __docformat__ = 'restructuredtext en' -__all__ = () +__all__ = ('LazyWrapper', 'AttackerEnv', 'DefenderEnv') diff --git a/malpzsim/agents/searchers.py b/malpzsim/agents/searchers.py index 1993e288..ae9a1770 100644 --- a/malpzsim/agents/searchers.py +++ b/malpzsim/agents/searchers.py @@ -101,3 +101,8 @@ def select_next_target( current_target = targets.pop() return current_target, False + +AGENTS = { + BreadthFirstAttacker.__name__: BreadthFirstAttacker, + DepthFirstAttacker.__name__: DepthFirstAttacker, +} \ No newline at end of file diff --git a/malpzsim/sims/mal_petting_zoo_simulator.py b/malpzsim/sims/mal_petting_zoo_simulator.py index ad18bc4f..0e10ff54 100644 --- a/malpzsim/sims/mal_petting_zoo_simulator.py +++ b/malpzsim/sims/mal_petting_zoo_simulator.py @@ -51,8 +51,6 @@ def create_blank_observation(self): num_lang_attack_steps = len(self.lang_graph.attack_steps) observation = { - 'action' : num_actions * [0], - 'step' : num_objects * [0], 'is_observable' : num_objects * [1], 'observed_state' : num_objects * [-1], 'remaining_ttc' : num_objects * [0] @@ -80,7 +78,17 @@ def create_blank_observation(self): [self._id_to_index[attack_step.id], self._id_to_index[child.id]]) - return observation + np_obs = { + 'is_observable' : np.array(observation['is_observable'], dtype=np.int8), + 'observed_state' : np.array(observation['observed_state'], dtype=np.int8), + 'remaining_ttc' : np.array(observation['remaining_ttc'], dtype=np.int64), + 'asset_type' : np.array(observation['asset_type'], dtype=np.int64), + 'asset_id' : np.array(observation['asset_id'], dtype=np.int64), + 'step_name' : np.array(observation['step_name'], dtype=np.int64), + 'edges' : np.array(observation['edges'], dtype=np.int64) + } + + return np_obs def _format_full_observation(self, observation): ''' @@ -88,7 +96,7 @@ def _format_full_observation(self, observation): sections that will not change over time, these define the structure of the attack graph. ''' - obs_str = f'Action: {observation["action"]}\n' + obs_str = f'Action: {observation.get("action", "")}\n' str_format = '{:<5} {:<6} {:<5} {:<5} {:<5} {:<5} {:<}\n' header = str_format.format( @@ -184,7 +192,7 @@ def observation_space(self, agent): 'edges': Box( 0, num_objects, - shape=(2, num_edges), + shape=(num_edges, 2), dtype=np.int64, ), # edges between steps } @@ -195,7 +203,7 @@ def action_space(self, agent): num_actions = 2 # two actions: wait or use # For now, an `object` is an attack step num_objects = len(self.attack_graph.nodes) - return MultiDiscrete([num_actions, num_objects]) + return MultiDiscrete([num_actions, num_objects], dtype=np.int64) def reset(self, seed: Optional[int] = None, diff --git a/malpzsim/wrappers/gym_wrapper.py b/malpzsim/wrappers/gym_wrapper.py new file mode 100644 index 00000000..765f3531 --- /dev/null +++ b/malpzsim/wrappers/gym_wrapper.py @@ -0,0 +1,127 @@ +from typing import Any, Dict, SupportsFloat + +import gymnasium as gym +from gymnasium.core import RenderFrame +import gymnasium.utils.env_checker as env_checker + +import numpy as np + +from malpzsim.wrappers.wrapper import LazyWrapper +from malpzsim.agents import searchers + + +AGENT_ATTACKER = "attacker" +AGENT_DEFENDER = "defender" + + +class AttackerEnv(gym.Env): + metadata = {"render_modes": []} + + def __init__(self, **kwargs: Any) -> None: + self.render_mode = kwargs.pop("render_mode", None) + agents = {AGENT_ATTACKER: AGENT_ATTACKER} + self.env = LazyWrapper(agents=agents, **kwargs) + self.observation_space = self.env.observation_space(AGENT_ATTACKER) + self.action_space = self.env.action_space(AGENT_ATTACKER) + super().__init__() + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Any, dict[str, Any]]: + super().reset(seed=seed, options=options) + obs, info = self.env.reset(seed=seed, options=options) + return obs[AGENT_ATTACKER], info[AGENT_ATTACKER] + + def step( + self, action: Any + ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: + obs: Dict[str, Any] + obs, rewards, terminated, truncated, infos = self.env.step( + {AGENT_ATTACKER: action} + ) + return ( + obs[AGENT_ATTACKER], + rewards[AGENT_ATTACKER], + terminated[AGENT_ATTACKER], + truncated[AGENT_ATTACKER], + infos[AGENT_ATTACKER], + ) + + def render(self): + return self.env.render() + + +class DefenderEnv(gym.Env): + metadata = {"render_modes": []} + + def __init__(self, **kwargs: Dict[str, Any]) -> None: + attacker_class: str = str(kwargs.pop("attacker_class", "BreadthFirstAttacker")) + self.randomize = kwargs.pop("randomize_attacker_behavior", False) + self.render_mode = kwargs.pop("render_mode", None) + agents = {AGENT_ATTACKER: AGENT_ATTACKER, AGENT_DEFENDER: AGENT_DEFENDER} + self.env = LazyWrapper(agents=agents, **kwargs) + self.attacker_class = searchers.AGENTS[attacker_class] + self.observation_space = self.env.observation_space(AGENT_DEFENDER) + self.action_space = self.env.action_space(AGENT_DEFENDER) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Any, dict[str, Any]]: + super().reset(seed=seed, options=options) + self.attacker = self.attacker_class({"seed": seed, "randomize": self.randomize}) + obs, info = self.env.reset(seed=seed, options=options) + self.attacker_obs = obs[AGENT_ATTACKER] + self.attacker_mask = info[AGENT_ATTACKER]["action_mask"] + return obs[AGENT_DEFENDER], info[AGENT_DEFENDER] + + def step( + self, action: Any + ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: + attacker_action = self.attacker.compute_action_from_dict( + self.attacker_obs, self.attacker_mask + ) + obs: Dict[str, Any] + obs, rewards, terminated, truncated, infos = self.env.step( + {AGENT_DEFENDER: action, AGENT_ATTACKER: attacker_action} + ) + self.attacker_obs = obs[AGENT_ATTACKER] + self.attacker_mask = infos[AGENT_ATTACKER]["action_mask"] + return ( + obs[AGENT_DEFENDER], + rewards[AGENT_DEFENDER], + terminated[AGENT_DEFENDER], + truncated[AGENT_DEFENDER], + infos[AGENT_DEFENDER], + ) + + def render(self): + return self.env.render() + + @staticmethod + def add_reverse_edges(edges: np.ndarray, defense_steps: set) -> np.ndarray: + # Add reverse edges from the defense steps children to the defense steps + # themselves + if defense_steps is not None: + for p, c in zip(edges[0, :], edges[1, :]): + if p in defense_steps: + new_edge = np.array([c, p]).reshape((2, 1)) + edges = np.concatenate((edges, new_edge), axis=1) + return edges + + +if __name__ == "__main__": + gym.register("DefenderEnv-v0", entry_point=DefenderEnv) + env = gym.make( + "DefenderEnv-v0", + model_file="/storage/GitHub/mal-petting-zoo-simulator/tests/example_model.json", + lang_file="/storage/GitHub/mal-petting-zoo-simulator/tests/org.mal-lang.coreLang-1.0.0.mar", + ) + env_checker.check_env(env.unwrapped) + + gym.register("AttackerEnv-v0", entry_point=AttackerEnv) + env = gym.make( + "AttackerEnv-v0", + model_file="/storage/GitHub/mal-petting-zoo-simulator/tests/example_model.json", + lang_file="/storage/GitHub/mal-petting-zoo-simulator/tests/org.mal-lang.coreLang-1.0.0.mar", + ) + env_checker.check_env(env.unwrapped) diff --git a/malpzsim/wrappers/wrapper.py b/malpzsim/wrappers/wrapper.py new file mode 100644 index 00000000..ebf5b27d --- /dev/null +++ b/malpzsim/wrappers/wrapper.py @@ -0,0 +1,60 @@ +from gymnasium.spaces.space import Space +from pettingzoo.utils.env import ParallelEnv + +from maltoolbox.language import classes_factory +from maltoolbox.language import specification +from maltoolbox.language import languagegraph as mallanguagegraph +from maltoolbox.attackgraph import attackgraph as malattackgraph +from maltoolbox.model import model as malmodel + +from malpzsim.sims.mal_petting_zoo_simulator import MalPettingZooSimulator +from typing import Any + + +class LazyWrapper(ParallelEnv): + + def __init__(self, **kwargs): + lang_file = kwargs.pop("lang_file") + model_file = kwargs.pop("model_file") + agents = kwargs.pop("agents", {}) + lang_spec = specification.load_language_specification_from_mar(str(lang_file)) + lang_classes_factory = classes_factory.LanguageClassesFactory(lang_spec) + lang_classes_factory.create_classes() + + lang_graph = mallanguagegraph.LanguageGraph() + lang_graph.generate_graph(lang_spec) + + model = malmodel.Model("Test Model", lang_spec, lang_classes_factory) + model.load_from_file(model_file) + + attack_graph = malattackgraph.AttackGraph() + attack_graph.generate_graph(lang_spec, model) + attack_graph.attach_attackers(model) + + sim = MalPettingZooSimulator(lang_graph, model, attack_graph, **kwargs) + + for agent_class, agent_id in agents.items(): + if agent_class == "attacker": + sim.register_attacker(agent_id, 0) + elif agent_class == "defender": + sim.register_defender(agent_id) + + self.sim = sim + + def step( + self, actions: dict + ) -> tuple[ + dict, dict[Any, float], dict[Any, bool], dict[Any, bool], dict[Any, dict] + ]: + return self.sim.step(actions) + + def reset( + self, seed: int | None = None, options: dict | None = None + ) -> tuple[dict, dict[Any, dict]]: + return self.sim.reset(seed, options) + + def observation_space(self, agent: Any) -> Space: + return self.sim.observation_space(agent) + + def action_space(self, agent: Any) -> Space: + return self.sim.action_space(agent)