diff --git a/CHANGELOG.md b/CHANGELOG.md
index 745e2b14..13e167ed 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,16 @@
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
+## [1.1.0] - 2024-02-12
+### Added
+- MultiCategorical mixin to operate MultiDiscrete action spaces
+
+### Changed (breaking changes)
+- Rename the `ManualTrainer` to `StepTrainer`
+- Output training/evaluation progress messages to system's stdout
+- Get single observation/action spaces for vectorized environments
+- Update Isaac Orbit environment wrapper
+
## [1.0.0] - 2023-08-16
Transition from pre-release versions (`1.0.0-rc.1` and`1.0.0-rc.2`) to a stable version.
diff --git a/docs/source/_static/imgs/model_categorical_cnn-dark.svg b/docs/source/_static/imgs/model_categorical_cnn-dark.svg
index 1c312ccc..9ae9bf44 100755
--- a/docs/source/_static/imgs/model_categorical_cnn-dark.svg
+++ b/docs/source/_static/imgs/model_categorical_cnn-dark.svg
@@ -1 +1 @@
-
+
diff --git a/docs/source/_static/imgs/model_categorical_cnn-light.svg b/docs/source/_static/imgs/model_categorical_cnn-light.svg
index da4893c4..030210e4 100755
--- a/docs/source/_static/imgs/model_categorical_cnn-light.svg
+++ b/docs/source/_static/imgs/model_categorical_cnn-light.svg
@@ -1 +1 @@
-
+
diff --git a/docs/source/_static/imgs/model_deterministic_cnn-dark.svg b/docs/source/_static/imgs/model_deterministic_cnn-dark.svg
index df29d7a3..f901f94d 100755
--- a/docs/source/_static/imgs/model_deterministic_cnn-dark.svg
+++ b/docs/source/_static/imgs/model_deterministic_cnn-dark.svg
@@ -1 +1 @@
-
+
diff --git a/docs/source/_static/imgs/model_deterministic_cnn-light.svg b/docs/source/_static/imgs/model_deterministic_cnn-light.svg
index 7ddb76f7..d12de8de 100755
--- a/docs/source/_static/imgs/model_deterministic_cnn-light.svg
+++ b/docs/source/_static/imgs/model_deterministic_cnn-light.svg
@@ -1 +1 @@
-
+
diff --git a/docs/source/_static/imgs/model_gaussian_cnn-dark.svg b/docs/source/_static/imgs/model_gaussian_cnn-dark.svg
index 411fa4e2..a48fa5c8 100755
--- a/docs/source/_static/imgs/model_gaussian_cnn-dark.svg
+++ b/docs/source/_static/imgs/model_gaussian_cnn-dark.svg
@@ -1 +1 @@
-
+
diff --git a/docs/source/_static/imgs/model_gaussian_cnn-light.svg b/docs/source/_static/imgs/model_gaussian_cnn-light.svg
index 426753a5..276e8b87 100755
--- a/docs/source/_static/imgs/model_gaussian_cnn-light.svg
+++ b/docs/source/_static/imgs/model_gaussian_cnn-light.svg
@@ -1 +1 @@
-
+
diff --git a/docs/source/_static/imgs/model_multicategorical-dark.svg b/docs/source/_static/imgs/model_multicategorical-dark.svg
new file mode 100755
index 00000000..dc246c43
--- /dev/null
+++ b/docs/source/_static/imgs/model_multicategorical-dark.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/_static/imgs/model_multicategorical-light.svg b/docs/source/_static/imgs/model_multicategorical-light.svg
new file mode 100755
index 00000000..e5e21a81
--- /dev/null
+++ b/docs/source/_static/imgs/model_multicategorical-light.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/api/agents/a2c.rst b/docs/source/api/agents/a2c.rst
index 23721bc0..c98f6448 100644
--- a/docs/source/api/agents/a2c.rst
+++ b/docs/source/api/agents/a2c.rst
@@ -164,6 +164,9 @@ The implementation supports the following `Gym spaces ` / :ref:`Gaussian ` / :ref:`MultivariateGaussian `
+ - :ref:`Categorical ` /
+ |br| :ref:`Multi-Categorical ` /
+ |br| :ref:`Gaussian ` /
+ |br| :ref:`MultivariateGaussian `
* - :math:`V_\phi(s)`
- Value
- :literal:`"value"`
diff --git a/docs/source/api/agents/amp.rst b/docs/source/api/agents/amp.rst
index 57cc9180..c993b67a 100644
--- a/docs/source/api/agents/amp.rst
+++ b/docs/source/api/agents/amp.rst
@@ -162,6 +162,10 @@ The implementation supports the following `Gym spaces ` / :ref:`MultivariateGaussian `
+ - :ref:`Gaussian ` /
+ |br| :ref:`MultivariateGaussian `
* - :math:`V_\phi(s)`
- Value
- :literal:`"value"`
diff --git a/docs/source/api/agents/cem.rst b/docs/source/api/agents/cem.rst
index 0c1ac587..68245818 100644
--- a/docs/source/api/agents/cem.rst
+++ b/docs/source/api/agents/cem.rst
@@ -119,6 +119,9 @@ The implementation supports the following `Gym spaces `
+ - :ref:`Categorical ` /
+ |br| :ref:`Multi-Categorical `
.. raw:: html
diff --git a/docs/source/api/agents/ddpg.rst b/docs/source/api/agents/ddpg.rst
index 86eb2fd0..00972f3f 100644
--- a/docs/source/api/agents/ddpg.rst
+++ b/docs/source/api/agents/ddpg.rst
@@ -159,6 +159,9 @@ The implementation supports the following `Gym spaces ` / :ref:`Gaussian ` / :ref:`MultivariateGaussian `
+ - :ref:`Categorical ` /
+ |br| :ref:`Multi-Categorical ` /
+ |br| :ref:`Gaussian ` /
+ |br| :ref:`MultivariateGaussian `
* - :math:`V_\phi(s)`
- Value
- :literal:`"value"`
diff --git a/docs/source/api/agents/q_learning.rst b/docs/source/api/agents/q_learning.rst
index 89452096..29c1b29b 100644
--- a/docs/source/api/agents/q_learning.rst
+++ b/docs/source/api/agents/q_learning.rst
@@ -99,6 +99,9 @@ The implementation supports the following `Gym spaces ` / :ref:`MultivariateGaussian `
+ - :ref:`Gaussian ` /
+ |br| :ref:`MultivariateGaussian `
* - :math:`V_\phi(s)`
- Value
- :literal:`"value"`
diff --git a/docs/source/api/agents/sac.rst b/docs/source/api/agents/sac.rst
index 457b0542..c55720cc 100644
--- a/docs/source/api/agents/sac.rst
+++ b/docs/source/api/agents/sac.rst
@@ -160,6 +160,9 @@ The implementation supports the following `Gym spaces ` / :ref:`MultivariateGaussian `
+ - :ref:`Gaussian ` /
+ |br| :ref:`MultivariateGaussian `
* - :math:`Q_{\phi 1}(s, a)`
- Q1-network (critic 1)
- :literal:`"critic_1"`
diff --git a/docs/source/api/agents/sarsa.rst b/docs/source/api/agents/sarsa.rst
index 87a6d8bf..e7759b54 100644
--- a/docs/source/api/agents/sarsa.rst
+++ b/docs/source/api/agents/sarsa.rst
@@ -99,6 +99,9 @@ The implementation supports the following `Gym spaces ` / :ref:`MultivariateGaussian `
+ - :ref:`Gaussian ` /
+ |br| :ref:`MultivariateGaussian `
* - :math:`V_\phi(s)`
- Value
- :literal:`"value"`
diff --git a/docs/source/api/models.rst b/docs/source/api/models.rst
index 48083b81..522d5eb6 100644
--- a/docs/source/api/models.rst
+++ b/docs/source/api/models.rst
@@ -6,6 +6,7 @@ Models
Tabular
Categorical
+ Multi-Categorical
Gaussian
Multivariate Gaussian
Deterministic
@@ -29,6 +30,9 @@ Models (or agent models) refer to a representation of the agent's policy, value
* - :doc:`Categorical model ` (discrete domain)
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
+ * - :doc:`Multi-Categorical model ` (discrete domain)
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
* - :doc:`Gaussian model ` (continuous domain)
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
diff --git a/docs/source/api/models/categorical.rst b/docs/source/api/models/categorical.rst
index 19543ab4..c285896d 100644
--- a/docs/source/api/models/categorical.rst
+++ b/docs/source/api/models/categorical.rst
@@ -163,6 +163,24 @@ Usage
:start-after: [start-cnn-functional-torch]
:end-before: [end-cnn-functional-torch]
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. tabs::
+
+ .. group-tab:: setup-style
+
+ .. literalinclude:: ../../snippets/categorical_model.py
+ :language: python
+ :start-after: [start-cnn-setup-jax]
+ :end-before: [end-cnn-setup-jax]
+
+ .. group-tab:: compact-style
+
+ .. literalinclude:: ../../snippets/categorical_model.py
+ :language: python
+ :start-after: [start-cnn-compact-jax]
+ :end-before: [end-cnn-compact-jax]
+
.. tab:: RNN
.. image:: ../../_static/imgs/model_categorical_rnn-light.svg
diff --git a/docs/source/api/models/deterministic.rst b/docs/source/api/models/deterministic.rst
index e5c24b87..30e5ef91 100644
--- a/docs/source/api/models/deterministic.rst
+++ b/docs/source/api/models/deterministic.rst
@@ -163,6 +163,24 @@ Usage
:start-after: [start-cnn-functional-torch]
:end-before: [end-cnn-functional-torch]
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. tabs::
+
+ .. group-tab:: setup-style
+
+ .. literalinclude:: ../../snippets/deterministic_model.py
+ :language: python
+ :start-after: [start-cnn-setup-jax]
+ :end-before: [end-cnn-setup-jax]
+
+ .. group-tab:: compact-style
+
+ .. literalinclude:: ../../snippets/deterministic_model.py
+ :language: python
+ :start-after: [start-cnn-compact-jax]
+ :end-before: [end-cnn-compact-jax]
+
.. tab:: RNN
.. image:: ../../_static/imgs/model_deterministic_rnn-light.svg
diff --git a/docs/source/api/models/gaussian.rst b/docs/source/api/models/gaussian.rst
index 721d79bd..1aa50c05 100644
--- a/docs/source/api/models/gaussian.rst
+++ b/docs/source/api/models/gaussian.rst
@@ -163,6 +163,24 @@ Usage
:start-after: [start-cnn-functional-torch]
:end-before: [end-cnn-functional-torch]
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. tabs::
+
+ .. group-tab:: setup-style
+
+ .. literalinclude:: ../../snippets/gaussian_model.py
+ :language: python
+ :start-after: [start-cnn-setup-jax]
+ :end-before: [end-cnn-setup-jax]
+
+ .. group-tab:: compact-style
+
+ .. literalinclude:: ../../snippets/gaussian_model.py
+ :language: python
+ :start-after: [start-cnn-compact-jax]
+ :end-before: [end-cnn-compact-jax]
+
.. tab:: RNN
.. image:: ../../_static/imgs/model_gaussian_rnn-light.svg
diff --git a/docs/source/api/models/multicategorical.rst b/docs/source/api/models/multicategorical.rst
new file mode 100644
index 00000000..a8c89065
--- /dev/null
+++ b/docs/source/api/models/multicategorical.rst
@@ -0,0 +1,401 @@
+.. _models_multicategorical:
+
+Multi-Categorical model
+=======================
+
+Multi-Categorical models run **discrete-domain stochastic** policies.
+
+.. raw:: html
+
+
+
+skrl provides a Python mixin (:literal:`MultiCategoricalMixin`) to assist in the creation of these types of models, allowing users to have full control over the function approximator definitions and architectures. Note that the use of this mixin must comply with the following rules:
+
+* The definition of multiple inheritance must always include the :ref:`Model ` base class at the end.
+
+* The :ref:`Model ` base class constructor must be invoked before the mixins constructor.
+
+.. warning::
+
+ For models in JAX/Flax it is imperative to define all parameters (except ``observation_space``, ``action_space`` and ``device``) with default values to avoid errors (``TypeError: __init__() missing N required positional argument``) during initialization.
+
+ In addition, it is necessary to initialize the model's ``state_dict`` (via the ``init_state_dict`` method) after its instantiation to avoid errors (``AttributeError: object has no attribute "state_dict". If "state_dict" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'``) during its use.
+
+.. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :emphasize-lines: 1, 3-4
+ :start-after: [start-definition-torch]
+ :end-before: [end-definition-torch]
+
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :emphasize-lines: 1, 3-4
+ :start-after: [start-definition-jax]
+ :end-before: [end-definition-jax]
+
+.. raw:: html
+
+
+
+Concept
+-------
+
+.. image:: ../../_static/imgs/model_multicategorical-light.svg
+ :width: 100%
+ :align: center
+ :class: only-light
+ :alt: Multi-Categorical model
+
+.. image:: ../../_static/imgs/model_multicategorical-dark.svg
+ :width: 100%
+ :align: center
+ :class: only-dark
+ :alt: Multi-Categorical model
+
+.. raw:: html
+
+
+
+Usage
+-----
+
+* Multi-Layer Perceptron (**MLP**)
+* Convolutional Neural Network (**CNN**)
+* Recurrent Neural Network (**RNN**)
+* Gated Recurrent Unit RNN (**GRU**)
+* Long Short-Term Memory RNN (**LSTM**)
+
+.. tabs::
+
+ .. tab:: MLP
+
+ .. image:: ../../_static/imgs/model_categorical_mlp-light.svg
+ :width: 40%
+ :align: center
+ :class: only-light
+
+ .. image:: ../../_static/imgs/model_categorical_mlp-dark.svg
+ :width: 40%
+ :align: center
+ :class: only-dark
+
+ .. raw:: html
+
+
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. tabs::
+
+ .. group-tab:: nn.Sequential
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-mlp-sequential-torch]
+ :end-before: [end-mlp-sequential-torch]
+
+ .. group-tab:: nn.functional
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-mlp-functional-torch]
+ :end-before: [end-mlp-functional-torch]
+
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. tabs::
+
+ .. group-tab:: setup-style
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-mlp-setup-jax]
+ :end-before: [end-mlp-setup-jax]
+
+ .. group-tab:: compact-style
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-mlp-compact-jax]
+ :end-before: [end-mlp-compact-jax]
+
+ .. tab:: CNN
+
+ .. image:: ../../_static/imgs/model_categorical_cnn-light.svg
+ :width: 100%
+ :align: center
+ :class: only-light
+
+ .. image:: ../../_static/imgs/model_categorical_cnn-dark.svg
+ :width: 100%
+ :align: center
+ :class: only-dark
+
+ .. raw:: html
+
+
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. tabs::
+
+ .. group-tab:: nn.Sequential
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-cnn-sequential-torch]
+ :end-before: [end-cnn-sequential-torch]
+
+ .. group-tab:: nn.functional
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-cnn-functional-torch]
+ :end-before: [end-cnn-functional-torch]
+
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. tabs::
+
+ .. group-tab:: setup-style
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-cnn-setup-jax]
+ :end-before: [end-cnn-setup-jax]
+
+ .. group-tab:: compact-style
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-cnn-compact-jax]
+ :end-before: [end-cnn-compact-jax]
+
+ .. tab:: RNN
+
+ .. image:: ../../_static/imgs/model_categorical_rnn-light.svg
+ :width: 90%
+ :align: center
+ :class: only-light
+
+ .. image:: ../../_static/imgs/model_categorical_rnn-dark.svg
+ :width: 90%
+ :align: center
+ :class: only-dark
+
+ where:
+
+ .. math::
+ \begin{aligned}
+ N ={} & \text{batch size} \\
+ L ={} & \text{sequence length} \\
+ D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
+ H_{in} ={} & \text{input_size} \\
+ H_{out} ={} & \text{hidden_size}
+ \end{aligned}
+
+ .. raw:: html
+
+
+
+ The following points are relevant in the definition of recurrent models:
+
+ * The ``.get_specification()`` method must be overwritten to return, under a dictionary key ``"rnn"``, a sub-dictionary that includes the sequence length (under key ``"sequence_length"``) as a number and a list of the dimensions (under key ``"sizes"``) of each initial hidden state
+
+ * The ``.compute()`` method's ``inputs`` parameter will have, at least, the following items in the dictionary:
+
+ * ``"states"``: state of the environment used to make the decision
+ * ``"taken_actions"``: actions taken by the policy for the given states, if applicable
+ * ``"terminated"``: episode termination status for sampled environment transitions. This key is only defined during the training process
+ * ``"rnn"``: list of initial hidden states ordered according to the model specification
+
+ * The ``.compute()`` method must include, under the ``"rnn"`` key of the returned dictionary, a list of each final hidden state
+
+ .. raw:: html
+
+
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. tabs::
+
+ .. group-tab:: nn.Sequential
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-rnn-sequential-torch]
+ :end-before: [end-rnn-sequential-torch]
+
+ .. group-tab:: nn.functional
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-rnn-functional-torch]
+ :end-before: [end-rnn-functional-torch]
+
+ .. tab:: GRU
+
+ .. image:: ../../_static/imgs/model_categorical_rnn-light.svg
+ :width: 90%
+ :align: center
+ :class: only-light
+
+ .. image:: ../../_static/imgs/model_categorical_rnn-dark.svg
+ :width: 90%
+ :align: center
+ :class: only-dark
+
+ where:
+
+ .. math::
+ \begin{aligned}
+ N ={} & \text{batch size} \\
+ L ={} & \text{sequence length} \\
+ D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
+ H_{in} ={} & \text{input_size} \\
+ H_{out} ={} & \text{hidden_size}
+ \end{aligned}
+
+ .. raw:: html
+
+
+
+ The following points are relevant in the definition of recurrent models:
+
+ * The ``.get_specification()`` method must be overwritten to return, under a dictionary key ``"rnn"``, a sub-dictionary that includes the sequence length (under key ``"sequence_length"``) as a number and a list of the dimensions (under key ``"sizes"``) of each initial hidden state
+
+ * The ``.compute()`` method's ``inputs`` parameter will have, at least, the following items in the dictionary:
+
+ * ``"states"``: state of the environment used to make the decision
+ * ``"taken_actions"``: actions taken by the policy for the given states, if applicable
+ * ``"terminated"``: episode termination status for sampled environment transitions. This key is only defined during the training process
+ * ``"rnn"``: list of initial hidden states ordered according to the model specification
+
+ * The ``.compute()`` method must include, under the ``"rnn"`` key of the returned dictionary, a list of each final hidden state
+
+ .. raw:: html
+
+
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. tabs::
+
+ .. group-tab:: nn.Sequential
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-gru-sequential-torch]
+ :end-before: [end-gru-sequential-torch]
+
+ .. group-tab:: nn.functional
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-gru-functional-torch]
+ :end-before: [end-gru-functional-torch]
+
+ .. tab:: LSTM
+
+ .. image:: ../../_static/imgs/model_categorical_rnn-light.svg
+ :width: 90%
+ :align: center
+ :class: only-light
+
+ .. image:: ../../_static/imgs/model_categorical_rnn-dark.svg
+ :width: 90%
+ :align: center
+ :class: only-dark
+
+ where:
+
+ .. math::
+ \begin{aligned}
+ N ={} & \text{batch size} \\
+ L ={} & \text{sequence length} \\
+ D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
+ H_{in} ={} & \text{input_size} \\
+ H_{cell} ={} & \text{hidden_size} \\
+ H_{out} ={} & \text{proj_size if } \text{proj_size}>0 \text{ otherwise hidden_size} \\
+ \end{aligned}
+
+ .. raw:: html
+
+
+
+ The following points are relevant in the definition of recurrent models:
+
+ * The ``.get_specification()`` method must be overwritten to return, under a dictionary key ``"rnn"``, a sub-dictionary that includes the sequence length (under key ``"sequence_length"``) as a number and a list of the dimensions (under key ``"sizes"``) of each initial hidden/cell states
+
+ * The ``.compute()`` method's ``inputs`` parameter will have, at least, the following items in the dictionary:
+
+ * ``"states"``: state of the environment used to make the decision
+ * ``"taken_actions"``: actions taken by the policy for the given states, if applicable
+ * ``"terminated"``: episode termination status for sampled environment transitions. This key is only defined during the training process
+ * ``"rnn"``: list of initial hidden/cell states ordered according to the model specification
+
+ * The ``.compute()`` method must include, under the ``"rnn"`` key of the returned dictionary, a list of each final hidden/cell states
+
+ .. raw:: html
+
+
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. tabs::
+
+ .. group-tab:: nn.Sequential
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-lstm-sequential-torch]
+ :end-before: [end-lstm-sequential-torch]
+
+ .. group-tab:: nn.functional
+
+ .. literalinclude:: ../../snippets/multicategorical_model.py
+ :language: python
+ :start-after: [start-lstm-functional-torch]
+ :end-before: [end-lstm-functional-torch]
+
+.. raw:: html
+
+
+
+API (PyTorch)
+-------------
+
+.. autoclass:: skrl.models.torch.multicategorical.MultiCategoricalMixin
+ :show-inheritance:
+ :members:
+
+ .. automethod:: __init__
+
+.. raw:: html
+
+
+
+API (JAX)
+---------
+
+.. autoclass:: skrl.models.jax.multicategorical.MultiCategoricalMixin
+ :show-inheritance:
+ :members:
+
+ .. automethod:: __init__
diff --git a/docs/source/api/multi_agents/ippo.rst b/docs/source/api/multi_agents/ippo.rst
index 8557a028..9a259326 100644
--- a/docs/source/api/multi_agents/ippo.rst
+++ b/docs/source/api/multi_agents/ippo.rst
@@ -171,6 +171,9 @@ The implementation supports the following `Gym spaces ` / :ref:`Gaussian ` / :ref:`MultivariateGaussian `
+ - :ref:`Categorical ` /
+ |br| :ref:`Multi-Categorical ` /
+ |br| :ref:`Gaussian ` /
+ |br| :ref:`MultivariateGaussian `
* - :math:`V_\phi(s)`
- Value
- :literal:`"value"`
diff --git a/docs/source/api/multi_agents/mappo.rst b/docs/source/api/multi_agents/mappo.rst
index ea8b82ad..c875ac6a 100644
--- a/docs/source/api/multi_agents/mappo.rst
+++ b/docs/source/api/multi_agents/mappo.rst
@@ -172,6 +172,9 @@ The implementation supports the following `Gym spaces ` / :ref:`Gaussian ` / :ref:`MultivariateGaussian `
+ - :ref:`Categorical ` /
+ |br| :ref:`Multi-Categorical ` /
+ |br| :ref:`Gaussian ` /
+ |br| :ref:`MultivariateGaussian `
* - :math:`V_\phi(s)`
- Value
- :literal:`"value"`
diff --git a/docs/source/api/trainers.rst b/docs/source/api/trainers.rst
index 45d4e5f1..83055038 100644
--- a/docs/source/api/trainers.rst
+++ b/docs/source/api/trainers.rst
@@ -6,7 +6,8 @@ Trainers
Sequential
Parallel
- Manual
+ Step
+ Manual training
Trainers are responsible for orchestrating and managing the training/evaluation of agents and their interactions with the environment.
@@ -26,7 +27,10 @@ Trainers are responsible for orchestrating and managing the training/evaluation
* - :doc:`Parallel trainer `
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
- * - :doc:`Manual trainer `
+ * - :doc:`Step trainer `
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\blacksquare`
+ * - :doc:`Manual training `
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
diff --git a/docs/source/api/trainers/manual.rst b/docs/source/api/trainers/manual.rst
index 439755e0..61047610 100644
--- a/docs/source/api/trainers/manual.rst
+++ b/docs/source/api/trainers/manual.rst
@@ -1,5 +1,5 @@
-Manual trainer
-==============
+Manual training
+===============
Train agents by manually controlling the training/evaluation loop.
@@ -33,60 +33,40 @@ Usage
.. group-tab:: |_4| |pytorch| |_4|
- .. literalinclude:: ../../snippets/trainer.py
- :language: python
- :start-after: [pytorch-start-manual]
- :end-before: [pytorch-end-manual]
+ .. tabs::
- .. group-tab:: |_4| |jax| |_4|
-
- .. literalinclude:: ../../snippets/trainer.py
- :language: python
- :start-after: [jax-start-manual]
- :end-before: [jax-end-manual]
-
-.. raw:: html
+ .. group-tab:: Training
-
+ .. literalinclude:: ../../snippets/trainer.py
+ :language: python
+ :start-after: [pytorch-start-manual-training]
+ :end-before: [pytorch-end-manual-training]
-Configuration
--------------
+ .. group-tab:: Evaluation
-.. literalinclude:: ../../../../skrl/trainers/torch/manual.py
- :language: python
- :lines: 14-19
- :linenos:
+ .. literalinclude:: ../../snippets/trainer.py
+ :language: python
+ :start-after: [pytorch-start-manual-evaluation]
+ :end-before: [pytorch-end-manual-evaluation]
-.. raw:: html
+ .. group-tab:: |_4| |jax| |_4|
-
+ .. tabs::
-API (PyTorch)
--------------
+ .. group-tab:: Training
-.. autoclass:: skrl.trainers.torch.manual.MANUAL_TRAINER_DEFAULT_CONFIG
+ .. literalinclude:: ../../snippets/trainer.py
+ :language: python
+ :start-after: [jax-start-manual-training]
+ :end-before: [jax-end-manual-training]
-.. autoclass:: skrl.trainers.torch.manual.ManualTrainer
- :undoc-members:
- :show-inheritance:
- :inherited-members:
- :members:
+ .. group-tab:: Evaluation
- .. automethod:: __init__
+ .. literalinclude:: ../../snippets/trainer.py
+ :language: python
+ :start-after: [jax-start-manual-evaluation]
+ :end-before: [jax-end-manual-evaluation]
.. raw:: html
-
-API (JAX)
----------
-
-.. autoclass:: skrl.trainers.jax.manual.MANUAL_TRAINER_DEFAULT_CONFIG
-
-.. autoclass:: skrl.trainers.jax.manual.ManualTrainer
- :undoc-members:
- :show-inheritance:
- :inherited-members:
- :members:
-
- .. automethod:: __init__
diff --git a/docs/source/api/trainers/parallel.rst b/docs/source/api/trainers/parallel.rst
index beff9be4..241b92d9 100644
--- a/docs/source/api/trainers/parallel.rst
+++ b/docs/source/api/trainers/parallel.rst
@@ -55,8 +55,8 @@ Configuration
.. literalinclude:: ../../../../skrl/trainers/torch/parallel.py
:language: python
- :lines: 15-20
- :linenos:
+ :start-after: [start-config-dict-torch]
+ :end-before: [end-config-dict-torch]
.. raw:: html
diff --git a/docs/source/api/trainers/sequential.rst b/docs/source/api/trainers/sequential.rst
index a4ee9095..6728dfb4 100644
--- a/docs/source/api/trainers/sequential.rst
+++ b/docs/source/api/trainers/sequential.rst
@@ -54,8 +54,8 @@ Configuration
.. literalinclude:: ../../../../skrl/trainers/torch/sequential.py
:language: python
- :lines: 14-19
- :linenos:
+ :start-after: [start-config-dict-torch]
+ :end-before: [end-config-dict-torch]
.. raw:: html
diff --git a/docs/source/api/trainers/step.rst b/docs/source/api/trainers/step.rst
new file mode 100644
index 00000000..ebb40e2c
--- /dev/null
+++ b/docs/source/api/trainers/step.rst
@@ -0,0 +1,92 @@
+Step trainer
+============
+
+Train agents controlling the training/evaluation loop step-by-step.
+
+.. raw:: html
+
+
+
+Concept
+-------
+
+.. image:: ../../_static/imgs/manual_trainer-light.svg
+ :width: 100%
+ :align: center
+ :class: only-light
+ :alt: Step-by-step trainer
+
+.. image:: ../../_static/imgs/manual_trainer-dark.svg
+ :width: 100%
+ :align: center
+ :class: only-dark
+ :alt: Step-by-step trainer
+
+.. raw:: html
+
+
+
+Usage
+-----
+
+.. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. literalinclude:: ../../snippets/trainer.py
+ :language: python
+ :start-after: [pytorch-start-step]
+ :end-before: [pytorch-end-step]
+
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. literalinclude:: ../../snippets/trainer.py
+ :language: python
+ :start-after: [jax-start-step]
+ :end-before: [jax-end-step]
+
+.. raw:: html
+
+
+
+Configuration
+-------------
+
+.. literalinclude:: ../../../../skrl/trainers/torch/step.py
+ :language: python
+ :start-after: [start-config-dict-torch]
+ :end-before: [end-config-dict-torch]
+
+.. raw:: html
+
+
+
+API (PyTorch)
+-------------
+
+.. autoclass:: skrl.trainers.torch.step.STEP_TRAINER_DEFAULT_CONFIG
+
+.. autoclass:: skrl.trainers.torch.step.StepTrainer
+ :undoc-members:
+ :show-inheritance:
+ :inherited-members:
+ :members:
+
+ .. automethod:: __init__
+
+.. raw:: html
+
+
+
+API (JAX)
+---------
+
+.. autoclass:: skrl.trainers.jax.step.STEP_TRAINER_DEFAULT_CONFIG
+
+.. autoclass:: skrl.trainers.jax.step.StepTrainer
+ :undoc-members:
+ :show-inheritance:
+ :inherited-members:
+ :members:
+
+ .. automethod:: __init__
diff --git a/docs/source/api/utils/isaacgym_utils.rst b/docs/source/api/utils/isaacgym_utils.rst
index 88d5a188..882fbd81 100644
--- a/docs/source/api/utils/isaacgym_utils.rst
+++ b/docs/source/api/utils/isaacgym_utils.rst
@@ -107,7 +107,6 @@ Usage
.. literalinclude:: ../../snippets/isaacgym_utils.py
:language: python
- :linenos:
:emphasize-lines: 4, 8, 56, 65-68
.. raw:: html
diff --git a/docs/source/api/utils/postprocessing.rst b/docs/source/api/utils/postprocessing.rst
index bc575f85..5eeb4ec5 100644
--- a/docs/source/api/utils/postprocessing.rst
+++ b/docs/source/api/utils/postprocessing.rst
@@ -25,7 +25,6 @@ Usage
.. literalinclude:: ../../snippets/utils_postprocessing.py
:language: python
- :linenos:
:emphasize-lines: 1, 5-6
:start-after: [start-memory_file_iterator-torch]
:end-before: [end-memory_file_iterator-torch]
@@ -34,7 +33,6 @@ Usage
.. literalinclude:: ../../snippets/utils_postprocessing.py
:language: python
- :linenos:
:emphasize-lines: 1, 5-6
:start-after: [start-memory_file_iterator-numpy]
:end-before: [end-memory_file_iterator-numpy]
@@ -43,7 +41,6 @@ Usage
.. literalinclude:: ../../snippets/utils_postprocessing.py
:language: python
- :linenos:
:emphasize-lines: 1, 5-6
:start-after: [start-memory_file_iterator-csv]
:end-before: [end-memory_file_iterator-csv]
@@ -101,7 +98,6 @@ Usage
.. literalinclude:: ../../snippets/utils_postprocessing.py
:language: python
- :linenos:
:emphasize-lines: 1, 5-7
:start-after: [start-tensorboard_file_iterator-list]
:end-before: [end-tensorboard_file_iterator-list]
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 42f37f8a..f04c7874 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -16,7 +16,7 @@
if skrl.__version__ != "unknown":
release = version = skrl.__version__
else:
- release = version = "1.0.0"
+ release = version = "1.1.0"
master_doc = "index"
diff --git a/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_env.py b/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_env.py
index 9e7f07bd..248f342b 100644
--- a/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_env.py
+++ b/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_env.py
@@ -1,9 +1,6 @@
import torch
import numpy as np
-from omni.isaac.core.utils.extensions import enable_extension
-enable_extension("omni.replicator.isaac") # required by OIGE
-
from omniisaacgymenvs.tasks.base.rl_task import RLTask
from omniisaacgymenvs.robots.articulations.franka import Franka as Robot
@@ -27,6 +24,8 @@
"headless": True,
"sim_device": "gpu",
"enable_livestream": False,
+ "warp": False,
+ "seed": 42,
"task": {"name": "ReachingFranka",
"physics_engine": "physx",
"env": {"numEnvs": 1024,
@@ -86,6 +85,7 @@
"rest_offset": 0.0},
"target": {"override_usd_defaults": False,
"fixed_base": True,
+ "make_kinematic": True,
"enable_self_collisions": False,
"enable_gyroscopic_forces": True,
"solver_position_iteration_count": 4,
diff --git a/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_skrl_eval.py b/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_skrl_eval.py
index ed4416b4..204a5584 100644
--- a/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_skrl_eval.py
+++ b/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_skrl_eval.py
@@ -8,6 +8,11 @@
from skrl.trainers.torch import SequentialTrainer
from skrl.utils.omniverse_isaacgym_utils import get_env_instance
from skrl.envs.torch import wrap_env
+from skrl.utils import set_seed
+
+
+# Seed for reproducibility
+seed = set_seed() # e.g. `set_seed(42)` for fixed seed
# Define only the policy for evaluation
@@ -37,6 +42,7 @@ def compute(self, inputs, role):
from omniisaacgymenvs.utils.config_utils.sim_config import SimConfig
from reaching_franka_omniverse_isaacgym_env import ReachingFrankaTask, TASK_CFG
+TASK_CFG["seed"] = seed
TASK_CFG["headless"] = headless
TASK_CFG["task"]["env"]["numEnvs"] = 64
TASK_CFG["task"]["env"]["controlSpace"] = "joint" # "joint" or "cartesian"
diff --git a/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_skrl_train.py b/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_skrl_train.py
index 899df201..75d28c3e 100644
--- a/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_skrl_train.py
+++ b/docs/source/examples/real_world/franka_emika_panda/reaching_franka_omniverse_isaacgym_skrl_train.py
@@ -13,8 +13,8 @@
from skrl.utils import set_seed
-# set the seed for reproducibility
-set_seed(42)
+# Seed for reproducibility
+seed = set_seed() # e.g. `set_seed(42)` for fixed seed
# Define the models (stochastic and deterministic models) for the agent using helper mixin.
@@ -62,6 +62,7 @@ def compute(self, inputs, role):
from omniisaacgymenvs.utils.config_utils.sim_config import SimConfig
from reaching_franka_omniverse_isaacgym_env import ReachingFrankaTask, TASK_CFG
+TASK_CFG["seed"] = seed
TASK_CFG["headless"] = headless
TASK_CFG["task"]["env"]["numEnvs"] = 1024
TASK_CFG["task"]["env"]["controlSpace"] = "joint" # "joint" or "cartesian"
diff --git a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_env.py b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_env.py
index f3a5a458..495892cd 100644
--- a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_env.py
+++ b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_env.py
@@ -1,9 +1,6 @@
import torch
import numpy as np
-from omni.isaac.core.utils.extensions import enable_extension
-enable_extension("omni.replicator.isaac") # required by OIGE
-
from omniisaacgymenvs.tasks.base.rl_task import RLTask
from omni.isaac.core.prims import RigidPrimView
@@ -28,6 +25,8 @@
"headless": True,
"sim_device": "gpu",
"enable_livestream": False,
+ "warp": False,
+ "seed": 42,
"task": {"name": "ReachingIiwa",
"physics_engine": "physx",
"env": {"numEnvs": 1024,
@@ -87,6 +86,7 @@
"rest_offset": 0.0},
"target": {"override_usd_defaults": False,
"fixed_base": True,
+ "make_kinematic": True,
"enable_self_collisions": False,
"enable_gyroscopic_forces": True,
"solver_position_iteration_count": 4,
diff --git a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_skrl_eval.py b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_skrl_eval.py
index c9494217..c6cb93b8 100644
--- a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_skrl_eval.py
+++ b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_skrl_eval.py
@@ -8,6 +8,11 @@
from skrl.trainers.torch import SequentialTrainer
from skrl.utils.omniverse_isaacgym_utils import get_env_instance
from skrl.envs.torch import wrap_env
+from skrl.utils import set_seed
+
+
+# Seed for reproducibility
+seed = set_seed() # e.g. `set_seed(42)` for fixed seed
# Define only the policy for evaluation
@@ -37,6 +42,7 @@ def compute(self, inputs, role):
from omniisaacgymenvs.utils.config_utils.sim_config import SimConfig
from reaching_iiwa_omniverse_isaacgym_env import ReachingIiwaTask, TASK_CFG
+TASK_CFG["seed"] = seed
TASK_CFG["headless"] = headless
TASK_CFG["task"]["env"]["numEnvs"] = 64
TASK_CFG["task"]["env"]["controlSpace"] = "joint" # "joint" or "cartesian"
diff --git a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_skrl_train.py b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_skrl_train.py
index d109085a..861cd77e 100644
--- a/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_skrl_train.py
+++ b/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_omniverse_isaacgym_skrl_train.py
@@ -13,8 +13,8 @@
from skrl.utils import set_seed
-# set the seed for reproducibility
-set_seed(42)
+# Seed for reproducibility
+seed = set_seed() # e.g. `set_seed(42)` for fixed seed
# Define the models (stochastic and deterministic models) for the agent using helper mixin.
@@ -62,6 +62,7 @@ def compute(self, inputs, role):
from omniisaacgymenvs.utils.config_utils.sim_config import SimConfig
from reaching_iiwa_omniverse_isaacgym_env import ReachingIiwaTask, TASK_CFG
+TASK_CFG["seed"] = seed
TASK_CFG["headless"] = headless
TASK_CFG["task"]["env"]["numEnvs"] = 1024
TASK_CFG["task"]["env"]["controlSpace"] = "joint" # "joint" or "cartesian"
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 5ab4e515..d587c727 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -154,6 +154,7 @@ Models
* :doc:`Tabular model ` (discrete domain)
* :doc:`Categorical model ` (discrete domain)
+ * :doc:`Multi-Categorical model ` (discrete domain)
* :doc:`Gaussian model ` (continuous domain)
* :doc:`Multivariate Gaussian model ` (continuous domain)
* :doc:`Deterministic model ` (continuous domain)
@@ -165,7 +166,7 @@ Trainers
* :doc:`Sequential trainer `
* :doc:`Parallel trainer `
- * :doc:`Manual trainer `
+ * :doc:`Step trainer `
Resources
^^^^^^^^^
diff --git a/docs/source/intro/getting_started.rst b/docs/source/intro/getting_started.rst
index db64b244..75320efe 100644
--- a/docs/source/intro/getting_started.rst
+++ b/docs/source/intro/getting_started.rst
@@ -762,7 +762,7 @@ The following code snippets show how to train/evaluate RL systems using the avai
:start-after: [pytorch-start-parallel]
:end-before: [pytorch-end-parallel]
- .. tab:: Manual trainer
+ .. tab:: Step trainer
.. tabs::
@@ -770,15 +770,15 @@ The following code snippets show how to train/evaluate RL systems using the avai
.. literalinclude:: ../snippets/trainer.py
:language: python
- :start-after: [pytorch-start-manual]
- :end-before: [pytorch-end-manual]
+ :start-after: [pytorch-start-step]
+ :end-before: [pytorch-end-step]
.. group-tab:: |_4| |jax| |_4|
.. literalinclude:: ../snippets/trainer.py
:language: python
- :start-after: [jax-start-manual]
- :end-before: [jax-end-manual]
+ :start-after: [jax-start-step]
+ :end-before: [jax-end-step]
.. raw:: html
diff --git a/docs/source/snippets/categorical_model.py b/docs/source/snippets/categorical_model.py
index 40800ab2..7afababc 100644
--- a/docs/source/snippets/categorical_model.py
+++ b/docs/source/snippets/categorical_model.py
@@ -242,6 +242,103 @@ def compute(self, inputs, role):
unnormalized_log_prob=True)
# [end-cnn-functional-torch]
+# [start-cnn-setup-jax]
+import flax.linen as nn
+
+from skrl.models.jax import Model, CategoricalMixin
+
+
+# define the model
+class CNN(CategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ CategoricalMixin.__init__(self, unnormalized_log_prob)
+
+ def setup(self):
+ self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
+ self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
+ self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
+ self.fc1 = nn.Dense(512)
+ self.fc2 = nn.Dense(16)
+ self.fc3 = nn.Dense(64)
+ self.fc4 = nn.Dense(32)
+ self.fc5 = nn.Dense(self.num_actions)
+
+ def __call__(self, inputs, role):
+ x = inputs["states"].reshape((-1, *self.observation_space.shape))
+ x = self.conv1(x)
+ x = nn.relu(x)
+ x = self.conv2(x)
+ x = nn.relu(x)
+ x = self.conv3(x)
+ x = nn.relu(x)
+ x = x.reshape((x.shape[0], -1))
+ x = self.fc1(x)
+ x = nn.relu(x)
+ x = self.fc2(x)
+ x = nn.tanh(x)
+ x = self.fc3(x)
+ x = nn.tanh(x)
+ x = self.fc4(x)
+ x = nn.tanh(x)
+ x = self.fc5(x)
+ return x, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True)
+
+# initialize model's state dict
+policy.init_state_dict("policy")
+# [end-cnn-setup-jax]
+
+# [start-cnn-compact-jax]
+import flax.linen as nn
+
+from skrl.models.jax import Model, CategoricalMixin
+
+
+# define the model
+class CNN(CategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ CategoricalMixin.__init__(self, unnormalized_log_prob)
+
+ @nn.compact # marks the given module method allowing inlined submodules
+ def __call__(self, inputs, role):
+ x = inputs["states"].reshape((-1, *self.observation_space.shape))
+ x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
+ x = nn.relu(x)
+ x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
+ x = nn.relu(x)
+ x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
+ x = nn.relu(x)
+ x = x.reshape((x.shape[0], -1))
+ x = nn.Dense(512)(x)
+ x = nn.relu(x)
+ x = nn.Dense(16)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(64)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(32)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(self.num_actions)(x)
+ return x, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True)
+
+# initialize model's state dict
+policy.init_state_dict("policy")
+# [end-cnn-compact-jax]
+
# =============================================================================
# [start-rnn-sequential-torch]
diff --git a/docs/source/snippets/deterministic_model.py b/docs/source/snippets/deterministic_model.py
index 38daa3e9..2022ae71 100644
--- a/docs/source/snippets/deterministic_model.py
+++ b/docs/source/snippets/deterministic_model.py
@@ -246,6 +246,107 @@ def compute(self, inputs, role):
clip_actions=False)
# [end-cnn-functional-torch]
+# [start-cnn-setup-jax]
+import jax.numpy as jnp
+import flax.linen as nn
+
+from skrl.models.jax import Model, DeterministicMixin
+
+
+# define the model
+class CNN(DeterministicMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ DeterministicMixin.__init__(self, clip_actions)
+
+ def setup(self):
+ self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
+ self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
+ self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
+ self.fc1 = nn.Dense(512)
+ self.fc2 = nn.Dense(16)
+ self.fc3 = nn.Dense(64)
+ self.fc4 = nn.Dense(32)
+ self.fc5 = nn.Dense(1)
+
+ def __call__(self, inputs, role):
+ x = inputs["states"].reshape((-1, *self.observation_space.shape))
+ x = self.conv1(x)
+ x = nn.relu(x)
+ x = self.conv2(x)
+ x = nn.relu(x)
+ x = self.conv3(x)
+ x = nn.relu(x)
+ x = x.reshape((x.shape[0], -1))
+ x = self.fc1(x)
+ x = nn.relu(x)
+ x = self.fc2(x)
+ x = nn.tanh(x)
+ x = jnp.concatenate([x, inputs["taken_actions"]], axis=-1)
+ x = self.fc3(x)
+ x = nn.tanh(x)
+ x = self.fc4(x)
+ x = nn.tanh(x)
+ x = self.fc5(x)
+ return x, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+critic = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ clip_actions=False)
+
+# initialize model's state dict
+critic.init_state_dict("critic")
+# [end-cnn-setup-jax]
+
+# [start-cnn-compact-jax]
+import jax.numpy as jnp
+import flax.linen as nn
+
+from skrl.models.jax import Model, DeterministicMixin
+
+
+# define the model
+class CNN(DeterministicMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ DeterministicMixin.__init__(self, clip_actions)
+
+ @nn.compact # marks the given module method allowing inlined submodules
+ def __call__(self, inputs, role):
+ x = inputs["states"].reshape((-1, *self.observation_space.shape))
+ x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
+ x = nn.relu(x)
+ x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
+ x = nn.relu(x)
+ x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
+ x = nn.relu(x)
+ x = x.reshape((x.shape[0], -1))
+ x = nn.Dense(512)(x)
+ x = nn.relu(x)
+ x = nn.Dense(16)(x)
+ x = nn.tanh(x)
+ x = jnp.concatenate([x, inputs["taken_actions"]], axis=-1)
+ x = nn.Dense(64)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(32)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(1)(x)
+ return x, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+critic = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ clip_actions=False)
+
+# initialize model's state dict
+critic.init_state_dict("critic")
+# [end-cnn-compact-jax]
+
# =============================================================================
# [start-rnn-sequential-torch]
diff --git a/docs/source/snippets/gaussian_model.py b/docs/source/snippets/gaussian_model.py
index 952c994e..3f48aa58 100644
--- a/docs/source/snippets/gaussian_model.py
+++ b/docs/source/snippets/gaussian_model.py
@@ -289,6 +289,118 @@ def compute(self, inputs, role):
reduction="sum")
# [end-cnn-functional-torch]
+# [start-cnn-setup-jax]
+import jax.numpy as jnp
+import flax.linen as nn
+
+from skrl.models.jax import Model, GaussianMixin
+
+
+# define the model
+class CNN(GaussianMixin, Model):
+ def __init__(self, observation_space, action_space, device=None,
+ clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
+
+ def setup(self):
+ self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
+ self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
+ self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
+ self.fc1 = nn.Dense(512)
+ self.fc2 = nn.Dense(16)
+ self.fc3 = nn.Dense(64)
+ self.fc4 = nn.Dense(32)
+ self.fc5 = nn.Dense(self.num_actions)
+
+ self.log_std_parameter = self.param("log_std_parameter", lambda _: jnp.zeros(self.num_actions))
+
+ def __call__(self, inputs, role):
+ x = inputs["states"].reshape((-1, *self.observation_space.shape))
+ x = self.conv1(x)
+ x = nn.relu(x)
+ x = self.conv2(x)
+ x = nn.relu(x)
+ x = self.conv3(x)
+ x = nn.relu(x)
+ x = x.reshape((x.shape[0], -1))
+ x = self.fc1(x)
+ x = nn.relu(x)
+ x = self.fc2(x)
+ x = nn.tanh(x)
+ x = self.fc3(x)
+ x = nn.tanh(x)
+ x = self.fc4(x)
+ x = nn.tanh(x)
+ x = self.fc5(x)
+ return nn.tanh(x), self.log_std_parameter, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ clip_actions=True,
+ clip_log_std=True,
+ min_log_std=-20,
+ max_log_std=2,
+ reduction="sum")
+
+# initialize model's state dict
+policy.init_state_dict("policy")
+# [end-cnn-setup-jax]
+
+# [start-cnn-compact-jax]
+import jax.numpy as jnp
+import flax.linen as nn
+
+from skrl.models.jax import Model, GaussianMixin
+
+
+# define the model
+class CNN(GaussianMixin, Model):
+ def __init__(self, observation_space, action_space, device=None,
+ clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
+
+ @nn.compact # marks the given module method allowing inlined submodules
+ def __call__(self, inputs, role):
+ x = inputs["states"].reshape((-1, *self.observation_space.shape))
+ x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
+ x = nn.relu(x)
+ x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
+ x = nn.relu(x)
+ x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
+ x = nn.relu(x)
+ x = x.reshape((x.shape[0], -1))
+ x = nn.Dense(512)(x)
+ x = nn.relu(x)
+ x = nn.Dense(16)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(64)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(32)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(self.num_actions)(x)
+ log_std_parameter = self.param("log_std_parameter", lambda _: jnp.zeros(self.num_actions))
+ return nn.tanh(x), log_std_parameter, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ clip_actions=True,
+ clip_log_std=True,
+ min_log_std=-20,
+ max_log_std=2,
+ reduction="sum")
+
+# initialize model's state dict
+policy.init_state_dict("policy")
+# [end-cnn-compact-jax]
+
# =============================================================================
# [start-rnn-sequential-torch]
diff --git a/docs/source/snippets/multicategorical_model.py b/docs/source/snippets/multicategorical_model.py
new file mode 100644
index 00000000..7dbcc422
--- /dev/null
+++ b/docs/source/snippets/multicategorical_model.py
@@ -0,0 +1,892 @@
+# [start-definition-torch]
+class MultiCategoricalModel(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, reduction="sum"):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+# [end-definition-torch]
+
+
+# [start-definition-jax]
+class MultiCategoricalModel(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, reduction="sum", **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+# [end-definition-jax]
+
+# =============================================================================
+
+# [start-mlp-sequential-torch]
+import torch
+import torch.nn as nn
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class MLP(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum"):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.net = nn.Sequential(nn.Linear(self.num_observations, 64),
+ nn.ReLU(),
+ nn.Linear(64, 32),
+ nn.ReLU(),
+ nn.Linear(32, self.num_actions))
+
+ def compute(self, inputs, role):
+ return self.net(inputs["states"]), {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = MLP(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum")
+# [end-mlp-sequential-torch]
+
+# [start-mlp-functional-torch]
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class MLP(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum"):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.fc1 = nn.Linear(self.num_observations, 64)
+ self.fc2 = nn.Linear(64, 32)
+ self.logits = nn.Linear(32, self.num_actions)
+
+ def compute(self, inputs, role):
+ x = self.fc1(inputs["states"])
+ x = F.relu(x)
+ x = self.fc2(x)
+ x = F.relu(x)
+ return self.logits(x), {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = MLP(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum")
+# [end-mlp-functional-torch]
+
+# [start-mlp-setup-jax]
+import flax.linen as nn
+
+from skrl.models.jax import Model, MultiCategoricalMixin
+
+
+# define the model
+class MLP(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, reduction="sum", **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ def setup(self):
+ self.fc1 = nn.Dense(64)
+ self.fc2 = nn.Dense(32)
+ self.fc3 = nn.Dense(self.num_actions)
+
+ def __call__(self, inputs, role):
+ x = self.fc1(inputs["states"])
+ x = nn.relu(x)
+ x = self.fc2(x)
+ x = nn.relu(x)
+ x = self.fc3(x)
+ return x, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = MLP(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum")
+
+# initialize model's state dict
+policy.init_state_dict("policy")
+# [end-mlp-setup-jax]
+
+# [start-mlp-compact-jax]
+import flax.linen as nn
+
+from skrl.models.jax import Model, MultiCategoricalMixin
+
+
+# define the model
+class MLP(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, reduction="sum", **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ @nn.compact # marks the given module method allowing inlined submodules
+ def __call__(self, inputs, role):
+ x = nn.Dense(64)(inputs["states"])
+ x = nn.relu(x)
+ x = nn.Dense(32)(x)
+ x = nn.relu(x)
+ x = nn.Dense(self.num_actions)(x)
+ return x, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = MLP(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum")
+
+# initialize model's state dict
+policy.init_state_dict("policy")
+# [end-mlp-compact-jax]
+
+# =============================================================================
+
+# [start-cnn-sequential-torch]
+import torch
+import torch.nn as nn
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class CNN(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum"):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.net = nn.Sequential(nn.Conv2d(3, 32, kernel_size=8, stride=4),
+ nn.ReLU(),
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
+ nn.ReLU(),
+ nn.Flatten(),
+ nn.Linear(1024, 512),
+ nn.ReLU(),
+ nn.Linear(512, 16),
+ nn.Tanh(),
+ nn.Linear(16, 64),
+ nn.Tanh(),
+ nn.Linear(64, 32),
+ nn.Tanh(),
+ nn.Linear(32, self.num_actions))
+
+ def compute(self, inputs, role):
+ # permute (samples, width * height * channels) -> (samples, channels, width, height)
+ return self.net(inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)), {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum")
+# [end-cnn-sequential-torch]
+
+# [start-cnn-functional-torch]
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class CNN(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum"):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=8, stride=4)
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
+ self.fc1 = nn.Linear(1024, 512)
+ self.fc2 = nn.Linear(512, 16)
+ self.fc3 = nn.Linear(16, 64)
+ self.fc4 = nn.Linear(64, 32)
+ self.fc5 = nn.Linear(32, self.num_actions)
+
+ def compute(self, inputs, role):
+ # permute (samples, width * height * channels) -> (samples, channels, width, height)
+ x = inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)
+ x = self.conv1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = F.relu(x)
+ x = self.conv3(x)
+ x = F.relu(x)
+ x = torch.flatten(x, start_dim=1)
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = self.fc2(x)
+ x = torch.tanh(x)
+ x = self.fc3(x)
+ x = torch.tanh(x)
+ x = self.fc4(x)
+ x = torch.tanh(x)
+ x = self.fc5(x)
+ return x, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum")
+# [end-cnn-functional-torch]
+
+# [start-cnn-setup-jax]
+import flax.linen as nn
+
+from skrl.models.jax import Model, MultiCategoricalMixin
+
+
+# define the model
+class CNN(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, reduction="sum", **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ def setup(self):
+ self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
+ self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
+ self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
+ self.fc1 = nn.Dense(512)
+ self.fc2 = nn.Dense(16)
+ self.fc3 = nn.Dense(64)
+ self.fc4 = nn.Dense(32)
+ self.fc5 = nn.Dense(self.num_actions)
+
+ def __call__(self, inputs, role):
+ x = inputs["states"].reshape((-1, *self.observation_space.shape))
+ x = self.conv1(x)
+ x = nn.relu(x)
+ x = self.conv2(x)
+ x = nn.relu(x)
+ x = self.conv3(x)
+ x = nn.relu(x)
+ x = x.reshape((x.shape[0], -1))
+ x = self.fc1(x)
+ x = nn.relu(x)
+ x = self.fc2(x)
+ x = nn.tanh(x)
+ x = self.fc3(x)
+ x = nn.tanh(x)
+ x = self.fc4(x)
+ x = nn.tanh(x)
+ x = self.fc5(x)
+ return x, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum")
+
+# initialize model's state dict
+policy.init_state_dict("policy")
+# [end-cnn-setup-jax]
+
+# [start-cnn-compact-jax]
+import flax.linen as nn
+
+from skrl.models.jax import Model, MultiCategoricalMixin
+
+
+# define the model
+class CNN(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, reduction="sum", **kwargs):
+ Model.__init__(self, observation_space, action_space, device, **kwargs)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ @nn.compact # marks the given module method allowing inlined submodules
+ def __call__(self, inputs, role):
+ x = inputs["states"].reshape((-1, *self.observation_space.shape))
+ x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
+ x = nn.relu(x)
+ x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
+ x = nn.relu(x)
+ x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
+ x = nn.relu(x)
+ x = x.reshape((x.shape[0], -1))
+ x = nn.Dense(512)(x)
+ x = nn.relu(x)
+ x = nn.Dense(16)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(64)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(32)(x)
+ x = nn.tanh(x)
+ x = nn.Dense(self.num_actions)(x)
+ return x, {}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = CNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum")
+
+# initialize model's state dict
+policy.init_state_dict("policy")
+# [end-cnn-compact-jax]
+
+# =============================================================================
+
+# [start-rnn-sequential-torch]
+import torch
+import torch.nn as nn
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class RNN(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum",
+ num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.num_envs = num_envs
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size # Hout
+ self.sequence_length = sequence_length
+
+ self.rnn = nn.RNN(input_size=self.num_observations,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ batch_first=True) # batch_first -> (batch, sequence, features)
+
+ self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
+ nn.ReLU(),
+ nn.Linear(64, 32),
+ nn.ReLU(),
+ nn.Linear(32, self.num_actions))
+
+ def get_specification(self):
+ # batch size (N) is the number of envs during rollout
+ return {"rnn": {"sequence_length": self.sequence_length,
+ "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
+
+ def compute(self, inputs, role):
+ states = inputs["states"]
+ terminated = inputs.get("terminated", None)
+ hidden_states = inputs["rnn"][0]
+
+ # training
+ if self.training:
+ rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
+ hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
+ # get the hidden states corresponding to the initial sequence
+ hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
+
+ # reset the RNN state in the middle of a sequence
+ if terminated is not None and torch.any(terminated):
+ rnn_outputs = []
+ terminated = terminated.view(-1, self.sequence_length)
+ indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
+
+ for i in range(len(indexes) - 1):
+ i0, i1 = indexes[i], indexes[i + 1]
+ rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
+ hidden_states[:, (terminated[:,i1-1]), :] = 0
+ rnn_outputs.append(rnn_output)
+
+ rnn_output = torch.cat(rnn_outputs, dim=1)
+ # no need to reset the RNN state in the sequence
+ else:
+ rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
+ # rollout
+ else:
+ rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
+ rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
+
+ # flatten the RNN output
+ rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
+
+ return self.net(rnn_output), {"rnn": [hidden_states]}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = RNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum",
+ num_envs=env.num_envs,
+ num_layers=1,
+ hidden_size=64,
+ sequence_length=10)
+# [end-rnn-sequential-torch]
+
+# [start-rnn-functional-torch]
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class RNN(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum",
+ num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.num_envs = num_envs
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size # Hout
+ self.sequence_length = sequence_length
+
+ self.rnn = nn.RNN(input_size=self.num_observations,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ batch_first=True) # batch_first -> (batch, sequence, features)
+
+ self.fc1 = nn.Linear(self.hidden_size, 64)
+ self.fc2 = nn.Linear(64, 32)
+ self.logits = nn.Linear(32, self.num_actions)
+
+ def get_specification(self):
+ # batch size (N) is the number of envs during rollout
+ return {"rnn": {"sequence_length": self.sequence_length,
+ "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
+
+ def compute(self, inputs, role):
+ states = inputs["states"]
+ terminated = inputs.get("terminated", None)
+ hidden_states = inputs["rnn"][0]
+
+ # training
+ if self.training:
+ rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
+ hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
+ # get the hidden states corresponding to the initial sequence
+ hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
+
+ # reset the RNN state in the middle of a sequence
+ if terminated is not None and torch.any(terminated):
+ rnn_outputs = []
+ terminated = terminated.view(-1, self.sequence_length)
+ indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
+
+ for i in range(len(indexes) - 1):
+ i0, i1 = indexes[i], indexes[i + 1]
+ rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
+ hidden_states[:, (terminated[:,i1-1]), :] = 0
+ rnn_outputs.append(rnn_output)
+
+ rnn_output = torch.cat(rnn_outputs, dim=1)
+ # no need to reset the RNN state in the sequence
+ else:
+ rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
+ # rollout
+ else:
+ rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
+ rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
+
+ # flatten the RNN output
+ rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
+
+ x = self.fc1(rnn_output)
+ x = F.relu(x)
+ x = self.fc2(x)
+ x = F.relu(x)
+
+ return self.logits(x), {"rnn": [hidden_states]}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = RNN(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum",
+ num_envs=env.num_envs,
+ num_layers=1,
+ hidden_size=64,
+ sequence_length=10)
+# [end-rnn-functional-torch]
+
+# =============================================================================
+
+# [start-gru-sequential-torch]
+import torch
+import torch.nn as nn
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class GRU(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum",
+ num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.num_envs = num_envs
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size # Hout
+ self.sequence_length = sequence_length
+
+ self.gru = nn.GRU(input_size=self.num_observations,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ batch_first=True) # batch_first -> (batch, sequence, features)
+
+ self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
+ nn.ReLU(),
+ nn.Linear(64, 32),
+ nn.ReLU(),
+ nn.Linear(32, self.num_actions))
+
+ def get_specification(self):
+ # batch size (N) is the number of envs during rollout
+ return {"rnn": {"sequence_length": self.sequence_length,
+ "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
+
+ def compute(self, inputs, role):
+ states = inputs["states"]
+ terminated = inputs.get("terminated", None)
+ hidden_states = inputs["rnn"][0]
+
+ # training
+ if self.training:
+ rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
+ hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
+ # get the hidden states corresponding to the initial sequence
+ hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
+
+ # reset the RNN state in the middle of a sequence
+ if terminated is not None and torch.any(terminated):
+ rnn_outputs = []
+ terminated = terminated.view(-1, self.sequence_length)
+ indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
+
+ for i in range(len(indexes) - 1):
+ i0, i1 = indexes[i], indexes[i + 1]
+ rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
+ hidden_states[:, (terminated[:,i1-1]), :] = 0
+ rnn_outputs.append(rnn_output)
+
+ rnn_output = torch.cat(rnn_outputs, dim=1)
+ # no need to reset the RNN state in the sequence
+ else:
+ rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
+ # rollout
+ else:
+ rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
+ rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
+
+ # flatten the RNN output
+ rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
+
+ return self.net(rnn_output), {"rnn": [hidden_states]}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = GRU(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum",
+ num_envs=env.num_envs,
+ num_layers=1,
+ hidden_size=64,
+ sequence_length=10)
+# [end-gru-sequential-torch]
+
+# [start-gru-functional-torch]
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class GRU(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum",
+ num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.num_envs = num_envs
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size # Hout
+ self.sequence_length = sequence_length
+
+ self.gru = nn.GRU(input_size=self.num_observations,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ batch_first=True) # batch_first -> (batch, sequence, features)
+
+ self.fc1 = nn.Linear(self.hidden_size, 64)
+ self.fc2 = nn.Linear(64, 32)
+ self.logits = nn.Linear(32, self.num_actions)
+
+ def get_specification(self):
+ # batch size (N) is the number of envs during rollout
+ return {"rnn": {"sequence_length": self.sequence_length,
+ "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
+
+ def compute(self, inputs, role):
+ states = inputs["states"]
+ terminated = inputs.get("terminated", None)
+ hidden_states = inputs["rnn"][0]
+
+ # training
+ if self.training:
+ rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
+ hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
+ # get the hidden states corresponding to the initial sequence
+ hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
+
+ # reset the RNN state in the middle of a sequence
+ if terminated is not None and torch.any(terminated):
+ rnn_outputs = []
+ terminated = terminated.view(-1, self.sequence_length)
+ indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
+
+ for i in range(len(indexes) - 1):
+ i0, i1 = indexes[i], indexes[i + 1]
+ rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
+ hidden_states[:, (terminated[:,i1-1]), :] = 0
+ rnn_outputs.append(rnn_output)
+
+ rnn_output = torch.cat(rnn_outputs, dim=1)
+ # no need to reset the RNN state in the sequence
+ else:
+ rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
+ # rollout
+ else:
+ rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
+ rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
+
+ # flatten the RNN output
+ rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
+
+ x = self.fc1(rnn_output)
+ x = F.relu(x)
+ x = self.fc2(x)
+ x = F.relu(x)
+
+ return self.logits(x), {"rnn": [hidden_states]}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = GRU(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum",
+ num_envs=env.num_envs,
+ num_layers=1,
+ hidden_size=64,
+ sequence_length=10)
+# [end-gru-functional-torch]
+
+# =============================================================================
+
+# [start-lstm-sequential-torch]
+import torch
+import torch.nn as nn
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class LSTM(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum",
+ num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.num_envs = num_envs
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
+ self.sequence_length = sequence_length
+
+ self.lstm = nn.LSTM(input_size=self.num_observations,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ batch_first=True) # batch_first -> (batch, sequence, features)
+
+ self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
+ nn.ReLU(),
+ nn.Linear(64, 32),
+ nn.ReLU(),
+ nn.Linear(32, self.num_actions))
+
+ def get_specification(self):
+ # batch size (N) is the number of envs during rollout
+ return {"rnn": {"sequence_length": self.sequence_length,
+ "sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
+ (self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
+
+ def compute(self, inputs, role):
+ states = inputs["states"]
+ terminated = inputs.get("terminated", None)
+ hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
+
+ # training
+ if self.training:
+ rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
+ hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
+ cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
+ # get the hidden/cell states corresponding to the initial sequence
+ hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
+ cell_states = cell_states[:,:,0,:].contiguous() # (D * num_layers, N, Hcell)
+
+ # reset the RNN state in the middle of a sequence
+ if terminated is not None and torch.any(terminated):
+ rnn_outputs = []
+ terminated = terminated.view(-1, self.sequence_length)
+ indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
+
+ for i in range(len(indexes) - 1):
+ i0, i1 = indexes[i], indexes[i + 1]
+ rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
+ hidden_states[:, (terminated[:,i1-1]), :] = 0
+ cell_states[:, (terminated[:,i1-1]), :] = 0
+ rnn_outputs.append(rnn_output)
+
+ rnn_states = (hidden_states, cell_states)
+ rnn_output = torch.cat(rnn_outputs, dim=1)
+ # no need to reset the RNN state in the sequence
+ else:
+ rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
+ # rollout
+ else:
+ rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
+ rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
+
+ # flatten the RNN output
+ rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
+
+ return self.net(rnn_output), {"rnn": [rnn_states[0], rnn_states[1]]}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = LSTM(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum",
+ num_envs=env.num_envs,
+ num_layers=1,
+ hidden_size=64,
+ sequence_length=10)
+# [end-lstm-sequential-torch]
+
+# [start-lstm-functional-torch]
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from skrl.models.torch import Model, MultiCategoricalMixin
+
+
+# define the model
+class LSTM(MultiCategoricalMixin, Model):
+ def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum",
+ num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
+ Model.__init__(self, observation_space, action_space, device)
+ MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+
+ self.num_envs = num_envs
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
+ self.sequence_length = sequence_length
+
+ self.lstm = nn.LSTM(input_size=self.num_observations,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ batch_first=True) # batch_first -> (batch, sequence, features)
+
+ self.fc1 = nn.Linear(self.hidden_size, 64)
+ self.fc2 = nn.Linear(64, 32)
+ self.logits = nn.Linear(32, self.num_actions)
+
+ def get_specification(self):
+ # batch size (N) is the number of envs during rollout
+ return {"rnn": {"sequence_length": self.sequence_length,
+ "sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
+ (self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
+
+ def compute(self, inputs, role):
+ states = inputs["states"]
+ terminated = inputs.get("terminated", None)
+ hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
+
+ # training
+ if self.training:
+ rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
+ hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
+ cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
+ # get the hidden/cell states corresponding to the initial sequence
+ hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
+ cell_states = cell_states[:,:,0,:].contiguous() # (D * num_layers, N, Hcell)
+
+ # reset the RNN state in the middle of a sequence
+ if terminated is not None and torch.any(terminated):
+ rnn_outputs = []
+ terminated = terminated.view(-1, self.sequence_length)
+ indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
+
+ for i in range(len(indexes) - 1):
+ i0, i1 = indexes[i], indexes[i + 1]
+ rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
+ hidden_states[:, (terminated[:,i1-1]), :] = 0
+ cell_states[:, (terminated[:,i1-1]), :] = 0
+ rnn_outputs.append(rnn_output)
+
+ rnn_states = (hidden_states, cell_states)
+ rnn_output = torch.cat(rnn_outputs, dim=1)
+ # no need to reset the RNN state in the sequence
+ else:
+ rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
+ # rollout
+ else:
+ rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
+ rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
+
+ # flatten the RNN output
+ rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
+
+ x = self.fc1(rnn_output)
+ x = F.relu(x)
+ x = self.fc2(x)
+ x = F.relu(x)
+
+ return self.logits(x), {"rnn": [rnn_states[0], rnn_states[1]]}
+
+
+# instantiate the model (assumes there is a wrapped environment: env)
+policy = LSTM(observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device,
+ unnormalized_log_prob=True,
+ reduction="sum",
+ num_envs=env.num_envs,
+ num_layers=1,
+ hidden_size=64,
+ sequence_length=10)
+# [end-lstm-functional-torch]
diff --git a/docs/source/snippets/trainer.py b/docs/source/snippets/trainer.py
index 7e4e7cbb..f2ce690a 100644
--- a/docs/source/snippets/trainer.py
+++ b/docs/source/snippets/trainer.py
@@ -199,15 +199,15 @@ def eval(self) -> None:
# =============================================================================
-# [pytorch-start-manual]
-from skrl.trainers.torch import ManualTrainer
+# [pytorch-start-step]
+from skrl.trainers.torch import StepTrainer
# assuming there is an environment called 'env'
# and an agent or a list of agents called 'agents'
# create a sequential trainer
cfg = {"timesteps": 50000, "headless": False}
-trainer = ManualTrainer(env=env, agents=agents, cfg=cfg)
+trainer = StepTrainer(env=env, agents=agents, cfg=cfg)
# train the agent(s)
for timestep in range(cfg["timesteps"]):
@@ -216,18 +216,18 @@ def eval(self) -> None:
# evaluate the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.eval(timestep=timestep)
-# [pytorch-end-manual]
+# [pytorch-end-step]
-# [jax-start-manual]
-from skrl.trainers.jax import ManualTrainer
+# [jax-start-step]
+from skrl.trainers.jax import StepTrainer
# assuming there is an environment called 'env'
# and an agent or a list of agents called 'agents'
# create a sequential trainer
cfg = {"timesteps": 50000, "headless": False}
-trainer = ManualTrainer(env=env, agents=agents, cfg=cfg)
+trainer = StepTrainer(env=env, agents=agents, cfg=cfg)
# train the agent(s)
for timestep in range(cfg["timesteps"]):
@@ -236,4 +236,44 @@ def eval(self) -> None:
# evaluate the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.eval(timestep=timestep)
-# [jax-end-manual]
+# [jax-end-step]
+
+# =============================================================================
+
+# [pytorch-start-manual-training]
+
+# [pytorch-end-manual-training]
+
+# [pytorch-start-manual-evaluation]
+# assuming there is an environment named 'env'
+# and an agent named 'agents' (or a state-preprocessor and a policy)
+
+states, infos = env.reset()
+
+for i in range(1000):
+ # state-preprocessor + policy
+ with torch.no_grad():
+ states = state_preprocessor(states)
+ actions = policy.act({"states": states})[0]
+
+ # step the environment
+ next_states, rewards, terminated, truncated, infos = env.step(actions)
+
+ # render the environment
+ env.render()
+
+ # check for termination/truncation
+ if terminated.any() or truncated.any():
+ states, infos = env.reset()
+ else:
+ states = next_states
+# [pytorch-end-manual-evaluation]
+
+
+# [jax-start-manual-training]
+
+# [jax-end-manual-training]
+
+# [jax-start-manual-evaluation]
+
+# [jax-end-manual-evaluation]
diff --git a/pyproject.toml b/pyproject.toml
index 2afb84a0..6de499cb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "skrl"
-version = "1.0.0"
+version = "1.1.0"
description = "Modular and flexible library for reinforcement learning on PyTorch and JAX"
readme = "README.md"
requires-python = ">=3.6"
diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py
index 60486358..fa1acb7b 100644
--- a/skrl/agents/jax/a2c/a2c.py
+++ b/skrl/agents/jax/a2c/a2c.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import functools
@@ -169,7 +169,7 @@ def _value_loss(params):
class A2C(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -270,7 +270,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/jax/base.py b/skrl/agents/jax/base.py
index e1d82079..71e9d091 100644
--- a/skrl/agents/jax/base.py
+++ b/skrl/agents/jax/base.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Mapping, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import collections
import copy
@@ -19,7 +19,7 @@
class Agent:
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -132,7 +132,7 @@ def _get_internal_value(self, _module: Any) -> Any:
"""
return _module.state_dict.params if hasattr(_module, "state_dict") else _module
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
This method should be called before the agent is used.
diff --git a/skrl/agents/jax/cem/cem.py b/skrl/agents/jax/cem/cem.py
index 2b0c4b9a..65c401e9 100644
--- a/skrl/agents/jax/cem/cem.py
+++ b/skrl/agents/jax/cem/cem.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import gym
@@ -52,7 +52,7 @@
class CEM(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -130,7 +130,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/jax/ddpg/ddpg.py b/skrl/agents/jax/ddpg/ddpg.py
index 769316ee..feb0e12f 100644
--- a/skrl/agents/jax/ddpg/ddpg.py
+++ b/skrl/agents/jax/ddpg/ddpg.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import functools
@@ -111,7 +111,7 @@ def _policy_loss(policy_params, critic_params):
class DDPG(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -214,7 +214,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/jax/dqn/ddqn.py b/skrl/agents/jax/dqn/ddqn.py
index 0d0968c3..47f7f53b 100644
--- a/skrl/agents/jax/dqn/ddqn.py
+++ b/skrl/agents/jax/dqn/ddqn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import functools
@@ -89,7 +89,7 @@ def _q_network_loss(params):
class DDQN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -182,7 +182,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py
index 49cb6e34..dcec3140 100644
--- a/skrl/agents/jax/dqn/dqn.py
+++ b/skrl/agents/jax/dqn/dqn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import functools
@@ -86,7 +86,7 @@ def _q_network_loss(params):
class DQN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -179,7 +179,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py
index 01a4942b..3437b049 100644
--- a/skrl/agents/jax/ppo/ppo.py
+++ b/skrl/agents/jax/ppo/ppo.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import functools
@@ -188,7 +188,7 @@ def _value_loss(params):
class PPO(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -296,7 +296,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py
index 281a7210..3e90d06d 100644
--- a/skrl/agents/jax/rpo/rpo.py
+++ b/skrl/agents/jax/rpo/rpo.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import functools
@@ -191,7 +191,7 @@ def _value_loss(params):
class RPO(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -300,7 +300,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py
index 5a7a7c5a..75cecf11 100644
--- a/skrl/agents/jax/sac/sac.py
+++ b/skrl/agents/jax/sac/sac.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import functools
@@ -122,7 +122,7 @@ def _entropy_loss(params):
class SAC(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -255,7 +255,7 @@ def value(self):
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py
index 568452f4..1c2544d8 100644
--- a/skrl/agents/jax/td3/td3.py
+++ b/skrl/agents/jax/td3/td3.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import functools
@@ -130,7 +130,7 @@ def _policy_loss(policy_params, critic_1_params):
class TD3(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -250,7 +250,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py
index 8903a180..97af7cb3 100644
--- a/skrl/agents/torch/a2c/a2c.py
+++ b/skrl/agents/torch/a2c/a2c.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -59,7 +59,7 @@
class A2C(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -153,7 +153,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py
index c17e38e2..9b24bc09 100644
--- a/skrl/agents/torch/a2c/a2c_rnn.py
+++ b/skrl/agents/torch/a2c/a2c_rnn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -59,7 +59,7 @@
class A2C_RNN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -153,7 +153,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py
index 28889f9e..e9311dae 100644
--- a/skrl/agents/torch/amp/amp.py
+++ b/skrl/agents/torch/amp/amp.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Optional, Tuple, Union
+from typing import Any, Callable, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -76,7 +76,7 @@
class AMP(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -218,7 +218,7 @@ def __init__(self,
else:
self._amp_state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/base.py b/skrl/agents/torch/base.py
index b6355f52..237a0953 100644
--- a/skrl/agents/torch/base.py
+++ b/skrl/agents/torch/base.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Mapping, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import collections
import copy
@@ -18,7 +18,7 @@
class Agent:
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -125,7 +125,7 @@ def _get_internal_value(self, _module: Any) -> Any:
"""
return _module.state_dict() if hasattr(_module, "state_dict") else _module
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
This method should be called before the agent is used.
diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py
index dd5e1ff3..a99eff5a 100644
--- a/skrl/agents/torch/cem/cem.py
+++ b/skrl/agents/torch/cem/cem.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import gym
@@ -49,7 +49,7 @@
class CEM(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -126,7 +126,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py
index ede9b7c4..a5270909 100644
--- a/skrl/agents/torch/ddpg/ddpg.py
+++ b/skrl/agents/torch/ddpg/ddpg.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import gym
@@ -60,7 +60,7 @@
class DDPG(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -161,7 +161,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py
index 1d52937a..e1a8142e 100644
--- a/skrl/agents/torch/ddpg/ddpg_rnn.py
+++ b/skrl/agents/torch/ddpg/ddpg_rnn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import gym
@@ -60,7 +60,7 @@
class DDPG_RNN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -161,7 +161,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py
index 479ba934..84352027 100644
--- a/skrl/agents/torch/dqn/ddqn.py
+++ b/skrl/agents/torch/dqn/ddqn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import math
@@ -59,7 +59,7 @@
class DDQN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -150,7 +150,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py
index 782befaf..4c485524 100644
--- a/skrl/agents/torch/dqn/dqn.py
+++ b/skrl/agents/torch/dqn/dqn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import math
@@ -59,7 +59,7 @@
class DQN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -150,7 +150,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py
index 2c035a6a..8c2315bd 100644
--- a/skrl/agents/torch/ppo/ppo.py
+++ b/skrl/agents/torch/ppo/ppo.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -66,7 +66,7 @@
class PPO(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -167,7 +167,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py
index 8086995e..ccabafca 100644
--- a/skrl/agents/torch/ppo/ppo_rnn.py
+++ b/skrl/agents/torch/ppo/ppo_rnn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -66,7 +66,7 @@
class PPO_RNN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -167,7 +167,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/q_learning/q_learning.py b/skrl/agents/torch/q_learning/q_learning.py
index 5be6ac97..16212d8f 100644
--- a/skrl/agents/torch/q_learning/q_learning.py
+++ b/skrl/agents/torch/q_learning/q_learning.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import gym
@@ -39,7 +39,7 @@
class Q_LEARNING(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -99,7 +99,7 @@ def __init__(self,
self._current_next_states = None
self._current_dones = None
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py
index e7528c33..5929f54e 100644
--- a/skrl/agents/torch/rpo/rpo.py
+++ b/skrl/agents/torch/rpo/rpo.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -67,7 +67,7 @@
class RPO(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -169,7 +169,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py
index 1060adb7..382d1efb 100644
--- a/skrl/agents/torch/rpo/rpo_rnn.py
+++ b/skrl/agents/torch/rpo/rpo_rnn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -67,7 +67,7 @@
class RPO_RNN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -169,7 +169,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py
index d41b4830..22468d80 100644
--- a/skrl/agents/torch/sac/sac.py
+++ b/skrl/agents/torch/sac/sac.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -60,7 +60,7 @@
class SAC(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -179,7 +179,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py
index 501dc122..755cbeab 100644
--- a/skrl/agents/torch/sac/sac_rnn.py
+++ b/skrl/agents/torch/sac/sac_rnn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -60,7 +60,7 @@
class SAC_RNN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -179,7 +179,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/sarsa/sarsa.py b/skrl/agents/torch/sarsa/sarsa.py
index bb717025..4cc14c34 100644
--- a/skrl/agents/torch/sarsa/sarsa.py
+++ b/skrl/agents/torch/sarsa/sarsa.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import gym
@@ -39,7 +39,7 @@
class SARSA(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -99,7 +99,7 @@ def __init__(self,
self._current_next_states = None
self._current_dones = None
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py
index a6eb23c8..86275243 100644
--- a/skrl/agents/torch/td3/td3.py
+++ b/skrl/agents/torch/td3/td3.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -66,7 +66,7 @@
class TD3(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -182,7 +182,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py
index ea906b15..fdd619d8 100644
--- a/skrl/agents/torch/td3/td3_rnn.py
+++ b/skrl/agents/torch/td3/td3_rnn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import itertools
@@ -66,7 +66,7 @@
class TD3_RNN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -182,7 +182,7 @@ def __init__(self,
else:
self._state_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py
index b7f16036..2c00b69b 100644
--- a/skrl/agents/torch/trpo/trpo.py
+++ b/skrl/agents/torch/trpo/trpo.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import gym
@@ -65,7 +65,7 @@
class TRPO(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -164,7 +164,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py
index bc8ae463..58b187e5 100644
--- a/skrl/agents/torch/trpo/trpo_rnn.py
+++ b/skrl/agents/torch/trpo/trpo_rnn.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Mapping, Optional, Tuple, Union
import copy
import gym
@@ -65,7 +65,7 @@
class TRPO_RNN(Agent):
def __init__(self,
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
@@ -164,7 +164,7 @@ def __init__(self,
else:
self._value_preprocessor = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/envs/wrappers/jax/base.py b/skrl/envs/wrappers/jax/base.py
index 09170dcc..0be8b742 100644
--- a/skrl/envs/wrappers/jax/base.py
+++ b/skrl/envs/wrappers/jax/base.py
@@ -26,6 +26,14 @@ def __init__(self, env: Any) -> None:
self.device = jax.devices(self._env.device.split(':')[0] if type(self._env.device) == str else self._env.device.type)[0]
except RuntimeError:
pass
+ # spaces
+ try:
+ self._action_space = self._env.single_action_space
+ self._observation_space = self._env.single_observation_space
+ except AttributeError:
+ self._action_space = self._env.action_space
+ self._observation_space = self._env.observation_space
+ self._state_space = self._env.state_space if hasattr(self._env, "state_space") else self._observation_space
def __getattr__(self, key: str) -> Any:
"""Get an attribute from the wrapped environment
@@ -100,19 +108,19 @@ def state_space(self) -> gym.Space:
If the wrapped environment does not have the ``state_space`` property,
the value of the ``observation_space`` property will be used
"""
- return self._env.state_space if hasattr(self._env, "state_space") else self._env.observation_space
+ return self._state_space
@property
def observation_space(self) -> gym.Space:
"""Observation space
"""
- return self._env.observation_space
+ return self._observation_space
@property
def action_space(self) -> gym.Space:
"""Action space
"""
- return self._env.action_space
+ return self._action_space
class MultiAgentEnvWrapper(object):
diff --git a/skrl/envs/wrappers/jax/bidexhands_envs.py b/skrl/envs/wrappers/jax/bidexhands_envs.py
index 2ca535f9..fe63563c 100644
--- a/skrl/envs/wrappers/jax/bidexhands_envs.py
+++ b/skrl/envs/wrappers/jax/bidexhands_envs.py
@@ -5,8 +5,13 @@
import jax
import jax.dlpack
import numpy as np
-import torch
-import torch.utils.dlpack
+
+
+try:
+ import torch
+ import torch.utils.dlpack
+except:
+ pass # TODO: show warning message
from skrl.envs.wrappers.jax.base import MultiAgentEnvWrapper
diff --git a/skrl/envs/wrappers/jax/gym_envs.py b/skrl/envs/wrappers/jax/gym_envs.py
index 685bbbb6..3b073d9e 100644
--- a/skrl/envs/wrappers/jax/gym_envs.py
+++ b/skrl/envs/wrappers/jax/gym_envs.py
@@ -113,6 +113,8 @@ def _tensor_to_action(self, actions: np.ndarray) -> Any:
return actions.astype(space[0].dtype).reshape(-1)
elif isinstance(space, gym.spaces.Discrete):
return actions.item()
+ elif isinstance(space, gym.spaces.MultiDiscrete):
+ return actions.astype(space.dtype).reshape(space.shape)
elif isinstance(space, gym.spaces.Box):
return actions.astype(space.dtype).reshape(space.shape)
raise ValueError(f"Action space type {type(space)} not supported. Please report this issue")
diff --git a/skrl/envs/wrappers/jax/gymnasium_envs.py b/skrl/envs/wrappers/jax/gymnasium_envs.py
index c2cc1d9f..44a45633 100644
--- a/skrl/envs/wrappers/jax/gymnasium_envs.py
+++ b/skrl/envs/wrappers/jax/gymnasium_envs.py
@@ -108,6 +108,8 @@ def _tensor_to_action(self, actions: np.ndarray) -> Any:
return actions.astype(space[0].dtype).reshape(-1)
if isinstance(space, gymnasium.spaces.Discrete):
return actions.item()
+ elif isinstance(space, gymnasium.spaces.MultiDiscrete):
+ return actions.astype(space.dtype).reshape(space.shape)
elif isinstance(space, gymnasium.spaces.Box):
return actions.astype(space.dtype).reshape(space.shape)
raise ValueError(f"Action space type {type(space)} not supported. Please report this issue")
diff --git a/skrl/envs/wrappers/jax/isaac_orbit_envs.py b/skrl/envs/wrappers/jax/isaac_orbit_envs.py
index 68adacdb..c1e897c9 100644
--- a/skrl/envs/wrappers/jax/isaac_orbit_envs.py
+++ b/skrl/envs/wrappers/jax/isaac_orbit_envs.py
@@ -3,8 +3,13 @@
import jax
import jax.dlpack as jax_dlpack
import numpy as np
-import torch
-import torch.utils.dlpack as torch_dlpack
+
+
+try:
+ import torch
+ import torch.utils.dlpack as torch_dlpack
+except:
+ pass # TODO: show warning message
from skrl import logger
from skrl.envs.wrappers.jax.base import Wrapper
@@ -53,10 +58,10 @@ def step(self, actions: Union[np.ndarray, jax.Array]) -> \
actions = _jax2torch(actions, self._env.device, self._jax)
with torch.no_grad():
- self._obs_dict, reward, terminated, info = self._env.step(actions)
+ self._obs_dict, reward, terminated, truncated, info = self._env.step(actions)
terminated = terminated.to(dtype=torch.int8)
- truncated = info["time_outs"].to(dtype=torch.int8) if "time_outs" in info else torch.zeros_like(terminated)
+ truncated = truncated.to(dtype=torch.int8)
return _torch2jax(self._obs_dict["policy"], self._jax), \
_torch2jax(reward.view(-1, 1), self._jax), \
@@ -71,9 +76,9 @@ def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]:
:rtype: np.ndarray or jax.Array and any other info
"""
if self._reset_once:
- self._obs_dict = self._env.reset()
+ self._obs_dict, info = self._env.reset()
self._reset_once = False
- return _torch2jax(self._obs_dict["policy"], self._jax), {}
+ return _torch2jax(self._obs_dict["policy"], self._jax), info
def render(self, *args, **kwargs) -> None:
"""Render the environment
diff --git a/skrl/envs/wrappers/jax/isaacgym_envs.py b/skrl/envs/wrappers/jax/isaacgym_envs.py
index a426a557..459444b9 100644
--- a/skrl/envs/wrappers/jax/isaacgym_envs.py
+++ b/skrl/envs/wrappers/jax/isaacgym_envs.py
@@ -3,8 +3,13 @@
import jax
import jax.dlpack as jax_dlpack
import numpy as np
-import torch
-import torch.utils.dlpack as torch_dlpack
+
+
+try:
+ import torch
+ import torch.utils.dlpack as torch_dlpack
+except:
+ pass # TODO: show warning message
from skrl import logger
from skrl.envs.wrappers.jax.base import Wrapper
diff --git a/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py b/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py
index 23105bc1..e49ee786 100644
--- a/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py
+++ b/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py
@@ -3,8 +3,13 @@
import jax
import jax.dlpack as jax_dlpack
import numpy as np
-import torch
-import torch.utils.dlpack as torch_dlpack
+
+
+try:
+ import torch
+ import torch.utils.dlpack as torch_dlpack
+except:
+ pass # TODO: show warning message
from skrl import logger
from skrl.envs.wrappers.jax.base import Wrapper
diff --git a/skrl/envs/wrappers/torch/base.py b/skrl/envs/wrappers/torch/base.py
index 233c2a08..85e79b08 100644
--- a/skrl/envs/wrappers/torch/base.py
+++ b/skrl/envs/wrappers/torch/base.py
@@ -19,6 +19,14 @@ def __init__(self, env: Any) -> None:
self.device = torch.device(self._env.device)
else:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ # spaces
+ try:
+ self._action_space = self._env.single_action_space
+ self._observation_space = self._env.single_observation_space
+ except AttributeError:
+ self._action_space = self._env.action_space
+ self._observation_space = self._env.observation_space
+ self._state_space = self._env.state_space if hasattr(self._env, "state_space") else self._observation_space
def __getattr__(self, key: str) -> Any:
"""Get an attribute from the wrapped environment
@@ -91,19 +99,19 @@ def state_space(self) -> gym.Space:
If the wrapped environment does not have the ``state_space`` property,
the value of the ``observation_space`` property will be used
"""
- return self._env.state_space if hasattr(self._env, "state_space") else self._env.observation_space
+ return self._state_space
@property
def observation_space(self) -> gym.Space:
"""Observation space
"""
- return self._env.observation_space
+ return self._observation_space
@property
def action_space(self) -> gym.Space:
"""Action space
"""
- return self._env.action_space
+ return self._action_space
class MultiAgentEnvWrapper(object):
diff --git a/skrl/envs/wrappers/torch/gym_envs.py b/skrl/envs/wrappers/torch/gym_envs.py
index ee50cbf9..7b8b6af8 100644
--- a/skrl/envs/wrappers/torch/gym_envs.py
+++ b/skrl/envs/wrappers/torch/gym_envs.py
@@ -113,6 +113,8 @@ def _tensor_to_action(self, actions: torch.Tensor) -> Any:
return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(-1)
elif isinstance(space, gym.spaces.Discrete):
return actions.item()
+ elif isinstance(space, gym.spaces.MultiDiscrete):
+ return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
elif isinstance(space, gym.spaces.Box):
return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
raise ValueError(f"Action space type {type(space)} not supported. Please report this issue")
diff --git a/skrl/envs/wrappers/torch/gymnasium_envs.py b/skrl/envs/wrappers/torch/gymnasium_envs.py
index 0009575b..74db7e0d 100644
--- a/skrl/envs/wrappers/torch/gymnasium_envs.py
+++ b/skrl/envs/wrappers/torch/gymnasium_envs.py
@@ -108,6 +108,8 @@ def _tensor_to_action(self, actions: torch.Tensor) -> Any:
return np.array(actions.cpu().numpy(), dtype=space[0].dtype).reshape(-1)
if isinstance(space, gymnasium.spaces.Discrete):
return actions.item()
+ elif isinstance(space, gymnasium.spaces.MultiDiscrete):
+ return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
elif isinstance(space, gymnasium.spaces.Box):
return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
raise ValueError(f"Action space type {type(space)} not supported. Please report this issue")
diff --git a/skrl/envs/wrappers/torch/isaac_orbit_envs.py b/skrl/envs/wrappers/torch/isaac_orbit_envs.py
index e558c698..c9670ce1 100644
--- a/skrl/envs/wrappers/torch/isaac_orbit_envs.py
+++ b/skrl/envs/wrappers/torch/isaac_orbit_envs.py
@@ -26,8 +26,7 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of torch.Tensor and any other info
"""
- self._obs_dict, reward, terminated, info = self._env.step(actions)
- truncated = info["time_outs"] if "time_outs" in info else torch.zeros_like(terminated)
+ self._obs_dict, reward, terminated, truncated, info = self._env.step(actions)
return self._obs_dict["policy"], reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), info
def reset(self) -> Tuple[torch.Tensor, Any]:
@@ -37,9 +36,9 @@ def reset(self) -> Tuple[torch.Tensor, Any]:
:rtype: torch.Tensor and any other info
"""
if self._reset_once:
- self._obs_dict = self._env.reset()
+ self._obs_dict, info = self._env.reset()
self._reset_once = False
- return self._obs_dict["policy"], {}
+ return self._obs_dict["policy"], info
def render(self, *args, **kwargs) -> None:
"""Render the environment
diff --git a/skrl/memories/jax/base.py b/skrl/memories/jax/base.py
index f0ea27a2..41cea1b1 100644
--- a/skrl/memories/jax/base.py
+++ b/skrl/memories/jax/base.py
@@ -126,6 +126,8 @@ def _get_space_size(self,
elif issubclass(type(space), gym.Space):
if issubclass(type(space), gym.spaces.Discrete):
return (1,) if keep_dimensions else 1
+ elif issubclass(type(space), gym.spaces.MultiDiscrete):
+ return space.nvec.shape[0]
elif issubclass(type(space), gym.spaces.Box):
return tuple(space.shape) if keep_dimensions else np.prod(space.shape)
elif issubclass(type(space), gym.spaces.Dict):
@@ -135,6 +137,8 @@ def _get_space_size(self,
elif issubclass(type(space), gymnasium.Space):
if issubclass(type(space), gymnasium.spaces.Discrete):
return (1,) if keep_dimensions else 1
+ elif issubclass(type(space), gymnasium.spaces.MultiDiscrete):
+ return space.nvec.shape[0]
elif issubclass(type(space), gymnasium.spaces.Box):
return tuple(space.shape) if keep_dimensions else np.prod(space.shape)
elif issubclass(type(space), gymnasium.spaces.Dict):
diff --git a/skrl/memories/torch/base.py b/skrl/memories/torch/base.py
index da10626f..7d0f0615 100644
--- a/skrl/memories/torch/base.py
+++ b/skrl/memories/torch/base.py
@@ -102,6 +102,8 @@ def _get_space_size(self,
elif issubclass(type(space), gym.Space):
if issubclass(type(space), gym.spaces.Discrete):
return (1,) if keep_dimensions else 1
+ elif issubclass(type(space), gym.spaces.MultiDiscrete):
+ return space.nvec.shape[0]
elif issubclass(type(space), gym.spaces.Box):
return tuple(space.shape) if keep_dimensions else np.prod(space.shape)
elif issubclass(type(space), gym.spaces.Dict):
@@ -111,6 +113,8 @@ def _get_space_size(self,
elif issubclass(type(space), gymnasium.Space):
if issubclass(type(space), gymnasium.spaces.Discrete):
return (1,) if keep_dimensions else 1
+ elif issubclass(type(space), gymnasium.spaces.MultiDiscrete):
+ return space.nvec.shape[0]
elif issubclass(type(space), gymnasium.spaces.Box):
return tuple(space.shape) if keep_dimensions else np.prod(space.shape)
elif issubclass(type(space), gymnasium.spaces.Dict):
diff --git a/skrl/models/jax/__init__.py b/skrl/models/jax/__init__.py
index fd3b1064..ef2386fd 100644
--- a/skrl/models/jax/__init__.py
+++ b/skrl/models/jax/__init__.py
@@ -3,3 +3,4 @@
from skrl.models.jax.categorical import CategoricalMixin
from skrl.models.jax.deterministic import DeterministicMixin
from skrl.models.jax.gaussian import GaussianMixin
+from skrl.models.jax.multicategorical import MultiCategoricalMixin
diff --git a/skrl/models/jax/multicategorical.py b/skrl/models/jax/multicategorical.py
new file mode 100644
index 00000000..9ad1bb3e
--- /dev/null
+++ b/skrl/models/jax/multicategorical.py
@@ -0,0 +1,192 @@
+from typing import Any, Mapping, Optional, Tuple, Union
+
+from functools import partial
+
+import flax
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+from skrl import config
+
+
+# https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function
+@partial(jax.jit, static_argnames=("unnormalized_log_prob"))
+def _categorical(net_output,
+ unnormalized_log_prob,
+ taken_actions,
+ key):
+ # normalize
+ if unnormalized_log_prob:
+ logits = net_output - jax.scipy.special.logsumexp(net_output, axis=-1, keepdims=True)
+ # probs = jax.nn.softmax(logits)
+ else:
+ probs = net_output / net_output.sum(-1, keepdims=True)
+ eps = jnp.finfo(probs.dtype).eps
+ logits = jnp.log(probs.clip(min=eps, max=1 - eps))
+
+ # sample actions
+ actions = jax.random.categorical(key, logits, axis=-1, shape=None)
+
+ # log of the probability density function
+ taken_actions = actions if taken_actions is None else taken_actions.astype(jnp.int32).reshape(-1)
+ log_prob = jax.nn.log_softmax(logits)[jnp.arange(taken_actions.shape[0]), taken_actions]
+
+ return actions.reshape(-1, 1), log_prob.reshape(-1, 1)
+
+@jax.jit
+def _entropy(logits):
+ logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
+ logits = logits.clip(min=jnp.finfo(logits.dtype).min)
+ p_log_p = logits * jax.nn.softmax(logits)
+ return -p_log_p.sum(-1)
+
+
+class MultiCategoricalMixin:
+ def __init__(self, unnormalized_log_prob: bool = True, reduction: str = "sum", role: str = "") -> None:
+ """MultiCategorical mixin model (stochastic model)
+
+ :param unnormalized_log_prob: Flag to indicate how to be interpreted the model's output (default: ``True``).
+ If True, the model's output is interpreted as unnormalized log probabilities
+ (it can be any real number), otherwise as normalized probabilities
+ (the output must be non-negative, finite and have a non-zero sum)
+ :type unnormalized_log_prob: bool, optional
+ :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``).
+ Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density
+ function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)``
+ :type reduction: str, optional
+ :param role: Role play by the model (default: ``""``)
+ :type role: str, optional
+
+ :raises ValueError: If the reduction method is not valid
+
+ Example::
+
+ # define the model
+ >>> import flax.linen as nn
+ >>> from skrl.models.jax import Model, MultiCategoricalMixin
+ >>>
+ >>> class Policy(MultiCategoricalMixin, Model):
+ ... def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, reduction="sum", **kwargs):
+ ... Model.__init__(self, observation_space, action_space, device, **kwargs)
+ ... MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+ ...
+ ... @nn.compact # marks the given module method allowing inlined submodules
+ ... def __call__(self, inputs, role):
+ ... x = nn.elu(nn.Dense(32)(inputs["states"]))
+ ... x = nn.elu(nn.Dense(32)(x))
+ ... x = nn.Dense(self.num_actions)(x)
+ ... return x, {}
+ ...
+ >>> # given an observation_space: gym.spaces.Box with shape (4,)
+ >>> # and an action_space: gym.spaces.MultiDiscrete with nvec = [3, 2]
+ >>> model = Policy(observation_space, action_space)
+ >>>
+ >>> print(model)
+ Policy(
+ # attributes
+ observation_space = Box(-1.0, 1.0, (4,), float32)
+ action_space = MultiDiscrete([3 2])
+ device = StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)
+ )
+ """
+ self._unnormalized_log_prob = unnormalized_log_prob
+
+ if reduction not in ["mean", "sum", "prod", "none"]:
+ raise ValueError("reduction must be one of 'mean', 'sum', 'prod' or 'none'")
+ self._reduction = jnp.mean if reduction == "mean" else jnp.sum if reduction == "sum" \
+ else jnp.prod if reduction == "prod" else None
+
+ self._i = 0
+ self._key = config.jax.key
+
+ self._action_space_nvec = np.cumsum(self.action_space.nvec).tolist()
+ self._action_space_shape = self._get_space_size(self.action_space, number_of_elements=False)
+
+ # https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.IncorrectPostInitOverrideError
+ flax.linen.Module.__post_init__(self)
+
+ def act(self,
+ inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]],
+ role: str = "",
+ params: Optional[jax.Array] = None) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]:
+ """Act stochastically in response to the state of the environment
+
+ :param inputs: Model inputs. The most common keys are:
+
+ - ``"states"``: state of the environment used to make the decision
+ - ``"taken_actions"``: actions taken by the policy for the given states
+ :type inputs: dict where the values are typically np.ndarray or jax.Array
+ :param role: Role play by the model (default: ``""``)
+ :type role: str, optional
+ :param params: Parameters used to compute the output (default: ``None``).
+ If ``None``, internal parameters will be used
+ :type params: jnp.array
+
+ :return: Model output. The first component is the action to be taken by the agent.
+ The second component is the log of the probability density function.
+ The third component is a dictionary containing the network output ``"net_output"``
+ and extra output values
+ :rtype: tuple of jax.Array, jax.Array or None, and dict
+
+ Example::
+
+ >>> # given a batch of sample states with shape (4096, 4)
+ >>> actions, log_prob, outputs = model.act({"states": states})
+ >>> print(actions.shape, log_prob.shape, outputs["net_output"].shape)
+ (4096, 2) (4096, 1) (4096, 5)
+ """
+ self._i += 1
+ subkey = jax.random.fold_in(self._key, self._i)
+ inputs["key"] = subkey
+
+ # map from states/observations to normalized probabilities or unnormalized log probabilities
+ net_output, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role)
+
+ # split inputs
+ net_outputs = jnp.split(net_output, self._action_space_nvec, axis=-1)
+ if "taken_actions" in inputs:
+ taken_actions = jnp.split(inputs["taken_actions"], self._action_space_shape, axis=-1)
+ else:
+ taken_actions = [None] * self._action_space_shape
+
+ # compute actions and log_prob
+ actions, log_prob = [], []
+ for _net_output, _taken_actions in zip(net_outputs, taken_actions):
+ _actions, _log_prob = _categorical(_net_output,
+ self._unnormalized_log_prob,
+ _taken_actions,
+ subkey)
+ actions.append(_actions)
+ log_prob.append(_log_prob)
+
+ actions = jnp.concatenate(actions, axis=-1)
+ log_prob = jnp.concatenate(log_prob, axis=-1)
+
+ if self._reduction is not None:
+ log_prob = self._reduction(log_prob, axis=-1)
+ if log_prob.ndim != actions.ndim:
+ log_prob = jnp.expand_dims(log_prob, -1)
+
+ outputs["net_output"] = net_output
+ # avoid jax.errors.UnexpectedTracerError
+ outputs["stddev"] = jnp.full_like(log_prob, jnp.nan)
+ return actions, log_prob, outputs
+
+ def get_entropy(self, logits: jax.Array, role: str = "") -> jax.Array:
+ """Compute and return the entropy of the model
+
+ :param role: Role play by the model (default: ``""``)
+ :type role: str, optional
+
+ :return: Entropy of the model
+ :rtype: jax.Array
+
+ Example::
+
+ # given a standard deviation array: stddev
+ >>> entropy = model.get_entropy(stddev)
+ >>> print(entropy.shape)
+ (4096, 8)
+ """
+ return _entropy(logits)
diff --git a/skrl/models/torch/__init__.py b/skrl/models/torch/__init__.py
index c7be0d71..774ebfeb 100644
--- a/skrl/models/torch/__init__.py
+++ b/skrl/models/torch/__init__.py
@@ -3,5 +3,6 @@
from skrl.models.torch.categorical import CategoricalMixin
from skrl.models.torch.deterministic import DeterministicMixin
from skrl.models.torch.gaussian import GaussianMixin
+from skrl.models.torch.multicategorical import MultiCategoricalMixin
from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin
from skrl.models.torch.tabular import TabularMixin
diff --git a/skrl/models/torch/categorical.py b/skrl/models/torch/categorical.py
index 9181dc89..6b338ca5 100644
--- a/skrl/models/torch/categorical.py
+++ b/skrl/models/torch/categorical.py
@@ -52,13 +52,8 @@ def __init__(self, unnormalized_log_prob: bool = True, role: str = "") -> None:
)
)
"""
- if not hasattr(self, "_c_unnormalized_log_prob"):
- self._c_unnormalized_log_prob = {}
- self._c_unnormalized_log_prob[role] = unnormalized_log_prob
-
- if not hasattr(self, "_c_distribution"):
- self._c_distribution = {}
- self._c_distribution[role] = None
+ self._unnormalized_log_prob = unnormalized_log_prob
+ self._distribution = None
def act(self,
inputs: Mapping[str, Union[torch.Tensor, Any]],
@@ -90,15 +85,15 @@ def act(self,
net_output, outputs = self.compute(inputs, role)
# unnormalized log probabilities
- if self._c_unnormalized_log_prob[role] if role in self._c_unnormalized_log_prob else self._c_unnormalized_log_prob[""]:
- self._c_distribution[role] = Categorical(logits=net_output)
+ if self._unnormalized_log_prob:
+ self._distribution = Categorical(logits=net_output)
# normalized probabilities
else:
- self._c_distribution[role] = Categorical(probs=net_output)
+ self._distribution = Categorical(probs=net_output)
# actions and log of the probability density function
- actions = self._c_distribution[role].sample()
- log_prob = self._c_distribution[role].log_prob(inputs.get("taken_actions", actions).view(-1))
+ actions = self._distribution.sample()
+ log_prob = self._distribution.log_prob(inputs.get("taken_actions", actions).view(-1))
outputs["net_output"] = net_output
return actions.unsqueeze(-1), log_prob.unsqueeze(-1), outputs
@@ -117,10 +112,9 @@ def get_entropy(self, role: str = "") -> torch.Tensor:
>>> print(entropy.shape)
torch.Size([4096, 1])
"""
- distribution = self._c_distribution[role] if role in self._c_distribution else self._c_distribution[""]
- if distribution is None:
+ if self._distribution is None:
return torch.tensor(0.0, device=self.device)
- return distribution.entropy().to(self.device)
+ return self._distribution.entropy().to(self.device)
def distribution(self, role: str = "") -> torch.distributions.Categorical:
"""Get the current distribution of the model
@@ -136,4 +130,4 @@ def distribution(self, role: str = "") -> torch.distributions.Categorical:
>>> print(distribution)
Categorical(probs: torch.Size([4096, 2]), logits: torch.Size([4096, 2]))
"""
- return self._c_distribution[role] if role in self._c_distribution else self._c_distribution[""]
+ return self._distribution
diff --git a/skrl/models/torch/deterministic.py b/skrl/models/torch/deterministic.py
index a5624104..af6cdce5 100644
--- a/skrl/models/torch/deterministic.py
+++ b/skrl/models/torch/deterministic.py
@@ -51,14 +51,12 @@ def __init__(self, clip_actions: bool = False, role: str = "") -> None:
)
)
"""
- if not hasattr(self, "_d_clip_actions"):
- self._d_clip_actions = {}
- self._d_clip_actions[role] = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
+ self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
issubclass(type(self.action_space), gymnasium.Space))
- if self._d_clip_actions[role]:
- self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
- self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32)
+ if self._clip_actions:
+ self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
+ self._clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32)
def act(self,
inputs: Mapping[str, Union[torch.Tensor, Any]],
@@ -88,7 +86,7 @@ def act(self,
actions, outputs = self.compute(inputs, role)
# clip actions
- if self._d_clip_actions[role] if role in self._d_clip_actions else self._d_clip_actions[""]:
- actions = torch.clamp(actions, min=self.clip_actions_min, max=self.clip_actions_max)
+ if self._clip_actions:
+ actions = torch.clamp(actions, min=self._clip_actions_min, max=self._clip_actions_max)
return actions, None, outputs
diff --git a/skrl/models/torch/gaussian.py b/skrl/models/torch/gaussian.py
index fb4f13d6..a9721b63 100644
--- a/skrl/models/torch/gaussian.py
+++ b/skrl/models/torch/gaussian.py
@@ -72,40 +72,24 @@ def __init__(self,
)
)
"""
- if not hasattr(self, "_g_clip_actions"):
- self._g_clip_actions = {}
- self._g_clip_actions[role] = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
+ self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
issubclass(type(self.action_space), gymnasium.Space))
- if self._g_clip_actions[role]:
- self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
- self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32)
-
- if not hasattr(self, "_g_clip_log_std"):
- self._g_clip_log_std = {}
- self._g_clip_log_std[role] = clip_log_std
- if not hasattr(self, "_g_log_std_min"):
- self._g_log_std_min = {}
- self._g_log_std_min[role] = min_log_std
- if not hasattr(self, "_g_log_std_max"):
- self._g_log_std_max = {}
- self._g_log_std_max[role] = max_log_std
-
- if not hasattr(self, "_g_log_std"):
- self._g_log_std = {}
- self._g_log_std[role] = None
- if not hasattr(self, "_g_num_samples"):
- self._g_num_samples = {}
- self._g_num_samples[role] = None
- if not hasattr(self, "_g_distribution"):
- self._g_distribution = {}
- self._g_distribution[role] = None
+ if self._clip_actions:
+ self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
+ self._clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32)
+
+ self._clip_log_std = clip_log_std
+ self._log_std_min = min_log_std
+ self._log_std_max = max_log_std
+
+ self._log_std = None
+ self._num_samples = None
+ self._distribution = None
if reduction not in ["mean", "sum", "prod", "none"]:
raise ValueError("reduction must be one of 'mean', 'sum', 'prod' or 'none'")
- if not hasattr(self, "_g_reduction"):
- self._g_reduction = {}
- self._g_reduction[role] = torch.mean if reduction == "mean" else torch.sum if reduction == "sum" \
+ self._reduction = torch.mean if reduction == "mean" else torch.sum if reduction == "sum" \
else torch.prod if reduction == "prod" else None
def act(self,
@@ -138,29 +122,26 @@ def act(self,
mean_actions, log_std, outputs = self.compute(inputs, role)
# clamp log standard deviations
- if self._g_clip_log_std[role] if role in self._g_clip_log_std else self._g_clip_log_std[""]:
- log_std = torch.clamp(log_std,
- self._g_log_std_min[role] if role in self._g_log_std_min else self._g_log_std_min[""],
- self._g_log_std_max[role] if role in self._g_log_std_max else self._g_log_std_max[""])
+ if self._clip_log_std:
+ log_std = torch.clamp(log_std, self._log_std_min, self._log_std_max)
- self._g_log_std[role] = log_std
- self._g_num_samples[role] = mean_actions.shape[0]
+ self._log_std = log_std
+ self._num_samples = mean_actions.shape[0]
# distribution
- self._g_distribution[role] = Normal(mean_actions, log_std.exp())
+ self._distribution = Normal(mean_actions, log_std.exp())
# sample using the reparameterization trick
- actions = self._g_distribution[role].rsample()
+ actions = self._distribution.rsample()
# clip actions
- if self._g_clip_actions[role] if role in self._g_clip_actions else self._g_clip_actions[""]:
- actions = torch.clamp(actions, min=self.clip_actions_min, max=self.clip_actions_max)
+ if self._clip_actions:
+ actions = torch.clamp(actions, min=self._clip_actions_min, max=self._clip_actions_max)
# log of the probability density function
- log_prob = self._g_distribution[role].log_prob(inputs.get("taken_actions", actions))
- reduction = self._g_reduction[role] if role in self._g_reduction else self._g_reduction[""]
- if reduction is not None:
- log_prob = reduction(log_prob, dim=-1)
+ log_prob = self._distribution.log_prob(inputs.get("taken_actions", actions))
+ if self._reduction is not None:
+ log_prob = self._reduction(log_prob, dim=-1)
if log_prob.dim() != actions.dim():
log_prob = log_prob.unsqueeze(-1)
@@ -181,10 +162,9 @@ def get_entropy(self, role: str = "") -> torch.Tensor:
>>> print(entropy.shape)
torch.Size([4096, 8])
"""
- distribution = self._g_distribution[role] if role in self._g_distribution else self._g_distribution[""]
- if distribution is None:
+ if self._distribution is None:
return torch.tensor(0.0, device=self.device)
- return distribution.entropy().to(self.device)
+ return self._distribution.entropy().to(self.device)
def get_log_std(self, role: str = "") -> torch.Tensor:
"""Return the log standard deviation of the model
@@ -200,8 +180,7 @@ def get_log_std(self, role: str = "") -> torch.Tensor:
>>> print(log_std.shape)
torch.Size([4096, 8])
"""
- return (self._g_log_std[role] if role in self._g_log_std else self._g_log_std[""]) \
- .repeat(self._g_num_samples[role] if role in self._g_num_samples else self._g_num_samples[""], 1)
+ return self._log_std.repeat(self._num_samples, 1)
def distribution(self, role: str = "") -> torch.distributions.Normal:
"""Get the current distribution of the model
@@ -217,4 +196,4 @@ def distribution(self, role: str = "") -> torch.distributions.Normal:
>>> print(distribution)
Normal(loc: torch.Size([4096, 8]), scale: torch.Size([4096, 8]))
"""
- return self._g_distribution[role] if role in self._g_distribution else self._g_distribution[""]
+ return self._distribution
diff --git a/skrl/models/torch/multicategorical.py b/skrl/models/torch/multicategorical.py
new file mode 100644
index 00000000..2c749862
--- /dev/null
+++ b/skrl/models/torch/multicategorical.py
@@ -0,0 +1,155 @@
+from typing import Any, Mapping, Sequence, Tuple, Union
+
+import torch
+from torch.distributions import Categorical
+
+
+class MultiCategoricalMixin:
+ def __init__(self, unnormalized_log_prob: bool = True, reduction: str = "sum", role: str = "") -> None:
+ """MultiCategorical mixin model (stochastic model)
+
+ :param unnormalized_log_prob: Flag to indicate how to be interpreted the model's output (default: ``True``).
+ If True, the model's output is interpreted as unnormalized log probabilities
+ (it can be any real number), otherwise as normalized probabilities
+ (the output must be non-negative, finite and have a non-zero sum)
+ :type unnormalized_log_prob: bool, optional
+ :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``).
+ Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density
+ function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)``
+ :type reduction: str, optional
+ :param role: Role play by the model (default: ``""``)
+ :type role: str, optional
+
+ :raises ValueError: If the reduction method is not valid
+
+ Example::
+
+ # define the model
+ >>> import torch
+ >>> import torch.nn as nn
+ >>> from skrl.models.torch import Model, MultiCategoricalMixin
+ >>>
+ >>> class Policy(MultiCategoricalMixin, Model):
+ ... def __init__(self, observation_space, action_space, device="cuda:0", unnormalized_log_prob=True, reduction="sum"):
+ ... Model.__init__(self, observation_space, action_space, device)
+ ... MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
+ ...
+ ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32),
+ ... nn.ELU(),
+ ... nn.Linear(32, 32),
+ ... nn.ELU(),
+ ... nn.Linear(32, self.num_actions))
+ ...
+ ... def compute(self, inputs, role):
+ ... return self.net(inputs["states"]), {}
+ ...
+ >>> # given an observation_space: gym.spaces.Box with shape (4,)
+ >>> # and an action_space: gym.spaces.MultiDiscrete with nvec = [3, 2]
+ >>> model = Policy(observation_space, action_space)
+ >>>
+ >>> print(model)
+ Policy(
+ (net): Sequential(
+ (0): Linear(in_features=4, out_features=32, bias=True)
+ (1): ELU(alpha=1.0)
+ (2): Linear(in_features=32, out_features=32, bias=True)
+ (3): ELU(alpha=1.0)
+ (4): Linear(in_features=32, out_features=5, bias=True)
+ )
+ )
+ """
+ self._unnormalized_log_prob = unnormalized_log_prob
+ self._distributions = []
+
+ if reduction not in ["mean", "sum", "prod", "none"]:
+ raise ValueError("reduction must be one of 'mean', 'sum', 'prod' or 'none'")
+ self._reduction = torch.mean if reduction == "mean" else torch.sum if reduction == "sum" \
+ else torch.prod if reduction == "prod" else None
+
+ def act(self,
+ inputs: Mapping[str, Union[torch.Tensor, Any]],
+ role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]:
+ """Act stochastically in response to the state of the environment
+
+ :param inputs: Model inputs. The most common keys are:
+
+ - ``"states"``: state of the environment used to make the decision
+ - ``"taken_actions"``: actions taken by the policy for the given states
+ :type inputs: dict where the values are typically torch.Tensor
+ :param role: Role play by the model (default: ``""``)
+ :type role: str, optional
+
+ :return: Model output. The first component is the action to be taken by the agent.
+ The second component is the log of the probability density function.
+ The third component is a dictionary containing the network output ``"net_output"``
+ and extra output values
+ :rtype: tuple of torch.Tensor, torch.Tensor or None, and dict
+
+ Example::
+
+ >>> # given a batch of sample states with shape (4096, 4)
+ >>> actions, log_prob, outputs = model.act({"states": states})
+ >>> print(actions.shape, log_prob.shape, outputs["net_output"].shape)
+ torch.Size([4096, 2]) torch.Size([4096, 1]) torch.Size([4096, 5])
+ """
+ # map from states/observations to normalized probabilities or unnormalized log probabilities
+ net_output, outputs = self.compute(inputs, role)
+
+ # unnormalized log probabilities
+ if self._unnormalized_log_prob:
+ self._distributions = [Categorical(logits=logits) for logits in torch.split(net_output, self.action_space.nvec.tolist(), dim=-1)]
+ # normalized probabilities
+ else:
+ self._distributions = [Categorical(probs=probs) for probs in torch.split(net_output, self.action_space.nvec.tolist(), dim=-1)]
+
+ # actions
+ actions = torch.stack([distribution.sample() for distribution in self._distributions], dim=-1)
+
+ # log of the probability density function
+ log_prob = torch.stack([distribution.log_prob(_actions.view(-1)) for _actions, distribution \
+ in zip(torch.unbind(inputs.get("taken_actions", actions), dim=-1), self._distributions)], dim=-1)
+ if self._reduction is not None:
+ log_prob = self._reduction(log_prob, dim=-1)
+ if log_prob.dim() != actions.dim():
+ log_prob = log_prob.unsqueeze(-1)
+
+ outputs["net_output"] = net_output
+ return actions, log_prob, outputs
+
+ def get_entropy(self, role: str = "") -> torch.Tensor:
+ """Compute and return the entropy of the model
+
+ :return: Entropy of the model
+ :rtype: torch.Tensor
+ :param role: Role play by the model (default: ``""``)
+ :type role: str, optional
+
+ Example::
+
+ >>> entropy = model.get_entropy()
+ >>> print(entropy.shape)
+ torch.Size([4096, 1])
+ """
+ if self._distributions:
+ entropy = torch.stack([distribution.entropy().to(self.device) for distribution in self._distributions], dim=-1)
+ if self._reduction is not None:
+ return self._reduction(entropy, dim=-1).unsqueeze(-1)
+ return entropy
+ return torch.tensor(0.0, device=self.device)
+
+ def distribution(self, role: str = "") -> torch.distributions.Categorical:
+ """Get the current distribution of the model
+
+ :return: First distributions of the model
+ :rtype: torch.distributions.Categorical
+ :param role: Role play by the model (default: ``""``)
+ :type role: str, optional
+
+ Example::
+
+ >>> distribution = model.distribution()
+ >>> print(distribution)
+ Categorical(probs: torch.Size([10, 3]), logits: torch.Size([10, 3]))
+ """
+ # TODO: find a way to integrate in the class the distribution functions (e.g.: stddev)
+ return self._distributions[0]
diff --git a/skrl/models/torch/multivariate_gaussian.py b/skrl/models/torch/multivariate_gaussian.py
index bf7a9ccf..0f43aadc 100644
--- a/skrl/models/torch/multivariate_gaussian.py
+++ b/skrl/models/torch/multivariate_gaussian.py
@@ -65,34 +65,20 @@ def __init__(self,
)
)
"""
- if not hasattr(self, "_mg_clip_actions"):
- self._mg_clip_actions = {}
- self._mg_clip_actions[role] = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
+ self._clip_actions = clip_actions and (issubclass(type(self.action_space), gym.Space) or \
issubclass(type(self.action_space), gymnasium.Space))
- if self._mg_clip_actions[role]:
- self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
- self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32)
-
- if not hasattr(self, "_mg_clip_log_std"):
- self._mg_clip_log_std = {}
- self._mg_clip_log_std[role] = clip_log_std
- if not hasattr(self, "_mg_log_std_min"):
- self._mg_log_std_min = {}
- self._mg_log_std_min[role] = min_log_std
- if not hasattr(self, "_mg_log_std_max"):
- self._mg_log_std_max = {}
- self._mg_log_std_max[role] = max_log_std
-
- if not hasattr(self, "_mg_log_std"):
- self._mg_log_std = {}
- self._mg_log_std[role] = None
- if not hasattr(self, "_mg_num_samples"):
- self._mg_num_samples = {}
- self._mg_num_samples[role] = None
- if not hasattr(self, "_mg_distribution"):
- self._mg_distribution = {}
- self._mg_distribution[role] = None
+ if self._clip_actions:
+ self._clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32)
+ self._clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32)
+
+ self._clip_log_std = clip_log_std
+ self._log_std_min = min_log_std
+ self._log_std_max = max_log_std
+
+ self._log_std = None
+ self._num_samples = None
+ self._distribution = None
def act(self,
inputs: Mapping[str, Union[torch.Tensor, Any]],
@@ -124,27 +110,25 @@ def act(self,
mean_actions, log_std, outputs = self.compute(inputs, role)
# clamp log standard deviations
- if self._mg_clip_log_std[role] if role in self._mg_clip_log_std else self._mg_clip_log_std[""]:
- log_std = torch.clamp(log_std,
- self._mg_log_std_min[role] if role in self._mg_log_std_min else self._mg_log_std_min[""],
- self._mg_log_std_max[role] if role in self._mg_log_std_max else self._mg_log_std_max[""])
+ if self._clip_log_std:
+ log_std = torch.clamp(log_std, self._log_std_min, self._log_std_max)
- self._mg_log_std[role] = log_std
- self._mg_num_samples[role] = mean_actions.shape[0]
+ self._log_std = log_std
+ self._num_samples = mean_actions.shape[0]
# distribution
covariance = torch.diag(log_std.exp() * log_std.exp())
- self._mg_distribution[role] = MultivariateNormal(mean_actions, scale_tril=covariance)
+ self._distribution = MultivariateNormal(mean_actions, scale_tril=covariance)
# sample using the reparameterization trick
- actions = self._mg_distribution[role].rsample()
+ actions = self._distribution.rsample()
# clip actions
- if self._mg_clip_actions[role] if role in self._mg_clip_actions else self._mg_clip_actions[""]:
- actions = torch.clamp(actions, min=self.clip_actions_min, max=self.clip_actions_max)
+ if self._clip_actions:
+ actions = torch.clamp(actions, min=self._clip_actions_min, max=self._clip_actions_max)
# log of the probability density function
- log_prob = self._mg_distribution[role].log_prob(inputs.get("taken_actions", actions))
+ log_prob = self._distribution.log_prob(inputs.get("taken_actions", actions))
if log_prob.dim() != actions.dim():
log_prob = log_prob.unsqueeze(-1)
@@ -165,10 +149,9 @@ def get_entropy(self, role: str = "") -> torch.Tensor:
>>> print(entropy.shape)
torch.Size([4096])
"""
- distribution = self._mg_distribution[role] if role in self._mg_distribution else self._mg_distribution[""]
- if distribution is None:
+ if self._distribution is None:
return torch.tensor(0.0, device=self.device)
- return distribution.entropy().to(self.device)
+ return self._distribution.entropy().to(self.device)
def get_log_std(self, role: str = "") -> torch.Tensor:
"""Return the log standard deviation of the model
@@ -184,8 +167,7 @@ def get_log_std(self, role: str = "") -> torch.Tensor:
>>> print(log_std.shape)
torch.Size([4096, 8])
"""
- return (self._mg_log_std[role] if role in self._mg_log_std else self._mg_log_std[""]) \
- .repeat(self._mg_num_samples[role] if role in self._mg_num_samples else self._mg_num_samples[""], 1)
+ return self._log_std.repeat(self._num_samples, 1)
def distribution(self, role: str = "") -> torch.distributions.MultivariateNormal:
"""Get the current distribution of the model
@@ -201,4 +183,4 @@ def distribution(self, role: str = "") -> torch.distributions.MultivariateNormal
>>> print(distribution)
MultivariateNormal(loc: torch.Size([4096, 8]), scale_tril: torch.Size([4096, 8, 8]))
"""
- return self._mg_distribution[role] if role in self._mg_distribution else self._mg_distribution[""]
+ return self._distribution
diff --git a/skrl/multi_agents/jax/ippo/ippo.py b/skrl/multi_agents/jax/ippo/ippo.py
index 31aed63c..6dcfd328 100644
--- a/skrl/multi_agents/jax/ippo/ippo.py
+++ b/skrl/multi_agents/jax/ippo/ippo.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Mapping, Optional, Sequence, Union
+from typing import Any, Mapping, Optional, Sequence, Union
import copy
import functools
@@ -189,7 +189,7 @@ def _value_loss(params):
class IPPO(MultiAgent):
def __init__(self,
possible_agents: Sequence[str],
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memories: Optional[Mapping[str, Memory]] = None,
observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None,
action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None,
@@ -308,7 +308,7 @@ def __init__(self,
else:
self._value_preprocessor[uid] = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/multi_agents/jax/mappo/mappo.py b/skrl/multi_agents/jax/mappo/mappo.py
index dd15c560..aa3aff34 100644
--- a/skrl/multi_agents/jax/mappo/mappo.py
+++ b/skrl/multi_agents/jax/mappo/mappo.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Mapping, Optional, Sequence, Union
+from typing import Any, Mapping, Optional, Sequence, Union
import copy
import functools
@@ -191,7 +191,7 @@ def _value_loss(params):
class MAPPO(MultiAgent):
def __init__(self,
possible_agents: Sequence[str],
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memories: Optional[Mapping[str, Memory]] = None,
observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None,
action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None,
@@ -323,7 +323,7 @@ def __init__(self,
else:
self._value_preprocessor[uid] = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/multi_agents/torch/ippo/ippo.py b/skrl/multi_agents/torch/ippo/ippo.py
index e4a386f0..45913edd 100644
--- a/skrl/multi_agents/torch/ippo/ippo.py
+++ b/skrl/multi_agents/torch/ippo/ippo.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Mapping, Optional, Sequence, Union
+from typing import Any, Mapping, Optional, Sequence, Union
import copy
import itertools
@@ -67,7 +67,7 @@
class IPPO(MultiAgent):
def __init__(self,
possible_agents: Sequence[str],
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memories: Optional[Mapping[str, Memory]] = None,
observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None,
action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None,
@@ -178,7 +178,7 @@ def __init__(self,
else:
self._value_preprocessor[uid] = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/multi_agents/torch/mappo/mappo.py b/skrl/multi_agents/torch/mappo/mappo.py
index 0d98275b..98fff05c 100644
--- a/skrl/multi_agents/torch/mappo/mappo.py
+++ b/skrl/multi_agents/torch/mappo/mappo.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Mapping, Optional, Sequence, Union
+from typing import Any, Mapping, Optional, Sequence, Union
import copy
import itertools
@@ -69,7 +69,7 @@
class MAPPO(MultiAgent):
def __init__(self,
possible_agents: Sequence[str],
- models: Dict[str, Model],
+ models: Mapping[str, Model],
memories: Optional[Mapping[str, Memory]] = None,
observation_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None,
action_spaces: Optional[Union[Mapping[str, int], Mapping[str, gym.Space], Mapping[str, gymnasium.Space]]] = None,
@@ -193,7 +193,7 @@ def __init__(self,
else:
self._value_preprocessor[uid] = self._empty_preprocessor
- def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
+ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
diff --git a/skrl/trainers/jax/__init__.py b/skrl/trainers/jax/__init__.py
index 0db99a1b..4348d781 100644
--- a/skrl/trainers/jax/__init__.py
+++ b/skrl/trainers/jax/__init__.py
@@ -1,4 +1,4 @@
from skrl.trainers.jax.base import Trainer, generate_equally_spaced_scopes # isort:skip
-from skrl.trainers.jax.manual import ManualTrainer
from skrl.trainers.jax.sequential import SequentialTrainer
+from skrl.trainers.jax.step import StepTrainer
diff --git a/skrl/trainers/jax/base.py b/skrl/trainers/jax/base.py
index b30d0a86..b542c2e7 100644
--- a/skrl/trainers/jax/base.py
+++ b/skrl/trainers/jax/base.py
@@ -2,6 +2,7 @@
import atexit
import contextlib
+import sys
import tqdm
from skrl import logger
@@ -161,7 +162,7 @@ def single_agent_train(self) -> None:
# reset env
states, infos = self.env.reset()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
@@ -218,7 +219,7 @@ def single_agent_eval(self) -> None:
# reset env
states, infos = self.env.reset()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# compute actions
with contextlib.nullcontext():
@@ -274,7 +275,7 @@ def multi_agent_train(self) -> None:
states, infos = self.env.reset()
shared_states = infos.get("shared_states", None)
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
@@ -334,7 +335,7 @@ def multi_agent_eval(self) -> None:
states, infos = self.env.reset()
shared_states = infos.get("shared_states", None)
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# compute actions
with contextlib.nullcontext():
diff --git a/skrl/trainers/jax/sequential.py b/skrl/trainers/jax/sequential.py
index 0bd48278..6fbb261c 100644
--- a/skrl/trainers/jax/sequential.py
+++ b/skrl/trainers/jax/sequential.py
@@ -2,6 +2,7 @@
import contextlib
import copy
+import sys
import tqdm
import jax.numpy as jnp
@@ -11,12 +12,14 @@
from skrl.trainers.jax import Trainer
+# [start-config-dict-jax]
SEQUENTIAL_TRAINER_DEFAULT_CONFIG = {
"timesteps": 100000, # number of timesteps to train for
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
}
+# [end-config-dict-jax]
class SequentialTrainer(Trainer):
@@ -84,7 +87,7 @@ def train(self) -> None:
# reset env
states, infos = self.env.reset()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# pre-interaction
for agent in self.agents:
@@ -156,7 +159,7 @@ def eval(self) -> None:
# reset env
states, infos = self.env.reset()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# compute actions
with contextlib.nullcontext():
diff --git a/skrl/trainers/jax/manual.py b/skrl/trainers/jax/step.py
similarity index 89%
rename from skrl/trainers/jax/manual.py
rename to skrl/trainers/jax/step.py
index b8bf0c40..ae7e5986 100644
--- a/skrl/trainers/jax/manual.py
+++ b/skrl/trainers/jax/step.py
@@ -1,33 +1,38 @@
-from typing import List, Optional, Union
+from typing import Any, List, Optional, Tuple, Union
import contextlib
import copy
+import sys
import tqdm
+import jax
import jax.numpy as jnp
+import numpy as np
from skrl.agents.jax import Agent
from skrl.envs.wrappers.jax import Wrapper
from skrl.trainers.jax import Trainer
-MANUAL_TRAINER_DEFAULT_CONFIG = {
+# [start-config-dict-jax]
+STEP_TRAINER_DEFAULT_CONFIG = {
"timesteps": 100000, # number of timesteps to train for
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
}
+# [end-config-dict-jax]
-class ManualTrainer(Trainer):
+class StepTrainer(Trainer):
def __init__(self,
env: Wrapper,
agents: Union[Agent, List[Agent]],
agents_scope: Optional[List[int]] = None,
cfg: Optional[dict] = None) -> None:
- """Manual trainer
+ """Step-by-step trainer
- Train agents by manually controlling the training/evaluation loop
+ Train agents by controlling the training/evaluation loop step by step
:param env: Environment to train on
:type env: skrl.envs.wrappers.jax.Wrapper
@@ -36,10 +41,10 @@ def __init__(self,
:param agents_scope: Number of environments for each agent to train on (default: ``None``)
:type agents_scope: tuple or list of int, optional
:param cfg: Configuration dictionary (default: ``None``).
- See MANUAL_TRAINER_DEFAULT_CONFIG for default values
+ See STEP_TRAINER_DEFAULT_CONFIG for default values
:type cfg: dict, optional
"""
- _cfg = copy.deepcopy(MANUAL_TRAINER_DEFAULT_CONFIG)
+ _cfg = copy.deepcopy(STEP_TRAINER_DEFAULT_CONFIG)
_cfg.update(cfg if cfg is not None else {})
agents_scope = agents_scope if agents_scope is not None else []
super().__init__(env=env, agents=agents, agents_scope=agents_scope, cfg=_cfg)
@@ -56,7 +61,9 @@ def __init__(self,
self.states = None
- def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None:
+ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> \
+ Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array],
+ Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]:
"""Execute a training iteration
This method executes the following steps once:
@@ -75,6 +82,9 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
:param timesteps: Total number of timesteps (default: ``None``).
If None, the total number of timesteps is obtained from the trainer's config
:type timesteps: int, optional
+
+ :return: Observation, reward, terminated, truncated, info
+ :rtype: tuple of np.ndarray or jax.Array and any other info
"""
if timestep is None:
self._timestep += 1
@@ -82,7 +92,7 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timesteps = self.timesteps if timesteps is None else timesteps
if self._progress is None:
- self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar)
+ self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar, file=sys.stdout)
self._progress.update(n=1)
# set running mode
@@ -162,7 +172,11 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
else:
self.states = next_states
- def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None:
+ return next_states, rewards, terminated, truncated, infos
+
+ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> \
+ Tuple[Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array],
+ Union[np.ndarray, jax.Array], Union[np.ndarray, jax.Array], Any]:
"""Evaluate the agents sequentially
This method executes the following steps in loop:
@@ -178,6 +192,9 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
:param timesteps: Total number of timesteps (default: ``None``).
If None, the total number of timesteps is obtained from the trainer's config
:type timesteps: int, optional
+
+ :return: Observation, reward, terminated, truncated, info
+ :rtype: tuple of np.ndarray or jax.Array and any other info
"""
if timestep is None:
self._timestep += 1
@@ -185,7 +202,7 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timesteps = self.timesteps if timesteps is None else timesteps
if self._progress is None:
- self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar)
+ self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar, file=sys.stdout)
self._progress.update(n=1)
# set running mode
@@ -249,3 +266,5 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
self.states, infos = self.env.reset()
else:
self.states = next_states
+
+ return next_states, rewards, terminated, truncated, infos
diff --git a/skrl/trainers/torch/__init__.py b/skrl/trainers/torch/__init__.py
index 9fcf4349..2d1f8e40 100644
--- a/skrl/trainers/torch/__init__.py
+++ b/skrl/trainers/torch/__init__.py
@@ -1,5 +1,5 @@
from skrl.trainers.torch.base import Trainer, generate_equally_spaced_scopes # isort:skip
-from skrl.trainers.torch.manual import ManualTrainer
from skrl.trainers.torch.parallel import ParallelTrainer
from skrl.trainers.torch.sequential import SequentialTrainer
+from skrl.trainers.torch.step import StepTrainer
diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py
index 5c8b7a52..a13d70a4 100644
--- a/skrl/trainers/torch/base.py
+++ b/skrl/trainers/torch/base.py
@@ -1,6 +1,7 @@
from typing import List, Optional, Union
import atexit
+import sys
import tqdm
import torch
@@ -162,7 +163,7 @@ def single_agent_train(self) -> None:
# reset env
states, infos = self.env.reset()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
@@ -218,7 +219,7 @@ def single_agent_eval(self) -> None:
# reset env
states, infos = self.env.reset()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# compute actions
with torch.no_grad():
@@ -273,7 +274,7 @@ def multi_agent_train(self) -> None:
states, infos = self.env.reset()
shared_states = infos.get("shared_states", None)
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
@@ -332,7 +333,7 @@ def multi_agent_eval(self) -> None:
states, infos = self.env.reset()
shared_states = infos.get("shared_states", None)
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# compute actions
with torch.no_grad():
diff --git a/skrl/trainers/torch/parallel.py b/skrl/trainers/torch/parallel.py
index 57212796..68b9b9d8 100644
--- a/skrl/trainers/torch/parallel.py
+++ b/skrl/trainers/torch/parallel.py
@@ -1,6 +1,7 @@
from typing import List, Optional, Union
import copy
+import sys
import tqdm
import torch
@@ -11,12 +12,14 @@
from skrl.trainers.torch import Trainer
+# [start-config-dict-torch]
PARALLEL_TRAINER_DEFAULT_CONFIG = {
"timesteps": 100000, # number of timesteps to train for
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
}
+# [end-config-dict-torch]
def fn_processor(process_index, *args):
@@ -201,7 +204,7 @@ def train(self) -> None:
if not states.is_cuda:
states.share_memory_()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# pre-interaction
for pipe in producer_pipes:
@@ -337,7 +340,7 @@ def eval(self) -> None:
if not states.is_cuda:
states.share_memory_()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# compute actions
with torch.no_grad():
diff --git a/skrl/trainers/torch/sequential.py b/skrl/trainers/torch/sequential.py
index 8f1ef1c2..49952351 100644
--- a/skrl/trainers/torch/sequential.py
+++ b/skrl/trainers/torch/sequential.py
@@ -1,6 +1,7 @@
from typing import List, Optional, Union
import copy
+import sys
import tqdm
import torch
@@ -10,12 +11,14 @@
from skrl.trainers.torch import Trainer
+# [start-config-dict-torch]
SEQUENTIAL_TRAINER_DEFAULT_CONFIG = {
"timesteps": 100000, # number of timesteps to train for
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
}
+# [end-config-dict-torch]
class SequentialTrainer(Trainer):
@@ -83,7 +86,7 @@ def train(self) -> None:
# reset env
states, infos = self.env.reset()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# pre-interaction
for agent in self.agents:
@@ -154,7 +157,7 @@ def eval(self) -> None:
# reset env
states, infos = self.env.reset()
- for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
+ for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):
# compute actions
with torch.no_grad():
diff --git a/skrl/trainers/torch/manual.py b/skrl/trainers/torch/step.py
similarity index 90%
rename from skrl/trainers/torch/manual.py
rename to skrl/trainers/torch/step.py
index c0179591..c60476f1 100644
--- a/skrl/trainers/torch/manual.py
+++ b/skrl/trainers/torch/step.py
@@ -1,6 +1,7 @@
-from typing import List, Optional, Union
+from typing import Any, List, Optional, Tuple, Union
import copy
+import sys
import tqdm
import torch
@@ -10,23 +11,25 @@
from skrl.trainers.torch import Trainer
-MANUAL_TRAINER_DEFAULT_CONFIG = {
+# [start-config-dict-torch]
+STEP_TRAINER_DEFAULT_CONFIG = {
"timesteps": 100000, # number of timesteps to train for
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
}
+# [end-config-dict-torch]
-class ManualTrainer(Trainer):
+class StepTrainer(Trainer):
def __init__(self,
env: Wrapper,
agents: Union[Agent, List[Agent]],
agents_scope: Optional[List[int]] = None,
cfg: Optional[dict] = None) -> None:
- """Manual trainer
+ """Step-by-step trainer
- Train agents by manually controlling the training/evaluation loop
+ Train agents by controlling the training/evaluation loop step by step
:param env: Environment to train on
:type env: skrl.envs.wrappers.torch.Wrapper
@@ -35,10 +38,10 @@ def __init__(self,
:param agents_scope: Number of environments for each agent to train on (default: ``None``)
:type agents_scope: tuple or list of int, optional
:param cfg: Configuration dictionary (default: ``None``).
- See MANUAL_TRAINER_DEFAULT_CONFIG for default values
+ See STEP_TRAINER_DEFAULT_CONFIG for default values
:type cfg: dict, optional
"""
- _cfg = copy.deepcopy(MANUAL_TRAINER_DEFAULT_CONFIG)
+ _cfg = copy.deepcopy(STEP_TRAINER_DEFAULT_CONFIG)
_cfg.update(cfg if cfg is not None else {})
agents_scope = agents_scope if agents_scope is not None else []
super().__init__(env=env, agents=agents, agents_scope=agents_scope, cfg=_cfg)
@@ -55,7 +58,8 @@ def __init__(self,
self.states = None
- def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None:
+ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> \
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
"""Execute a training iteration
This method executes the following steps once:
@@ -74,6 +78,9 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
:param timesteps: Total number of timesteps (default: ``None``).
If None, the total number of timesteps is obtained from the trainer's config
:type timesteps: int, optional
+
+ :return: Observation, reward, terminated, truncated, info
+ :rtype: tuple of torch.Tensor and any other info
"""
if timestep is None:
self._timestep += 1
@@ -81,7 +88,7 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timesteps = self.timesteps if timesteps is None else timesteps
if self._progress is None:
- self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar)
+ self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar, file=sys.stdout)
self._progress.update(n=1)
# set running mode
@@ -162,7 +169,10 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
else:
self.states = next_states
- def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None:
+ return next_states, rewards, terminated, truncated, infos
+
+ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> \
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
"""Evaluate the agents sequentially
This method executes the following steps in loop:
@@ -178,6 +188,9 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
:param timesteps: Total number of timesteps (default: ``None``).
If None, the total number of timesteps is obtained from the trainer's config
:type timesteps: int, optional
+
+ :return: Observation, reward, terminated, truncated, info
+ :rtype: tuple of torch.Tensor and any other info
"""
if timestep is None:
self._timestep += 1
@@ -185,7 +198,7 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timesteps = self.timesteps if timesteps is None else timesteps
if self._progress is None:
- self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar)
+ self._progress = tqdm.tqdm(total=timesteps, disable=self.disable_progressbar, file=sys.stdout)
self._progress.update(n=1)
# set running mode
@@ -248,3 +261,5 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
self.states, infos = self.env.reset()
else:
self.states = next_states
+
+ return next_states, rewards, terminated, truncated, infos