Skip to content

Commit

Permalink
Merge branch 'master' into prioritized-experience-replay
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Nov 22, 2023
2 parents ec272b9 + e3dea4b commit a043cfd
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 15 deletions.
21 changes: 20 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
Changelog
==========

Release 2.2.0a11 (WIP)
Release 2.2.1 (2023-11-17)
--------------------------
**Support for options at reset, bug fixes and better error messages**

.. note::

SB3 v2.2.0 was yanked after a breaking change was found in `GH#1751 <https://github.com/DLR-RM/stable-baselines3/issues/1751>`_.
Please use SB3 v2.2.1 and not v2.2.0.


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
Expand All @@ -33,12 +39,24 @@ Bug Fixes:
- Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD)
- Fixed check_env for Sequence observation space (@corentinlger)
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)
- Fixed ResourceWarning when loading and saving models (files were not closed), please note that only path are closed automatically,
the behavior stay the same for tempfiles (they need to be closed manually),
the behavior is now consistent when loading/saving replay buffer

`SB3-Contrib`_
^^^^^^^^^^^^^^
- Added ``set_options`` for ``AsyncEval``
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to TRPO

`RL Zoo`_
^^^^^^^^^
- Removed `gym` dependency, the package is still required for some pretrained agents.
- Added `--eval-env-kwargs` to `train.py` (@Quentin18)
- Added `ppo_lstm` to hyperparams_opt.py (@technocrat13)
- Upgraded to `pybullet_envs_gymnasium>=0.4.0`
- Removed old hacks (for instance limiting offpolicy algorithms to one env at test time)
- Updated docker image, removed support for X server
- Replaced deprecated `optuna.suggest_uniform(...)` by `optuna.suggest_float(..., low=..., high=...)`

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -67,6 +85,7 @@ Others:
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed ``stable_baselines3/common/policies.py`` type hints
- Switched to ``mypy`` only for checking types
- Added tests to check consistency when saving/loading files

Documentation:
^^^^^^^^^^^^^^
Expand Down
31 changes: 21 additions & 10 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,14 @@ def save_to_zip_file(
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
save_path = open_path(save_path, "w", verbose=0, suffix="zip")
file = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)

# Create a zip-archive and write our objects there.
with zipfile.ZipFile(save_path, mode="w") as archive:
with zipfile.ZipFile(file, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
Expand All @@ -331,6 +331,9 @@ def save_to_zip_file(
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])

if isinstance(save_path, (str, pathlib.Path)):
file.close()


def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
"""
Expand All @@ -344,10 +347,12 @@ def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, ver
:param obj: The object to save.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file_handler, protocol=pickle.HIGHEST_PROTOCOL)
file = open_path(path, "w", verbose=verbose, suffix="pkl")
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
if isinstance(path, (str, pathlib.Path)):
file.close()


def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
Expand All @@ -360,8 +365,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
path actually exists. If path is a io.BufferedIOBase the path exists.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
return pickle.load(file_handler)
file = open_path(path, "r", verbose=verbose, suffix="pkl")
obj = pickle.load(file)
if isinstance(path, (str, pathlib.Path)):
file.close()
return obj


def load_from_zip_file(
Expand Down Expand Up @@ -391,14 +399,14 @@ def load_from_zip_file(
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
file = open_path(load_path, "r", verbose=verbose, suffix="zip")

# set device to cpu if cuda is not available
device = get_device(device=device)

# Open the zip archive and load data
try:
with zipfile.ZipFile(load_path) as archive:
with zipfile.ZipFile(file) as archive:
namelist = archive.namelist()
# If data or parameters is not in the
# zip archive, assume they were stored
Expand Down Expand Up @@ -450,4 +458,7 @@ def load_from_zip_file(
except zipfile.BadZipFile as e:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
finally:
if isinstance(load_path, (str, pathlib.Path)):
file.close()
return data, params, pytorch_variables
6 changes: 3 additions & 3 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ def step_async(self, actions: np.ndarray) -> None:
def step_wait(self) -> VecEnvStepReturn:
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos, self.reset_infos = zip(*results)
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]

def reset(self) -> VecEnvObs:
for env_idx, remote in enumerate(self.remotes):
remote.send(("reset", (self._seeds[env_idx], self._options[env_idx])))
results = [remote.recv() for remote in self.remotes]
obs, self.reset_infos = zip(*results)
obs, self.reset_infos = zip(*results) # type: ignore[assignment]
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a11
2.2.1
36 changes: 36 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import pathlib
import tempfile
import warnings
import zipfile
from collections import OrderedDict
Expand Down Expand Up @@ -747,3 +748,38 @@ def test_dqn_target_update_interval(tmp_path):
model = DQN.load(tmp_path / "dqn_cartpole")
os.remove(tmp_path / "dqn_cartpole.zip")
assert model.target_update_interval == 100


# Turn warnings into errors
@pytest.mark.filterwarnings("error")
def test_no_resource_warning(tmp_path):
# Check behavior of save/load
# see https://github.com/DLR-RM/stable-baselines3/issues/1751

# check that files are properly closed
# Create a PPO agent and save it
PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole")
PPO.load(tmp_path / "dqn_cartpole")

PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole"))
PPO.load(str(tmp_path / "dqn_cartpole"))

# Do the same but in memory, should not close the file
with tempfile.TemporaryFile() as fp:
PPO("MlpPolicy", "CartPole-v1").save(fp)
PPO.load(fp)
assert not fp.closed

# Same but with replay buffer
model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200)
model.save_replay_buffer(tmp_path / "replay")
model.load_replay_buffer(tmp_path / "replay")

model.save_replay_buffer(str(tmp_path / "replay"))
model.load_replay_buffer(str(tmp_path / "replay"))

with tempfile.TemporaryFile() as fp:
model.save_replay_buffer(fp)
fp.seek(0)
model.load_replay_buffer(fp)
assert not fp.closed

0 comments on commit a043cfd

Please sign in to comment.