diff --git a/Project.toml b/Project.toml index 1658f078..5feb98b9 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ DataFrames = "1.3" Distributions = "0.25.102" ExponentialUtilities = "1" ForwardDiff = "0.10" -GraphDynamics = "0.1.5" +GraphDynamics = "0.2" Graphs = "1" Interpolations = "0.14, 0.15" MetaGraphs = "0.7" diff --git a/examples/RF_learning_simple.jl b/examples/RF_learning_simple.jl index 7f6dcdb7..c6092888 100644 --- a/examples/RF_learning_simple.jl +++ b/examples/RF_learning_simple.jl @@ -74,4 +74,4 @@ agent = Agent(g; name=:ag, t_block = 90); #define environment : contains stimuli and feedback env = ClassificationEnvironment(stim; name=:env, namespace=global_ns) -run_experiment!(agent, env; alg=Vern7(), reltol=1e-9,abstol=1e-9) +run_experiment!(agent, env; t_warmup=200, alg=Vern7(), reltol=1e-9,abstol=1e-9) diff --git a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl index 66cbcd53..a087494c 100644 --- a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl +++ b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl @@ -72,8 +72,7 @@ using GraphDynamics: StateIndex, ParamIndex, event_times, - calculate_inputs, - connection_index + calculate_inputs using Random: Random, @@ -256,12 +255,6 @@ function populate_flatgraph(h, g, blox, v, g_i, h_i) if length(components(blox)) == 1 && only(components(blox)) == blox h_i += 1 add_subsystem!(h, to_subsystem(blox), Neuroblox.namespaced_nameof(blox)) - # add_vertices!(h, 1) - # subsystem = to_subsystem(blox) - # name = Neuroblox.namespaced_nameof(blox) - - # set_subsystem!(h, to_subsystem(blox), h_i) - # set_name!(h, name, h_i) if v isa Dict @assert !haskey(v, g_i) v[g_i] = h_i @@ -341,14 +334,9 @@ situations where tiny matrices like (e.g. 5x5) get stored as sparse arrays rathe function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_length_cutoff=0) check_all_supported_blox(_g) g = flat_graph(_g) - - total_eltype = mapreduce(promote_type, vertices(g)) do i - eltype(get_subsystem(g, i)) - end - fix_eltype(s::Subsystem{Name}) where {Name} = convert(Subsystem{Name, total_eltype}, s) - + subsystems_and_names_flat = map(vertices(g)) do i - (subsystem = fix_eltype(get_subsystem(g, i)), name = get_name(g, i)) + (subsystem = get_subsystem(g, i), name = get_name(g, i)) end names_flat = map(last, subsystems_and_names_flat) subsystems_flat = map(first, subsystems_and_names_flat) @@ -433,9 +421,9 @@ function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_ end end end - states_partitioned = map(v -> map(get_states, v), subsystems) + states_partitioned = map(v -> map(get_states, v), subsystems) params_partitioned = map(v -> map(get_params, v), subsystems) - names_partitioned = map(v -> map(last, v), subsystems_and_names) + names_partitioned = map(v -> map(last, v), subsystems_and_names) composite_continuous_events_partitioned = let if isempty(g.composite_continuous_events_builder) @@ -479,5 +467,5 @@ function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_ end end - end#module GraphDynamicsInterop + diff --git a/src/GraphDynamicsInterop/connection_interop.jl b/src/GraphDynamicsInterop/connection_interop.jl index 9b8e2af2..4f2a0f7e 100644 --- a/src/GraphDynamicsInterop/connection_interop.jl +++ b/src/GraphDynamicsInterop/connection_interop.jl @@ -86,6 +86,7 @@ end struct BasicConnection <: ConnectionRule weight::Float64 end +Base.zero(::BasicConnection) = Base.zero(BasicConnection) Base.zero(::Type{<:BasicConnection}) = BasicConnection(0.0) function (c::BasicConnection)(blox_src, blox_dst) (; jcn = c.weight * output(blox_src)) @@ -177,7 +178,7 @@ function get_connection( (;conn, names) end -struct HHConnection_GAP +struct HHConnection_GAP <: ConnectionRule w::Float64 w_gap::Float64 w_gap_rev::Float64 @@ -211,7 +212,7 @@ end #---------------------------------------------- -# Kuramoto +# Kuramoto function get_connection(src::KuramotoOscillator, dst::KuramotoOscillator, kwargs) (;w_val, name) = generate_weight_param(src, dst, kwargs) (;conn=BasicConnection(w_val), names=[name]) @@ -227,40 +228,12 @@ end #---------------------------------------------- # LIFExci / LIFInh -function blox_wiring_rule!(h, blox::Union{LIFExciNeuron, LIFInhNeuron}, v, kwargs) - evbs = h.composite_discrete_events_builder - i = only(v) - push!(evbs, SpikeAffectEventBuilder(i, Int[], Int[])) -end - - -function blox_wiring_rule!(h, - blox_src::Union{LIFExciNeuron, LIFInhNeuron}, - blox_dst::Union{LIFExciNeuron, LIFInhNeuron}, - v_src, v_dst, kwargs) - #this is the fallback method for non-composite blox, hence vi and vj should have only one element - i, j = only(v_src), only(v_dst) - (; w_val, name) = generate_weight_param(blox_src, blox_dst, kwargs) - conn = BasicConnection(w_val) - - let evbs = h.composite_discrete_events_builder - idx = findfirst(evb -> (evb isa SpikeAffectEventBuilder) && (evb.idx_src == i), evbs) - if isnothing(idx) - error("SpikeAffectEventBuilder for neuron not found, this indicates its blox wiring rule never ran.") - else - if blox_dst isa LIFExciNeuron - push!(evbs[idx].idx_dsts_exci, j) - elseif blox_dst isa LIFInhNeuron - push!(evbs[idx].idx_dsts_inh, j) - end - end - end - add_edge!(h, i, j, Dict(:conn => conn, :names => [name])) -end - function (c::BasicConnection)(sys_src::Subsystem{LIFExciNeuron}, sys_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}}) - (; jcn = 0.0) + w = c.weight + + (; jcn = w * sys_src.S_NMDA * sys_dst.g_NMDA * (sys_dst.V - sys_dst.V_E) / + (1 + sys_dst.Mg * exp(-0.062 * sys_dst.V) / 3.57)) end function (c::BasicConnection)(::Subsystem{LIFInhNeuron}, @@ -268,111 +241,65 @@ function (c::BasicConnection)(::Subsystem{LIFInhNeuron}, (; jcn = 0.0) end -struct SpikeAffectEventBuilder - idx_src::Int - idx_dsts_inh::Vector{Int} - idx_dsts_exci::Vector{Int} +const LIFExciInhNeuron = Union{LIFExciNeuron, LIFInhNeuron} +GraphDynamics.has_discrete_events(::Type{LIFExciNeuron}) = true +GraphDynamics.has_discrete_events(::Type{LIFInhNeuron}) = true +function GraphDynamics.discrete_event_condition((; t_refract_end, V, θ)::Subsystem{LIF}, t, _) where {LIF <: LIFExciInhNeuron} + # Triggers when either a refractory period is ending, or the neuron spiked (voltage exceeds threshold θ) + (V > θ) || (t_refract_end == t) end +function GraphDynamics.apply_discrete_event!(integrator, + states_view_src, params_view_src, + neuron_src::Subsystem{LIF}, + foreach_connected_neuron) where {LIF <: LIFExciInhNeuron} + t = integrator.t + if t == neuron_src.t_refract_end # Refreactory period is over + params = params_view_src[] + params_view_src[] = @set params.is_refractory = 0 + else # Neuron fired + # Begin refractory period + params_src = params_view_src[] + @reset params_src.t_refract_end = t + params_src.t_refract_duration + @reset params_src.is_refractory = 1 + + add_tstop!(integrator, params_src.t_refract_end) + params_view_src[] = params_src -struct SpikeAffectEvent{i_src, i_LIFInh, i_LIFExci} - j_src::Int - j_dsts_inh::Vector{Int} - j_dsts_exci::Vector{Int} -end + # Reset the neuron voltage + states_view_src[:V] = params_src.V_reset -function (ev::SpikeAffectEventBuilder)(index_map) - (i_src, j_src) = index_map[ev.idx_src] - i_inh, j_dsts_inh = let v = ev.idx_dsts_inh - if isempty(v) - nothing, Int[] - else - index_map[first(v)][1], map(idx -> index_map[idx][2], v) + # Now apply a function to each connected dst neuron + foreach_connected_neuron() do conn, neuron_dst, states_view_dst, params_view_dst + lif_exci_inh_update_connected_neuron(neuron_src, states_view_src, conn, neuron_dst, states_view_dst) end end - i_exci, j_dsts_exci = let v = ev.idx_dsts_exci - if isempty(v) - nothing, Int[] - else - index_map[first(v)][1], map(idx -> index_map[idx][2], v) - end +end +function lif_exci_inh_update_connected_neuron(neuron_src::Subsystem{LIFExciNeuron}, + states_view_src, + conn::BasicConnection, + neuron_dst::Subsystem{<:LIFExciInhNeuron}, + states_view_dst) + w = conn.weight + # check if the neuron is connected to itself + if states_view_src === states_view_dst + # x is the rise variable for NMDA synapses and it only applies to self-recurrent connections + states_view_dst[:x] += w end - SpikeAffectEvent{i_src, i_inh, i_exci}(j_src, j_dsts_inh, j_dsts_exci) + states_view_dst[:S_AMPA] += w + nothing end - -function GraphDynamics.discrete_event_condition(states, - params, - connection_matrices, - ev::SpikeAffectEvent{i_src, i_dst_inh, i_dsts_exci}, - t) where {i_src, i_dst_inh, i_dsts_exci} - (; j_src) = ev - neuron_src = Subsystem(states[i_src][j_src], params[i_src][j_src]) - neuron_src.V > neuron_src.θ +function lif_exci_inh_update_connected_neuron(neuron_src::Subsystem{LIFInhNeuron}, + states_view_src, + conn::BasicConnection, + neuron_dst::Subsystem{<:LIFExciInhNeuron}, + states_view_dst) + w = conn.weight + states_view_dst[:S_GABA] += w + nothing end - -function GraphDynamics.apply_discrete_event!(integrator, - states::NTuple{Len, Any}, - params::NTuple{Len, Any}, - connection_matrices, - t, - ev::SpikeAffectEvent{i_src, i_dst_inh, i_dst_exci} - ) where {i_src, i_dst_inh, i_dst_exci, Len} - (; j_src, j_dsts_inh, j_dsts_exci) = ev - - nc = connection_index(BasicConnection, connection_matrices) - - params_src = params[i_src][j_src] - @reset params_src.t_refract_end = t + params_src.t_refract_duration - @reset params_src.is_refractory = 1 - - params[i_src][j_src] = params_src - add_tstop!(integrator, params_src.t_refract_end) - - states_src = states[i_src][j_src] - states[i_src][:V, j_src] = params_src.V_reset - if (states_src isa SubsystemStates{LIFExciNeuron}) && (j_src ∈ j_dsts_exci) - # x is the rise variable for NMDA synapses and it only applies to self-recurrent connections - w = connection_matrices[nc][i_src, i_src][j_src, j_src].weight - states[i_src][:x, j_src] += w - end - - if states_src isa SubsystemStates{LIFExciNeuron} - if !isnothing(i_dst_inh) - M = connection_matrices[nc][i_src, i_dst_inh] - for j_dst ∈ j_dsts_inh - w = M[j_src, j_dst].weight - states[i_dst_inh][:S_AMPA, j_dst] += w - end - end - if !isnothing(i_dst_exci) - M = connection_matrices[nc][i_src, i_dst_exci] - for j_dst ∈ j_dsts_exci - w = M[j_src, j_dst].weight - states[i_dst_exci][:S_AMPA, j_dst] += w - end - end - elseif states_src isa SubsystemStates{LIFInhNeuron} - if !isnothing(i_dst_inh) - M = connection_matrices[nc][i_src, i_dst_inh] - for j_dst ∈ j_dsts_inh - w = M[j_src, j_dst].weight - states[i_dst_inh][:S_GABA, j_dst] += w - end - end - if !isnothing(i_dst_exci) - M = connection_matrices[nc][i_src, i_dst_exci] - for j_dst ∈ j_dsts_exci - w = M[j_src, j_dst].weight - states[i_dst_exci][:S_GABA, j_dst] += w - end - end - else - error("this should be unreachable") - end -end - function blox_wiring_rule!(h, stim::PoissonSpikeTrain, blox_dst::Union{LIFExciNeuron, LIFInhNeuron}, @@ -382,7 +309,7 @@ function blox_wiring_rule!(h, conn = PoissonSpikeConn(w_val, Set(Neuroblox.generate_spike_times(stim))) add_edge!(h, i, j, Dict(:conn => conn, :names => [name])) end -struct PoissonSpikeConn +struct PoissonSpikeConn <: ConnectionRule w::Float64 t_spikes::Set{Float64} end @@ -391,23 +318,29 @@ function ((;w)::PoissonSpikeConn)(stim::Subsystem{PoissonSpikeTrain}, blox_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}}) (; jcn = 0.0) end -GraphDynamics.has_discrete_events(::PoissonSpikeConn) = true -GraphDynamics.has_discrete_events(::Type{PoissonSpikeConn}) = true GraphDynamics.event_times((;t_spikes)::PoissonSpikeConn) = (t_spikes) -GraphDynamics.discrete_event_condition((;t_spikes)::PoissonSpikeConn, t) = (t ∈ t_spikes) +GraphDynamics.has_discrete_events(::Type{PoissonSpikeTrain}) = true +function GraphDynamics.discrete_event_condition(p::Subsystem{PoissonSpikeTrain}, t, foreach_connected_neuron::F) where {F} + # check if any of the downstream connections from p spike at time t. + cond = mapreduce(|, foreach_connected_neuron; init=false) do conn, _, _, _ + t ∈ conn.t_spikes + end +end function GraphDynamics.apply_discrete_event!(integrator, - _, _, - vstates_dst, _, - _::PoissonSpikeConn, - _::Subsystem{PoissonSpikeTrain}, - _::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}}) - states = vstates_dst[] - states = @set states.S_AMPA_ext += 1 - vstates_dst[] = states - nothing + states_view_src, params_view_src, + neuron_src::Subsystem{PoissonSpikeTrain}, + foreach_connected_neuron::F) where {F} + t = integrator.t + foreach_connected_neuron() do conn, neuron_dst, states_view_dst, params_view_dst + # Check each downstream connection, if it's time to spike, increment the downstream neuron's S_AMPA_ext + if t ∈ conn.t_spikes + states_view_dst[:S_AMPA_ext] += 1 + end + end end + components(blox::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = blox.parts issupported(::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = true @@ -438,7 +371,6 @@ function blox_wiring_rule!(h, blox_dst::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}, v_src, v_dst, kwargs) neurons_dst = components(blox_dst) - for (j, neuron_dst) ∈ enumerate(neurons_dst) blox_wiring_rule!(h, stim, neuron_dst, only(v_src), v_dst[j], kwargs) end @@ -749,7 +681,7 @@ function get_connection(discr_src::Matrisome, discr_dst::Matrisome, kwargs) MMConn(t_event) end -struct MMConn{T} +struct MMConn{T} <: ConnectionRule t_event::T end @@ -825,7 +757,7 @@ function get_connection(discr_src::TAN, discr_dst::Matrisome, kwargs) (; conn = TAN_M_Conn(w_val, t_event), names=[name]) end -struct TAN_M_Conn +struct TAN_M_Conn <: ConnectionRule w::Float64 t_event::Float64 end @@ -1019,7 +951,7 @@ end # #------------------------- # PING Network -struct PINGConnection +struct PINGConnection <: ConnectionRule w::Float64 V_E::Float64 V_I::Float64 diff --git a/src/GraphDynamicsInterop/neuron_interop.jl b/src/GraphDynamicsInterop/neuron_interop.jl index 3126cc32..f52c138a 100644 --- a/src/GraphDynamicsInterop/neuron_interop.jl +++ b/src/GraphDynamicsInterop/neuron_interop.jl @@ -128,8 +128,8 @@ function define_neurons() )) @eval begin GraphDynamics.has_continuous_events(::Type{$T}) = true - GraphDynamics.continuous_event_condition((; $(p_and_s_syms...))::Subsystem{$T}, t) = $ev_condition - function GraphDynamics.apply_continuous_event!(integrator, sview, pview, neuron::Subsystem{$T}) + GraphDynamics.continuous_event_condition((; $(p_and_s_syms...))::Subsystem{$T}, t, _) = $ev_condition + function GraphDynamics.apply_continuous_event!(integrator, sview, pview, neuron::Subsystem{$T}, _) (; $(p_and_s_syms...)) = neuron sview[] = SubsystemStates{$T}(merge(NamedTuple(get_states(neuron)), $ev_affect)) end @@ -140,16 +140,16 @@ end define_neurons() # it's useful when developing this module to have these in a function #Maybe should just encorporate this into define_neurons() -for T ∈ [:LIFExciNeuron, :LIFInhNeuron] - @eval begin - GraphDynamics.has_discrete_events(::Type{$T}) = true - GraphDynamics.discrete_event_condition((; t_refract_end)::Subsystem{$T}, t) = t_refract_end == t - function GraphDynamics.apply_discrete_event!(integrator, _, pview, neuron::Subsystem{$T}) - params = get_params(neuron) - pview[] = @set params.is_refractory = 0 - end - end -end +# for T ∈ [:LIFExciNeuron, :LIFInhNeuron] +# @eval begin +# GraphDynamics.has_discrete_events(::Type{$T}) = true +# GraphDynamics.discrete_event_condition((; t_refract_end)::Subsystem{$T}, t) = t_refract_end == t +# function GraphDynamics.apply_discrete_event!(integrator, _, pview, neuron::Subsystem{$T}, _) +# params = get_params(neuron) +# pview[] = @set params.is_refractory = 0 +# end +# end +# end issupported(::PoissonSpikeTrain) = true components(p::PoissonSpikeTrain) = (p,) @@ -199,11 +199,11 @@ function to_subsystem(s::Matrisome) end GraphDynamics.has_discrete_events(::Type{Matrisome}) = true GraphDynamics.discrete_events_require_inputs(::Type{Matrisome}) = true -function GraphDynamics.discrete_event_condition((;t_event,)::Subsystem{Matrisome}, t) +function GraphDynamics.discrete_event_condition((;t_event,)::Subsystem{Matrisome}, t, _) t == t_event end GraphDynamics.event_times((;t_event)::Subsystem{Matrisome}) = t_event -function GraphDynamics.apply_discrete_event!(integrator, _, vparams, s::Subsystem{Matrisome}, jcn) +function GraphDynamics.apply_discrete_event!(integrator, _, vparams, s::Subsystem{Matrisome}, _, jcn) # recording the values of jcn and H at the event time in the parameters jcn_ and H_ params = get_params(s) vparams[] = @set params.jcn_ = jcn @@ -270,14 +270,13 @@ end GraphDynamics.has_discrete_events(::Type{SNc}) = true GraphDynamics.discrete_events_require_inputs(::Type{SNc}) = true -function GraphDynamics.discrete_event_condition((;t_event,)::Subsystem{SNc}, t) +function GraphDynamics.discrete_event_condition((;t_event,)::Subsystem{SNc}, t, _) t == t_event end GraphDynamics.event_times((;t_event)::Subsystem{SNc}) = t_event -function GraphDynamics.apply_discrete_event!(integrator, _, vparams, s::Subsystem{SNc}, jcn) +function GraphDynamics.apply_discrete_event!(integrator, _, vparams, s::Subsystem{SNc}, _, jcn) # recording the values of jcn and H at the event time in the parameters jcn_ and H_ params = get_params(s) vparams[] = @set params.jcn_ = jcn nothing end - diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index a321b109..4a20dfcf 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -42,7 +42,7 @@ import ModelingToolkit: inputs, nameof, outputs, getdescription using Symbolics: @register_symbolic, getdefaultval, get_variables -using CSV: read +using CSV: read, write using DataFrames using Peaks: argmaxima, peakproms!, peakheights!, findmaxima diff --git a/src/blox/cortical.jl b/src/blox/cortical.jl index 97a7894b..c199a70f 100644 --- a/src/blox/cortical.jl +++ b/src/blox/cortical.jl @@ -115,6 +115,7 @@ struct LIFExciCircuitBlox <: CompositeBlox Mg = 1, # mM exci_scaling_factor = 1, inh_scaling_factor = 1, + skip_system_creation=false, kwargs... ) @@ -154,10 +155,13 @@ struct LIFExciCircuitBlox <: CompositeBlox end end - bc = connector_from_graph(g) - - sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(neurons; name) - + if skip_system_creation + bc = nothing + sys = nothing + else + bc = connector_from_graph(g) + sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(neurons; name) + end new(namespace, neurons, sys, bc, kwargs) end end @@ -191,6 +195,7 @@ struct LIFInhCircuitBlox <: CompositeBlox Mg = 1, # mM exci_scaling_factor = 1, inh_scaling_factor = 1, + skip_system_creation = false, kwargs... ) @@ -228,10 +233,14 @@ struct LIFInhCircuitBlox <: CompositeBlox end end - bc = connector_from_graph(g) + if skip_system_creation + bc = nothing + sys = nothing + else + bc = connector_from_graph(g) + sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(neurons; name) + end - sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(neurons; name) - new(namespace, neurons, sys, bc, kwargs) end end diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index f37944ff..a8ba1cca 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -31,6 +31,10 @@ function weight_gradient(hp::HebbianPlasticity, sol, w, feedback) return hp(val_pre, val_post, w, feedback) end +get_eval_times(l::HebbianPlasticity) = [l.t_pre, l.t_post] + +get_eval_states(l::HebbianPlasticity) = [l.state_pre, l.state_post] + mutable struct HebbianModulationPlasticity <: AbstractLearningRule const K const decay @@ -74,6 +78,10 @@ function weight_gradient(hmp::HebbianModulationPlasticity, sol, w, feedback) return hmp(val_pre, val_post, val_mod, w, feedback) end +get_eval_times(l::HebbianModulationPlasticity) = [l.t_pre, l.t_post, l.t_mod] + +get_eval_states(l::HebbianModulationPlasticity) = [l.state_pre, l.state_post, get_modulator_state(l.modulator)] + function maybe_set_state_pre!(lr::AbstractLearningRule, state) if isnothing(lr.state_pre) lr.state_pre = state @@ -121,7 +129,7 @@ end (env::ClassificationEnvironment)(action) = action == env.category[env.current_trial] -increment_trial!(env::AbstractEnvironment) = env.current_trial += 1 +increment_trial!(env::AbstractEnvironment) = env.current_trial = mod(env.current_trial, env.N_trials) + 1 reset!(env::AbstractEnvironment) = env.current_trial = 1 @@ -153,6 +161,10 @@ function (p::GreedyPolicy)(sol::SciMLBase.AbstractSciMLSolution) return argmax(comp_vals) end +get_eval_times(gp::GreedyPolicy) = [gp.t_decision] + +get_eval_states(gp::GreedyPolicy) = gp.competitor_states + """ function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem) ps = parameters(sys) @@ -200,191 +212,151 @@ end reset!(ag::Agent) = ag.problem = remake(ag.problem; p = ag.init_params) -function run_experiment!(agent::Agent, env::ClassificationEnvironment, t_warmup=200.0; kwargs...) +function run_experiment!(agent::Agent, env::ClassificationEnvironment; t_warmup=0, kwargs...) N_trials = env.N_trials t_trial = env.t_trial tspan = (0, t_trial) sys = get_sys(agent) - prob = agent.problem + defs = ModelingToolkit.get_defaults(sys) + learning_rules = agent.learning_rules + + stim_params = get_trial_stimulus(env) + init_params = ModelingToolkit.MTKParameters(sys, merge(defs, stim_params)) if t_warmup > 0 - prob = remake(prob; tspan=(0,t_warmup)) - if haskey(kwargs, :alg) - sol = solve(prob, kwargs[:alg]; kwargs...) - else - sol = solve(prob; alg_hints = [:stiff], kwargs...) - end - u0 = sol[1:end,end] # last value of state vector - prob = remake(prob; tspan=tspan, u0=u0) + u0 = run_warmup(agent, env, t_warmup; kwargs...) + agent.problem = remake(agent.problem; tspan, u0=u0, p=init_params) else - prob = remake(prob; tspan) - u0 = [] + agent.problem = remake(agent.problem; tspan, p=init_params) + end + + t_stops = mapreduce(get_eval_times, union, values(learning_rules)) + + action_selection = agent.action_selection + if !isnothing(action_selection) + t_stops = union(t_stops, get_eval_times(action_selection)) end - action_selection = agent.action_selection - learning_rules = agent.learning_rules - - defs = ModelingToolkit.get_defaults(sys) weights = Dict{Num, Float64}() for w in keys(learning_rules) weights[w] = defs[w] end for _ in Base.OneTo(N_trials) - - stim_params = get_trial_stimulus(env) - - to_update = merge(weights, stim_params) - new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params)) - - prob = remake(prob; p = new_params, u0=u0) - if haskey(kwargs, :alg) - sol = solve(prob, kwargs[:alg]; kwargs...) - else - sol = solve(prob; alg_hints = [:stiff], kwargs...) - end - - # u0 = sol[1:end,end] # next run should continue where the last one ended - # In the paper we assume sufficient time interval before net stimulus so that - # system reaches back to steady state, so we don't continue from previous trial's endpoint - - if isnothing(action_selection) - feedback = 1 - else - action = action_selection(sol) - feedback = env(action) - end - - for (w, rule) in learning_rules - w_val = weights[w] - Δw = weight_gradient(rule, sol, w_val, feedback) - weights[w] += Δw - end - increment_trial!(env) + run_trial!(agent, env, weights, nothing; saveat = t_stops, kwargs...) end - - agent.problem = prob end -function run_trial!(agent::Agent, env::ClassificationEnvironment, weights::Dict{Num, Float64}, u0::Vector{Float64}; kwargs...) - N_trials = env.N_trials - - if env.current_trial <= N_trials - t_trial = env.t_trial - tspan = (0, t_trial) - - prob = agent.problem - - action_selection = agent.action_selection - learning_rules = agent.learning_rules - - @show env.current_trial - stim_params = get_trial_stimulus(env) - @show stim_params - @show weights - prob = remake(prob; tspan=tspan, p = merge(weights, stim_params), u0=u0) - - if haskey(kwargs, :alg) - sol = solve(prob, kwargs[:alg]; kwargs...) - else - sol = solve(prob; alg_hints = [:stiff], kwargs...) - end - - if isnothing(action_selection) - feedback = 1 - else - action = action_selection(sol) - feedback = env(action) - end - - for (w, rule) in learning_rules - w_val = weights[w] - Δw = weight_gradient(rule, sol, w_val, feedback) - @show Δw - weights[w] += Δw - end - prob = remake(prob; p = merge(weights)) #updates the weights in prob - increment_trial!(env) - agent.problem = prob - # u0 = sol[1:end,end] - end -end - -function run_experiment!(agent::Agent, env::ClassificationEnvironment, save_path::String, t_warmup=200.0; kwargs...) +function run_experiment!(agent::Agent, env::ClassificationEnvironment, save_path::String; t_warmup=0, kwargs...) N_trials = env.N_trials t_trial = env.t_trial tspan = (0, t_trial) + sys = get_sys(agent) - prob = agent.problem + defs = ModelingToolkit.get_defaults(sys) + learning_rules = agent.learning_rules + + stim_params = get_trial_stimulus(env) + init_params = ModelingToolkit.MTKParameters(sys, merge(defs, stim_params)) if t_warmup > 0 - prob = remake(prob; tspan=(0,t_warmup)) - if haskey(kwargs, :alg) - sol = solve(prob, kwargs[:alg]; kwargs...) - else - sol = solve(prob; alg_hints = [:stiff], kwargs...) - end - u0 = sol[1:end,end] # last value of state vector - prob = remake(prob; tspan=tspan, u0=u0) + u0 = run_warmup(agent, env, t_warmup; kwargs...) + agent.problem = remake(agent.problem; tspan, u0=u0, p=init_params) else - prob = remake(prob; tspan) - u0 = [] + agent.problem = remake(agent.problem; tspan, p=init_params) end - action_selection = agent.action_selection - learning_rules = agent.learning_rules - - defs = ModelingToolkit.get_defaults(sys) weights = Dict{Num, Float64}() for w in keys(learning_rules) weights[w] = defs[w] end - for trial_num in Base.OneTo(N_trials) + #= + # TO DO: Ideally we should use save_idxs here to save some memory for long solves. + # However it does not seem possible currently to either do time interpolation on the solution + # or access observed states when save_idxs is used. Need to check with SciML people. + states = unknowns(sys) + idxs_V = findall(s -> occursin("₊V(t)", s), String.(Symbol.(states))) + + states_learning = mapreduce(get_eval_states, union, values(learning_rules)) + action_selection = agent.action_selection + if !isnothing(action_selection) + states_learning = union(states_learning, get_eval_states(action_selection)) + end + + idxs_learning = map(states_learning) do sl + findfirst(s -> occursin(String(Symbol(sl)), String(Symbol(s))), states) + end + filter!(!isnothing, idxs_learning) + + save_idxs = union(idxs_V, idxs_learning) + =# + + for trial in Base.OneTo(N_trials) + sol = run_trial!(agent, env, weights, nothing; kwargs...) + + save_voltages(sol, save_path, trial) + end +end + +function run_warmup(agent::Agent, env::ClassificationEnvironment, t_warmup; kwargs...) - stim_params = get_trial_stimulus(env) + prob = remake(agent.problem; tspan=(0, t_warmup)) + if haskey(kwargs, :alg) + sol = solve(prob, kwargs[:alg]; save_everystep=false, kwargs...) + else + sol = solve(prob; alg_hints = [:stiff], save_everystep=false, kwargs...) + end + u0 = sol[:,end] # last value of state vector - to_update = merge(weights, stim_params) - new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params)) + return u0 +end - prob = remake(prob; p = new_params, u0=u0) - if haskey(kwargs, :alg) - sol = solve(prob, kwargs[:alg]; kwargs...) - else - sol = solve(prob; alg_hints = [:stiff], kwargs...) - end +function run_trial!(agent::Agent, env::ClassificationEnvironment, weights, u0; kwargs...) - # u0 = sol[1:end,end] # next run should continue where the last one ended - # In the paper we assume sufficient time interval before net stimulus so that - # system reaches back to steady state, so we don't continue from previous trial's endpoint + prob = agent.problem + action_selection = agent.action_selection + learning_rules = agent.learning_rules + sys = get_sys(agent) + defs = ModelingToolkit.get_defaults(sys) - if isnothing(action_selection) - feedback = 1 - else - action = action_selection(sol) - feedback = env(action) - end + if haskey(kwargs, :alg) + sol = solve(prob, kwargs[:alg]; kwargs...) + else + sol = solve(prob; alg_hints = [:stiff], kwargs...) + end - for (w, rule) in learning_rules - w_val = weights[w] - Δw = weight_gradient(rule, sol, w_val, feedback) - weights[w] += Δw - end - increment_trial!(env) + # u0 = sol[1:end,end] # next run should continue where the last one ended + # In the paper we assume sufficient time interval before next stimulus so that + # system reaches back to steady state, so we don't continue from previous trial's endpoint - if !isnothing(save_path) - save_voltages(sol, save_path, trial_num) - end + if isnothing(action_selection) + feedback = 1 + else + action = action_selection(sol) + feedback = env(action) + end + for (w, rule) in learning_rules + w_val = weights[w] + Δw = weight_gradient(rule, sol, w_val, feedback) + weights[w] += Δw end + + increment_trial!(env) + + stim_params = get_trial_stimulus(env) + new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params)) + + agent.problem = remake(prob; p = new_params) - agent.problem = prob + return sol end function save_voltages(sol, filepath, numtrial) df = DataFrame(sol) fname = "sim"*lpad(numtrial, 4, "0")*".csv" fullpath = joinpath(filepath, fname) - CSV.write(fullpath, df) + write(fullpath, df) end diff --git a/test/GraphDynamicsTests/test_suite.jl b/test/GraphDynamicsTests/test_suite.jl index bdcff148..2e1f7258 100644 --- a/test/GraphDynamicsTests/test_suite.jl +++ b/test/GraphDynamicsTests/test_suite.jl @@ -70,6 +70,11 @@ function test_compare_du_and_sols(::Type{ODEProblem}, g, tspan; sol_u_reordered, du_reordered end @debug "" norm(sol_grp .- sol_mtk) / norm(sol_mtk) + for i ∈ eachindex(state_names) + if !isapprox(sol_grp[i], sol_mtk[i]; rtol=rtol) + @debug "" i state_names[i] sol_grp[i] sol_mtk[i] + end + end @test sort(du_grp) ≈ sort(du_mtk) # due to the MTK getu bug, we'll compare the sorted versions @test sol_grp ≈ sol_mtk rtol=rtol end @@ -686,13 +691,13 @@ function lif_exci_inh_tests(;tspan=(0.0, 20.0), rtol=1e-8) add_edge!(g, background_input2 => n1; weight = 0.0) add_edge!(g, stim_A => n1; weight = 1.0) add_edge!(g, stim_B => n1; weight = 1.0) - add_edge!(g, n1 => n2; weight = 1.0) - add_edge!(g, n2 => n1; weight = 2.0) - add_edge!(g, n3 => n1; weight = 3.0) + add_edge!(g, n1 => n2; weight = 1.0) + add_edge!(g, n2 => n1; weight = 2.0) + add_edge!(g, n3 => n1; weight = 3.0) test_compare_du_and_sols(ODEProblem, g, tspan; rtol, alg=Tsit5()) end -function decision_making_test(;tspan=(0.0, 9.0), rtol=1e-5) +function decision_making_test(;tspan=(0.0, 20.0), rtol=1e-5, N_E=24) ## Describe what the local variables you define are for global_ns = :g ## global name for the circuit. All components should be inside this namespace. @@ -700,7 +705,7 @@ function decision_making_test(;tspan=(0.0, 9.0), rtol=1e-5) spike_rate = 2.4 ## spikes / ms f = 0.15 ## ratio of selective excitatory to non-selective excitatory neurons - N_E = 24 ## total number of excitatory neurons + N_E ## total number of excitatory neurons N_I = Int(ceil(N_E / 4)) ## total number of inhibitory neurons N_E_selective = Int(ceil(f * N_E)) ## number of selective excitatory neurons N_E_nonselective = N_E - 2 * N_E_selective ## number of non-selective excitatory neurons @@ -761,26 +766,8 @@ function decision_making_test(;tspan=(0.0, 9.0), rtol=1e-5) add_edge!(g, n_inh => n_A; weight = 1) add_edge!(g, n_inh => n_B; weight = 1) add_edge!(g, n_inh => n_ns; weight = 1) - test_compare_du_and_sols(ODEProblem, g, tspan; rtol, alg=Tsit5()) - # local state_names - # sol_gys = let sys = graphsystem_from_graph(g) - # prob = ODEProblem(sys, [], tspan, []) - # sol = solve(prob, Tsit5()) - - # state_names = variable_symbols(sys) - # sol_u_reordered = map(state_names) do name - # sol[name][end] - # end - # end - # sol_mtk = let sys = system_from_graph(g; name=:sys) - # prob = ODEProblem(sys, [], tspan, []) - # sol = solve(prob, Tsit5()) - # sol_u_reordered = map(state_names) do name - # sol[name][end] - # end - # end - # sort(sol_gys .- sol_mtk) + test_compare_du_and_sols(ODEProblem, g, tspan; rtol, alg=Tsit5()) end function ping_tests(;tspan=(0.0, 2.0)) diff --git a/test/plasticity.jl b/test/plasticity.jl index da00d06b..f9f36a73 100644 --- a/test/plasticity.jl +++ b/test/plasticity.jl @@ -49,7 +49,7 @@ using ModelingToolkit: getp env = ClassificationEnvironment(stim; name=:env, namespace=global_ns) - run_experiment!(agent, env; alg=Tsit5(), reltol=1e-6,abstol=1e-9) + run_experiment!(agent, env; t_warmup=200, alg=Tsit5(), reltol=1e-6,abstol=1e-9) final_params = agent.problem.p # At least some weights need to be different. diff --git a/test/reinforcement_learning.jl b/test/reinforcement_learning.jl index 0dbdd61c..e2878397 100644 --- a/test/reinforcement_learning.jl +++ b/test/reinforcement_learning.jl @@ -66,7 +66,7 @@ using ModelingToolkit: getp init_params_idxs_other_params = params_at(idxs_other_params) env = ClassificationEnvironment(stim; name=:env, namespace=global_ns) - run_experiment!(agent, env; alg=Vern7(), reltol=1e-9,abstol=1e-9) + run_experiment!(agent, env; t_warmup=200, alg=Vern7(), reltol=1e-9,abstol=1e-9) final_params = reduce(vcat, agent.problem.p) # At least some weights need to be different. @@ -81,3 +81,82 @@ using ModelingToolkit: getp reset!(env) @test env.current_trial == 1 end + +@testset "RL test with save" begin + t_trial = 2 # ms + time_block_dur = 0.01 # ms + N_trials = 3 + + global_ns = :g # global namespace + @named VAC = CorticalBlox(N_wta=3, N_exci=3, namespace=global_ns, density=0.1, weight=1) + @named PFC = CorticalBlox(N_wta=2, N_exci=3, namespace=global_ns, density=0.1, weight=1) + @named STR_L = Striatum(N_inhib=2, namespace=global_ns) + @named STR_R = Striatum(N_inhib=2, namespace=global_ns) + @named SNcb = SNc(namespace=global_ns, N_time_blocks=t_trial/time_block_dur) + @named TAN_pop = TAN(;namespace=global_ns) + + @named AS = GreedyPolicy(namespace=global_ns, t_decision=0.31*t_trial) + + fn = joinpath(@__DIR__, "../examples/image_example.csv") + data = CSV.read(fn, DataFrame) + @named stim = ImageStimulus(data[1:N_trials,:]; namespace=global_ns, t_stimulus=0.4*t_trial, t_pause=0.6*t_trial) + + bloxs = [VAC, PFC, STR_L, STR_R, SNcb, TAN_pop, AS, stim] + d = Dict(b => i for (i,b) in enumerate(bloxs)) + + hebbian_mod = HebbianModulationPlasticity(K=0.2, decay=0.01, α=3, θₘ=1, modulator=SNcb, t_pre=t_trial, t_post=t_trial, t_mod=0.31*t_trial) + hebbian = HebbianPlasticity(K=0.2, W_lim=2, t_pre=t_trial, t_post=t_trial) + + g = MetaDiGraph() + add_blox!.(Ref(g), bloxs) + + add_edge!(g, d[stim], d[VAC], Dict(:weight => 1, :density => 0.1)) + add_edge!(g, d[VAC], d[PFC], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian)) + add_edge!(g, d[PFC], d[STR_L], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian_mod)) + add_edge!(g, d[PFC], d[STR_R], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian_mod)) + add_edge!(g, d[STR_R], d[STR_L], Dict(:weight => 1, :t_event => 0.3*t_trial)) + add_edge!(g, d[STR_L], d[STR_R], Dict(:weight => 1, :t_event => 0.3*t_trial)) + add_edge!(g, d[STR_L], d[SNcb], Dict(:weight => 1)) + add_edge!(g, d[STR_R], d[SNcb], Dict(:weight => 1)) + add_edge!(g, d[STR_L], d[AS]) + add_edge!(g, d[STR_R], d[AS]) + add_edge!(g, d[STR_L], d[TAN_pop], Dict(:weight => 1)) + add_edge!(g, d[STR_R], d[TAN_pop], Dict(:weight => 1)) + add_edge!(g, d[TAN_pop], d[STR_L], Dict(:weight => 1, :t_event => 0.1*t_trial)) + add_edge!(g, d[TAN_pop], d[STR_R], Dict(:weight => 1, :t_event => 0.1*t_trial)) + + agent = Agent(g; name=:ag, t_block = t_trial/5); + ps = parameters(agent.odesystem) + + + map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps)) + idxs_weight = findall(x -> occursin("w_", String(Symbol(x))), ps) + idx_stim = findall(x -> occursin("stim₊", String(Symbol(x))), ps) + idx_jcn = findall(x -> occursin("jcn", String(Symbol(x))), ps) + idx_spikes = findall(x -> occursin("spikes", String(Symbol(x))), ps) + idx_H = findall(x -> occursin("H", String(Symbol(x))), ps) + idx_I_bg = findall(x -> occursin("I_bg", String(Symbol(x))), ps) + idxs_other_params = setdiff(eachindex(ps), vcat(idxs_weight, idx_stim, idx_jcn, idx_spikes, idx_H, idx_I_bg)) + + params_at(idxs) = getp(agent.problem, parameters(agent.odesystem)[idxs])(agent.problem) + init_params_all = params_at(:) + init_params_idxs_weight = params_at(idxs_weight) + init_params_idxs_other_params = params_at(idxs_other_params) + + env = ClassificationEnvironment(stim; name=:env, namespace=global_ns) + run_experiment!(agent, env, "./"; t_warmup=200, alg=Vern7(), reltol=1e-9,abstol=1e-9) + + final_params = reduce(vcat, agent.problem.p) + # At least some weights need to be different. + @test any(init_params_idxs_weight .!= params_at(idxs_weight)) + # @test any(init_params[map_idxs[idxs_weight]] .!= final_params[map_idxs[idxs_weight]]) + # All non-weight parameters need to be the same. + @test all(init_params_idxs_other_params .== params_at(idxs_other_params)) + # @test all(init_params[map_idxs[idxs_other_params]] .== final_params[map_idxs[idxs_other_params]]) + + reset!(agent) + @test all(init_params_all .== params_at(:)) + @show setdiff(init_params_all, params_at(:)) + reset!(env) + @test env.current_trial == 1 +end \ No newline at end of file diff --git a/test/reinforcement_learning_flattening.jl b/test/reinforcement_learning_flattening.jl deleted file mode 100644 index ce7b5d92..00000000 --- a/test/reinforcement_learning_flattening.jl +++ /dev/null @@ -1,88 +0,0 @@ -using Neuroblox -using DifferentialEquations -using Test -using Graphs -using MetaGraphs -using DataFrames -using CSV -using ModelingToolkit: getp - -@testset "RL test with save" begin - t_trial = 2 # ms - time_block_dur = 0.01 # ms - N_trials = 3 - - global_ns = :g # global namespace - @named VAC = CorticalBlox(N_wta=3, N_exci=3, namespace=global_ns, density=0.1, weight=1) - @named PFC = CorticalBlox(N_wta=2, N_exci=3, namespace=global_ns, density=0.1, weight=1) - @named STR_L = Striatum(N_inhib=2, namespace=global_ns) - @named STR_R = Striatum(N_inhib=2, namespace=global_ns) - @named SNcb = SNc(namespace=global_ns, N_time_blocks=t_trial/time_block_dur) - @named TAN_pop = TAN(;namespace=global_ns) - - @named AS = GreedyPolicy(namespace=global_ns, t_decision=0.31*t_trial) - - fn = joinpath(@__DIR__, "../examples/image_example.csv") - data = CSV.read(fn, DataFrame) - @named stim = ImageStimulus(data[1:N_trials,:]; namespace=global_ns, t_stimulus=0.4*t_trial, t_pause=0.6*t_trial) - - bloxs = [VAC, PFC, STR_L, STR_R, SNcb, TAN_pop, AS, stim] - d = Dict(b => i for (i,b) in enumerate(bloxs)) - - hebbian_mod = HebbianModulationPlasticity(K=0.2, decay=0.01, α=3, θₘ=1, modulator=SNcb, t_pre=t_trial, t_post=t_trial, t_mod=0.31*t_trial) - hebbian = HebbianPlasticity(K=0.2, W_lim=2, t_pre=t_trial, t_post=t_trial) - - g = MetaDiGraph() - add_blox!.(Ref(g), bloxs) - - add_edge!(g, d[stim], d[VAC], Dict(:weight => 1, :density => 0.1)) - add_edge!(g, d[VAC], d[PFC], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian)) - add_edge!(g, d[PFC], d[STR_L], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian_mod)) - add_edge!(g, d[PFC], d[STR_R], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian_mod)) - add_edge!(g, d[STR_R], d[STR_L], Dict(:weight => 1, :t_event => 0.3*t_trial)) - add_edge!(g, d[STR_L], d[STR_R], Dict(:weight => 1, :t_event => 0.3*t_trial)) - add_edge!(g, d[STR_L], d[SNcb], Dict(:weight => 1)) - add_edge!(g, d[STR_R], d[SNcb], Dict(:weight => 1)) - add_edge!(g, d[STR_L], d[AS]) - add_edge!(g, d[STR_R], d[AS]) - add_edge!(g, d[STR_L], d[TAN_pop], Dict(:weight => 1)) - add_edge!(g, d[STR_R], d[TAN_pop], Dict(:weight => 1)) - add_edge!(g, d[TAN_pop], d[STR_L], Dict(:weight => 1, :t_event => 0.1*t_trial)) - add_edge!(g, d[TAN_pop], d[STR_R], Dict(:weight => 1, :t_event => 0.1*t_trial)) - - agent = Agent(g; name=:ag, t_block = t_trial/5); - ps = parameters(agent.odesystem) - - - map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps)) - idxs_weight = findall(x -> occursin("w_", String(Symbol(x))), ps) - idx_stim = findall(x -> occursin("stim₊", String(Symbol(x))), ps) - idx_jcn = findall(x -> occursin("jcn", String(Symbol(x))), ps) - idx_spikes = findall(x -> occursin("spikes", String(Symbol(x))), ps) - idx_H = findall(x -> occursin("H", String(Symbol(x))), ps) - idx_I_bg = findall(x -> occursin("I_bg", String(Symbol(x))), ps) - idxs_other_params = setdiff(eachindex(ps), vcat(idxs_weight, idx_stim, idx_jcn, idx_spikes, idx_H, idx_I_bg)) - - params_at(idxs) = getp(agent.problem, parameters(agent.odesystem)[idxs])(agent.problem) - init_params_all = params_at(:) - init_params_idxs_weight = params_at(idxs_weight) - init_params_idxs_other_params = params_at(idxs_other_params) - - env = ClassificationEnvironment(stim; name=:env, namespace=global_ns) - run_experiment!(agent, env, "./"; alg=Vern7(), reltol=1e-9,abstol=1e-9) - - final_params = reduce(vcat, agent.problem.p) - # At least some weights need to be different. - @test any(init_params_idxs_weight .!= params_at(idxs_weight)) - # @test any(init_params[map_idxs[idxs_weight]] .!= final_params[map_idxs[idxs_weight]]) - # All non-weight parameters need to be the same. - @test all(init_params_idxs_other_params .== params_at(idxs_other_params)) - # @test all(init_params[map_idxs[idxs_other_params]] .== final_params[map_idxs[idxs_other_params]]) - - reset!(agent) - @test all(init_params_all .== params_at(:)) - @show setdiff(init_params_all, params_at(:)) - reset!(env) - @test env.current_trial == 1 -end -