From 93af594fa31028a272760f48e468be20855c7c5e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 25 Aug 2023 03:30:26 -0300 Subject: [PATCH] fix: get the guide to function properly --- guides/gridworld.livemd | 260 +++++++++-------------------- lib/rein.ex | 10 +- lib/rein/agents/q_learning.ex | 49 ++---- lib/rein/environments/gridworld.ex | 51 +++--- 4 files changed, 128 insertions(+), 242 deletions(-) diff --git a/guides/gridworld.livemd b/guides/gridworld.livemd index 21b10b3..82b1198 100644 --- a/guides/gridworld.livemd +++ b/guides/gridworld.livemd @@ -5,7 +5,8 @@ my_app_root = Path.join(__DIR__, "..") Mix.install( [ - {:rein, path: my_app_root, env: :dev} + {:rein, path: my_app_root}, + {:kino_vega_lite, "~> 0.1"} ], config_path: Path.join(my_app_root, "config/config.exs"), lockfile: Path.join(my_app_root, "mix.lock"), @@ -14,19 +15,21 @@ Mix.install( ) ``` -## Section +## Initializing the plot + +In the code block below, we initialize some meta variables and configure our VegaLite plot in way that it can be updated iteratively over the algorithm iterations. ```elixir alias VegaLite, as: Vl {min_x, max_x, min_y, max_y} = Rein.Environments.Gridworld.bounding_box() -possible_targets_l = [[div(min_x + max_x, 2), max_y - 2]] +possible_targets_l = [[round((min_x + max_x) / 2), max_y]] -possible_targets_l = - for x <- (min_x + 2)..(max_x - 2), y <- 2..max_y do - [x, y] - end +# possible_targets_l = +# for x <- (min_x + 2)..(max_x - 2), y <- 2..max_y do +# [x, y] +# end possible_targets = Nx.tensor(Enum.shuffle(possible_targets_l)) @@ -53,70 +56,23 @@ grid_widget = ), Vl.new() |> Vl.data(name: "trajectory") - |> Vl.mark(:line, opacity: 0.5, tooltip: [content: "data"]) + |> Vl.mark(:line, point: true, opacity: 1, tooltip: [content: "data"]) |> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [min_x, max_x], clamp: true]) |> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [min_y, max_y], clamp: true]) - |> Vl.encode_field(:color, "episode", - type: :nominal, - scale: [scheme: "blues"], - legend: false - ) |> Vl.encode_field(:order, "index") ]) |> Kino.VegaLite.new() |> Kino.render() -loss_widget = - Vl.new(width: width, height: height, title: "Loss") - |> Vl.data(name: "loss") - |> Vl.mark(:line, - grid: true, - tooltip: [content: "data"], - interpolate: "step-after", - point: true, - color: "blue" - ) - |> Vl.encode_field(:x, "episode", type: :quantitative) - |> Vl.encode_field(:y, "loss", - type: :quantitative, - scale: [ - domain: [0, 1], - type: "linear", - base: 10, - clamp: true - ] - ) - |> Vl.encode_field(:order, "episode") - |> Kino.VegaLite.new() - |> Kino.render() - -reward_widget = - Vl.new(width: width, height: height, title: "Total Reward per Epoch") - |> Vl.data(name: "reward") - |> Vl.mark(:line, - grid: true, - tooltip: [content: "data"], - interpolate: "step-after", - point: true, - color: "blue" - ) - |> Vl.encode_field(:x, "episode", type: :quantitative) - |> Vl.encode_field(:y, "reward", - type: :quantitative, - scale: [ - domain: [-2, 2], - type: "symlog", - base: 10, - clamp: true - ] - ) - |> Vl.encode_field(:order, "episode") - |> Kino.VegaLite.new() - |> Kino.render() - nil ``` +## Configuring and running the Q Learning Agent + +Now we're ready to start configuring our agent. The `plot_fn` function defined below is a callback that `Rein` calls at the end of each iteration, so that we can do anything with the data. + +Usually, this means that we'll extract data to either plot, report or save somewhere. + ```elixir # 250 max_iter * 15 episodes max_points = 1000 @@ -132,172 +88,122 @@ plot_fn = fn axon_state -> grid_widget, %{ x: Nx.to_number(axon_state.step_state.environment_state.target_x), - y: Nx.to_number(axon_state.step_state.environment_state.target_y), - episode: episode, - episode_group: rem(episode, 15) + y: Nx.to_number(axon_state.step_state.environment_state.target_y) }, dataset: "target" ) - Kino.VegaLite.push( - loss_widget, - %{ - episode: episode, - loss: Nx.to_number(Nx.mean(axon_state.step_state.agent_state.loss)) - }, - dataset: "loss" - ) - IO.inspect("Episode #{episode} ended") trajectory = axon_state.step_state.trajectory - idx = Nx.to_flat_list(trajectory[0][0..(axon_state.iteration - 1)//1] |> Nx.as_type(:s64)) - x = Nx.to_flat_list(trajectory[1][0..(axon_state.iteration - 1)//1]) - y = Nx.to_flat_list(trajectory[2][0..(axon_state.iteration - 1)//1]) + iteration = Nx.to_number(axon_state.step_state.iteration) points = - [idx, x, y] - |> Enum.zip_with(fn [index, x, y] -> + trajectory[0..(iteration - 1)//1] + |> Nx.to_list() + |> Enum.with_index(fn [x, y], index -> %{ x: x, y: y, - index: index, - episode: episode, - episode_group: rem(episode, 15) + index: index } end) Kino.VegaLite.push_many(grid_widget, points, dataset: "trajectory") end - Kino.VegaLite.push( - reward_widget, - %{ - episode: axon_state.epoch, - reward: Nx.to_number(axon_state.step_state.agent_state.total_reward) - }, - dataset: "reward" - ) - axon_state end ``` -```elixir -filename = "/Users/paulo.valente/Desktop/gridworld.bin" - -{ - q_policy, - experience_replay_buffer_index, - persisted_experience_replay_buffer_entries, - experience_replay_buffer, - total_episodes -} = - try do - contents = File.read!(filename) - File.write!(filename <> "_bak", contents) - - %{serialized: serialized, total_episodes: total_episodes} = :erlang.binary_to_term(contents) - - %{ - q_policy: q_policy, - experience_replay_buffer_index: experience_replay_buffer_index, - persisted_experience_replay_buffer_entries: persisted_experience_replay_buffer_entries, - experience_replay_buffer: exp_replay_buffer - } = Nx.deserialize(serialized) - - {q_policy, experience_replay_buffer_index, persisted_experience_replay_buffer_entries, - exp_replay_buffer, total_episodes} - rescue - File.Error -> - {%{}, 0, 0, nil, 0} - end +Now, we get to the actual training! -# q_policy = %{} -# total_episodes = 0 -# experience_replay_buffer = nil -# persisted_experience_replay_buffer_entries = 0 -# experience_replay_buffer_index = 0 -total_episodes -``` +The code below calls `Rein.train` with some configuration for the `Gridworld` environment being solved through a `QLearning` agent. -```elixir -q_policy -``` +This will return the whole `Axon.Loop` struct in the `result` variable, so that we can inspect and/or save it afterwards. ```elixir -num_observations = Rein.Environments.Gridworld.state_vector_size() -num_actions = Rein.Environments.Gridworld.num_actions() +Kino.VegaLite.clear(grid_widget) -policy_net = - Axon.input("state", shape: {nil, num_observations}) - |> Axon.dense(128, activation: :relu) - |> Axon.dense(64, activation: :relu) - |> Axon.dense(64, activation: :relu) - |> Axon.dense(num_actions) +episodes = 15_000 +max_iter = 20 -# These might seem redundant, but will make more sense for multi-input models +environment_to_state_vector_fn = fn %{x: x, y: y, target_x: target_x, target_y: target_y} -> + delta_x = Nx.subtract(x, min_x) + delta_y = Nx.subtract(y, min_y) -environment_to_input_fn = fn env_state -> - %{"state" => Rein.Environments.Gridworld.as_state_vector(env_state)} + Nx.stack([delta_x, delta_y, Nx.subtract(target_x, min_x), Nx.subtract(target_y, min_y)]) end -state_vector_to_input_fn = fn state_vector -> - %{"state" => state_vector} +state_to_trajectory_fn = fn %{environment_state: %{x: x, y: y}} -> + Nx.stack([x, y]) end -environment_to_state_vector_fn = &Rein.Environments.Gridworld.as_state_vector/1 -``` +delta_x = max_x - min_x + 1 +delta_y = max_y - min_y + 1 -```elixir -Kino.VegaLite.clear(grid_widget) -Kino.VegaLite.clear(loss_widget, dataset: "loss") -Kino.VegaLite.clear(reward_widget, dataset: "reward") - -episodes = 5000 -max_iter = 200 +state_space_shape = {delta_x, delta_y, delta_x, delta_y} {t, result} = :timer.tc(fn -> Rein.train( {Rein.Environments.Gridworld, possible_targets: possible_targets}, - {Rein.Agents.DQN, - policy_net: policy_net, - eps_max_iter: -1, - q_policy: q_policy, - persisted_experience_replay_buffer_entries: persisted_experience_replay_buffer_entries, - experience_replay_buffer_index: experience_replay_buffer_index, - experience_replay_buffer: experience_replay_buffer, - environment_to_input_fn: environment_to_input_fn, + {Rein.Agents.QLearning, + state_space_shape: state_space_shape, + num_actions: 4, environment_to_state_vector_fn: environment_to_state_vector_fn, - state_vector_to_input_fn: state_vector_to_input_fn}, + learning_rate: 1.0e-2, + gamma: 0.99, + exploration_eps: 1.0e-4}, plot_fn, + state_to_trajectory_fn, num_episodes: episodes, max_iter: max_iter ) end) -# File.write!("/Users/paulo.valente/Desktop/results_#{NaiveDateTime.utc_now() |> NaiveDateTime.to_iso8601() |> String.replace(":", "")}.bin", ) -serialized = - Nx.serialize( - Map.take(result.step_state.agent_state, [ - :q_policy, - :experience_replay_buffer_index, - :experience_replay_buffer, - :persisted_experience_replay_buffer_entries - ]) - ) +"#{Float.round(t / 1_000_000, 3)} s" +``` + +With the code below, we can check some points of interest in the learned Q matrix. -contents = - :erlang.term_to_binary(%{serialized: serialized, total_episodes: total_episodes + episodes}) +Especially, we can see below that for a target at x = 2, y = 4: -# File.write!(filename, contents) +* For the position x = 2, y = 3, the selected action is to go up; +* For the position x = 1, y = 4, the selected action is to go right; +* For the position x = 3, y = 4, the selected action is to go left. -"#{Float.round(t / 1_000_000, 3)} s" -``` +This shows that at least for the positions closer to the target, our agent already knows the best policy for those respective states! ```elixir -File.write!(filename, contents) -result.step_state.agent_state.q_policy +state_vector_to_index = fn state_vector, shape -> + {linear_indices_offsets_list, _} = + shape + |> Tuple.to_list() + |> Enum.reverse() + |> Enum.reduce({[], 1}, fn x, {acc, multiplier} -> + {[multiplier | acc], multiplier * x} + end) + + linear_indices_offsets = Nx.tensor(linear_indices_offsets_list) + + Nx.dot(state_vector, linear_indices_offsets) +end + +# Actions are [up, down, right, left] + +# up +idx = state_vector_to_index.(Nx.tensor([2, 3, 2, 4]), {5, 5, 5, 5}) +IO.inspect(result.step_state.agent_state.q_matrix[idx]) + +# right +idx = state_vector_to_index.(Nx.tensor([1, 4, 2, 4]), {5, 5, 5, 5}) +IO.inspect(result.step_state.agent_state.q_matrix[idx]) + +# left +idx = state_vector_to_index.(Nx.tensor([3, 4, 2, 4]), {5, 5, 5, 5}) +IO.inspect(result.step_state.agent_state.q_matrix[idx]) + +nil ``` diff --git a/lib/rein.ex b/lib/rein.ex index 6b2691e..2573342 100644 --- a/lib/rein.ex +++ b/lib/rein.ex @@ -43,7 +43,7 @@ defmodule Rein do epoch_completed_callback :: (map() -> :ok), state_to_trajectory_fn :: (t() -> Nx.t()), opts :: keyword() - ) :: Axon.Loop.t() + ) :: term() # underscore vars below for doc names def train( _environment_with_options = {environment, environment_init_opts}, @@ -60,7 +60,8 @@ defmodule Rein do :checkpoint_path, checkpoint_serialization_fn: &Nx.serialize/1, accumulated_episodes: 0, - num_episodes: 100 + num_episodes: 100, + output_transform: & &1 ]) random_key = opts[:random_key] || Nx.Random.key(System.system_time()) @@ -112,7 +113,8 @@ defmodule Rein do num_episodes: num_episodes, max_iter: max_iter, model_name: model_name, - checkpoint_path: opts[:checkpoint_path] + checkpoint_path: opts[:checkpoint_path], + output_transform: opts[:output_transform] ) end @@ -121,11 +123,13 @@ defmodule Rein do state_to_trajectory_fn = Keyword.fetch!(opts, :state_to_trajectory_fn) num_episodes = Keyword.fetch!(opts, :num_episodes) max_iter = Keyword.fetch!(opts, :max_iter) + output_transform = Keyword.fetch!(opts, :output_transform) loop_fn = &batch_step(&1, &2, agent, environment, state_to_trajectory_fn) loop_fn |> Axon.Loop.loop() + |> then(&%{&1 | output_transform: output_transform}) |> Axon.Loop.handle_event( :epoch_started, &{:continue, diff --git a/lib/rein/agents/q_learning.ex b/lib/rein/agents/q_learning.ex index 5a3e91a..a993faa 100644 --- a/lib/rein/agents/q_learning.ex +++ b/lib/rein/agents/q_learning.ex @@ -14,37 +14,25 @@ defmodule Rein.Agents.QLearning do @derive {Nx.Container, containers: [ :q_matrix, - :loss, - :loss_denominator, :observation ], keep: [ :num_actions, - :environment_to_input_fn, :environment_to_state_vector_fn, - :state_vector_to_input_fn, :learning_rate, - :batch_size, :gamma, :exploration_eps, - :exploration_decay_rate, :state_space_shape ]} defstruct [ :q_matrix, :observation, - :loss, - :loss_denominator, :num_actions, - :environment_to_input_fn, :environment_to_state_vector_fn, - :state_vector_to_input_fn, :learning_rate, - :batch_size, :gamma, :exploration_eps, - :exploration_decay_rate, :state_space_shape ] @@ -57,8 +45,7 @@ defmodule Rein.Agents.QLearning do :environment_to_state_vector_fn, :learning_rate, :gamma, - :exploration_eps, - :exploration_decay_rate + :exploration_eps ]) state_space_shape = opts[:state_space_shape] @@ -76,36 +63,28 @@ defmodule Rein.Agents.QLearning do environment_to_state_vector_fn: opts[:environment_to_state_vector_fn], learning_rate: opts[:learning_rate], gamma: opts[:gamma], - exploration_decay_rate: opts[:exploration_decay_rate], exploration_eps: opts[:exploration_eps], - state_space_shape: state_space_shape + state_space_shape: state_space_shape, + num_actions: num_actions } - {state, random_key} + reset(random_key, state) end @impl true - def reset(random_key, %Rein{agent_state: state}) do - zero = Nx.tensor(0, type: :f32) + def reset(random_key, %Rein{agent_state: agent_state}), do: reset(random_key, agent_state) - state = Nx.broadcast(0, state.state_space_shape) + def reset(random_key, %__MODULE__{} = agent_state) do + zero = Nx.tensor(0, type: :f32) observation = %{ action: 0, - state: state, - next_state: state, + state: 0, + next_state: 0, reward: zero } - [observation, _] = Nx.broadcast_vectors([observation, random_key]) - - {%{ - state - | total_reward: zero, - loss: zero, - loss_denominator: zero, - observation: observation - }, random_key} + {%__MODULE__{agent_state | observation: observation}, random_key} end @impl true @@ -162,7 +141,7 @@ defmodule Rein.Agents.QLearning do end @impl true - defn optimize_model(state) do + defn optimize_model(rl_state) do %{ observation: %{ state: state, @@ -173,7 +152,7 @@ defmodule Rein.Agents.QLearning do q_matrix: q_matrix, gamma: gamma, learning_rate: learning_rate - } = state.agent_state + } = rl_state.agent_state # Q_table[current_state, action] = # (1-lr) * Q_table[current_state, action] + @@ -183,7 +162,9 @@ defmodule Rein.Agents.QLearning do (1 - learning_rate) * q_matrix[[state, action]] + learning_rate * (reward + gamma * Nx.reduce_max(q_matrix[next_state])) - Nx.indexed_put(q_matrix, Nx.stack([state, action]), q) + q_matrix = Nx.indexed_put(q_matrix, Nx.stack([state, action]), q) + + %{rl_state | agent_state: %{rl_state.agent_state | q_matrix: q_matrix}} end deftransformp state_vector_to_index(state_vector, shape) do diff --git a/lib/rein/environments/gridworld.ex b/lib/rein/environments/gridworld.ex index 2b08f0c..dbe4eca 100644 --- a/lib/rein/environments/gridworld.ex +++ b/lib/rein/environments/gridworld.ex @@ -39,9 +39,9 @@ defmodule Rein.Environments.Gridworld do ] @min_x 0 - @max_x 20 + @max_x 4 @min_y 0 - @max_y 20 + @max_y 4 def bounding_box, do: {@min_x, @max_x, @min_y, @max_y} @@ -75,10 +75,13 @@ defmodule Rein.Environments.Gridworld do target = Nx.reshape(target, {2}) + target_x = target[0] + target_y = target[1] + y = Nx.tensor(0, type: :s64) - [x, y, target_x, target_y, zero_bool, _key] = - Nx.broadcast_vectors([x, y, target[0], target[1], random_key, Nx.u8(0)]) + # [x, y, target_x, target_y, zero_bool, _key] = + # Nx.broadcast_vectors([x, y, target[0], target[1], random_key, Nx.u8(0)]) state = %{ state @@ -89,8 +92,8 @@ defmodule Rein.Environments.Gridworld do target_x: target_x, target_y: target_y, reward: reward, - is_terminal: zero_bool, - has_reached_target: zero_bool + is_terminal: Nx.u8(0), + has_reached_target: Nx.u8(0) } {state, random_key} @@ -104,19 +107,25 @@ defmodule Rein.Environments.Gridworld do {new_x, new_y} = cond do action == 0 -> - {x, Nx.min(y + 1, @max_y + 1)} + {x, y + 1} action == 1 -> - {x, Nx.max(y - 1, @min_y - 1)} + {x, y - 1} action == 2 -> - {Nx.min(@max_x + 1, x + 1), y} + {x + 1, y} true -> - {Nx.max(x - 1, @min_x - 1), y} + {x - 1, y} end - new_env = %{env | x: new_x, y: new_y, prev_x: x, prev_y: y} + new_env = %{ + env + | x: Nx.clip(new_x, @min_x, @max_x), + y: Nx.clip(new_y, @min_y, @max_y), + prev_x: x, + prev_y: y + } updated_env = new_env @@ -127,22 +136,8 @@ defmodule Rein.Environments.Gridworld do end defnp calculate_reward(env) do - %__MODULE__{ - is_terminal: is_terminal, - has_reached_target: has_reached_target - } = env - - reward = - cond do - has_reached_target -> - 1 - - is_terminal -> - -1 - - true -> - -0.01 - end + distance = Nx.abs(env.target_x - env.x) + Nx.abs(env.target_y - env.y) + reward = -1.0 * distance %{env | reward: reward} end @@ -157,7 +152,7 @@ defmodule Rein.Environments.Gridworld do end defnp has_reached_target(%__MODULE__{x: x, y: y, target_x: target_x, target_y: target_y}) do - Nx.abs(target_x - x) <= 1.5 and Nx.abs(target_y - y) <= 1.5 + target_x == x and target_y == y end defnp normalize(v, min, max), do: (v - min) / (max - min)