diff --git a/test/test_rb.py b/test/test_rb.py index 4b9b1a5dc9f..548e4ba9726 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -834,7 +834,7 @@ def test_set_tensorclass(self, max_size, shape, storage): @pytest.mark.parametrize("priority_key", ["pk", "td_error"]) @pytest.mark.parametrize("contiguous", [True, False]) @pytest.mark.parametrize("device", get_default_devices()) -def test_prototype_prb(priority_key, contiguous, device): +def test_ptdrb(priority_key, contiguous, device): torch.manual_seed(0) np.random.seed(0) rb = TensorDictReplayBuffer( diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index c3999806aaf..749bf0888ae 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -925,7 +925,7 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: if data.ndim: priority = self._get_priority_vector(data) else: - priority = self._get_priority_item(data) + priority = torch.as_tensor(self._get_priority_item(data)) index = data.get("index") while index.shape != priority.shape: # reduce index diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 0352b803b66..5e9b6dd75be 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -24,7 +24,6 @@ from torchrl._utils import _replace_last from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage -from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES try: from torchrl._torchrl import ( @@ -250,11 +249,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class PrioritizedSampler(Sampler): """Prioritized sampler for replay buffer. - Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. - Prioritized experience replay." - (https://arxiv.org/abs/1511.05952) + Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952) Args: + max_capacity (int): maximum capacity of the buffer. alpha (float): exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. @@ -264,6 +262,51 @@ class PrioritizedSampler(Sampler): tensordicts (ie stored trajectory). Can be one of "max", "min", "median" or "mean". + Examples: + >>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler + >>> from tensordict import TensorDict + >>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0)) + >>> priority = torch.tensor([0, 1000]) + >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) + >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) + >>> rb.add(data_0) + >>> rb.add(data_1) + >>> rb.update_priority(torch.tensor([0, 1]), priority=priority) + >>> sample, info = rb.sample(10, return_info=True) + >>> print(sample) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), + obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), + priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), + reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False) + >>> print(info) + {'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, + 1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])} + + .. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the + process of updating the priorities: + + >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler + >>> from tensordict import TensorDict + >>> rb = TDRB( + ... storage=LazyTensorStorage(10), + ... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0), + ... priority_key="priority", # This kwarg isn't present in regular RBs + ... ) + >>> priority = torch.tensor([0, 1000]) + >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) + >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) + >>> data = torch.stack([data_0, data_1]) + >>> rb.extend(data) + >>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor + >>> sample, info = rb.sample(10, return_info=True) + >>> print(sample['index']) # The index is packed with the tensordict + tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + """ def __init__( @@ -327,15 +370,17 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: raise RuntimeError("negative p_sum") if p_min <= 0: raise RuntimeError("negative p_min") + # For some undefined reason, only np.random works here. + # All PT attempts fail, even when subsequently transformed into numpy mass = np.random.uniform(0.0, p_sum, size=batch_size) + # mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum) + # mass = torch.rand(batch_size).mul_(p_sum) index = self._sum_tree.scan_lower_bound(mass) - if not isinstance(index, np.ndarray): - index = np.array([index]) - if isinstance(index, torch.Tensor): - index.clamp_max_(len(storage) - 1) - else: - index = np.clip(index, None, len(storage) - 1) - weight = self._sum_tree[index] + index = torch.as_tensor(index) + if not index.ndim: + index = index.unsqueeze(0) + index.clamp_max_(len(storage) - 1) + weight = torch.as_tensor(self._sum_tree[index]) # Importance sampling weight formula: # w_i = (p_i / sum(p) * N) ^ (-beta) @@ -345,9 +390,10 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: # weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta) # weight_i = (p_i / min(p)) ^ (-beta) # weight = np.power(weight / (p_min + self._eps), -self._beta) - weight = np.power(weight / p_min, -self._beta) + weight = torch.pow(weight / p_min, -self._beta) return index, {"_weight": weight} + @torch.no_grad() def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: priority = self.default_priority @@ -360,6 +406,11 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: "priority should be a scalar or an iterable of the same " "length as index" ) + # make sure everything is cast to cpu + if isinstance(index, torch.Tensor) and not index.is_cpu: + index = index.cpu() + if isinstance(priority, torch.Tensor) and not priority.is_cpu: + priority = priority.cpu() self._sum_tree[index] = priority self._min_tree[index] = priority @@ -377,6 +428,7 @@ def extend(self, index: torch.Tensor) -> None: index = index.cpu() self._add_or_extend(index) + @torch.no_grad() def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] ) -> None: @@ -389,28 +441,26 @@ def update_priority( indexed elements. """ - if isinstance(index, INT_CLASSES): - if not isinstance(priority, float): - if len(priority) != 1: - raise RuntimeError( - f"priority length should be 1, got {len(priority)}" - ) - priority = priority.item() - else: - if not ( - isinstance(priority, float) - or len(priority) == 1 - or len(index) == len(priority) - ): + priority = torch.as_tensor(priority, device=torch.device("cpu")).detach() + index = torch.as_tensor( + index, dtype=torch.long, device=torch.device("cpu") + ).detach() + # we need to reshape priority if it has more than one elements or if it has + # a different shape than index + if priority.numel() > 1 and priority.shape != index.shape: + try: + priority = priority.reshape(index.shape[:1]) + except Exception as err: raise RuntimeError( "priority should be a number or an iterable of the same " - "length as index" - ) - index = _to_numpy(index) - priority = _to_numpy(priority) - - self._max_priority = max(self._max_priority, np.max(priority)) - priority = np.power(priority + self._eps, self._alpha) + f"length as index. Got priority of shape {priority.shape} and index " + f"{index.shape}." + ) from err + elif priority.numel() <= 1: + priority = priority.squeeze() + + self._max_priority = priority.max().clamp_min(self._max_priority).item() + priority = torch.pow(priority + self._eps, self._alpha) self._sum_tree[index] = priority self._min_tree[index] = priority @@ -1233,7 +1283,7 @@ def __getitem__(self, index): if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): - index = torch.tensor(index) + index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError(