Skip to content

Commit

Permalink
Merge pull request #1 from DockYard/pv-refactor/revamp-docs-and-opts
Browse files Browse the repository at this point in the history
chore: prepare for release
  • Loading branch information
polvalente authored Aug 25, 2023
2 parents 2150467 + 5243412 commit 5ffe59c
Show file tree
Hide file tree
Showing 13 changed files with 177 additions and 97 deletions.
4 changes: 2 additions & 2 deletions guides/gridworld.livemd
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Reinforcement Learning
# First steps with Gridworld

```elixir
my_app_root = Path.join(__DIR__, "..")
Expand All @@ -9,7 +9,7 @@ Mix.install(
],
config_path: Path.join(my_app_root, "config/config.exs"),
lockfile: Path.join(my_app_root, "mix.lock"),
# change to "cuda118" to use CUDA
# change to "cuda118" or "cuda120" to use CUDA
system_env: %{"XLA_TARGET" => "cpu"}
)
```
Expand Down
53 changes: 28 additions & 25 deletions lib/rein.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@ defmodule Rein do

import Nx.Defn

@type t :: %__MODULE__{
agent: module(),
agent_state: term(),
environment: module(),
environment_state: term(),
episode: Nx.t(),
iteration: Nx.t(),
random_key: Nx.t(),
trajectory: Nx.t()
}

@derive {Nx.Container,
containers: [
:agent_state,
Expand All @@ -14,7 +25,7 @@ defmodule Rein do
:episode,
:trajectory
],
keep: [:agent_opts]}
keep: []}
defstruct [
:agent,
:agent_state,
Expand All @@ -23,19 +34,20 @@ defmodule Rein do
:random_key,
:iteration,
:episode,
:trajectory,
:agent_opts
:trajectory
]

@spec train(
{environment :: module, init_opts :: keyword()},
{agent :: module, init_opts :: keyword},
epoch_completed_callback :: (map() -> :ok),
state_to_trajectory_fn :: (t() -> Nx.t()),
opts :: keyword()
) :: Axon.Loop.t()
# underscore vars below for doc names
def train(
{environment, environment_init_opts},
{agent, agent_init_opts},
_environment_with_options = {environment, environment_init_opts},
_agent_with_options = {agent, agent_init_opts},
epoch_completed_callback,
state_to_trajectory_fn,
opts \\ []
Expand All @@ -61,27 +73,22 @@ defmodule Rein do
episode = Nx.tensor(opts[:accumulated_episodes], type: :s64)
iteration = Nx.tensor(0, type: :s64)

# TO-DO: needs Nx 0.6
# [episode, iteration, _] =
# Nx.broadcast_vectors([episode, iteration, random_key], align_ranks: false)
[episode, iteration, _] =
Nx.broadcast_vectors([episode, iteration, random_key], align_ranks: false)

{environment_state, random_key} = environment.init(random_key, environment_init_opts)

{agent_state, agent_opts, random_key} =
case agent.reset(random_key, %__MODULE__{
environment_state: environment_state,
agent: agent,
agent_state: init_agent_state,
episode: episode
}) do
{s, o, k} -> {s, o, k}
{s, k} -> {s, [], k}
end
{agent_state, random_key} =
agent.reset(random_key, %__MODULE__{
environment_state: environment_state,
agent: agent,
agent_state: init_agent_state,
episode: episode
})

initial_state = %__MODULE__{
agent: agent,
agent_state: agent_state,
agent_opts: agent_opts,
environment: environment,
environment_state: environment_state,
random_key: random_key,
Expand Down Expand Up @@ -181,16 +188,12 @@ defmodule Rein do
) do
{environment_state, random_key} = environment.reset(random_key, environment_state)

{agent_state, agent_opts, random_key} =
case agent.reset(random_key, %{loop_state | environment_state: environment_state}) do
{state, opts, key} -> {state, opts, key}
{state, key} -> {state, [], key}
end
{agent_state, random_key} =
agent.reset(random_key, %{loop_state | environment_state: environment_state})

state = %{
loop_state
| agent_state: agent_state,
agent_opts: agent_opts,
environment_state: environment_state,
random_key: random_key,
trajectory: Nx.broadcast(Nx.tensor(:nan, type: :f32), loop_state.trajectory),
Expand Down
13 changes: 11 additions & 2 deletions lib/rein/agent.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,26 @@ defmodule Rein.Agent do

@typedoc "An arbitrary `Nx.Container` that holds metadata for the agent"
@type t :: Nx.Container.t()

@typedoc "The full state of the current Reinforcement Learning process, as stored in the `Rein` struct"
@type rl_state :: Rein.t()

@doc """
Initializes the agent state with the given agent-specific options.
Also calls `c:reset/2` in the end.
Should be implemented in a way that the result would be semantically
the same as if `c:reset/2` was called in the end of the function.
As a suggestion, the implementation should only initialize fixed
values here, that is values that don't change between sessions
(epochs for non-episodic tasks, episodes for episodic tasks). Then,
call `c:reset/2` internally to initialize the rest of variable values.
"""
@callback init(random_key :: Nx.t(), opts :: keyword) :: {t(), random_key :: Nx.t()}

@doc """
Resets any values that aren't fixed for the agent state.
Resets any values that vary between sessions (which would be episodes
for episodic tasks) for the agent state.
"""
@callback reset(random_key :: Nx.t(), rl_state :: t) :: {t(), random_key :: Nx.t()}

Expand Down
14 changes: 1 addition & 13 deletions lib/rein/agents/ddpg.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defmodule Rein.Agents.DDPG do
Deep Deterministic Policy Gradient implementation.
This assumes that the Actor network will output `{nil, num_actions}` actions,
and that the Critic network accepts `"actions"` input with the same shape.
and that the Critic network accepts the `"actions"` input with the same shape.
Actions are deemed to be in a continuous space of type `:f32`.
"""
Expand All @@ -14,18 +14,6 @@ defmodule Rein.Agents.DDPG do

@behaviour Rein.Agent

@derive {Inspect,
except: [
# :actor_params,
# :actor_target_params,
# :critic_params,
# :critic_target_params,
# :experience_replay_buffer,
# :actor_optimizer_state,
# :critic_optimizer_state,
# :state_features_memory
]}

@derive {Nx.Container,
containers: [
:actor_params,
Expand Down
27 changes: 10 additions & 17 deletions lib/rein/agents/dqn.ex
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
defmodule Rein.Agents.DQN do
@moduledoc """
Deep Q-Learning implementation.
This implementation utilizes a single target network for
the policy network.
"""
import Nx.Defn

@behaviour Rein.Agent
Expand Down Expand Up @@ -134,9 +140,9 @@ defmodule Rein.Agents.DQN do

# TO-DO: receive optimizer as argument
{optimizer_init_fn, optimizer_update_fn} =
Axon.Updates.clip_by_global_norm()
|> Axon.Updates.compose(
Axon.Optimizers.adamw(learning_rate: @learning_rate, eps: @eps, decay: @adamw_decay)
Polaris.Updates.clip_by_global_norm()
|> Polaris.Updates.compose(
Polaris.Optimizers.adamw(learning_rate: @learning_rate, eps: @eps, decay: @adamw_decay)
)

initial_q_policy_state = opts[:q_policy] || raise "missing initial q_policy"
Expand Down Expand Up @@ -514,10 +520,7 @@ defmodule Rein.Agents.DQN do
state_vector_size * 2 + 3,
td_errors
),
huber_loss(expected_state_action_values, state_action_values)
# Axon.Losses.mean_squared_error(expected_state_action_values, state_action_values,
# reduction: :mean
# )
Axon.Losses.huber(expected_state_action_values, state_action_values, reduction: :mean)
}
end,
&elem(&1, 1)
Expand Down Expand Up @@ -613,14 +616,4 @@ defmodule Rein.Agents.DQN do

Nx.indexed_put(buffer, indices, Nx.reshape(td_errors, {n}))
end

defnp huber_loss(y_true, y_pred, opts \\ [delta: 1.0]) do
delta = opts[:delta]

abs_diff = Nx.abs(y_pred - y_true)

(abs_diff <= delta)
|> Nx.select(0.5 * abs_diff ** 2, delta * abs_diff - 0.5 * delta ** 2)
|> Nx.mean()
end
end
8 changes: 8 additions & 0 deletions lib/rein/agents/q_learning.ex
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
defmodule Rein.Agents.QLearning do
@moduledoc """
Q-Learning implementation.
This implementation uses epsilon-greedy sampling
for exploration, and doesn't contemplate any kind
of target network.
"""

import Nx.Defn

@behaviour Rein.Agent
Expand Down
15 changes: 8 additions & 7 deletions lib/rein/agents/sac.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ defmodule Rein.Agents.SAC do
Actions are deemed to be in a continuous space of type `:f32`.
For simplicity in the implementation of the Dual Q implementation,
`:critic_params` is vectorized with the axis `:critics` with default
size 2. Likewise, `:critic_target_params` and `:critic_optimizer_state`
are also vectorized in the same way.
Vectorized axes from `:random_key` are still propagated normally throughout
the agent state for parallel training.
The Dual Q implementation utilizes two copies of the critic network, `critic1` and `critic2`,
each with their own separate target network.
Vectorized axes from `:random_key` are propagated normally throughout
the agent state for parallel simulations, but all samples are stored in the same
circular buffer. After all simulations have ran, the optimization steps are run
on a sample space consisting of all previous experiences, including all of the
parallel simulations that have just finished executing.
"""
import Nx.Defn

Expand Down
20 changes: 13 additions & 7 deletions lib/rein/environment.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,31 @@ defmodule Rein.Environment do

@typedoc "An arbitrary `Nx.Container` that holds metadata for the environment"
@type t :: Nx.Container.t()
@type rl_state :: Rein.t()

@doc "The number of possible actions for the environment"
@callback num_actions() :: pos_integer()
@typedoc "The full state of the current Reinforcement Learning process, as stored in the `Rein` struct"
@type rl_state :: Rein.t()

@doc """
Initializes the environment state with the given environment-specific options.
Initializes the environment state with the given enviroment-specific options.
Should be implemented in a way that the result would be semantically
the same as if `c:reset/2` was called in the end of the function.
Also calls `c:reset/2` in the end.
As a suggestion, the implementation should only initialize fixed
values here, that is values that don't change between sessions
(epochs for non-episodic tasks, episodes for episodic tasks). Then,
call `c:reset/2` internally to initialize the rest of variable values.
"""
@callback init(random_key :: Nx.t(), opts :: keyword) :: {t(), random_key :: Nx.t()}

@doc """
Resets any values that aren't fixed for the environment state.
Resets any values that vary between sessions (which would be episodes
for episodic tasks, epochs for non-episodic tasks) for the environment state.
"""
@callback reset(random_key :: Nx.t(), environment_state :: t) :: {t(), random_key :: Nx.t()}

@doc """
Applies the given action to the environment.
Applies the selected action to the environment.
Returns the updated environment, also updated with the reward
and a flag indicating whether the new state is terminal.
Expand Down
27 changes: 16 additions & 11 deletions lib/rein/environments/gridworld.ex
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
defmodule Rein.Environments.Gridworld do
@moduledoc """
Gridworld environment with 4 discrete actions.
Gridworld is an environment where the agent
aims to reach a given target from a collection
of possible targets, only being able to choose
1 of 4 actions: up, down, left and right.
"""
import Nx.Defn

@behaviour Rein.Environment
Expand All @@ -12,7 +20,6 @@ defmodule Rein.Environments.Gridworld do
:target_x,
:target_y,
:reward,
:reward_stage,
:is_terminal,
:possible_targets,
:has_reached_target
Expand All @@ -26,7 +33,6 @@ defmodule Rein.Environments.Gridworld do
:target_x,
:target_y,
:reward,
:reward_stage,
:is_terminal,
:possible_targets,
:has_reached_target
Expand All @@ -39,10 +45,10 @@ defmodule Rein.Environments.Gridworld do

def bounding_box, do: {@min_x, @max_x, @min_y, @max_y}

# x, y, target_x, target_y, prev_x, prev_y, reward_stage
def state_vector_size, do: 7
# x, y, target_x, target_y, has_reached_target, distance_norm
@doc "The size of the state vector returned by `as_state_vector/1`"
def state_vector_size, do: 6

@impl true
# up, down, left, right
def num_actions, do: 4

Expand Down Expand Up @@ -71,12 +77,8 @@ defmodule Rein.Environments.Gridworld do

y = Nx.tensor(0, type: :s64)

target_x = target[0]
target_y = target[1]
zero_bool = Nx.tensor(false)
# TO-DO: needs Nx 0.6
# [x, y, target_x, target_y, zero_bool, _key] =
# Nx.broadcast_vectors([x, y, target[0], target[1], random_key, Nx.u8(0)])
[x, y, target_x, target_y, zero_bool, _key] =
Nx.broadcast_vectors([x, y, target[0], target[1], random_key, Nx.u8(0)])

state = %{
state
Expand Down Expand Up @@ -160,6 +162,9 @@ defmodule Rein.Environments.Gridworld do

defnp normalize(v, min, max), do: (v - min) / (max - min)

@doc """
Default function for turning the environment into a vector representation.
"""
defn as_state_vector(%{
x: x,
y: y,
Expand Down
Loading

0 comments on commit 5ffe59c

Please sign in to comment.