Skip to content

Commit

Permalink
fix merging
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Nov 18, 2024
1 parent 4c44ec2 commit 6eab54b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 103 deletions.
103 changes: 1 addition & 102 deletions minigrid/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,106 +881,6 @@ def step(self, action):
reward += self.death_cost

return obs, reward, terminated, truncated, info

class StochasticActionWrapper(ActionWrapper):
"""
Add stochasticity to the actions
If a random action is provided, it is returned with probability `1 - prob`.
Else, a random action is sampled from the action space.
"""

def __init__(self, env=None, prob=0.9, random_action=None):
super().__init__(env)
self.prob = prob
self.random_action = random_action

def action(self, action):
""" """
if np.random.uniform() < self.prob:
return action
else:
if self.random_action is None:
return self.np_random.integers(0, high=6)
else:
return self.random_action


class NoDeath(Wrapper):
"""
Wrapper to prevent death in specific cells (e.g., lava cells).
Instead of dying, the agent will receive a negative reward.
Example:
>>> import gymnasium as gym
>>> from minigrid.wrappers import NoDeath
>>>
>>> env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
>>> _, _ = env.reset(seed=2)
>>> _, _, _, _, _ = env.step(1)
>>> _, reward, term, *_ = env.step(2)
>>> reward, term
(0, True)
>>>
>>> env = NoDeath(env, no_death_types=("lava",), death_cost=-1.0)
>>> _, _ = env.reset(seed=2)
>>> _, _, _, _, _ = env.step(1)
>>> _, reward, term, *_ = env.step(2)
>>> reward, term
(-1.0, False)
>>>
>>>
>>> env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
>>> _, _ = env.reset(seed=2)
>>> _, reward, term, *_ = env.step(2)
>>> reward, term
(-1, True)
>>>
>>> env = NoDeath(env, no_death_types=("ball",), death_cost=-1.0)
>>> _, _ = env.reset(seed=2)
>>> _, reward, term, *_ = env.step(2)
>>> reward, term
(-2.0, False)
"""

def __init__(self, env, no_death_types: tuple[str, ...], death_cost: float = -1.0):
"""A wrapper to prevent death in specific cells.
Args:
env: The environment to apply the wrapper
no_death_types: List of strings to identify death cells
death_cost: The negative reward received in death cells
"""
assert "goal" not in no_death_types, "goal cannot be a death cell"

super().__init__(env)
self.death_cost = death_cost
self.no_death_types = no_death_types

def step(self, action):
# In Dynamic-Obstacles, obstacles move after the agent moves,
# so we need to check for collision before self.env.step()
front_cell = self.unwrapped.grid.get(*self.unwrapped.front_pos)
going_to_death = (
action == self.unwrapped.actions.forward
and front_cell is not None
and front_cell.type in self.no_death_types
)

obs, reward, terminated, truncated, info = self.env.step(action)

# We also check if the agent stays in death cells (e.g., lava)
# without moving
current_cell = self.unwrapped.grid.get(*self.unwrapped.agent_pos)
in_death = current_cell is not None and current_cell.type in self.no_death_types

if terminated and (going_to_death or in_death):
terminated = False
reward += self.death_cost

return obs, reward, terminated, truncated, info


class MoveActionWrapper(Wrapper):
"""
Expand Down Expand Up @@ -1010,8 +910,7 @@ def step(self, action):
else:
for _ in range(left_turns):
self.env.step(0)

return self.env.step(2)
else:
return self.env.step(action - 1)

2 changes: 1 addition & 1 deletion tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
FlatObsWrapper,
FullyObsWrapper,
ImgObsWrapper,
NoDeath,
MoveActionWrapper,
NoDeath,
OneHotPartialObsWrapper,
PositionBonus,
ReseedWrapper,
Expand Down

0 comments on commit 6eab54b

Please sign in to comment.