From 0165b6743355d52718d1a6fe03e24876a811a202 Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Thu, 14 Nov 2024 09:04:09 -0800 Subject: [PATCH] FFT OpenBC Solver: more optimization (#4232) --- Src/FFT/AMReX_FFT_OpenBCSolver.H | 2 + Src/FFT/AMReX_FFT_R2C.H | 97 +++++++++++++++++++++++--------- 2 files changed, 72 insertions(+), 27 deletions(-) diff --git a/Src/FFT/AMReX_FFT_OpenBCSolver.H b/Src/FFT/AMReX_FFT_OpenBCSolver.H index b3ba2df90b..1aeae31332 100644 --- a/Src/FFT/AMReX_FFT_OpenBCSolver.H +++ b/Src/FFT/AMReX_FFT_OpenBCSolver.H @@ -151,6 +151,8 @@ void OpenBCSolver::setGreensFunction (F const& greens_function) } } } + + m_r2c.prepare_openbc(); } template diff --git a/Src/FFT/AMReX_FFT_R2C.H b/Src/FFT/AMReX_FFT_R2C.H index 9277e96e13..bd0ad1ff39 100644 --- a/Src/FFT/AMReX_FFT_R2C.H +++ b/Src/FFT/AMReX_FFT_R2C.H @@ -164,6 +164,8 @@ public: template void post_forward_doit (F const& post_forward); + void prepare_openbc (); + private: static std::pair,Plan> make_c2c_plans (cMF& inout); @@ -176,6 +178,8 @@ private: Plan m_fft_bwd_y{}; Plan m_fft_fwd_z{}; Plan m_fft_bwd_z{}; + Plan m_fft_fwd_x_half{}; + Plan m_fft_bwd_x_half{}; // Comm meta-data. In the forward phase, we start with (x,y,z), // transpose to (y,x,z) and then (z,x,y). In the backward phase, we @@ -394,6 +398,60 @@ R2C::~R2C () m_fft_fwd_x.destroy(); m_fft_fwd_y.destroy(); m_fft_fwd_z.destroy(); + if (m_fft_bwd_x_half.plan != m_fft_fwd_x_half.plan) { + m_fft_bwd_x_half.destroy(); + } + m_fft_fwd_x_half.destroy(); +} + +template +void R2C::prepare_openbc () +{ +#if (AMREX_SPACEDIM == 3) + if (m_slab_decomp) { + auto* fab = detail::get_fab(m_rx); + if (fab) { + Box bottom_half = m_real_domain; + bottom_half.growHi(2,-m_real_domain.length(2)/2); + Box box = fab->box() & bottom_half; + if (box.ok()) { + auto* pr = fab->dataPtr(); + auto* pc = (typename Plan::VendorComplex *) + detail::get_fab(m_cx)->dataPtr(); +#ifdef AMREX_USE_SYCL + m_fft_fwd_x_half.template init_r2c + (box, pr, pc, m_slab_decomp); + m_fft_bwd_x_half = m_fft_fwd_x_half; +#else + if constexpr (D == Direction::both || D == Direction::forward) { + m_fft_fwd_x_half.template init_r2c + (box, pr, pc, m_slab_decomp); + } + if constexpr (D == Direction::both || D == Direction::backward) { + m_fft_bwd_x_half.template init_r2c + (box, pr, pc, m_slab_decomp); + } +#endif + } + } + } // else todo + + if (m_cmd_x2z && ! m_cmd_x2z_half) { + Box bottom_half = m_spectral_domain_z; + // Note that z-direction's index is 0 because we z is the + // unit-stride direction here. + bottom_half.growHi(0,-m_spectral_domain_z.length(0)/2); + m_cmd_x2z_half = std::make_unique + (m_cz, bottom_half, m_cx, IntVect(0), m_dtos_x2z); + } + + if (m_cmd_z2x && ! m_cmd_z2x_half) { + Box bottom_half = m_spectral_domain_x; + bottom_half.growHi(2,-m_spectral_domain_x.length(2)/2); + m_cmd_z2x_half = std::make_unique + (m_cx, bottom_half, m_cz, IntVect(0), m_dtos_z2x); + } +#endif } template @@ -406,7 +464,8 @@ void R2C::forward (MF const& inmf) if (&m_rx != &inmf) { m_rx.ParallelCopy(inmf, 0, 0, 1); } - m_fft_fwd_x.template compute_r2c(); + auto& fft_x = m_openbc_half ? m_fft_fwd_x_half : m_fft_fwd_x; + fft_x.template compute_r2c(); if ( m_cmd_x2y) { ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y); @@ -419,19 +478,16 @@ void R2C::forward (MF const& inmf) #if (AMREX_SPACEDIM == 3) else if ( m_cmd_x2z) { if (m_openbc_half) { - Box upper_half = m_spectral_domain_z; - // Note that z-direction's index is 0 because we z is the unit-stride direction here. - upper_half.growLo (0,-m_spectral_domain_z.length(0)/2); - if (! m_cmd_x2z_half) { - Box bottom_half = m_spectral_domain_z; - bottom_half.growHi(0,-m_spectral_domain_z.length(0)/2); - m_cmd_x2z_half = std::make_unique - (m_cz, bottom_half, m_cx, IntVect(0), m_dtos_x2z); - } NonLocalBC::ApplyDtosAndProjectionOnReciever packing {NonLocalBC::PackComponents{}, m_dtos_x2z}; auto handler = ParallelCopy_nowait(m_cz, m_cx, *m_cmd_x2z_half, packing); + + Box upper_half = m_spectral_domain_z; + // Note that z-direction's index is 0 because we z is the + // unit-stride direction here. + upper_half.growLo (0,-m_spectral_domain_z.length(0)/2); m_cz.setVal(0, upper_half, 0, 1); + ParallelCopy_finish(m_cz, std::move(handler), *m_cmd_x2z_half, packing); } else { ParallelCopy(m_cz, m_cx, *m_cmd_x2z, 0, 0, 1, m_dtos_x2z); @@ -459,22 +515,8 @@ void R2C::backward_doit (MF& outmf, IntVect const& ngout) } #if (AMREX_SPACEDIM == 3) else if ( m_cmd_z2x) { - if (m_openbc_half) { - Box upper_half = m_spectral_domain_x; - upper_half.growLo (2,-m_spectral_domain_x.length(2)/2); - if (! m_cmd_z2x_half) { - Box bottom_half = m_spectral_domain_x; - bottom_half.growHi(2,-m_spectral_domain_x.length(2)/2); - m_cmd_z2x_half = std::make_unique - (m_cx, bottom_half, m_cz, IntVect(0), m_dtos_z2x); - } - NonLocalBC::ApplyDtosAndProjectionOnReciever packing - {NonLocalBC::PackComponents{}, m_dtos_z2x}; - auto handler = ParallelCopy_nowait(m_cx, m_cz, *m_cmd_z2x_half, packing); - ParallelCopy_finish(m_cx, std::move(handler), *m_cmd_z2x_half, packing); - } else { - ParallelCopy(m_cx, m_cz, *m_cmd_z2x, 0, 0, 1, m_dtos_z2x); - } + auto const& cmd = m_openbc_half ? m_cmd_z2x_half : m_cmd_z2x; + ParallelCopy(m_cx, m_cz, *cmd, 0, 0, 1, m_dtos_z2x); } #endif @@ -483,7 +525,8 @@ void R2C::backward_doit (MF& outmf, IntVect const& ngout) ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x); } - m_fft_bwd_x.template compute_r2c(); + auto& fft_x = m_openbc_half ? m_fft_bwd_x_half : m_fft_bwd_x; + fft_x.template compute_r2c(); outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout); }