Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need to converge to a single firing rate function #457

Open
gabrevaya opened this issue Oct 11, 2024 · 2 comments
Open

Need to converge to a single firing rate function #457

gabrevaya opened this issue Oct 11, 2024 · 2 comments

Comments

@gabrevaya
Copy link
Contributor

Currently, there are two functions for measuring mean firing rate: mean_firing_rate and firing_rate. @harisorgn, I think we should probably try to converge to a single function.

The function firing_rate has some serious performance issues, at least in the following case:

using Neuroblox
using StochasticDiffEq

@named fsi = Striatum_FSI_Adam(N_inhib = 10)
sys = get_system(fsi; simplify = true)
prob = SDEProblem(sys, [], (0.0, 10000.0), [])
sol = solve(prob, RKMil(); dt = 0.05, saveat = 0.05, abstol = 1e-2, reltol = 1e-2);
@btime firing_rate($fsi, $sol, threshold=-25)
# 125.693 s (1299235830 allocations: 78.25 GiB)

@btime let 
    spikes = detect_spikes($fsi, $sol; threshold=-25)
    t, fr_fsi = mean_firing_rate($spikes, $sol, trim_transient=500)
end
# 100.482 ms (841 allocations: 46.38 MiB)

I don't have time now to dig into what is causing that performance issue (I'll probably come back to this on Tuesday), but I guess it's due to:

    S = mapreduce(sparse_hcat, neurons) do neuron
        detect_spikes(neuron, sol; threshold, ts)
    end

in the detect_spikes function, similar to what used to happen in state_timeseries here: 87eb04b.

BTW, I've just added a dispatch for mean_firing_rate to account for EnsembleProblems and fixed the trimming of the initial transient in #455.

@harisorgn
Copy link
Member

Agreed. I wrote firing_rate sloppily to debug my tutorial. The issue is most likely the multiple calls to detect_spikes. A solution with a single call is definitely possible, so more like what you have. We just need a dispatch where the first argument is a composite blox or a vector of bloxs.

@harisorgn
Copy link
Member

Added more comments to #455 as I was going through your new dispatch. These comments are more relevant here though, apologies 😅 . I am linking so we don't forget.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants