Skip to content

Commit

Permalink
Get rid of gin dependency and make policy base class abstract.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708966070
  • Loading branch information
jaindeepali authored and copybara-github committed Dec 23, 2024
1 parent c0618ec commit b7b7108
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 36 deletions.
12 changes: 5 additions & 7 deletions iris/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import copy
from typing import Any, Dict, Optional, Sequence, Union
from absl import logging
import gin
import gym
from gym import spaces
from gym.spaces import utils
Expand Down Expand Up @@ -106,13 +105,16 @@ def state(self, new_state: Dict[str, Any]) -> None:
pass


@gin.configurable
class MeanStdBuffer(Buffer):
"""Collect stats for calculating mean and std online."""

def __init__(self, shape: Sequence[int] = (0,)) -> None:
self._shape = shape
self._data = {}
self._data = {
N: 0,
MEAN: np.zeros(self._shape, dtype=np.float64),
UNNORM_VAR: np.zeros(self._shape, dtype=np.float64),
}
self.reset()

def reset(self) -> None:
Expand Down Expand Up @@ -283,7 +285,6 @@ def state(self, state: Dict[str, np.ndarray]) -> None:
self._state = state.copy()


@gin.configurable
class NoNormalizer(Normalizer):
"""No Normalization applied to input."""

Expand All @@ -300,7 +301,6 @@ def __call__(
return value


@gin.configurable
class ActionRangeDenormalizer(Normalizer):
"""Actions mapped to given range from [-1, 1]."""

Expand Down Expand Up @@ -341,7 +341,6 @@ def __call__(
return action


@gin.configurable
class ObservationRangeNormalizer(Normalizer):
"""Observations mapped from given range to [-1, 1]."""

Expand Down Expand Up @@ -383,7 +382,6 @@ def __call__(
return observation


@gin.configurable
class RunningMeanStdNormalizer(Normalizer):
"""Standardize observations with mean and std calculated online."""

Expand Down
10 changes: 5 additions & 5 deletions iris/normalizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
class BufferTest(absltest.TestCase):

def test_meanstdbuffer(self):
buffer = normalizer.MeanStdBuffer((1))
buffer = normalizer.MeanStdBuffer((1,))
buffer.push(np.asarray(10.0))
buffer.push(np.asarray(11.0))

new_buffer = normalizer.MeanStdBuffer((1))
new_buffer = normalizer.MeanStdBuffer((1,))
new_buffer.data = buffer.data

self.assertEqual(new_buffer._std, buffer._std)
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_mean_std_buffer_empty_merge(self):
self.assertEqual(mean_std_buffer._data['n'], 0)

def test_mean_std_buffer_scalar(self):
mean_std_buffer = normalizer.MeanStdBuffer((1))
mean_std_buffer = normalizer.MeanStdBuffer((1,))
mean_std_buffer.push(np.asarray(10.0))
self.assertEqual(mean_std_buffer._std, 1.0) # First value is always 1.0.

Expand All @@ -154,10 +154,10 @@ def test_mean_std_buffer_scalar(self):
np.testing.assert_almost_equal(mean_std_buffer._std, np.sqrt(0.5))

def test_mean_std_buffer_reject_infinity_on_merge(self):
mean_std_buffer = normalizer.MeanStdBuffer((1))
mean_std_buffer = normalizer.MeanStdBuffer((1,))
mean_std_buffer.push(np.asarray(10.0))

infinty_buffer = normalizer.MeanStdBuffer((1))
infinty_buffer = normalizer.MeanStdBuffer((1,))
infinty_buffer.push(np.asarray(np.inf))

mean_std_buffer.merge(infinty_buffer.data)
Expand Down
11 changes: 8 additions & 3 deletions iris/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

"""Policy class for computing action from weights and observation vector."""

import abc
from typing import Dict, Union

import gym
from gym.spaces import utils
import numpy as np


class BasePolicy(object):
class BasePolicy(abc.ABC):
"""Base policy class for reinforcement learning."""

def __init__(self, ob_space: gym.Space, ac_space: gym.Space) -> None:
Expand Down Expand Up @@ -55,23 +56,27 @@ def set_iteration(self, value: int | None):
self._iteration = value

def update_weights(self, new_weights: np.ndarray) -> None:
"""Updates the flat weights vector."""
self._weights[:] = new_weights[:]

def get_weights(self) -> np.ndarray:
"""Returns the flat weights vector."""
return self._weights

def get_representation_weights(self):
"""Returns the flat representation weights vector."""
return self._representation_weights

def update_representation_weights(
self, new_representation_weights: np.ndarray) -> None:
"""Updates the flat representation weights vector."""
self._representation_weights[:] = new_representation_weights[:]

def reset(self):
"""Resets the internal policy state."""
pass

@abc.abstractmethod
def act(self, ob: Union[np.ndarray, Dict[str, np.ndarray]]
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""Maps the observation to action."""
raise NotImplementedError(
"Should be implemented in derived classes for specific policies.")
45 changes: 27 additions & 18 deletions iris/policies/nas_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pyglove as pg


class PyGlovePolicy(abc.ABC, base_policy.BasePolicy):
class PyGlovePolicy(base_policy.BasePolicy):
"""Base class for all policies involving NAS search."""

@abc.abstractmethod
Expand All @@ -42,40 +42,49 @@ def dna_spec(self) -> pg.DNASpec:
class NumpyTopologyPolicy(PyGlovePolicy):
"""Parent class for numpy-based policies."""

def __init__(self,
ob_space: gym.Space,
ac_space: gym.Space,
hidden_layer_sizes: Sequence[int],
seed: int = 0,
**kwargs):
base_policy.BasePolicy.__init__(self, ob_space, ac_space)
def __init__(
self,
ob_space: gym.Space,
ac_space: gym.Space,
hidden_layer_sizes: Sequence[int],
seed: int = 0,
**kwargs
):
super().__init__(ob_space, ac_space)

self._hidden_layer_sizes = hidden_layer_sizes
self._total_nb_nodes = sum(
self._hidden_layer_sizes) + self._ob_dim + self._ac_dim
self._all_layer_sizes = [self._ob_dim] + list(
self._hidden_layer_sizes) + [self._ac_dim]
self._total_nb_nodes = (
sum(self._hidden_layer_sizes) + self._ob_dim + self._ac_dim
)
self._all_layer_sizes = (
[self._ob_dim] + list(self._hidden_layer_sizes) + [self._ac_dim]
)

self._total_weight_parameters = self._total_nb_nodes**2
self._total_bias_parameters = self._total_nb_nodes
self._total_nb_parameters = self._total_weight_parameters + self._total_bias_parameters
self._total_nb_parameters = (
self._total_weight_parameters + self._total_bias_parameters
)

np.random.seed(seed)
self._weights = np.random.uniform(
low=-1.0, high=1.0, size=(self._total_nb_nodes, self._total_nb_nodes))
low=-1.0, high=1.0, size=(self._total_nb_nodes, self._total_nb_nodes)
)
self._biases = np.random.uniform(
low=-1.0, high=1.0, size=self._total_nb_nodes)
low=-1.0, high=1.0, size=self._total_nb_nodes
)

self._edge_dict = {}

def act(self, ob: Union[np.ndarray, Dict[str, np.ndarray]]
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
def act(
self, ob: Union[np.ndarray, Dict[str, np.ndarray]]
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
ob = utils.flatten(self._ob_space, ob)
values = [0.0] * self._total_nb_nodes
for i in range(self._ob_dim):
values[i] = ob[i]
for i in range(self._total_nb_nodes):
if ((i > self._ob_dim) and (i < self._total_nb_nodes - self._ac_dim)):
if (i > self._ob_dim) and (i < self._total_nb_nodes - self._ac_dim):
values[i] = np.tanh(values[i] + self._biases[i])
if i in self._edge_dict:
j_list = self._edge_dict[i]
Expand Down
3 changes: 0 additions & 3 deletions requirements-rl.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,3 @@ jax # Use latest version.
jaxlib # Use latest version.
flax # Use latest version.
tensorflow # TODO(team): Resolve version conflicts.

# Configuration + Experimentation
gin-config>=0.5.0

0 comments on commit b7b7108

Please sign in to comment.