Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests for mps support #1

Open
wants to merge 50 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
ace0516
Use MPS device when available
araffin Jul 4, 2022
9ac6225
Merge branch 'master' into feat/mps-support
araffin Aug 13, 2022
2dcbef9
Update test
araffin Aug 13, 2022
b00ca7f
Merge branch 'master' into feat/mps-support
araffin Aug 16, 2022
06a2124
Merge branch 'master' into feat/mps-support
qgallouedec Sep 28, 2022
6d868c0
Merge branch 'master' into feat/mps-support
qgallouedec Oct 4, 2022
8d79e96
Merge branch 'master' into feat/mps-support
qgallouedec Oct 7, 2022
3276cb0
Merge branch 'master' into feat/mps-support
qgallouedec Oct 10, 2022
f4f6073
Merge branch 'master' into feat/mps-support
qgallouedec Oct 14, 2022
64327c7
Merge branch 'master' into feat/mps-support
qgallouedec Oct 17, 2022
0344c3c
Merge branch 'master' into feat/mps-support
araffin Oct 24, 2022
fa196ab
Merge branch 'master' into feat/mps-support
qgallouedec Nov 2, 2022
efd086e
Merge branch 'master' into feat/mps-support
araffin Nov 18, 2022
7f11843
Merge branch 'master' into feat/mps-support
qgallouedec Dec 7, 2022
c60f681
Merge branch 'master' into feat/mps-support
qgallouedec Dec 20, 2022
92e8d11
Merge branch 'master' into feat/mps-support
araffin Jan 13, 2023
b235c8e
Merge branch 'master' into feat/mps-support
qgallouedec Feb 14, 2023
d4d0536
Merge branch 'master' into feat/mps-support
araffin Apr 3, 2023
0311b62
Merge branch 'master' into feat/mps-support
araffin Apr 21, 2023
086f79a
Merge branch 'master' into feat/mps-support
araffin May 3, 2023
fe606fc
Merge branch 'master' into feat/mps-support
araffin May 24, 2023
34f4819
Merge branch 'master' into feat/mps-support
qgallouedec Jun 30, 2023
ef39571
Merge branch 'master' into feat/mps-support
araffin Aug 17, 2023
d26324c
Merge branch 'master' into feat/mps-support
araffin Aug 30, 2023
1e5dc90
Merge branch 'master' into feat/mps-support
araffin Oct 6, 2023
40ed03c
mps.is_available -> mps.is_built
qgallouedec Oct 6, 2023
e83924b
docstring
qgallouedec Oct 6, 2023
b707480
Merge branch 'master' into feat/mps-support
qgallouedec Nov 2, 2023
81e3c63
Merge branch 'master' into feat/mps-support
araffin Nov 16, 2023
f0e54a7
Merge branch 'master' into feat/mps-support
araffin Jan 10, 2024
d47c586
Merge branch 'master' into feat/mps-support
araffin Apr 18, 2024
b85a2a5
Fix warning
araffin Apr 18, 2024
1c25053
Fix tests
deathcoder Sep 14, 2024
f822ef5
Attempt fix ci: only cast reward from float64 to float32
deathcoder Sep 17, 2024
1ac4a60
allow running workflows from ui
deathcoder Sep 17, 2024
9970f51
Merge pull request #2 from deathcoder/attempt-fix-ci
deathcoder Sep 17, 2024
955382e
Merge branch 'master' into feat/mps-support
araffin Sep 18, 2024
56c153f
Add warning when using PPO on GPU and update doc (#2017)
Dev1nW Oct 7, 2024
3d59b5c
Use uv on GitHub CI for faster download and update changelog (#2026)
araffin Oct 24, 2024
dd3d0ac
Update readme and clarify planned features (#2030)
araffin Oct 29, 2024
5e7372d
Merge branch 'feat/mps-support' into feat/mps-support
araffin Oct 29, 2024
263e657
Merge branch 'master' into feat/mps-support
araffin Oct 29, 2024
8f0b488
Update Gymnasium to v1.0.0 (#1837)
pseudo-rnd-thoughts Nov 4, 2024
e4f4f12
Add note about SAC ent coeff optimization (#2037)
araffin Nov 8, 2024
7c71688
Merge branch 'master' into feat/mps-support
araffin Nov 8, 2024
4c03a25
Merge remote-tracking branch 'origin/feat/mps-support' into feat/mps-…
araffin Nov 8, 2024
020ee42
Release 2.4.0 (#2040)
araffin Nov 18, 2024
daaebd0
Drop python 3.8 and add python 3.12 support (#2041)
araffin Nov 18, 2024
9489b1a
Merge branch 'master' into feat/mps-support
araffin Nov 18, 2024
0ec37d8
Merge branch 'feat/mps-support' into feat/mps-support
araffin Nov 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ name: CI

on:
push:
branches: [ master ]
branches: [master]
pull_request:
branches: [ master ]

branches: [master]
workflow_dispatch:
jobs:
build:
env:
Expand All @@ -23,38 +23,38 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu

# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
pip install .[extra_no_roms,tests,docs]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quote the pip install argument to prevent shell errors

The command:

pip install .[extra_no_roms,tests,docs]

may be misinterpreted by the shell due to the square brackets, which the shell might interpret as pattern characters for filename expansion (globbing). To prevent potential shell errors, it's recommended to quote the argument.

Apply this diff to fix the issue:

-pip install .[extra_no_roms,tests,docs]
+pip install '.[extra_no_roms,tests,docs]'

# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ New Features:
- Added checkpoints for replay buffer and ``VecNormalize`` statistics (@anand-bala)
- Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio)
- The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys
- Use MacOS Metal "mps" device when available

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -615,6 +616,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Save cloudpickle version


`SB3-Contrib`_
Expand Down
2 changes: 2 additions & 0 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
:return:
"""
if copy:
if hasattr(th, "backends") and th.backends.mps.is_built():
return th.tensor(array, dtype=th.float32, device=self.device)
Comment on lines +139 to +140
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Modify the MPS check in the to_torch method for better compatibility

Consider changing the conditional check to determine if the device is MPS by checking self.device.type == "mps" instead of hasattr(th, "backends") and th.backends.mps.is_built(). This ensures more robust detection of the MPS backend and better compatibility across different PyTorch versions.

Apply this diff to modify the condition:

 def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
     if copy:
-        if hasattr(th, "backends") and th.backends.mps.is_built():
+        if self.device.type == "mps":
             return th.tensor(array, dtype=th.float32, device=self.device)
         return th.tensor(array, device=self.device)
     return th.as_tensor(array, device=self.device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
return th.tensor(array, dtype=th.float32, device=self.device)
if self.device.type == "mps":
return th.tensor(array, dtype=th.float32, device=self.device)

return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device)

Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/envs/bit_flipping_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
if self.discrete_obs_space:
# The internal state is the binary representation of the
# observed one
return int(sum(state[i] * 2**i for i in range(len(state))))
return int(sum(int(state[i]) * 2**i for i in range(len(state))))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify state conversion using NumPy vectorization

Currently, the state conversion uses a generator expression with explicit loops and integer casting:

return int(sum(int(state[i]) * 2**i for i in range(len(state))))

This can be simplified and optimized by leveraging NumPy's vectorized operations, which are more efficient and concise. Consider rewriting the code using np.dot:

Apply this diff to simplify the code:

- return int(sum(int(state[i]) * 2**i for i in range(len(state))))
+ return int(state.dot(2 ** np.arange(len(state))))

This approach eliminates the explicit loop and casting, improving performance and readability.


if self.image_obs_space:
size = np.prod(self.image_shape)
image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8)))
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize image creation by preallocating the array

Instead of concatenating arrays to create the image, you can preallocate the array and assign values directly. This avoids unnecessary memory allocation and improves performance:

Apply this diff to optimize the code:

- image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
+ image = np.zeros(size, dtype=np.uint8)
+ image[:len(state)] = state.astype(np.uint8) * 255

This refactored code preallocates a zero-filled array of the required size and assigns the scaled state values directly to the beginning of the array.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
image = np.zeros(size, dtype=np.uint8)
image[:len(state)] = state.astype(np.uint8) * 255

return image.reshape(self.image_shape).astype(np.uint8)
return state

Expand Down
36 changes: 27 additions & 9 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None:
"""
Seed the different random generators.

:param seed:
:param using_cuda:
:param seed: Seed
:param using_cuda: Whether CUDA is currently used
"""
# Seed python RNG
random.seed(seed)
Expand Down Expand Up @@ -138,19 +138,20 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
For now, it supports only CPU and CUDA.
By default, it tries to use the GPU.

:param device: One for 'auto', 'cuda', 'cpu'
:param device: One of "auto", "cuda", "cpu",
or any PyTorch supported device (for instance "mps")
Comment on lines +144 to +148
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring to reflect MPS support

The docstring states that the function supports only CPU and CUDA devices, but with the addition of MPS support, this is no longer accurate. Please update the docstring to include MPS and reflect the current functionality.

:return: Supported Pytorch device
"""
# Cuda by default
# MPS/CUDA by default
if device == "auto":
device = "cuda"
device = get_available_accelerator()
# Force conversion to th.device
device = th.device(device)

# Cuda not available
# CUDA not available
Comment on lines +151 to +157
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance device handling logic.

The current implementation could be improved to handle device availability more robustly:

  1. Add explicit MPS availability check
  2. Consider handling potential MPS-related errors

Apply this diff to improve the implementation:

     # MPS/CUDA by default
     if device == "auto":
         device = get_available_accelerator()
     # Force conversion to th.device
     device = th.device(device)
 
-    # CUDA not available
-    if device.type == th.device("cuda").type and not th.cuda.is_available():
+    # Check device availability
+    if device.type == "cuda" and not th.cuda.is_available():
+        return th.device("cpu")
+    elif device.type == "mps":
+        try:
+            if not (hasattr(th, "backends") and th.backends.mps.is_built() and th.backends.mps.is_available()):
+                return th.device("cpu")
+        except:
+            return th.device("cpu")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# MPS/CUDA by default
if device == "auto":
device = "cuda"
device = get_available_accelerator()
# Force conversion to th.device
device = th.device(device)
# Cuda not available
# CUDA not available
# MPS/CUDA by default
if device == "auto":
device = get_available_accelerator()
# Force conversion to th.device
device = th.device(device)
# Check device availability
if device.type == "cuda" and not th.cuda.is_available():
return th.device("cpu")
elif device.type == "mps":
try:
if not (hasattr(th, "backends") and th.backends.mps.is_built() and th.backends.mps.is_available()):
return th.device("cpu")
except:
return th.device("cpu")

if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")

Expand Down Expand Up @@ -482,6 +483,8 @@ def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.devi
if isinstance(obs, np.ndarray):
return th.as_tensor(obs, device=device)
elif isinstance(obs, dict):
if hasattr(th, "backends") and th.backends.mps.is_built():
return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()}
Comment on lines +489 to +490
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify MPS handling in obs_as_tensor

The current implementation introduces code duplication when handling observations for MPS devices. Consider refactoring to streamline the code and ensure compatibility with different PyTorch versions.

Apply this diff to simplify the code:

    elif isinstance(obs, dict):
-       if hasattr(th, "backends") and th.backends.mps.is_built():
-           return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()}
-       return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
+       tensor_kwargs = {'device': device}
+       if device.type == 'mps':
+           tensor_kwargs['dtype'] = th.float32
+       return {key: th.as_tensor(_obs, **tensor_kwargs) for (key, _obs) in obs.items()}

This refactor reduces code duplication and uses the device type to handle dtype settings appropriately.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()}
elif isinstance(obs, dict):
tensor_kwargs = {'device': device}
if device.type == 'mps':
tensor_kwargs['dtype'] = th.float32
return {key: th.as_tensor(_obs, **tensor_kwargs) for (key, _obs) in obs.items()}

return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
else:
raise Exception(f"Unrecognized type of observation {type(obs)}")
Expand Down Expand Up @@ -515,6 +518,21 @@ def should_collect_more_steps(
)


def get_available_accelerator() -> str:
"""
Return the available accelerator
(currently checking only for CUDA and MPS device)
"""
if hasattr(th, "backends") and th.backends.mps.is_built():
# MacOS Metal GPU
th.set_default_dtype(th.float32)
return "mps"
elif th.cuda.is_available():
return "cuda"
else:
return "cpu"

Comment on lines +524 to +537
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Improve accelerator detection implementation.

The current implementation has several issues:

  1. Setting global dtype could have unintended side effects
  2. Missing proper MPS availability check
  3. Missing error handling
  4. Docstring needs improvement

Apply this diff to fix these issues:

 def get_available_accelerator() -> str:
     """
     Return the available accelerator
-    (currently checking only for CUDA and MPS device)
+    (checking for CUDA and MPS devices)
+    
+    Note: MPS (Apple Metal) requires float32 dtype and may not be available
+    on older macOS versions even if built.
+    
+    Returns:
+        str: The available accelerator type ('mps', 'cuda', or 'cpu')
     """
-    if hasattr(th, "backends") and th.backends.mps.is_built():
-        # MacOS Metal GPU
-        th.set_default_dtype(th.float32)
-        return "mps"
+    try:
+        if (hasattr(th, "backends") and th.backends.mps.is_built() 
+            and th.backends.mps.is_available()):
+            # MacOS Metal GPU
+            return "mps"
+    except:  # Catch any MPS-related errors
+        pass
     elif th.cuda.is_available():
         return "cuda"
     else:
         return "cpu"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_available_accelerator() -> str:
"""
Return the available accelerator
(currently checking only for CUDA and MPS device)
"""
if hasattr(th, "backends") and th.backends.mps.is_built():
# MacOS Metal GPU
th.set_default_dtype(th.float32)
return "mps"
elif th.cuda.is_available():
return "cuda"
else:
return "cpu"
def get_available_accelerator() -> str:
"""
Return the available accelerator
(checking for CUDA and MPS devices)
Note: MPS (Apple Metal) requires float32 dtype and may not be available
on older macOS versions even if built.
Returns:
str: The available accelerator type ('mps', 'cuda', or 'cpu')
"""
try:
if (hasattr(th, "backends") and th.backends.mps.is_built()
and th.backends.mps.is_available()):
# MacOS Metal GPU
return "mps"
except: # Catch any MPS-related errors
pass
elif th.cuda.is_available():
return "cuda"
else:
return "cpu"


def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
"""
Retrieve system and python env info for the current system.
Expand All @@ -530,7 +548,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
"Python": platform.python_version(),
"Stable-Baselines3": sb3.__version__,
"PyTorch": th.__version__,
"GPU Enabled": str(th.cuda.is_available()),
"Accelerator": get_available_accelerator(),
"Numpy": np.__version__,
"Cloudpickle": cloudpickle.__version__,
"Gymnasium": gym.__version__,
Expand Down
17 changes: 16 additions & 1 deletion stable_baselines3/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ def _sanity_checks(self) -> None:
f"not {self.observation_space}"
)

@staticmethod
def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray:
"""
Cast `np.float64` reward datatype to `np.float32`,
keep the others dtype unchanged.

:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if reward.dtype == np.float64:
return reward.astype(np.float32)
return reward
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix docstring parameter and add input validation.

The implementation looks good but has a few minor issues:

  1. The docstring refers to a dtype parameter that doesn't exist
  2. Missing input type validation

Apply these changes:

 @staticmethod
 def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray:
     """
     Cast `np.float64` reward datatype to `np.float32`,
     keep the others dtype unchanged.

-    :param dtype: The original action space dtype
+    :param reward: The reward array to potentially cast
     :return: ``np.float32`` if the dtype was float64,
         the original dtype otherwise.
     """
+    if not isinstance(reward, np.ndarray):
+        raise TypeError(f"Expected numpy array, got {type(reward)}")
     if reward.dtype == np.float64:
         return reward.astype(np.float32)
     return reward
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@staticmethod
def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray:
"""
Cast `np.float64` reward datatype to `np.float32`,
keep the others dtype unchanged.
:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if reward.dtype == np.float64:
return reward.astype(np.float32)
return reward
@staticmethod
def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray:
"""
Cast `np.float64` reward datatype to `np.float32`,
keep the others dtype unchanged.
:param reward: The reward array to potentially cast
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if not isinstance(reward, np.ndarray):
raise TypeError(f"Expected numpy array, got {type(reward)}")
if reward.dtype == np.float64:
return reward.astype(np.float32)
return reward


def __getstate__(self) -> Dict[str, Any]:
"""
Gets state for pickling.
Expand Down Expand Up @@ -254,7 +268,8 @@ def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
"""
if self.norm_reward:
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
return reward

return self._maybe_cast_reward(reward)

def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
# Avoid modifying by reference the original object
Expand Down
3 changes: 3 additions & 0 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from gymnasium import spaces
from gymnasium.spaces.space import Space

Expand Down Expand Up @@ -151,6 +152,8 @@ def test_discrete_obs_space(model_class, env):
],
)
def test_float64_action_space(model_class, obs_space, action_space):
if hasattr(th, "backends") and th.backends.mps.is_built():
pytest.skip("MPS framework doesn't support float64")
Comment on lines +155 to +156
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve MPS availability check and skip message

The current implementation has several areas for improvement:

  1. Use is_available() instead of is_built() to accurately check MPS availability
  2. Make the skip message more informative about the specific limitation
  3. Simplify the condition structure

Apply this diff to improve the implementation:

-    if hasattr(th, "backends") and th.backends.mps.is_built():
-        pytest.skip("MPS framework doesn't support float64")
+    if hasattr(th.backends, "mps") and th.backends.mps.is_available():
+        pytest.skip("Skipping float64 tests: MPS backend does not support float64 dtype operations")

This change:

  • Uses the recommended is_available() check
  • Provides a clearer skip message explaining why the test is skipped
  • Simplifies the condition by checking hasattr(th.backends, "mps") directly
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
pytest.skip("MPS framework doesn't support float64")
if hasattr(th.backends, "mps") and th.backends.mps.is_available():
pytest.skip("Skipping float64 tests: MPS backend does not support float64 dtype operations")

env = DummyEnv(obs_space, action_space)
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
if isinstance(env.observation_space, spaces.Dict):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,10 @@ def test_get_system_info():
assert info["Stable-Baselines3"] == str(sb3.__version__)
assert "Python" in info_str
assert "PyTorch" in info_str
assert "GPU Enabled" in info_str
assert "Accelerator" in info_str
assert "Numpy" in info_str
assert "Gym" in info_str
assert "Cloudpickle" in info_str


def test_is_vectorized_observation():
Expand Down