diff --git a/highway_env/__init__.py b/highway_env/__init__.py index 7241ff142..2493d128a 100644 --- a/highway_env/__init__.py +++ b/highway_env/__init__.py @@ -16,6 +16,7 @@ os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" from gymnasium.envs.registration import register +from highway_env.envs.common.abstract import MultiAgentWrapper def register_highway_envs(): @@ -56,7 +57,8 @@ def register_highway_envs(): register( id="intersection-multi-agent-v1", - entry_point="highway_env.envs:TupleMultiAgentIntersectionEnv", + entry_point="highway_env.envs:MultiAgentIntersectionEnv", + additional_wrappers=(MultiAgentWrapper.wrapper_spec(),), ) # lane_keeping_env.py diff --git a/highway_env/envs/intersection_env.py b/highway_env/envs/intersection_env.py index 4b512cec9..d55e32035 100644 --- a/highway_env/envs/intersection_env.py +++ b/highway_env/envs/intersection_env.py @@ -3,7 +3,7 @@ import numpy as np from highway_env import utils -from highway_env.envs.common.abstract import AbstractEnv, MultiAgentWrapper +from highway_env.envs.common.abstract import AbstractEnv from highway_env.road.lane import AbstractLane, CircularLane, LineType, StraightLane from highway_env.road.regulation import RegulatedRoad from highway_env.road.road import RoadNetwork @@ -423,6 +423,3 @@ def default_config(cls) -> dict: } ) return config - - -TupleMultiAgentIntersectionEnv = MultiAgentWrapper(MultiAgentIntersectionEnv)