diff --git a/src/transfer-functions/integration.jl b/src/transfer-functions/integration.jl index f5a6184d..bd000cb6 100644 --- a/src/transfer-functions/integration.jl +++ b/src/transfer-functions/integration.jl @@ -18,10 +18,10 @@ _integration_setup( g_grid_upscale = 2, ) = IntegrationSetup(h, time, integrand, pure_radial, alloc_segbuf(T), g_grid_upscale) -function _g_fine_grid_iterate(setup::IntegrationSetup, lo, hi) - Δ = (hi - lo) / setup.g_grid_upscale - lows = range(lo, hi - Δ, setup.g_grid_upscale) - highs = range(lo + Δ, hi, setup.g_grid_upscale) +function _g_fine_grid_iterate(upscale, lo, hi) + Δ = (hi - lo) / upscale + lows = range(lo, hi - Δ, upscale) + highs = range(lo + Δ, hi, upscale) zip(lows, highs) end @@ -249,7 +249,8 @@ function _integrate_transfer_problem!( glo = g_grid[j] ghi = g_grid[j+1] flux_contrib = zero(eltype(output)) - for (g_fine_lo, g_fine_hi) in _g_fine_grid_iterate(setup, glo, ghi) + for (g_fine_lo, g_fine_hi) in + _g_fine_grid_iterate(setup.g_grid_upscale, glo, ghi) k = integrate_bin(closures, _both_branches(closures), g_fine_lo, g_fine_hi) flux_contrib += k end @@ -278,26 +279,28 @@ function _integrate_transfer_problem!( t_source_disc = setup.time(rₑ) @inbounds for j in eachindex(g_grid_view) - glo = clamp(g_grid[j], closures.branch.gmin, closures.branch.gmax) - ghi = clamp(g_grid[j+1], closures.branch.gmin, closures.branch.gmax) + g_grid_low = clamp(g_grid[j], closures.branch.gmin, closures.branch.gmax) + g_grid_hi = clamp(g_grid[j+1], closures.branch.gmin, closures.branch.gmax) # skip if bin not relevant - if glo == ghi + if g_grid_low == g_grid_hi continue end - k1 = integrate_bin(closures, _lower_branch(closures), glo, ghi) - k2 = integrate_bin(closures, _upper_branch(closures), glo, ghi) - - # find which bin to dump in - t_lower_branch, t_upper_branch = _time_bins(closures, glo, ghi) - i1 = searchsortedfirst(t_grid, t_lower_branch + t_source_disc) - i2 = searchsortedfirst(t_grid, t_upper_branch + t_source_disc) - - imax = lastindex(t_grid) - if i1 <= imax - output[j, i1] += k1 * θ - end - if i2 <= imax - output[j, i2] += k2 * θ + for (glo, ghi) in + _g_fine_grid_iterate(setup.g_grid_upscale, g_grid_low, g_grid_hi) + k1 = integrate_bin(closures, _lower_branch(closures), glo, ghi) + k2 = integrate_bin(closures, _upper_branch(closures), glo, ghi) + # find which bin to dump in + t_lower_branch, t_upper_branch = _time_bins(closures, glo, ghi) + i1 = searchsortedfirst(t_grid, t_lower_branch + t_source_disc) + i2 = searchsortedfirst(t_grid, t_upper_branch + t_source_disc) + + imax = lastindex(t_grid) + if i1 <= imax + output[j, i1] += k1 * θ + end + if i2 <= imax + output[j, i2] += k2 * θ + end end end end