Skip to content

Commit

Permalink
More efficient detect_spikes and firing_rate (#489)
Browse files Browse the repository at this point in the history
* add OhMyThreads dep

* import multithread utils

* allocate sparse array after nonzeros are known

* multithread spike detection per neuron

* improve firing rate function

* pass `kwargs...` to `tmapreduce`

* default to `scheduler=:serial`

* add `kwargs...` to avoid errors when dispatch is called with multithread kwargs
  • Loading branch information
harisorgn authored Nov 13, 2024
1 parent daa4f4b commit 40d4a9e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Peaks = "18e31ff7-3703-566c-8e60-38913d67486b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
4 changes: 4 additions & 0 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ module Neuroblox

import Base: merge

using Base.Threads: nthreads

using OhMyThreads: tmapreduce

using Reexport
@reexport using ModelingToolkit
const t = ModelingToolkit.t_nounits
Expand Down
31 changes: 17 additions & 14 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,13 +403,11 @@ end

replace_refractory!(V, blox, sol::SciMLBase.AbstractSolution) = V

function find_spikes(x::AbstractVector{T}; threshold=zero(T)) where {T}
spikes = spzeros(Bool, size(x))

function find_spikes(x::AbstractVector{T}; threshold=zero(T)) where {T}
spike_idxs = argmaxima(x)
peakheights!(spike_idxs, x[spike_idxs]; minheight = threshold)

spikes[spike_idxs] .= 1
spikes = sparsevec(spike_idxs, ones(length(spike_idxs)), length(x))

return spikes
end
Expand All @@ -423,7 +421,7 @@ end

function detect_spikes(
blox::AbstractNeuronBlox, sol::SciMLBase.AbstractSolution;
threshold = nothing, tolerance = 1e-3, ts = nothing
threshold = nothing, tolerance = 1e-3, ts = nothing, kwargs...
)
namespaced_name = namespaced_nameof(blox)

Expand All @@ -445,12 +443,12 @@ end

function detect_spikes(
blox::Union{CompositeBlox, AbstractVector{<:AbstractNeuronBlox}}, sol::SciMLBase.AbstractSolution;
threshold = nothing, ts=nothing
threshold = nothing, ts=nothing, scheduler=:serial, kwargs...
)

neurons = get_neurons(blox)

S = mapreduce(sparse_hcat, neurons) do neuron
S = tmapreduce(sparse_hcat, neurons; scheduler, kwargs...) do neuron
detect_spikes(neuron, sol; threshold, ts)
end

Expand All @@ -459,16 +457,20 @@ end

function firing_rate(
blox, sol::SciMLBase.AbstractSolution;
win_size = last(sol.t), win_resolution = 1e-3,
transient = 0, overlap = 0, threshold = nothing)
win_size = last(sol.t), transient = 0, overlap = 0,
threshold = nothing, scheduler=:serial, kwargs...)

spikes = detect_spikes(blox, sol; threshold, scheduler, kwargs...)
N_neurons = size(spikes, 2)

ts = sol.t
t_win_start = transient:(win_size - win_size*overlap):(last(ts) - win_size)

fr = map(t_win_start) do tws
spikes = detect_spikes(blox, sol; threshold, ts = tws:win_resolution:(tws + win_size))
N_neurons = size(spikes, 2)
1000.0 * (nnz(spikes) / N_neurons) / win_size
idx_start = findfirst(x -> x >= tws, ts)
idx_end = findfirst(x -> x >= tws + win_size, ts)

1000.0 * (nnz(spikes[idx_start:idx_end, :]) / N_neurons) / win_size
end

return fr
Expand All @@ -485,13 +487,14 @@ function mean_firing_rate(spikes::SparseMatrixCSC, sol; trim_transient = 0,
end

function mean_firing_rate(blox, sols::SciMLBase.EnsembleSolution; trim_transient = 0,
Δt = last(sols[1].t) - trim_transient, threshold = -35)
Δt = last(sols[1].t) - trim_transient, threshold = -35,
scheduler=:serial, kwargs...)

t_fr = trim_transient:Δt:last(sols[1].t)
firing_rates = Vector{Float64}[]

for sol in sols
spikes = detect_spikes(blox, sol; threshold = threshold)
spikes = detect_spikes(blox, sol; threshold, scheduler, kwargs...)
spikes = transpose(spikes)
fr = _mean_firing_rate(spikes, unique(sol.t), t_fr)
push!(firing_rates, fr)
Expand Down

0 comments on commit 40d4a9e

Please sign in to comment.