Skip to content

Commit

Permalink
Fix train signature and update type hints (#24)
Browse files Browse the repository at this point in the history
* Hotfix for train signature

* Fix deprecated type hints

* Fix mypy

* Update optax dep for python 3.8
  • Loading branch information
araffin authored Jan 16, 2024
1 parent ba597ca commit 37ed771
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 27 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ jobs:
- name: Type check
run: |
make type
# skip mypy, jax doesn't have its latest version for python 3.8
if: "!(matrix.python-version == '3.8')"
- name: Test with pytest
run: |
make pytest
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ env = [
filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Numpy warnings from jax
"ignore:Conversion of an array with ndim:DeprecationWarning:jax",
# TQDM experimental
"ignore:rich is experimental",
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
Expand Down
2 changes: 1 addition & 1 deletion sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
else:
self.n_units = 256

def build(self, key: jax.random.KeyArray, lr_schedule: Schedule) -> jax.random.KeyArray:
def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
key, qf_key = jax.random.split(key, 2)

obs = jnp.array([self.observation_space.sample()])
Expand Down
4 changes: 2 additions & 2 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(

self.key = self.noise_key = jax.random.PRNGKey(0)

def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, max_grad_norm: float) -> jax.random.KeyArray:
def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> jax.Array:
key, actor_key, vf_key = jax.random.split(key, 3)
# Keep a key for the actor
key, self.key = jax.random.split(key, 2)
Expand Down Expand Up @@ -190,7 +190,7 @@ def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.n
self.reset_noise()
return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key)

def predict_all(self, observation: np.ndarray, key: jax.random.KeyArray) -> np.ndarray:
def predict_all(self, observation: np.ndarray, key: jax.Array) -> np.ndarray:
return self._predict_all(self.actor_state, self.vf_state, observation, key)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion sbx/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(

self.key = self.noise_key = jax.random.PRNGKey(0)

def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, qf_learning_rate: float) -> jax.random.KeyArray:
def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) -> jax.Array:
key, actor_key, qf_key, dropout_key = jax.random.split(key, 4)
# Keep a key for the actor
key, self.key = jax.random.split(key, 2)
Expand Down
13 changes: 7 additions & 6 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,25 +179,26 @@ def learn(
progress_bar=progress_bar,
)

def train(self, batch_size, gradient_steps):
def train(self, gradient_steps: int, batch_size: int) -> None:
assert self.replay_buffer is not None
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
# Pre-compute the indices where we need to update the actor
# This is a hack in order to jit the train loop
# It will compile once per value of policy_delay_indices
policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0}
policy_delay_indices = flax.core.FrozenDict(policy_delay_indices)
policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) # type: ignore[assignment]

if isinstance(data.observations, dict):
keys = list(self.observation_space.keys())
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
else:
obs = data.observations.numpy()
next_obs = data.next_observations.numpy()

# Convert to numpy
data = ReplayBufferSamplesNp(
data = ReplayBufferSamplesNp( # type: ignore[assignment]
obs,
data.actions.numpy(),
next_obs,
Expand Down Expand Up @@ -241,7 +242,7 @@ def update_critic(
next_observations: np.ndarray,
rewards: np.ndarray,
dones: np.ndarray,
key: jax.random.KeyArray,
key: jax.Array,
):
key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4)
# sample action from the actor
Expand Down Expand Up @@ -285,7 +286,7 @@ def update_actor(
qf_state: RLTrainState,
ent_coef_state: TrainState,
observations: np.ndarray,
key: jax.random.KeyArray,
key: jax.Array,
):
key, dropout_key, noise_key = jax.random.split(key, 3)

Expand Down
2 changes: 1 addition & 1 deletion sbx/td3/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(

self.key = self.noise_key = jax.random.PRNGKey(0)

def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, qf_learning_rate: float) -> jax.random.KeyArray:
def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) -> jax.Array:
key, actor_key, qf_key, dropout_key = jax.random.split(key, 4)
# Keep a key for the actor
key, self.key = jax.random.split(key, 2)
Expand Down
13 changes: 7 additions & 6 deletions sbx/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,26 @@ def learn(
progress_bar=progress_bar,
)

def train(self, batch_size, gradient_steps):
def train(self, gradient_steps: int, batch_size: int) -> None:
assert self.replay_buffer is not None
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
# Pre-compute the indices where we need to update the actor
# This is a hack in order to jit the train loop
# It will compile once per value of policy_delay_indices
policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0}
policy_delay_indices = flax.core.FrozenDict(policy_delay_indices)
policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) # type: ignore[assignment]

if isinstance(data.observations, dict):
keys = list(self.observation_space.keys())
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
else:
obs = data.observations.numpy()
next_obs = data.next_observations.numpy()

# Convert to numpy
data = ReplayBufferSamplesNp(
data = ReplayBufferSamplesNp( # type: ignore[assignment]
obs,
data.actions.numpy(),
next_obs,
Expand Down Expand Up @@ -182,7 +183,7 @@ def update_critic(
dones: np.ndarray,
target_policy_noise: float,
target_noise_clip: float,
key: jax.random.KeyArray,
key: jax.Array,
):
key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4)
# Select action according to target net and add clipped noise
Expand Down Expand Up @@ -223,7 +224,7 @@ def update_actor(
actor_state: RLTrainState,
qf_state: RLTrainState,
observations: np.ndarray,
key: jax.random.KeyArray,
key: jax.Array,
):
key, dropout_key = jax.random.split(key, 2)

Expand Down
2 changes: 1 addition & 1 deletion sbx/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(

self.key = self.noise_key = jax.random.PRNGKey(0)

def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, qf_learning_rate: float) -> jax.random.KeyArray:
def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) -> jax.Array:
key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4)
key, dropout_key1, dropout_key2, self.key = jax.random.split(key, 4)
# Initialize noise
Expand Down
13 changes: 7 additions & 6 deletions sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,25 +180,26 @@ def learn(
progress_bar=progress_bar,
)

def train(self, batch_size, gradient_steps):
def train(self, gradient_steps: int, batch_size: int) -> None:
assert self.replay_buffer is not None
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
# Pre-compute the indices where we need to update the actor
# This is a hack in order to jit the train loop
# It will compile once per value of policy_delay_indices
policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0}
policy_delay_indices = flax.core.FrozenDict(policy_delay_indices)
policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) # type: ignore[assignment]

if isinstance(data.observations, dict):
keys = list(self.observation_space.keys())
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
else:
obs = data.observations.numpy()
next_obs = data.next_observations.numpy()

# Convert to numpy
data = ReplayBufferSamplesNp(
data = ReplayBufferSamplesNp( # type: ignore[assignment]
obs,
data.actions.numpy(),
next_obs,
Expand Down Expand Up @@ -246,7 +247,7 @@ def update_critic(
next_observations: np.ndarray,
rewards: np.ndarray,
dones: np.ndarray,
key: jax.random.KeyArray,
key: jax.Array,
):
key, noise_key, dropout_key_1, dropout_key_2 = jax.random.split(key, 4)
key, dropout_key_3, dropout_key_4 = jax.random.split(key, 3)
Expand Down Expand Up @@ -327,7 +328,7 @@ def update_actor(
qf2_state: RLTrainState,
ent_coef_state: TrainState,
observations: np.ndarray,
key: jax.random.KeyArray,
key: jax.Array,
):
key, dropout_key_1, dropout_key_2, noise_key = jax.random.split(key, 4)

Expand Down
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.1
0.10.0
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@
packages=[package for package in find_packages() if package.startswith("sbx")],
package_data={"sbx": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.2.0a9",
"stable_baselines3>=2.3.0a1",
"jax",
"jaxlib",
"flax",
"optax",
'optax; python_version >= "3.9.0"',
# See https://github.com/google-deepmind/optax/issues/711
'optax<0.1.8; python_version < "3.9.0"',
"tqdm",
"rich",
"tensorflow_probability",
Expand Down

0 comments on commit 37ed771

Please sign in to comment.