Skip to content

Commit

Permalink
[Doc] Adding recurrent policies to export tutorial
Browse files Browse the repository at this point in the history
ghstack-source-id: 1f1af399b120db8bbb1789748641f44fd3b1bd5e
Pull Request resolved: #2559
  • Loading branch information
vmoens committed Nov 13, 2024
1 parent c0187a9 commit 7051238
Show file tree
Hide file tree
Showing 6 changed files with 328 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
cd ./docs
# timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
# bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build
cd ..
cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"
Expand Down
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
version = f"main ({torchrl.__version__})"
release = "main"

os.environ["TORCHRL_CONSOLE_STREAM"] = "stdout"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
Expand Down Expand Up @@ -95,6 +97,7 @@
"abort_on_example_error": False,
"only_warn_on_example_error": True,
"show_memory": True,
"capture_repr": ("_repr_html_", "__repr__"), # capture representations
}

napoleon_use_ivar = True
Expand Down
65 changes: 65 additions & 0 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import pytest
import torch

import torchrl.modules
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from torch import nn
Expand Down Expand Up @@ -743,6 +745,41 @@ def test_set_temporal_mode(self):
lstm_module.parameters()
)

def test_python_cudnn(self):
lstm_module = LSTMModule(
input_size=3,
hidden_size=12,
batch_first=True,
dropout=0,
num_layers=2,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
).set_recurrent_mode(True)
obs = torch.rand(10, 20, 3)

hidden0 = torch.rand(10, 20, 2, 12)
hidden1 = torch.rand(10, 20, 2, 12)

is_init = torch.zeros(10, 20, dtype=torch.bool)
assert isinstance(lstm_module.lstm, nn.LSTM)
outs_ref = lstm_module(
observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init
)

lstm_module.make_python_based()
assert isinstance(lstm_module.lstm, torchrl.modules.LSTM)
outs_rl = lstm_module(
observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init
)
torch.testing.assert_close(outs_ref, outs_rl)

lstm_module.make_cudnn_based()
assert isinstance(lstm_module.lstm, nn.LSTM)
outs_cudnn = lstm_module(
observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init
)
torch.testing.assert_close(outs_ref, outs_cudnn)

def test_noncontiguous(self):
lstm_module = LSTMModule(
input_size=3,
Expand Down Expand Up @@ -1088,6 +1125,34 @@ def test_set_temporal_mode(self):
gru_module.parameters()
)

def test_python_cudnn(self):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
dropout=0,
num_layers=2,
in_keys=["observation", "hidden0"],
out_keys=["intermediate", ("next", "hidden0")],
).set_recurrent_mode(True)
obs = torch.rand(10, 20, 3)

hidden0 = torch.rand(10, 20, 2, 12)

is_init = torch.zeros(10, 20, dtype=torch.bool)
assert isinstance(gru_module.gru, nn.GRU)
outs_ref = gru_module(observation=obs, hidden0=hidden0, is_init=is_init)

gru_module.make_python_based()
assert isinstance(gru_module.gru, torchrl.modules.GRU)
outs_rl = gru_module(observation=obs, hidden0=hidden0, is_init=is_init)
torch.testing.assert_close(outs_ref, outs_rl)

gru_module.make_cudnn_based()
assert isinstance(gru_module.gru, nn.GRU)
outs_cudnn = gru_module(observation=obs, hidden0=hidden0, is_init=is_init)
torch.testing.assert_close(outs_ref, outs_cudnn)

def test_noncontiguous(self):
gru_module = GRUModule(
input_size=3,
Expand Down
24 changes: 21 additions & 3 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,17 @@
# Remove all attached handlers
while logger.hasHandlers():
logger.removeHandler(logger.handlers[0])
console_handler = logging.StreamHandler()
stream_handlers = {
"stdout": sys.stdout,
"stderr": sys.stderr,
}
TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM")
if TORCHRL_CONSOLE_STREAM:
stream_handler = stream_handlers[TORCHRL_CONSOLE_STREAM]
else:
stream_handler = None
console_handler = logging.StreamHandler(stream=stream_handler)

console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(name)s][%(levelname)s] %(message)s")
console_handler.setFormatter(formatter)
Expand Down Expand Up @@ -86,17 +96,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
val[2] = N

@staticmethod
def print(prefix=None): # noqa: T202
def print(prefix=None) -> str: # noqa: T202
"""Prints the state of the timer.
Returns:
the string printed using the logger.
"""
keys = list(timeit._REG)
keys.sort()
string = []
for name in keys:
strings = []
if prefix:
strings.append(prefix)
strings.append(
f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)"
)
logger.info(" -- ".join(strings))
string.append(" -- ".join(strings))
logger.info(string[-1])
return "\n".join(string)

@classmethod
def todict(cls, percall=True):
Expand Down
108 changes: 107 additions & 1 deletion torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tensordict.base import NO_DEFAULT

from tensordict.nn import TensorDictModuleBase as ModuleBase
from tensordict.nn import dispatch, TensorDictModuleBase as ModuleBase
from tensordict.utils import expand_as_right, prod, set_lazy_legacy

from torch import nn, Tensor
Expand Down Expand Up @@ -467,6 +467,8 @@ def __init__(
raise ValueError("The input lstm must have batch_first=True.")
if bidirectional:
raise ValueError("The input lstm cannot be bidirectional.")
if not hidden_size:
raise ValueError("hidden_size must be passed.")
if python_based:
lstm = LSTM(
input_size=input_size,
Expand Down Expand Up @@ -524,6 +526,58 @@ def __init__(
self.out_keys = out_keys
self._recurrent_mode = False

def make_python_based(self) -> LSTMModule:
"""Transforms the LSTM layer in its python-based version.
Returns:
self
"""
if isinstance(self.lstm, LSTM):
return self
lstm = LSTM(
input_size=self.lstm.input_size,
hidden_size=self.lstm.hidden_size,
num_layers=self.lstm.num_layers,
bias=self.lstm.bias,
dropout=self.lstm.dropout,
proj_size=self.lstm.proj_size,
device="meta",
batch_first=self.lstm.batch_first,
bidirectional=self.lstm.bidirectional,
)
from tensordict import from_module

from_module(self.lstm).to_module(lstm)
self.lstm = lstm
return self

def make_cudnn_based(self) -> LSTMModule:
"""Transforms the LSTM layer in its CuDNN-based version.
Returns:
self
"""
if isinstance(self.lstm, nn.LSTM):
return self
lstm = nn.LSTM(
input_size=self.lstm.input_size,
hidden_size=self.lstm.hidden_size,
num_layers=self.lstm.num_layers,
bias=self.lstm.bias,
dropout=self.lstm.dropout,
proj_size=self.lstm.proj_size,
device="meta",
batch_first=self.lstm.batch_first,
bidirectional=self.lstm.bidirectional,
)
from tensordict import from_module

from_module(self.lstm).to_module(lstm)
self.lstm = lstm
return self

def make_tensordict_primer(self):
"""Makes a tensordict primer for the environment.
Expand Down Expand Up @@ -644,6 +698,7 @@ def set_recurrent_mode(self, mode: bool = True):
out._recurrent_mode = mode
return out

@dispatch
def forward(self, tensordict: TensorDictBase):
# we want to get an error if the value input is missing, but not the hidden states
defaults = [NO_DEFAULT, None, None]
Expand Down Expand Up @@ -1273,6 +1328,56 @@ def __init__(
self.out_keys = out_keys
self._recurrent_mode = False

def make_python_based(self) -> GRUModule:
"""Transforms the GRU layer in its python-based version.
Returns:
self
"""
if isinstance(self.gru, GRU):
return self
gru = GRU(
input_size=self.gru.input_size,
hidden_size=self.gru.hidden_size,
num_layers=self.gru.num_layers,
bias=self.gru.bias,
dropout=self.gru.dropout,
device="meta",
batch_first=self.gru.batch_first,
bidirectional=self.gru.bidirectional,
)
from tensordict import from_module

from_module(self.gru).to_module(gru)
self.gru = gru
return self

def make_cudnn_based(self) -> GRUModule:
"""Transforms the GRU layer in its CuDNN-based version.
Returns:
self
"""
if isinstance(self.gru, nn.GRU):
return self
gru = nn.GRU(
input_size=self.gru.input_size,
hidden_size=self.gru.hidden_size,
num_layers=self.gru.num_layers,
bias=self.gru.bias,
dropout=self.gru.dropout,
device="meta",
batch_first=self.gru.batch_first,
bidirectional=self.gru.bidirectional,
)
from tensordict import from_module

from_module(self.gru).to_module(gru)
self.gru = gru
return self

def make_tensordict_primer(self):
"""Makes a tensordict primer for the environment.
Expand Down Expand Up @@ -1389,6 +1494,7 @@ def set_recurrent_mode(self, mode: bool = True):
out._recurrent_mode = mode
return out

@dispatch
@set_lazy_legacy(False)
def forward(self, tensordict: TensorDictBase):
# we want to get an error if the value input is missing, but not the hidden states
Expand Down
Loading

0 comments on commit 7051238

Please sign in to comment.