From 62d977bc65040e4217bd6005e99c645d1d8a4a5d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Feb 2024 14:39:43 +0000 Subject: [PATCH 1/4] [Refactor] Use filter_empty=True in apply (#1879) --- test/test_cost.py | 2 +- torchrl/collectors/collectors.py | 22 +++++++++++++------ torchrl/envs/batched_envs.py | 31 ++++++++------------------- torchrl/envs/common.py | 2 +- torchrl/envs/transforms/transforms.py | 2 +- torchrl/objectives/cql.py | 28 ++++++++++++++++-------- 6 files changed, 46 insertions(+), 41 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index c6eb27172ee..dae1fa5f70c 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6100,7 +6100,7 @@ def zero_param(p): if isinstance(p, nn.Parameter): p.data.zero_() - params.apply(zero_param) + params.apply(zero_param, filter_empty=True) # assert len(list(floss_fn.parameters())) == 0 with params.to_module(loss_fn): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index bea46bb6cd4..202dcc9ead8 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -801,22 +801,30 @@ def check_exclusive(val): "Consider using a placeholder for missing keys." ) - policy_output._fast_apply(check_exclusive, call_on_nested=True) + policy_output._fast_apply( + check_exclusive, call_on_nested=True, filter_empty=True + ) + # Use apply, because it works well with lazy stacks # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has # changed them here). # This will cause a failure to update entries when policy and env device mismatch and # casting is necessary. + def filter_policy(value_output, value_input, value_input_clone): + if ( + (value_input is None) + or (value_output is not value_input) + or ~torch.isclose(value_output, value_input_clone).any() + ): + return value_output + filtered_policy_output = policy_output.apply( - lambda value_output, value_input, value_input_clone: value_output - if (value_input is None) - or (value_output is not value_input) - or ~torch.isclose(value_output, value_input_clone).any() - else None, + filter_policy, policy_input_copy, policy_input_clone, default=None, + filter_empty=True, ) self._policy_output_keys = list( self._policy_output_keys.union( @@ -933,7 +941,7 @@ def cuda_check(tensor: torch.Tensor): if tensor.is_cuda: cuda_devices.add(tensor.device) - self._final_rollout.apply(cuda_check) + self._final_rollout.apply(cuda_check, filter_empty=True) for device in cuda_devices: streams.append(torch.cuda.Stream(device, priority=-1)) events.append(streams[-1].record_event()) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9669963cb33..cfb977d4bb2 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -419,12 +419,8 @@ def _check_for_empty_spec(specs: CompositeSpec): def map_device(key, value, device_map=device_map): return value.to(device_map[key]) - # self._env_tensordict.named_apply( - # map_device, nested_keys=True, filter_empty=True - # ) self._env_tensordict.named_apply( - map_device, - nested_keys=True, + map_device, nested_keys=True, filter_empty=True ) self._batch_locked = meta_data.batch_locked @@ -792,16 +788,11 @@ def select_and_clone(name, tensor): if name in selected_output_keys: return tensor.clone() - # out = self.shared_tensordict_parent.named_apply( - # select_and_clone, - # nested_keys=True, - # filter_empty=True, - # ) out = self.shared_tensordict_parent.named_apply( select_and_clone, nested_keys=True, + filter_empty=True, ) - del out["next"] if out.device != device: if device is None: @@ -842,8 +833,7 @@ def select_and_clone(name, tensor): if name in self._selected_step_keys: return tensor.clone() - # out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) - out = next_td.named_apply(select_and_clone, nested_keys=True) + out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) if out.device != device: if device is None: @@ -1059,8 +1049,7 @@ def _start_workers(self) -> None: def look_for_cuda(tensor, has_cuda=has_cuda): has_cuda[0] = has_cuda[0] or tensor.is_cuda - # self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True) - self.shared_tensordict_parent.apply(look_for_cuda) + self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True) has_cuda = has_cuda[0] if has_cuda: self.event = torch.cuda.Event() @@ -1182,14 +1171,14 @@ def step_and_maybe_reset( if x.device != device else x.clone(), device=device, - # filter_empty=True, + filter_empty=True, ) tensordict_ = tensordict_._fast_apply( lambda x: x.to(device, non_blocking=True) if x.device != device else x.clone(), device=device, - # filter_empty=True, + filter_empty=True, ) else: next_td = next_td.clone().clear_device_() @@ -1244,7 +1233,7 @@ def select_and_clone(name, tensor): out = next_td.named_apply( select_and_clone, nested_keys=True, - # filter_empty=True, + filter_empty=True, ) if out.device != device: if device is None: @@ -1314,9 +1303,8 @@ def select_and_clone(name, tensor): out = self.shared_tensordict_parent.named_apply( select_and_clone, nested_keys=True, - # filter_empty=True, + filter_empty=True, ) - del out["next"] if out.device != device: if device is None: @@ -1452,8 +1440,7 @@ def _run_worker_pipe_shared_mem( def look_for_cuda(tensor, has_cuda=has_cuda): has_cuda[0] = has_cuda[0] or tensor.is_cuda - # shared_tensordict.apply(look_for_cuda, filter_empty=True) - shared_tensordict.apply(look_for_cuda) + shared_tensordict.apply(look_for_cuda, filter_empty=True) has_cuda = has_cuda[0] else: has_cuda = device.type == "cuda" diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 61cd211b6ae..746cc60f142 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -107,7 +107,7 @@ def metadata_from_env(env) -> EnvMetaData: def fill_device_map(name, val, device_map=device_map): device_map[name] = val.device - tensordict.named_apply(fill_device_map, nested_keys=True) + tensordict.named_apply(fill_device_map, nested_keys=True, filter_empty=True) return EnvMetaData( tensordict, specs, batch_size, env_str, device, batch_locked, device_map ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index efa59e25c26..b71c6fcffc3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3063,7 +3063,7 @@ def __init__(self): super().__init__(in_keys=[]) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict.apply(check_finite) + tensordict.apply(check_finite, filter_empty=True) return tensordict def _reset( diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index f963f0e0b52..69a30c7f484 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -577,9 +577,14 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: def _get_policy_actions(self, data, actor_params, num_actions=10): batch_size = data.batch_size batch_size = list(batch_size[:-1]) + [batch_size[-1] * num_actions] - tensordict = data.select(*self.actor_network.in_keys).apply( - lambda x: x.repeat_interleave(num_actions, dim=data.ndim - 1), - batch_size=batch_size, + in_keys = [unravel_key(key) for key in self.actor_network.in_keys] + + def filter_and_repeat(name, x): + if name in in_keys: + return x.repeat_interleave(num_actions, dim=data.ndim - 1) + + tensordict = data.named_apply( + filter_and_repeat, batch_size=batch_size, filter_empty=True ) with torch.no_grad(): with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module( @@ -731,13 +736,18 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tensor: batch_size = tensordict_q_random.batch_size batch_size = list(batch_size[:-1]) + [batch_size[-1] * self.num_random] - tensordict_q_random = tensordict_q_random.select( - *self.actor_network.in_keys - ).apply( - lambda x: x.repeat_interleave( - self.num_random, dim=tensordict_q_random.ndim - 1 - ), + in_keys = [unravel_key(key) for key in self.actor_network.in_keys] + + def filter_and_repeat(name, x): + if name in in_keys: + return x.repeat_interleave( + self.num_random, dim=tensordict_q_random.ndim - 1 + ) + + tensordict_q_random = tensordict_q_random.named_apply( + filter_and_repeat, batch_size=batch_size, + filter_empty=True, ) tensordict_q_random.set(self.tensor_keys.action, random_actions_tensor) cql_tensordict = torch.cat( From e53eb738fea2a924903f89d3faf4db7eb096c721 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Feb 2024 17:43:20 +0000 Subject: [PATCH 2/4] [BugFix] Fix _reset data passing in parallel env (#1880) --- test/test_env.py | 23 +++++++++++++++++++++++ torchrl/envs/batched_envs.py | 25 +++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index e316e1ae10f..15bcf5e3fcb 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -624,6 +624,29 @@ def test_parallel_env_with_policy( # env_serial.close() env0.close() + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + @pytest.mark.parametrize("heterogeneous", [False, True]) + def test_transform_env_transform_no_device(self, heterogeneous): + # Tests non-regression on 1865 + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(), StepCounter(max_steps=3) + ) + + if heterogeneous: + make_envs = [EnvCreator(make_env), EnvCreator(make_env)] + else: + make_envs = make_env + penv = ParallelEnv(2, make_envs) + r = penv.rollout(6, break_when_any_done=False) + assert r.shape == (2, 6) + try: + env = TransformedEnv(penv) + r = env.rollout(6, break_when_any_done=False) + assert r.shape == (2, 6) + finally: + penv.close() + @pytest.mark.skipif(not _has_gym, reason="no gym") @pytest.mark.parametrize( "env_name", diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index cfb977d4bb2..2a955af1261 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1284,7 +1284,22 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_, keys_to_update=list(self._selected_reset_keys) ) continue - out = ("reset", tensordict_) + if tensordict_ is not None: + tdkeys = list(tensordict_.keys(True, True)) + + # This way we can avoid calling select over all the keys in the shared tensordict + def tentative_update(val, other): + if other is not None: + val.copy_(other) + return val + + self.shared_tensordicts[i].apply_( + tentative_update, tensordict_, default=None + ) + out = ("reset", tdkeys) + else: + out = ("reset", False) + channel.send(out) workers.append(i) @@ -1509,7 +1524,13 @@ def look_for_cuda(tensor, has_cuda=has_cuda): torchrl_logger.info(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") - cur_td = env.reset(tensordict=data) + # we use 'data' to pass the keys that we need to pass to reset, + # because passing the entire buffer may have unwanted consequences + cur_td = env.reset( + tensordict=root_shared_tensordict.select(*data, strict=False) + if data + else None + ) shared_tensordict.update_( cur_td, keys_to_update=list(_selected_reset_keys), From 1fe745a584f0cc8e0f4fca09c4e5ac5f976b542f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Feb 2024 18:31:22 +0000 Subject: [PATCH 3/4] [Refactor] Use filter_empty=False in apply for params (#1882) --- torchrl/collectors/collectors.py | 10 +++++++--- torchrl/envs/transforms/rlhf.py | 6 ++++-- torchrl/objectives/common.py | 10 +++++++--- torchrl/objectives/ddpg.py | 4 +++- torchrl/objectives/multiagent/qmixer.py | 2 +- 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 202dcc9ead8..98040d9640e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -316,12 +316,14 @@ def _map_to_device_params(param, device): # Create a stateless policy, then populate this copy with params on device with param_and_buf.apply( - functools.partial(_map_to_device_params, device="meta") + functools.partial(_map_to_device_params, device="meta"), + filter_empty=False, ).to_module(policy): policy = deepcopy(policy) param_and_buf.apply( - functools.partial(_map_to_device_params, device=self.policy_device) + functools.partial(_map_to_device_params, device=self.policy_device), + filter_empty=False, ).to_module(policy) return policy, get_weights_fn @@ -1495,7 +1497,9 @@ def map_weight( weight = nn.Parameter(weight, requires_grad=False) return weight - local_policy_weights = TensorDictParams(policy_weights.apply(map_weight)) + local_policy_weights = TensorDictParams( + policy_weights.apply(map_weight, filter_empty=False) + ) def _get_weight_fn(weights=policy_weights): # This function will give the local_policy_weight the original weights. diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 623bc2864fe..33874393038 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -112,7 +112,9 @@ def __init__( # check that the model has parameters params = TensorDict.from_module(actor) - with params.apply(_stateless_param, device="meta").to_module(actor): + with params.apply( + _stateless_param, device="meta", filter_empty=False + ).to_module(actor): # copy a stateless actor self.__dict__["functional_actor"] = deepcopy(actor) # we need to register these params as buffer to have `to` and similar @@ -129,7 +131,7 @@ def _make_detached_param(x): ) return x.clone() - self.frozen_params = params.apply(_make_detached_param) + self.frozen_params = params.apply(_make_detached_param, filter_empty=False) if requires_grad: # includes the frozen params/buffers in the module parameters/buffers self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 1f5edcf26ed..5d620b56227 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -262,7 +262,9 @@ def _compare_and_expand(param): params = TensorDictParams( params.apply( - _compare_and_expand, batch_size=[expand_dim, *params.shape] + _compare_and_expand, + batch_size=[expand_dim, *params.shape], + filter_empty=False, ), no_convert=True, ) @@ -283,7 +285,7 @@ def _compare_and_expand(param): # set the functional module: we need to convert the params to non-differentiable params # otherwise they will appear twice in parameters with params.apply( - self._make_meta_params, device=torch.device("meta") + self._make_meta_params, device=torch.device("meta"), filter_empty=False ).to_module(module): # avoid buffers and params being exposed self.__dict__[module_name] = deepcopy(module) @@ -293,7 +295,9 @@ def _compare_and_expand(param): # if create_target_params: # we create a TensorDictParams to keep the target params as Buffer instances target_params = TensorDictParams( - params.apply(_make_target_param(clone=create_target_params)), + params.apply( + _make_target_param(clone=create_target_params), filter_empty=False + ), no_convert=True, ) setattr(self, name_params_target + "_params", target_params) diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 6572084c8ec..03e82689ad5 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -197,7 +197,9 @@ def __init__( actor_critic = ActorCriticWrapper(actor_network, value_network) params = TensorDict.from_module(actor_critic) - params_meta = params.apply(self._make_meta_params, device=torch.device("meta")) + params_meta = params.apply( + self._make_meta_params, device=torch.device("meta"), filter_empty=False + ) with params_meta.to_module(actor_critic): self.__dict__["actor_critic"] = deepcopy(actor_critic) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index f7b9307a962..fcfcba49ca1 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -224,7 +224,7 @@ def __init__( global_value_network = SafeSequential(local_value_network, mixer_network) params = TensorDict.from_module(global_value_network) with params.apply( - self._make_meta_params, device=torch.device("meta") + self._make_meta_params, device=torch.device("meta"), filter_empty=False ).to_module(global_value_network): self.__dict__["global_value_network"] = deepcopy(global_value_network) From 144f5470e9b0358d33cc8ac346f0fb1a07839067 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 7 Feb 2024 09:32:32 +0000 Subject: [PATCH 4/4] [Doc] Improve PrioritizedSampler doc and get rid of np dependency as much as possible (#1881) --- test/test_rb.py | 2 +- torchrl/data/replay_buffers/replay_buffers.py | 2 +- torchrl/data/replay_buffers/samplers.py | 116 +++++++++++++----- 3 files changed, 85 insertions(+), 35 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 4b9b1a5dc9f..548e4ba9726 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -834,7 +834,7 @@ def test_set_tensorclass(self, max_size, shape, storage): @pytest.mark.parametrize("priority_key", ["pk", "td_error"]) @pytest.mark.parametrize("contiguous", [True, False]) @pytest.mark.parametrize("device", get_default_devices()) -def test_prototype_prb(priority_key, contiguous, device): +def test_ptdrb(priority_key, contiguous, device): torch.manual_seed(0) np.random.seed(0) rb = TensorDictReplayBuffer( diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index c3999806aaf..749bf0888ae 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -925,7 +925,7 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: if data.ndim: priority = self._get_priority_vector(data) else: - priority = self._get_priority_item(data) + priority = torch.as_tensor(self._get_priority_item(data)) index = data.get("index") while index.shape != priority.shape: # reduce index diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 0352b803b66..5e9b6dd75be 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -24,7 +24,6 @@ from torchrl._utils import _replace_last from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage -from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES try: from torchrl._torchrl import ( @@ -250,11 +249,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class PrioritizedSampler(Sampler): """Prioritized sampler for replay buffer. - Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. - Prioritized experience replay." - (https://arxiv.org/abs/1511.05952) + Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952) Args: + max_capacity (int): maximum capacity of the buffer. alpha (float): exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. @@ -264,6 +262,51 @@ class PrioritizedSampler(Sampler): tensordicts (ie stored trajectory). Can be one of "max", "min", "median" or "mean". + Examples: + >>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler + >>> from tensordict import TensorDict + >>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0)) + >>> priority = torch.tensor([0, 1000]) + >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) + >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) + >>> rb.add(data_0) + >>> rb.add(data_1) + >>> rb.update_priority(torch.tensor([0, 1]), priority=priority) + >>> sample, info = rb.sample(10, return_info=True) + >>> print(sample) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), + obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), + priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), + reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False) + >>> print(info) + {'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, + 1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])} + + .. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the + process of updating the priorities: + + >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler + >>> from tensordict import TensorDict + >>> rb = TDRB( + ... storage=LazyTensorStorage(10), + ... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0), + ... priority_key="priority", # This kwarg isn't present in regular RBs + ... ) + >>> priority = torch.tensor([0, 1000]) + >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) + >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) + >>> data = torch.stack([data_0, data_1]) + >>> rb.extend(data) + >>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor + >>> sample, info = rb.sample(10, return_info=True) + >>> print(sample['index']) # The index is packed with the tensordict + tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + """ def __init__( @@ -327,15 +370,17 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: raise RuntimeError("negative p_sum") if p_min <= 0: raise RuntimeError("negative p_min") + # For some undefined reason, only np.random works here. + # All PT attempts fail, even when subsequently transformed into numpy mass = np.random.uniform(0.0, p_sum, size=batch_size) + # mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum) + # mass = torch.rand(batch_size).mul_(p_sum) index = self._sum_tree.scan_lower_bound(mass) - if not isinstance(index, np.ndarray): - index = np.array([index]) - if isinstance(index, torch.Tensor): - index.clamp_max_(len(storage) - 1) - else: - index = np.clip(index, None, len(storage) - 1) - weight = self._sum_tree[index] + index = torch.as_tensor(index) + if not index.ndim: + index = index.unsqueeze(0) + index.clamp_max_(len(storage) - 1) + weight = torch.as_tensor(self._sum_tree[index]) # Importance sampling weight formula: # w_i = (p_i / sum(p) * N) ^ (-beta) @@ -345,9 +390,10 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: # weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta) # weight_i = (p_i / min(p)) ^ (-beta) # weight = np.power(weight / (p_min + self._eps), -self._beta) - weight = np.power(weight / p_min, -self._beta) + weight = torch.pow(weight / p_min, -self._beta) return index, {"_weight": weight} + @torch.no_grad() def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: priority = self.default_priority @@ -360,6 +406,11 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: "priority should be a scalar or an iterable of the same " "length as index" ) + # make sure everything is cast to cpu + if isinstance(index, torch.Tensor) and not index.is_cpu: + index = index.cpu() + if isinstance(priority, torch.Tensor) and not priority.is_cpu: + priority = priority.cpu() self._sum_tree[index] = priority self._min_tree[index] = priority @@ -377,6 +428,7 @@ def extend(self, index: torch.Tensor) -> None: index = index.cpu() self._add_or_extend(index) + @torch.no_grad() def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] ) -> None: @@ -389,28 +441,26 @@ def update_priority( indexed elements. """ - if isinstance(index, INT_CLASSES): - if not isinstance(priority, float): - if len(priority) != 1: - raise RuntimeError( - f"priority length should be 1, got {len(priority)}" - ) - priority = priority.item() - else: - if not ( - isinstance(priority, float) - or len(priority) == 1 - or len(index) == len(priority) - ): + priority = torch.as_tensor(priority, device=torch.device("cpu")).detach() + index = torch.as_tensor( + index, dtype=torch.long, device=torch.device("cpu") + ).detach() + # we need to reshape priority if it has more than one elements or if it has + # a different shape than index + if priority.numel() > 1 and priority.shape != index.shape: + try: + priority = priority.reshape(index.shape[:1]) + except Exception as err: raise RuntimeError( "priority should be a number or an iterable of the same " - "length as index" - ) - index = _to_numpy(index) - priority = _to_numpy(priority) - - self._max_priority = max(self._max_priority, np.max(priority)) - priority = np.power(priority + self._eps, self._alpha) + f"length as index. Got priority of shape {priority.shape} and index " + f"{index.shape}." + ) from err + elif priority.numel() <= 1: + priority = priority.squeeze() + + self._max_priority = priority.max().clamp_min(self._max_priority).item() + priority = torch.pow(priority + self._eps, self._alpha) self._sum_tree[index] = priority self._min_tree[index] = priority @@ -1233,7 +1283,7 @@ def __getitem__(self, index): if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): - index = torch.tensor(index) + index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError(