Skip to content

Commit

Permalink
gnn compatibility wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
kasanari committed Mar 5, 2024
1 parent e9e1599 commit 91e8593
Showing 1 changed file with 48 additions and 1 deletion.
49 changes: 48 additions & 1 deletion malpzsim/wrappers/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import gymnasium as gym
import gymnasium.utils.env_checker as env_checker

from gymnasium import Wrapper
from gymnasium import spaces
from gymnasium.core import RenderFrame
import numpy as np

from malpzsim.wrappers.wrapper import LazyWrapper
Expand Down Expand Up @@ -125,7 +127,52 @@ def num_assets(self):
def num_step_names(self):
return self.env.sim.num_step_names

def _to_binary(val, max_val):
return np.array(list(np.binary_repr(val, width=max_val.bit_length())), dtype=np.int64)


def vec_to_binary(vec, max_val):
return np.array([_to_binary(val, max_val) for val in vec])


class LabeledGraphWrapper(Wrapper):
def __init__(self, env: gym.Env) -> None:
super().__init__(env)

self.num_assets = self.env.unwrapped.num_assets
self.num_steps = self.env.unwrapped.num_step_names
num_nodes = self.env.observation_space['observed_state'].shape[0]
num_commands = 2
self.observation_space = spaces.Dict({
"nodes": spaces.Box(0, 1, shape=(num_nodes, (3).bit_length() + self.num_assets.bit_length() + self.num_steps.bit_length()), dtype=np.int8),
"edges": self.env.observation_space["edges"],
"mask_0": spaces.Box(0, 1, shape=(num_commands,), dtype=np.int8),
"mask_1": spaces.Box(0, 1, shape=(num_nodes,), dtype=np.int8),
}
)
pass

def render(self) -> RenderFrame | list[RenderFrame] | None:
return self.env.render()

def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]:
obs, info = self.env.reset(**kwargs)
return self._to_graph(obs, info), info

def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
obs, reward, terminated, truncated, info = self.env.step(action)
return self._to_graph(obs, info), reward, terminated, truncated, info

def _to_graph(self, obs: dict[str, Any], info) -> Dict[str, Any]:
nodes = np.concatenate(
[
vec_to_binary(obs["observed_state"] + 1, 3),
vec_to_binary(obs["asset_type"], self.num_assets),
vec_to_binary(obs["step_name"], self.num_steps),
],
axis=1
)
return {"nodes": nodes, "edges": obs["edges"], "mask_0": info['action_mask'][0], "mask_1": info['action_mask'][1]}

def register_envs():
gym.register("MALDefenderEnv-v0", entry_point=DefenderEnv)
Expand Down

0 comments on commit 91e8593

Please sign in to comment.