Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests for mps support #1

Open
wants to merge 50 commits into
base: master
Choose a base branch
from
Open

Fix tests for mps support #1

wants to merge 50 commits into from

Conversation

deathcoder
Copy link
Owner

@deathcoder deathcoder commented Sep 17, 2024

Description

closes DLR-RM#914

When i started on the base branch feat/mps-support there were 45 failing tests that i now consider fixed, a few things to note:

  • in most cases i added a check (if mps device is available then i have to apply various casting to make sure tensors are float32 and remain float32) not sure if this approach is correct but happy to change it to something else that also works
  • i decided to skip test_float64_action_space tests entirely since float64 is not supported
  • this test test_save_load[True-SAC] only fails when running the full-suite or running all test_save_load tests (make pytest or python3 -m pytest -v -k 'test_save_load') if instead i run the the single breaking test (python3 -m pytest -v -k 'test_save_load[True-SAC]') then it passes 🤷‍♂️ i also run the test file in pycharm and it passes there too so i'm not sure what the issue is, i can add the stacktace of the failing test in a comment if needed
  • i'm not sure about a few things regarding this template, i think these are not breaking changes but for example i force a cast in vec_normalize:normalize_reward that maybe is considered breaking?
  • i also looked into the changelog but i couldnt figure out how to edit it

Here the full list of fixed tests

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-DQN] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-DQN] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

  • FAILED tests/test_envs.py::test_bit_flipping[kwargs1] - OverflowError: Python integer 128 out of bounds for int8

  • FAILED tests/test_envs.py::test_bit_flipping[kwargs2] - OverflowError: Python integer 255 out of bounds for int8

  • FAILED tests/test_envs.py::test_bit_flipping[kwargs3] - OverflowError: Python integer 255 out of bounds for int8

  • FAILED tests/test_her.py::test_her[True-SAC] - OverflowError: Python integer 255 out of bounds for int8

  • FAILED tests/test_her.py::test_her[True-TD3] - OverflowError: Python integer 255 out of bounds for int8

  • FAILED tests/test_her.py::test_her[True-DDPG] - OverflowError: Python integer 255 out of bounds for int8

  • FAILED tests/test_her.py::test_her[True-DQN] - OverflowError: Python integer 255 out of bounds for int8

  • FAILED tests/test_her.py::test_multiprocessing[True-TD3] - EOFError

  • FAILED tests/test_her.py::test_multiprocessing[True-DQN] - EOFError

  • FAILED tests/test_train_eval_mode.py::test_td3_train_with_batch_norm - AssertionError: assert ~tensor(True, device='mps:0')

  • FAILED tests/test_vec_normalize.py::test_get_original - AssertionError: assert dtype('float32') == dtype('float64')

  • FAILED tests/test_vec_normalize.py::test_get_original_dict - AssertionError: assert dtype('float32') == dtype('float64')

  • FAILED tests/test_her.py::test_save_load[True-SAC] - ValueError: Expected parameter scale (Tensor of shape (64, 4)) of distribution Normal(loc: torch.Size([64, 4]), scale: torch.Size([64, 4])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:

Unsupported tests fixed by skipping

  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  • FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have opened an associated PR on the SB3-Contrib repository (if necessary)
  • I have opened an associated PR on the RL-Zoo3 repository (if necessary)
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

Summary by CodeRabbit

  • New Features

    • Introduced the CrossQ algorithm in the SB3 Contrib section.
    • Added a method to recommend CPU usage for algorithms not optimized for GPU.
    • Enhanced reward normalization process with improved data type handling.
    • Updated documentation to include new features and algorithms.
  • Bug Fixes

    • Resolved warnings related to GPU usage for the PPO model.
    • Improved memory management in buffer classes.
  • Documentation

    • Updated changelog to reflect new features, breaking changes, and upgrade recommendations.
    • Clarified contribution guidelines and testing instructions for contributors.
    • Added Zenodo DOI for citing specific versions of SB3.
    • Updated installation instructions to reflect changes in prerequisites and package versions.

araffin and others added 30 commits July 4, 2022 14:51
Copy link

coderabbitai bot commented Sep 17, 2024

Walkthrough

The pull request introduces enhancements to the stable-baselines3 codebase, focusing on improved error handling, GPU compatibility for Apple Silicon through the MPS backend, and better data type management in various functions. Key updates include a refined environment checker, the addition of functions to save the cloudpickle version, and modifications to tensor creation logic to support MPS. These changes collectively aim to enhance performance, compatibility, and robustness in the library.

Changes

Files Change Summary
.github/workflows/ci.yml Updated Python version matrix, enhanced dependency installation commands, added conditional installation for gymnasium, and modified torch version specification to include +cpu.
docs/misc/changelog.rst Updated changelog for SB3 v2.5.0a0, including increased minimum required versions for PyTorch and Gymnasium, added support for Python 3.12, and noted various breaking changes and new features.
stable_baselines3/common/buffers.py Modified type hints from Tuple and Dict to tuple and dict, updated tensor handling in BaseBuffer and derived classes, and enhanced compatibility with MPS backend.
stable_baselines3/common/utils.py Updated documentation for set_random_seed, get_device, and obs_as_tensor, added get_available_accelerator() function, and modified get_system_info to reflect new hardware information.
stable_baselines3/common/vec_env/vec_normalize.py Updated type hints and added comments for clarity, ensuring compatibility with the new type hinting conventions.
stable_baselines3/common/on_policy_algorithm.py Added _maybe_recommend_cpu method to warn users when running on GPU with MlpPolicy, advising CPU usage for better performance.
README.md Updated CI and coverage report badge links, added content about OpenRL Benchmark logs, revised "Planned features" section, and included new algorithm "CrossQ".
tests/test_run.py Added import for torch, updated test_ppo_warnings to check for GPU usage warnings.
tests/test_utils.py Enhanced tests for vectorized environments, removed type hints from test_evaluate_policy, and updated assertions in test_get_system_info.

Assessment against linked issues

Objective Addressed Explanation
Support for GPU acceleration using the Apple Silicon chips (#[914])
Ensure stable-baselines3 recognizes the MPS device for computations (#[914])
Prevent errors when using unsupported data types with MPS (#[914]) It's unclear if all unsupported data types are handled.

🐇 In the code we hop and play,
With MPS support here to stay.
Tensors dance on silicon bright,
As errors fade, and code takes flight!
Cloudpickle's version saved with glee,
In this patch, we’re wild and free! 🌟


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Outside diff range and nitpick comments (3)
stable_baselines3/common/buffers.py (1)

138-139: Add explanatory comment for clarity

To improve code readability and maintainability, consider adding a comment that explains why casting to float32 is necessary when the MPS backend is available.

Apply this diff to add the comment:

 def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
     if copy:
+        # Cast to float32 since MPS backend does not support float64 tensors
         if hasattr(th, "backends") and th.backends.mps.is_built():
             return th.tensor(array, dtype=th.float32, device=self.device)
         return th.tensor(array, device=self.device)
     return th.as_tensor(array, device=self.device)
docs/misc/changelog.rst (2)

561-561: Correct capitalization of 'macOS' and 'MPS'

Apple's operating system is styled as 'macOS', and 'MPS' (Metal Performance Shaders) is typically capitalized. Please update the line for correct branding.

Apply this diff to correct the capitalization:

- - Use MacOS Metal "mps" device when available
+ - Use macOS Metal "MPS" device when available

619-619: Enclose 'cloudpickle' in double backticks for consistency

In the changelog, package names are usually formatted with double backticks to denote code. Please wrap 'cloudpickle' with double backticks for consistent styling.

Apply this diff to update the formatting:

- - Save cloudpickle version
+ - Save ``cloudpickle`` version
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 512eea9 and 1c25053.

Files selected for processing (7)
  • docs/misc/changelog.rst (2 hunks)
  • stable_baselines3/common/buffers.py (1 hunks)
  • stable_baselines3/common/envs/bit_flipping_env.py (1 hunks)
  • stable_baselines3/common/utils.py (5 hunks)
  • stable_baselines3/common/vec_env/vec_normalize.py (1 hunks)
  • tests/test_spaces.py (2 hunks)
  • tests/test_utils.py (1 hunks)
Files skipped from review due to trivial changes (1)
  • stable_baselines3/common/vec_env/vec_normalize.py
Additional comments not posted (8)
tests/test_spaces.py (1)

7-7: Importing torch as 'th' is necessary

The import statement is required for accessing th.backends.mps in the subsequent condition.

stable_baselines3/common/utils.py (3)

32-33: Good addition of parameter descriptions

Including :param seed: and :param using_cuda: in the docstring improves clarity and helps users understand the function's parameters.


148-150: Enhance device selection logic to include MPS

Using get_available_accelerator() in the device selection logic allows the function to automatically select the best available hardware, including MPS devices on Apple Silicon. This improves compatibility and performance.


551-551: Include accelerator information in system diagnostics

Adding the available accelerator to the system info enhances the utility of diagnostic outputs, helping users understand which hardware is being utilized.

tests/test_utils.py (2)

448-448: Include "Cloudpickle" in system info assertions

The addition of assert "Cloudpickle" in info_str ensures that the system information includes the Cloudpickle version, which is helpful for debugging serialization issues.


445-445: Verify that "Accelerator" is included in get_system_info output on all platforms

Changing the assertion from "GPU Enabled" to "Accelerator" broadens the check to include various hardware accelerators, such as Apple's MPS on Apple Silicon. Ensure that get_system_info consistently outputs the "Accelerator" keyword across different environments, including those without accelerators, to prevent potential test failures.

You can run the following script to verify the presence of "Accelerator" in the get_system_info output on different platforms:

Verification successful

"Accelerator" is consistently included in get_system_info output across all platforms

The verification process confirms that the "Accelerator" keyword is consistently included in the get_system_info() output, regardless of the platform or available hardware. The get_available_accelerator() function ensures that an appropriate value ("mps", "cuda", or "cpu") is always returned, which is then included in the system info dictionary. This implementation guarantees that the assertion in the test file will pass on all supported platforms.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that "Accelerator" is present in `get_system_info` output.

# Expected: The output should include a line containing "Accelerator"
python -c 'from stable_baselines3.common.utils import get_system_info; print(get_system_info()[1])' | grep -i 'Accelerator' || echo "Accelerator keyword not found."

Length of output: 268


Script:

#!/bin/bash
# Find the implementation of get_system_info function
ast-grep --lang python --pattern 'def get_system_info():
  $$$
'

# Search for "Accelerator" in stable_baselines3/common/utils.py
rg "Accelerator" stable_baselines3/common/utils.py

Length of output: 169


Script:

#!/bin/bash
# Find the get_system_info function with context
rg -A 20 "def get_system_info\(\):" stable_baselines3/common/utils.py

# Find the implementation of get_available_accelerator function
rg -A 10 "def get_available_accelerator\(\):" stable_baselines3/common/utils.py

Length of output: 153


Script:

#!/bin/bash
# Check if the file exists
ls -l stable_baselines3/common/utils.py

# If the file exists, display its content
cat stable_baselines3/common/utils.py

# If the file doesn't exist, search for files containing "Accelerator"
rg "Accelerator" .

Length of output: 21862

stable_baselines3/common/buffers.py (2)

138-139: Properly casting tensors to float32 for MPS backend compatibility

The added check correctly ensures that tensors are cast to float32 when using the MPS backend, which does not support float64 tensors. This change enhances compatibility with Apple Silicon devices utilizing MPS.


138-139: Verify that all tensors are appropriately cast to float32 when using MPS

Since the MPS backend does not support float64, please verify that there are no other instances in the codebase where tensors might be inadvertently created with float64 dtype when using the MPS device.

Run the following script to identify potential tensor creations without explicit dtype, which could default to float64 and cause issues with the MPS backend:

@@ -78,11 +78,11 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
if self.discrete_obs_space:
# The internal state is the binary representation of the
# observed one
return int(sum(state[i] * 2**i for i in range(len(state))))
return int(sum(int(state[i]) * 2**i for i in range(len(state))))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify state conversion using NumPy vectorization

Currently, the state conversion uses a generator expression with explicit loops and integer casting:

return int(sum(int(state[i]) * 2**i for i in range(len(state))))

This can be simplified and optimized by leveraging NumPy's vectorized operations, which are more efficient and concise. Consider rewriting the code using np.dot:

Apply this diff to simplify the code:

- return int(sum(int(state[i]) * 2**i for i in range(len(state))))
+ return int(state.dot(2 ** np.arange(len(state))))

This approach eliminates the explicit loop and casting, improving performance and readability.


if self.image_obs_space:
size = np.prod(self.image_shape)
image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8)))
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize image creation by preallocating the array

Instead of concatenating arrays to create the image, you can preallocate the array and assign values directly. This avoids unnecessary memory allocation and improves performance:

Apply this diff to optimize the code:

- image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
+ image = np.zeros(size, dtype=np.uint8)
+ image[:len(state)] = state.astype(np.uint8) * 255

This refactored code preallocates a zero-filled array of the required size and assigns the scaled state values directly to the beginning of the array.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
image = np.zeros(size, dtype=np.uint8)
image[:len(state)] = state.astype(np.uint8) * 255

Comment on lines +141 to +145
For now, it supports only CPU and CUDA.
By default, it tries to use the GPU.

:param device: One for 'auto', 'cuda', 'cpu'
:param device: One of "auto", "cuda", "cpu",
or any PyTorch supported device (for instance "mps")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring to reflect MPS support

The docstring states that the function supports only CPU and CUDA devices, but with the addition of MPS support, this is no longer accurate. Please update the docstring to include MPS and reflect the current functionality.

Comment on lines +486 to +487
if hasattr(th, "backends") and th.backends.mps.is_built():
return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify MPS handling in obs_as_tensor

The current implementation introduces code duplication when handling observations for MPS devices. Consider refactoring to streamline the code and ensure compatibility with different PyTorch versions.

Apply this diff to simplify the code:

    elif isinstance(obs, dict):
-       if hasattr(th, "backends") and th.backends.mps.is_built():
-           return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()}
-       return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
+       tensor_kwargs = {'device': device}
+       if device.type == 'mps':
+           tensor_kwargs['dtype'] = th.float32
+       return {key: th.as_tensor(_obs, **tensor_kwargs) for (key, _obs) in obs.items()}

This refactor reduces code duplication and uses the device type to handle dtype settings appropriately.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()}
elif isinstance(obs, dict):
tensor_kwargs = {'device': device}
if device.type == 'mps':
tensor_kwargs['dtype'] = th.float32
return {key: th.as_tensor(_obs, **tensor_kwargs) for (key, _obs) in obs.items()}

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Outside diff range and nitpick comments (1)
.github/workflows/ci.yml (1)

8-10: Consider updating branch names to 'main' if applicable

The workflow is configured to trigger on pushes and pull requests to the master branch:

branches: [master]

If your repository has migrated to using main as the default branch, consider updating the branch name to main to reflect this change.

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 1c25053 and 9970f51.

Files selected for processing (2)
  • .github/workflows/ci.yml (2 hunks)
  • stable_baselines3/common/vec_env/vec_normalize.py (2 hunks)
Files skipped from review as they are similar to previous changes (1)
  • stable_baselines3/common/vec_env/vec_normalize.py
Additional context used
actionlint
.github/workflows/ci.yml

32-32: shellcheck reported issue in this script: SC2102:info:11:14: Ranges can only match single chars (mentioned due to duplicates)

(shellcheck)

Additional comments not posted (1)
.github/workflows/ci.yml (1)

11-11: Good addition of workflow_dispatch trigger

Adding workflow_dispatch allows manual triggering of the workflow, which is helpful for on-demand builds and tests.

- name: Test with pytest
run: |
make pytest
pip install .[extra_no_roms,tests,docs]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quote the pip install argument to prevent shell errors

The command:

pip install .[extra_no_roms,tests,docs]

may be misinterpreted by the shell due to the square brackets, which the shell might interpret as pattern characters for filename expansion (globbing). To prevent potential shell errors, it's recommended to quote the argument.

Apply this diff to fix the issue:

-pip install .[extra_no_roms,tests,docs]
+pip install '.[extra_no_roms,tests,docs]'

araffin and others added 6 commits September 18, 2024 14:28
* Update documentation

Added comment to PPO documentation that CPU should primarily be used unless using CNN as well as sample code. Added warning to user for both PPO and A2C that CPU should be used if the user is running GPU without using a CNN, reference Issue DLR-RM#1245.

* Add warning to base class and add test

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
…2026)

* Use uv on GitHub CI for faster download and update changelog

* Fix new mypy issues
* Update readme and clarify planned features

* Fix rtd python version

* Fix pip version for rtd

* Update rtd ubuntu and mambaforge

* Add upper bound for gymnasium

* [ci skip] Update readme
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (5)
.github/workflows/ci.yml (1)

34-38: Consider version pinning for UV and improving documentation.

While using UV for faster package installation is beneficial, consider:

  1. Pin the UV version to ensure consistent behavior across builds
  2. Expand the comment to explain why CPU version is used (e.g., for CI environment compatibility)

Apply this diff to implement the suggestions:

 # Use uv for faster downloads
-pip install uv
+pip install uv==0.1.25  # Pin to specific version for consistency
-# cpu version of pytorch
+# Use CPU version of PyTorch for CI compatibility (GPU not available in runners)
 # See https://github.com/astral-sh/uv/issues/1497
 uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
stable_baselines3/common/vec_env/vec_normalize.py (1)

128-140: Consider creating a shared utility for dtype casting.

Since MPS compatibility requires float32 casting in multiple places, consider moving this functionality to a utility module (e.g., stable_baselines3.common.utils) where it can be reused for observations, actions, and other numeric data.

This would:

  1. Centralize MPS compatibility logic
  2. Make it easier to maintain consistent dtype handling
  3. Reduce code duplication if similar casting is needed elsewhere
docs/misc/changelog.rst (3)

656-656: Organize changelog entries in proper sections

The changes for adding MacOS Metal support and cloudpickle version saving are scattered in the file. These should be organized under the appropriate sections (New Features) in version 2.4.0a9.

Apply this organization:

Release 2.4.0a9 (WIP)
--------------------------

New Features:
^^^^^^^^^^^^^
+ - Use MacOS Metal "mps" device when available
+ - Save cloudpickle version

Also applies to: 714-714


Line range hint 1-50: Improve changelog format consistency

The changelog entries would benefit from:

  1. Adding version numbers to bug fixes for better tracking
  2. Using consistent punctuation (some entries end with periods, others don't)
  3. Using consistent capitalization at the start of each entry

Example format:

Bug Fixes:
^^^^^^^^^^
- Fixed memory leak when loading learner from storage (v2.4.0)
- Fixed error when loading a model with net_arch manually set to None (v2.4.0)

Line range hint 1-800: Improve changelog section organization

The "Others" section often contains entries that could be better categorized. Consider:

  1. Creating a "Development" section for CI/testing changes
  2. Moving performance improvements to "New Features"
  3. Moving type hint fixes to a dedicated "Type Hints" section

This would make the changelog more organized and easier to navigate.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 9970f51 and 5e7372d.

📒 Files selected for processing (5)
  • .github/workflows/ci.yml (1 hunks)
  • docs/misc/changelog.rst (2 hunks)
  • stable_baselines3/common/buffers.py (1 hunks)
  • stable_baselines3/common/utils.py (5 hunks)
  • stable_baselines3/common/vec_env/vec_normalize.py (2 hunks)
🔇 Additional comments (6)
.github/workflows/ci.yml (2)

46-48: Quote package specifiers and improve documentation.

The installation command needs proper quoting, and the comment could be more descriptive.

Apply this diff:

-uv pip install --system .[extra_no_roms,tests,docs]
+uv pip install --system '.[extra_no_roms,tests,docs]'
-# Use headless version
+# Use headless version of opencv as GUI is not needed in CI environment
 uv pip install --system opencv-python-headless

41-44: ⚠️ Potential issue

Security: Verify ROM package integrity and use official sources.

The current implementation downloads ROMs from a personal Gist without verifying the package integrity, which poses security risks:

  1. No checksum verification for the downloaded content
  2. Using an unofficial source (personal Gist) instead of official channels

Consider implementing checksum verification:

 uv pip install --system autorom
-wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36926eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
+# Download from official source and verify checksum
+EXPECTED_SHA256="<expected_checksum>"
+wget <official_source_url> -O Roms.tar.gz.b64
+echo "$EXPECTED_SHA256 Roms.tar.gz.b64" | sha256sum --check
 base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
stable_baselines3/common/vec_env/vec_normalize.py (1)

271-272: LGTM! Verify float32 compatibility across the codebase.

The addition of _maybe_cast_reward call is correctly placed after normalization operations. This ensures MPS compatibility while maintaining existing behavior for other devices.

Let's verify that this change doesn't impact other parts of the codebase that might expect float64:

✅ Verification successful

Let's examine the specific usage in these files to make a proper assessment.


LGTM! The float32 casting is safely implemented

The implementation is correct and safe because:

  • The _maybe_cast_reward function only converts float64 to float32, leaving other dtypes unchanged
  • The test suite explicitly verifies that reward dtypes are preserved through assert orig_rewards.dtype == rewards.dtype
  • The casting is done after normalization operations, ensuring numerical stability
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for explicit float64 dtype usage with rewards
rg -l "float64.*reward|reward.*float64"

# Search for reward type assertions or checks
rg -l "assert.*reward.*dtype|reward.*dtype.*assert"

Length of output: 173


Script:

#!/bin/bash
# Check float64 usage in vec_normalize.py
rg "float64.*reward|reward.*float64" stable_baselines3/common/vec_env/vec_normalize.py -B 2 -A 2

# Check reward dtype assertions in tests
rg "assert.*reward.*dtype|reward.*dtype.*assert" tests/test_vec_normalize.py -B 2 -A 2

Length of output: 957

stable_baselines3/common/utils.py (3)

143-147: Docstring still needs updating


488-489: Implementation can be simplified


553-553: LGTM!

The change correctly uses the new get_available_accelerator function to report the available hardware accelerator.

Comment on lines 128 to 140
@staticmethod
def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray:
"""
Cast `np.float64` reward datatype to `np.float32`,
keep the others dtype unchanged.

:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if reward.dtype == np.float64:
return reward.astype(np.float32)
return reward
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix docstring parameter and add input validation.

The implementation looks good but has a few minor issues:

  1. The docstring refers to a dtype parameter that doesn't exist
  2. Missing input type validation

Apply these changes:

 @staticmethod
 def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray:
     """
     Cast `np.float64` reward datatype to `np.float32`,
     keep the others dtype unchanged.

-    :param dtype: The original action space dtype
+    :param reward: The reward array to potentially cast
     :return: ``np.float32`` if the dtype was float64,
         the original dtype otherwise.
     """
+    if not isinstance(reward, np.ndarray):
+        raise TypeError(f"Expected numpy array, got {type(reward)}")
     if reward.dtype == np.float64:
         return reward.astype(np.float32)
     return reward
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@staticmethod
def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray:
"""
Cast `np.float64` reward datatype to `np.float32`,
keep the others dtype unchanged.
:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if reward.dtype == np.float64:
return reward.astype(np.float32)
return reward
@staticmethod
def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray:
"""
Cast `np.float64` reward datatype to `np.float32`,
keep the others dtype unchanged.
:param reward: The reward array to potentially cast
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if not isinstance(reward, np.ndarray):
raise TypeError(f"Expected numpy array, got {type(reward)}")
if reward.dtype == np.float64:
return reward.astype(np.float32)
return reward

Comment on lines +138 to +139
if hasattr(th, "backends") and th.backends.mps.is_built():
return th.tensor(array, dtype=th.float32, device=self.device)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Modify the MPS check in the to_torch method for better compatibility

Consider changing the conditional check to determine if the device is MPS by checking self.device.type == "mps" instead of hasattr(th, "backends") and th.backends.mps.is_built(). This ensures more robust detection of the MPS backend and better compatibility across different PyTorch versions.

Apply this diff to modify the condition:

 def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
     if copy:
-        if hasattr(th, "backends") and th.backends.mps.is_built():
+        if self.device.type == "mps":
             return th.tensor(array, dtype=th.float32, device=self.device)
         return th.tensor(array, device=self.device)
     return th.as_tensor(array, device=self.device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
return th.tensor(array, dtype=th.float32, device=self.device)
if self.device.type == "mps":
return th.tensor(array, dtype=th.float32, device=self.device)

pseudo-rnd-thoughts and others added 4 commits November 4, 2024 12:03
* Update Gymnasium to v1.0.0a1

* Comment out `gymnasium.wrappers.monitor` (todo update to VideoRecord)

* Fix ruff warnings

* Register Atari envs

* Update `getattr` to `Env.get_wrapper_attr`

* Reorder imports

* Fix `seed` order

* Fix collecting `max_steps`

* Copy and paste video recorder to prevent the need to rewrite the vec vide recorder wrapper

* Use `typing.List` rather than list

* Fix env attribute forwarding

* Separate out env attribute collection from its utilisation

* Update for Gymnasium alpha 2

* Remove assert for OrderedDict

* Update setup.py

* Add type: ignore

* Test with Gymnasium main

* Remove `gymnasium.logger.debug/info`

* Fix github CI yaml

* Run gym 0.29.1 on python 3.10

* Update lower bounds

* Integrate video recorder

* Remove ordered dict

* Update changelog

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
* Allow new sphinx version

* Add note about SAC ent coeff and add DQN tutorial link
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Outside diff range and nitpick comments (5)
docs/conda_env.yml (1)

7-15: Consider adding MPS-specific testing dependencies.

Given that this PR focuses on MPS support, consider adding testing dependencies specific to MPS validation:

  1. PyTorch should be installed with MPS support
  2. Testing tools for GPU compatibility might be needed

Consider adding a separate test environment configuration with:

name: mps-test
channels:
  - pytorch
  - conda-forge
dependencies:
  - pip=24.2
  - python=3.11
  - pytorch=2.5.0  # Ensure MPS support is included
  - pip:
    - pytest-gpu  # For GPU-specific testing
    - gymnasium>=0.28.1,<0.30
    - numpy>=1.20,<2.0
CONTRIBUTING.md (1)

9-9: Use proper markdown link syntax for the issues URL.

The URL change to point to "help wanted" issues is a good improvement for new contributors. However, let's enhance readability by using proper markdown link syntax.

Apply this change:

-    - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted
+    - Look at the outstanding issues [here](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted)
🧰 Tools
🪛 Markdownlint

9-9: null
Bare URL used

(MD034, no-bare-urls)

docs/modules/ppo.rst (1)

95-104: Code example looks good but could use environment safety checks.

The code example effectively demonstrates the recommended setup. However, consider adding environment safety checks for a more robust example.

 if __name__=="__main__":
+    try:
         env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
         model = PPO("MlpPolicy", env, device="cpu")
         model.learn(total_timesteps=25_000)
+    finally:
+        env.close()
tests/test_run.py (1)

215-219: Consider enhancing the docstring with MPS-specific context.

While the docstring is accurate, it could be more specific about the GPU-related warnings, especially in the context of MPS support.

Consider updating it to:

    """
    Test that PPO warns and errors correctly on
    problematic rollout buffer sizes,
-   and recommend using CPU.
+   and recommends using CPU when inappropriate GPU usage is detected
+   (especially relevant for MPS/Apple Silicon support).
    """
README.md (1)

202-205: Consider adding MPS support information

Given that this PR focuses on MPS support and its limitations (especially with float64), consider adding a note about device compatibility in the documentation.

Example addition:

+ Note: When using Apple Silicon (M1/M2) GPUs with the MPS backend, be aware that some operations, particularly those involving float64 data types, are not supported. Please ensure your inputs are in float32 format.
🧰 Tools
🪛 LanguageTool

[uncategorized] ~204-~204: Loose punctuation mark.
Context: ...point in the action space. * Discrete: A list of possible actions, where each ...

(UNLIKELY_OPENING_PUNCTUATION)


[uncategorized] ~205-~205: Loose punctuation mark.
Context: ... actions can be used. * MultiDiscrete: A list of possible actions, where each ...

(UNLIKELY_OPENING_PUNCTUATION)

🪛 Markdownlint

203-203: Expected: dash; Actual: asterisk
Unordered list style

(MD004, ul-style)


204-204: Expected: dash; Actual: asterisk
Unordered list style

(MD004, ul-style)


205-205: Expected: dash; Actual: asterisk
Unordered list style

(MD004, ul-style)


203-203: Expected: 0; Actual: 1
Unordered list indentation

(MD007, ul-indent)


204-204: Expected: 0; Actual: 1
Unordered list indentation

(MD007, ul-indent)


205-205: Expected: 0; Actual: 1
Unordered list indentation

(MD007, ul-indent)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 5e7372d and 4c03a25.

📒 Files selected for processing (14)
  • .readthedocs.yml (1 hunks)
  • CONTRIBUTING.md (1 hunks)
  • README.md (9 hunks)
  • docs/conda_env.yml (1 hunks)
  • docs/guide/algos.rst (1 hunks)
  • docs/guide/sb3_contrib.rst (1 hunks)
  • docs/index.rst (1 hunks)
  • docs/misc/changelog.rst (5 hunks)
  • docs/modules/ppo.rst (1 hunks)
  • stable_baselines3/common/on_policy_algorithm.py (2 hunks)
  • stable_baselines3/common/utils.py (7 hunks)
  • stable_baselines3/version.txt (1 hunks)
  • tests/test_run.py (3 hunks)
  • tests/test_utils.py (2 hunks)
✅ Files skipped from review due to trivial changes (1)
  • stable_baselines3/version.txt
🚧 Files skipped from review as they are similar to previous changes (1)
  • docs/misc/changelog.rst
🧰 Additional context used
🪛 Markdownlint
CONTRIBUTING.md

9-9: null
Bare URL used

(MD034, no-bare-urls)

🔇 Additional comments (25)
docs/conda_env.yml (3)

4-4: LGTM: Channel update aligns with best practices.

Switching to conda-forge as the primary channel is a good practice as it often provides more up-to-date packages and better maintenance.


11-11: Version constraints need documentation update.

The updated version constraints for gymnasium and numpy are reasonable, but:

  1. The gymnasium constraint >=0.28.1,<0.30 suggests API stability concerns
  2. The numpy constraint >=1.20,<2.0 is quite broad

Let's verify if these constraints are documented:

#!/bin/bash
# Check if version constraints are documented
rg -g "*.rst|*.md" "gymnasium|numpy" -A 2 "version|requirement|dependency"

Consider adding a comment in the yml file explaining the version constraints, particularly for gymnasium which has a narrow range.

Also applies to: 15-15


7-9: Verify compatibility with existing codebase.

The significant version upgrades warrant careful consideration:

  1. Python 3.11 upgrade from 3.8 is a major jump
  2. PyTorch 2.5.0 upgrade from 1.13.0 introduces substantial changes

While these updates are necessary for MPS support, they might introduce breaking changes.

Let's verify the minimum version requirements in the codebase:

✅ Verification successful

The version upgrades are compatible with the codebase requirements

Based on the setup.py findings:

  • Python ≥3.8 is required, and Python 3.11 is explicitly supported (listed in classifiers)
  • PyTorch ≥1.13 is required, making PyTorch 2.5.0 compatible
  • The codebase has no strict upper version bounds that would conflict with these upgrades
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for version-specific code or requirements
rg -g '!docs/*' -g '!*.yml' "python|torch" -A 2 "version|requires|dependency"

Length of output: 241


Script:

#!/bin/bash
# Search for Python version checks in the codebase
rg "python_version|sys\.version|python_requires" -g "!docs/*" -g "!*.yml"

# Search for torch/pytorch version requirements
rg "torch\.(version|__version__|cuda)" -g "!docs/*" -g "!*.yml"

# Check setup.py or pyproject.toml for dependencies
fd "setup.py|pyproject.toml" --exec cat {}

# Look for requirements files
fd "requirements" --exec cat {}

Length of output: 7800

.readthedocs.yml (1)

21-21: LGTM! Mambaforge version update aligns with dependency changes.

The upgrade to mambaforge-23.11 aligns well with the Python 3.11 upgrade in conda_env.yml and supports the newer PyTorch 2.5.0 required for improved MPS compatibility.

#!/bin/bash
# Verify conda environment configuration alignment
echo "Checking conda environment configuration..."
cat docs/conda_env.yml
docs/index.rst (2)

116-117: LGTM! Valuable addition for academic citations.

The Zenodo DOI note is a helpful addition that enables precise version referencing in academic work.


123-123: LGTM! Improved terminology.

The change from "repo" to "repository" enhances the documentation's professionalism.

docs/guide/algos.rst (1)

13-13: Verify relevance of CrossQ documentation to MPS support.

This documentation change appears unrelated to the PR's stated objectives of fixing MPS support and GPU compatibility for Apple Silicon. Could you clarify if this change was intentionally included in this PR?

✅ Verification successful

Let me gather more specific information about CrossQ and its relationship with MPS support.


Let me try a different approach to check the changes in this PR.


CrossQ documentation is part of a larger set of changes and is correctly included

The documentation change for CrossQ is intentionally included as part of this PR because:

  1. The PR appears to be a merge of multiple features, including both MPS support and the addition of the new CrossQ algorithm
  2. The changelog explicitly mentions "New algorithm: CrossQ in SB3 Contrib" in the same release
  3. CrossQ is being added as a new algorithm in the SB3-Contrib repository with proper documentation updates across multiple files
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if CrossQ is related to MPS support changes
# Look for CrossQ-related changes in test files that were modified for MPS support
rg -l "CrossQ" $(git diff --name-only | grep "test")

# Look for MPS-related changes in CrossQ implementation
rg "mps" $(git ls-files | grep -i "crossq")

Length of output: 1655


Script:

#!/bin/bash
# Check git diff for the specific changes made to CrossQ in the documentation
git diff docs/guide/algos.rst | grep -A 5 -B 5 "CrossQ"

# Look for CrossQ implementation files and their changes
fd -e py -e pyx -e pyd crossq -i
git diff $(fd -e py -e pyx -e pyd crossq -i)

# Check if there are any MPS-related changes in the PR
git diff | grep -A 5 -B 5 "mps"

Length of output: 169


Script:

#!/bin/bash
# Get the full context of changes in the algos.rst file
git diff docs/guide/algos.rst

# Look for any CrossQ-related files or implementations in the current codebase
rg -l "class CrossQ" 

# Check all changes in this PR that might be related to MPS support
git log -p HEAD~1..HEAD | grep -A 5 -B 5 "mps\|CrossQ"

Length of output: 4842

docs/modules/ppo.rst (2)

91-94: LGTM! Well-placed performance guidance.

The note provides valuable guidance about PPO's performance characteristics and hardware utilization. This is particularly relevant given the PR's focus on MPS support and helps users make informed decisions about their setup.


106-106: LGTM! Comprehensive references provided.

The added references to the vectorized environments documentation, relevant GitHub issue, and Colab notebook provide valuable resources for users to learn more about multiprocessing optimization.

tests/test_run.py (2)

4-4: LGTM: Import addition is appropriate.

The torch import is necessary for the new GPU-related test case and follows consistent aliasing conventions.


242-246: 🛠️ Refactor suggestion

Expand GPU warning test coverage for MPS devices.

The current test only verifies the warning for CUDA devices, but given that this PR specifically addresses MPS support, we should include MPS-specific test cases.

Consider adding an MPS-specific test case:

    with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"):
        model = PPO("MlpPolicy", "Pendulum-v1")
        # Test CUDA device
        model.device = th.device("cuda")
        model._maybe_recommend_cpu()
+
+    # Test MPS device specifically
+    if hasattr(th, "backends") and th.backends.mps.is_available():
+        with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"):
+            model = PPO("MlpPolicy", "Pendulum-v1")
+            model.device = th.device("mps")
+            model._maybe_recommend_cpu()

Let's verify the current MPS-related test coverage:

stable_baselines3/common/on_policy_algorithm.py (2)

3-3: LGTM: Clean integration of CPU recommendation feature

The warning import and method call placement are appropriate, occurring after policy initialization when we can properly check the policy type.

Also applies to: 139-140


142-160: Verify tensor dtype handling for MPS compatibility

Given that MPS doesn't support float64, we should verify how tensor dtypes are handled in this class.

Let's check for float64 tensor usage in the codebase:

README.md (7)

2-3: LGTM: CI badge updates are accurate

The CI badge links have been correctly updated to point to GitHub Actions workflows.


25-26: LGTM: Valuable addition about performance tracking

The addition of OpenRL Benchmark information helps users find detailed performance logs and reports.


46-52: LGTM: Clear project status and future direction

The updates effectively communicate:

  • Project's stable status
  • Focus on maintenance and bug fixes
  • Active development in associated repositories

90-90: LGTM: Updated experimental features list

The addition of CrossQ to the experimental features list maintains consistency with the algorithm table.


108-108: LGTM: Simplified Windows heading

The simplified heading is more inclusive while maintaining access to detailed installation instructions.


115-116: LGTM: Clear installation command

The installation command is well-formatted and clearly shows the optional extras.


187-187: LGTM: Algorithm table update

CrossQ has been correctly added to the algorithm table with appropriate support indicators.

stable_baselines3/common/utils.py (3)

65-65: LGTM! Explicit float casting improves type safety.

The explicit float casting aligns with the function's return type hint and prevents potential type-related issues.


553-553: LGTM! Consistent with new accelerator detection.

The update correctly uses the new get_available_accelerator function to report the system's accelerator type.


Line range hint 1-553: Verify MPS support implementation across codebase.

Let's verify the consistency of MPS support implementation:

  1. Check for other MPS-related code that might need similar updates
  2. Verify float32 dtype enforcement for MPS
✅ Verification successful

Based on the shell script results, I can now generate the final response:

MPS support implementation is correctly handled with proper float32 enforcement

The codebase shows consistent handling of MPS support across three key files:

  • stable_baselines3/common/utils.py: Properly enforces float32 for MPS tensors in dict observations
  • stable_baselines3/common/buffers.py: Already handles dtype casting from float64 to float32
  • tests/test_spaces.py: Correctly skips float64 tests for MPS devices

The implementation properly handles:

  1. MPS device detection and float32 enforcement for tensor conversion
  2. Automatic float64 to float32 casting in buffers
  3. Test suite compatibility with MPS limitations
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if MPS handling is consistent across the codebase
# Test: Look for other MPS-related checks that might need updating

echo "=== Checking for MPS-related code ==="
rg "mps.is_built|mps.is_available" -l

echo -e "\n=== Checking for tensor dtype handling ==="
rg "as_tensor.*dtype.*float|float32.*mps|mps.*float32" -A 2 -B 2

echo -e "\n=== Checking for potential float64 usage with MPS ==="
rg "float64|torch.double" -A 2 -B 2

Length of output: 8554

tests/test_utils.py (2)

445-448: Improved device detection for better MPS support.

The change from "GPU Enabled" to "Accelerator" is a good improvement as it:

  1. Makes the test more generic to support different acceleration backends (including MPS)
  2. Better aligns with PyTorch's device abstraction

The addition of "Cloudpickle" version check is also valuable for tracking dependencies.


180-180: Verify if type hint removal is necessary for MPS support.

The removal of type hints from the direct_policy parameter appears to be related to MPS device compatibility. While the change is acceptable, we should verify if this is the minimum required change.

@@ -16,6 +16,6 @@ conda:
environment: docs/conda_env.yml

build:
os: ubuntu-22.04
os: ubuntu-24.04
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Ubuntu 24.04 is not yet officially supported by Read the Docs

Ubuntu 24.04 (Jammy) was just released in April 2024 and Read the Docs typically takes some time to officially support new Ubuntu LTS versions. I recommend keeping ubuntu-22.04 for now to ensure stable documentation builds. You can upgrade to 24.04 once it's officially supported by Read the Docs.

  • Revert the OS version in .readthedocs.yml back to ubuntu-22.04
  • Keep the Python environment update to mambaforge-23.11 as it's independent of the OS version
🔗 Analysis chain

Verify Ubuntu 24.04 support on Read the Docs.

While upgrading to Ubuntu 24.04 LTS is forward-thinking, we should verify that Read the Docs officially supports this version to ensure stable documentation builds.


Let me gather more information about Read the Docs' Ubuntu support.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if ubuntu-24.04 is mentioned in the RTD documentation or changelog
rg -i "ubuntu.?24\.04|jammy" README.md docs/

Length of output: 474


Script:

#!/bin/bash
# Check .readthedocs.yml history to understand the OS version transition
git log -p -- .readthedocs.yml | grep -A 5 -B 5 "os:"

# Fetch current .readthedocs.yml content
cat .readthedocs.yml

Length of output: 944

@@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary``
=================== =========== ============ ================= =============== ================
ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️
A2C ✔️ ✔️ ✔️ ✔️ ✔️
CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Inconsistency in CrossQ implementation location.

There appears to be an inconsistency in the documentation. The table indicates CrossQ is implemented in SB3 Contrib (via footnote [#f1]), but the note at the bottom of the file mentions it's in the SBX repo. Please clarify which repository actually contains the CrossQ implementation.

Comment on lines +142 to +160
def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.

:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Adjust default policy class name and enhance MPS-specific guidance

The implementation is good but needs two adjustments:

  1. The default mlp_class_name parameter value doesn't match the intended behavior:

    • The docstring mentions MlpPolicy, but the default is "ActorCriticPolicy"
    • This could trigger warnings for non-MLP policies
  2. Given this PR's focus on MPS support, the warning message should include MPS-specific guidance.

Here's the suggested implementation:

-    def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
+    def _maybe_recommend_cpu(self, mlp_class_name: str = "MlpPolicy") -> None:
         """
         Recommend to use CPU only when using A2C/PPO with MlpPolicy.
 
         :param: The name of the class for the default MlpPolicy.
         """
         policy_class_name = self.policy_class.__name__
         if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
             warnings.warn(
                 f"You are trying to run {self.__class__.__name__} on the GPU, "
                 "but it is primarily intended to run on the CPU when not using a CNN policy "
                 f"(you are using {policy_class_name} which should be a MlpPolicy). "
+                "This is especially important for MPS (Apple Silicon GPU) users. "
                 "See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
                 "for more info. "
                 "You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
                 "Note: The model will train, but the GPU utilization will be poor and "
                 "the training might take longer than on CPU.",
                 UserWarning,
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)
def _maybe_recommend_cpu(self, mlp_class_name: str = "MlpPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"This is especially important for MPS (Apple Silicon GPU) users. "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)

Comment on lines +150 to +156
# MPS/CUDA by default
if device == "auto":
device = "cuda"
device = get_available_accelerator()
# Force conversion to th.device
device = th.device(device)

# Cuda not available
# CUDA not available
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance device handling logic.

The current implementation could be improved to handle device availability more robustly:

  1. Add explicit MPS availability check
  2. Consider handling potential MPS-related errors

Apply this diff to improve the implementation:

     # MPS/CUDA by default
     if device == "auto":
         device = get_available_accelerator()
     # Force conversion to th.device
     device = th.device(device)
 
-    # CUDA not available
-    if device.type == th.device("cuda").type and not th.cuda.is_available():
+    # Check device availability
+    if device.type == "cuda" and not th.cuda.is_available():
+        return th.device("cpu")
+    elif device.type == "mps":
+        try:
+            if not (hasattr(th, "backends") and th.backends.mps.is_built() and th.backends.mps.is_available()):
+                return th.device("cpu")
+        except:
+            return th.device("cpu")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# MPS/CUDA by default
if device == "auto":
device = "cuda"
device = get_available_accelerator()
# Force conversion to th.device
device = th.device(device)
# Cuda not available
# CUDA not available
# MPS/CUDA by default
if device == "auto":
device = get_available_accelerator()
# Force conversion to th.device
device = th.device(device)
# Check device availability
if device.type == "cuda" and not th.cuda.is_available():
return th.device("cpu")
elif device.type == "mps":
try:
if not (hasattr(th, "backends") and th.backends.mps.is_built() and th.backends.mps.is_available()):
return th.device("cpu")
except:
return th.device("cpu")

Comment on lines +523 to +536
def get_available_accelerator() -> str:
"""
Return the available accelerator
(currently checking only for CUDA and MPS device)
"""
if hasattr(th, "backends") and th.backends.mps.is_built():
# MacOS Metal GPU
th.set_default_dtype(th.float32)
return "mps"
elif th.cuda.is_available():
return "cuda"
else:
return "cpu"

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Improve accelerator detection implementation.

The current implementation has several issues:

  1. Setting global dtype could have unintended side effects
  2. Missing proper MPS availability check
  3. Missing error handling
  4. Docstring needs improvement

Apply this diff to fix these issues:

 def get_available_accelerator() -> str:
     """
     Return the available accelerator
-    (currently checking only for CUDA and MPS device)
+    (checking for CUDA and MPS devices)
+    
+    Note: MPS (Apple Metal) requires float32 dtype and may not be available
+    on older macOS versions even if built.
+    
+    Returns:
+        str: The available accelerator type ('mps', 'cuda', or 'cpu')
     """
-    if hasattr(th, "backends") and th.backends.mps.is_built():
-        # MacOS Metal GPU
-        th.set_default_dtype(th.float32)
-        return "mps"
+    try:
+        if (hasattr(th, "backends") and th.backends.mps.is_built() 
+            and th.backends.mps.is_available()):
+            # MacOS Metal GPU
+            return "mps"
+    except:  # Catch any MPS-related errors
+        pass
     elif th.cuda.is_available():
         return "cuda"
     else:
         return "cpu"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_available_accelerator() -> str:
"""
Return the available accelerator
(currently checking only for CUDA and MPS device)
"""
if hasattr(th, "backends") and th.backends.mps.is_built():
# MacOS Metal GPU
th.set_default_dtype(th.float32)
return "mps"
elif th.cuda.is_available():
return "cuda"
else:
return "cpu"
def get_available_accelerator() -> str:
"""
Return the available accelerator
(checking for CUDA and MPS devices)
Note: MPS (Apple Metal) requires float32 dtype and may not be available
on older macOS versions even if built.
Returns:
str: The available accelerator type ('mps', 'cuda', or 'cpu')
"""
try:
if (hasattr(th, "backends") and th.backends.mps.is_built()
and th.backends.mps.is_available()):
# MacOS Metal GPU
return "mps"
except: # Catch any MPS-related errors
pass
elif th.cuda.is_available():
return "cuda"
else:
return "cpu"

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Outside diff range and nitpick comments (64)
stable_baselines3/common/vec_env/vec_video_recorder.py (7)

2-2: Redundant Import of 'os.path'

Importing os.path is unnecessary since it is a submodule of os and can be accessed via os.path after importing os. This import is redundant.

Consider removing the redundant import:

-import os.path

74-74: Type Hint Compatibility for 'recorded_frames'

The type hint list[np.ndarray] uses the built-in list type with subscripts, which is valid in Python 3.9 and above. If you need to maintain compatibility with earlier Python versions, consider importing List from the typing module.

Update the type hint for backward compatibility:

-from typing import Callable
+from typing import Callable, List
...
-self.recorded_frames: list[np.ndarray] = []
+self.recorded_frames: List[np.ndarray] = []

76-79: Clarify Installation Instructions for 'moviepy' Dependency

The error message suggests installing gymnasium[other] to satisfy the moviepy dependency, which might not be intuitive for users. Providing direct installation instructions for moviepy enhances clarity.

Update the error message:

-raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e
+raise error.DependencyNotInstalled("MoviePy is not installed. Please install it using `pip install moviepy`") from e

100-102: Adjust Condition for Stopping Video Recording

The condition len(self.recorded_frames) > self.video_length may result in recording more frames than desired. To stop recording when the desired number of frames is reached, use >= instead.

Update the condition:

-if len(self.recorded_frames) > self.video_length:
+if len(self.recorded_frames) >= self.video_length:

119-121: Update Deprecated Logger Method from 'warn' to 'warning'

The method logger.warn is deprecated in favor of logger.warning. Updating it ensures compatibility with newer versions of the logging module.

Change the logger call:

-logger.warn(
+logger.warning(

141-141: Update Deprecated Logger Method from 'warn' to 'warning'

As above, replace logger.warn with logger.warning for consistency and to avoid deprecation warnings.

Modify the logger call:

-logger.warn("Ignored saving a video as there were zero frames to save.")
+logger.warning("Ignored saving a video as there were zero frames to save.")

154-154: Update Deprecated Logger Method from 'warn' to 'warning'

Again, update the deprecated logging method to maintain code consistency.

Amend the logger call:

-logger.warn("Unable to save last video! Did you call close()?")
+logger.warning("Unable to save last video! Did you call close()?")
pyproject.toml (2)

Line range hint 39-43: Consider adding MPS-specific test configurations

Given the focus on MPS support, consider adding:

  1. A pytest marker for MPS-specific tests
  2. Warning filters for PyTorch MPS-related warnings

Add the following configurations:

 markers = [
     "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
+    "mps: marks tests that require MPS device (deselect with '-m \"not mps\"')",
 ]

 filterwarnings = [
     # Tensorboard warnings
     "ignore::DeprecationWarning:tensorboard",
     # Gymnasium warnings
     "ignore::UserWarning:gymnasium",
     # tqdm warning about rich being experimental
     "ignore:rich is experimental",
+    # PyTorch MPS warnings
+    "ignore::UserWarning:torch.backends.mps",
 ]

Line range hint 53-63: Consider enabling branch coverage for better MPS code path testing

The current coverage configuration has branch coverage disabled. Given that MPS support involves device-specific code paths, enabling branch coverage could help ensure better test coverage of device-specific logic.

Consider updating the coverage configuration:

 [tool.coverage.run]
 disable_warnings = ["couldnt-parse"]
-branch = false
+branch = true
 omit = [
     "tests/*",
     "setup.py",
     # Require graphical interface
     "stable_baselines3/common/results_plotter.py",
     # Require ffmpeg
     "stable_baselines3/common/vec_env/vec_video_recorder.py",
+    # Skip coverage for non-MPS environments
+    "stable_baselines3/common/device_utils.py:if not torch.backends.mps.is_available()",
 ]
stable_baselines3/common/vec_env/vec_frame_stack.py (1)

43-43: Consider documenting MPS dtype requirements.

Since this is a key method where observations are initialized, consider adding a docstring note about MPS dtype requirements to help users understand potential limitations.

-    def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]:
+    def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]:
+        """
+        Reset all environments
+        
+        Note: When using MPS (Metal Performance Shaders), ensure observations
+        are float32 as MPS doesn't support float64 operations.
+        """
stable_baselines3/common/vec_env/util.py (2)

14-14: Consider narrowing the key type for better type safety.

While using Any as the key type provides flexibility, it might be too permissive. Consider using a more specific Union type like dict[Union[str, int, None], np.ndarray] to explicitly document the expected key types (strings for Dict spaces, integers for Tuple spaces, and None for unstructured spaces).


51-51: Update assertion message to reflect dict usage.

The assertion message still mentions "ordered subspaces" despite moving away from OrderedDict. Consider updating it to better reflect the current implementation.

-        assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces"
+        assert isinstance(obs_space.spaces, dict), "Dict space must have dictionary of subspaces"
stable_baselines3/common/type_aliases.py (1)

84-88: Consider documenting MPS-specific behavior for numpy array handling.

While the type hint updates are correct, consider adding documentation about how numpy arrays are handled when using the MPS backend, particularly regarding:

  • Dtype restrictions (float32 vs float64)
  • Device transfer behavior
  • Performance implications

Example docstring addition:

 def predict(
     self,
     observation: Union[np.ndarray, dict[str, np.ndarray]],
     state: Optional[tuple[np.ndarray, ...]] = None,
     episode_start: Optional[np.ndarray] = None,
     deterministic: bool = False,
 ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
-    """
+    """
+    Note: When using MPS backend, numpy arrays are automatically converted to float32
+    as MPS doesn't support float64 operations.
+    
     Get the policy action from an observation (and optional hidden state).
stable_baselines3/common/vec_env/vec_check_nan.py (1)

Line range hint 50-64: Consider enhancing NaN/inf checks for MPS compatibility.

Given this PR's focus on MPS support and considering that MPS has specific numerical precision requirements, consider enhancing this function to:

  1. Add specific checks for MPS-related numerical issues
  2. Include device-specific context in error messages

Here's a suggested enhancement:

-    def check_array_value(self, name: str, value: np.ndarray) -> list[tuple[str, str]]:
+    def check_array_value(self, name: str, value: np.ndarray, device: str = "cpu") -> list[tuple[str, str]]:
         """
         Check for inf and NaN for a single numpy array.
 
         :param name: Name of the value being check
         :param value: Value (numpy array) to check
+        :param device: The compute device (e.g., "cpu", "cuda", "mps")
         :return: A list of issues found.
         """
         found = []
+        # MPS-specific checks
+        if device == "mps" and value.dtype == np.float64:
+            found.append((name, "float64_on_mps"))
         has_nan = np.any(np.isnan(value))
         has_inf = self.check_inf and np.any(np.isinf(value))
         if has_inf:
             found.append((name, "inf"))
         if has_nan:
             found.append((name, "nan"))
         return found
stable_baselines3/common/results_plotter.py (3)

Line range hint 47-70: Consider explicit dtype casting for MPS compatibility

Since MPS doesn't support float64, consider explicitly casting numpy arrays to float32 when extracting values from the DataFrame.

    if x_axis == X_TIMESTEPS:
-        x_var = np.cumsum(data_frame.l.values)
-        y_var = data_frame.r.values
+        x_var = np.cumsum(data_frame.l.values).astype(np.float32)
+        y_var = data_frame.r.values.astype(np.float32)
    elif x_axis == X_EPISODES:
-        x_var = np.arange(len(data_frame))
-        y_var = data_frame.r.values
+        x_var = np.arange(len(data_frame), dtype=np.float32)
+        y_var = data_frame.r.values.astype(np.float32)
    elif x_axis == X_WALLTIME:
        # Convert to hours
-        x_var = data_frame.t.values / 3600.0
-        y_var = data_frame.r.values
+        x_var = (data_frame.t.values / 3600.0).astype(np.float32)
+        y_var = data_frame.r.values.astype(np.float32)

Line range hint 72-97: Handle potential float64 outputs from numpy operations

The np.mean operation might return float64 results which are incompatible with MPS. Consider explicitly casting the results to float32.

    if x.shape[0] >= EPISODES_WINDOW:
        # Compute and plot rolling mean with window of size EPISODE_WINDOW
-        x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
+        x, y_mean = window_func(x, y, EPISODES_WINDOW, lambda x, **kwargs: np.mean(x, **kwargs).astype(np.float32))

Line range hint 1-124: Consider comprehensive dtype management for MPS support

While the type hint updates are good, this plotting utility needs a more comprehensive approach to dtype management for proper MPS support. Consider:

  1. Adding a global configuration for dtype (float32/float64)
  2. Implementing consistent dtype handling across all numpy operations
  3. Adding tests specifically for MPS compatibility
stable_baselines3/common/vec_env/__init__.py (1)

2-2: Consider adding MPS-specific environment wrapper.

Given the PR's objective of adding MPS support and handling float32/float64 tensor conversions, consider adding an MpsVecWrapper class to this file. This wrapper could:

  1. Automatically handle tensor dtype conversions for MPS compatibility
  2. Provide clear error messages for unsupported operations
  3. Ensure consistent device placement across the vectorized environment

This would centralize MPS-specific logic and make it easier to maintain.

Would you like me to help design the MPS wrapper class with the necessary tensor conversion logic?

docs/guide/install.rst (1)

10-10: Add note about MPS support

Consider adding information about MPS support for Apple Silicon users, as this is a key feature enabled by these version updates.

Add a note like:

Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3
+
+.. note::
+
+    PyTorch 2.3+ includes improved support for Apple Silicon GPUs through the Metal Performance Shaders (MPS) backend.
+    Users with Apple Silicon Macs can leverage GPU acceleration for improved performance.
docs/modules/sac.rst (1)

38-40: LGTM! Consider adding brief explanation of log-space benefits.

The added note about temperature optimization in log-space is accurate and well-placed. The references to GitHub issues provide good empirical evidence for the stability benefits.

Consider expanding the note slightly to explain why log-space optimization is more stable (e.g., better numerical stability due to avoiding very small values, consistent with how neural networks often handle similar coefficients). This would help users better understand the implementation choice.

 .. note::
    When automatically adjusting the temperature (alpha/entropy coefficient), we optimize the logarithm of the entropy coefficient instead of the entropy coefficient itself. This is consistent with the original implementation and has proven to be more stable
-    (see issues `GH#36 <https://github.com/DLR-RM/stable-baselines3/issues/36>`_, `#55 <https://github.com/araffin/sbx/issues/55>`_ and others).
+    due to better numerical stability when dealing with small coefficient values (see issues `GH#36 <https://github.com/DLR-RM/stable-baselines3/issues/36>`_, `#55 <https://github.com/araffin/sbx/issues/55>`_ and others).
stable_baselines3/common/sb2_compat/rmsprop_tf_like.py (1)

Line range hint 37-86: Consider adding device-specific dtype handling

To ensure robust MPS support, consider adding device-specific dtype handling in the optimizer initialization. This could be implemented as a utility function that ensures tensor operations use supported dtypes for the target device.

Consider these approaches:

  1. Add a utility function in a common location to handle dtype compatibility
  2. Add device checks in the optimizer initialization
  3. Document device-specific dtype requirements

Example utility function:

def ensure_compatible_dtype(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
    """Ensures tensor dtype is compatible with the given device."""
    if device.type == 'mps' and tensor.dtype == torch.float64:
        return tensor.to(torch.float32)
    return tensor
stable_baselines3/ddpg/ddpg.py (2)

Line range hint 76-77: Document MPS device support in parameters

Since this PR adds MPS support, the device parameter documentation should explicitly mention MPS compatibility.

Add MPS-specific information to the docstring:

-    :param device: Device (cpu, cuda, ...) on which the code should be run.
-        Setting it to auto, the code will be run on the GPU if possible.
+    :param device: Device (cpu, cuda, mps, ...) on which the code should be run.
+        Setting it to auto, the code will be run on the GPU if possible.
+        For Apple Silicon (M1/M2) users, MPS backend will be used when available.

Line range hint 1-127: Overall: MPS support needs more comprehensive implementation

While the type hint updates are valuable, this file requires additional changes to fully support MPS devices:

  1. Add MPS-specific device handling utilities
  2. Implement proper tensor dtype management for MPS compatibility
  3. Add MPS-specific training optimizations
  4. Improve documentation to cover MPS usage

These changes are essential to meet the PR's objective of proper PyTorch GPU compatibility on Apple Silicon.

stable_baselines3/common/envs/identity_env.py (1)

Test environments need MPS device coverage

The test environments (IdentityEnv and FakeImageEnv) lack explicit MPS device testing:

  • Current MPS tests are limited to basic functionality in test_spaces.py and test_save_load.py
  • No MPS-specific test coverage for these test environments
  • These environments are used across multiple test files but without device-specific testing
🔗 Analysis chain

Line range hint 1-161: Verify test coverage for MPS compatibility

While the test environments themselves are MPS-compatible (using uint8 and float32 dtypes), we should ensure they are included in the test suite with the MPS device.

Consider adding explicit test cases for these environments with MPS device to ensure they work correctly with Apple Silicon GPUs.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for test files that use these environments
rg -l "IdentityEnv|FakeImageEnv" tests/

Length of output: 199


Script:

#!/bin/bash
# Check if these test files include MPS device testing
rg -A 5 "mps|device" tests/test_identity.py tests/test_envs.py tests/test_cnn.py

# Check if there are any MPS-related tests in the test suite
rg -l "mps" tests/

Length of output: 146


Script:

#!/bin/bash
# Check how MPS is being used in the existing tests
rg -B 2 -A 5 "mps" tests/test_spaces.py tests/test_save_load.py

# Check if there's any device-related configuration or setup in test configuration
rg -l "device|cuda|torch" tests/conftest.py

Length of output: 1252

stable_baselines3/common/envs/multi_input_envs.py (1)

Line range hint 52-65: Consider standardizing dtypes for better MPS compatibility

The environment mixes different dtypes (float64 for vectors, uint8 for images) which could cause issues with MPS support. Consider:

  1. Standardizing on float32 for all floating-point data
  2. Adding explicit dtype conversion in the observation space
  3. Documenting dtype requirements for MPS compatibility

This would help address the test failures mentioned in the PR objectives and improve compatibility with Apple Silicon GPUs.

tests/test_gae.py (1)

Line range hint 108-122: Consider adding MPS device compatibility tests

Given that this PR aims to improve MPS support, consider:

  1. Adding test cases for MPS device compatibility
  2. Ensuring tensor operations in CustomPolicy handle float32 correctly

Example test case to add:

@pytest.mark.parametrize("device", ["cpu", "cuda", "mps"])
def test_policy_device_compatibility(device):
    if not th.cuda.is_available() and device == "cuda":
        pytest.skip("CUDA not available")
    if not hasattr(th, "mps") and device == "mps":
        pytest.skip("MPS not available")
    
    env = CustomEnv()
    policy = CustomPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        device=device
    )
    # Verify tensor operations work with float32 on the device
    obs = env.observation_space.sample()
    obs_tensor = th.as_tensor(obs, device=device, dtype=th.float32)
    actions, values, log_prob = policy(obs_tensor)
    assert actions.device.type == device
    assert values.device.type == device
    assert log_prob.device.type == device
stable_baselines3/common/env_util.py (2)

46-50: Consider adding MPS-specific environment kwargs documentation

While the type hint updates are correct, given this PR's focus on MPS support, consider documenting any MPS-specific environment kwargs that users might need to set (e.g., device selection).

Add documentation like:

    :param env_kwargs: Optional keyword argument to pass to the env constructor
+                     For MPS support, you may need to specify device-related parameters.

Line range hint 1-170: Consider adding MPS device detection utility

Given this PR's focus on MPS support, consider adding a utility function in this file to detect MPS availability and handle device selection consistently across the codebase. This would help users properly initialize environments with MPS support.

Example utility function:

def get_device() -> str:
    """
    Returns the appropriate device (mps, cuda, or cpu) based on availability.
    """
    if torch.backends.mps.is_available():
        return "mps"
    elif torch.cuda.is_available():
        return "cuda"
    return "cpu"
stable_baselines3/common/vec_env/stacked_observations.py (1)

123-124: LGTM: Type hint modernization complete

The parameter and return type annotations have been consistently updated to use built-in collection types. These changes complete the type hint modernization in this file while maintaining the original functionality.

Consider adding a note in the changelog about the type hint modernization, as it affects the entire codebase and requires Python 3.9+.

stable_baselines3/common/preprocessing.py (2)

110-110: Consider adding type information to error message.

The assertion error message could be more helpful by including the expected type information.

-assert isinstance(obs, dict), f"Expected dict, got {type(obs)}"
+assert isinstance(obs, dict), f"Expected dict[str, torch.Tensor], got {type(obs)}"

Line range hint 119-122: Image normalization is MPS-compatible.

The current implementation using .float() defaults to float32, which is compatible with MPS devices. The normalization operation obs.float() / 255.0 will work correctly on Apple Silicon GPUs.

Consider adding a comment documenting that this operation produces float32 tensors, which is important for MPS compatibility.

stable_baselines3/common/monitor.py (1)

Line range hint 1-256: Consider adding MPS-specific type checking

Given that this file handles critical reward tracking and metric accumulation, consider adding explicit type checking or conversion for MPS compatibility:

  1. Add a utility function to ensure rewards are always float32 when using MPS
  2. Consider adding warning logs when float64 values are detected with MPS

Would you like help implementing these suggestions?

stable_baselines3/a2c/a2c.py (1)

Line range hint 68-90: Consider explicit MPS device handling

While the device parameter supports MPS through the Union type, consider adding explicit documentation about MPS support and any limitations.

Add MPS-related documentation to the docstring:

 :param device: Device (cpu, cuda, ...) on which the code should be run.
-        Setting it to auto, the code will be run on the GPU if possible.
+        Setting it to auto, the code will be run on the GPU if possible.
+        For Apple Silicon, MPS device is supported but requires float32 tensors.
stable_baselines3/dqn/policies.py (2)

285-291: Handle mixed observation types for MPS compatibility

The MultiInputPolicy needs to handle dictionary observations which might contain mixed types. Ensure all observation processing is compatible with MPS limitations:

  1. Convert all numerical observations to float32
  2. Handle non-numerical observations appropriately

Consider adding a preprocessing step in CombinedExtractor to ensure MPS compatibility:

def preprocess_observation(self, obs: dict[str, Any]) -> dict[str, Any]:
    if self.device.type == "mps":
        return {
            key: value.to(dtype=th.float32) if isinstance(value, th.Tensor) else value
            for key, value in obs.items()
        }
    return obs

Line range hint 1-291: Consider adding explicit MPS support handling

While the type hint updates are good, this file lacks explicit MPS support handling. Consider:

  1. Adding MPS device detection
  2. Implementing fallback mechanisms for unsupported operations
  3. Adding warnings for unsupported features (e.g., float64 operations)

Add a device compatibility check in the base policy:

def _check_device_compatibility(self) -> None:
    if self.device.type == "mps":
        if any(param.dtype == th.float64 for param in self.parameters()):
            raise ValueError(
                "MPS device does not support float64. "
                "Please ensure all parameters are float32."
            )
stable_baselines3/td3/td3.py (2)

Line range hint 156-174: Add explicit dtype handling for MPS compatibility

Given that MPS doesn't support float64, we should ensure all tensor operations use float32. Consider adding explicit dtype handling for the noise generation and subsequent operations:

- noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise)
+ # Ensure float32 dtype for MPS compatibility
+ noise = replay_data.actions.clone().to(dtype=th.float32).data.normal_(0, self.target_policy_noise)

Also consider adding a validation in the constructor to ensure proper dtype usage:

def _setup_model(self) -> None:
    super()._setup_model()
    if self.device.type == "mps":
        # Ensure all model parameters are float32
        self.policy.to(dtype=th.float32)

Line range hint 1-236: Consider implementing comprehensive MPS support strategy

While the type hint updates are good, the TD3 implementation needs a more comprehensive strategy for MPS support:

  1. Add a device compatibility check in the constructor
  2. Implement a central dtype management system
  3. Add validation for unsupported operations on MPS
  4. Document MPS-specific limitations and behaviors

This would help prevent runtime errors and provide better user experience on Apple Silicon devices.

Consider creating a base mixin or utility class that handles MPS-specific compatibility:

class MPSCompatibilityMixin:
    def validate_mps_compatibility(self):
        if self.device.type == "mps":
            # Validate all model parameters are float32
            # Check for unsupported operations
            # Set appropriate flags/warnings
            pass

    def ensure_tensor_compatibility(self, tensor: th.Tensor) -> th.Tensor:
        if self.device.type == "mps":
            return tensor.to(dtype=th.float32)
        return tensor
stable_baselines3/common/atari_wrappers.py (1)

Line range hint 209-227: Consider future GPU-accelerated preprocessing optimization

The image preprocessing pipeline (grayscale conversion, resizing, etc.) currently runs on CPU using OpenCV and NumPy. While this is outside the scope of the current MPS compatibility fixes, consider creating a future enhancement ticket to evaluate GPU acceleration of the preprocessing pipeline on Apple Silicon, which could potentially improve performance.

stable_baselines3/dqn/dqn.py (2)

Line range hint 102-102: Add explicit MPS device handling

Given the PR's objective to support Apple Silicon GPUs, consider adding explicit MPS device handling:

  1. Add device compatibility checks
  2. Ensure proper fallback to CPU when MPS is unavailable

Example implementation:

    def __init__(
        self,
        policy: Union[str, type[DQNPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-4,
-       device: Union[th.device, str] = "auto",
+       device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
    ) -> None:
+       # Handle MPS device availability
+       if device == "auto" and th.backends.mps.is_available():
+           device = "mps"
+       elif device == "mps" and not th.backends.mps.is_available():
+           warnings.warn("MPS device requested but not available. Falling back to CPU.")
+           device = "cpu"
        super().__init__(
            policy,
            env,
            device=device,
            ...
        )

Line range hint 191-207: Ensure float32 tensor operations for MPS compatibility

Since MPS doesn't support float64, ensure all tensor operations use float32 dtype:

  1. Q-value computations
  2. Loss calculations

Apply this change to ensure float32 compatibility:

            with th.no_grad():
                # Compute the next Q-values using the target network
-               next_q_values = self.q_net_target(replay_data.next_observations)
+               next_q_values = self.q_net_target(replay_data.next_observations.to(dtype=th.float32))
                # Follow greedy policy: use the one with the highest value
                next_q_values, _ = next_q_values.max(dim=1)
                # Avoid potential broadcast issue
                next_q_values = next_q_values.reshape(-1, 1)
                # 1-step TD target
-               target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
+               target_q_values = (replay_data.rewards.to(dtype=th.float32) + 
+                                (1 - replay_data.dones.to(dtype=th.float32)) * 
+                                self.gamma * next_q_values)

            # Get current Q-values estimates
-           current_q_values = self.q_net(replay_data.observations)
+           current_q_values = self.q_net(replay_data.observations.to(dtype=th.float32))
stable_baselines3/common/torch_layers.py (1)

Line range hint 1-316: Consider implementing centralized tensor dtype management

To better support MPS and handle tensor dtype consistently, consider:

  1. Implementing a centralized tensor factory that enforces correct dtype based on device type
  2. Adding a device-aware tensor conversion utility
  3. Creating a configuration option to explicitly control default tensor dtype

This would make it easier to maintain dtype compatibility across different devices (CPU, CUDA, MPS) and prevent scattered dtype conversions throughout the codebase.

stable_baselines3/sac/sac.py (3)

Line range hint 170-171: Ensure float32 dtype for MPS compatibility

Given that MPS doesn't support float64, we should explicitly specify float32 dtype when creating tensors from numpy arrays:

- self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32))
+ self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32))
+ # Ensure tensor operations use float32
+ self.target_entropy_tensor = th.tensor(self.target_entropy, dtype=th.float32, device=self.device)

Line range hint 191-195: Explicitly specify float32 dtype for learned entropy coefficient

To maintain MPS compatibility, ensure the entropy coefficient tensor uses float32:

- self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
+ self.log_ent_coef = th.log(th.ones(1, dtype=th.float32, device=self.device) * init_value).requires_grad_(True)

Line range hint 118-119: Add MPS device validation

To properly support Apple Silicon GPU, we should add explicit MPS device validation:

+ @staticmethod
+ def _is_mps_available() -> bool:
+     return hasattr(th, "backends") and hasattr(th.backends, "mps") and th.backends.mps.is_available()

  def __init__(
      self,
      ...,
      device: Union[th.device, str] = "auto",
      ...
  ):
+     # Validate MPS device availability
+     if device == "mps" and not self._is_mps_available():
+         device = "cpu"
+         if self.verbose > 0:
+             print("Warning: MPS device not available. Using CPU instead.")
tests/test_vec_normalize.py (1)

Line range hint 249-266: Improve test clarity and maintainability.

Consider the following improvements to enhance test clarity:

  1. Extract magic numbers in assertions to named constants with clear meaning:
-assert np.allclose(env.obs_rms.mean, 0.5, atol=1e-4)
-assert np.allclose(env.ret_rms.mean, 0.0, atol=1e-4)
+EXPECTED_INITIAL_OBS_MEAN = 0.5
+EXPECTED_INITIAL_RET_MEAN = 0.0
+assert np.allclose(env.obs_rms.mean, EXPECTED_INITIAL_OBS_MEAN, atol=1e-4)
+assert np.allclose(env.ret_rms.mean, EXPECTED_INITIAL_RET_MEAN, atol=1e-4)
  1. Add docstrings explaining the expected behavior and why specific values are expected
stable_baselines3/common/vec_env/base_vec_env.py (1)

Line range hint 151-187: Fix incorrect parameter descriptions in docstring

The docstring for env_is_wrapped method contains copy-pasted parameter descriptions from another method. Please update them to match the actual parameters.

     def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]:
         """
         Check if environments are wrapped with a given wrapper.
 
-        :param method_name: The name of the environment method to invoke.
-        :param indices: Indices of envs whose method to call
-        :param method_args: Any positional arguments to provide in the call
-        :param method_kwargs: Any keyword arguments to provide in the call
+        :param wrapper_class: The wrapper class to check for
+        :param indices: Indices of envs to check
         :return: True if the env is wrapped, False otherwise, for each env queried.
         """
stable_baselines3/her/her_replay_buffer.py (2)

3-3: LGTM! Consider documenting Python version requirement.

The update to use built-in types for type hints (PEP 585) is good. Since this feature requires Python 3.9+, consider documenting this requirement in the project's README or requirements.txt.


137-142: Consider adding float32 conversion for MPS compatibility.

While the type hint modernization looks good, given that MPS doesn't support float64, consider adding explicit float32 conversion when handling numpy arrays to ensure compatibility with MPS devices.

Example implementation:

 def add(  # type: ignore[override]
     self,
     obs: dict[str, np.ndarray],
     next_obs: dict[str, np.ndarray],
     action: np.ndarray,
     reward: np.ndarray,
     done: np.ndarray,
     infos: list[dict[str, Any]],
 ) -> None:
+    # Ensure float32 dtype for MPS compatibility
+    for key in obs:
+        if obs[key].dtype == np.float64:
+            obs[key] = obs[key].astype(np.float32)
+    for key in next_obs:
+        if next_obs[key].dtype == np.float64:
+            next_obs[key] = next_obs[key].astype(np.float32)
+    if action.dtype == np.float64:
+        action = action.astype(np.float32)
+    if reward.dtype == np.float64:
+        reward = reward.astype(np.float32)
tests/test_logger.py (2)

Line range hint 595-607: LGTM: Well-structured test environment

The DummySuccessEnv implementation is clean and properly implements the Gymnasium interface. It effectively simulates success/failure scenarios for testing success rate logging.

Consider adding edge cases to test:

  1. All successes ([True] * STATS_WINDOW_SIZE)
  2. All failures ([False] * STATS_WINDOW_SIZE)
  3. Alternating success/failure ([True, False] * (STATS_WINDOW_SIZE // 2))

610-619: LGTM: Comprehensive success rate testing

The test effectively verifies success rate logging with different success patterns (30%, 50%, 80%). The assertions correctly validate the logger's tracking of success rates.

Consider adding error handling for edge cases:

-    assert logger.name_to_value["rollout/success_rate"] == 0.3
+    assert abs(logger.name_to_value["rollout/success_rate"] - 0.3) < 1e-6

This would handle potential floating-point precision issues.

stable_baselines3/sac/policies.py (3)

Line range hint 147-166: Ensure tensor operations use float32 for MPS compatibility

The get_action_dist_params method performs tensor operations without explicit dtype handling. For MPS compatibility, ensure all tensor operations use float32.

 def get_action_dist_params(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor, dict[str, th.Tensor]]:
     features = self.extract_features(obs, self.features_extractor)
     latent_pi = self.latent_pi(features)
-    mean_actions = self.mu(latent_pi)
+    # Ensure float32 dtype for MPS compatibility
+    mean_actions = self.mu(latent_pi.to(dtype=th.float32))

     if self.use_sde:
         return mean_actions, self.log_std, dict(latent_sde=latent_pi)
     # Unstructured exploration (Original implementation)
-    log_std = self.log_std(latent_pi)
+    log_std = self.log_std(latent_pi.to(dtype=th.float32))
     # Original Implementation to cap the standard deviation
     log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
     return mean_actions, log_std, {}

Line range hint 312-331: Update constructor parameters to include dtype

The _get_constructor_parameters method should include the dtype parameter in the returned dictionary to ensure proper reconstruction of the policy.

 def _get_constructor_parameters(self) -> dict[str, Any]:
     data = super()._get_constructor_parameters()

     data.update(
         dict(
             net_arch=self.net_arch,
             activation_fn=self.net_args["activation_fn"],
+            dtype=self.dtype,  # Add dtype to constructor parameters
             use_sde=self.actor_kwargs["use_sde"],
             log_std_init=self.actor_kwargs["log_std_init"],
             use_expln=self.actor_kwargs["use_expln"],
             clip_mean=self.actor_kwargs["clip_mean"],
             n_critics=self.critic_kwargs["n_critics"],

Line range hint 1-481: Consider implementing a global dtype configuration

To ensure consistent tensor dtype handling across the entire codebase, consider implementing a global configuration mechanism for tensor dtypes. This would help manage the MPS compatibility requirement for float32 tensors while maintaining flexibility for other backends.

Key recommendations:

  1. Add a global dtype configuration in the base policy
  2. Implement dtype validation for MPS devices
  3. Add utility functions for tensor dtype conversion
  4. Update all tensor creation operations to use the configured dtype
stable_baselines3/common/save_util.py (1)

Line range hint 76-383: Consider documenting MPS device handling

While the type hint updates are solid, consider adding documentation about MPS device support in the docstrings, particularly for load_from_zip_file and save_to_zip_file functions, as they are critical for model serialization with different devices.

stable_baselines3/common/env_checker.py (1)

Line range hint 175-199: Consider adding MPS dtype compatibility check

Since MPS doesn't support float64, and this function handles numpy arrays for reward computation, consider adding a check to ensure tensor dtypes are compatible with MPS when it's available.

Add a dtype check before the reward computation:

 def _check_goal_env_compute_reward(
     obs: dict[str, Union[np.ndarray, int]],
     env: gym.Env,
     reward: float,
     info: dict[str, Any],
 ) -> None:
     """
     Check that reward is computed with `compute_reward`
     and that the implementation is vectorized.
     """
     achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"]
+    # Ensure MPS compatibility by checking dtype
+    if hasattr(env, "device") and str(env.device) == "mps":
+        if (isinstance(achieved_goal, np.ndarray) and achieved_goal.dtype == np.float64) or \
+           (isinstance(desired_goal, np.ndarray) and desired_goal.dtype == np.float64):
+            warnings.warn(
+                "MPS device detected with float64 arrays. MPS doesn't support float64. "
+                "Consider converting arrays to float32."
+            )
stable_baselines3/common/logger.py (1)

332-333: Add error handling for file operations

The file operations could benefit from proper error handling to gracefully handle I/O errors.

Consider wrapping the file operations in a try-except block:

-        self.file = open(filename, "w+")
+        try:
+            self.file = open(filename, "w+")
+        except IOError as e:
+            raise RuntimeError(f"Failed to open CSV file {filename}: {e}")
stable_baselines3/common/distributions.py (1)

Line range hint 644-653: Consider standardizing epsilon handling across distributions

The TanhBijector.inverse method uses dtype-specific epsilon (th.finfo(y.dtype).eps), while other parts of the code use fixed values (e.g., 1e-6). Consider standardizing this approach across all distributions for consistent numerical stability across different hardware and backends.

Example implementation:

 def __init__(self, epsilon: float = 1e-6):
     super().__init__()
-    self.epsilon = epsilon
+    # Use dtype-specific epsilon for better numerical stability
+    self.epsilon = th.finfo(th.float32).eps  # Default to float32 for MPS compatibility
stable_baselines3/common/base_class.py (1)

Line range hint 1-7: Add MPS-specific documentation

Given that MPS doesn't support float64 tensors, consider adding documentation in the class docstring about tensor dtype requirements and restrictions when using MPS devices. This would help users and derived classes handle dtype conversions appropriately.

Add the following to the class docstring:

 """Abstract base classes for RL algorithms."""
+
+"""
+Note on Apple Silicon GPU Support:
+When using the MPS (Metal Performance Shaders) device, ensure all tensors are
+float32 as float64 operations are not supported. Derived classes should handle
+appropriate tensor casting when MPS device is detected.
+"""
stable_baselines3/common/policies.py (2)

Line range hint 971-979: Add explicit dtype handling for MPS compatibility

Given that MPS doesn't support float64 (as mentioned in the PR objectives), we should add explicit dtype handling in the forward pass to ensure tensors are cast to float32 when using MPS device.

Add dtype handling in the forward method:

 def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]:
     # Learn the features extractor using the policy loss only
     # when the features_extractor is shared with the actor
     with th.set_grad_enabled(not self.share_features_extractor):
         features = self.extract_features(obs, self.features_extractor)
+    # Ensure float32 dtype for MPS compatibility
+    if features.device.type == "mps":
+        features = features.to(dtype=th.float32)
+        actions = actions.to(dtype=th.float32)
     qvalue_input = th.cat([features, actions], dim=1)
     return tuple(q_net(qvalue_input) for q_net in self.q_networks)

Line range hint 1-971: Consider centralizing MPS dtype handling

Given the MPS float64 limitations, consider implementing a centralized dtype management system that automatically handles tensor dtype conversions when using MPS device. This could be implemented as a utility function or decorator that wraps tensor operations.

Consider creating a utility function in stable_baselines3/common/utils.py:

def ensure_mps_compatible_dtype(tensor: th.Tensor) -> th.Tensor:
    """Ensures tensor dtype is compatible with MPS device."""
    if tensor.device.type == "mps" and tensor.dtype == th.float64:
        return tensor.to(dtype=th.float32)
    return tensor
docs/misc/changelog.rst (4)

11-13: Consider adding migration guide for breaking changes

Since there are significant breaking changes (PyTorch 2.3.0 requirement and Python 3.8 removal), consider adding a migration guide section to help users upgrade smoothly.

Would you like me to help draft a migration guide section?


16-18: Expand on NumPy v2.0 support details

The changelog would benefit from more details about the NumPy v2.0 support changes, specifically:

  • What changes were made to VecNormalize
  • How the bit flipping env was updated
  • Any potential breaking changes or considerations for users

114-116: Add more context to PPO documentation update

The PPO documentation update regarding CPU usage with MlpPolicy should include:

  • Reasoning behind the recommendation
  • Performance implications
  • When users should/shouldn't follow this recommendation

Line range hint 1-1127: Standardize formatting across changelog

For better readability and consistency:

  1. Standardize punctuation at the end of bullet points
  2. Maintain consistent spacing between items within sections
  3. Use consistent capitalization in section titles
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 4c03a25 and 0ec37d8.

📒 Files selected for processing (70)
  • .github/workflows/ci.yml (2 hunks)
  • README.md (9 hunks)
  • docs/conda_env.yml (1 hunks)
  • docs/conf.py (1 hunks)
  • docs/guide/install.rst (1 hunks)
  • docs/index.rst (2 hunks)
  • docs/misc/changelog.rst (7 hunks)
  • docs/modules/dqn.rst (1 hunks)
  • docs/modules/sac.rst (1 hunks)
  • pyproject.toml (1 hunks)
  • setup.py (4 hunks)
  • stable_baselines3/a2c/a2c.py (3 hunks)
  • stable_baselines3/common/atari_wrappers.py (2 hunks)
  • stable_baselines3/common/base_class.py (14 hunks)
  • stable_baselines3/common/buffers.py (8 hunks)
  • stable_baselines3/common/callbacks.py (12 hunks)
  • stable_baselines3/common/distributions.py (13 hunks)
  • stable_baselines3/common/env_checker.py (2 hunks)
  • stable_baselines3/common/env_util.py (5 hunks)
  • stable_baselines3/common/envs/bit_flipping_env.py (5 hunks)
  • stable_baselines3/common/envs/identity_env.py (4 hunks)
  • stable_baselines3/common/envs/multi_input_envs.py (4 hunks)
  • stable_baselines3/common/evaluation.py (2 hunks)
  • stable_baselines3/common/logger.py (14 hunks)
  • stable_baselines3/common/monitor.py (8 hunks)
  • stable_baselines3/common/noise.py (2 hunks)
  • stable_baselines3/common/off_policy_algorithm.py (6 hunks)
  • stable_baselines3/common/on_policy_algorithm.py (5 hunks)
  • stable_baselines3/common/policies.py (17 hunks)
  • stable_baselines3/common/preprocessing.py (4 hunks)
  • stable_baselines3/common/results_plotter.py (5 hunks)
  • stable_baselines3/common/running_mean_std.py (1 hunks)
  • stable_baselines3/common/save_util.py (5 hunks)
  • stable_baselines3/common/sb2_compat/rmsprop_tf_like.py (2 hunks)
  • stable_baselines3/common/torch_layers.py (6 hunks)
  • stable_baselines3/common/type_aliases.py (3 hunks)
  • stable_baselines3/common/utils.py (10 hunks)
  • stable_baselines3/common/vec_env/__init__.py (3 hunks)
  • stable_baselines3/common/vec_env/base_vec_env.py (10 hunks)
  • stable_baselines3/common/vec_env/dummy_vec_env.py (4 hunks)
  • stable_baselines3/common/vec_env/patch_gym.py (1 hunks)
  • stable_baselines3/common/vec_env/stacked_observations.py (3 hunks)
  • stable_baselines3/common/vec_env/subproc_vec_env.py (9 hunks)
  • stable_baselines3/common/vec_env/util.py (4 hunks)
  • stable_baselines3/common/vec_env/vec_check_nan.py (1 hunks)
  • stable_baselines3/common/vec_env/vec_frame_stack.py (2 hunks)
  • stable_baselines3/common/vec_env/vec_monitor.py (2 hunks)
  • stable_baselines3/common/vec_env/vec_normalize.py (9 hunks)
  • stable_baselines3/common/vec_env/vec_transpose.py (3 hunks)
  • stable_baselines3/common/vec_env/vec_video_recorder.py (4 hunks)
  • stable_baselines3/ddpg/ddpg.py (2 hunks)
  • stable_baselines3/dqn/dqn.py (6 hunks)
  • stable_baselines3/dqn/policies.py (7 hunks)
  • stable_baselines3/her/her_replay_buffer.py (4 hunks)
  • stable_baselines3/ppo/ppo.py (3 hunks)
  • stable_baselines3/sac/policies.py (9 hunks)
  • stable_baselines3/sac/sac.py (5 hunks)
  • stable_baselines3/td3/policies.py (7 hunks)
  • stable_baselines3/td3/td3.py (4 hunks)
  • stable_baselines3/version.txt (1 hunks)
  • tests/test_dict_env.py (3 hunks)
  • tests/test_distributions.py (1 hunks)
  • tests/test_env_checker.py (3 hunks)
  • tests/test_gae.py (4 hunks)
  • tests/test_logger.py (3 hunks)
  • tests/test_spaces.py (3 hunks)
  • tests/test_tensorboard.py (3 hunks)
  • tests/test_utils.py (4 hunks)
  • tests/test_vec_envs.py (4 hunks)
  • tests/test_vec_normalize.py (5 hunks)
✅ Files skipped from review due to trivial changes (2)
  • docs/conf.py
  • stable_baselines3/common/vec_env/vec_monitor.py
🚧 Files skipped from review as they are similar to previous changes (9)
  • .github/workflows/ci.yml
  • README.md
  • docs/conda_env.yml
  • docs/index.rst
  • stable_baselines3/common/buffers.py
  • stable_baselines3/common/on_policy_algorithm.py
  • stable_baselines3/common/utils.py
  • stable_baselines3/common/vec_env/vec_normalize.py
  • stable_baselines3/version.txt
🔇 Additional comments (155)
stable_baselines3/common/vec_env/vec_video_recorder.py (1)

21-21: Verify License Compliance for Adapted Code

You mentioned that the video recorder code was adapted from Gymnasium v1.0. Ensure that the use of this code complies with Gymnasium's license terms, and include appropriate attribution if required.

Would you like assistance in reviewing the licensing requirements?

stable_baselines3/common/running_mean_std.py (2)

5-5: LGTM: Type hint modernization

The update from Tuple to tuple aligns with modern Python type hinting introduced in Python 3.9+.


Line range hint 12-13: Consider float32 support for MPS compatibility

The class currently uses np.float64 for mean and variance calculations, but MPS (Metal Performance Shaders) doesn't support float64 operations. This could cause issues when using the library with Apple Silicon GPUs.

Consider:

  1. Adding a dtype parameter to support both float32 and float64:
-    def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()):
+    def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = (), dtype: np.dtype = np.float64):
-        self.mean = np.zeros(shape, np.float64)
-        self.var = np.ones(shape, np.float64)
+        self.mean = np.zeros(shape, dtype)
+        self.var = np.ones(shape, dtype)
  1. Or defaulting to float32 when MPS is detected:
def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()):
    dtype = np.float32 if torch.backends.mps.is_available() else np.float64
    self.mean = np.zeros(shape, dtype)
    self.var = np.ones(shape, dtype)

Let's check if this class is used in conjunction with MPS device:

stable_baselines3/common/vec_env/vec_frame_stack.py (2)

1-2: LGTM! Good modernization of imports.

The change from typing.Mapping to collections.abc.Mapping follows Python's best practices for type hints. This is a positive change that maintains backward compatibility while preparing for future Python versions.


33-37: Verify tensor handling for MPS compatibility.

While the type hint updates look good, given the PR's focus on MPS support, we should verify that the stacked observations maintain the correct dtype (float32) when using MPS.

stable_baselines3/common/vec_env/util.py (2)

5-5: LGTM! Import changes align with type hint modernization.

The simplified imports reflect the move away from OrderedDict dependency, which is a good practice for reducing unnecessary dependencies.


54-57: Verify removal of OrderedDict dependencies.

The type ignore comments suggest potential type checking issues. Let's verify that all OrderedDict dependencies have been properly removed from related files.

Also applies to: 65-65

✅ Verification successful

OrderedDict dependencies have been properly removed from this file

The verification shows that the file stable_baselines3/common/vec_env/util.py has no OrderedDict imports or usage. The type ignore comments are related to type checking for dictionary assignments and not related to OrderedDict. While OrderedDict is still used in some test files and other parts of the codebase, its removal from this specific implementation is correct and complete.

The changelog entry also confirms this was an intentional change: "Removed unnecessary copy_obs_dict method for SubprocVecEnv, remove the use of ordered dict and rename flatten_obs to stack_obs"

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any remaining OrderedDict references in the codebase
# that might affect this implementation.

echo "Checking for OrderedDict imports and usages..."
rg -l "OrderedDict" --type python

echo "Checking for ordered dictionary references in comments and strings..."
rg -l "ordered.*dict" --type python

Length of output: 385


Script:

#!/bin/bash
# Let's try without specifying the file type since it's not recognized
echo "Checking for OrderedDict imports and usages..."
rg "OrderedDict"

echo -e "\nChecking for ordered dictionary references in comments and strings..."
rg "ordered.*dict"

echo -e "\nLet's also check the specific file and its imports..."
cat "stable_baselines3/common/vec_env/util.py"

Length of output: 4335

tests/test_tensorboard.py (2)

36-38: LGTM! Type hint update is correct.

The change from Dict to dict for metric_dict is appropriate and maintains consistency with the previous type hint update.


27-30: LGTM! Verify tensorboard logging on MPS devices.

The type hint update from Dict to dict aligns with modern Python practices. The type ignore for gamma is properly documented.

Let's verify if there are any MPS-specific tensorboard logging issues:

✅ Verification successful

The previous search didn't give us enough context about MPS-specific test configurations. Let's check test files and device handling more specifically to determine if there are any MPS-related concerns for tensorboard logging.


Based on the gathered information, I can now generate the final response:

No MPS-specific concerns for tensorboard logging

The type hint changes in test_tensorboard.py are safe. The codebase properly handles device management:

  • Tensorboard logging is device-agnostic and operates on CPU-converted data before logging
  • The codebase has proper MPS device support with appropriate fallbacks
  • No MPS-specific test skips or conditions affect tensorboard functionality
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any MPS-related tensor operations that might affect tensorboard logging
rg -l "mps" | xargs rg "tensorboard|logger"

# Check if there are any MPS-specific test skips or conditions
rg "skipif.*mps|pytest.*mps"

Length of output: 7358


Script:

#!/bin/bash
# Let's check for any MPS-specific test configurations or device handling in test files
rg -A 5 "device|mps" tests/

# Check for any MPS-specific test configurations in conftest.py
fd "conftest.py" | xargs cat

# Check for any device-specific handling in tensorboard-related files
rg -l "tensorboard" | xargs rg "device|mps|cuda"

Length of output: 20413

stable_baselines3/common/type_aliases.py (2)

4-4: LGTM! Good modernization of type hints.

The removal of Tuple, Dict, and List imports in favor of using built-in types for type hints aligns with PEP 585 and modern Python practices.


16-23: Verify MPS compatibility with these type definitions.

The type updates look good and follow modern Python practices. However, given that this PR adds MPS support, we should verify that these type definitions work correctly with MPS operations.

✅ Verification successful

Type aliases are compatible with MPS device support

Based on the verification results, these type definitions are compatible with MPS operations:

  • The codebase properly handles MPS device through get_device() utility which supports MPS alongside CPU and CUDA
  • The type aliases (GymObs, TensorDict, etc.) are used consistently with device operations:
    • obs_as_tensor() converts observations to tensors on any device including MPS
    • Tensor operations use these types when moving data between devices
    • Buffer implementations handle these types correctly across different devices
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for any MPS-specific type handling in the codebase
# that might interact with these type aliases

# Look for MPS device checks or tensor conversions
rg -l "mps" | xargs rg "\.to\(|device|dtype"

# Look for any special handling of these types with MPS
rg -l "GymObs|GymStepReturn|TensorDict" | xargs rg "mps"

Length of output: 13193

docs/modules/dqn.rst (1)

28-28: Verify the tutorial link exists and is accessible.

The added tutorial link appears to be a valuable resource for users learning about DQN.

✅ Verification successful

Tutorial link is valid and provides relevant DQN content

The added tutorial link points to a public GitHub repository that specifically focuses on Deep Q-Network (DQN) and Fitted Q-Iteration. The repository exists, is actively maintained (not archived), and contains relevant educational content as described in its description.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the tutorial repository exists and is accessible
gh repo view araffin/rlss23-dqn-tutorial --json name,description,isArchived,defaultBranch

Length of output: 1222


Script:

#!/bin/bash
# Description: Verify that the tutorial repository exists and is accessible (with correct JSON fields)
gh repo view araffin/rlss23-dqn-tutorial --json name,description,isArchived,url,visibility

Length of output: 320

stable_baselines3/common/vec_env/patch_gym.py (1)

Line range hint 1-92: Verify impact on MPS support

While this file handles environment compatibility, it doesn't contain any direct MPS-related changes. Let's verify if any environment conversion logic needs to be updated to handle MPS-specific tensor operations.

stable_baselines3/common/vec_env/vec_check_nan.py (1)

50-50: LGTM! Verify minimum Python version requirement.

The update to use built-in type annotations (list[tuple[str, str]]) aligns with PEP 585 and modern Python practices.

Let's verify the project's minimum Python version requirement:

✅ Verification successful

Type hint modernization is compatible with project's Python requirements

The change to use built-in type annotations (list[tuple[str, str]]) is fully compatible with the project's minimum Python version requirement of >=3.9 as specified in setup.py. This syntax is supported since Python 3.9 (PEP 585).

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check setup.py or pyproject.toml for Python version requirement
rg -A 5 "python_requires|requires-python" 

Length of output: 405

stable_baselines3/common/results_plotter.py (2)

Line range hint 102-124: Verify dtype of loaded results

The load_results function might return data in float64 format. Consider adding dtype verification or explicit casting in the data loading process.


Line range hint 32-45: Verify numpy array dtype compatibility with MPS

While the type hint update looks good, given that MPS doesn't support float64, we should verify the dtype of numpy arrays being processed here.

✅ Verification successful

Based on the investigation, I can now provide a final response:

No dtype-related concerns for MPS compatibility

The window_func and its dependencies are safe for MPS compatibility because:

  1. The input arrays come from pandas DataFrame values (data_frame.l.values and data_frame.r.values) which default to float32 for numerical data
  2. The rolling window operation uses as_strided which creates a view without changing the underlying data type
  3. No explicit dtype conversions or float64 operations are performed in the results plotting utilities
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any explicit dtype specifications or float64 usage in test files
rg -l "float64|dtype=" "tests/" | while read -r file; do
  echo "=== $file ==="
  rg "window_func.*float64|window_func.*dtype=" "$file" -A 2
done

Length of output: 1162


Script:

#!/bin/bash
# Let's check the actual usage of window_func in the codebase and any test cases
echo "=== Checking window_func usage ==="
rg "window_func" -B 2 -A 2

# Also check the rolling_window implementation since it's used by window_func
echo -e "\n=== Checking rolling_window implementation ==="
rg "def rolling_window" -A 5 stable_baselines3/common/results_plotter.py

# Check if there are any explicit dtype conversions in results_plotter.py
echo -e "\n=== Checking dtype usage in results_plotter.py ==="
rg "dtype" stable_baselines3/common/results_plotter.py

# Let's also check the test files that might use these plotting utilities
echo -e "\n=== Checking test files using results_plotter ==="
rg "from stable_baselines3.common.results_plotter" tests/

Length of output: 1709


Script:

#!/bin/bash
# Let's check the implementation of rolling_window and its usage in more detail
echo "=== Checking rolling_window implementation ==="
rg -A 15 "def rolling_window" stable_baselines3/common/results_plotter.py

# Check where the input arrays for window_func come from by looking at ts2xy
echo -e "\n=== Checking ts2xy implementation ==="
rg -A 15 "def ts2xy" stable_baselines3/common/results_plotter.py

# Check if there are any numpy array creation or conversion in the file
echo -e "\n=== Checking numpy array operations ==="
rg "np\.(array|asarray)" stable_baselines3/common/results_plotter.py

# Check the monitor implementation since it's likely the source of the data
echo -e "\n=== Checking monitor data types ==="
rg "dtype" stable_baselines3/common/monitor.py

Length of output: 1889

stable_baselines3/common/vec_env/__init__.py (2)

45-45: LGTM! Consistent type hint update.

The change maintains consistency with the previous update. Since this function is crucial for wrapper detection, ensure it's tested with any MPS-specific wrappers that might be added.

#!/bin/bash
# Check if there are any MPS-specific wrapper tests
rg -l "mps.*test.*vec.*wrap|test.*mps.*vec.*wrap"

19-19: LGTM! Type hint update follows PEP 585.

The change from Type[VecEnvWrapperT] to type[VecEnvWrapperT] aligns with PEP 585's recommendation to use built-in types for annotations.

✅ Verification successful

Type hint update is compatible with project's Python requirements

The change from Type[VecEnvWrapperT] to type[VecEnvWrapperT] is safe as the project requires Python >=3.9 (found in setup.py), which fully supports this syntax. The update aligns with PEP 585's recommendation to use built-in types for annotations.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Verify minimum Python version requirement in setup.py or pyproject.toml
# as this syntax requires Python 3.7+
fd -e toml -e cfg -e py -x rg -l "python.*3\.[789]" {}

Length of output: 245


Script:

#!/bin/bash
# Check Python version requirement in setup.py
cat setup.py | grep -A 5 "python_requires"

# Check for any type hint related imports or comments
rg "from typing import Type|from __future__ import annotations|# type:|typing import|PEP 585" 

# Look for any version compatibility notes in README
cat README.md | grep -i "python.*version\|compatibility"

Length of output: 22594

stable_baselines3/common/vec_env/vec_transpose.py (3)

76-76: LGTM! Type annotation is accurate and follows conventions.

The type annotation correctly specifies that the method handles both numpy arrays and dictionaries.


109-109: LGTM! Return type annotation is consistent.

The return type annotation matches the transpose_observations method, maintaining consistency throughout the class.


Line range hint 76-94: Verify dtype handling for MPS compatibility

Given that MPS doesn't support float64, we should verify that image data is properly handled when using MPS device. While the Box space uses uint8, the actual observations might need explicit float32 casting when used with MPS.

Let's check for any float64 usage in observation handling:

docs/guide/install.rst (1)

10-10: LGTM: Version requirements updated for MPS support

The updated requirements align with the PR's objective to support MPS (Metal Performance Shaders) for Apple Silicon.

setup.py (4)

112-126: Well-structured organization of extra dependencies

The consolidation of dependencies into the 'extra' section improves clarity and maintainability. The grouping by functionality (render, tensorboard, memory checking, etc.) is logical and well-documented.


153-153: Verify testing coverage for Python 3.12

The addition of Python 3.12 support requires verification of test coverage and compatibility.

#!/bin/bash
# Check CI configuration for Python 3.12
echo "Checking CI configuration..."
rg -l "python-version.*3\.12" .github/workflows/

# Check test results for Python 3.12
echo "Checking test configurations..."
rg -l "python.*3\.12" tests/

79-81: Consider tightening numpy version constraint and verify torch compatibility

  1. The numpy upper bound <3.0 might be too permissive. Consider tightening it to <2.0 to prevent potential compatibility issues with major version changes.

  2. The torch requirement >=2.3 is appropriate for MPS support, but let's verify compatibility with the codebase.

✅ Verification successful

Based on the comprehensive analysis of the codebase, I can now provide a final response regarding the dependency version constraints:

The numpy version constraint <3.0 is appropriate, and torch >=2.3 is well integrated

  1. The codebase extensively uses numpy's float32/64 datatypes across multiple components (buffers, running mean std, etc.) but doesn't rely on numpy-specific APIs that might break in major versions.

  2. The torch integration is robust:

  • Proper device handling across policies and algorithms
  • MPS support is properly integrated through the device management system
  • Consistent dtype handling in buffers with explicit float32 conversions

The version constraints in setup.py are well-aligned with the codebase's requirements and the PR's objectives for Apple Silicon support.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for torch.mps usage in the codebase
echo "Checking for MPS-related code..."
rg -l "torch\.(mps|backends\.mps)"

# Check for numpy version references
echo "Checking for numpy version dependencies..."
rg -l "numpy([><=]=|==)[0-9]"

Length of output: 266


Script:

#!/bin/bash
# Let's check the actual content of conda_env.yml to see numpy version
echo "Checking conda_env.yml numpy version..."
cat docs/conda_env.yml

# Let's check for any numpy-specific functionality usage
echo -e "\nChecking for numpy array operations..."
rg "np\.(float|int)(32|64)" -l

# Let's also check for torch dtype usage which might affect numpy compatibility
echo -e "\nChecking for torch dtype specifications..."
rg "torch\.(float|int)(32|64)" -l

# Let's verify if there are any direct MPS device assignments
echo -e "\nChecking for device assignments..."
rg "device\s*=\s*['\"]mps['\"]" -l

Length of output: 1784


Script:

#!/bin/bash
# Let's check the actual numpy usage in files that use numpy dtypes
echo "Checking numpy dtype usage in running_mean_std.py..."
cat stable_baselines3/common/running_mean_std.py

echo -e "\nChecking numpy dtype usage in buffers.py..."
cat stable_baselines3/common/buffers.py

# Let's also check torch device handling
echo -e "\nChecking torch device handling..."
rg "\.to\(.*device" -l

# Let's verify torch dtype conversions
echo -e "\nChecking torch dtype conversions..."
rg "\.to\(dtype=" -l

Length of output: 37606


138-138: Verify documentation of Python version requirement change

The increase in minimum Python version to 3.9 is a breaking change that needs proper documentation.

✅ Verification successful

Python version requirement change is properly documented

The change in minimum Python version requirement from 3.8 to 3.9 is properly documented:

  1. The change is mentioned in the changelog under Release 2.5.0a0 as a breaking change: "Removed support for Python 3.8"
  2. The setup.py file correctly specifies the new requirement with python_requires=">=3.9"
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if the Python version change is documented
echo "Checking changelog for Python 3.9 requirement..."
rg -l "python.*3\.9|minimum.*python" docs/

# Check for any remaining Python 3.8 references
echo "Checking for outdated Python version references..."
rg -l "python.*3\.8" --glob '!setup.py'

Length of output: 341


Script:

#!/bin/bash
# Check changelog files with common naming patterns
echo "Checking changelog files..."
fd -t f -e md -e rst "change|history|release|news" -i

# If found, check their content for Python version mentions
echo "Checking content of potential changelog files..."
fd -t f -e md -e rst "change|history|release|news" -i -x cat {}

# Check README for Python version mentions
echo "Checking README..."
fd -t f "readme" -i -x cat {}

# Check documentation files for Python version mentions
echo "Checking docs directory..."
fd -t f -e md -e rst . -x cat {} docs/ 2>/dev/null || true

Length of output: 424763


Script:

#!/bin/bash
# Check changelog for Python 3.8 removal and Python version mentions
echo "Checking changelog for Python version changes..."
rg -l "python.*3\.8|minimum.*python" docs/misc/changelog.rst

# Check for Python version mentions in docs
echo "Checking docs for Python version mentions..."
rg -l "python.*3\.8|minimum.*python" docs/

# Check setup.py for Python version requirement
echo "Checking setup.py for Python version requirement..."
cat setup.py | grep "python_requires"

Length of output: 541

stable_baselines3/common/noise.py (3)

3-3: LGTM! Modern import practice adopted.

The change to import Iterable from collections.abc instead of typing follows modern Python practices and PEP recommendations.


157-157: LGTM! Updated type hints to use built-in generics.

The change from List[ActionNoise] to list[ActionNoise] aligns with Python's move towards using built-in types for annotations (PEP 585).

Also applies to: 161-161


Line range hint 36-36: Verify NumPy dtype usage across the codebase.

The noise classes already use np.float32 as the default dtype, which is compatible with MPS. Let's verify that this dtype is consistently used when converting noise to tensors elsewhere in the codebase.

Also applies to: 64-64

✅ Verification successful

Let me gather more information about the dtype handling in these specific files to ensure there are no compatibility issues.


No changes needed for MPS support in noise handling

The verification confirms that the noise implementation is MPS-compatible:

  • Noise is generated using NumPy with explicit dtype control in noise.py via astype(self._dtype)
  • When noise is added to actions, it's done through direct arithmetic operations that preserve the dtype
  • No explicit tensor conversions of noise values are found that could cause dtype issues
  • The noise is properly handled by the existing tensor operations in TD3, off-policy algorithm, and distribution classes
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for potential dtype conversion issues when noise values are used
# Look for patterns where noise values might be converted to tensors with wrong dtype

# Search for tensor conversion of noise values
rg "torch\.(tensor|from_numpy|FloatTensor).*noise" -A 2

# Search for noise application to actions
rg "action.*\+.*noise" -A 2

Length of output: 1030


Script:

#!/bin/bash
# Check dtype handling in files where noise is used
rg "dtype" -B 2 -A 2 tests/test_sde.py stable_baselines3/td3/td3.py stable_baselines3/common/off_policy_algorithm.py stable_baselines3/common/distributions.py

# Check noise-related tensor operations
ast-grep --pattern 'torch.tensor($noise, $_)'
ast-grep --pattern 'torch.from_numpy($noise)'

# Look for any explicit dtype specifications in noise-related operations
rg "noise.*dtype" -B 2 -A 2

Length of output: 955

tests/test_spaces.py (1)

28-28: LGTM: Type hint modernization

The change from Optional[Dict] to Optional[dict] aligns with modern Python type hint conventions and maintains consistency across the codebase.

stable_baselines3/common/sb2_compat/rmsprop_tf_like.py (3)

1-2: LGTM: Modern import practice adopted

Good modernization of imports by using collections.abc.Iterable instead of typing.Iterable, following PEP 585 recommendations.


71-71: LGTM: Consistent type hint modernization

Good update to use built-in dict type annotation instead of typing.Dict, maintaining consistency with PEP 585.


Line range hint 93-95: Verify tensor dtype compatibility with MPS

Since MPS doesn't support float64, we should verify that tensor initialization is compatible with MPS devices. The ones_like and zeros_like operations inherit dtype from input parameters, which could cause issues if parameters are float64.

Let's verify the tensor operations:

stable_baselines3/ddpg/ddpg.py (3)

Line range hint 110-127: Consider MPS-specific training optimizations

Since the learn method is crucial for training, consider adding MPS-specific optimizations or error handling:

  1. Handle potential MPS-specific out-of-memory scenarios
  2. Add warnings about float64 operations during training
  3. Consider adding device-specific performance logging

Let's check for existing MPS-related training code:


Line range hint 79-108: Add MPS-specific tensor dtype handling

The PR objectives mention issues with float64 tensors on MPS devices. Consider adding dtype checks and conversions in the model setup.

Let's verify tensor dtype handling across the codebase:


Line range hint 1-13: Consider adding MPS-specific device handling utilities

Given that this PR aims to improve MPS support, consider adding utility functions or type aliases for MPS device handling to ensure consistent device management across the codebase.

Let's check if other files have implemented MPS-related utilities that should be imported here:

stable_baselines3/common/envs/identity_env.py (4)

37-37: LGTM: Type hint modernization

The update from Tuple to tuple aligns with modern Python type hinting practices using built-in types.


45-45: LGTM: Consistent type hint update

The type hint update maintains the same structure while using modern built-in types.


77-77: LGTM: Type hint consistency maintained

The type hint update is consistent with the parent class while preserving the numpy array type specification.


145-145: LGTM: Type hint update with MPS compatibility

The type hint modernization is consistent with other environments. The uint8 observation space is compatible with MPS, avoiding potential float64 issues mentioned in the PR objectives.

tests/test_env_checker.py (3)

138-138: LGTM: Type hint update follows modern conventions

The update from Dict to dict aligns with PEP 585 which made built-in collections support type hints.


165-165: LGTM: Return type annotation updated correctly

The update to use lowercase tuple[int, dict] follows PEP 585 conventions while maintaining the correct return type specification.


173-173: LGTM: Step method signature follows Gymnasium API

The updated type hints follow PEP 585 conventions while maintaining compatibility with the Gymnasium environment interface (observation, reward, terminated, truncated, info).

✅ Verification successful

Step method signature follows Gymnasium API and codebase conventions

The type hints are consistent with other environments in the codebase, particularly matching the patterns in:

  • stable_baselines3/common/monitor.py
  • stable_baselines3/common/envs/identity_env.py

All typed implementations follow the same tuple structure with 5 elements: observation, reward, terminated, truncated, and info dictionary, using PEP 585 type hints.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the step method signature matches Gymnasium's API
# Expected: The signature should match other environments in the codebase

# Search for other step method signatures to confirm consistency
rg -U "def step.*->.*tuple\[.*\]" --type python

Length of output: 80


Script:

#!/bin/bash
# Search for step method signatures in Python files
rg -t py -U "def step.*->.*tuple\["

# Also search without type hints to see other step implementations
rg -t py -U "def step\([^)]*\):"

Length of output: 1607

stable_baselines3/common/evaluation.py (2)

2-2: Verify Python version compatibility for type hint changes

The changes from Dict/List/Tuple to dict/list/tuple follow PEP 585, but require Python 3.9+. Given that this is a widely-used RL library, we should ensure backward compatibility.

Consider one of these approaches:

  1. Keep using typing.Dict/List/Tuple for backward compatibility
  2. Add runtime version check:
import sys
if sys.version_info < (3, 9):
    from typing import Dict, List, Tuple

Also applies to: 17-17, 21-21

✅ Verification successful

Let me try another approach to verify the Python version compatibility.


Type hint changes are compatible with project requirements

The changes from typing.Dict/List/Tuple to built-in dict/list/tuple are appropriate since:

  • Project explicitly requires Python 3.9+ in setup.py: python_requires=">=3.9"
  • CI tests run on Python 3.9-3.12
  • pyproject.toml confirms target version: target-version = "py39"

These type hint changes align with PEP 585 and the project's minimum Python version requirements.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check minimum Python version requirement in setup files and CI configs
echo "Checking Python version requirements..."
rg -l "python" . | grep -E "setup.*|requirements.*|pyproject.*|tox.*|\.github/workflows/.*"

Length of output: 337


Script:

#!/bin/bash
# Check setup.py, pyproject.toml and other config files for Python version
echo "=== Checking setup.py ==="
cat setup.py 2>/dev/null || true

echo -e "\n=== Checking pyproject.toml ==="
cat pyproject.toml 2>/dev/null || true

echo -e "\n=== Checking CI workflow files ==="
find .github/workflows -type f -name "*.yml" -exec cat {} \; 2>/dev/null || true

echo -e "\n=== Checking tox.ini ==="
cat tox.ini 2>/dev/null || true

echo -e "\n=== Checking requirements files ==="
find . -maxdepth 2 -type f -name "*requirements*.txt" -exec cat {} \; 2>/dev/null || true

Length of output: 9882


Line range hint 89-94: Verify tensor type compatibility for MPS device

Given that MPS doesn't support float64, we should ensure proper tensor type handling during policy evaluation. The observations and actions between model.predict() and env.step() need to maintain compatible dtypes.

Consider adding explicit tensor type checks:

def _ensure_tensor_type(tensor, device):
    if device == "mps" and tensor.dtype == torch.float64:
        return tensor.to(torch.float32)
    return tensor
✅ Verification successful

Let me try a different approach to gather information about tensor handling and MPS compatibility.


Based on the search results, I can now generate the final response:

MPS device compatibility is properly handled

The codebase already has proper MPS (Metal Performance Shaders) device handling:

  • Automatic float32 conversion is enforced when MPS is detected via th.set_default_dtype(th.float32)
  • Device-specific tensor creation is handled in obs_as_tensor() with MPS-specific logic
  • The codebase skips float64 tests when MPS is detected
  • Tensor device placement is consistently managed through .to(device) calls
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for tensor dtype handling in model prediction and environment interaction
echo "Checking tensor dtype handling..."
rg -A 3 "dtype|float32|float64" --type python

# Check for MPS-related error handling
echo "Checking MPS-related error handling..."
ast-grep --pattern 'try {
  $$$
  $TENSOR = $$.to($DEVICE)
  $$$
} catch ($ERROR) {
  $$$
}'

Length of output: 337


Script:

#!/bin/bash
# Search for tensor dtype handling in Python files
echo "Checking tensor dtype handling..."
rg -t py "dtype|float32|float64" -A 3

# Look for device-related code
echo "Checking device handling..."
rg -t py "\.to\(.*device" -A 2

# Search for observation handling in evaluation
echo "Checking observation handling..."
rg -t py "observations.*to\(" -A 2

# Look for MPS-related code
echo "Checking MPS-related code..."
rg -t py "mps" -A 2

Length of output: 75816

stable_baselines3/common/envs/multi_input_envs.py (2)

97-97: LGTM: Type hint update

The update to use lowercase dict type hint aligns with modern Python type hinting practices (PEP 585).


Line range hint 169-183: Verify test determinism with random starts

The reset implementation looks good, but the random start behavior could affect test determinism. Ensure that:

  1. Tests properly handle random starts by setting specific seeds
  2. Test failures mentioned in PR objectives aren't related to non-deterministic resets

Let's check for test files using this environment:

#!/bin/bash
# Search for test files importing SimpleMultiObsEnv
rg -l "SimpleMultiObsEnv" test/
tests/test_gae.py (3)

26-26: LGTM: Type hint modernization

The change from Dict to dict aligns with Python's typing best practices by using built-in types.


56-56: LGTM: Consistent type hint update

The change maintains consistency with the previous type hint modernization.


76-76: Verify compatibility with custom environments

The change from direct attribute access to get_wrapper_attr("max_steps") is more robust for wrapped environments, but could be a breaking change for custom environments that don't properly expose the max_steps attribute through the wrapper chain.

Let's verify the compatibility:

Consider:

  1. Adding a fallback mechanism for direct attribute access
  2. Documenting this change in the changelog as it might affect custom environment implementations
stable_baselines3/common/vec_env/dummy_vec_env.py (7)

3-5: LGTM: Modern import organization

The changes follow Python's modern type hinting practices by using collections.abc instead of typing for Sequence.


30-30: LGTM: Updated type hint syntax

The type hint update follows PEP 585, using the built-in list type instead of typing.List.


50-50: LGTM: Consistent type hint update

The type hint change maintains consistency with the modern typing approach used throughout the file.


127-130: LGTM: Consistent wrapper method access

The change maintains consistency with the wrapper attribute access pattern used in get_attr.


132-142: LGTM: Type hint modernization

The changes consistently apply the modern type hinting approach using built-in collections.


116-119: Verify wrapper attribute access pattern

The change from direct getattr to get_wrapper_attr suggests a more robust way of accessing wrapped environment attributes. Ensure this pattern is consistently used across the codebase.

✅ Verification successful

Change to get_wrapper_attr is consistent with codebase patterns

The verification shows that the change from getattr to get_wrapper_attr is correct and consistent with the codebase's design:

  • get_wrapper_attr is used consistently in vectorized environments (DummyVecEnv and SubprocVecEnv) for attribute access
  • The pattern is used in both attribute retrieval (get_attr) and method calls (env_method)
  • The implementation follows the proper wrapper attribute access pattern established in base_vec_env.py
  • Direct getattr usage is primarily found in utility functions and base classes, not in environment wrappers
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for attribute access patterns in environment wrappers
ast-grep --pattern 'getattr($env, $attr)' 
ast-grep --pattern '$env.get_wrapper_attr($attr)'

Length of output: 95


Script:

#!/bin/bash
# Search for attribute access patterns in environment wrappers using ripgrep
rg "getattr\(.*?\)" --type py -A 2 -B 2
rg "get_wrapper_attr\(.*?\)" --type py -A 2 -B 2

Length of output: 6547


114-114: Verify performance impact of using deepcopy

While using deepcopy ensures proper isolation of observation data (which could be important for MPS tensor handling), it might impact performance. Consider benchmarking this change, especially for environments with large observation spaces.

stable_baselines3/common/env_util.py (2)

12-12: LGTM: Type hint modernization

The change from Type[gym.Wrapper] to type[gym.Wrapper] follows PEP 585 recommendations and maintains backward compatibility.

#!/bin/bash
# Verify type hint compatibility across the codebase
rg "Type\[gym\.Wrapper\]" --type python

137-141: Verify Atari wrapper compatibility with MPS device

While the type hints are updated correctly, we should ensure that the AtariWrapper is compatible with MPS device tensors, especially since the PR mentions float32/float64 compatibility issues.

stable_baselines3/common/vec_env/stacked_observations.py (3)

2-3: LGTM: Good practice using collections.abc

Using collections.abc.Mapping instead of typing.Mapping is more explicit and follows Python's best practices for runtime type checking.


70-70: LGTM: Consistent type hint modernization

The return type annotation update is consistent with the codebase's modernization to use built-in collection types.


10-10: Verify minimum Python version requirement

The use of built-in generics (dict[...] instead of Dict[...]) requires Python 3.9+.

✅ Verification successful

Python version requirement is compatible with built-in generics

The project's minimum Python version requirement is 3.9+ as specified in setup.py, which is compatible with the usage of built-in generics (PEP 585). The change from Dict to dict is safe.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if the project's minimum Python version is compatible with PEP 585
grep -r "python_requires" .

Length of output: 67

stable_baselines3/common/preprocessing.py (2)

145-145: Type hint update looks good.

The change from Tuple to tuple aligns with modern Python typing conventions.


93-96: Type hint modernization looks good, but verify MPS compatibility.

The update to use built-in dict type is correct and follows modern Python typing conventions. However, given the PR's focus on MPS support, we should ensure tensor operations are MPS-compatible.

Let's verify tensor dtype handling across the codebase:

✅ Verification successful

MPS compatibility is maintained in preprocessing operations

The preprocessing code handles tensors safely for MPS compatibility because:

  • Image normalization uses simple division (obs.float() / 255.0) which is MPS-compatible
  • All tensor operations use basic arithmetic and float() conversion without complex dtype specifications
  • No explicit dtype conversions that could trigger MPS limitations (like float64)
  • One-hot encoding and tensor shape operations use standard PyTorch operations

The type hint modernization from Dict to dict is also correct and consistent with modern Python conventions.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for tensor creation and dtype specifications
rg "torch\.(tensor|float|zeros|ones)" -A 2
rg "\.to\(.*dtype" -A 2

Length of output: 2984


Script:

#!/bin/bash
# Let's check the preprocessing.py file for tensor operations and dtype handling
rg -A 3 "dtype|\.to\(" stable_baselines3/common/preprocessing.py

# Also check for image normalization code which is relevant for MPS
rg -A 5 "normalize_images|is_image_space" stable_baselines3/common/preprocessing.py

Length of output: 3164

stable_baselines3/common/monitor.py (7)

8-8: LGTM: Type hint imports simplified

The change to use built-in types as generics aligns with PEP 585 and modern Python typing practices.


36-37: LGTM: Type hints updated correctly

The use of tuple instead of Tuple is appropriate here as these parameters should be immutable sequences of strings.


64-64: LGTM: Return type hint updated correctly

The return type hint update maintains the correct typing for gym environment reset method.


129-129: LGTM: Return type hints updated appropriately

The type hint updates for episode metrics getter methods are correct and maintain consistency with the class attributes.

Also applies to: 137-137, 145-145


178-179: LGTM: Type hints updated consistently

The type hint updates in ResultsWriter and utility functions maintain consistency with the rest of the file and don't affect the logging functionality.

Also applies to: 203-203, 220-220


55-59: Verify float precision handling for MPS compatibility

Given that MPS doesn't support float64, we should verify that these lists storing float values (especially rewards and episode_returns) don't cause precision issues when running on MPS devices.

Also applies to: 62-62


85-85: Verify reward type casting for MPS compatibility

The step method handles reward values which need to be compatible with MPS limitations. Ensure that:

  1. Reward values are properly cast to float32 when using MPS
  2. No precision loss occurs during reward accumulation
stable_baselines3/a2c/a2c.py (3)

1-1: LGTM: Modern type hint usage

The change to use built-in types aligns with PEP 585, which is a good modernization practice.


Line range hint 134-187: Review tensor dtype handling for MPS compatibility

The training loop performs several tensor operations that might need explicit dtype handling for MPS compatibility:

  1. Action conversion for discrete spaces
  2. Value and advantage calculations
  3. Loss computations

Consider adding explicit float32 casting for MPS device compatibility.

#!/bin/bash
# Description: Check tensor dtype handling across the codebase

echo "Checking for explicit dtype handling in tensor operations..."
rg "to\(.*float|dtype=" stable_baselines3/a2c/

echo "Checking for existing MPS-related dtype handling..."
rg -i "mps.*float|float.*mps" stable_baselines3/

Consider adding dtype checks and conversions:

 def train(self) -> None:
+    # Ensure float32 dtype for MPS compatibility
+    if self.device.type == "mps":
+        self.policy = self.policy.to(dtype=th.float32)

60-64: Verify MPS compatibility for policy implementations

While the type hint changes look good, given this PR's focus on MPS support, we should verify that all policy implementations (ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy) properly handle the MPS device.

✅ Verification successful

MPS device support is properly handled through base policy implementation

The verification shows that device handling is implemented at the BasePolicy level through:

  • A device property that infers the correct device from policy parameters
  • Device-agnostic tensor operations using self.device
  • Proper device handling in load/save operations with device mapping
  • Tensor operations properly use the policy's device

Since all three policies (ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy) inherit from BasePolicy, they automatically inherit the device handling capabilities. The implementation is device-agnostic and will work with any PyTorch-supported device, including MPS.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if policy implementations handle MPS device properly
# Look for device handling in policy implementations

echo "Checking device handling in policy implementations..."
rg -l "def to\(.*device" stable_baselines3/common/policies.py

echo "Checking for any MPS-specific handling..."
rg -i "mps" stable_baselines3/common/policies.py

Length of output: 438


Script:

#!/bin/bash
# Let's check the policy implementations more thoroughly

# First, check the base policy class for device handling
echo "=== Checking BasePolicy device handling ==="
rg -A 5 "def to\(" stable_baselines3/common/policies.py

# Check for any device-related code in policies
echo -e "\n=== Checking device-related code in policies ==="
rg "device|cuda|cpu" stable_baselines3/common/policies.py

# Look for tensor operations that might need device handling
echo -e "\n=== Checking tensor operations ==="
rg "torch\.(tensor|Tensor|FloatTensor)" stable_baselines3/common/policies.py

# Check the actual policy implementations
echo -e "\n=== Checking specific policy implementations ==="
ast-grep --pattern 'class ActorCriticPolicy'
ast-grep --pattern 'class ActorCriticCnnPolicy'
ast-grep --pattern 'class MultiInputActorCriticPolicy'

Length of output: 1841

stable_baselines3/common/envs/bit_flipping_env.py (5)

2-2: LGTM: Type imports are correctly organized

The addition of Optional and Union type hints is appropriate for the type annotations used throughout the file.


169-169: LGTM: Return type annotation is accurate

The update from Dict to dict aligns with modern Python type hints, and the Union[int, np.ndarray] correctly represents the possible value types.


184-185: LGTM: Type hints follow Gymnasium interface

The type hints for the reset method correctly implement the Gymnasium environment interface, with proper use of Optional and dict types.


80-81: Verify MPS compatibility with int32 conversion

The conversion to int32 for NumPy 2.0 compatibility is good. However, since this PR aims to improve MPS support, we should verify that int32 is supported on the MPS backend.

✅ Verification successful

Based on the search results, I can now generate the final response since I have enough information about MPS and type handling in the codebase:

int32 conversion is compatible with MPS backend

The codebase already handles MPS compatibility well:

  • The obs_as_tensor function in common/utils.py has special handling for MPS devices
  • Type conversions are handled safely before tensor creation
  • The codebase explicitly checks for MPS availability using th.backends.mps.is_built()
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for other int32 conversions and MPS-related type handling
# to ensure consistency across the codebase

# Look for similar int32 conversions
rg "astype\(np\.int32\)" -A 2 -B 2

# Look for MPS-related type handling
rg "mps|device" -g "*.py" -A 2 -B 2

Length of output: 61670


213-213: Consider MPS compatibility with float32 return type

While the type hints are accurate, the np.float32 return type should be verified for MPS compatibility, as the PR objectives mention issues with float32/float64 handling on MPS devices.

tests/test_distributions.py (1)

57-57: LGTM! Type hint modernization.

The update from Tuple to tuple aligns with PEP 585 and modern Python type hinting practices.

Let's check if similar type hint updates are needed in other test files:

✅ Verification successful

Type hint modernization is consistent with the codebase

The codebase already follows modern type hint practices:

  • Return type hints in test_env_checker.py and test_distributions.py use lowercase tuple[...]
  • The only uppercase Tuple usages are for gym spaces construction, which is correct as those refer to the actual spaces.Tuple class
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Find other test files that might need similar type hint updates
# Look for usage of typing.Tuple in test files

rg -l "from typing import.*Tuple" "tests/"
rg "-> Tuple\[" "tests/"

Length of output: 472


Script:

#!/bin/bash
# Let's fix the ripgrep command syntax and try again
# Search for Tuple type hints in test files

# Look for typing.Tuple imports
rg "from typing.*import.*Tuple" tests/

# Look for Tuple type hints in return annotations (escaping the arrow properly)
rg "\-\> Tuple\[" tests/

# Also search for other potential Tuple usages in type hints
rg ": Tuple\[" tests/

Length of output: 172


Script:

#!/bin/bash
# Let's try a simpler approach with basic pattern matching
# Search for any mentions of 'Tuple' in test files

# Look for Tuple usage in test files
rg "Tuple" tests/

# Also check for tuple type hints to see the current style
rg "tuple\[" tests/

Length of output: 618

stable_baselines3/common/vec_env/subproc_vec_env.py (7)

3-4: LGTM: Type hint modernization

The change to use collections.abc.Sequence instead of typing.Sequence is a good modernization that enables better runtime type checking.


30-30: LGTM: PEP 585 compliant type hints

The change to use dict instead of Dict aligns with PEP 585 and modern Python type hinting practices.


95-95: LGTM: Consistent type hint update

The parameter type hint update follows the same modernization pattern seen throughout the file.


168-168: LGTM: Consistent method signature updates

The return type hint updates maintain consistency with the modernization effort throughout the codebase.

Also applies to: 183-183, 190-190, 197-197


132-132: Verify tensor stacking behavior on MPS devices

The change from flattening to stacking observations could have implications for tensor operations on MPS devices. We should ensure this works correctly with the MPS backend.

Also applies to: 142-142

✅ Verification successful

Based on the search results, I can now provide a final response:

MPS device support is properly handled with observation stacking

The change from _flatten_obs to _stack_obs is part of a broader refactoring that improves tensor operations across all devices, including MPS. The codebase shows:

  • The _stack_obs function properly handles different observation space types using numpy stacking operations
  • Device handling is abstracted through PyTorch's device system which explicitly supports MPS ("mps") as noted in utils.py
  • No device-specific tensor operations are used in the stacking logic that would cause issues on MPS
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for tensor operations and device handling
rg "stack|flatten|to\(.*device" --type python
# Look for any MPS-related test cases
rg "test.*mps" --type python

Length of output: 146


Script:

#!/bin/bash
# Search for tensor operations and device handling with correct syntax
rg "stack|flatten|to\(.*device"

# Look for any MPS-related test cases
rg "test.*mps"

# Search for device-specific code
rg "device.*mps|mps.*device"

# Search for observation space handling
rg "_stack_obs|_flatten_obs"

Length of output: 36422


209-233: Consider performance implications of stacking operations

The refactored observation stacking logic might have different performance characteristics on MPS devices compared to the previous flattening approach. Key points:

  1. The function now explicitly handles Dict and Tuple spaces with separate stacking logic
  2. Each stacking operation creates a new tensor, which could impact memory usage
  3. The behavior might differ between CPU, CUDA, and MPS devices

Consider benchmarking the performance impact of these changes, especially on MPS devices. You might want to:

  1. Profile memory usage patterns
  2. Compare performance between devices
  3. Consider adding device-specific optimizations if needed

57-57: Verify the get_wrapper_attr implementation

The change from getattr to get_wrapper_attr improves encapsulation, but we should verify that the wrapper method properly handles all attribute access cases, especially for MPS-specific attributes.

Also applies to: 60-60

✅ Verification successful

The previous results show usage of get_wrapper_attr but not its implementation. Let's search for its definition and related wrapper patterns to understand how it handles attribute access.


The previous results show extensive wrapper usage but we still need to find the implementation of get_wrapper_attr. Let's search in base classes and gym wrapper implementations.


Let's try to find the implementation with more context and related attribute access patterns.


Based on the search results, I can now generate the final response. The code shows that get_wrapper_attr is used consistently across the vectorized environment implementations, with DummyVecEnv using it to access attributes and SubprocVecEnv using it for both method and attribute access. The change from getattr to get_wrapper_attr is part of the standard vectorized environment interface.

The change to get_wrapper_attr is a valid improvement for attribute access

The switch from getattr to get_wrapper_attr is correct and consistent with the codebase's design:

  • It's used in both DummyVecEnv and SubprocVecEnv implementations
  • Properly handles attribute access through environment wrappers
  • Part of the standard vectorized environment interface
  • No MPS-specific attributes are involved, this is a general improvement
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for get_wrapper_attr implementation and usage
rg "get_wrapper_attr" -A 5
# Search for any MPS-related attributes that might be accessed
rg "mps|cuda|device" --type python

Length of output: 3900


Script:

#!/bin/bash
# Search for get_wrapper_attr implementation in base classes and wrappers
ast-grep --pattern 'def get_wrapper_attr($$$)'

# Search for any wrapper-related imports and base classes
rg "class.*Wrapper|from.*import.*Wrapper" -A 2

# Search for any unwrapped attribute access patterns
rg "unwrapped\." -A 2

Length of output: 18114


Script:

#!/bin/bash
# Search for get_wrapper_attr implementation in gym.Wrapper or base classes
ast-grep --pattern 'class Wrapper($$$)'

# Search for get_wrapper_attr in base VecEnv class
rg "get_wrapper_attr" stable_baselines3/common/vec_env/base_vec_env.py -A 5

# Search for get_attr implementation in base classes
ast-grep --pattern 'def get_attr($$$)'

Length of output: 158


Script:

#!/bin/bash
# Search for get_wrapper_attr in the entire codebase with more context
rg -B 5 -A 10 "def get_wrapper_attr"

# Search for get_attr in base classes with more context
rg -B 5 -A 10 "def get_attr.*:"

# Search for any attribute access patterns in the environment
rg "env\." -B 2 -A 2 stable_baselines3/common/vec_env/

Length of output: 45146

stable_baselines3/dqn/policies.py (4)

38-39: Verify type hint compatibility across Python versions

The change from List to list and Type to type aligns with PEP 585, but we should ensure compatibility with all supported Python versions.

#!/bin/bash
# Check minimum Python version in setup.py and pyproject.toml
echo "Checking Python version requirements:"
rg "python_requires|requires-python" .

240-246: Verify CNN operations compatibility with MPS backend

The CnnPolicy uses NatureCNN which involves multiple convolution operations. Ensure all CNN operations are compatible with the MPS backend, particularly regarding tensor types and operations.

#!/bin/bash
# Search for CNN-related tests with MPS devices
rg "test.*cnn.*mps|test.*conv.*mps" tests/
# Check for any known CNN operation issues with MPS
rg -A 5 "conv.*mps|cnn.*mps" .

Line range hint 1-291: Enhance test coverage for MPS support

While the type hint updates improve code quality, additional test coverage is needed for MPS support:

  1. Test model training on MPS devices
  2. Test model inference on MPS devices
  3. Test save/load functionality with MPS devices
  4. Test fallback behavior for unsupported operations
#!/bin/bash
# Check current test coverage for MPS
echo "Searching for MPS-related tests:"
rg -l "test.*mps" tests/

Line range hint 74-84: Verify model serialization compatibility with MPS devices

The constructor parameters are crucial for model serialization. Since this PR adds MPS support, we should ensure models can be properly saved and loaded when using MPS devices.

stable_baselines3/td3/td3.py (4)

Line range hint 68-72: LGTM: Type hint modernization

The update from Dict to dict aligns with PEP 585, which made built-in collections generic. This is a safe change that improves code consistency without affecting runtime behavior.


231-232: LGTM: Return type modernization

The update from List to list in the return type annotation is a safe modernization change.


234-236: LGTM: Return type modernization

The update to use lowercase tuple[list[str], list[str]] is a safe type hint modernization that doesn't affect the model saving functionality.


81-81: LGTM: Constructor parameter type updates

The type hint updates align with modern Python typing conventions. However, given the PR's focus on MPS support and float32/float64 compatibility:

Let's verify tensor dtype handling in the implementation:

Also applies to: 89-89, 92-93, 100-100

✅ Verification successful

Based on the search results, I can now verify the tensor dtype handling in the codebase:

  1. The codebase consistently uses th.float32 as the default dtype for tensors, with explicit handling in common/utils.py where it sets th.set_default_dtype(th.float32).

  2. There's proper handling of float64 to float32 conversion in buffers and proper MPS device compatibility checks:

  • common/buffers.py explicitly casts np.float64 actions to np.float32
  • test_spaces.py has proper MPS checks: pytest.skip("MPS framework doesn't support float64")
  1. Device handling is consistent with proper .to(device) calls for all network components and tensors.

Type hint changes are safe and tensor dtype handling is robust

The constructor parameter type changes are purely cosmetic Python type hints that don't affect runtime behavior. The codebase already has proper float32/float64 handling and MPS device compatibility measures in place.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for tensor creation and dtype handling
rg -A 2 "th\.(tensor|float|zeros|ones)" 
# Search for potential float64 usage
rg -A 2 "float64|torch\.double"
# Search for device-specific tensor operations
rg -A 2 "to\(.*device"

Length of output: 16163

stable_baselines3/common/atari_wrappers.py (2)

67-67: LGTM: Type annotation update

The change from Dict to dict aligns with modern Python type hints and maintains consistency across the codebase.


Line range hint 67-71: Verify NumPy operations compatibility with MPS

Since this PR focuses on MPS support, we should verify that the NumPy operations in the reset loop remain compatible when the environment observations are eventually converted to PyTorch tensors on MPS device.

✅ Verification successful

Based on the search results, I can now provide a final response:

No MPS compatibility concerns in the Atari wrapper's reset loop

The code in the Atari wrapper's reset loop is safe for MPS device compatibility because:

  1. The observations remain as NumPy arrays within the wrapper layer and no tensor operations occur here
  2. The actual tensor conversion and device placement happens later in the policy layer through obs_as_tensor() function, which already handles device compatibility
  3. The preprocessing pipeline (in preprocessing.py) properly handles the conversion from NumPy arrays to PyTorch tensors with appropriate dtype casting
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for tensor device conversion points in the codebase
# Look for places where numpy arrays from env are converted to torch tensors

# Search for numpy to torch conversion patterns
rg -A 2 "torch\.(from_numpy|tensor|FloatTensor)" 

# Search for device-specific handling
rg -A 2 "(to\(.*device)|\.cuda\(|\.to\(.*mps)" 

Length of output: 2303


Script:

#!/bin/bash
# Let's try a different approach to search for device-related patterns
# Search for numpy to torch conversion and device handling with proper regex

# Look for tensor device conversions and numpy operations in the atari wrappers
rg -A 3 "to\(.*device" stable_baselines3/common/atari_wrappers.py
rg -A 3 "\.cuda\(" stable_baselines3/common/atari_wrappers.py
rg -A 3 "\.to\(" stable_baselines3/common/atari_wrappers.py
rg -A 3 "torch\.(from_numpy|tensor|FloatTensor)" stable_baselines3/common/atari_wrappers.py

# Check how observations are handled in the VecEnv wrappers
rg -A 3 "torch\.(from_numpy|tensor|FloatTensor)" stable_baselines3/common/vec_env/
rg -A 3 "(to\(|\.cuda\()" stable_baselines3/common/vec_env/

Length of output: 434


Script:

#!/bin/bash
# Let's check how observations are processed in the environment wrappers and base classes

# Look for observation handling in base environment classes
rg -A 5 "def step" stable_baselines3/common/base_class.py
rg -A 5 "def reset" stable_baselines3/common/base_class.py

# Check observation preprocessing
rg -A 5 "observation" stable_baselines3/common/preprocessing.py
rg -A 5 "obs" stable_baselines3/common/preprocessing.py

# Look for numpy operations in the atari wrapper
rg -A 3 "np\." stable_baselines3/common/atari_wrappers.py

# Check device handling in policy classes
rg -A 3 "device" stable_baselines3/common/policies.py

Length of output: 19219

tests/test_dict_env.py (3)

75-78: LGTM: Type hint improvement in reset method

The change from Dict to dict type hint is more Pythonic and aligns with PEP 484 recommendations.


120-123: LGTM: Removed redundant seed call

Good cleanup - removing the redundant seed call since it's already handled in the reset method.


Line range hint 1-300: Verify MPS compatibility with dictionary observations

Since this PR focuses on MPS support, we should verify that dictionary observations work correctly with MPS device tensors.

stable_baselines3/dqn/dqn.py (2)

65-65: Type hint updates look good!

The changes from typing module types to built-in types (e.g., Dictdict, Listlist, Tupletuple) align with modern Python practices and PEP 585.

Also applies to: 78-78, 86-86, 88-89, 98-98, 230-231, 234-234, 276-276, 279-279


Line range hint 1-283: Verify consistent MPS support across related files

To ensure comprehensive MPS support, verify similar changes in related files:

  1. Policy implementations
  2. Network architectures
  3. Buffer implementations
stable_baselines3/td3/policies.py (4)

119-125: Verify optimizer compatibility with MPS backend

The type hint updates are good. However, for MPS support:

  1. Ensure the optimizer (optimizer_class) supports MPS device tensors
  2. Verify that feature extractors handle MPS tensor operations correctly
#!/bin/bash
# Search for any custom optimizers that might need MPS compatibility checks
rg "class.*\(.*optim\.Optimizer\)" --type py

342-348: Verify dict observation space handling on MPS

The type hints are good, but for MPS support:

  1. Ensure CombinedExtractor properly handles dict observation spaces on MPS
  2. Verify that all observation types in the dict support MPS operations
#!/bin/bash
# Check CombinedExtractor implementation for MPS compatibility
ast-grep --pattern 'class CombinedExtractor(BaseFeaturesExtractor):
  $$$
  def forward($$$) {
    $$$
  }
'

288-294: Ensure CNN operations are optimized for MPS

While the type hints are correct, the NatureCNN feature extractor needs special attention for MPS:

  1. Verify that all CNN operations support MPS backend
  2. Consider adding MPS-specific optimizations for CNN operations
#!/bin/bash
# Check CNN implementation for potential MPS compatibility issues
ast-grep --pattern 'class NatureCNN(BaseFeaturesExtractor):
  $$$
  def forward($$$) {
    $$$
  }
'

39-42: Consider MPS compatibility for network architecture and activation functions

The type hint updates look good, but given this PR's focus on MPS support, ensure that:

  1. The network architecture (net_arch) supports MPS-compatible tensor operations
  2. The activation function (activation_fn) is compatible with MPS backend
✅ Verification successful

No MPS-specific concerns with activation functions

The codebase shows proper MPS support implementation with the following evidence:

  1. MPS device detection is properly handled in common/utils.py with th.backends.mps.is_built()
  2. The default activation function (nn.ReLU) is widely used and fully supported by PyTorch's MPS backend
  3. All tensor operations are properly handled with device-agnostic code

The type hint changes are correct and the activation function implementation is compatible with MPS without requiring any modifications.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if any unsupported activation functions are used with MPS
ast-grep --pattern 'activation_fn: type[nn.Module] = $activation' | grep -v "ReLU\|Tanh\|Sigmoid"

Length of output: 110


Script:

#!/bin/bash
# Let's verify activation functions more comprehensively
# First, check all activation function usages in the TD3 module
rg "activation_fn" stable_baselines3/td3/ -B 2 -A 2

# Then check if there are any MPS-specific handling or restrictions
rg "mps" stable_baselines3/ -B 2 -A 2

# Finally, check the default activation functions in parent classes
rg "activation_fn.*=.*nn\." stable_baselines3/ -B 2 -A 2

Length of output: 17932

stable_baselines3/ppo/ppo.py (2)

74-75: LGTM: Type hint modernization

The updates to type hints (changing Dict to dict, Type to type) align with PEP 585 and modern Python practices. These changes improve code consistency without affecting runtime behavior.

Also applies to: 82-82, 98-99, 103-103


Line range hint 201-201: Add MPS compatibility checks for tensor operations

Given that MPS doesn't support float64, we should ensure all tensor operations are using float32. Consider adding dtype checks or conversions at key points:

  1. Advantage normalization
  2. Policy loss calculation
  3. Value prediction clipping
  4. MSE loss calculation

Here's a suggested approach:

def _ensure_float32(self, tensor: th.Tensor) -> th.Tensor:
    """Convert tensor to float32 if on MPS device."""
    if self.device.type == 'mps' and tensor.dtype == th.float64:
        return tensor.to(th.float32)
    return tensor

# Then use it in tensor operations:
advantages = self._ensure_float32(rollout_data.advantages)
if self.normalize_advantage and len(advantages) > 1:
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

Let's verify the tensor dtypes in the codebase:

Also applies to: 211-212, 228-228, 232-232

stable_baselines3/common/torch_layers.py (4)

113-119: LGTM: Type hint modernization

The updates to type hints follow PEP 585, using built-in collection types. This is a safe change that improves code maintainability.


252-252: LGTM: Return type annotation update

The change to use tuple instead of Tuple aligns with modern Python type hinting practices.


214-215: Verify tensor dtype compatibility with MPS

While the type hint changes look good, given that MPS doesn't support float64, we should ensure that tensors created in the policy and value networks are explicitly cast to float32.

Also applies to: 220-221

✅ Verification successful

Let me analyze the output to verify tensor dtype compatibility with MPS.

The search results show several important findings:

  1. The codebase explicitly sets default dtype to float32:
th.set_default_dtype(th.float32)
  1. All tensor creation from numpy arrays in observation preprocessing is explicitly cast to float32:
return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()}
  1. All numpy arrays in buffers are initialized with float32:
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
  1. No explicit float64/double dtype specifications were found in tensor creation.

Let me generate the final response:

No tensor dtype compatibility issues found

The codebase already enforces float32 tensors through:

  • Default PyTorch dtype set to float32
  • Explicit float32 casting in observation preprocessing
  • float32 numpy arrays in replay buffers
  • No explicit float64/double tensor creation
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for potential float64 tensor creation in the codebase
# Look for tensor creation patterns that might default to float64

# Search for tensor creation patterns
rg -A 2 "torch\.(tensor|as_tensor|zeros|ones|empty|full|arange|linspace)" 

# Search for dtype specifications
rg "dtype=torch\.(float64|double)"

Length of output: 48851


291-291: Verify image tensor handling for MPS compatibility

While the type hint update is correct, this class processes images and we should ensure proper tensor dtype handling for MPS compatibility.

✅ Verification successful

Let me analyze the search results to verify the image tensor handling for MPS compatibility.

From the search results, I can see:

  1. In CombinedExtractor and NatureCNN, image tensors are properly handled:
# In NatureCNN
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]

# In CombinedExtractor
encoded_tensor_list.append(extractor(observations[key]))
  1. The preprocessing function properly handles image observations:
# In preprocessing.py
if is_image_space(observation_space):
    return obs.float() / 255.0
return obs.float()
  1. The codebase consistently uses float32 as the default dtype:
# In utils.py
th.set_default_dtype(th.float32)
  1. No explicit dtype conversions or device-specific handling that could cause MPS compatibility issues.

Image tensor handling is MPS compatible

The image tensor handling in CombinedExtractor follows best practices:

  • Uses .float() for proper dtype conversion
  • Relies on PyTorch's default float32 dtype
  • No explicit device-specific operations that could break MPS compatibility
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check image tensor handling in NatureCNN and related classes

# Look for image tensor creation and conversion
rg -A 3 "observation_space\.sample\(\)|observations\[key\]" 

# Check for any explicit dtype conversions in image processing
rg -A 2 "\.(float|double|to\(|type\()"

Length of output: 48851

stable_baselines3/sac/sac.py (2)

80-80: Type hint modernization looks good!

The updates to use lowercase variants (dict, type, tuple) instead of their capitalized counterparts align with PEP 585 and modern Python type hinting practices.

Also applies to: 92-92, 100-100, 103-103, 104-104, 114-114


Line range hint 246-252: Consider batch normalization behavior on MPS

The batch normalization statistics update might need special handling for MPS devices:

Consider adding a comment documenting any specific considerations for batch normalization behavior on MPS devices.

tests/test_vec_normalize.py (2)

25-25: LGTM! Type annotations updated to use built-in generics.

The changes consistently update type annotations across all test environment classes to use built-in generic types (dict) instead of the typing module equivalents (Dict). This is a good modernization that aligns with Python's type system evolution since Python 3.9+.

Also applies to: 42-42, 65-65, 97-97


Line range hint 57-59: Consider adding test coverage for float64 handling.

Given that the PR objectives mention skipping float64 action space tests due to MPS limitations, it would be valuable to add test cases that verify this behavior. Consider:

  1. Adding a test environment with float64 observation/action spaces
  2. Adding test cases that verify proper error handling when float64 tensors are used with MPS

Let's check if there are any existing float64 test cases:

Also applies to: 89-93

stable_baselines3/common/vec_env/base_vec_env.py (4)

69-73: LGTM! Type hints updated consistently

The instance variable type hints are correctly updated to use built-in types while maintaining the same initialization logic and data structures.


Line range hint 383-423: LGTM! Wrapper methods updated consistently

The method signatures in VecEnvWrapper are correctly updated to use built-in types while maintaining the proper delegation to the wrapped environment.


Line range hint 1-477: Verify MPS device handling requirements

Given that this PR aims to add MPS support and handle float32/float64 tensor compatibility, we should verify if this base class needs additional methods or attributes to support MPS device selection and tensor type handling.

#!/bin/bash
# Check for MPS-related code in the codebase
echo "Checking for MPS-related code:"
rg "mps|device|float32|float64" --type python

# Check for tensor operations that might need MPS support
echo "Checking for tensor operations:"
ast-grep --pattern 'np.array($$$)' 

18-21: Verify backward compatibility of type alias changes

The update from typing types to built-in types (e.g., Dictdict) aligns with PEP 585, but we should ensure compatibility with older Python versions if they're still supported.

✅ Verification successful

Type hint changes are safe and compatible with supported Python versions

The changes to built-in types (e.g., dict instead of Dict) are fully compatible with the codebase's minimum Python requirement of 3.9+ as specified in both setup.py and pyproject.toml. Python 3.9 introduced PEP 585 which made built-in types available for type annotations.

  • The codebase already uses modern type hints in other files, with only one file still using legacy typing imports
  • The minimum Python version (3.9) is explicitly set in both setup.py (python_requires=">=3.9") and pyproject.toml (target-version = "py39")
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check Python version support in setup.py and pyproject.toml
echo "Checking Python version requirements:"
fd -e toml -e py -x grep -l "python" {} \; | xargs cat

# Check if any code still uses old typing types
echo "Checking for mixed usage of typing types:"
rg "from typing import .*Dict|from typing import .*List|from typing import .*Tuple" 

Length of output: 7892


Script:

#!/bin/bash
# Check Python version support in setup.py and pyproject.toml
echo "=== Python Version Requirements ==="
find . -maxdepth 1 -type f -name "setup.py" -o -name "pyproject.toml" | xargs cat

# Check for typing imports across the codebase
echo -e "\n=== Typing Import Usage ==="
rg --type py "from typing import (.*Dict.*|.*List.*|.*Tuple.*)" || echo "No matches found"

# Check minimum Python version in the codebase
echo -e "\n=== Python Version References ==="
rg --type py "python_requires|Python \d|requires python" || echo "No version references found"

Length of output: 7770

stable_baselines3/her/her_replay_buffer.py (2)

101-101: LGTM! Type hint modernization is correct.

The update to use dict[str, Any] instead of Dict[str, Any] aligns with modern Python type hinting standards while maintaining the same functionality.


112-112: LGTM! Type hint modernization is correct.

The update to use dict[str, Any] instead of Dict[str, Any] aligns with modern Python type hinting standards while maintaining the same functionality.

tests/test_logger.py (1)

5-5: LGTM: Modern type hint usage

The change from typing.Sequence to collections.abc.Sequence follows Python's type hint best practices. Since Python 3.9, collections.abc is the recommended source for container types.

stable_baselines3/sac/policies.py (1)

403-413: Ensure feature extractors handle tensor dtypes correctly

Both CnnPolicy and MultiInputPolicy use different feature extractors (NatureCNN and CombinedExtractor) which need to handle tensor dtypes consistently for MPS compatibility.

Let's verify the feature extractors' tensor dtype handling:

Also applies to: 469-479

stable_baselines3/common/save_util.py (5)

379-379: LGTM: Type hint consistency maintained

The type hint update maintains consistency with other changes in the file.


296-298: LGTM: Verify PyTorch state dict compatibility with MPS

The type hint updates are correct. Since this function handles PyTorch state dictionaries, ensure it's tested with models using the MPS device.

#!/bin/bash
# Search for PyTorch MPS device handling
rg "to\(.*mps.*\)" --type py

383-383: Verify tensor device handling for MPS

While the type hint update is correct, ensure that tensor device conversion is properly handled when loading models on MPS devices.

#!/bin/bash
# Search for tensor device conversion patterns
ast-grep --pattern 'th.load($_, map_location=$_)'

131-131: LGTM: Consistent type hint update

The type hint modernization is consistent with the codebase changes. Since this function is critical for model loading, ensure it's tested with MPS device tensors.

#!/bin/bash
# Search for MPS-related tests
rg -l "mps|Metal Performance Shaders" tests/

76-76: LGTM: Type hint modernization

The change from Dict to dict aligns with PEP 585. Ensure your minimum Python version requirement is >=3.9 where this syntax is supported.

✅ Verification successful

Type hint modernization is compatible with project requirements

The change from Dict to dict is safe as the project explicitly requires Python >=3.9 (confirmed in setup.py's python_requires=">=3.9"), which fully supports this PEP 585 syntax. The project's pyproject.toml also confirms targeting Python 3.9 with target-version = "py39".

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check Python version requirement in setup.py or pyproject.toml
fd -e toml -e cfg -e txt "setup|requirements|pyproject" -x cat {} \; | grep -i "python"

Length of output: 249


Script:

#!/bin/bash
# Check Python version requirements in setup.py
fd "setup.py" -x cat {}

# Also check pyproject.toml if it exists
fd "pyproject.toml" -x cat {}

# Check any requirements files
fd "requirements" -x cat {}

Length of output: 7237

tests/test_vec_envs.py (4)

7-7: LGTM: Import addition is appropriate.

The addition of Optional from typing module is necessary for the type hints being introduced.


33-35: LGTM: Type hints follow Gymnasium API conventions.

The type annotations for current_options and the reset method signature follow the modern Gymnasium API conventions, improving type safety and code clarity.


196-196: LGTM: Consistent type hint implementation.

The reset method signature update maintains consistency with the CustomGymEnv implementation and Gymnasium API.


310-310: LGTM: Simplified dict assertion is appropriate.

The change from collections.OrderedDict to dict is appropriate since Python 3.7+ preserves dictionary order by default. This simplification maintains the same functionality while making the code cleaner.

stable_baselines3/common/env_checker.py (1)

175-178: LGTM! Type hints improvement

The updated type hints for dictionary parameters provide better type safety and code clarity. This change aligns with modern Python practices and helps with static type checking.

tests/test_utils.py (3)

4-4: LGTM: Proper initialization of Atari environments

The addition of ale_py import and registration ensures that Atari environments are properly initialized for testing.

Also applies to: 28-28


448-451: LGTM: Improved system info checks

The changes enhance system information reporting:

  1. Using "Accelerator" instead of "GPU Enabled" better supports various acceleration types (including MPS)
  2. Adding Cloudpickle version check helps with environment debugging

183-183: Verify policy evaluation with different float types on MPS

While removing type hints makes the code more flexible, we should ensure that policy evaluation works correctly with both float32 and float64 types on MPS devices.

stable_baselines3/common/logger.py (2)

8-8: LGTM! Good practice using collections.abc

The change from collections to collections.abc for importing Mapping and Sequence follows Python's recommended practices and helps avoid deprecation warnings.


118-118: LGTM! Consistent modernization of type hints

The changes systematically update type hints to use built-in generics (PEP 585), making the code more maintainable and future-proof while maintaining type safety.

Also applies to: 140-140, 176-176, 264-264, 290-290, 337-337, 403-403, 485-488, 494-494, 504-504, 517-517, 628-628, 639-639

stable_baselines3/common/off_policy_algorithm.py (4)

7-7: LGTM! Import optimization

The import statement has been simplified to include only the essential typing imports, improving code clarity.


82-82: LGTM! Type hint modernization

The type hints have been updated to use built-in generic types (e.g., dict, list, type) instead of their typing module counterparts (e.g., Dict, List, Type). This change:

  • Aligns with PEP 585
  • Improves code readability
  • Prepares the codebase for future Python versions

Also applies to: 90-90, 93-94, 96-96, 108-108


259-259: LGTM! Method type hint improvements

The return and parameter type hints in methods have been updated to use modern type annotations. These changes:

  • Maintain consistency with the overall type hint modernization
  • Improve static type checking capabilities
  • Enhance IDE support for better code assistance

Also applies to: 365-365, 445-448


Line range hint 1-613: Verify MPS device compatibility

While the type hint updates improve code quality, this core algorithm file might need additional changes to properly support MPS devices, particularly around tensor operations and device handling.

✅ Verification successful

MPS device support is compatible with current implementation

The code in off_policy_algorithm.py is already device-agnostic and compatible with MPS:

  • Device handling is done through PyTorch's standard .to(device) mechanism
  • No direct CUDA-specific code or device type checks
  • No hard-coded tensor operations that would conflict with MPS
  • The device fallback logic in docstring is generic enough to work with MPS

The implementation properly delegates device handling to PyTorch's device management system, which includes support for MPS devices.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for tensor creation and device handling patterns
rg -l "to\(.*device" | while read -r file; do
  echo "=== $file ==="
  rg "to\(.*device" "$file"
done

# Search for float type specifications that might need adjustment for MPS
rg "dtype=th\.float64|\.double\(\)" 

Length of output: 2363


Script:

#!/bin/bash
# Search for device handling in the base algorithm class
ast-grep --pattern 'device = $_' stable_baselines3/common/off_policy_algorithm.py

# Search for tensor operations that might need MPS compatibility
rg "th\.(tensor|FloatTensor|from_numpy)" stable_baselines3/common/off_policy_algorithm.py

# Check if there are any direct device type checks
rg "device\.(type|startswith)" stable_baselines3/common/off_policy_algorithm.py

# Look for any CUDA specific code
rg "(cuda|cpu)" stable_baselines3/common/off_policy_algorithm.py

Length of output: 408

stable_baselines3/common/callbacks.py (4)

4-4: LGTM: Import statement update aligns with type hint modernization.

The change to only import specific typing constructs is appropriate, as the built-in dict and list types are now used for type hints.


48-49: LGTM: Consistent type hint modernization across callback implementations.

The changes systematically update type hints from typing.Dict and typing.List to built-in dict and list types across all callback implementations. This modernization:

  • Follows PEP 585 recommendations for using built-in collection types as generic types
  • Maintains consistent typing across the codebase
  • Improves code readability

Also applies to: 78-78, 128-128, 137-137, 331-331, 430-430, 533-533


408-413: LGTM: Type annotations in EvalCallback enhance code clarity.

The type annotations for evaluation metrics storage are well-defined and appropriate:

  • evaluations_results: List of episode rewards
  • evaluations_timesteps: List of timesteps
  • evaluations_length: List of episode lengths
  • evaluations_successes: List of success flags

Line range hint 1-700: Verify callback behavior with MPS device.

Given that this PR aims to add MPS support, we should verify that these callbacks handle MPS-specific scenarios correctly, particularly in EvalCallback where model evaluation occurs.

✅ Verification successful

Let me check the device handling in the base classes and callbacks to make a proper assessment.


Let me check how device handling is implemented in evaluation-related code since that's most relevant for callbacks.


Based on the gathered information, I can now provide a final response:

No device-specific changes needed for callbacks

The callbacks, particularly EvalCallback, don't require MPS-specific handling because:

  1. Device management is already properly abstracted in the base class (BaseAlgorithm) which handles device selection and model placement
  2. Callbacks operate at a higher level, dealing with metrics and model management rather than direct tensor operations
  3. The evaluation process inherits device settings from the model instance, as seen in the base class implementation

The existing device-agnostic design ensures that callbacks will work correctly with any device, including MPS, as they:

  • Don't perform direct tensor operations
  • Use the model's device settings during evaluation
  • Handle device-independent operations like saving checkpoints and logging metrics
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for any MPS-specific handling in the codebase
rg -l "mps|device" --type py

# Look for device-related test cases
rg -l "test.*device" --type py

Length of output: 1132


Script:

#!/bin/bash
# Check device handling in base class and callbacks
rg -A 5 "device" stable_baselines3/common/base_class.py stable_baselines3/common/callbacks.py

# Check if there are any device-specific tests for callbacks
rg -A 5 "device.*test.*callback" tests/

Length of output: 10453


Script:

#!/bin/bash
# Check device handling in evaluation code
rg -A 5 "device" stable_baselines3/common/evaluation.py

# Check if there are any device-specific test cases for evaluation
rg -A 5 "test.*eval.*device" tests/

Length of output: 93

stable_baselines3/common/distributions.py (3)

4-4: LGTM: Type hint modernization

The updates to use built-in types (tuple, dict, list) instead of typing module equivalents (Tuple, Dict, List) align with modern Python type hinting practices. These changes improve compatibility with newer Python versions while maintaining the same type safety.

Also applies to: 33-33, 101-101, 138-138, 193-193, 257-257, 308-308, 321-321, 365-365, 415-415, 516-516, 614-614, 664-664


Line range hint 1-700: Verify complete MPS compatibility

Given that MPS support is a key objective of this PR, we should ensure all tensor operations throughout the distributions are compatible.

#!/bin/bash
# Description: Comprehensive check for MPS compatibility issues

# Check for any remaining float64 references
rg "float64|double|Float64|Double"

# Look for numpy operations that might introduce float64
rg "np\.(array|zeros|ones)"

# Check for tensor operations that might need dtype specification
ast-grep --pattern 'torch.$$$($$$)' 

Line range hint 516-614: Verify tensor dtype compatibility with MPS backend

The StateDependentNoiseDistribution class performs several tensor operations that might be affected by MPS backend limitations. We should ensure all tensor operations are using float32 to maintain compatibility with MPS.

✅ Verification successful

Based on the search results, I can see that there are no explicit float64 tensor operations in the StateDependentNoiseDistribution class. The tensor operations in the code use PyTorch's default dtype, which is float32 unless explicitly changed. The only float64 usages found in the codebase are in test files or numpy arrays in other components.

Let me verify the tensor creation operations in the specific file to ensure there are no implicit float64 conversions:


Let me try a different approach to check the tensor operations in the distributions file:


Based on the search results, I can now provide a definitive response about the tensor operations in the StateDependentNoiseDistribution class:

MPS compatibility verified - no explicit float64 operations found

All tensor operations in the StateDependentNoiseDistribution class use PyTorch's default dtype (float32) with no explicit float64 conversions:

  • Tensor creation uses th.ones() without dtype specification
  • Normal distribution operations use default dtype
  • Matrix operations (th.mm, th.bmm) and math operations maintain input dtype
  • No to(torch.float64) or similar conversions found

The code is compatible with MPS backend as it doesn't force float64 operations anywhere in the implementation.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for potential float64 usage in tensor operations

# Search for tensor creation and conversion operations
rg -A 2 "torch\.(tensor|ones|zeros|from_numpy|as_tensor)" 

# Search for specific float dtype specifications
rg "dtype=(torch\.)?(float|double|float32|float64)"

Length of output: 29842


Script:

#!/bin/bash
# Check tensor creation and dtype specifications in the distributions file
rg "torch\.(tensor|ones|zeros|from_numpy|as_tensor)" stable_baselines3/common/distributions.py

# Check for any float dtype specifications in the distributions file
rg "dtype=" stable_baselines3/common/distributions.py

Length of output: 150


Script:

#!/bin/bash
# Check all tensor-related operations in the StateDependentNoiseDistribution class
ast-grep --pattern 'th.ones($$$)' stable_baselines3/common/distributions.py

# Also check for any Normal distribution initialization
ast-grep --pattern 'Normal($$$)' stable_baselines3/common/distributions.py

# Check for any tensor operations in the file
rg -A 2 "th\." stable_baselines3/common/distributions.py

Length of output: 12194

stable_baselines3/common/base_class.py (3)

9-10: LGTM: Modern type hint import

The change from typing.Iterable to collections.abc.Iterable follows Python's type hint modernization guidelines. This is a good practice as it prepares the codebase for future Python versions where these types will be deprecated from the typing module.


98-98: LGTM: Consistent type hint modernization

The type hint changes follow PEP 585 guidelines by using built-in collection types (e.g., list, dict) instead of their typing module equivalents. This modernization improves code readability and maintainability while maintaining type safety.

Also applies to: 111-111, 145-145, 148-148, 439-439, 539-543


644-648: Verify MPS device handling in load method

While the type hints have been updated, we should verify that the method properly handles MPS devices when loading models, especially given the PR's objective of supporting PyTorch GPU compatibility on Apple Silicon.

Also applies to: 804-804

✅ Verification successful

Based on the gathered information, I can now provide the final response:

MPS device support is properly implemented

The codebase already has comprehensive MPS device support:

  • get_device() function explicitly supports MPS devices as per the docstring
  • MPS device detection is implemented in get_available_accelerator() function
  • The device handling is uniform across the codebase through the get_device() function
  • The load() method uses this device handling infrastructure through get_device() calls
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if there are any explicit MPS device handling in the codebase

# Search for MPS-related device handling
rg -i "mps|metal" --type python

# Search for device-related conditionals that might need MPS support
ast-grep --pattern 'if $device.type == $_: $$$'

Length of output: 114


Script:

#!/bin/bash
# Let's try different approaches to find device handling

# Search for MPS-related device handling without file type restriction
rg -i "mps|metal"

# Search for device handling in Python files
rg "device.*=|device.*type" 

# Look for the load method implementation to check device handling
ast-grep --pattern 'def load(cls, path: $_, env: $_, device: $_, custom_objects: $_)'

Length of output: 10948


Script:

#!/bin/bash
# Let's check the actual implementation of the load method and device handling

# Look for the load method implementation in base_class.py
rg -A 20 "def load\(cls.*device.*custom_objects" stable_baselines3/common/base_class.py

# Check device handling in get_device function
rg -A 15 "def get_device" stable_baselines3/common/utils.py

Length of output: 722

stable_baselines3/common/policies.py (2)

8-8: Type hint improvements look good!

The changes to use built-in types (dict, list, tuple) instead of typing module types (Dict, List, Tuple) align with modern Python type hinting practices and improve code readability.

Also applies to: 67-68, 71-72, 213-213, 236-236, 333-334, 337-337, 453-454, 461-462, 465-466, 537-537, 636-636, 662-662, 719-719, 803-804, 811-812, 815-816, 876-877, 884-885, 888-889, 945-945, 948-948, 964-964, 971-971


964-970: Verify q_networks initialization with MPS device

The q_networks initialization looks correct, but we should verify its compatibility with MPS devices since this PR aims to add MPS support.

✅ Verification successful

Let me gather more specific information about device handling in the context of q_networks.


Let me check one more aspect to verify device handling in the ContinuousCritic class and its parent classes.


Let me try one final check to specifically look at device handling in the policy classes.


Based on the gathered information, I can now provide the final response:

Device handling for q_networks is properly implemented through inheritance

The ContinuousCritic class inherits from BaseModel, which handles device movement through PyTorch's standard .to(device) mechanism. The q_networks initialization is compatible with any device (including MPS) as it relies on PyTorch's native device management system. No specific MPS-related modifications are needed for the q_networks implementation.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any MPS-related device handling in the codebase
rg -l "mps|cuda|cpu" | grep -v "test"

# Search for q_networks usage with device specifications
ast-grep --pattern 'q_networks = $_'

Length of output: 932


Script:

#!/bin/bash
# Check how devices are handled in policies.py and related files
rg "to\(.*device" stable_baselines3/common/policies.py -B 2 -A 2

# Look for device handling in the ContinuousCritic class
ast-grep --pattern 'class ContinuousCritic {
  $$$
  to($device) {
    $$$
  }
  $$$
}'

# Check q_networks usage in relation to device handling
rg "q_networks.*device|device.*q_networks" -B 2 -A 2

Length of output: 342


Script:

#!/bin/bash
# Check the full ContinuousCritic class implementation
ast-grep --pattern 'class ContinuousCritic {
  $$$
}'

# Look for device handling in base classes
rg "class.*BaseModel|class.*BaseFeaturesExtractor|class.*BasePolicy" -A 10 stable_baselines3/common/policies.py

# Check if there's any specific MPS handling code
rg "mps" stable_baselines3/common/policies.py -B 2 -A 2

Length of output: 15429


Script:

#!/bin/bash
# Check device handling in policy classes
rg "def to\(|\.to\(" stable_baselines3/common/policies.py -B 2 -A 2

# Look for base policy class device handling
ast-grep --pattern 'class BasePolicy {
  $$$
  to($device) {
    $$$
  }
  $$$
}'

# Check for any device-related methods in ContinuousCritic
rg "ContinuousCritic" stable_baselines3/common/policies.py -A 20

Length of output: 1362

self.close_video_recorder()
self._capture_frame()
if len(self.recorded_frames) > self.video_length:
print(f"Saving video to {self.video_path}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use Logger Instead of 'print' Statements

Using print statements in library code is discouraged as it can clutter the standard output. It's better to use the logger for such messages.

Replace the print statement with logger.info:

-import os
+import os
+from gymnasium import error, logger
...
-print(f"Saving video to {self.video_path}")
+logger.info(f"Saving video to {self.video_path}")

Comment on lines +4 to +5
# Assume Python 3.9
target-version = "py39"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Python version requirements need to be updated in documentation

The Python version requirements are inconsistent across the project:

  • pyproject.toml and setup.py correctly specify Python 3.9+
  • Installation guide correctly states Python 3.9+
  • Migration guide still mentions Python 3.7+ (outdated)
  • Changelog has old entries but no mention of the Python 3.9 requirement change
  • Conda environment uses Python 3.11, which is compatible

Please update:

  • Migration guide to reflect the Python 3.9+ requirement
  • Add an entry in the changelog about the minimum Python version change to 3.9
🔗 Analysis chain

Verify Python version compatibility across the project

The change to Python 3.9 as the target version needs to be consistently reflected across the project configuration files and documentation.

Consider adding a note in the changelog about this version requirement change, as it could be a breaking change for users still on Python 3.8.


Let's check the content of the environment file and documentation files to ensure version consistency.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check Python version requirements across project files

# Check setup.py/setup.cfg for Python version
echo "Checking setup files for Python version:"
rg -A 2 "python_requires|requires-python"

# Check environment files
echo -e "\nChecking environment files:"
fd -e yml -e yaml | xargs rg -l "python"

# Check documentation
echo -e "\nChecking documentation:"
fd -e md -e rst | xargs rg -l "python.*3\.[0-9]"

Length of output: 558


Script:

#!/bin/bash
# Check Python version in conda environment file
echo "Python version in conda_env.yml:"
cat docs/conda_env.yml | grep -A 2 "python"

# Check Python version mentions in documentation
echo -e "\nPython version in changelog:"
rg "python.*3\.[0-9]" docs/misc/changelog.rst

echo -e "\nPython version in installation guide:"
rg "python.*3\.[0-9]" docs/guide/install.rst

echo -e "\nPython version in migration guide:"
rg "python.*3\.[0-9]" docs/guide/migration.rst

Length of output: 1026

@@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma
"Missing shimmy installation. You provided an OpenAI Gym environment. "
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym environments with SB3, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
"install shimmy (`pip install 'shimmy>=2.0'`)."
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix inconsistent shimmy version requirements

There's an inconsistency in the required shimmy versions between the two functions:

  • _patch_env requires shimmy>=2.0
  • _convert_space requires shimmy>=0.2.1

This could lead to compatibility issues if users install an older version of shimmy based on the second error message.

Apply this diff to maintain consistency:

             "Missing shimmy installation. You provided an OpenAI Gym space. "
             "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
             "In order to use OpenAI Gym space with SB3, you need to "
-            "install shimmy (`pip install 'shimmy>=0.2.1'`)."
+            "install shimmy (`pip install 'shimmy>=2.0'`)."

Also applies to: 89-92

@@ -7,7 +7,7 @@ Installation
Prerequisites
-------------

Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13
Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3

Windows
~~~~~~~
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix Python version inconsistency in Windows section

The Windows section still references Python 3.8, which contradicts the new 3.9+ requirement.

Apply this change:

-You need an environment with Python version 3.8 or above.
+You need an environment with Python version 3.9 or above.

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +155 to +156
if hasattr(th, "backends") and th.backends.mps.is_built():
pytest.skip("MPS framework doesn't support float64")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve MPS availability check and skip message

The current implementation has several areas for improvement:

  1. Use is_available() instead of is_built() to accurately check MPS availability
  2. Make the skip message more informative about the specific limitation
  3. Simplify the condition structure

Apply this diff to improve the implementation:

-    if hasattr(th, "backends") and th.backends.mps.is_built():
-        pytest.skip("MPS framework doesn't support float64")
+    if hasattr(th.backends, "mps") and th.backends.mps.is_available():
+        pytest.skip("Skipping float64 tests: MPS backend does not support float64 dtype operations")

This change:

  • Uses the recommended is_available() check
  • Provides a clearer skip message explaining why the test is skipped
  • Simplifies the condition by checking hasattr(th.backends, "mps") directly
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
pytest.skip("MPS framework doesn't support float64")
if hasattr(th.backends, "mps") and th.backends.mps.is_available():
pytest.skip("Skipping float64 tests: MPS backend does not support float64 dtype operations")

@@ -73,7 +73,7 @@ def __init__(
self.init_possible_transitions()

self.num_col = num_col
self.state_mapping: List[Dict[str, np.ndarray]] = []
self.state_mapping: list[dict[str, np.ndarray]] = []
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

⚠️ Potential issue

Consider using float32 for vector observations for MPS compatibility

The observation space indeed uses float64 for the 'vec' component:

"vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64)

This could cause compatibility issues with MPS (Metal Performance Shaders) which doesn't support float64. Consider updating the vector observation space to use float32 for better compatibility with Apple Silicon GPUs.

🔗 Analysis chain

Consider using float32 for vector observations for MPS compatibility

While the type hint update is good, the observation space uses float64 dtype which isn't supported by MPS. Consider updating the vector observation space to use float32 for better compatibility with Apple Silicon GPUs.

Let's verify the observation space dtype usage:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for other observation space definitions using float64
rg -B 2 -A 2 "spaces\.Box.*float64" 

Length of output: 171


Script:

#!/bin/bash
# Search for observation space definitions in the file
rg -B 2 -A 2 "spaces\.Box" stable_baselines3/common/envs/multi_input_envs.py

# Also check the state_mapping initialization and usage
rg -B 2 -A 2 "state_mapping" stable_baselines3/common/envs/multi_input_envs.py

Length of output: 1803

Comment on lines +116 to +122
net_arch: Optional[list[int]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure tensor type compatibility with MPS devices

Since MPS doesn't support float64, we should ensure that:

  1. The network architecture initialization uses float32
  2. The optimizer configuration is compatible with MPS devices

Consider adding explicit tensor type checks:

def _build(self, lr_schedule: Schedule) -> None:
    if self.device.type == "mps":
        # Ensure float32 is used for MPS compatibility
        self.q_net.to(dtype=th.float32)
        self.q_net_target.to(dtype=th.float32)

Comment on lines 229 to +234
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
observation: Union[np.ndarray, dict[str, np.ndarray]],
state: Optional[tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add float32 casting for observations in predict method

To maintain MPS compatibility, ensure observations are cast to float32 before processing.

Apply this change:

    def predict(
        self,
        observation: Union[np.ndarray, dict[str, np.ndarray]],
        state: Optional[tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
    ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
+       # Ensure float32 dtype for MPS compatibility
+       if isinstance(observation, dict):
+           observation = {k: v.astype(np.float32) if v.dtype == np.float64 else v 
+                         for k, v in observation.items()}
+       elif isinstance(observation, np.ndarray) and observation.dtype == np.float64:
+           observation = observation.astype(np.float32)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
observation: Union[np.ndarray, dict[str, np.ndarray]],
state: Optional[tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
def predict(
self,
observation: Union[np.ndarray, dict[str, np.ndarray]],
state: Optional[tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
# Ensure float32 dtype for MPS compatibility
if isinstance(observation, dict):
observation = {k: v.astype(np.float32) if v.dtype == np.float64 else v
for k, v in observation.items()}
elif isinstance(observation, np.ndarray) and observation.dtype == np.float64:
observation = observation.astype(np.float32)

Comment on lines +219 to +229
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Propagate tensor dtype handling through the policy class

To ensure consistent tensor dtype handling across the entire policy, consider adding a dtype parameter and propagating it to both actor and critic networks.

 def __init__(
     self,
     observation_space: spaces.Space,
     action_space: spaces.Box,
     lr_schedule: Schedule,
     net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
     activation_fn: type[nn.Module] = nn.ReLU,
+    dtype: th.dtype = th.float32,
     use_sde: bool = False,
     log_std_init: float = -3,
     use_expln: bool = False,
     clip_mean: float = 2.0,
     features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
     features_extractor_kwargs: Optional[dict[str, Any]] = None,
     normalize_images: bool = True,
     optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
     optimizer_kwargs: Optional[dict[str, Any]] = None,
     n_critics: int = 2,
     share_features_extractor: bool = False,
 ):

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +54 to +57
net_arch: list[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
activation_fn: type[nn.Module] = nn.ReLU,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding explicit tensor dtype handling for MPS compatibility

Since this PR aims to support MPS devices which only work with float32, consider adding explicit tensor dtype handling in the Actor class initialization to ensure all tensors are created with the correct dtype.

 def __init__(
     self,
     observation_space: spaces.Space,
     action_space: spaces.Box,
     net_arch: list[int],
     features_extractor: nn.Module,
     features_dim: int,
     activation_fn: type[nn.Module] = nn.ReLU,
+    dtype: th.dtype = th.float32,
     use_sde: bool = False,
     log_std_init: float = -3,
     full_std: bool = True,
     use_expln: bool = False,
     clip_mean: float = 2.0,
     normalize_images: bool = True,
 ):

Committable suggestion skipped: line range outside the PR's diff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Supporting PyTorch GPU compatibility on Apple Silicon chips
5 participants