Skip to content

Commit

Permalink
Update TD3/DDPG/DQN defaults for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Dec 13, 2023
1 parent 373166d commit fb0d36d
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 10 deletions.
56 changes: 56 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,62 @@
Changelog
==========

Release 2.3.0a0 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- The defaults hyperparameters of ``TD3`` and ``DDPG`` have been changed to be consistent with ``SAC``

.. code-block:: python
# SB3 < 2.3.0 default hyperparameters
# model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100, learning_rate=1e-3)
# SB3 >= 2.3.0:
model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256, learning_rate=3e-4)
.. note::

One inconsistency remains: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC


- The default ``leanrning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms


.. code-block:: python
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
# model = DQN("MlpPolicy", env, learning_start=50_000)
# SB3 >= 2.3.0:
model = DQN("MlpPolicy", env, learning_start=100)
New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^

`SB3-Contrib`_
^^^^^^^^^^^^^^

`RL Zoo`_
^^^^^^^^^

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^



Release 2.2.1 (2023-11-17)
--------------------------
**Support for options at reset, bug fixes and better error messages**
Expand Down
8 changes: 4 additions & 4 deletions stable_baselines3/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def __init__(
self,
policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-3,
learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
gradient_steps: int = -1,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 50000,
learning_starts: int = 100,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
Expand Down
8 changes: 4 additions & 4 deletions stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def __init__(
self,
policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-3,
learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
gradient_steps: int = -1,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.1
2.3.0a0

0 comments on commit fb0d36d

Please sign in to comment.