Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
Browse files Browse the repository at this point in the history
  • Loading branch information
SandishKumarHN committed Feb 28, 2024
2 parents 582c9c5 + aebc6a2 commit 79d8a29
Show file tree
Hide file tree
Showing 30 changed files with 1,302 additions and 1,477 deletions.
2 changes: 2 additions & 0 deletions test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ def test_imports():
from torchrl.envs.gym_like import GymLikeEnv # noqa: F401
from torchrl.modules import SafeModule # noqa: F401
from torchrl.objectives.common import LossModule # noqa: F401

PrioritizedReplayBuffer(alpha=1.1, beta=1.1)
534 changes: 453 additions & 81 deletions test/test_cost.py

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2609,6 +2609,23 @@ def forward(self, values):
env.rollout(10, policy)


def test_parallel_another_ctx():
from torch import multiprocessing as mp

sm = mp.get_start_method()
if sm == "spawn":
other_sm = "fork"
else:
other_sm = "spawn"
env = ParallelEnv(2, ContinuousActionVecMockEnv, mp_start_method=other_sm)
try:
assert env.rollout(3) is not None
assert env._workers[0]._start_method == other_sm
finally:
env.close()
del env


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
141 changes: 96 additions & 45 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse

from numbers import Number

import numpy as np
Expand Down Expand Up @@ -859,7 +860,7 @@ def _get_mock_input_td(
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize("n_agent_inputs", [6, None])
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
@pytest.mark.parametrize("batch", [(4,), (4, 3), ()])
def test_multiagent_mlp(
self,
n_agents,
Expand Down Expand Up @@ -923,7 +924,7 @@ def test_multiagent_mlp_lazy(self):
share_params=False,
depth=2,
)
optim = torch.optim.Adam(mlp.parameters())
optim = torch.optim.SGD(mlp.parameters())
for p in mlp.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
Expand All @@ -938,6 +939,11 @@ def test_multiagent_mlp_lazy(self):
td = self._get_mock_input_td(3, 4, batch=(10,))
obs = td.get(("agents", "observation"))
out = mlp(obs)
assert (
not mlp.params[0]
.apply(lambda x, y: torch.isclose(x, y), mlp.params[1])
.any()
)
out.mean().backward()
optim.step()
for p in mlp.parameters():
Expand All @@ -947,27 +953,58 @@ def test_multiagent_mlp_lazy(self):
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
def test_multiagent_reset_mlp(
self,
n_agents,
centralised,
share_params,
):
actor_net = MultiAgentMLP(
n_agent_inputs=4,
n_agent_outputs=6,
num_cells=(4, 4),
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
)
params_before = actor_net.params.clone()
actor_net.reset_parameters()
params_after = actor_net.params
assert not params_before.apply(
lambda x, y: torch.isclose(x, y), params_after, batch_size=[]
).any()
if params_after.numel() > 1:
assert (
not params_after[0]
.apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[])
.any()
)

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize("channels", [3, None])
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
@pytest.mark.parametrize("batch", [(4,), (4, 3), ()])
def test_multiagent_cnn(
self,
n_agents,
centralised,
share_params,
batch,
channels,
x=50,
y=50,
x=15,
y=15,
):
torch.manual_seed(0)
cnn = MultiAgentConvNet(
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
in_features=channels,
kernel_sizes=3,
)
if channels is None:
channels = 3
Expand All @@ -983,21 +1020,20 @@ def test_multiagent_cnn(
obs = td[("agents", "observation")]
out = cnn(obs)
assert out.shape[:-1] == (*batch, n_agents)
for i in range(n_agents):
if centralised and share_params:
assert torch.allclose(out[..., i, :], out[..., 0, :])
else:
if centralised and share_params:
torch.testing.assert_close(out, out[..., :1, :].expand_as(out))
else:
for i in range(n_agents):
for j in range(i + 1, n_agents):
assert not torch.allclose(out[..., i, :], out[..., j, :])

obs[..., 0, 0, 0, 0] += 1
out2 = cnn(obs)
for i in range(n_agents):
if centralised:
# a modification to the input of agent 0 will impact all agents
assert not torch.allclose(out[..., i, :], out2[..., i, :])
elif i > 0:
assert torch.allclose(out[..., i, :], out2[..., i, :])
if centralised:
# a modification to the input of agent 0 will impact all agents
assert not torch.isclose(out, out2).all()
elif n_agents > 1:
assert not torch.isclose(out[..., 0, :], out2[..., 0, :]).all()
torch.testing.assert_close(out[..., 1:, :], out2[..., 1:, :])

obs = torch.randn(*batch, 1, channels, x, y).expand(
*batch, n_agents, channels, x, y
Expand All @@ -1013,13 +1049,16 @@ def test_multiagent_cnn(
assert not torch.allclose(out[..., i, :], out[..., j, :])

def test_multiagent_cnn_lazy(self):
n_agents = 5
n_channels = 3
cnn = MultiAgentConvNet(
n_agents=5,
n_agents=n_agents,
centralised=False,
share_params=False,
in_features=None,
kernel_sizes=3,
)
optim = torch.optim.Adam(cnn.parameters())
optim = torch.optim.SGD(cnn.parameters())
for p in cnn.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
break
Expand All @@ -1034,14 +1073,19 @@ def test_multiagent_cnn_lazy(self):
td = TensorDict(
{
"agents": TensorDict(
{"observation": torch.randn(10, 5, 3, 50, 50)},
[10, 5],
{"observation": torch.randn(4, n_agents, n_channels, 15, 15)},
[4, 5],
)
},
batch_size=[10],
batch_size=[4],
)
obs = td[("agents", "observation")]
out = cnn(obs)
assert (
not cnn.params[0]
.apply(lambda x, y: torch.isclose(x, y), cnn.params[1])
.any()
)
out.mean().backward()
optim.step()
for p in cnn.parameters():
Expand All @@ -1052,17 +1096,36 @@ def test_multiagent_cnn_lazy(self):
raise AssertionError("UninitializedParameter found")

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize(
"batch",
[
(10,),
(
10,
3,
),
(),
],
)
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
def test_multiagent_reset_cnn(
self,
n_agents,
centralised,
share_params,
):
actor_net = MultiAgentConvNet(
in_features=4,
num_cells=[5, 5],
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
)
params_before = actor_net.params.clone()
actor_net.reset_parameters()
params_after = actor_net.params
assert not params_before.apply(
lambda x, y: torch.isclose(x, y), params_after, batch_size=[]
).any()
if params_after.numel() > 1:
assert (
not params_after[0]
.apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[])
.any()
)

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
def test_vdn(self, n_agents, batch):
torch.manual_seed(0)
mixer = VDNMixer(n_agents=n_agents, device="cpu")
Expand All @@ -1075,17 +1138,7 @@ def test_vdn(self, n_agents, batch):
assert torch.equal(obs.sum(-2), out)

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize(
"batch",
[
(10,),
(
10,
3,
),
(),
],
)
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
@pytest.mark.parametrize("state_shape", [(64, 64, 3), (10,)])
def test_qmix(self, n_agents, batch, state_shape):
torch.manual_seed(0)
Expand Down Expand Up @@ -1271,7 +1324,6 @@ def test_onlinedtactor(self, batch_dims, T=5):
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_lstm_cell(device, bias):

lstm_cell1 = LSTMCell(10, 20, device=device, bias=bias)
lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias)

Expand Down Expand Up @@ -1307,7 +1359,6 @@ def test_python_lstm_cell(device, bias):
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_gru_cell(device, bias):

gru_cell1 = GRUCell(10, 20, device=device, bias=bias)
gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias)

Expand Down
40 changes: 38 additions & 2 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from torch.utils._pytree import tree_flatten, tree_map

from torchrl.collectors import RandomPolicy, SyncDataCollector
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
MultiStep,
PrioritizedReplayBuffer,
RemoteTensorDictReplayBuffer,
ReplayBuffer,
Expand Down Expand Up @@ -2559,6 +2561,13 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
rb = rbtype(storage=storage_cls(100, ndim=datadim), batch_size=4)
rb.extend(data)
assert len(rb) == 12
data = rb[:]
if datatype in ("tensordict", "tensorclass"):
assert data.numel() == 12
else:
assert all(
leaf.shape[:datadim].numel() == 12 for leaf in tree_flatten(data)[0]
)
s = rb.sample()
if datatype in ("tensordict", "tensorclass"):
assert (s.exclude("index") == 1).all()
Expand Down Expand Up @@ -2600,7 +2609,19 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
),
],
)
def test_rb_multidim_collector(self, rbtype, storage_cls, writer_cls, sampler_cls):
@pytest.mark.parametrize(
"transform",
[
None,
[
lambda: split_trajectories,
functools.partial(MultiStep, gamma=0.9, n_steps=3),
],
],
)
def test_rb_multidim_collector(
self, rbtype, storage_cls, writer_cls, sampler_cls, transform
):
from _utils_internal import CARTPOLE_VERSIONED

torch.manual_seed(0)
Expand All @@ -2625,9 +2646,24 @@ def test_rb_multidim_collector(self, rbtype, storage_cls, writer_cls, sampler_cl
sampler=sampler_cls(),
writer=writer_cls(),
)
if not isinstance(rb._sampler, SliceSampler) and transform is not None:
pytest.skip("no need to test this combination")
if transform:
for t in transform:
rb.append_transform(t())
for data in collector:
rb.extend(data)
rb.sample()
if isinstance(rb, TensorDictReplayBuffer) and transform is not None:
# this should fail bc we can't set the indices after executing the transform.
with pytest.raises(RuntimeError, match="Failed to set the metadata"):
rb.sample()
return
s = rb.sample()
rbtot = rb[:]
assert rbtot.shape[0] == 2
assert len(rb) == rbtot.numel()
if transform is not None:
assert s.ndim == 2


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 79d8a29

Please sign in to comment.