From 107a6c5f9f5bc78d2d385459aeb539258ee57daf Mon Sep 17 00:00:00 2001 From: Onno Eberhard Date: Thu, 1 Dec 2022 20:23:55 +0100 Subject: [PATCH] Initial commit --- README.md | 94 ++++++++++++++++++++++++++++++- examples/example.py | 48 ++++++++++++++++ pink/__init__.py | 11 ++++ pink/cnrl.py | 42 ++++++++++++++ pink/colorednoise.py | 129 +++++++++++++++++++++++++++++++++++++++++++ pink/sb3.py | 119 +++++++++++++++++++++++++++++++++++++++ pink/tonic.py | 35 ++++++++++++ pyproject.toml | 21 +++++++ 8 files changed, 498 insertions(+), 1 deletion(-) create mode 100644 examples/example.py create mode 100644 pink/__init__.py create mode 100644 pink/cnrl.py create mode 100644 pink/colorednoise.py create mode 100644 pink/sb3.py create mode 100644 pink/tonic.py create mode 100644 pyproject.toml diff --git a/README.md b/README.md index c91b42e..c46663e 100644 --- a/README.md +++ b/README.md @@ -1 +1,93 @@ -# Pink Noise Is All You Need \ No newline at end of file +# Colored Action Noise for Deep RL + +This repository contains easy-to-use implementations of pink noise and general colored noise for use as action noise in deep reinforcement learning. Included are the following classes: +- `ColoredNoiseProcess` and `PinkNoiseProcess` for general use, based on the [colorednoise](https://github.com/felixpatzelt/colorednoise) library +- `ColoredActionNoise` and `PinkActionNoise` to be used with deterministic policy algorithms like DDPG and TD3 in Stable Baselines3, both are subclasses of `stable_baselines3.common.noise.ActionNoise` +- `ColoredNoiseDist`, `PinkNoiseDist` to be used with stochastic policy algorithms like SAC in Stable Baselines3 +- `MPO_CN` for using colored noise (incl. pink noise) with MPO using the Tonic RL library. + +For more information, please see our paper: [Pink Noise Is All You Need: Colored Noise Exploration in Deep Reinforcement Learning](https://bit.ly/pink-noise-rl). + +## Installation +You can install the library via pip: +``` +pip install pink-noise-rl +``` +Note: In Python, the import statement is simply `import pink`. + +## Usage +We provide minimal examples for using pink noise on SAC, TD3 and MPO below. An example comparing pink noise with the default action noise of SAC is included in the `examples` directory. + +### Stable Baselines3: SAC, TD3 +```python +import gym +from stable_baselines3 import SAC, TD3 + +# All classes mentioned above can be imported from `pink` +from pink import PinkNoiseDist, PinkActionNoise + +# Initialize environment +env = gym.make("MountainCarContinuous-v0") +action_dim = env.action_space.shape[-1] +seq_len = env._max_episode_steps +``` + +#### SAC +```python +# Initialize agent +model = SAC("MlpPolicy", env) + +# Set action noise +model.actor.action_dist = PinkNoiseDist(action_dim, seq_len) + +# Train agent +model.learn(total_timesteps=10_000) +``` + +#### TD3 +```python +# Initialize agent +model = TD3("MlpPolicy", env) + +# Set action noise +noise_scale = 0.3*np.ones(action_dim) +model.action_noise = PinkActionNoise(noise_scale, seq_len) + +# Train agent +model.learn(total_timesteps=10_000) +``` + +### Tonic: MPO +```python +import gym +from tonic import Trainer +from pink import MPO_CN + +# Initialize environment +env = gym.make("MountainCarContinuous-v0") +seq_len = env._max_episode_steps + +# Initialize agent with pink noise +beta = 1 +model = MPO_CN() +model.initialize(beta, seq_len, env.observation_space, env.action_space) + +# Train agent +trainer = tonic.Trainer(steps=10_000) +trainer.initialize(model, env) +trainer.run() +``` + + +## Citing +If you use this code in your research, please cite our paper: +```bibtex +@misc{eberhard-2022-pink, + title = {Pink {{Noise Is All You Need}}: {{Colored Noise Exploration}} in {{Deep Reinforcement Learning}}}, + author = {Eberhard, Onno and Hollenstein, Jakob and Pinneri, Cristina and Martius, Georg}, + date = {2022}, + howpublished = {NeurIPS Deep RL Workshop 2022} +} +``` + +If there are any problems, or you have a question, don't hesitate to open an issue here on GitHub. diff --git a/examples/example.py b/examples/example.py new file mode 100644 index 0000000..342de6a --- /dev/null +++ b/examples/example.py @@ -0,0 +1,48 @@ +"""Comparing pink action noise with the default noise on SAC.""" + +import gym +from stable_baselines3 import SAC + +from pink import PinkNoiseDist + +# Initialize environment +env = gym.make("MountainCarContinuous-v0") +action_dim = env.action_space.shape[-1] +seq_len = env._max_episode_steps + +# Initialize agents +model_default = SAC("MlpPolicy", env) +model_pink = SAC("MlpPolicy", env) + +# Set action noise +model_pink.actor.action_dist = PinkNoiseDist(action_dim, seq_len) + +# Train agents +model_default.learn(total_timesteps=10_000) +model_pink.learn(total_timesteps=10_000) + +# Evaluate learned policies +N = 100 +for name, model in zip(["Default noise\n-------------", "Pink noise\n----------"], [model_default, model_pink]): + solved = 0 + for i in range(N): + obs = env.reset() + done = False + while not done: + obs, r, done, _ = env.step(model.predict(obs, deterministic=True)[0]) + if r > 0: + solved += 1 + break + + print(name) + print(f"Solved: {solved/N * 100:.0f}%\n") + + +# - Output of this program - +# Default noise +# ------------- +# Solved: 0% +# +# Pink noise +# ---------- +# Solved: 100% diff --git a/pink/__init__.py b/pink/__init__.py new file mode 100644 index 0000000..888b2ea --- /dev/null +++ b/pink/__init__.py @@ -0,0 +1,11 @@ +from .cnrl import * + +try: + from .sb3 import * +except: + pass + +try: + from .tonic import * +except: + pass diff --git a/pink/cnrl.py b/pink/cnrl.py new file mode 100644 index 0000000..30dbac9 --- /dev/null +++ b/pink/cnrl.py @@ -0,0 +1,42 @@ +from . import colorednoise as cn + +class ColoredNoiseProcess(): + def __init__(self, beta, scale=1, chunksize=None, largest_wavelength=None, rng=None): + """Colored noise implemented as a process that allows subsequent samples. + Implemented as a buffer; every "chunksize[-1]" items, a cut to a new time series starts. + """ + self.beta = beta + if largest_wavelength is None: + self.minimum_frequency = 0 + else: + self.minimum_frequency = 1 / largest_wavelength + self.scale = scale + self.rng = rng + + # The last component of chunksize is the time index + try: + self.chunksize = list(chunksize) + except TypeError: + self.chunksize = [chunksize] + self.time_steps = self.chunksize[-1] + + # Set first time-step such that buffer will be initialized + self.idx = self.time_steps + + def sample(self): + self.idx += 1 # Next time step + + # Refill buffer if depleted + if self.idx >= self.time_steps: + self.buffer = cn.powerlaw_psd_gaussian( + exponent=self.beta, size=self.chunksize, fmin=self.minimum_frequency, rng=self.rng) + self.idx = 0 + + return self.scale * self.buffer[..., self.idx] + +class PinkNoiseProcess(ColoredNoiseProcess): + def __init__(self, scale=1, chunksize=None, largest_wavelength=None, rng=None): + """Colored noise implemented as a process that allows subsequent samples. + Implemented as a buffer; every "chunksize[-1]" items, a cut to a new time series starts. + """ + super().__init__(1, scale, chunksize, largest_wavelength, rng) diff --git a/pink/colorednoise.py b/pink/colorednoise.py new file mode 100644 index 0000000..f2edbc4 --- /dev/null +++ b/pink/colorednoise.py @@ -0,0 +1,129 @@ +"""Colored noise generation script +Modified from colorednoise package: https://github.com/felixpatzelt/colorednoise +""" + +import numpy as np +from numpy.fft import irfft, rfftfreq + + +def powerlaw_psd_gaussian(exponent, size, fmin=0, rng=None): + """Gaussian (1/f)**beta noise. + + Based on the algorithm in: + Timmer, J. and Koenig, M.: + On generating power law noise. + Astron. Astrophys. 300, 707-710 (1995) + + Normalised to unit variance + + Parameters: + ----------- + + exponent : float + The power-spectrum of the generated noise is proportional to + + S(f) = (1 / f)**beta + flicker / pink noise: exponent beta = 1 + brown noise: exponent beta = 2 + + Furthermore, the autocorrelation decays proportional to lag**-gamma + with gamma = 1 - beta for 0 < beta < 1. + There may be finite-size issues for beta close to one. + + shape : int or iterable + The output has the given shape, and the desired power spectrum in + the last coordinate. That is, the last dimension is taken as time, + and all other components are independent. + + fmin : float, optional + Low-frequency cutoff. + Default: 0 corresponds to original paper. + + The power-spectrum below fmin is flat. fmin is defined relative + to a unit sampling rate (see numpy's rfftfreq). For convenience, + the passed value is mapped to max(fmin, 1/samples) internally + since 1/samples is the lowest possible finite frequency in the + sample. The largest possible value is fmin = 0.5, the Nyquist + frequency. The output for this value is white noise. + + rng : np.random.Generator, optional + Random number generator (for reproducibility). If None (default), a new + random number generator is created by calling np.random.default_rng(). + + + Returns + ------- + out : array + The samples. + + + Examples: + --------- + + >>> # generate 1/f noise == pink noise == flicker noise + >>> import colorednoise as cn + >>> y = cn.powerlaw_psd_gaussian(1, 5) + """ + + # Make sure size is a list so we can iterate it and assign to it. + try: + size = list(size) + except TypeError: + size = [size] + + # The number of samples in each time series + samples = size[-1] + + # Calculate Frequencies (we asume a sample rate of one) + # Use fft functions for real output (-> hermitian spectrum) + f = rfftfreq(samples) + + # Validate / normalise fmin + if 0 <= fmin <= 0.5: + fmin = max(fmin, 1./samples) # Low frequency cutoff + else: + raise ValueError("fmin must be chosen between 0 and 0.5.") + + # Build scaling factors for all frequencies + s_scale = f + ix = np.sum(s_scale < fmin) # Index of the cutoff + if ix and ix < len(s_scale): + s_scale[:ix] = s_scale[ix] + s_scale = s_scale**(-exponent/2.) + + # Calculate theoretical output standard deviation from scaling + w = s_scale[1:].copy() + w[-1] *= (1 + (samples % 2)) / 2. # correct f = +-0.5 + sigma = 2 * np.sqrt(np.sum(w**2)) / samples + + # Adjust size to generate one Fourier component per frequency + size[-1] = len(f) + + # Add empty dimension(s) to broadcast s_scale along last + # dimension of generated random power + phase (below) + dims_to_add = len(size) - 1 + s_scale = s_scale[(None,) * dims_to_add + (Ellipsis,)] + + # Generate scaled random power + phase + if rng is None: + rng = np.random.default_rng() + sr = rng.normal(scale=s_scale, size=size) + si = rng.normal(scale=s_scale, size=size) + + # If the signal length is even, frequencies +/- 0.5 are equal + # so the coefficient must be real. + if not (samples % 2): + si[..., -1] = 0 + sr[..., -1] *= np.sqrt(2) # Fix magnitude + + # Regardless of signal length, the DC component must be real + si[..., 0] = 0 + sr[..., 0] *= np.sqrt(2) # Fix magnitude + + # Combine power + corrected phase to Fourier components + s = sr + 1J * si + + # Transform to real time series & scale to unit variance + y = irfft(s, n=samples, axis=-1) / sigma + + return y diff --git a/pink/sb3.py b/pink/sb3.py new file mode 100644 index 0000000..f2485a0 --- /dev/null +++ b/pink/sb3.py @@ -0,0 +1,119 @@ +"""Colored noise implementations for Stable Baselines3""" + +import numpy as np +import torch as th +from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution +from stable_baselines3.common.noise import ActionNoise + +from .cnrl import ColoredNoiseProcess + + +class ColoredActionNoise(ActionNoise): + def __init__(self, beta, sigma, seq_len, rng=None): + """Action noise from a colored noise process. + + Parameters + ---------- + beta : array_like + Exponents of colored noise power-law spectra. Should be a list of the same dimensionality as the action + space (one beta for each action dimension). + sigma : array_like + Noise scales of colored noise signals. Should be a list of the same dimensionality as the action + space (one scale for each action dimension). + seq_len : int + Length of sampled colored noise signals. If sampled for longer than `seq_len` steps, a new + colored noise signal of the same length is sampled. Should usually be set to the episode length + (horizon) of the RL task. + rng : np.random.Generator, optional, by default None + Random number generator (for reproducibility). If None, a new random number generator is created by calling + `np.random.default_rng()`. + """ + super().__init__() + self._beta = beta + self._sigma = sigma + self._gen = [ColoredNoiseProcess(beta=b, scale=s, chunksize=seq_len, largest_wavelength=None, rng=rng) + for b, s in zip(beta, sigma)] + + def __call__(self) -> np.ndarray: + return np.asarray([g.sample() for g in self._gen]) + + def __repr__(self) -> str: + return f"ColoredActionNoise(beta={self._beta}, sigma={self._sigma})" + + +class PinkActionNoise(ColoredActionNoise): + def __init__(self, sigma, seq_len, rng=None): + """Action noise from a pink noise process. + + Parameters + ---------- + sigma : array_like + Noise scales of pink noise signals. Should be a list of the same dimensionality as the action + space (one scale for each action dimension). + seq_len : int + Length of sampled pink noise signals. If sampled for longer than `seq_len` steps, a new + pink noise signal of the same length is sampled. Should usually be set to the episode length + (horizon) of the RL task. + rng : np.random.Generator, optional, by default None + Random number generator (for reproducibility). If None, a new random number generator is created by calling + `np.random.default_rng()`. + """ + super().__init__(np.ones_like(sigma), sigma, seq_len, rng) + + +class ColoredNoiseDist(SquashedDiagGaussianDistribution): + def __init__(self, beta, seq_len, rng=None, epsilon=1e-6): + """ + Gaussian colored noise distribution for using colored action noise with stochastic policies. + + The colored noise is only used for sampling actions. In all other respects, this class acts like its parent + class (`SquashedDiagGaussianDistribution`). + + Parameters + ---------- + beta : array_like + Exponents of colored noise power-law spectra. Should be a list of the same dimensionality as the action + space (one beta for each action dimension). + seq_len : int + Length of sampled colored noise signals. If sampled for longer than `seq_len` steps, a new + colored noise signal of the same length is sampled. Should usually be set to the episode length + (horizon) of the RL task. + rng : np.random.Generator, optional, by default None + Random number generator (for reproducibility). If None, a new random number generator is created by calling + `np.random.default_rng()`. + epsilon : float, optional, by default 1e-6 + A small value to avoid NaN due to numerical imprecision. + """ + super().__init__(len(beta), epsilon) + self.cn_processes = [ColoredNoiseProcess(beta=b, chunksize=seq_len, largest_wavelength=None, rng=rng) + for b in beta] + + def sample(self) -> th.Tensor: + cn_sample = th.tensor([cnp.sample() for cnp in self.cn_processes]).float() + self.gaussian_actions = self.distribution.mean + self.distribution.stddev*cn_sample + return th.tanh(self.gaussian_actions) + + +class PinkNoiseDist(ColoredNoiseDist): + def __init__(self, action_dim, seq_len, rng=None, epsilon=1e-6): + """ + Gaussian pink noise distribution for using pink action noise with stochastic policies. + + The pink noise is only used for sampling actions. In all other respects, this class acts like its parent + class (`SquashedDiagGaussianDistribution`). + + Parameters + ---------- + action_dim : int + Dimension of the action space. + seq_len : int + Length of sampled colored noise signals. If sampled for longer than `seq_len` steps, a new + colored noise signal of the same length is sampled. Should usually be set to the episode length + (horizon) of the RL task. + rng : np.random.Generator, optional, by default None + Random number generator (for reproducibility). If None, a new random number generator is created by calling + `np.random.default_rng()`. + epsilon : float, optional, by default 1e-6 + A small value to avoid NaN due to numerical imprecision. + """ + super().__init__(np.ones(action_dim), seq_len, rng, epsilon) diff --git a/pink/tonic.py b/pink/tonic.py new file mode 100644 index 0000000..fc78075 --- /dev/null +++ b/pink/tonic.py @@ -0,0 +1,35 @@ +"""Colored noise implementations for Tonic RL library""" + +import numpy as np +import torch as th +from tonic.torch.agents import MPO + +from .cnrl import ColoredNoiseProcess + + +class MPO_CN(MPO): + """MPO with colored noise exploration""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def initialize(self, beta, seq_len, observation_space, action_space, rng=None, seed=None): + """For documentation of beta, seq_len, rng see `pink.sb3.ColoredNoiseDist`.""" + super().initialize(observation_space, action_space, seed) + self.seq_len = seq_len + self.rng = rng + self.action_space = action_space + self.set_beta(beta) + + def set_beta(self, beta): + if np.isscalar(beta): + beta = [beta] * self.action_space.shape[0] + self.cn_processes = [ + ColoredNoiseProcess(beta=b, chunksize=self.seq_len, largest_wavelength=None, rng=self.rng) for b in beta] + + def _step(self, observations): + observations = th.as_tensor(observations, dtype=th.float32) + cn_sample = th.tensor([[cnp.sample() for cnp in self.cn_processes]]).float() + with th.no_grad(): + loc = self.model.actor(observations).loc + scale = self.model.actor(observations).scale + return loc + scale*cn_sample diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..df0e7ec --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[tool.poetry] +name = "pink-noise-rl" +version = "1.0.0" +description = "Pink noise for exploration in reinforcement learning" +authors = ["Onno Eberhard "] +license = "MIT" +readme = "README.md" +repository = "https://github.com/martius-lab/pink-noise-rl" +packages = [ + { include = "pink" } +] + + +[tool.poetry.dependencies] +python = "^3.8" +numpy = "*" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api"