Skip to content

Commit

Permalink
Reduce memory usage in reinforcement learning (#485)
Browse files Browse the repository at this point in the history
* move `t_warmup` to kwargs

* more RL save test

* add `run_warmup` function

* update `run_trial!` function

* add getter functions for states and times when actions and learning rules are evaluated

* update `run_experiment!` dispatches

* move RL tests in a single file

* remove `save_idxs` kwarg

* import `CSV.write` for RL save
  • Loading branch information
harisorgn authored Nov 5, 2024
1 parent 3b9221d commit a77cb63
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 231 deletions.
2 changes: 1 addition & 1 deletion examples/RF_learning_simple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ agent = Agent(g; name=:ag, t_block = 90);
#define environment : contains stimuli and feedback
env = ClassificationEnvironment(stim; name=:env, namespace=global_ns)

run_experiment!(agent, env; alg=Vern7(), reltol=1e-9,abstol=1e-9)
run_experiment!(agent, env; t_warmup=200, alg=Vern7(), reltol=1e-9,abstol=1e-9)
2 changes: 1 addition & 1 deletion src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import ModelingToolkit: inputs, nameof, outputs, getdescription

using Symbolics: @register_symbolic, getdefaultval, get_variables

using CSV: read
using CSV: read, write
using DataFrames

using Peaks: argmaxima, peakproms!, peakheights!, findmaxima
Expand Down
250 changes: 111 additions & 139 deletions src/blox/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ function weight_gradient(hp::HebbianPlasticity, sol, w, feedback)
return hp(val_pre, val_post, w, feedback)
end

get_eval_times(l::HebbianPlasticity) = [l.t_pre, l.t_post]

get_eval_states(l::HebbianPlasticity) = [l.state_pre, l.state_post]

mutable struct HebbianModulationPlasticity <: AbstractLearningRule
const K
const decay
Expand Down Expand Up @@ -74,6 +78,10 @@ function weight_gradient(hmp::HebbianModulationPlasticity, sol, w, feedback)
return hmp(val_pre, val_post, val_mod, w, feedback)
end

get_eval_times(l::HebbianModulationPlasticity) = [l.t_pre, l.t_post, l.t_mod]

get_eval_states(l::HebbianModulationPlasticity) = [l.state_pre, l.state_post, get_modulator_state(l.modulator)]

function maybe_set_state_pre!(lr::AbstractLearningRule, state)
if isnothing(lr.state_pre)
lr.state_pre = state
Expand Down Expand Up @@ -121,7 +129,7 @@ end

(env::ClassificationEnvironment)(action) = action == env.category[env.current_trial]

increment_trial!(env::AbstractEnvironment) = env.current_trial += 1
increment_trial!(env::AbstractEnvironment) = env.current_trial = mod(env.current_trial, env.N_trials) + 1

reset!(env::AbstractEnvironment) = env.current_trial = 1

Expand Down Expand Up @@ -153,6 +161,10 @@ function (p::GreedyPolicy)(sol::SciMLBase.AbstractSciMLSolution)
return argmax(comp_vals)
end

get_eval_times(gp::GreedyPolicy) = [gp.t_decision]

get_eval_states(gp::GreedyPolicy) = gp.competitor_states

"""
function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem)
ps = parameters(sys)
Expand Down Expand Up @@ -200,191 +212,151 @@ end

reset!(ag::Agent) = ag.problem = remake(ag.problem; p = ag.init_params)

function run_experiment!(agent::Agent, env::ClassificationEnvironment, t_warmup=200.0; kwargs...)
function run_experiment!(agent::Agent, env::ClassificationEnvironment; t_warmup=0, kwargs...)
N_trials = env.N_trials
t_trial = env.t_trial
tspan = (0, t_trial)

sys = get_sys(agent)
prob = agent.problem
defs = ModelingToolkit.get_defaults(sys)
learning_rules = agent.learning_rules

stim_params = get_trial_stimulus(env)
init_params = ModelingToolkit.MTKParameters(sys, merge(defs, stim_params))

if t_warmup > 0
prob = remake(prob; tspan=(0,t_warmup))
if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], kwargs...)
end
u0 = sol[1:end,end] # last value of state vector
prob = remake(prob; tspan=tspan, u0=u0)
u0 = run_warmup(agent, env, t_warmup; kwargs...)
agent.problem = remake(agent.problem; tspan, u0=u0, p=init_params)
else
prob = remake(prob; tspan)
u0 = []
agent.problem = remake(agent.problem; tspan, p=init_params)
end

t_stops = mapreduce(get_eval_times, union, values(learning_rules))

action_selection = agent.action_selection
if !isnothing(action_selection)
t_stops = union(t_stops, get_eval_times(action_selection))
end

action_selection = agent.action_selection
learning_rules = agent.learning_rules

defs = ModelingToolkit.get_defaults(sys)
weights = Dict{Num, Float64}()
for w in keys(learning_rules)
weights[w] = defs[w]
end

for _ in Base.OneTo(N_trials)

stim_params = get_trial_stimulus(env)

to_update = merge(weights, stim_params)
new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params))

prob = remake(prob; p = new_params, u0=u0)
if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], kwargs...)
end

# u0 = sol[1:end,end] # next run should continue where the last one ended
# In the paper we assume sufficient time interval before net stimulus so that
# system reaches back to steady state, so we don't continue from previous trial's endpoint

if isnothing(action_selection)
feedback = 1
else
action = action_selection(sol)
feedback = env(action)
end

for (w, rule) in learning_rules
w_val = weights[w]
Δw = weight_gradient(rule, sol, w_val, feedback)
weights[w] += Δw
end
increment_trial!(env)
run_trial!(agent, env, weights, nothing; saveat = t_stops, kwargs...)
end

agent.problem = prob
end

function run_trial!(agent::Agent, env::ClassificationEnvironment, weights::Dict{Num, Float64}, u0::Vector{Float64}; kwargs...)
N_trials = env.N_trials

if env.current_trial <= N_trials
t_trial = env.t_trial
tspan = (0, t_trial)

prob = agent.problem

action_selection = agent.action_selection
learning_rules = agent.learning_rules

@show env.current_trial
stim_params = get_trial_stimulus(env)
@show stim_params
@show weights
prob = remake(prob; tspan=tspan, p = merge(weights, stim_params), u0=u0)

if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], kwargs...)
end

if isnothing(action_selection)
feedback = 1
else
action = action_selection(sol)
feedback = env(action)
end

for (w, rule) in learning_rules
w_val = weights[w]
Δw = weight_gradient(rule, sol, w_val, feedback)
@show Δw
weights[w] += Δw
end
prob = remake(prob; p = merge(weights)) #updates the weights in prob
increment_trial!(env)
agent.problem = prob
# u0 = sol[1:end,end]
end
end

function run_experiment!(agent::Agent, env::ClassificationEnvironment, save_path::String, t_warmup=200.0; kwargs...)
function run_experiment!(agent::Agent, env::ClassificationEnvironment, save_path::String; t_warmup=0, kwargs...)
N_trials = env.N_trials
t_trial = env.t_trial
tspan = (0, t_trial)

sys = get_sys(agent)
prob = agent.problem
defs = ModelingToolkit.get_defaults(sys)
learning_rules = agent.learning_rules

stim_params = get_trial_stimulus(env)
init_params = ModelingToolkit.MTKParameters(sys, merge(defs, stim_params))

if t_warmup > 0
prob = remake(prob; tspan=(0,t_warmup))
if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], kwargs...)
end
u0 = sol[1:end,end] # last value of state vector
prob = remake(prob; tspan=tspan, u0=u0)
u0 = run_warmup(agent, env, t_warmup; kwargs...)
agent.problem = remake(agent.problem; tspan, u0=u0, p=init_params)
else
prob = remake(prob; tspan)
u0 = []
agent.problem = remake(agent.problem; tspan, p=init_params)
end

action_selection = agent.action_selection
learning_rules = agent.learning_rules

defs = ModelingToolkit.get_defaults(sys)
weights = Dict{Num, Float64}()
for w in keys(learning_rules)
weights[w] = defs[w]
end

for trial_num in Base.OneTo(N_trials)
#=
# TO DO: Ideally we should use save_idxs here to save some memory for long solves.
# However it does not seem possible currently to either do time interpolation on the solution
# or access observed states when save_idxs is used. Need to check with SciML people.
states = unknowns(sys)
idxs_V = findall(s -> occursin("₊V(t)", s), String.(Symbol.(states)))
states_learning = mapreduce(get_eval_states, union, values(learning_rules))
action_selection = agent.action_selection
if !isnothing(action_selection)
states_learning = union(states_learning, get_eval_states(action_selection))
end
idxs_learning = map(states_learning) do sl
findfirst(s -> occursin(String(Symbol(sl)), String(Symbol(s))), states)
end
filter!(!isnothing, idxs_learning)
save_idxs = union(idxs_V, idxs_learning)
=#

for trial in Base.OneTo(N_trials)
sol = run_trial!(agent, env, weights, nothing; kwargs...)

save_voltages(sol, save_path, trial)
end
end

function run_warmup(agent::Agent, env::ClassificationEnvironment, t_warmup; kwargs...)

stim_params = get_trial_stimulus(env)
prob = remake(agent.problem; tspan=(0, t_warmup))
if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; save_everystep=false, kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], save_everystep=false, kwargs...)
end
u0 = sol[:,end] # last value of state vector

to_update = merge(weights, stim_params)
new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params))
return u0
end

prob = remake(prob; p = new_params, u0=u0)
if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], kwargs...)
end
function run_trial!(agent::Agent, env::ClassificationEnvironment, weights, u0; kwargs...)

# u0 = sol[1:end,end] # next run should continue where the last one ended
# In the paper we assume sufficient time interval before net stimulus so that
# system reaches back to steady state, so we don't continue from previous trial's endpoint
prob = agent.problem
action_selection = agent.action_selection
learning_rules = agent.learning_rules
sys = get_sys(agent)
defs = ModelingToolkit.get_defaults(sys)

if isnothing(action_selection)
feedback = 1
else
action = action_selection(sol)
feedback = env(action)
end
if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], kwargs...)
end

for (w, rule) in learning_rules
w_val = weights[w]
Δw = weight_gradient(rule, sol, w_val, feedback)
weights[w] += Δw
end
increment_trial!(env)
# u0 = sol[1:end,end] # next run should continue where the last one ended
# In the paper we assume sufficient time interval before next stimulus so that
# system reaches back to steady state, so we don't continue from previous trial's endpoint

if !isnothing(save_path)
save_voltages(sol, save_path, trial_num)
end
if isnothing(action_selection)
feedback = 1
else
action = action_selection(sol)
feedback = env(action)
end

for (w, rule) in learning_rules
w_val = weights[w]
Δw = weight_gradient(rule, sol, w_val, feedback)
weights[w] += Δw
end

increment_trial!(env)

stim_params = get_trial_stimulus(env)
new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params))

agent.problem = remake(prob; p = new_params)

agent.problem = prob
return sol
end

function save_voltages(sol, filepath, numtrial)
df = DataFrame(sol)
fname = "sim"*lpad(numtrial, 4, "0")*".csv"
fullpath = joinpath(filepath, fname)
CSV.write(fullpath, df)
write(fullpath, df)
end
2 changes: 1 addition & 1 deletion test/plasticity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using ModelingToolkit: getp


env = ClassificationEnvironment(stim; name=:env, namespace=global_ns)
run_experiment!(agent, env; alg=Tsit5(), reltol=1e-6,abstol=1e-9)
run_experiment!(agent, env; t_warmup=200, alg=Tsit5(), reltol=1e-6,abstol=1e-9)

final_params = agent.problem.p
# At least some weights need to be different.
Expand Down
Loading

0 comments on commit a77cb63

Please sign in to comment.