Skip to content

Commit

Permalink
[BugFix] Fix failing tests
Browse files Browse the repository at this point in the history
ghstack-source-id: a43a2e3dbf76cd63c57ae00028df04b41a4e2f2b
Pull Request resolved: #2582
  • Loading branch information
vmoens committed Nov 19, 2024
1 parent 408cf7d commit 863121a
Show file tree
Hide file tree
Showing 20 changed files with 248 additions and 117 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ jobs:
REF_TYPE=${{ github.ref_type }}
REF_NAME=${{ github.ref_name }}
apt-get update
apt-get install rsync -y
if [[ "${REF_TYPE}" == branch ]]; then
if [[ "${REF_NAME}" == main ]]; then
Expand Down
6 changes: 4 additions & 2 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
OrnsteinUhlenbeckProcessModule(
spec=action_spec,
annealing_num_steps=1_000_000,
).to(device),
device=device,
),
)
elif cfg.network.noise_type == "gaussian":
actor_model_explore = TensorDictSequential(
Expand All @@ -245,7 +246,8 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
sigma_init=1.0,
mean=0.0,
std=0.1,
).to(device),
device=device,
),
)
else:
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def make_dreamer(
annealing_num_steps=1,
mean=0.0,
std=cfg.networks.exploration_noise,
device=device,
),
)

Expand Down
1 change: 1 addition & 0 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def train(cfg: "DictConfig"): # noqa: F821
spec=env.unbatched_action_spec,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
device=cfg.train.device,
),
)

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/redq/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ collector:
async_collection: 1
frames_per_batch: 1024
total_frames: 1_000_000
device: cpu
device:
env_per_collector: 1
init_random_frames: 50_000
multi_step: 1
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def main(cfg: "DictConfig"): # noqa: F821
annealing_num_steps=cfg.exploration.annealing_frames,
sigma=cfg.exploration.ou_sigma,
theta=cfg.exploration.ou_theta,
).to(device),
device=device,
),
)
if device == torch.device("cpu"):
# mostly for debugging
Expand Down
68 changes: 38 additions & 30 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,59 @@
from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.collectors.collectors import DataCollectorBase

from torchrl.data import ReplayBuffer, TensorDictReplayBuffer
from torchrl.data.postprocs import MultiStep
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data import (
LazyMemmapStorage,
MultiStep,
PrioritizedSampler,
RandomSampler,
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs import ParallelEnv
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import env_creator, EnvCreator
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
from torchrl.envs import (
CatFrames,
CatTensors,
CenterCrop,
Compose,
DMControlEnv,
DoubleToFloat,
env_creator,
EnvBase,
EnvCreator,
FlattenObservation,
GrayScale,
gSDENoise,
GymEnv,
InitTracker,
NoopResetEnv,
ObservationNorm,
ParallelEnv,
Resize,
RewardScaling,
StepCounter,
ToTensorImage,
TransformedEnv,
VecNorm,
)
from torchrl.envs.transforms.transforms import (
FlattenObservation,
gSDENoise,
InitTracker,
StepCounter,
)
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
ActorCriticOperator,
ActorValueOperator,
DdpgCnnActor,
DdpgCnnQNet,
MLP,
NoisyLinear,
NormalParamExtractor,
ProbabilisticActor,
SafeModule,
SafeSequential,
TanhNormal,
ValueOperator,
)
from torchrl.modules.distributions import TanhNormal
from torchrl.modules.distributions.continuous import SafeTanhTransform
from torchrl.modules.models.exploration import LazygSDEModule
from torchrl.modules.models.models import DdpgCnnActor, DdpgCnnQNet, MLP
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.objectives import HardUpdate, SoftUpdate
from torchrl.objectives.common import LossModule
from torchrl.objectives import HardUpdate, LossModule, SoftUpdate, TargetNetUpdater
from torchrl.objectives.deprecated import REDQLoss_deprecated
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.record.loggers import Logger
from torchrl.record.recorder import VideoRecorder
from torchrl.trainers.helpers import sync_async_collector, sync_sync_collector
Expand Down Expand Up @@ -518,7 +522,7 @@ def make_redq_model(
actor_module = SafeSequential(
actor_module,
SafeModule(
LazygSDEModule(transform=transform),
LazygSDEModule(transform=transform, device=device),
in_keys=["action", gSDE_state_key, "_eps_gSDE"],
out_keys=["loc", "scale", "action", "_eps_gSDE"],
),
Expand Down Expand Up @@ -606,7 +610,9 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
categorical_action_encoding = cfg.env.categorical_action_encoding

if custom_env is None and custom_env_maker is None:
if isinstance(cfg.collector.device, str):
if cfg.collector.device in ("", None):
device = "cpu" if not torch.cuda.is_available() else "cuda:0"
elif isinstance(cfg.collector.device, str):
device = cfg.collector.device
elif isinstance(cfg.collector.device, Sequence):
device = cfg.collector.device[0]
Expand Down Expand Up @@ -1000,11 +1006,14 @@ def make_collector_offpolicy(
env_kwargs.update(make_env_kwargs)
elif make_env_kwargs is not None:
env_kwargs = make_env_kwargs
cfg.collector.device = (
cfg.collector.device
if len(cfg.collector.device) > 1
else cfg.collector.device[0]
)
if cfg.collector.device in ("", None):
cfg.collector.device = "cpu" if not torch.cuda.is_available() else "cuda:0"
else:
cfg.collector.device = (
cfg.collector.device
if len(cfg.collector.device) > 1
else cfg.collector.device[0]
)
collector_helper_kwargs = {
"env_fns": make_env,
"env_kwargs": env_kwargs,
Expand All @@ -1017,7 +1026,6 @@ def make_collector_offpolicy(
# we already took care of building the make_parallel_env function
"num_collectors": -cfg.num_workers // -cfg.collector.env_per_collector,
"device": cfg.collector.device,
"storing_device": cfg.collector.device,
"init_random_frames": cfg.collector.init_random_frames,
"split_trajs": True,
# trajectories must be separated if multi-step is used
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def make_td3_agent(cfg, train_env, eval_env, device):
mean=0,
std=0.1,
spec=action_spec,
).to(device),
device=device,
),
)
return model, actor_model_explore

Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/td3_bc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def make_td3_agent(cfg, train_env, device):
mean=0,
std=0.1,
spec=action_spec,
).to(device),
device=device,
),
)
return model, actor_model_explore

Expand Down
4 changes: 2 additions & 2 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def get_available_devices():
def get_default_devices():
num_cuda = torch.cuda.device_count()
if num_cuda == 0:
if torch.mps.is_available():
return [torch.device("mps:0")]
return [torch.device("cpu")]
elif num_cuda == 1:
return [torch.device("cuda:0")]
elif torch.mps.is_available():
return [torch.device("mps:0")]
else:
# then run on all devices
return get_available_devices()
Expand Down
2 changes: 1 addition & 1 deletion test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
("data", "sample_log_prob"),
],
)
def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=3):
def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=1):
env = NestedCountingEnv(nested_dim=nested_dim)
action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1)
policy_module = TensorDictModule(
Expand Down
23 changes: 23 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,29 @@ def make_and_test_policy(
)


@pytest.mark.parametrize(
"ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
)
def test_no_stopiteration(ctype):
# Tests that there is no StopIteration raised and that the length of the collector is properly set
if ctype is SyncDataCollector:
envs = SerialEnv(16, CountingEnv)
else:
envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)]

collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300)
try:
c_iter = iter(collector)
assert len(collector) == 2
for i in range(len(collector)): # noqa: B007
c = next(c_iter)
assert c is not None
assert i == 1
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
31 changes: 14 additions & 17 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def test_ou(
self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
):
torch.manual_seed(seed)
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
device
net = nn.Sequential(
nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor()
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
action_spec = Bounded(-torch.ones(d_act), torch.ones(d_act), (d_act,))
Expand All @@ -252,13 +252,13 @@ def test_ou(
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
)

if interface == "module":
ou = OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
ou = OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device)
exploratory_policy = TensorDictSequential(policy, ou)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device)
ou = exploratory_policy

tensordict = TensorDict(
Expand Down Expand Up @@ -338,10 +338,10 @@ def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0

if interface == "module":
exploratory_policy = TensorDictSequential(
policy, OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
policy, OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device)
)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device)
exploratory_policy(env.reset())
collector = SyncDataCollector(
create_env_fn=env,
Expand Down Expand Up @@ -456,10 +456,10 @@ def test_additivegaussian_sd(
device=device,
)
if interface == "module":
exploratory_policy = AdditiveGaussianModule(action_spec).to(device)
exploratory_policy = AdditiveGaussianModule(action_spec, device=device)
else:
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
device
net = nn.Sequential(
nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor()
)
module = SafeModule(
net,
Expand All @@ -473,10 +473,10 @@ def test_additivegaussian_sd(
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
)
given_spec = action_spec if spec_origin == "spec" else None
exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to(
device
exploratory_policy = AdditiveGaussianWrapper(
policy, spec=given_spec, device=device
)
if spec_origin is not None:
sigma_init = (
Expand Down Expand Up @@ -727,10 +727,7 @@ def test_gsde(
@pytest.mark.parametrize("std", [1, 2])
@pytest.mark.parametrize("sigma_init", [None, 1.5, 3])
@pytest.mark.parametrize("learn_sigma", [False, True])
@pytest.mark.parametrize(
"device",
[torch.device("cuda:0") if torch.cuda.device_count() else torch.device("cpu")],
)
@pytest.mark.parametrize("device", get_default_devices())
def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_sigma):
torch.manual_seed(0)
state = torch.randn(10000, *state_dim, device=device) * std + mean
Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,7 +2076,10 @@ def test_transform_rb(self, rbclass):
):
td = rb.sample(10)

@retry(AssertionError, tries=10, delay=0)
def test_collector_match(self):
torch.manual_seed(0)

# The counter in the collector should match the one from the transform
t = TrajCounter()

Expand Down
Loading

0 comments on commit 863121a

Please sign in to comment.