diff --git a/src/tools.hpp b/src/tools.hpp index 696bea3..ca4336b 100644 --- a/src/tools.hpp +++ b/src/tools.hpp @@ -36,20 +36,26 @@ namespace FsGridTools { typedef int64_t GlobalID; typedef int Task_t; + //! Helper function: calculate size of the local coordinate space for the given dimension + // \param numCells Number of cells in the global Simulation, in this dimension + // \param nTasks Total number of tasks in this dimension + // \param taskIndex This task's position in this dimension + // \return Number of cells for this task's local domain (actual cells, not counting ghost cells) + static FsIndex_t calcLocalSize(FsSize_t numCells, Task_t nTasks, Task_t taskIndex) { + const FsIndex_t nPerTask = numCells / nTasks; + const FsIndex_t remainder = numCells % nTasks; + return taskIndex < remainder ? nPerTask + 1 : nPerTask; + } + //! Helper function: calculate position of the local coordinate space for the given dimension - // \param globalCells Number of cells in the global Simulation, in this dimension - // \param ntasks Total number of tasks in this dimension - // \param my_n This task's position in this dimension + // \param numCells number of cells + // \param nTasks Total number of tasks in this dimension + // \param taskIndex This task's position in this dimension // \return Cell number at which this task's domains cells start (actual cells, not counting ghost cells) - static FsIndex_t calcLocalStart(FsSize_t globalCells, Task_t ntasks, Task_t my_n) { - FsIndex_t n_per_task = globalCells / ntasks; - FsIndex_t remainder = globalCells % ntasks; - - if (my_n < remainder) { - return my_n * (n_per_task + 1); - } else { - return my_n * n_per_task + remainder; - } + static FsIndex_t calcLocalStart(FsSize_t numCells, Task_t nTasks, Task_t taskIndex) { + const FsIndex_t n_per_task = numCells / nTasks; + const FsIndex_t remainder = numCells % nTasks; + return taskIndex * calcLocalSize(numCells, nTasks, taskIndex) + (taskIndex >= remainder) * remainder; } //! Helper function: given a global cellID, calculate the global cell coordinate from it. @@ -72,21 +78,6 @@ namespace FsGridTools { return cell; } - //! Helper function: calculate size of the local coordinate space for the given dimension - // \param globalCells Number of cells in the global Simulation, in this dimension - // \param ntasks Total number of tasks in this dimension - // \param my_n This task's position in this dimension - // \return Number of cells for this task's local domain (actual cells, not counting ghost cells) - static FsIndex_t calcLocalSize(FsSize_t globalCells, Task_t ntasks, Task_t my_n) { - FsIndex_t n_per_task = globalCells / ntasks; - FsIndex_t remainder = globalCells % ntasks; - if (my_n < remainder) { - return n_per_task + 1; - } else { - return n_per_task; - } - } - //! Helper function to optimize decomposition of this grid over the given number of tasks static void computeDomainDecomposition(const std::array& GlobalSize, Task_t nProcs, std::array& processDomainDecomposition, int rank, diff --git a/tests/unit_tests/tools_tests.cpp b/tests/unit_tests/tools_tests.cpp index ce10e76..72a79cb 100644 --- a/tests/unit_tests/tools_tests.cpp +++ b/tests/unit_tests/tools_tests.cpp @@ -11,3 +11,37 @@ TEST(FsGridToolsTests, calcLocalStart1) { ASSERT_EQ(FsGridTools::calcLocalStart(numGlobalCells, numTasks, 2), 64); ASSERT_EQ(FsGridTools::calcLocalStart(numGlobalCells, numTasks, 3), 96); } + +TEST(FsGridToolsTests, calcLocalStart2) { + constexpr FsGridTools::FsSize_t numGlobalCells = 666u; + constexpr FsGridTools::Task_t numTasks = 64u; + + for (int i = 0; i < 26; i++) { + ASSERT_EQ(FsGridTools::calcLocalStart(numGlobalCells, numTasks, i), i * 11); + } + for (int i = 26; i < numTasks; i++) { + ASSERT_EQ(FsGridTools::calcLocalStart(numGlobalCells, numTasks, i), i * 10 + 26); + } +} + +TEST(FsGridToolsTests, calcLocalSize1) { + constexpr FsGridTools::FsSize_t numGlobalCells = 1024u; + constexpr FsGridTools::Task_t numTasks = 32u; + + for (int i = 0; i < numTasks; i++) { + ASSERT_EQ(FsGridTools::calcLocalSize(numGlobalCells, numTasks, i), 32); + } +} + +TEST(FsGridToolsTests, calcLocalSize2) { + constexpr FsGridTools::FsSize_t numGlobalCells = 666u; + constexpr FsGridTools::Task_t numTasks = 64u; + + for (int i = 0; i < 26; i++) { + ASSERT_EQ(FsGridTools::calcLocalSize(numGlobalCells, numTasks, i), 11); + } + + for (int i = 26; i < numTasks; i++) { + ASSERT_EQ(FsGridTools::calcLocalSize(numGlobalCells, numTasks, i), 10); + } +}