Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into prioritized_slice_sam…
Browse files Browse the repository at this point in the history
…pler
  • Loading branch information
vmoens committed Feb 7, 2024
2 parents 8463e09 + 144f547 commit 56de68a
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 88 deletions.
2 changes: 1 addition & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6100,7 +6100,7 @@ def zero_param(p):
if isinstance(p, nn.Parameter):
p.data.zero_()

params.apply(zero_param)
params.apply(zero_param, filter_empty=True)

# assert len(list(floss_fn.parameters())) == 0
with params.to_module(loss_fn):
Expand Down
23 changes: 23 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,29 @@ def test_parallel_env_with_policy(
# env_serial.close()
env0.close()

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@pytest.mark.parametrize("heterogeneous", [False, True])
def test_transform_env_transform_no_device(self, heterogeneous):
# Tests non-regression on 1865
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(), StepCounter(max_steps=3)
)

if heterogeneous:
make_envs = [EnvCreator(make_env), EnvCreator(make_env)]
else:
make_envs = make_env
penv = ParallelEnv(2, make_envs)
r = penv.rollout(6, break_when_any_done=False)
assert r.shape == (2, 6)
try:
env = TransformedEnv(penv)
r = env.rollout(6, break_when_any_done=False)
assert r.shape == (2, 6)
finally:
penv.close()

@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.parametrize(
"env_name",
Expand Down
2 changes: 1 addition & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def test_set_tensorclass(self, max_size, shape, storage):
@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
def test_prototype_prb(priority_key, contiguous, device):
def test_ptdrb(priority_key, contiguous, device):
torch.manual_seed(0)
np.random.seed(0)
rb = TensorDictReplayBuffer(
Expand Down
32 changes: 22 additions & 10 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ def _map_to_device_params(param, device):

# Create a stateless policy, then populate this copy with params on device
with param_and_buf.apply(
functools.partial(_map_to_device_params, device="meta")
functools.partial(_map_to_device_params, device="meta"),
filter_empty=False,
).to_module(policy):
policy = deepcopy(policy)

param_and_buf.apply(
functools.partial(_map_to_device_params, device=self.policy_device)
functools.partial(_map_to_device_params, device=self.policy_device),
filter_empty=False,
).to_module(policy)

return policy, get_weights_fn
Expand Down Expand Up @@ -801,22 +803,30 @@ def check_exclusive(val):
"Consider using a placeholder for missing keys."
)

policy_output._fast_apply(check_exclusive, call_on_nested=True)
policy_output._fast_apply(
check_exclusive, call_on_nested=True, filter_empty=True
)

# Use apply, because it works well with lazy stacks
# Edge-case of this approach: the policy may change the values in-place and only by a tiny bit
# or occasionally. In these cases, the keys will be missed (we can't detect if the policy has
# changed them here).
# This will cause a failure to update entries when policy and env device mismatch and
# casting is necessary.
def filter_policy(value_output, value_input, value_input_clone):
if (
(value_input is None)
or (value_output is not value_input)
or ~torch.isclose(value_output, value_input_clone).any()
):
return value_output

filtered_policy_output = policy_output.apply(
lambda value_output, value_input, value_input_clone: value_output
if (value_input is None)
or (value_output is not value_input)
or ~torch.isclose(value_output, value_input_clone).any()
else None,
filter_policy,
policy_input_copy,
policy_input_clone,
default=None,
filter_empty=True,
)
self._policy_output_keys = list(
self._policy_output_keys.union(
Expand Down Expand Up @@ -933,7 +943,7 @@ def cuda_check(tensor: torch.Tensor):
if tensor.is_cuda:
cuda_devices.add(tensor.device)

self._final_rollout.apply(cuda_check)
self._final_rollout.apply(cuda_check, filter_empty=True)
for device in cuda_devices:
streams.append(torch.cuda.Stream(device, priority=-1))
events.append(streams[-1].record_event())
Expand Down Expand Up @@ -1487,7 +1497,9 @@ def map_weight(
weight = nn.Parameter(weight, requires_grad=False)
return weight

local_policy_weights = TensorDictParams(policy_weights.apply(map_weight))
local_policy_weights = TensorDictParams(
policy_weights.apply(map_weight, filter_empty=False)
)

def _get_weight_fn(weights=policy_weights):
# This function will give the local_policy_weight the original weights.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
if data.ndim:
priority = self._get_priority_vector(data)
else:
priority = self._get_priority_item(data)
priority = torch.as_tensor(self._get_priority_item(data))
index = data.get("index")
while index.shape != priority.shape:
# reduce index
Expand Down
116 changes: 83 additions & 33 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from torchrl._utils import _replace_last
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES

try:
from torchrl._torchrl import (
Expand Down Expand Up @@ -250,11 +249,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
class PrioritizedSampler(Sampler):
"""Prioritized sampler for replay buffer.
Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
Prioritized experience replay."
(https://arxiv.org/abs/1511.05952)
Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952)
Args:
max_capacity (int): maximum capacity of the buffer.
alpha (float): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
Expand All @@ -264,6 +262,51 @@ class PrioritizedSampler(Sampler):
tensordicts (ie stored trajectory). Can be one of "max", "min",
"median" or "mean".
Examples:
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> rb.add(data_0)
>>> rb.add(data_1)
>>> rb.update_priority(torch.tensor([0, 1]), priority=priority)
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample)
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> print(info)
{'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
.. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the
process of updating the priorities:
>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = TDRB(
... storage=LazyTensorStorage(10),
... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
... priority_key="priority", # This kwarg isn't present in regular RBs
... )
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> data = torch.stack([data_0, data_1])
>>> rb.extend(data)
>>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample['index']) # The index is packed with the tensordict
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
"""

def __init__(
Expand Down Expand Up @@ -327,15 +370,17 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
raise RuntimeError("negative p_sum")
if p_min <= 0:
raise RuntimeError("negative p_min")
# For some undefined reason, only np.random works here.
# All PT attempts fail, even when subsequently transformed into numpy
mass = np.random.uniform(0.0, p_sum, size=batch_size)
# mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum)
# mass = torch.rand(batch_size).mul_(p_sum)
index = self._sum_tree.scan_lower_bound(mass)
if not isinstance(index, np.ndarray):
index = np.array([index])
if isinstance(index, torch.Tensor):
index.clamp_max_(len(storage) - 1)
else:
index = np.clip(index, None, len(storage) - 1)
weight = self._sum_tree[index]
index = torch.as_tensor(index)
if not index.ndim:
index = index.unsqueeze(0)
index.clamp_max_(len(storage) - 1)
weight = torch.as_tensor(self._sum_tree[index])

# Importance sampling weight formula:
# w_i = (p_i / sum(p) * N) ^ (-beta)
Expand All @@ -345,9 +390,10 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
# weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta)
# weight_i = (p_i / min(p)) ^ (-beta)
# weight = np.power(weight / (p_min + self._eps), -self._beta)
weight = np.power(weight / p_min, -self._beta)
weight = torch.pow(weight / p_min, -self._beta)
return index, {"_weight": weight}

@torch.no_grad()
def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
priority = self.default_priority

Expand All @@ -360,6 +406,11 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None:
"priority should be a scalar or an iterable of the same "
"length as index"
)
# make sure everything is cast to cpu
if isinstance(index, torch.Tensor) and not index.is_cpu:
index = index.cpu()
if isinstance(priority, torch.Tensor) and not priority.is_cpu:
priority = priority.cpu()

self._sum_tree[index] = priority
self._min_tree[index] = priority
Expand All @@ -377,6 +428,7 @@ def extend(self, index: torch.Tensor) -> None:
index = index.cpu()
self._add_or_extend(index)

@torch.no_grad()
def update_priority(
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor]
) -> None:
Expand All @@ -389,28 +441,26 @@ def update_priority(
indexed elements.
"""
if isinstance(index, INT_CLASSES):
if not isinstance(priority, float):
if len(priority) != 1:
raise RuntimeError(
f"priority length should be 1, got {len(priority)}"
)
priority = priority.item()
else:
if not (
isinstance(priority, float)
or len(priority) == 1
or len(index) == len(priority)
):
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
index = torch.as_tensor(
index, dtype=torch.long, device=torch.device("cpu")
).detach()
# we need to reshape priority if it has more than one elements or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
try:
priority = priority.reshape(index.shape[:1])
except Exception as err:
raise RuntimeError(
"priority should be a number or an iterable of the same "
"length as index"
)
index = _to_numpy(index)
priority = _to_numpy(priority)

self._max_priority = max(self._max_priority, np.max(priority))
priority = np.power(priority + self._eps, self._alpha)
f"length as index. Got priority of shape {priority.shape} and index "
f"{index.shape}."
) from err
elif priority.numel() <= 1:
priority = priority.squeeze()

self._max_priority = priority.max().clamp_min(self._max_priority).item()
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority

Expand Down Expand Up @@ -1471,7 +1521,7 @@ def __getitem__(self, index):
if isinstance(index, slice) and index == slice(None):
return self
if isinstance(index, (list, range, np.ndarray)):
index = torch.tensor(index)
index = torch.as_tensor(index)
if isinstance(index, torch.Tensor):
if index.ndim > 1:
raise RuntimeError(
Expand Down
Loading

0 comments on commit 56de68a

Please sign in to comment.