diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index c79fdf77..ff51b108 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -182,29 +182,6 @@ get_eval_times(gp::GreedyPolicy) = [gp.t_decision] get_eval_states(gp::GreedyPolicy) = gp.competitor_states -""" -function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem) - ps = parameters(sys) - params = prob.p - map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps)) - comp_params = p.competitor_params - idxs_cp = Int64[] - for i in eachindex(comp_params) - idxs = findall(x -> x==comp_params[i], ps) - push!(idxs_cp,idxs) - end - comp_vals = params[map_idxs[idxs_cp]] - @info comp_vals - return argmax(comp_vals) -end -""" - -function narrowtype(d::Dict) - types = unique(typeof.(values(d))) - U = Union{types...} - Dict{Num, U}(d) -end - mutable struct Agent{S,P,A,LR,C} odesystem::S problem::P @@ -214,7 +191,7 @@ mutable struct Agent{S,P,A,LR,C} function Agent(g::MetaDiGraph; name, kwargs...) bc = connector_from_graph(g) - + t_block = haskey(kwargs, :t_block) ? kwargs[:t_block] : missing # TODO: add another version that uses system_from_graph(g,bc,params;) sys = system_from_graph(g, bc; name, t_block, allow_parameter=false) @@ -225,7 +202,7 @@ mutable struct Agent{S,P,A,LR,C} prob = ODEProblem(sys, u0, (0.,1.), p) policy = action_selection_from_graph(g) - learning_rules = narrowtype(bc.learning_rules) + learning_rules = narrowtype(bc.learning_rule) new{typeof(sys), typeof(prob), typeof(policy), typeof(learning_rules), typeof(bc)}(sys, prob, policy, learning_rules, bc) end