Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 18, 2024
1 parent 748b2ba commit 7c0a7c8
Show file tree
Hide file tree
Showing 12 changed files with 24 additions and 33 deletions.
8 changes: 2 additions & 6 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.stats import chisquare
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.utils import _unravel_key_to_tuple

from torchrl._utils import _make_ordinal_device
from torchrl.data.tensor_specs import (
_keys_to_empty_composite_spec,
BinaryDiscreteTensorSpec,
Expand All @@ -29,11 +29,7 @@
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
)
from torchrl.data.utils import (
_make_ordinal_device,
check_no_exclusive_keys,
consolidate_spec,
)
from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
Expand Down
8 changes: 8 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,11 @@ def _can_be_pickled(obj):
return True
except (pickle.PickleError, AttributeError, TypeError):
return False


def _make_ordinal_device(device: torch.device):
if device is None:
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import _make_ordinal_device, CloudpickleWrapper, DEVICE_TYPING
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tensordict.utils import expand_as_right, expand_right
from torch import Tensor

from torchrl._utils import accept_remote_rref_udf_invocation
from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
RandomSampler,
Expand Down Expand Up @@ -58,7 +58,7 @@
Writer,
WriterEnsemble,
)
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.transforms.transforms import _InvertTransform


Expand Down
3 changes: 1 addition & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tensordict.memmap import MemoryMappedTensor
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torchrl._utils import implement_for, logger as torchrl_logger
from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger
from torchrl.data.replay_buffers.checkpointers import (
ListStorageCheckpointer,
StorageCheckpointerBase,
Expand All @@ -38,7 +38,6 @@
INT_CLASSES,
tree_iter,
)
from torchrl.data.utils import _make_ordinal_device


class Storage:
Expand Down
3 changes: 1 addition & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
unravel_key,
)
from tensordict.utils import _getitem_batch_size, NestedKey
from torchrl._utils import get_binary_env_var
from torchrl.data.utils import _make_ordinal_device
from torchrl._utils import _make_ordinal_device, get_binary_env_var

DEVICE_TYPING = Union[torch.device, str, int]

Expand Down
8 changes: 0 additions & 8 deletions torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,3 @@ def _find_action_space(action_space):
f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}."
)
return action_space


def _make_ordinal_device(device: torch.device):
if device is None:
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
8 changes: 2 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
_make_ordinal_device,
_ProcessNoWarn,
logger as torchrl_logger,
VERBOSE,
)
from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.utils import (
_make_ordinal_device,
CloudpickleWrapper,
contains_lazy_spec,
DEVICE_TYPING,
)
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData
from torchrl.envs.env_creator import get_env_metadata

Expand Down
3 changes: 2 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensordict.utils import NestedKey
from torchrl._utils import (
_ends_with,
_make_ordinal_device,
_replace_last,
implement_for,
prod,
Expand All @@ -31,7 +32,7 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.utils import (
_make_compatible_policy,
_repr_by_depth,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/habitat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import importlib.util

import torch

from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl._utils import _make_ordinal_device
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import _classproperty
Expand Down
3 changes: 1 addition & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from torch import nn, Tensor
from torch.utils._pytree import tree_map

from torchrl._utils import _append_last, _ends_with, _replace_last
from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
Expand All @@ -58,7 +58,6 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import _make_ordinal_device
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict
from torchrl.envs.transforms import functional as F
from torchrl.envs.transforms.utils import (
Expand Down
3 changes: 2 additions & 1 deletion torchrl/trainers/helpers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from typing import Optional

import torch
from torchrl._utils import _make_ordinal_device

from torchrl.data.replay_buffers.replay_buffers import (
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.data.utils import DEVICE_TYPING


def make_replay_buffer(
Expand Down

0 comments on commit 7c0a7c8

Please sign in to comment.