-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
81ee832
commit 107a6c5
Showing
8 changed files
with
498 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,93 @@ | ||
# Pink Noise Is All You Need | ||
# Colored Action Noise for Deep RL | ||
|
||
This repository contains easy-to-use implementations of pink noise and general colored noise for use as action noise in deep reinforcement learning. Included are the following classes: | ||
- `ColoredNoiseProcess` and `PinkNoiseProcess` for general use, based on the [colorednoise](https://github.com/felixpatzelt/colorednoise) library | ||
- `ColoredActionNoise` and `PinkActionNoise` to be used with deterministic policy algorithms like DDPG and TD3 in Stable Baselines3, both are subclasses of `stable_baselines3.common.noise.ActionNoise` | ||
- `ColoredNoiseDist`, `PinkNoiseDist` to be used with stochastic policy algorithms like SAC in Stable Baselines3 | ||
- `MPO_CN` for using colored noise (incl. pink noise) with MPO using the Tonic RL library. | ||
|
||
For more information, please see our paper: [Pink Noise Is All You Need: Colored Noise Exploration in Deep Reinforcement Learning](https://bit.ly/pink-noise-rl). | ||
|
||
## Installation | ||
You can install the library via pip: | ||
``` | ||
pip install pink-noise-rl | ||
``` | ||
Note: In Python, the import statement is simply `import pink`. | ||
|
||
## Usage | ||
We provide minimal examples for using pink noise on SAC, TD3 and MPO below. An example comparing pink noise with the default action noise of SAC is included in the `examples` directory. | ||
|
||
### Stable Baselines3: SAC, TD3 | ||
```python | ||
import gym | ||
from stable_baselines3 import SAC, TD3 | ||
|
||
# All classes mentioned above can be imported from `pink` | ||
from pink import PinkNoiseDist, PinkActionNoise | ||
|
||
# Initialize environment | ||
env = gym.make("MountainCarContinuous-v0") | ||
action_dim = env.action_space.shape[-1] | ||
seq_len = env._max_episode_steps | ||
``` | ||
|
||
#### SAC | ||
```python | ||
# Initialize agent | ||
model = SAC("MlpPolicy", env) | ||
|
||
# Set action noise | ||
model.actor.action_dist = PinkNoiseDist(action_dim, seq_len) | ||
|
||
# Train agent | ||
model.learn(total_timesteps=10_000) | ||
``` | ||
|
||
#### TD3 | ||
```python | ||
# Initialize agent | ||
model = TD3("MlpPolicy", env) | ||
|
||
# Set action noise | ||
noise_scale = 0.3*np.ones(action_dim) | ||
model.action_noise = PinkActionNoise(noise_scale, seq_len) | ||
|
||
# Train agent | ||
model.learn(total_timesteps=10_000) | ||
``` | ||
|
||
### Tonic: MPO | ||
```python | ||
import gym | ||
from tonic import Trainer | ||
from pink import MPO_CN | ||
|
||
# Initialize environment | ||
env = gym.make("MountainCarContinuous-v0") | ||
seq_len = env._max_episode_steps | ||
|
||
# Initialize agent with pink noise | ||
beta = 1 | ||
model = MPO_CN() | ||
model.initialize(beta, seq_len, env.observation_space, env.action_space) | ||
|
||
# Train agent | ||
trainer = tonic.Trainer(steps=10_000) | ||
trainer.initialize(model, env) | ||
trainer.run() | ||
``` | ||
|
||
|
||
## Citing | ||
If you use this code in your research, please cite our paper: | ||
```bibtex | ||
@misc{eberhard-2022-pink, | ||
title = {Pink {{Noise Is All You Need}}: {{Colored Noise Exploration}} in {{Deep Reinforcement Learning}}}, | ||
author = {Eberhard, Onno and Hollenstein, Jakob and Pinneri, Cristina and Martius, Georg}, | ||
date = {2022}, | ||
howpublished = {NeurIPS Deep RL Workshop 2022} | ||
} | ||
``` | ||
|
||
If there are any problems, or you have a question, don't hesitate to open an issue here on GitHub. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
"""Comparing pink action noise with the default noise on SAC.""" | ||
|
||
import gym | ||
from stable_baselines3 import SAC | ||
|
||
from pink import PinkNoiseDist | ||
|
||
# Initialize environment | ||
env = gym.make("MountainCarContinuous-v0") | ||
action_dim = env.action_space.shape[-1] | ||
seq_len = env._max_episode_steps | ||
|
||
# Initialize agents | ||
model_default = SAC("MlpPolicy", env) | ||
model_pink = SAC("MlpPolicy", env) | ||
|
||
# Set action noise | ||
model_pink.actor.action_dist = PinkNoiseDist(action_dim, seq_len) | ||
|
||
# Train agents | ||
model_default.learn(total_timesteps=10_000) | ||
model_pink.learn(total_timesteps=10_000) | ||
|
||
# Evaluate learned policies | ||
N = 100 | ||
for name, model in zip(["Default noise\n-------------", "Pink noise\n----------"], [model_default, model_pink]): | ||
solved = 0 | ||
for i in range(N): | ||
obs = env.reset() | ||
done = False | ||
while not done: | ||
obs, r, done, _ = env.step(model.predict(obs, deterministic=True)[0]) | ||
if r > 0: | ||
solved += 1 | ||
break | ||
|
||
print(name) | ||
print(f"Solved: {solved/N * 100:.0f}%\n") | ||
|
||
|
||
# - Output of this program - | ||
# Default noise | ||
# ------------- | ||
# Solved: 0% | ||
# | ||
# Pink noise | ||
# ---------- | ||
# Solved: 100% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .cnrl import * | ||
|
||
try: | ||
from .sb3 import * | ||
except: | ||
pass | ||
|
||
try: | ||
from .tonic import * | ||
except: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from . import colorednoise as cn | ||
|
||
class ColoredNoiseProcess(): | ||
def __init__(self, beta, scale=1, chunksize=None, largest_wavelength=None, rng=None): | ||
"""Colored noise implemented as a process that allows subsequent samples. | ||
Implemented as a buffer; every "chunksize[-1]" items, a cut to a new time series starts. | ||
""" | ||
self.beta = beta | ||
if largest_wavelength is None: | ||
self.minimum_frequency = 0 | ||
else: | ||
self.minimum_frequency = 1 / largest_wavelength | ||
self.scale = scale | ||
self.rng = rng | ||
|
||
# The last component of chunksize is the time index | ||
try: | ||
self.chunksize = list(chunksize) | ||
except TypeError: | ||
self.chunksize = [chunksize] | ||
self.time_steps = self.chunksize[-1] | ||
|
||
# Set first time-step such that buffer will be initialized | ||
self.idx = self.time_steps | ||
|
||
def sample(self): | ||
self.idx += 1 # Next time step | ||
|
||
# Refill buffer if depleted | ||
if self.idx >= self.time_steps: | ||
self.buffer = cn.powerlaw_psd_gaussian( | ||
exponent=self.beta, size=self.chunksize, fmin=self.minimum_frequency, rng=self.rng) | ||
self.idx = 0 | ||
|
||
return self.scale * self.buffer[..., self.idx] | ||
|
||
class PinkNoiseProcess(ColoredNoiseProcess): | ||
def __init__(self, scale=1, chunksize=None, largest_wavelength=None, rng=None): | ||
"""Colored noise implemented as a process that allows subsequent samples. | ||
Implemented as a buffer; every "chunksize[-1]" items, a cut to a new time series starts. | ||
""" | ||
super().__init__(1, scale, chunksize, largest_wavelength, rng) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
"""Colored noise generation script | ||
Modified from colorednoise package: https://github.com/felixpatzelt/colorednoise | ||
""" | ||
|
||
import numpy as np | ||
from numpy.fft import irfft, rfftfreq | ||
|
||
|
||
def powerlaw_psd_gaussian(exponent, size, fmin=0, rng=None): | ||
"""Gaussian (1/f)**beta noise. | ||
Based on the algorithm in: | ||
Timmer, J. and Koenig, M.: | ||
On generating power law noise. | ||
Astron. Astrophys. 300, 707-710 (1995) | ||
Normalised to unit variance | ||
Parameters: | ||
----------- | ||
exponent : float | ||
The power-spectrum of the generated noise is proportional to | ||
S(f) = (1 / f)**beta | ||
flicker / pink noise: exponent beta = 1 | ||
brown noise: exponent beta = 2 | ||
Furthermore, the autocorrelation decays proportional to lag**-gamma | ||
with gamma = 1 - beta for 0 < beta < 1. | ||
There may be finite-size issues for beta close to one. | ||
shape : int or iterable | ||
The output has the given shape, and the desired power spectrum in | ||
the last coordinate. That is, the last dimension is taken as time, | ||
and all other components are independent. | ||
fmin : float, optional | ||
Low-frequency cutoff. | ||
Default: 0 corresponds to original paper. | ||
The power-spectrum below fmin is flat. fmin is defined relative | ||
to a unit sampling rate (see numpy's rfftfreq). For convenience, | ||
the passed value is mapped to max(fmin, 1/samples) internally | ||
since 1/samples is the lowest possible finite frequency in the | ||
sample. The largest possible value is fmin = 0.5, the Nyquist | ||
frequency. The output for this value is white noise. | ||
rng : np.random.Generator, optional | ||
Random number generator (for reproducibility). If None (default), a new | ||
random number generator is created by calling np.random.default_rng(). | ||
Returns | ||
------- | ||
out : array | ||
The samples. | ||
Examples: | ||
--------- | ||
>>> # generate 1/f noise == pink noise == flicker noise | ||
>>> import colorednoise as cn | ||
>>> y = cn.powerlaw_psd_gaussian(1, 5) | ||
""" | ||
|
||
# Make sure size is a list so we can iterate it and assign to it. | ||
try: | ||
size = list(size) | ||
except TypeError: | ||
size = [size] | ||
|
||
# The number of samples in each time series | ||
samples = size[-1] | ||
|
||
# Calculate Frequencies (we asume a sample rate of one) | ||
# Use fft functions for real output (-> hermitian spectrum) | ||
f = rfftfreq(samples) | ||
|
||
# Validate / normalise fmin | ||
if 0 <= fmin <= 0.5: | ||
fmin = max(fmin, 1./samples) # Low frequency cutoff | ||
else: | ||
raise ValueError("fmin must be chosen between 0 and 0.5.") | ||
|
||
# Build scaling factors for all frequencies | ||
s_scale = f | ||
ix = np.sum(s_scale < fmin) # Index of the cutoff | ||
if ix and ix < len(s_scale): | ||
s_scale[:ix] = s_scale[ix] | ||
s_scale = s_scale**(-exponent/2.) | ||
|
||
# Calculate theoretical output standard deviation from scaling | ||
w = s_scale[1:].copy() | ||
w[-1] *= (1 + (samples % 2)) / 2. # correct f = +-0.5 | ||
sigma = 2 * np.sqrt(np.sum(w**2)) / samples | ||
|
||
# Adjust size to generate one Fourier component per frequency | ||
size[-1] = len(f) | ||
|
||
# Add empty dimension(s) to broadcast s_scale along last | ||
# dimension of generated random power + phase (below) | ||
dims_to_add = len(size) - 1 | ||
s_scale = s_scale[(None,) * dims_to_add + (Ellipsis,)] | ||
|
||
# Generate scaled random power + phase | ||
if rng is None: | ||
rng = np.random.default_rng() | ||
sr = rng.normal(scale=s_scale, size=size) | ||
si = rng.normal(scale=s_scale, size=size) | ||
|
||
# If the signal length is even, frequencies +/- 0.5 are equal | ||
# so the coefficient must be real. | ||
if not (samples % 2): | ||
si[..., -1] = 0 | ||
sr[..., -1] *= np.sqrt(2) # Fix magnitude | ||
|
||
# Regardless of signal length, the DC component must be real | ||
si[..., 0] = 0 | ||
sr[..., 0] *= np.sqrt(2) # Fix magnitude | ||
|
||
# Combine power + corrected phase to Fourier components | ||
s = sr + 1J * si | ||
|
||
# Transform to real time series & scale to unit variance | ||
y = irfft(s, n=samples, axis=-1) / sigma | ||
|
||
return y |
Oops, something went wrong.