Skip to content

Commit

Permalink
fix: get the guide to function properly
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Aug 25, 2023
1 parent 5ffe59c commit 93af594
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 242 deletions.
260 changes: 83 additions & 177 deletions guides/gridworld.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ my_app_root = Path.join(__DIR__, "..")

Mix.install(
[
{:rein, path: my_app_root, env: :dev}
{:rein, path: my_app_root},
{:kino_vega_lite, "~> 0.1"}
],
config_path: Path.join(my_app_root, "config/config.exs"),
lockfile: Path.join(my_app_root, "mix.lock"),
Expand All @@ -14,19 +15,21 @@ Mix.install(
)
```

## Section
## Initializing the plot

In the code block below, we initialize some meta variables and configure our VegaLite plot in way that it can be updated iteratively over the algorithm iterations.

```elixir
alias VegaLite, as: Vl

{min_x, max_x, min_y, max_y} = Rein.Environments.Gridworld.bounding_box()

possible_targets_l = [[div(min_x + max_x, 2), max_y - 2]]
possible_targets_l = [[round((min_x + max_x) / 2), max_y]]

possible_targets_l =
for x <- (min_x + 2)..(max_x - 2), y <- 2..max_y do
[x, y]
end
# possible_targets_l =
# for x <- (min_x + 2)..(max_x - 2), y <- 2..max_y do
# [x, y]
# end

possible_targets = Nx.tensor(Enum.shuffle(possible_targets_l))

Expand All @@ -53,70 +56,23 @@ grid_widget =
),
Vl.new()
|> Vl.data(name: "trajectory")
|> Vl.mark(:line, opacity: 0.5, tooltip: [content: "data"])
|> Vl.mark(:line, point: true, opacity: 1, tooltip: [content: "data"])
|> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [min_x, max_x], clamp: true])
|> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [min_y, max_y], clamp: true])
|> Vl.encode_field(:color, "episode",
type: :nominal,
scale: [scheme: "blues"],
legend: false
)
|> Vl.encode_field(:order, "index")
])
|> Kino.VegaLite.new()
|> Kino.render()

loss_widget =
Vl.new(width: width, height: height, title: "Loss")
|> Vl.data(name: "loss")
|> Vl.mark(:line,
grid: true,
tooltip: [content: "data"],
interpolate: "step-after",
point: true,
color: "blue"
)
|> Vl.encode_field(:x, "episode", type: :quantitative)
|> Vl.encode_field(:y, "loss",
type: :quantitative,
scale: [
domain: [0, 1],
type: "linear",
base: 10,
clamp: true
]
)
|> Vl.encode_field(:order, "episode")
|> Kino.VegaLite.new()
|> Kino.render()

reward_widget =
Vl.new(width: width, height: height, title: "Total Reward per Epoch")
|> Vl.data(name: "reward")
|> Vl.mark(:line,
grid: true,
tooltip: [content: "data"],
interpolate: "step-after",
point: true,
color: "blue"
)
|> Vl.encode_field(:x, "episode", type: :quantitative)
|> Vl.encode_field(:y, "reward",
type: :quantitative,
scale: [
domain: [-2, 2],
type: "symlog",
base: 10,
clamp: true
]
)
|> Vl.encode_field(:order, "episode")
|> Kino.VegaLite.new()
|> Kino.render()

nil
```

## Configuring and running the Q Learning Agent

Now we're ready to start configuring our agent. The `plot_fn` function defined below is a callback that `Rein` calls at the end of each iteration, so that we can do anything with the data.

Usually, this means that we'll extract data to either plot, report or save somewhere.

```elixir
# 250 max_iter * 15 episodes
max_points = 1000
Expand All @@ -132,172 +88,122 @@ plot_fn = fn axon_state ->
grid_widget,
%{
x: Nx.to_number(axon_state.step_state.environment_state.target_x),
y: Nx.to_number(axon_state.step_state.environment_state.target_y),
episode: episode,
episode_group: rem(episode, 15)
y: Nx.to_number(axon_state.step_state.environment_state.target_y)
},
dataset: "target"
)

Kino.VegaLite.push(
loss_widget,
%{
episode: episode,
loss: Nx.to_number(Nx.mean(axon_state.step_state.agent_state.loss))
},
dataset: "loss"
)

IO.inspect("Episode #{episode} ended")

trajectory = axon_state.step_state.trajectory

idx = Nx.to_flat_list(trajectory[0][0..(axon_state.iteration - 1)//1] |> Nx.as_type(:s64))
x = Nx.to_flat_list(trajectory[1][0..(axon_state.iteration - 1)//1])
y = Nx.to_flat_list(trajectory[2][0..(axon_state.iteration - 1)//1])
iteration = Nx.to_number(axon_state.step_state.iteration)

points =
[idx, x, y]
|> Enum.zip_with(fn [index, x, y] ->
trajectory[0..(iteration - 1)//1]
|> Nx.to_list()
|> Enum.with_index(fn [x, y], index ->
%{
x: x,
y: y,
index: index,
episode: episode,
episode_group: rem(episode, 15)
index: index
}
end)

Kino.VegaLite.push_many(grid_widget, points, dataset: "trajectory")
end

Kino.VegaLite.push(
reward_widget,
%{
episode: axon_state.epoch,
reward: Nx.to_number(axon_state.step_state.agent_state.total_reward)
},
dataset: "reward"
)

axon_state
end
```

```elixir
filename = "/Users/paulo.valente/Desktop/gridworld.bin"

{
q_policy,
experience_replay_buffer_index,
persisted_experience_replay_buffer_entries,
experience_replay_buffer,
total_episodes
} =
try do
contents = File.read!(filename)
File.write!(filename <> "_bak", contents)

%{serialized: serialized, total_episodes: total_episodes} = :erlang.binary_to_term(contents)

%{
q_policy: q_policy,
experience_replay_buffer_index: experience_replay_buffer_index,
persisted_experience_replay_buffer_entries: persisted_experience_replay_buffer_entries,
experience_replay_buffer: exp_replay_buffer
} = Nx.deserialize(serialized)

{q_policy, experience_replay_buffer_index, persisted_experience_replay_buffer_entries,
exp_replay_buffer, total_episodes}
rescue
File.Error ->
{%{}, 0, 0, nil, 0}
end
Now, we get to the actual training!

# q_policy = %{}
# total_episodes = 0
# experience_replay_buffer = nil
# persisted_experience_replay_buffer_entries = 0
# experience_replay_buffer_index = 0
total_episodes
```
The code below calls `Rein.train` with some configuration for the `Gridworld` environment being solved through a `QLearning` agent.

```elixir
q_policy
```
This will return the whole `Axon.Loop` struct in the `result` variable, so that we can inspect and/or save it afterwards.

```elixir
num_observations = Rein.Environments.Gridworld.state_vector_size()
num_actions = Rein.Environments.Gridworld.num_actions()
Kino.VegaLite.clear(grid_widget)

policy_net =
Axon.input("state", shape: {nil, num_observations})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(64, activation: :relu)
|> Axon.dense(64, activation: :relu)
|> Axon.dense(num_actions)
episodes = 15_000
max_iter = 20

# These might seem redundant, but will make more sense for multi-input models
environment_to_state_vector_fn = fn %{x: x, y: y, target_x: target_x, target_y: target_y} ->
delta_x = Nx.subtract(x, min_x)
delta_y = Nx.subtract(y, min_y)

environment_to_input_fn = fn env_state ->
%{"state" => Rein.Environments.Gridworld.as_state_vector(env_state)}
Nx.stack([delta_x, delta_y, Nx.subtract(target_x, min_x), Nx.subtract(target_y, min_y)])
end

state_vector_to_input_fn = fn state_vector ->
%{"state" => state_vector}
state_to_trajectory_fn = fn %{environment_state: %{x: x, y: y}} ->
Nx.stack([x, y])
end

environment_to_state_vector_fn = &Rein.Environments.Gridworld.as_state_vector/1
```
delta_x = max_x - min_x + 1
delta_y = max_y - min_y + 1

```elixir
Kino.VegaLite.clear(grid_widget)
Kino.VegaLite.clear(loss_widget, dataset: "loss")
Kino.VegaLite.clear(reward_widget, dataset: "reward")

episodes = 5000
max_iter = 200
state_space_shape = {delta_x, delta_y, delta_x, delta_y}

{t, result} =
:timer.tc(fn ->
Rein.train(
{Rein.Environments.Gridworld, possible_targets: possible_targets},
{Rein.Agents.DQN,
policy_net: policy_net,
eps_max_iter: -1,
q_policy: q_policy,
persisted_experience_replay_buffer_entries: persisted_experience_replay_buffer_entries,
experience_replay_buffer_index: experience_replay_buffer_index,
experience_replay_buffer: experience_replay_buffer,
environment_to_input_fn: environment_to_input_fn,
{Rein.Agents.QLearning,
state_space_shape: state_space_shape,
num_actions: 4,
environment_to_state_vector_fn: environment_to_state_vector_fn,
state_vector_to_input_fn: state_vector_to_input_fn},
learning_rate: 1.0e-2,
gamma: 0.99,
exploration_eps: 1.0e-4},
plot_fn,
state_to_trajectory_fn,
num_episodes: episodes,
max_iter: max_iter
)
end)

# File.write!("/Users/paulo.valente/Desktop/results_#{NaiveDateTime.utc_now() |> NaiveDateTime.to_iso8601() |> String.replace(":", "")}.bin", )
serialized =
Nx.serialize(
Map.take(result.step_state.agent_state, [
:q_policy,
:experience_replay_buffer_index,
:experience_replay_buffer,
:persisted_experience_replay_buffer_entries
])
)
"#{Float.round(t / 1_000_000, 3)} s"
```

With the code below, we can check some points of interest in the learned Q matrix.

contents =
:erlang.term_to_binary(%{serialized: serialized, total_episodes: total_episodes + episodes})
Especially, we can see below that for a target at x = 2, y = 4:

# File.write!(filename, contents)
* For the position x = 2, y = 3, the selected action is to go up;
* For the position x = 1, y = 4, the selected action is to go right;
* For the position x = 3, y = 4, the selected action is to go left.

"#{Float.round(t / 1_000_000, 3)} s"
```
This shows that at least for the positions closer to the target, our agent already knows the best policy for those respective states!

```elixir
File.write!(filename, contents)
result.step_state.agent_state.q_policy
state_vector_to_index = fn state_vector, shape ->
{linear_indices_offsets_list, _} =
shape
|> Tuple.to_list()
|> Enum.reverse()
|> Enum.reduce({[], 1}, fn x, {acc, multiplier} ->
{[multiplier | acc], multiplier * x}
end)

linear_indices_offsets = Nx.tensor(linear_indices_offsets_list)

Nx.dot(state_vector, linear_indices_offsets)
end

# Actions are [up, down, right, left]

# up
idx = state_vector_to_index.(Nx.tensor([2, 3, 2, 4]), {5, 5, 5, 5})
IO.inspect(result.step_state.agent_state.q_matrix[idx])

# right
idx = state_vector_to_index.(Nx.tensor([1, 4, 2, 4]), {5, 5, 5, 5})
IO.inspect(result.step_state.agent_state.q_matrix[idx])

# left
idx = state_vector_to_index.(Nx.tensor([3, 4, 2, 4]), {5, 5, 5, 5})
IO.inspect(result.step_state.agent_state.q_matrix[idx])

nil
```
10 changes: 7 additions & 3 deletions lib/rein.ex
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ defmodule Rein do
epoch_completed_callback :: (map() -> :ok),
state_to_trajectory_fn :: (t() -> Nx.t()),
opts :: keyword()
) :: Axon.Loop.t()
) :: term()
# underscore vars below for doc names
def train(
_environment_with_options = {environment, environment_init_opts},
Expand All @@ -60,7 +60,8 @@ defmodule Rein do
:checkpoint_path,
checkpoint_serialization_fn: &Nx.serialize/1,
accumulated_episodes: 0,
num_episodes: 100
num_episodes: 100,
output_transform: & &1
])

random_key = opts[:random_key] || Nx.Random.key(System.system_time())
Expand Down Expand Up @@ -112,7 +113,8 @@ defmodule Rein do
num_episodes: num_episodes,
max_iter: max_iter,
model_name: model_name,
checkpoint_path: opts[:checkpoint_path]
checkpoint_path: opts[:checkpoint_path],
output_transform: opts[:output_transform]
)
end

Expand All @@ -121,11 +123,13 @@ defmodule Rein do
state_to_trajectory_fn = Keyword.fetch!(opts, :state_to_trajectory_fn)
num_episodes = Keyword.fetch!(opts, :num_episodes)
max_iter = Keyword.fetch!(opts, :max_iter)
output_transform = Keyword.fetch!(opts, :output_transform)

loop_fn = &batch_step(&1, &2, agent, environment, state_to_trajectory_fn)

loop_fn
|> Axon.Loop.loop()
|> then(&%{&1 | output_transform: output_transform})
|> Axon.Loop.handle_event(
:epoch_started,
&{:continue,
Expand Down
Loading

0 comments on commit 93af594

Please sign in to comment.