Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added allocation-free correlation #39

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion src/correlations.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export ccorr
export ccorr, plan_ccorr, ccorr_psf, plan_ccorr_psf, plan_ccorr_buffer, plan_ccorr_psf_buffer

"""
ccorr(u, v[, dims]; centered=false)
Expand Down Expand Up @@ -65,3 +65,81 @@ function ccorr(u::AbstractArray{<:Real, N}, v::AbstractArray{<:Real, M},
return out
end
end

function ccorr_psf(u::AbstractArray{T, N}, psf::AbstractArray{D, M}, dims=ntuple(+, min(N, M))) where {T, D, N, M}
return ccorr(u, ifftshift(psf, dims), dims)
end

function p_ccorr_aux(P, P_inv, u, v_ft)
return (P_inv.p * ((P * u) .* conj(v_ft) .* P_inv.scale))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use conj.(v_ft). That should avoid some allocations.

end

function plan_ccorr(u::AbstractArray{T1, N}, v::AbstractArray{T2, M}, dims=ntuple(+, N);
kwargs...) where {T1, T2, N, M}
eltype_error(T1, T2)
plan = get_plan(T1)
# do the preplanning step
P = let
# FFTW.MEASURE flag might overwrite input! Hence copy!
if (:flags in keys(kwargs) &&
(getindex(kwargs, :flags) == FFTW.MEASURE || getindex(kwargs, :flags) == FFTW.PATIENT))
plan(copy(u), dims; kwargs...)
else
plan(u, dims; kwargs...)
end
end

v_ft = fft_or_rfft(T1)(v, dims)
# construct the efficient conv function
# P and P_inv can be understood like matrices
# but their computation is fast
ccorr = let P = P,
P_inv = inv(P),
# put a different name here! See https://discourse.julialang.org/t/type-issue-with-captured-variables-let-workaround-failed/85661
v_ft = v_ft
ccorr(u, v_ft=v_ft) = p_ccorr_aux(P, P_inv, u, v_ft)
end

return v_ft, ccorr
end

function plan_ccorr_psf(u::AbstractArray{T, N}, psf::AbstractArray{T, M}, dims=ntuple(+, N);
kwargs...) where {T, N, M}
return plan_ccorr(u, ifftshift(psf, dims), dims; kwargs...)
end

function plan_ccorr_buffer(u::AbstractArray{T1, N}, v::AbstractArray{T2, M}, dims=ntuple(+, N);
kwargs...) where {T1, T2, N, M}
eltype_error(T1, T2)
plan = get_plan(T1)
# do the preplanning step
P_u = plan(u, dims; kwargs...)
P_v = plan(v, dims)

u_buff = P_u * u
v_ft = P_v * v
conj!(v_ft)
uv_buff = u_buff .* v_ft

# for fourier space we need a new plan
P = plan(u .* v, dims; kwargs...)
P_inv = inv(P)
out_buff = P_inv * uv_buff

# construct the efficient conv function
# P and P_inv can be understood like matrices
# but their computation is fast
function ccorr(u, v_ft=v_ft)
mul!(u_buff, P_u, u)
uv_buff .= u_buff .* v_ft
mul!(out_buff, P_inv, uv_buff)
return out_buff
end

return v_ft, ccorr
end

function plan_ccorr_psf_buffer(u::AbstractArray{T, N}, psf::AbstractArray{T, M}, dims=ntuple(+, N);
kwargs...) where {T, N, M}
return plan_ccorr_buffer(u, ifftshift(psf, dims), dims; kwargs...)
end