Skip to content

Commit

Permalink
Plan with particle filters
Browse files Browse the repository at this point in the history
Signed-off-by: Victor Blanchi <[email protected]>
  • Loading branch information
VictorBlanchi committed Apr 11, 2024
1 parent 71a4faf commit bce0a79
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 8 deletions.
4 changes: 1 addition & 3 deletions examples/gym-cartpole/smart_agent.zls
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ let proba model obs_gym = action where
rec obs = simple_pendulum (obs_gym, (Right fby action))
and action = controller (obs)
and () = Infer_pf.factor (-10. *. (abs_float (obs.pole_angle)))
and display = draw_obs_back obs

let node smart_main () = () where
rec reset action = Infer_pf.plan 10 10 model obs every true
rec reset action = Infer_pf.plan_pf 30 10 10 model obs every true
and obs, _, stop = cart_pole_gym true (Right fby action)
and display = draw_obs_front obs
4 changes: 1 addition & 3 deletions examples/gym-cartpole/smart_pid_agent.zls
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ let proba smart_model (obs_gym) = action where
rec obs = simple_pendulum (obs_gym, (Right fby action))
and action = smart_controller (obs)
and () = Infer_pf.factor (-10. *. (abs_float (obs.pole_angle)))
and display = draw_obs_back obs


(** PID controller for the cart-pole example **)
Expand All @@ -72,11 +71,10 @@ let proba pid_model (obs, ctrl_action) = p, (i, d) where


let node smart_pid_main () = () where
rec reset action_smart = Infer_pf.plan 10 10 smart_model obs_smart
rec reset action_smart = Infer_pf.plan_pf 30 10 10 smart_model obs_smart
every true
and obs_smart, _, _ = cart_pole_gym true (Right fby action_smart)
and pid_dist = Infer_pf.infer 1000 pid_model (obs_smart, action_smart)
and () = draw_obs_front obs_smart
and (p, (i, d)) = Distribution.draw pid_dist
and obs, _, stop = cart_pole_gym true (Right fby action)
and reset action = pid_controller (obs.pole_angle, (p, i, d))
Expand Down
67 changes: 65 additions & 2 deletions probzelus/inference/infer_pf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

open Ztypes
open Owl
open Printf

type pstate = {
idx : int; (** particle index *)
Expand Down Expand Up @@ -203,6 +204,7 @@ let expectation scores =
let s = Array.fold_left ( +. ) 0. scores in
s /. float (Array.length scores)

(** [plan_step n k model_step model_copy] return a function [step] that duplicates the current particle [n] times and advances it forward [k] times.*)
let plan_step n k model_step model_copy =
let table = Hashtbl.create 7 in
let rec expected_utility (state, score) (ttl, input) =
Expand Down Expand Up @@ -244,8 +246,49 @@ let plan_step n k model_step model_copy =
in
step

(* [plan n k f x] runs n instances of [f] on the input stream *)
(* [x] but at each step, do a prediction of depth k *)
(** [plan_step_pf n k model_step model_copy] returns a function [step] that duplicates the current particle [n] times, advances it forward, copies it [n] times, applies a particle filter of size [h], and repeats this process [k] times. *)
let plan_step_pf n h k model_step model_copy =
let table = Hashtbl.create 7 in
let rec expected_utility state (ttl, input) =
let states = Array.init h (fun _ -> Probzelus_utils.copy state) in
let scores = Array.make h 0.0 in
Array.iteri
(fun i state -> ignore @@ model_step state ({ idx = i; scores }, input))
states;
let norm = Normalize.log_sum_exp scores in
let probabilities = Array.map (fun score -> exp (score -. norm)) scores in
let dist = Normalize.to_distribution (Array.init h id) probabilities in
let index = Distribution.draw dist in
let state, score = (states.(index), scores.(index)) in
if ttl < 1 then score else norm +. expected_utility state (ttl - 1, input)
in
let state_value_copy (src_st, src_val) (dst_st, dst_val) =
model_copy src_st dst_st;
dst_val := !src_val
in
let step { infer_states = states; infer_scores = scores } input =
let values =
Array.mapi
(fun i state ->
let value = model_step state ({ idx = i; scores }, input) in
scores.(i) <- expected_utility state (k, input);
value)
states
in
let states_values =
Array.mapi (fun i state -> (state, ref values.(i))) states
in
let norm = Normalize.log_sum_exp scores in
let probabilities = Array.map (fun score -> exp (score -. norm)) scores in
Normalize.resample state_value_copy n probabilities states_values;
Array.fill scores 0 n 0.0;
Hashtbl.clear table;
states_values
in
step

(** [plan n k f x] runs n instances of [f] on the input stream
[x] but at each step, do a prediction of depth [k] *)
let plan n k (Cnode model : (pstate * 't1, 't2) Ztypes.cnode) =
let alloc () = ref (model.alloc ()) in
let reset state = model.reset !state in
Expand All @@ -264,6 +307,26 @@ let plan n k (Cnode model : (pstate * 't1, 't2) Ztypes.cnode) =
in
Cnode { alloc; reset; copy; step }

(** [plan n k f x] runs n instances of [f] on the input stream
[x] but at each step, do a prediction of depth [k] and use a particle filter of size [h] *)
let plan_pf n h k (Cnode model : (pstate * 't1, 't2) Ztypes.cnode) =
let alloc () = ref (model.alloc ()) in
let reset state = model.reset !state in
let copy src dst = model.copy !src !dst in
let step_body = plan_step_pf n h k model.step model.copy in
let step plan_state input =
let states = Array.init n (fun _ -> Probzelus_utils.copy !plan_state) in
let scores = Array.make n 0.0 in
let states_values =
step_body { infer_states = states; infer_scores = scores } input
in
let dist = Normalize.normalize states_values in
let state', value = Distribution.draw dist in
plan_state := state';
!value
in
Cnode { alloc; reset; copy; step }

type 'state infd_state = {
infd_states : 'state array;
infd_scores : float array;
Expand Down
5 changes: 5 additions & 0 deletions probzelus/inference/infer_pf.zli
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ val plan :
('t1 ~D~> 't2) -S->
't1 -D-> 't2

val plan_pf :
int -S-> int -S-> int -S->
('t1 ~D~> 't2) -S->
't1 -D-> 't2

val infer_depth :
int -S-> int -S->
('t1 ~D~> 't2) -S->
Expand Down

0 comments on commit bce0a79

Please sign in to comment.