From bce0a793caee6c838bcfbf4673db5bfa543f30cf Mon Sep 17 00:00:00 2001 From: Victor Blanchi Date: Thu, 11 Apr 2024 09:47:52 +0200 Subject: [PATCH] Plan with particle filters Signed-off-by: Victor Blanchi --- examples/gym-cartpole/smart_agent.zls | 4 +- examples/gym-cartpole/smart_pid_agent.zls | 4 +- probzelus/inference/infer_pf.ml | 67 ++++++++++++++++++++++- probzelus/inference/infer_pf.zli | 5 ++ 4 files changed, 72 insertions(+), 8 deletions(-) diff --git a/examples/gym-cartpole/smart_agent.zls b/examples/gym-cartpole/smart_agent.zls index dd1c1600..a51e44a3 100644 --- a/examples/gym-cartpole/smart_agent.zls +++ b/examples/gym-cartpole/smart_agent.zls @@ -27,9 +27,7 @@ let proba model obs_gym = action where rec obs = simple_pendulum (obs_gym, (Right fby action)) and action = controller (obs) and () = Infer_pf.factor (-10. *. (abs_float (obs.pole_angle))) - and display = draw_obs_back obs let node smart_main () = () where - rec reset action = Infer_pf.plan 10 10 model obs every true + rec reset action = Infer_pf.plan_pf 30 10 10 model obs every true and obs, _, stop = cart_pole_gym true (Right fby action) - and display = draw_obs_front obs diff --git a/examples/gym-cartpole/smart_pid_agent.zls b/examples/gym-cartpole/smart_pid_agent.zls index b90c1d2d..702b61e3 100644 --- a/examples/gym-cartpole/smart_pid_agent.zls +++ b/examples/gym-cartpole/smart_pid_agent.zls @@ -51,7 +51,6 @@ let proba smart_model (obs_gym) = action where rec obs = simple_pendulum (obs_gym, (Right fby action)) and action = smart_controller (obs) and () = Infer_pf.factor (-10. *. (abs_float (obs.pole_angle))) - and display = draw_obs_back obs (** PID controller for the cart-pole example **) @@ -72,11 +71,10 @@ let proba pid_model (obs, ctrl_action) = p, (i, d) where let node smart_pid_main () = () where - rec reset action_smart = Infer_pf.plan 10 10 smart_model obs_smart + rec reset action_smart = Infer_pf.plan_pf 30 10 10 smart_model obs_smart every true and obs_smart, _, _ = cart_pole_gym true (Right fby action_smart) and pid_dist = Infer_pf.infer 1000 pid_model (obs_smart, action_smart) - and () = draw_obs_front obs_smart and (p, (i, d)) = Distribution.draw pid_dist and obs, _, stop = cart_pole_gym true (Right fby action) and reset action = pid_controller (obs.pole_angle, (p, i, d)) diff --git a/probzelus/inference/infer_pf.ml b/probzelus/inference/infer_pf.ml index 58e1696a..cf2eecad 100644 --- a/probzelus/inference/infer_pf.ml +++ b/probzelus/inference/infer_pf.ml @@ -26,6 +26,7 @@ open Ztypes open Owl +open Printf type pstate = { idx : int; (** particle index *) @@ -203,6 +204,7 @@ let expectation scores = let s = Array.fold_left ( +. ) 0. scores in s /. float (Array.length scores) +(** [plan_step n k model_step model_copy] return a function [step] that duplicates the current particle [n] times and advances it forward [k] times.*) let plan_step n k model_step model_copy = let table = Hashtbl.create 7 in let rec expected_utility (state, score) (ttl, input) = @@ -244,8 +246,49 @@ let plan_step n k model_step model_copy = in step -(* [plan n k f x] runs n instances of [f] on the input stream *) -(* [x] but at each step, do a prediction of depth k *) +(** [plan_step_pf n k model_step model_copy] returns a function [step] that duplicates the current particle [n] times, advances it forward, copies it [n] times, applies a particle filter of size [h], and repeats this process [k] times. *) +let plan_step_pf n h k model_step model_copy = + let table = Hashtbl.create 7 in + let rec expected_utility state (ttl, input) = + let states = Array.init h (fun _ -> Probzelus_utils.copy state) in + let scores = Array.make h 0.0 in + Array.iteri + (fun i state -> ignore @@ model_step state ({ idx = i; scores }, input)) + states; + let norm = Normalize.log_sum_exp scores in + let probabilities = Array.map (fun score -> exp (score -. norm)) scores in + let dist = Normalize.to_distribution (Array.init h id) probabilities in + let index = Distribution.draw dist in + let state, score = (states.(index), scores.(index)) in + if ttl < 1 then score else norm +. expected_utility state (ttl - 1, input) + in + let state_value_copy (src_st, src_val) (dst_st, dst_val) = + model_copy src_st dst_st; + dst_val := !src_val + in + let step { infer_states = states; infer_scores = scores } input = + let values = + Array.mapi + (fun i state -> + let value = model_step state ({ idx = i; scores }, input) in + scores.(i) <- expected_utility state (k, input); + value) + states + in + let states_values = + Array.mapi (fun i state -> (state, ref values.(i))) states + in + let norm = Normalize.log_sum_exp scores in + let probabilities = Array.map (fun score -> exp (score -. norm)) scores in + Normalize.resample state_value_copy n probabilities states_values; + Array.fill scores 0 n 0.0; + Hashtbl.clear table; + states_values + in + step + +(** [plan n k f x] runs n instances of [f] on the input stream + [x] but at each step, do a prediction of depth [k] *) let plan n k (Cnode model : (pstate * 't1, 't2) Ztypes.cnode) = let alloc () = ref (model.alloc ()) in let reset state = model.reset !state in @@ -264,6 +307,26 @@ let plan n k (Cnode model : (pstate * 't1, 't2) Ztypes.cnode) = in Cnode { alloc; reset; copy; step } +(** [plan n k f x] runs n instances of [f] on the input stream + [x] but at each step, do a prediction of depth [k] and use a particle filter of size [h] *) +let plan_pf n h k (Cnode model : (pstate * 't1, 't2) Ztypes.cnode) = + let alloc () = ref (model.alloc ()) in + let reset state = model.reset !state in + let copy src dst = model.copy !src !dst in + let step_body = plan_step_pf n h k model.step model.copy in + let step plan_state input = + let states = Array.init n (fun _ -> Probzelus_utils.copy !plan_state) in + let scores = Array.make n 0.0 in + let states_values = + step_body { infer_states = states; infer_scores = scores } input + in + let dist = Normalize.normalize states_values in + let state', value = Distribution.draw dist in + plan_state := state'; + !value + in + Cnode { alloc; reset; copy; step } + type 'state infd_state = { infd_states : 'state array; infd_scores : float array; diff --git a/probzelus/inference/infer_pf.zli b/probzelus/inference/infer_pf.zli index 230188f6..aa246a72 100644 --- a/probzelus/inference/infer_pf.zli +++ b/probzelus/inference/infer_pf.zli @@ -59,6 +59,11 @@ val plan : ('t1 ~D~> 't2) -S-> 't1 -D-> 't2 +val plan_pf : + int -S-> int -S-> int -S-> + ('t1 ~D~> 't2) -S-> + 't1 -D-> 't2 + val infer_depth : int -S-> int -S-> ('t1 ~D~> 't2) -S->