Skip to content

Commit

Permalink
[Feature] Add scheduler for alpha/beta parameters of PrioritizedSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed Sep 24, 2024
1 parent 33e86c5 commit e2337ef
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 0 deletions.
45 changes: 45 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@
UnsqueezeTransform,
VecNorm,
)
from torchrl.data.replay_buffers.scheduler import (
LinearScheduler,
StepScheduler,
SchedulerList
)


OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
_has_tv = importlib.util.find_spec("torchvision") is not None
Expand Down Expand Up @@ -3026,6 +3032,45 @@ def test_prioritized_slice_sampler_episodes(device):
), "after priority update, only episode 1 and 3 are expected to be sampled"


def test_prioritized_parameter_scheduler():
INIT_ALPHA = 0.7
INIT_BETA = 0.6
GAMMA = 0.1
EVERY_N_STEPS = 10
LINEAR_STEPS = 100
TOTAL_STEPS = 200
rb = TensorDictPrioritizedReplayBuffer(
alpha=INIT_ALPHA,
beta=INIT_BETA,
storage=ListStorage(max_size=2000)
)
data = TensorDict(
{
"data": torch.randn(1000, 5)
},
batch_size=1000
)
rb.extend(data)
alpha_scheduler = LinearScheduler(
rb, param_name="alpha", final_value=0.0, num_steps=LINEAR_STEPS
)
beta_scheduler = StepScheduler(
rb, param_name="beta", gamma=GAMMA, n_steps=EVERY_N_STEPS, max_value=1.0, mode="additive"
)
scheduler = SchedulerList(scheduler=(alpha_scheduler, beta_scheduler))
expected_alpha_vals = np.linspace(INIT_ALPHA, 0.0, num=LINEAR_STEPS+1)
expected_alpha_vals = np.pad(expected_alpha_vals, (0, TOTAL_STEPS-LINEAR_STEPS), constant_values=0.0)
expected_beta_vals = [INIT_BETA]
for _ in range((TOTAL_STEPS // EVERY_N_STEPS -1)):
expected_beta_vals.append(expected_beta_vals[-1] + GAMMA)
expected_beta_vals = np.atleast_2d(expected_beta_vals).repeat(EVERY_N_STEPS).clip(None, 1.0)
for i in range(TOTAL_STEPS):
assert np.isclose(rb.sampler.alpha, expected_alpha_vals[i]), f"expected {expected_alpha_vals[i]}, got {rb.sampler.alpha}"
assert np.isclose(rb.sampler.beta, expected_beta_vals[i]), f"expected {expected_beta_vals[i]}, got {rb.sampler.beta}"
rb.sample(20)
scheduler.step()


class TestEnsemble:
def _make_data(self, data_type):
if data_type is torch.Tensor:
Expand Down
17 changes: 17 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,23 @@ def __repr__(self):
@property
def max_size(self):
return self._max_capacity

@property
def alpha(self):
return self._alpha

@alpha.setter
def alpha(self, value):
self._alpha = value

@property
def beta(self):
return self._beta

@beta.setter
def beta(self, value):
self._beta = value


def __getstate__(self):
if get_spawning_popen() is not None:
Expand Down
235 changes: 235 additions & 0 deletions torchrl/data/replay_buffers/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import numpy as np
from typing import Callable, Dict, Any

from .replay_buffers import ReplayBuffer
from .samplers import Sampler


class ParameterScheduler:

"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler, e.g. the
alpha and beta values in the PrioritizedSampler.
Args:
rb (ReplayBuffer): the replay buffer whose sampler to adjust
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
Defaults to None.
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
Defaults to None
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
min_value: int | float = None,
max_value: int | float = None
):
if not isinstance(obj, ReplayBuffer) and not isinstance(obj, Sampler):
raise TypeError(
f"ParameterScheduler only supports Sampler class. Pass either ReplayBuffer or Sampler object. Got {type(obj)}"
)
self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj
self.param_name = param_name
self._min_val = min_value
self._max_val = max_value
if not hasattr(self.sampler, self.param_name):
raise ValueError(f"Provided class {obj.__name__} does not have an attribute {param_name}")
self.initial_val = getattr(self.sampler, self.param_name)
self._step_cnt = 0

def state_dict(self):
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
key: value for key, value in self.__dict__.items() if key != "sampler"
}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)

def step(self):
self._step_cnt += 1
# Apply the step function
new_value = self._step()
# clip value to specified range
new_value_clipped = np.clip(new_value, a_min=self._min_val, a_max=self._max_val)
# Set the new value of the parameter dynamically
setattr(self.sampler, self.param_name, new_value_clipped)

def _step(self):
raise NotImplementedError


class LambdaScheduler(ParameterScheduler):
"""Similar to torch.optim.LambdaLR, this class sets a parameter to its initial value
times a given function.
Args:
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself)
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter
lambda_fn (function): A function which computes a multiplicative factor given an integer
parameter step_count
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
Defaults to None.
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
Defaults to None
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
lambda_fn: Callable[[int], float],
min_value: int | float = None,
max_value: int | float = None
):
super().__init__(obj, param_name, min_value, max_value)
self.lambda_fn = lambda_fn

def _step(self):
return self.initial_val * self.lambda_fn(self._step_cnt)



class LinearScheduler(ParameterScheduler):
"""A linear scheduler for gradually altering a parameter in an object over a given number of steps.
This scheduler linearly interpolates between the initial value of the parameter and a final target value.
Args:
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself)
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter
final_value (Union[int, float]): The final value that the parameter will reach after the
specified number of steps.
num_steps (Union[int, float], optional): The total number of steps over which the parameter
will be linearly altered.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
>>> # beta = 0.7 if step == 1
>>> # beta = 0.8 if step == 2
>>> # beta = 0.9 if step == 3
>>> # beta = 1.0 if step >= 4
>>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
final_value: int | float,
num_steps: int
):
super().__init__(obj, param_name)
self.final_val = final_value
self.num_steps = num_steps
self._delta = (self.final_val - self.initial_val) / self.num_steps

def _step(self):
if self._step_cnt < self.num_steps:
return self.initial_val + (self._delta * self._step_cnt)
else:
return self.final_val



class StepScheduler(ParameterScheduler):
"""
A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes.
The scheduler can apply:
1. Multiplicative changes: `new_val = curr_val * gamma`
2. Additive changes: `new_val = curr_val + gamma`
Args:
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself)
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter
gamma (int | float, optional): The value by which to adjust the parameter,
either multiplicatively or additive
n_steps (int, optional): The number of steps after which the parameter should be altered.
Defaults to 1
mode (str, optional): The mode of scheduling. Can be either 'multiplicative' or 'additive'.
Defaults to 'multiplicative'
min_value (int | float, optional): a lower bound for the parameter to be adjusted
Defaults to None.
max_value (int | float, optional): an upper bound for the parameter to be adjusted
Defaults to None
Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
>>> # beta = 0.6 if step < 10
>>> # beta = 0.7 if step == 10
>>> # beta = 0.8 if step == 20
>>> # beta = 0.9 if step == 30
>>> # beta = 1.0 if step >= 40
>>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
gamma: int | float = 0.9,
n_steps: int = 1,
mode: str = "multiplicative",
min_value: int | float = None,
max_value: int | float = None
):

super().__init__(obj, param_name, min_value, max_value)
self.gamma = gamma
self.n_steps = n_steps
if mode == "additive":
operator = np.add
elif mode == "multiplicative":
operator = np.multiply
else:
raise ValueError(f"Invalid mode: {self.mode}. Choose 'multiplicative' or 'additive'.")
self.operator = operator

def _step(self):
"""Applies the scheduling logic to alter the parameter value every `n_steps`."""
# Check if the current step count is a multiple of n_steps
current_val = getattr(self.sampler, self.param_name)
if self._step_cnt % self.n_steps == 0:
return self.operator(current_val, self.gamma)
else:
return current_val


class SchedulerList:
def __init__(self, scheduler: list[ParameterScheduler]) -> None:
if isinstance(scheduler, ParameterScheduler):
scheduler = [scheduler]
self.scheduler = scheduler

def append(self, scheduler: ParameterScheduler):
self.scheduler.append(scheduler)

def step(self):
for scheduler in self.scheduler:
scheduler.step()

0 comments on commit e2337ef

Please sign in to comment.