-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix tests for mps support #1
base: master
Are you sure you want to change the base?
Changes from 36 commits
ace0516
9ac6225
2dcbef9
b00ca7f
06a2124
6d868c0
8d79e96
3276cb0
f4f6073
64327c7
0344c3c
fa196ab
efd086e
7f11843
c60f681
92e8d11
b235c8e
d4d0536
0311b62
086f79a
fe606fc
34f4819
ef39571
d26324c
1e5dc90
40ed03c
e83924b
b707480
81e3c63
f0e54a7
d47c586
b85a2a5
1c25053
f822ef5
1ac4a60
9970f51
955382e
56c153f
3d59b5c
dd3d0ac
5e7372d
263e657
8f0b488
e4f4f12
7c71688
4c03a25
020ee42
daaebd0
9489b1a
0ec37d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -135,6 +135,8 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: | |||||||||
:return: | ||||||||||
""" | ||||||||||
if copy: | ||||||||||
if hasattr(th, "backends") and th.backends.mps.is_built(): | ||||||||||
return th.tensor(array, dtype=th.float32, device=self.device) | ||||||||||
Comment on lines
+139
to
+140
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Modify the MPS check in the Consider changing the conditional check to determine if the device is MPS by checking Apply this diff to modify the condition: def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
if copy:
- if hasattr(th, "backends") and th.backends.mps.is_built():
+ if self.device.type == "mps":
return th.tensor(array, dtype=th.float32, device=self.device)
return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device) 📝 Committable suggestion
Suggested change
|
||||||||||
return th.tensor(array, device=self.device) | ||||||||||
return th.as_tensor(array, device=self.device) | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -78,11 +78,11 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: | |||||||
if self.discrete_obs_space: | ||||||||
# The internal state is the binary representation of the | ||||||||
# observed one | ||||||||
return int(sum(state[i] * 2**i for i in range(len(state)))) | ||||||||
return int(sum(int(state[i]) * 2**i for i in range(len(state)))) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify state conversion using NumPy vectorization Currently, the state conversion uses a generator expression with explicit loops and integer casting: return int(sum(int(state[i]) * 2**i for i in range(len(state)))) This can be simplified and optimized by leveraging NumPy's vectorized operations, which are more efficient and concise. Consider rewriting the code using Apply this diff to simplify the code: - return int(sum(int(state[i]) * 2**i for i in range(len(state))))
+ return int(state.dot(2 ** np.arange(len(state)))) This approach eliminates the explicit loop and casting, improving performance and readability. |
||||||||
|
||||||||
if self.image_obs_space: | ||||||||
size = np.prod(self.image_shape) | ||||||||
image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8))) | ||||||||
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8))) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optimize image creation by preallocating the array Instead of concatenating arrays to create the image, you can preallocate the array and assign values directly. This avoids unnecessary memory allocation and improves performance: Apply this diff to optimize the code: - image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
+ image = np.zeros(size, dtype=np.uint8)
+ image[:len(state)] = state.astype(np.uint8) * 255 This refactored code preallocates a zero-filled array of the required size and assigns the scaled state values directly to the beginning of the array. Committable suggestion
Suggested change
|
||||||||
return image.reshape(self.image_shape).astype(np.uint8) | ||||||||
return state | ||||||||
|
||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -29,8 +29,8 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Seed the different random generators. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:param seed: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:param using_cuda: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:param seed: Seed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:param using_cuda: Whether CUDA is currently used | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Seed python RNG | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
random.seed(seed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -138,19 +138,20 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Retrieve PyTorch device. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
It checks that the requested device is available first. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
For now, it supports only cpu and cuda. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
By default, it tries to use the gpu. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
For now, it supports only CPU and CUDA. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
By default, it tries to use the GPU. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:param device: One for 'auto', 'cuda', 'cpu' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:param device: One of "auto", "cuda", "cpu", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
or any PyTorch supported device (for instance "mps") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+144
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update docstring to reflect MPS support The docstring states that the function supports only CPU and CUDA devices, but with the addition of MPS support, this is no longer accurate. Please update the docstring to include MPS and reflect the current functionality. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:return: Supported Pytorch device | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Cuda by default | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# MPS/CUDA by default | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if device == "auto": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
device = "cuda" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
device = get_available_accelerator() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Force conversion to th.device | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
device = th.device(device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Cuda not available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# CUDA not available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+151
to
+157
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Enhance device handling logic. The current implementation could be improved to handle device availability more robustly:
Apply this diff to improve the implementation: # MPS/CUDA by default
if device == "auto":
device = get_available_accelerator()
# Force conversion to th.device
device = th.device(device)
- # CUDA not available
- if device.type == th.device("cuda").type and not th.cuda.is_available():
+ # Check device availability
+ if device.type == "cuda" and not th.cuda.is_available():
+ return th.device("cpu")
+ elif device.type == "mps":
+ try:
+ if not (hasattr(th, "backends") and th.backends.mps.is_built() and th.backends.mps.is_available()):
+ return th.device("cpu")
+ except:
+ return th.device("cpu") 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if device.type == th.device("cuda").type and not th.cuda.is_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return th.device("cpu") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -482,6 +483,8 @@ def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.devi | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if isinstance(obs, np.ndarray): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return th.as_tensor(obs, device=device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif isinstance(obs, dict): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if hasattr(th, "backends") and th.backends.mps.is_built(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+489
to
+490
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify MPS handling in The current implementation introduces code duplication when handling observations for MPS devices. Consider refactoring to streamline the code and ensure compatibility with different PyTorch versions. Apply this diff to simplify the code: elif isinstance(obs, dict):
- if hasattr(th, "backends") and th.backends.mps.is_built():
- return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()}
- return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
+ tensor_kwargs = {'device': device}
+ if device.type == 'mps':
+ tensor_kwargs['dtype'] = th.float32
+ return {key: th.as_tensor(_obs, **tensor_kwargs) for (key, _obs) in obs.items()} This refactor reduces code duplication and uses the device type to handle dtype settings appropriately. Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise Exception(f"Unrecognized type of observation {type(obs)}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -515,6 +518,21 @@ def should_collect_more_steps( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_available_accelerator() -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Return the available accelerator | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
(currently checking only for CUDA and MPS device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if hasattr(th, "backends") and th.backends.mps.is_built(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# MacOS Metal GPU | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
th.set_default_dtype(th.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return "mps" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif th.cuda.is_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return "cuda" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return "cpu" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+524
to
+537
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improve accelerator detection implementation. The current implementation has several issues:
Apply this diff to fix these issues: def get_available_accelerator() -> str:
"""
Return the available accelerator
- (currently checking only for CUDA and MPS device)
+ (checking for CUDA and MPS devices)
+
+ Note: MPS (Apple Metal) requires float32 dtype and may not be available
+ on older macOS versions even if built.
+
+ Returns:
+ str: The available accelerator type ('mps', 'cuda', or 'cpu')
"""
- if hasattr(th, "backends") and th.backends.mps.is_built():
- # MacOS Metal GPU
- th.set_default_dtype(th.float32)
- return "mps"
+ try:
+ if (hasattr(th, "backends") and th.backends.mps.is_built()
+ and th.backends.mps.is_available()):
+ # MacOS Metal GPU
+ return "mps"
+ except: # Catch any MPS-related errors
+ pass
elif th.cuda.is_available():
return "cuda"
else:
return "cpu" 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Retrieve system and python env info for the current system. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -530,7 +548,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Python": platform.python_version(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Stable-Baselines3": sb3.__version__, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"PyTorch": th.__version__, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"GPU Enabled": str(th.cuda.is_available()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Accelerator": get_available_accelerator(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Numpy": np.__version__, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Cloudpickle": cloudpickle.__version__, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Gymnasium": gym.__version__, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -125,6 +125,20 @@ def _sanity_checks(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"not {self.observation_space}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Cast `np.float64` reward datatype to `np.float32`, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
keep the others dtype unchanged. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:param dtype: The original action space dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:return: ``np.float32`` if the dtype was float64, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
the original dtype otherwise. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if reward.dtype == np.float64: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return reward.astype(np.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return reward | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix docstring parameter and add input validation. The implementation looks good but has a few minor issues:
Apply these changes: @staticmethod
def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray:
"""
Cast `np.float64` reward datatype to `np.float32`,
keep the others dtype unchanged.
- :param dtype: The original action space dtype
+ :param reward: The reward array to potentially cast
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
+ if not isinstance(reward, np.ndarray):
+ raise TypeError(f"Expected numpy array, got {type(reward)}")
if reward.dtype == np.float64:
return reward.astype(np.float32)
return reward 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __getstate__(self) -> Dict[str, Any]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Gets state for pickling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -254,7 +268,8 @@ def normalize_reward(self, reward: np.ndarray) -> np.ndarray: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.norm_reward: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return reward | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._maybe_cast_reward(reward) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Avoid modifying by reference the original object | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ | |||||||||
import gymnasium as gym | ||||||||||
import numpy as np | ||||||||||
import pytest | ||||||||||
import torch as th | ||||||||||
from gymnasium import spaces | ||||||||||
from gymnasium.spaces.space import Space | ||||||||||
|
||||||||||
|
@@ -151,6 +152,8 @@ def test_discrete_obs_space(model_class, env): | |||||||||
], | ||||||||||
) | ||||||||||
def test_float64_action_space(model_class, obs_space, action_space): | ||||||||||
if hasattr(th, "backends") and th.backends.mps.is_built(): | ||||||||||
pytest.skip("MPS framework doesn't support float64") | ||||||||||
Comment on lines
+155
to
+156
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve MPS availability check and skip message The current implementation has several areas for improvement:
Apply this diff to improve the implementation: - if hasattr(th, "backends") and th.backends.mps.is_built():
- pytest.skip("MPS framework doesn't support float64")
+ if hasattr(th.backends, "mps") and th.backends.mps.is_available():
+ pytest.skip("Skipping float64 tests: MPS backend does not support float64 dtype operations") This change:
📝 Committable suggestion
Suggested change
|
||||||||||
env = DummyEnv(obs_space, action_space) | ||||||||||
env = gym.wrappers.TimeLimit(env, max_episode_steps=200) | ||||||||||
if isinstance(env.observation_space, spaces.Dict): | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quote the pip install argument to prevent shell errors
The command:
may be misinterpreted by the shell due to the square brackets, which the shell might interpret as pattern characters for filename expansion (globbing). To prevent potential shell errors, it's recommended to quote the argument.
Apply this diff to fix the issue: