Skip to content

Commit

Permalink
[Feature] allow for tensor-type parameters in ParameterScheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed Sep 26, 2024
1 parent 5aa2a05 commit 915d1c4
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 33 deletions.
78 changes: 52 additions & 26 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3032,47 +3032,73 @@ def test_prioritized_slice_sampler_episodes(device):
), "after priority update, only episode 1 and 3 are expected to be sampled"


def test_prioritized_parameter_scheduler():
INIT_ALPHA = 0.7
INIT_BETA = 0.6
GAMMA = 0.1
EVERY_N_STEPS = 10
LINEAR_STEPS = 100
TOTAL_STEPS = 200
@pytest.mark.parametrize("alpha", [0.6, torch.tensor(1.0)])
@pytest.mark.parametrize("beta", [0.7, torch.tensor(0.1)])
@pytest.mark.parametrize("gamma", [0.1])
@pytest.mark.parametrize("total_steps", [200])
@pytest.mark.parametrize("n_annealing_steps", [100])
@pytest.mark.parametrize("anneal_every_n", [10, 159])
@pytest.mark.parametrize("alpha_min", [0, 0.2])
@pytest.mark.parametrize("beta_max", [1, 1.4])
def test_prioritized_parameter_scheduler(
alpha,
beta,
gamma,
total_steps,
n_annealing_steps,
anneal_every_n,
alpha_min,
beta_max,
):
rb = TensorDictPrioritizedReplayBuffer(
alpha=INIT_ALPHA, beta=INIT_BETA, storage=ListStorage(max_size=2000)
alpha=alpha, beta=beta, storage=ListStorage(max_size=1000)
)
data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000)
rb.extend(data)
alpha_scheduler = LinearScheduler(
rb, param_name="alpha", final_value=0.0, num_steps=LINEAR_STEPS
rb, param_name="alpha", final_value=alpha_min, num_steps=n_annealing_steps
)
beta_scheduler = StepScheduler(
rb,
param_name="beta",
gamma=GAMMA,
n_steps=EVERY_N_STEPS,
max_value=1.0,
gamma=gamma,
n_steps=anneal_every_n,
max_value=beta_max,
mode="additive",
)

scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler))
expected_alpha_vals = np.linspace(INIT_ALPHA, 0.0, num=LINEAR_STEPS + 1)
expected_alpha_vals = np.pad(
expected_alpha_vals, (0, TOTAL_STEPS - LINEAR_STEPS), constant_values=0.0

alpha = alpha if torch.is_tensor(alpha) else torch.tensor(alpha)
alpha_min = torch.tensor(alpha_min)
expected_alpha_vals = torch.linspace(alpha, alpha_min, n_annealing_steps + 1)
expected_alpha_vals = torch.nn.functional.pad(
expected_alpha_vals, (0, total_steps - n_annealing_steps), value=alpha_min
)
expected_beta_vals = [INIT_BETA]
for _ in range((TOTAL_STEPS // EVERY_N_STEPS - 1)):
expected_beta_vals.append(expected_beta_vals[-1] + GAMMA)

expected_beta_vals = [beta]
annealing_steps = total_steps // anneal_every_n
gammas = torch.arange(0, annealing_steps + 1, dtype=torch.float32) * gamma
expected_beta_vals = (
np.atleast_2d(expected_beta_vals).repeat(EVERY_N_STEPS).clip(None, 1.0)
(beta + gammas).repeat_interleave(anneal_every_n).clip(None, beta_max)
)
for i in range(TOTAL_STEPS):
assert np.isclose(
rb.sampler.alpha, expected_alpha_vals[i]
), f"expected {expected_alpha_vals[i]}, got {rb.sampler.alpha}"
assert np.isclose(
rb.sampler.beta, expected_beta_vals[i]
), f"expected {expected_beta_vals[i]}, got {rb.sampler.beta}"
for i in range(total_steps):
curr_alpha = rb.sampler.alpha
torch.testing.assert_close(
curr_alpha
if torch.is_tensor(curr_alpha)
else torch.tensor(curr_alpha).float(),
expected_alpha_vals[i],
msg=f"expected {expected_alpha_vals[i]}, got {curr_alpha}",
)
curr_beta = rb.sampler.beta
torch.testing.assert_close(
curr_beta
if torch.is_tensor(curr_beta)
else torch.tensor(curr_beta).float(),
expected_beta_vals[i],
msg=f"expected {expected_beta_vals[i]}, got {curr_beta}",
)
rb.sample(20)
scheduler.step()

Expand Down
34 changes: 27 additions & 7 deletions torchrl/data/replay_buffers/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from abc import ABC, abstractmethod

from typing import Any, Callable, Dict

import numpy as np

import torch

from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler


class ParameterScheduler:
class ParameterScheduler(ABC):
"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler.
Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler.
Expand Down Expand Up @@ -34,13 +44,19 @@ def __init__(
)
self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj
self.param_name = param_name
self._min_val = min_value
self._max_val = max_value
self._min_val = min_value or float("-inf")
self._max_val = max_value or float("inf")
if not hasattr(self.sampler, self.param_name):
raise ValueError(
f"Provided class {type(obj).__name__} does not have an attribute {param_name}"
)
self.initial_val = getattr(self.sampler, self.param_name)
initial_val = getattr(self.sampler, self.param_name)
if isinstance(initial_val, torch.Tensor):
initial_val = initial_val.clone()
self.backend = torch
else:
self.backend = np
self.initial_val = initial_val
self._step_cnt = 0

def state_dict(self):
Expand All @@ -67,14 +83,15 @@ def step(self):
# Apply the step function
new_value = self._step()
# clip value to specified range
new_value_clipped = np.clip(new_value, a_min=self._min_val, a_max=self._max_val)
new_value_clipped = self.backend.clip(new_value, self._min_val, self._max_val)
# Set the new value of the parameter dynamically
setattr(self.sampler, self.param_name, new_value_clipped)

@abstractmethod
def _step(self):
...


class LambdaScheduler(ParameterScheduler):
"""Sets a parameter to its initial value times a given function.
Expand Down Expand Up @@ -144,6 +161,9 @@ def __init__(
num_steps: int,
):
super().__init__(obj, param_name)
if isinstance(self.initial_val, torch.Tensor):
# cast to same type as initial value
final_value = torch.tensor(final_value).to(self.initial_val)
self.final_val = final_value
self.num_steps = num_steps
self._delta = (self.final_val - self.initial_val) / self.num_steps
Expand Down Expand Up @@ -208,9 +228,9 @@ def __init__(
self.n_steps = n_steps
self.mode = mode
if mode == "additive":
operator = np.add
operator = self.backend.add
elif mode == "multiplicative":
operator = np.multiply
operator = self.backend.multiply
else:
raise ValueError(
f"Invalid mode: {mode}. Choose 'multiplicative' or 'additive'."
Expand Down

0 comments on commit 915d1c4

Please sign in to comment.