Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Improve PrioritizedSampler doc and get rid of np dependency as much as possible #1881

Merged
merged 8 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def test_set_tensorclass(self, max_size, shape, storage):
@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
def test_prototype_prb(priority_key, contiguous, device):
def test_ptdrb(priority_key, contiguous, device):
torch.manual_seed(0)
np.random.seed(0)
rb = TensorDictReplayBuffer(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
if data.ndim:
priority = self._get_priority_vector(data)
else:
priority = self._get_priority_item(data)
priority = torch.as_tensor(self._get_priority_item(data))
index = data.get("index")
while index.shape != priority.shape:
# reduce index
Expand Down
116 changes: 83 additions & 33 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from torchrl._utils import _replace_last
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES

try:
from torchrl._torchrl import (
Expand Down Expand Up @@ -250,11 +249,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
class PrioritizedSampler(Sampler):
"""Prioritized sampler for replay buffer.

Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
Prioritized experience replay."
(https://arxiv.org/abs/1511.05952)
Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952)

Args:
max_capacity (int): maximum capacity of the buffer.
alpha (float): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
Expand All @@ -264,6 +262,51 @@ class PrioritizedSampler(Sampler):
tensordicts (ie stored trajectory). Can be one of "max", "min",
"median" or "mean".

Examples:
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> rb.add(data_0)
>>> rb.add(data_1)
>>> rb.update_priority(torch.tensor([0, 1]), priority=priority)
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample)
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> print(info)
{'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

.. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the
process of updating the priorities:

>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = TDRB(
... storage=LazyTensorStorage(10),
... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
... priority_key="priority", # This kwarg isn't present in regular RBs
... )
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> data = torch.stack([data_0, data_1])
>>> rb.extend(data)
>>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample['index']) # The index is packed with the tensordict
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

"""

def __init__(
Expand Down Expand Up @@ -327,15 +370,17 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
raise RuntimeError("negative p_sum")
if p_min <= 0:
raise RuntimeError("negative p_min")
# For some undefined reason, only np.random works here.
# All PT attempts fail, even when subsequently transformed into numpy
mass = np.random.uniform(0.0, p_sum, size=batch_size)
# mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum)
# mass = torch.rand(batch_size).mul_(p_sum)
index = self._sum_tree.scan_lower_bound(mass)
if not isinstance(index, np.ndarray):
index = np.array([index])
if isinstance(index, torch.Tensor):
index.clamp_max_(len(storage) - 1)
else:
index = np.clip(index, None, len(storage) - 1)
weight = self._sum_tree[index]
index = torch.as_tensor(index)
if not index.ndim:
index = index.unsqueeze(0)
index.clamp_max_(len(storage) - 1)
weight = torch.as_tensor(self._sum_tree[index])

# Importance sampling weight formula:
# w_i = (p_i / sum(p) * N) ^ (-beta)
Expand All @@ -345,9 +390,10 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
# weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta)
# weight_i = (p_i / min(p)) ^ (-beta)
# weight = np.power(weight / (p_min + self._eps), -self._beta)
weight = np.power(weight / p_min, -self._beta)
weight = torch.pow(weight / p_min, -self._beta)
return index, {"_weight": weight}

@torch.no_grad()
def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
priority = self.default_priority

Expand All @@ -360,6 +406,11 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
"priority should be a scalar or an iterable of the same "
"length as index"
)
# make sure everything is cast to cpu
if isinstance(index, torch.Tensor) and not index.is_cpu:
index = index.cpu()
if isinstance(priority, torch.Tensor) and not priority.is_cpu:
priority = priority.cpu()

self._sum_tree[index] = priority
self._min_tree[index] = priority
Expand All @@ -377,6 +428,7 @@ def extend(self, index: torch.Tensor) -> None:
index = index.cpu()
self._add_or_extend(index)

@torch.no_grad()
def update_priority(
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor]
) -> None:
Expand All @@ -389,28 +441,26 @@ def update_priority(
indexed elements.

"""
if isinstance(index, INT_CLASSES):
if not isinstance(priority, float):
if len(priority) != 1:
raise RuntimeError(
f"priority length should be 1, got {len(priority)}"
)
priority = priority.item()
else:
if not (
isinstance(priority, float)
or len(priority) == 1
or len(index) == len(priority)
):
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
index = torch.as_tensor(
index, dtype=torch.long, device=torch.device("cpu")
).detach()
# we need to reshape priority if it has more than one elements or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
try:
priority = priority.reshape(index.shape[:1])
except Exception as err:
raise RuntimeError(
"priority should be a number or an iterable of the same "
"length as index"
)
index = _to_numpy(index)
priority = _to_numpy(priority)

self._max_priority = max(self._max_priority, np.max(priority))
priority = np.power(priority + self._eps, self._alpha)
f"length as index. Got priority of shape {priority.shape} and index "
f"{index.shape}."
) from err
elif priority.numel() <= 1:
priority = priority.squeeze()

self._max_priority = priority.max().clamp_min(self._max_priority).item()
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority

Expand Down Expand Up @@ -1233,7 +1283,7 @@ def __getitem__(self, index):
if isinstance(index, slice) and index == slice(None):
return self
if isinstance(index, (list, range, np.ndarray)):
index = torch.tensor(index)
index = torch.as_tensor(index)
if isinstance(index, torch.Tensor):
if index.ndim > 1:
raise RuntimeError(
Expand Down
Loading