From a151923d5e1eab68a8256fa02d0946a0f57bd5d0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jul 2024 16:14:30 +0100 Subject: [PATCH] Revert "[BugFix] Fix tensordict private imports" (#2276) --- test/test_cost.py | 2 +- torchrl/envs/batched_envs.py | 2 +- torchrl/envs/transforms/transforms.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 36a81a6906d..76fc4e651f4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -12,7 +12,7 @@ from dataclasses import asdict, dataclass from packaging import version as pack_version -from tensordict._C import unravel_keys +from tensordict._tensordict import unravel_keys from tensordict.nn import ( InteractionType, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index dab0c52cdf7..7f462782757 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -27,7 +27,7 @@ TensorDict, TensorDictBase, ) -from tensordict._C import unravel_key +from tensordict._tensordict import unravel_key from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index eb9cdce923d..bec76c603e6 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -39,7 +39,7 @@ unravel_key, unravel_key_list, ) -from tensordict._C import _unravel_key_to_tuple +from tensordict._tensordict import _unravel_key_to_tuple from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import expand_as_right, expand_right, NestedKey from torch import nn, Tensor