Skip to content

Commit

Permalink
Merge pull request #1 from kasanari/main
Browse files Browse the repository at this point in the history
Add single player gym wrappers, and make sim compatible with gym check
  • Loading branch information
andrewbwm authored Mar 1, 2024
2 parents ba0098a + ea8ec4b commit efebe4d
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 8 deletions.
5 changes: 3 additions & 2 deletions malpzsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# limitations under the License.
#


from malpzsim.wrappers.wrapper import LazyWrapper
from malpzsim.wrappers.gym_wrapper import AttackerEnv, DefenderEnv
"""
MAL Petting Zoo Simulator
"""
Expand All @@ -27,5 +28,5 @@
__license__ = 'Apache 2.0'
__docformat__ = 'restructuredtext en'

__all__ = ()
__all__ = ('LazyWrapper', 'AttackerEnv', 'DefenderEnv')

5 changes: 5 additions & 0 deletions malpzsim/agents/searchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,8 @@ def select_next_target(
current_target = targets.pop()

return current_target, False

AGENTS = {
BreadthFirstAttacker.__name__: BreadthFirstAttacker,
DepthFirstAttacker.__name__: DepthFirstAttacker,
}
20 changes: 14 additions & 6 deletions malpzsim/sims/mal_petting_zoo_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def create_blank_observation(self):
num_lang_attack_steps = len(self.lang_graph.attack_steps)

observation = {
'action' : num_actions * [0],
'step' : num_objects * [0],
'is_observable' : num_objects * [1],
'observed_state' : num_objects * [-1],
'remaining_ttc' : num_objects * [0]
Expand Down Expand Up @@ -80,15 +78,25 @@ def create_blank_observation(self):
[self._id_to_index[attack_step.id],
self._id_to_index[child.id]])

return observation
np_obs = {
'is_observable' : np.array(observation['is_observable'], dtype=np.int8),
'observed_state' : np.array(observation['observed_state'], dtype=np.int8),
'remaining_ttc' : np.array(observation['remaining_ttc'], dtype=np.int64),
'asset_type' : np.array(observation['asset_type'], dtype=np.int64),
'asset_id' : np.array(observation['asset_id'], dtype=np.int64),
'step_name' : np.array(observation['step_name'], dtype=np.int64),
'edges' : np.array(observation['edges'], dtype=np.int64)
}

return np_obs

def _format_full_observation(self, observation):
'''
Return a formatted string of the entire observation. This includes
sections that will not change over time, these define the structure of
the attack graph.
'''
obs_str = f'Action: {observation["action"]}\n'
obs_str = f'Action: {observation.get("action", "")}\n'

str_format = '{:<5} {:<6} {:<5} {:<5} {:<5} {:<5} {:<}\n'
header = str_format.format(
Expand Down Expand Up @@ -184,7 +192,7 @@ def observation_space(self, agent):
'edges': Box(
0,
num_objects,
shape=(2, num_edges),
shape=(num_edges, 2),
dtype=np.int64,
), # edges between steps
}
Expand All @@ -195,7 +203,7 @@ def action_space(self, agent):
num_actions = 2 # two actions: wait or use
# For now, an `object` is an attack step
num_objects = len(self.attack_graph.nodes)
return MultiDiscrete([num_actions, num_objects])
return MultiDiscrete([num_actions, num_objects], dtype=np.int64)

def reset(self,
seed: Optional[int] = None,
Expand Down
127 changes: 127 additions & 0 deletions malpzsim/wrappers/gym_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Any, Dict, SupportsFloat

import gymnasium as gym
from gymnasium.core import RenderFrame
import gymnasium.utils.env_checker as env_checker

import numpy as np

from malpzsim.wrappers.wrapper import LazyWrapper
from malpzsim.agents import searchers


AGENT_ATTACKER = "attacker"
AGENT_DEFENDER = "defender"


class AttackerEnv(gym.Env):
metadata = {"render_modes": []}

def __init__(self, **kwargs: Any) -> None:
self.render_mode = kwargs.pop("render_mode", None)
agents = {AGENT_ATTACKER: AGENT_ATTACKER}
self.env = LazyWrapper(agents=agents, **kwargs)
self.observation_space = self.env.observation_space(AGENT_ATTACKER)
self.action_space = self.env.action_space(AGENT_ATTACKER)
super().__init__()

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
super().reset(seed=seed, options=options)
obs, info = self.env.reset(seed=seed, options=options)
return obs[AGENT_ATTACKER], info[AGENT_ATTACKER]

def step(
self, action: Any
) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
obs: Dict[str, Any]
obs, rewards, terminated, truncated, infos = self.env.step(
{AGENT_ATTACKER: action}
)
return (
obs[AGENT_ATTACKER],
rewards[AGENT_ATTACKER],
terminated[AGENT_ATTACKER],
truncated[AGENT_ATTACKER],
infos[AGENT_ATTACKER],
)

def render(self):
return self.env.render()


class DefenderEnv(gym.Env):
metadata = {"render_modes": []}

def __init__(self, **kwargs: Dict[str, Any]) -> None:
attacker_class: str = str(kwargs.pop("attacker_class", "BreadthFirstAttacker"))
self.randomize = kwargs.pop("randomize_attacker_behavior", False)
self.render_mode = kwargs.pop("render_mode", None)
agents = {AGENT_ATTACKER: AGENT_ATTACKER, AGENT_DEFENDER: AGENT_DEFENDER}
self.env = LazyWrapper(agents=agents, **kwargs)
self.attacker_class = searchers.AGENTS[attacker_class]
self.observation_space = self.env.observation_space(AGENT_DEFENDER)
self.action_space = self.env.action_space(AGENT_DEFENDER)

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
super().reset(seed=seed, options=options)
self.attacker = self.attacker_class({"seed": seed, "randomize": self.randomize})
obs, info = self.env.reset(seed=seed, options=options)
self.attacker_obs = obs[AGENT_ATTACKER]
self.attacker_mask = info[AGENT_ATTACKER]["action_mask"]
return obs[AGENT_DEFENDER], info[AGENT_DEFENDER]

def step(
self, action: Any
) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
attacker_action = self.attacker.compute_action_from_dict(
self.attacker_obs, self.attacker_mask
)
obs: Dict[str, Any]
obs, rewards, terminated, truncated, infos = self.env.step(
{AGENT_DEFENDER: action, AGENT_ATTACKER: attacker_action}
)
self.attacker_obs = obs[AGENT_ATTACKER]
self.attacker_mask = infos[AGENT_ATTACKER]["action_mask"]
return (
obs[AGENT_DEFENDER],
rewards[AGENT_DEFENDER],
terminated[AGENT_DEFENDER],
truncated[AGENT_DEFENDER],
infos[AGENT_DEFENDER],
)

def render(self):
return self.env.render()

@staticmethod
def add_reverse_edges(edges: np.ndarray, defense_steps: set) -> np.ndarray:
# Add reverse edges from the defense steps children to the defense steps
# themselves
if defense_steps is not None:
for p, c in zip(edges[0, :], edges[1, :]):
if p in defense_steps:
new_edge = np.array([c, p]).reshape((2, 1))
edges = np.concatenate((edges, new_edge), axis=1)
return edges


if __name__ == "__main__":
gym.register("DefenderEnv-v0", entry_point=DefenderEnv)
env = gym.make(
"DefenderEnv-v0",
model_file="/storage/GitHub/mal-petting-zoo-simulator/tests/example_model.json",
lang_file="/storage/GitHub/mal-petting-zoo-simulator/tests/org.mal-lang.coreLang-1.0.0.mar",
)
env_checker.check_env(env.unwrapped)

gym.register("AttackerEnv-v0", entry_point=AttackerEnv)
env = gym.make(
"AttackerEnv-v0",
model_file="/storage/GitHub/mal-petting-zoo-simulator/tests/example_model.json",
lang_file="/storage/GitHub/mal-petting-zoo-simulator/tests/org.mal-lang.coreLang-1.0.0.mar",
)
env_checker.check_env(env.unwrapped)
60 changes: 60 additions & 0 deletions malpzsim/wrappers/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from gymnasium.spaces.space import Space
from pettingzoo.utils.env import ParallelEnv

from maltoolbox.language import classes_factory
from maltoolbox.language import specification
from maltoolbox.language import languagegraph as mallanguagegraph
from maltoolbox.attackgraph import attackgraph as malattackgraph
from maltoolbox.model import model as malmodel

from malpzsim.sims.mal_petting_zoo_simulator import MalPettingZooSimulator
from typing import Any


class LazyWrapper(ParallelEnv):

def __init__(self, **kwargs):
lang_file = kwargs.pop("lang_file")
model_file = kwargs.pop("model_file")
agents = kwargs.pop("agents", {})
lang_spec = specification.load_language_specification_from_mar(str(lang_file))
lang_classes_factory = classes_factory.LanguageClassesFactory(lang_spec)
lang_classes_factory.create_classes()

lang_graph = mallanguagegraph.LanguageGraph()
lang_graph.generate_graph(lang_spec)

model = malmodel.Model("Test Model", lang_spec, lang_classes_factory)
model.load_from_file(model_file)

attack_graph = malattackgraph.AttackGraph()
attack_graph.generate_graph(lang_spec, model)
attack_graph.attach_attackers(model)

sim = MalPettingZooSimulator(lang_graph, model, attack_graph, **kwargs)

for agent_class, agent_id in agents.items():
if agent_class == "attacker":
sim.register_attacker(agent_id, 0)
elif agent_class == "defender":
sim.register_defender(agent_id)

self.sim = sim

def step(
self, actions: dict
) -> tuple[
dict, dict[Any, float], dict[Any, bool], dict[Any, bool], dict[Any, dict]
]:
return self.sim.step(actions)

def reset(
self, seed: int | None = None, options: dict | None = None
) -> tuple[dict, dict[Any, dict]]:
return self.sim.reset(seed, options)

def observation_space(self, agent: Any) -> Space:
return self.sim.observation_space(agent)

def action_space(self, agent: Any) -> Space:
return self.sim.action_space(agent)

0 comments on commit efebe4d

Please sign in to comment.