Skip to content

Commit

Permalink
[Feature] RNG for RBs (#2379)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Aug 8, 2024
1 parent 342450e commit 918bfe6
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 33 deletions.
152 changes: 133 additions & 19 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@
".".join([str(s) for s in version.parse(str(torch.__version__)).release])
) >= version.parse("2.3.0")

ReplayBufferRNG = functools.partial(ReplayBuffer, generator=torch.Generator())
TensorDictReplayBufferRNG = functools.partial(
TensorDictReplayBuffer, generator=torch.Generator()
)


@pytest.mark.parametrize(
"sampler",
Expand All @@ -125,17 +130,27 @@
"rb_type,storage,datatype",
[
[ReplayBuffer, ListStorage, None],
[ReplayBufferRNG, ListStorage, None],
[TensorDictReplayBuffer, ListStorage, "tensordict"],
[TensorDictReplayBufferRNG, ListStorage, "tensordict"],
[RemoteTensorDictReplayBuffer, ListStorage, "tensordict"],
[ReplayBuffer, LazyTensorStorage, "tensor"],
[ReplayBuffer, LazyTensorStorage, "tensordict"],
[ReplayBuffer, LazyTensorStorage, "pytree"],
[ReplayBufferRNG, LazyTensorStorage, "tensor"],
[ReplayBufferRNG, LazyTensorStorage, "tensordict"],
[ReplayBufferRNG, LazyTensorStorage, "pytree"],
[TensorDictReplayBuffer, LazyTensorStorage, "tensordict"],
[TensorDictReplayBufferRNG, LazyTensorStorage, "tensordict"],
[RemoteTensorDictReplayBuffer, LazyTensorStorage, "tensordict"],
[ReplayBuffer, LazyMemmapStorage, "tensor"],
[ReplayBuffer, LazyMemmapStorage, "tensordict"],
[ReplayBuffer, LazyMemmapStorage, "pytree"],
[ReplayBufferRNG, LazyMemmapStorage, "tensor"],
[ReplayBufferRNG, LazyMemmapStorage, "tensordict"],
[ReplayBufferRNG, LazyMemmapStorage, "pytree"],
[TensorDictReplayBuffer, LazyMemmapStorage, "tensordict"],
[TensorDictReplayBufferRNG, LazyMemmapStorage, "tensordict"],
[RemoteTensorDictReplayBuffer, LazyMemmapStorage, "tensordict"],
],
)
Expand Down Expand Up @@ -1155,17 +1170,115 @@ def test_replay_buffer_trajectories(stack, reduction, datatype):
# sampled_td_filtered.batch_size = [3, 4]


class TestRNG:
def test_rb_rng(self):
state = torch.random.get_rng_state()
rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100))
rb.extend(torch.arange(100))
rb._rng.set_state(state)
a = rb.sample(32)
rb._rng.set_state(state)
b = rb.sample(32)
assert (a == b).all()
c = rb.sample(32)
assert (a != c).any()

def test_prb_rng(self):
state = torch.random.get_rng_state()
rb = ReplayBuffer(
sampler=PrioritizedSampler(100, 1.0, 1.0),
storage=LazyTensorStorage(100),
generator=torch.Generator(),
)
rb.extend(torch.arange(100))
rb.update_priority(index=torch.arange(100), priority=torch.arange(1, 101))

rb._rng.set_state(state)
a = rb.sample(32)

rb._rng.set_state(state)
b = rb.sample(32)
assert (a == b).all()

c = rb.sample(32)
assert (a != c).any()

def test_slice_rng(self):
state = torch.random.get_rng_state()
rb = ReplayBuffer(
sampler=SliceSampler(num_slices=4),
storage=LazyTensorStorage(100),
generator=torch.Generator(),
)
done = torch.zeros(100, 1, dtype=torch.bool)
done[49] = 1
done[-1] = 1
data = TensorDict(
{
"data": torch.arange(100),
("next", "done"): done,
},
batch_size=[100],
)
rb.extend(data)

rb._rng.set_state(state)
a = rb.sample(32)

rb._rng.set_state(state)
b = rb.sample(32)
assert (a == b).all()

c = rb.sample(32)
assert (a != c).any()

def test_rng_state_dict(self):
state = torch.random.get_rng_state()
rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100))
rb.extend(torch.arange(100))
rb._rng.set_state(state)
sd = rb.state_dict()
assert sd.get("_rng") is not None
a = rb.sample(32)

rb.load_state_dict(sd)
b = rb.sample(32)
assert (a == b).all()
c = rb.sample(32)
assert (a != c).any()

def test_rng_dumps(self, tmpdir):
state = torch.random.get_rng_state()
rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100))
rb.extend(torch.arange(100))
rb._rng.set_state(state)
rb.dumps(tmpdir)
a = rb.sample(32)

rb.loads(tmpdir)
b = rb.sample(32)
assert (a == b).all()
c = rb.sample(32)
assert (a != c).any()


@pytest.mark.parametrize(
"rbtype,storage",
[
(ReplayBuffer, None),
(ReplayBuffer, ListStorage),
(ReplayBufferRNG, None),
(ReplayBufferRNG, ListStorage),
(PrioritizedReplayBuffer, None),
(PrioritizedReplayBuffer, ListStorage),
(TensorDictReplayBuffer, None),
(TensorDictReplayBuffer, ListStorage),
(TensorDictReplayBuffer, LazyTensorStorage),
(TensorDictReplayBuffer, LazyMemmapStorage),
(TensorDictReplayBufferRNG, None),
(TensorDictReplayBufferRNG, ListStorage),
(TensorDictReplayBufferRNG, LazyTensorStorage),
(TensorDictReplayBufferRNG, LazyMemmapStorage),
(TensorDictPrioritizedReplayBuffer, None),
(TensorDictPrioritizedReplayBuffer, ListStorage),
(TensorDictPrioritizedReplayBuffer, LazyTensorStorage),
Expand All @@ -1175,33 +1288,34 @@ def test_replay_buffer_trajectories(stack, reduction, datatype):
@pytest.mark.parametrize("size", [3, 5, 100])
@pytest.mark.parametrize("prefetch", [0])
class TestBuffers:
_default_params_rb = {}
_default_params_td_rb = {}
_default_params_prb = {"alpha": 0.8, "beta": 0.9}
_default_params_td_prb = {"alpha": 0.8, "beta": 0.9}

default_constr = {
ReplayBuffer: ReplayBuffer,
PrioritizedReplayBuffer: functools.partial(
PrioritizedReplayBuffer, alpha=0.8, beta=0.9
),
TensorDictReplayBuffer: TensorDictReplayBuffer,
TensorDictPrioritizedReplayBuffer: functools.partial(
TensorDictPrioritizedReplayBuffer, alpha=0.8, beta=0.9
),
TensorDictReplayBufferRNG: TensorDictReplayBufferRNG,
ReplayBufferRNG: ReplayBufferRNG,
}

def _get_rb(self, rbtype, size, storage, prefetch):
if storage is not None:
storage = storage(size)
if rbtype is ReplayBuffer:
params = self._default_params_rb
elif rbtype is PrioritizedReplayBuffer:
params = self._default_params_prb
elif rbtype is TensorDictReplayBuffer:
params = self._default_params_td_rb
elif rbtype is TensorDictPrioritizedReplayBuffer:
params = self._default_params_td_prb
else:
raise NotImplementedError(rbtype)
rb = rbtype(storage=storage, prefetch=prefetch, batch_size=3, **params)
rb = self.default_constr[rbtype](
storage=storage, prefetch=prefetch, batch_size=3
)
return rb

def _get_datum(self, rbtype):
if rbtype is ReplayBuffer:
if rbtype in (ReplayBuffer, ReplayBufferRNG):
data = torch.randint(100, (1,))
elif rbtype is PrioritizedReplayBuffer:
data = torch.randint(100, (1,))
elif rbtype is TensorDictReplayBuffer:
elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG):
data = TensorDict({"a": torch.randint(100, (1,))}, [])
elif rbtype is TensorDictPrioritizedReplayBuffer:
data = TensorDict({"a": torch.randint(100, (1,))}, [])
Expand All @@ -1210,11 +1324,11 @@ def _get_datum(self, rbtype):
return data

def _get_data(self, rbtype, size):
if rbtype is ReplayBuffer:
if rbtype in (ReplayBuffer, ReplayBufferRNG):
data = [torch.randint(100, (1,)) for _ in range(size)]
elif rbtype is PrioritizedReplayBuffer:
data = [torch.randint(100, (1,)) for _ in range(size)]
elif rbtype is TensorDictReplayBuffer:
elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG):
data = TensorDict(
{
"a": torch.randint(100, (size,)),
Expand Down
Loading

0 comments on commit 918bfe6

Please sign in to comment.